diff --git a/contributing/samples/telemetry/main.py b/contributing/samples/telemetry/main.py index 3998c2a7..e580060d 100755 --- a/contributing/samples/telemetry/main.py +++ b/contributing/samples/telemetry/main.py @@ -46,13 +46,19 @@ async def main(): role='user', parts=[types.Part.from_text(text=new_message)] ) print('** User says:', content.model_dump(exclude_none=True)) - async for event in runner.run_async( + # TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is + # no longer supported. + agen = runner.run_async( user_id=user_id_1, session_id=session.id, new_message=content, - ): - if event.content.parts and event.content.parts[0].text: - print(f'** {event.author}: {event.content.parts[0].text}') + ) + try: + async for event in agen: + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + finally: + await agen.aclose() async def run_prompt_bytes(session: Session, new_message: str): content = types.Content( @@ -64,14 +70,20 @@ async def main(): ], ) print('** User says:', content.model_dump(exclude_none=True)) - async for event in runner.run_async( + # TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is + # no longer supported. + agen = runner.run_async( user_id=user_id_1, session_id=session.id, new_message=content, run_config=RunConfig(save_input_blobs_as_artifacts=True), - ): - if event.content.parts and event.content.parts[0].text: - print(f'** {event.author}: {event.content.parts[0].text}') + ) + try: + async for event in agen: + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + finally: + await agen.aclose() start_time = time.time() print('Start time:', start_time) diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index 6cd34a98..29b681a8 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -24,6 +24,8 @@ from typing import Callable from typing import Optional import uuid +from ...utils.context_utils import Aclosing + try: from a2a.server.agent_execution import AgentExecutor from a2a.server.agent_execution.context import RequestContext @@ -212,12 +214,13 @@ class A2aAgentExecutor(AgentExecutor): ) task_result_aggregator = TaskResultAggregator() - async for adk_event in runner.run_async(**run_args): - for a2a_event in convert_event_to_a2a_events( - adk_event, invocation_context, context.task_id, context.context_id - ): - task_result_aggregator.process_event(a2a_event) - await event_queue.enqueue_event(a2a_event) + async with Aclosing(runner.run_async(**run_args)) as agen: + async for adk_event in agen: + for a2a_event in convert_event_to_a2a_events( + adk_event, invocation_context, context.task_id, context.context_id + ): + task_result_aggregator.process_event(a2a_event) + await event_queue.enqueue_event(a2a_event) # publish the task result event - this is final if ( diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index f26ee42d..45f72760 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -39,6 +39,7 @@ from typing_extensions import override from typing_extensions import TypeAlias from ..events.event import Event +from ..utils.context_utils import Aclosing from ..utils.feature_decorator import experimental from .base_agent_config import BaseAgentConfig from .callback_context import CallbackContext @@ -212,21 +213,27 @@ class BaseAgent(BaseModel): Event: the events generated by the agent. """ - with tracer.start_as_current_span(f'agent_run [{self.name}]'): - ctx = self._create_invocation_context(parent_context) + async def _run_with_trace() -> AsyncGenerator[Event, None]: + with tracer.start_as_current_span(f'agent_run [{self.name}]'): + ctx = self._create_invocation_context(parent_context) - if event := await self.__handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return + if event := await self.__handle_before_agent_callback(ctx): + yield event + if ctx.end_invocation: + return - async for event in self._run_async_impl(ctx): - yield event + async with Aclosing(self._run_async_impl(ctx)) as agen: + async for event in agen: + yield event - if ctx.end_invocation: - return + if ctx.end_invocation: + return - if event := await self.__handle_after_agent_callback(ctx): + if event := await self.__handle_after_agent_callback(ctx): + yield event + + async with Aclosing(_run_with_trace()) as agen: + async for event in agen: yield event @final @@ -243,18 +250,25 @@ class BaseAgent(BaseModel): Yields: Event: the events generated by the agent. """ - with tracer.start_as_current_span(f'agent_run [{self.name}]'): - ctx = self._create_invocation_context(parent_context) - if event := await self.__handle_before_agent_callback(ctx): - yield event - if ctx.end_invocation: - return + async def _run_with_trace() -> AsyncGenerator[Event, None]: + with tracer.start_as_current_span(f'agent_run [{self.name}]'): + ctx = self._create_invocation_context(parent_context) - async for event in self._run_live_impl(ctx): - yield event + if event := await self.__handle_before_agent_callback(ctx): + yield event + if ctx.end_invocation: + return - if event := await self.__handle_after_agent_callback(ctx): + async with Aclosing(self._run_live_impl(ctx)) as agen: + async for event in agen: + yield event + + if event := await self.__handle_after_agent_callback(ctx): + yield event + + async with Aclosing(_run_with_trace()) as agen: + async for event in agen: yield event async def _run_async_impl( diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index ae55dd1e..99302e2f 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -51,6 +51,7 @@ from ..tools.base_toolset import BaseToolset from ..tools.function_tool import FunctionTool from ..tools.tool_configs import ToolConfig from ..tools.tool_context import ToolContext +from ..utils.context_utils import Aclosing from ..utils.feature_decorator import experimental from .base_agent import BaseAgent from .base_agent_config import BaseAgentConfig @@ -283,19 +284,21 @@ class LlmAgent(BaseAgent): async def _run_async_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: - async for event in self._llm_flow.run_async(ctx): - self.__maybe_save_output_to_state(event) - yield event + async with Aclosing(self._llm_flow.run_async(ctx)) as agen: + async for event in agen: + self.__maybe_save_output_to_state(event) + yield event @override async def _run_live_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: - async for event in self._llm_flow.run_live(ctx): - self.__maybe_save_output_to_state(event) - yield event - if ctx.end_invocation: - return + async with Aclosing(self._llm_flow.run_live(ctx)) as agen: + async for event in agen: + self.__maybe_save_output_to_state(event) + yield event + if ctx.end_invocation: + return @property def canonical_model(self) -> BaseLlm: diff --git a/src/google/adk/agents/loop_agent.py b/src/google/adk/agents/loop_agent.py index da879f7d..1313d208 100644 --- a/src/google/adk/agents/loop_agent.py +++ b/src/google/adk/agents/loop_agent.py @@ -27,6 +27,7 @@ from typing_extensions import override from ..agents.invocation_context import InvocationContext from ..events.event import Event +from ..utils.context_utils import Aclosing from ..utils.feature_decorator import experimental from .base_agent import BaseAgent from .base_agent_config import BaseAgentConfig @@ -58,10 +59,11 @@ class LoopAgent(BaseAgent): while not self.max_iterations or times_looped < self.max_iterations: for sub_agent in self.sub_agents: should_exit = False - async for event in sub_agent.run_async(ctx): - yield event - if event.actions.escalate: - should_exit = True + async with Aclosing(sub_agent.run_async(ctx)) as agen: + async for event in agen: + yield event + if event.actions.escalate: + should_exit = True if should_exit: return diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index ffc29a09..96fea31c 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -26,6 +26,7 @@ from typing import Type from typing_extensions import override from ..events.event import Event +from ..utils.context_utils import Aclosing from .base_agent import BaseAgent from .base_agent_config import BaseAgentConfig from .invocation_context import InvocationContext @@ -111,8 +112,10 @@ class ParallelAgent(BaseAgent): ) for sub_agent in self.sub_agents ] - async for event in _merge_agent_run(agent_runs): - yield event + + async with Aclosing(_merge_agent_run(agent_runs)) as agen: + async for event in agen: + yield event @override async def _run_live_impl( diff --git a/src/google/adk/agents/sequential_agent.py b/src/google/adk/agents/sequential_agent.py index 6ed203ed..8ec1e43b 100644 --- a/src/google/adk/agents/sequential_agent.py +++ b/src/google/adk/agents/sequential_agent.py @@ -22,6 +22,7 @@ from typing import Type from typing_extensions import override from ..events.event import Event +from ..utils.context_utils import Aclosing from .base_agent import BaseAgent from .base_agent import BaseAgentConfig from .invocation_context import InvocationContext @@ -40,8 +41,9 @@ class SequentialAgent(BaseAgent): self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: for sub_agent in self.sub_agents: - async for event in sub_agent.run_async(ctx): - yield event + async with Aclosing(sub_agent.run_async(ctx)) as agen: + async for event in agen: + yield event @override async def _run_live_impl( @@ -78,5 +80,6 @@ class SequentialAgent(BaseAgent): do not generate any text other than the function call.""" for sub_agent in self.sub_agents: - async for event in sub_agent.run_live(ctx): - yield event + async with Aclosing(sub_agent.run_live(ctx)) as agen: + async for event in agen: + yield event diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 44df7908..2d809262 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -73,6 +73,7 @@ from ..memory.base_memory_service import BaseMemoryService from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.session import Session +from ..utils.context_utils import Aclosing from .cli_eval import EVAL_SESSION_ID_PREFIX from .cli_eval import EvalStatus from .utils import cleanup @@ -828,14 +829,16 @@ class AdkWebServer: if not session: raise HTTPException(status_code=404, detail="Session not found") runner = await self.get_runner_async(req.app_name) - events = [ - event - async for event in runner.run_async( + + events = [] + async with Aclosing( + runner.run_async( user_id=req.user_id, session_id=req.session_id, new_message=req.new_message, ) - ] + ) as agen: + events = [event async for event in agen] logger.info("Generated %s events in agent run", len(events)) logger.debug("Events generated: %s", events) return events @@ -856,19 +859,24 @@ class AdkWebServer: StreamingMode.SSE if req.streaming else StreamingMode.NONE ) runner = await self.get_runner_async(req.app_name) - async for event in runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - run_config=RunConfig(streaming_mode=stream_mode), - ): - # Format as SSE data - sse_event = event.model_dump_json(exclude_none=True, by_alias=True) - logger.debug( - "Generated event in agent run streaming: %s", sse_event - ) - yield f"data: {sse_event}\n\n" + async with Aclosing( + runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + run_config=RunConfig(streaming_mode=stream_mode), + ) + ) as agen: + async for event in agen: + # Format as SSE data + sse_event = event.model_dump_json( + exclude_none=True, by_alias=True + ) + logger.debug( + "Generated event in agent run streaming: %s", sse_event + ) + yield f"data: {sse_event}\n\n" except Exception as e: logger.exception("Error in event_generator: %s", e) # You might want to yield an error event here @@ -954,12 +962,15 @@ class AdkWebServer: async def forward_events(): runner = await self.get_runner_async(app_name) - async for event in runner.run_live( - session=session, live_request_queue=live_request_queue - ): - await websocket.send_text( - event.model_dump_json(exclude_none=True, by_alias=True) - ) + async with Aclosing( + runner.run_live( + session=session, live_request_queue=live_request_queue + ) + ) as agen: + async for event in agen: + await websocket.send_text( + event.model_dump_json(exclude_none=True, by_alias=True) + ) async def process_messages(): try: diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index bf149a21..70c58d04 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -30,6 +30,7 @@ from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.session import Session +from ..utils.context_utils import Aclosing from .utils import envs from .utils.agent_loader import AgentLoader @@ -65,12 +66,15 @@ async def run_input_file( for query in input_file.queries: click.echo(f'[user]: {query}') content = types.Content(role='user', parts=[types.Part(text=query)]) - async for event in runner.run_async( - user_id=session.user_id, session_id=session.id, new_message=content - ): - if event.content and event.content.parts: - if text := ''.join(part.text or '' for part in event.content.parts): - click.echo(f'[{event.author}]: {text}') + async with Aclosing( + runner.run_async( + user_id=session.user_id, session_id=session.id, new_message=content + ) + ) as agen: + async for event in agen: + if event.content and event.content.parts: + if text := ''.join(part.text or '' for part in event.content.parts): + click.echo(f'[{event.author}]: {text}') return session @@ -94,14 +98,19 @@ async def run_interactively( continue if query == 'exit': break - async for event in runner.run_async( - user_id=session.user_id, - session_id=session.id, - new_message=types.Content(role='user', parts=[types.Part(text=query)]), - ): - if event.content and event.content.parts: - if text := ''.join(part.text or '' for part in event.content.parts): - click.echo(f'[{event.author}]: {text}') + async with Aclosing( + runner.run_async( + user_id=session.user_id, + session_id=session.id, + new_message=types.Content( + role='user', parts=[types.Part(text=query)] + ), + ) + ) as agen: + async for event in agen: + if event.content and event.content.parts: + if text := ''.join(part.text or '' for part in event.content.parts): + click.echo(f'[{event.author}]: {text}') await runner.close() diff --git a/src/google/adk/cli/cli_eval.py b/src/google/adk/cli/cli_eval.py index 2f1d090c..89e7f415 100644 --- a/src/google/adk/cli/cli_eval.py +++ b/src/google/adk/cli/cli_eval.py @@ -45,6 +45,7 @@ from ..evaluation.eval_result import EvalCaseResult from ..evaluation.evaluator import EvalStatus from ..evaluation.evaluator import Evaluator from ..sessions.base_session_service import BaseSessionService +from ..utils.context_utils import Aclosing logger = logging.getLogger("google_adk." + __name__) @@ -159,10 +160,11 @@ async def _collect_inferences( """ inference_results = [] for inference_request in inference_requests: - async for inference_result in eval_service.perform_inference( - inference_request=inference_request - ): - inference_results.append(inference_result) + async with Aclosing( + eval_service.perform_inference(inference_request=inference_request) + ) as agen: + async for inference_result in agen: + inference_results.append(inference_result) return inference_results @@ -180,10 +182,11 @@ async def _collect_eval_results( inference_results=inference_results, evaluate_config=EvaluateConfig(eval_metrics=eval_metrics), ) - async for eval_result in eval_service.evaluate( - evaluate_request=evaluate_request - ): - eval_results.append(eval_result) + async with Aclosing( + eval_service.evaluate(evaluate_request=evaluate_request) + ) as agen: + async for eval_result in agen: + eval_results.append(eval_result) return eval_results diff --git a/src/google/adk/evaluation/agent_evaluator.py b/src/google/adk/evaluation/agent_evaluator.py index c1872884..710d6e48 100644 --- a/src/google/adk/evaluation/agent_evaluator.py +++ b/src/google/adk/evaluation/agent_evaluator.py @@ -32,6 +32,7 @@ from pydantic import BaseModel from pydantic import ValidationError from ..agents.base_agent import BaseAgent +from ..utils.context_utils import Aclosing from .constants import MISSING_EVAL_DEPENDENCIES_MESSAGE from .eval_case import IntermediateData from .eval_case import Invocation @@ -538,10 +539,11 @@ class AgentEvaluator: # Generate inferences inference_results = [] for inference_request in inference_requests: - async for inference_result in eval_service.perform_inference( - inference_request=inference_request - ): - inference_results.append(inference_result) + async with Aclosing( + eval_service.perform_inference(inference_request=inference_request) + ) as agen: + async for inference_result in agen: + inference_results.append(inference_result) # Evaluate metrics # As we perform more than one run for an eval case, we collect eval results @@ -551,14 +553,15 @@ class AgentEvaluator: inference_results=inference_results, evaluate_config=EvaluateConfig(eval_metrics=eval_metrics), ) - async for eval_result in eval_service.evaluate( - evaluate_request=evaluate_request - ): - eval_id = eval_result.eval_id - if eval_id not in eval_results_by_eval_id: - eval_results_by_eval_id[eval_id] = [] + async with Aclosing( + eval_service.evaluate(evaluate_request=evaluate_request) + ) as agen: + async for eval_result in agen: + eval_id = eval_result.eval_id + if eval_id not in eval_results_by_eval_id: + eval_results_by_eval_id[eval_id] = [] - eval_results_by_eval_id[eval_id].append(eval_result) + eval_results_by_eval_id[eval_id].append(eval_result) return eval_results_by_eval_id diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index 7d643649..7f1c94f1 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -30,6 +30,7 @@ from ..runners import Runner from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.session import Session +from ..utils.context_utils import Aclosing from .eval_case import EvalCase from .eval_case import IntermediateData from .eval_case import Invocation @@ -189,18 +190,25 @@ class EvaluationGenerator: tool_uses = [] invocation_id = "" - async for event in runner.run_async( - user_id=user_id, session_id=session_id, new_message=user_content - ): - invocation_id = ( - event.invocation_id if not invocation_id else invocation_id - ) + async with Aclosing( + runner.run_async( + user_id=user_id, session_id=session_id, new_message=user_content + ) + ) as agen: + async for event in agen: + invocation_id = ( + event.invocation_id if not invocation_id else invocation_id + ) - if event.is_final_response() and event.content and event.content.parts: - final_response = event.content - elif event.get_function_calls(): - for call in event.get_function_calls(): - tool_uses.append(call) + if ( + event.is_final_response() + and event.content + and event.content.parts + ): + final_response = event.content + elif event.get_function_calls(): + for call in event.get_function_calls(): + tool_uses.append(call) response_invocations.append( Invocation( diff --git a/src/google/adk/evaluation/llm_as_judge.py b/src/google/adk/evaluation/llm_as_judge.py index ac1b3306..b17ee82d 100644 --- a/src/google/adk/evaluation/llm_as_judge.py +++ b/src/google/adk/evaluation/llm_as_judge.py @@ -24,6 +24,7 @@ from ..models.base_llm import BaseLlm from ..models.llm_request import LlmRequest from ..models.llm_response import LlmResponse from ..models.registry import LLMRegistry +from ..utils.context_utils import Aclosing from .eval_case import Invocation from .eval_metrics import EvalMetric from .evaluator import EvaluationResult @@ -109,21 +110,22 @@ class LlmAsJudge(Evaluator): num_samples = self._judge_model_options.num_samples invocation_result_samples = [] for _ in range(num_samples): - async for llm_response in self._judge_model.generate_content_async( - llm_request - ): - # Non-streaming call, so there is only one response content. - score = self.convert_auto_rater_response_to_score(llm_response) - invocation_result_samples.append( - PerInvocationResult( - actual_invocation=actual, - expected_invocation=expected, - score=score, - eval_status=get_eval_status( - score, self._eval_metric.threshold - ), - ) - ) + async with Aclosing( + self._judge_model.generate_content_async(llm_request) + ) as agen: + async for llm_response in agen: + # Non-streaming call, so there is only one response content. + score = self.convert_auto_rater_response_to_score(llm_response) + invocation_result_samples.append( + PerInvocationResult( + actual_invocation=actual, + expected_invocation=expected, + score=score, + eval_status=get_eval_status( + score, self._eval_metric.threshold + ), + ) + ) if not invocation_result_samples: continue per_invocation_results.append( diff --git a/src/google/adk/flows/llm_flows/_code_execution.py b/src/google/adk/flows/llm_flows/_code_execution.py index c2252f97..5c0a5777 100644 --- a/src/google/adk/flows/llm_flows/_code_execution.py +++ b/src/google/adk/flows/llm_flows/_code_execution.py @@ -39,6 +39,7 @@ from ...code_executors.code_executor_context import CodeExecutorContext from ...events.event import Event from ...events.event_actions import EventActions from ...models.llm_response import LlmResponse +from ...utils.context_utils import Aclosing from ._base_llm_processor import BaseLlmRequestProcessor from ._base_llm_processor import BaseLlmResponseProcessor @@ -122,8 +123,11 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor): if not invocation_context.agent.code_executor: return - async for event in _run_pre_processor(invocation_context, llm_request): - yield event + async with Aclosing( + _run_pre_processor(invocation_context, llm_request) + ) as agen: + async for event in agen: + yield event # Convert the code execution parts to text parts. if not isinstance(invocation_context.agent.code_executor, BaseCodeExecutor): @@ -152,8 +156,11 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor): if llm_response.partial: return - async for event in _run_post_processor(invocation_context, llm_response): - yield event + async with Aclosing( + _run_post_processor(invocation_context, llm_response) + ) as agen: + async for event in agen: + yield event response_processor = _CodeExecutionResponseProcessor() diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 90cf0fbc..0adaea1d 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -46,6 +46,7 @@ from ...telemetry import trace_send_data from ...telemetry import tracer from ...tools.base_toolset import BaseToolset from ...tools.tool_context import ToolContext +from ...utils.context_utils import Aclosing if TYPE_CHECKING: from ...agents.llm_agent import LlmAgent @@ -77,8 +78,11 @@ class BaseLlmFlow(ABC): event_id = Event.new_id() # Preprocess before calling the LLM. - async for event in self._preprocess_async(invocation_context, llm_request): - yield event + async with Aclosing( + self._preprocess_async(invocation_context, llm_request) + ) as agen: + async for event in agen: + yield event if invocation_context.end_invocation: return @@ -110,75 +114,78 @@ class BaseLlmFlow(ABC): async with llm.connect(llm_request) as llm_connection: if llm_request.contents: # Sends the conversation history to the model. - with tracer.start_as_current_span('send_data'): + # with tracer.start_as_current_span('send_data'): - if invocation_context.transcription_cache: - from . import audio_transcriber + if invocation_context.transcription_cache: + from . import audio_transcriber - audio_transcriber = audio_transcriber.AudioTranscriber( - init_client=True - if invocation_context.run_config.input_audio_transcription - is None - else False - ) - contents = audio_transcriber.transcribe_file(invocation_context) - logger.debug('Sending history to model: %s', contents) - await llm_connection.send_history(contents) - invocation_context.transcription_cache = None - trace_send_data(invocation_context, event_id, contents) - else: - await llm_connection.send_history(llm_request.contents) - trace_send_data( - invocation_context, event_id, llm_request.contents - ) + audio_transcriber = audio_transcriber.AudioTranscriber( + init_client=True + if invocation_context.run_config.input_audio_transcription + is None + else False + ) + contents = audio_transcriber.transcribe_file(invocation_context) + logger.debug('Sending history to model: %s', contents) + await llm_connection.send_history(contents) + invocation_context.transcription_cache = None + trace_send_data(invocation_context, event_id, contents) + else: + await llm_connection.send_history(llm_request.contents) + trace_send_data( + invocation_context, event_id, llm_request.contents + ) send_task = asyncio.create_task( self._send_to_model(llm_connection, invocation_context) ) try: - async for event in self._receive_from_model( - llm_connection, - event_id, - invocation_context, - llm_request, - ): - # Empty event means the queue is closed. - if not event: - break - logger.debug('Receive new event: %s', event) - yield event - # send back the function response - if event.get_function_responses(): - logger.debug( - 'Sending back last function response event: %s', event + async with Aclosing( + self._receive_from_model( + llm_connection, + event_id, + invocation_context, + llm_request, ) - invocation_context.live_request_queue.send_content( + ) as agen: + async for event in agen: + # Empty event means the queue is closed. + if not event: + break + logger.debug('Receive new event: %s', event) + yield event + # send back the function response + if event.get_function_responses(): + logger.debug( + 'Sending back last function response event: %s', event + ) + invocation_context.live_request_queue.send_content( + event.content + ) + if ( event.content - ) - if ( - event.content - and event.content.parts - and event.content.parts[0].function_response - and event.content.parts[0].function_response.name - == 'transfer_to_agent' - ): - await asyncio.sleep(1) - # cancel the tasks that belongs to the closed connection. - send_task.cancel() - await llm_connection.close() - if ( - event.content - and event.content.parts - and event.content.parts[0].function_response - and event.content.parts[0].function_response.name - == 'task_completed' - ): - # this is used for sequential agent to signal the end of the agent. - await asyncio.sleep(1) - # cancel the tasks that belongs to the closed connection. - send_task.cancel() - return + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name + == 'transfer_to_agent' + ): + await asyncio.sleep(1) + # cancel the tasks that belongs to the closed connection. + send_task.cancel() + await llm_connection.close() + if ( + event.content + and event.content.parts + and event.content.parts[0].function_response + and event.content.parts[0].function_response.name + == 'task_completed' + ): + # this is used for sequential agent to signal the end of the agent. + await asyncio.sleep(1) + # cancel the tasks that belongs to the closed connection. + send_task.cancel() + return finally: # Clean up if not send_task.done(): @@ -282,45 +289,49 @@ class BaseLlmFlow(ABC): assert invocation_context.live_request_queue try: while True: - async for llm_response in llm_connection.receive(): - if llm_response.live_session_resumption_update: - logger.info( - 'Update session resumption hanlde:' - f' {llm_response.live_session_resumption_update}.' - ) - invocation_context.live_session_resumption_handle = ( - llm_response.live_session_resumption_update.new_handle - ) - model_response_event = Event( - id=Event.new_id(), - invocation_id=invocation_context.invocation_id, - author=get_author_for_event(llm_response), - ) - async for event in self._postprocess_live( - invocation_context, - llm_request, - llm_response, - model_response_event, - ): - if ( - event.content - and event.content.parts - and event.content.parts[0].inline_data is None - and not event.partial - ): - # This can be either user data or transcription data. - # when output transcription enabled, it will contain model's - # transcription. - # when input transcription enabled, it will contain user - # transcription. - if not invocation_context.transcription_cache: - invocation_context.transcription_cache = [] - invocation_context.transcription_cache.append( - TranscriptionEntry( - role=event.content.role, data=event.content - ) + async with Aclosing(llm_connection.receive()) as agen: + async for llm_response in agen: + if llm_response.live_session_resumption_update: + logger.info( + 'Update session resumption hanlde:' + f' {llm_response.live_session_resumption_update}.' ) - yield event + invocation_context.live_session_resumption_handle = ( + llm_response.live_session_resumption_update.new_handle + ) + model_response_event = Event( + id=Event.new_id(), + invocation_id=invocation_context.invocation_id, + author=get_author_for_event(llm_response), + ) + async with Aclosing( + self._postprocess_live( + invocation_context, + llm_request, + llm_response, + model_response_event, + ) + ) as agen: + async for event in agen: + if ( + event.content + and event.content.parts + and event.content.parts[0].inline_data is None + and not event.partial + ): + # This can be either user data or transcription data. + # when output transcription enabled, it will contain model's + # transcription. + # when input transcription enabled, it will contain user + # transcription. + if not invocation_context.transcription_cache: + invocation_context.transcription_cache = [] + invocation_context.transcription_cache.append( + TranscriptionEntry( + role=event.content.role, data=event.content + ) + ) + yield event # Give opportunity for other tasks to run. await asyncio.sleep(0) except ConnectionClosedOK: @@ -332,9 +343,10 @@ class BaseLlmFlow(ABC): """Runs the flow.""" while True: last_event = None - async for event in self._run_one_step_async(invocation_context): - last_event = event - yield event + async with Aclosing(self._run_one_step_async(invocation_context)) as agen: + async for event in agen: + last_event = event + yield event if not last_event or last_event.is_final_response() or last_event.partial: if last_event and last_event.partial: logger.warning('The last event is partial, which is not expected.') @@ -348,8 +360,11 @@ class BaseLlmFlow(ABC): llm_request = LlmRequest() # Preprocess before calling the LLM. - async for event in self._preprocess_async(invocation_context, llm_request): - yield event + async with Aclosing( + self._preprocess_async(invocation_context, llm_request) + ) as agen: + async for event in agen: + yield event if invocation_context.end_invocation: return @@ -360,17 +375,26 @@ class BaseLlmFlow(ABC): author=invocation_context.agent.name, branch=invocation_context.branch, ) - async for llm_response in self._call_llm_async( - invocation_context, llm_request, model_response_event - ): - # Postprocess after calling the LLM. - async for event in self._postprocess_async( - invocation_context, llm_request, llm_response, model_response_event - ): - # Update the mutable event id to avoid conflict - model_response_event.id = Event.new_id() - model_response_event.timestamp = datetime.datetime.now().timestamp() - yield event + async with Aclosing( + self._call_llm_async( + invocation_context, llm_request, model_response_event + ) + ) as agen: + async for llm_response in agen: + # Postprocess after calling the LLM. + async with Aclosing( + self._postprocess_async( + invocation_context, + llm_request, + llm_response, + model_response_event, + ) + ) as agen: + async for event in agen: + # Update the mutable event id to avoid conflict + model_response_event.id = Event.new_id() + model_response_event.timestamp = datetime.datetime.now().timestamp() + yield event async def _preprocess_async( self, invocation_context: InvocationContext, llm_request: LlmRequest @@ -383,8 +407,11 @@ class BaseLlmFlow(ABC): # Runs processors. for processor in self.request_processors: - async for event in processor.run_async(invocation_context, llm_request): - yield event + async with Aclosing( + processor.run_async(invocation_context, llm_request) + ) as agen: + async for event in agen: + yield event # Run processors for tools. for tool_union in agent.tools: @@ -427,10 +454,11 @@ class BaseLlmFlow(ABC): """ # Runs processors. - async for event in self._postprocess_run_processors_async( - invocation_context, llm_response - ): - yield event + async with Aclosing( + self._postprocess_run_processors_async(invocation_context, llm_response) + ) as agen: + async for event in agen: + yield event # Skip the model response event if there is no content and no error code. # This is needed for the code executor to trigger another loop. @@ -449,10 +477,13 @@ class BaseLlmFlow(ABC): # Handles function calls. if model_response_event.get_function_calls(): - async for event in self._postprocess_handle_function_calls_async( - invocation_context, model_response_event, llm_request - ): - yield event + async with Aclosing( + self._postprocess_handle_function_calls_async( + invocation_context, model_response_event, llm_request + ) + ) as agen: + async for event in agen: + yield event async def _postprocess_live( self, @@ -474,10 +505,11 @@ class BaseLlmFlow(ABC): """ # Runs processors. - async for event in self._postprocess_run_processors_async( - invocation_context, llm_response - ): - yield event + async with Aclosing( + self._postprocess_run_processors_async(invocation_context, llm_response) + ) as agen: + async for event in agen: + yield event # Skip the model response event if there is no content and no error code. # This is needed for the code executor to trigger another loop. @@ -521,15 +553,19 @@ class BaseLlmFlow(ABC): agent_to_run = self._get_agent_to_run( invocation_context, transfer_to_agent ) - async for item in agent_to_run.run_live(invocation_context): - yield item + async with Aclosing(agent_to_run.run_live(invocation_context)) as agen: + async for item in agen: + yield item async def _postprocess_run_processors_async( self, invocation_context: InvocationContext, llm_response: LlmResponse ) -> AsyncGenerator[Event, None]: for processor in self.response_processors: - async for event in processor.run_async(invocation_context, llm_response): - yield event + async with Aclosing( + processor.run_async(invocation_context, llm_response) + ) as agen: + async for event in agen: + yield event async def _postprocess_handle_function_calls_async( self, @@ -565,8 +601,9 @@ class BaseLlmFlow(ABC): agent_to_run = self._get_agent_to_run( invocation_context, transfer_to_agent ) - async for event in agent_to_run.run_async(invocation_context): - yield event + async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: + async for event in agen: + yield event def _get_agent_to_run( self, invocation_context: InvocationContext, agent_name: str @@ -602,58 +639,71 @@ class BaseLlmFlow(ABC): # Calls the LLM. llm = self.__get_llm(invocation_context) - with tracer.start_as_current_span('call_llm'): + + async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: + # with tracer.start_as_current_span('call_llm'): if invocation_context.run_config.support_cfc: invocation_context.live_request_queue = LiveRequestQueue() responses_generator = self.run_live(invocation_context) - async for llm_response in self._run_and_handle_error( - responses_generator, - invocation_context, - llm_request, - model_response_event, - ): - # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response - # only yield partial response in SSE streaming mode - if ( - invocation_context.run_config.streaming_mode == StreamingMode.SSE - or not llm_response.partial - ): - yield llm_response - if llm_response.turn_complete: - invocation_context.live_request_queue.close() + async with Aclosing( + self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, + ) + ) as agen: + async for llm_response in agen: + # Runs after_model_callback if it exists. + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event + ): + llm_response = altered_llm_response + # only yield partial response in SSE streaming mode + if ( + invocation_context.run_config.streaming_mode + == StreamingMode.SSE + or not llm_response.partial + ): + yield llm_response + if llm_response.turn_complete: + invocation_context.live_request_queue.close() else: - # Check if we can make this llm call or not. If the current call pushes - # the counter beyond the max set value, then the execution is stopped - # right here, and exception is thrown. + # Check if we can make this llm call or not. If the current call + # pushes the counter beyond the max set value, then the execution is + # stopped right here, and exception is thrown. invocation_context.increment_llm_call_count() responses_generator = llm.generate_content_async( llm_request, stream=invocation_context.run_config.streaming_mode == StreamingMode.SSE, ) - async for llm_response in self._run_and_handle_error( - responses_generator, - invocation_context, - llm_request, - model_response_event, - ): - trace_call_llm( - invocation_context, - model_response_event.id, - llm_request, - llm_response, - ) - # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response + async with Aclosing( + self._run_and_handle_error( + responses_generator, + invocation_context, + llm_request, + model_response_event, + ) + ) as agen: + async for llm_response in agen: + trace_call_llm( + invocation_context, + model_response_event.id, + llm_request, + llm_response, + ) + # Runs after_model_callback if it exists. + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event + ): + llm_response = altered_llm_response - yield llm_response + yield llm_response + + async with Aclosing(_call_llm_with_tracing()) as agen: + async for event in agen: + yield event async def _handle_before_model_callback( self, @@ -775,8 +825,9 @@ class BaseLlmFlow(ABC): A generator of LlmResponse. """ try: - async for response in response_generator: - yield response + async with Aclosing(response_generator) as agen: + async for response in agen: + yield response except Exception as model_error: callback_context = CallbackContext( invocation_context, event_actions=model_response_event.actions diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 86f7e30a..0c8fa86a 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -40,6 +40,7 @@ from ...telemetry import trace_tool_call from ...telemetry import tracer from ...tools.base_tool import BaseTool from ...tools.tool_context import ToolContext +from ...utils.context_utils import Aclosing if TYPE_CHECKING: from ...agents.llm_agent import LlmAgent @@ -510,21 +511,24 @@ async def _process_function_live_helper( # we require the function to be a async generator function async def run_tool_and_update_queue(tool, function_args, tool_context): try: - async for result in __call_tool_live( - tool=tool, - args=function_args, - tool_context=tool_context, - invocation_context=invocation_context, - ): - updated_content = types.Content( - role='user', - parts=[ - types.Part.from_text( - text=f'Function {tool.name} returned: {result}' - ) - ], - ) - invocation_context.live_request_queue.send_content(updated_content) + async with Aclosing( + __call_tool_live( + tool=tool, + args=function_args, + tool_context=tool_context, + invocation_context=invocation_context, + ) + ) as agen: + async for result in agen: + updated_content = types.Content( + role='user', + parts=[ + types.Part.from_text( + text=f'Function {tool.name} returned: {result}' + ) + ], + ) + invocation_context.live_request_queue.send_content(updated_content) except asyncio.CancelledError: raise # Re-raise to properly propagate the cancellation @@ -586,12 +590,15 @@ async def __call_tool_live( invocation_context: InvocationContext, ) -> AsyncGenerator[Event, None]: """Calls the tool asynchronously (awaiting the coroutine).""" - async for item in tool._call_live( - args=args, - tool_context=tool_context, - invocation_context=invocation_context, - ): - yield item + async with Aclosing( + tool._call_live( + args=args, + tool_context=tool_context, + invocation_context=invocation_context, + ) + ) as agen: + async for item in agen: + yield item async def __call_tool_async( diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 3b46c91a..fd6f4a78 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -21,6 +21,7 @@ from typing import Union from google.genai import live from google.genai import types +from ..utils.context_utils import Aclosing from .base_llm_connection import BaseLlmConnection from .llm_response import LlmResponse @@ -142,90 +143,92 @@ class GeminiLlmConnection(BaseLlmConnection): """ text = '' - async for message in self._gemini_session.receive(): - logger.debug('Got LLM Live message: %s', message) - if message.server_content: - content = message.server_content.model_turn - if content and content.parts: - llm_response = LlmResponse( - content=content, interrupted=message.server_content.interrupted - ) - if content.parts[0].text: - text += content.parts[0].text - llm_response.partial = True - # don't yield the merged text event when receiving audio data - elif text and not content.parts[0].inline_data: + async with Aclosing(self._gemini_session.receive()) as agen: + async for message in agen: + logger.debug('Got LLM Live message: %s', message) + if message.server_content: + content = message.server_content.model_turn + if content and content.parts: + llm_response = LlmResponse( + content=content, interrupted=message.server_content.interrupted + ) + if content.parts[0].text: + text += content.parts[0].text + llm_response.partial = True + # don't yield the merged text event when receiving audio data + elif text and not content.parts[0].inline_data: + yield self.__build_full_text_response(text) + text = '' + yield llm_response + if ( + message.server_content.input_transcription + and message.server_content.input_transcription.text + ): + user_text = message.server_content.input_transcription.text + parts = [ + types.Part.from_text( + text=user_text, + ) + ] + llm_response = LlmResponse( + content=types.Content(role='user', parts=parts) + ) + yield llm_response + if ( + message.server_content.output_transcription + and message.server_content.output_transcription.text + ): + # TODO: Right now, we just support output_transcription without + # changing interface and data protocol. Later, we can consider to + # support output_transcription as a separate field in LlmResponse. + + # Transcription is always considered as partial event + # We rely on other control signals to determine when to yield the + # full text response(turn_complete, interrupted, or tool_call). + text += message.server_content.output_transcription.text + parts = [ + types.Part.from_text( + text=message.server_content.output_transcription.text + ) + ] + llm_response = LlmResponse( + content=types.Content(role='model', parts=parts), partial=True + ) + yield llm_response + + if message.server_content.turn_complete: + if text: + yield self.__build_full_text_response(text) + text = '' + yield LlmResponse( + turn_complete=True, + interrupted=message.server_content.interrupted, + ) + break + # in case of empty content or parts, we sill surface it + # in case it's an interrupted message, we merge the previous partial + # text. Other we don't merge. because content can be none when model + # safety threshold is triggered + if message.server_content.interrupted and text: yield self.__build_full_text_response(text) text = '' - yield llm_response - if ( - message.server_content.input_transcription - and message.server_content.input_transcription.text - ): - user_text = message.server_content.input_transcription.text - parts = [ - types.Part.from_text( - text=user_text, - ) - ] - llm_response = LlmResponse( - content=types.Content(role='user', parts=parts) - ) - yield llm_response - if ( - message.server_content.output_transcription - and message.server_content.output_transcription.text - ): - # TODO: Right now, we just support output_transcription without - # changing interface and data protocol. Later, we can consider to - # support output_transcription as a separate field in LlmResponse. - - # Transcription is always considered as partial event - # We rely on other control signals to determine when to yield the - # full text response(turn_complete, interrupted, or tool_call). - text += message.server_content.output_transcription.text - parts = [ - types.Part.from_text( - text=message.server_content.output_transcription.text - ) - ] - llm_response = LlmResponse( - content=types.Content(role='model', parts=parts), partial=True - ) - yield llm_response - - if message.server_content.turn_complete: + yield LlmResponse(interrupted=message.server_content.interrupted) + if message.tool_call: if text: yield self.__build_full_text_response(text) text = '' - yield LlmResponse( - turn_complete=True, interrupted=message.server_content.interrupted + parts = [ + types.Part(function_call=function_call) + for function_call in message.tool_call.function_calls + ] + yield LlmResponse(content=types.Content(role='model', parts=parts)) + if message.session_resumption_update: + logger.info('Redeived session reassumption message: %s', message) + yield ( + LlmResponse( + live_session_resumption_update=message.session_resumption_update + ) ) - break - # in case of empty content or parts, we sill surface it - # in case it's an interrupted message, we merge the previous partial - # text. Other we don't merge. because content can be none when model - # safety threshold is triggered - if message.server_content.interrupted and text: - yield self.__build_full_text_response(text) - text = '' - yield LlmResponse(interrupted=message.server_content.interrupted) - if message.tool_call: - if text: - yield self.__build_full_text_response(text) - text = '' - parts = [ - types.Part(function_call=function_call) - for function_call in message.tool_call.function_calls - ] - yield LlmResponse(content=types.Content(role='model', parts=parts)) - if message.session_resumption_update: - logger.info('Redeived session reassumption message: %s', message) - yield ( - LlmResponse( - live_session_resumption_update=message.session_resumption_update - ) - ) async def close(self): """Closes the llm server connection.""" diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index b1cad1c5..86515db1 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -32,6 +32,7 @@ from google.genai.types import FinishReason from typing_extensions import override from .. import version +from ..utils.context_utils import Aclosing from ..utils.variant_utils import GoogleLLMVariant from .base_llm import BaseLlm from .base_llm_connection import BaseLlmConnection @@ -141,39 +142,40 @@ class Gemini(BaseLlm): # contents are sent, we send an accumulated event which contains all the # previous partial content. The only difference is bidi rely on # complete_turn flag to detect end while sse depends on finish_reason. - async for response in responses: - logger.debug(_build_response_log(response)) - llm_response = LlmResponse.create(response) - usage_metadata = llm_response.usage_metadata - if ( - llm_response.content - and llm_response.content.parts - and llm_response.content.parts[0].text - ): - part0 = llm_response.content.parts[0] - if part0.thought: - thought_text += part0.text - else: - text += part0.text - llm_response.partial = True - elif (thought_text or text) and ( - not llm_response.content - or not llm_response.content.parts - # don't yield the merged text event when receiving audio data - or not llm_response.content.parts[0].inline_data - ): - parts = [] - if thought_text: - parts.append(types.Part(text=thought_text, thought=True)) - if text: - parts.append(types.Part.from_text(text=text)) - yield LlmResponse( - content=types.ModelContent(parts=parts), - usage_metadata=llm_response.usage_metadata, - ) - thought_text = '' - text = '' - yield llm_response + async with Aclosing(responses) as agen: + async for response in agen: + logger.debug(_build_response_log(response)) + llm_response = LlmResponse.create(response) + usage_metadata = llm_response.usage_metadata + if ( + llm_response.content + and llm_response.content.parts + and llm_response.content.parts[0].text + ): + part0 = llm_response.content.parts[0] + if part0.thought: + thought_text += part0.text + else: + text += part0.text + llm_response.partial = True + elif (thought_text or text) and ( + not llm_response.content + or not llm_response.content.parts + # don't yield the merged text event when receiving audio data + or not llm_response.content.parts[0].inline_data + ): + parts = [] + if thought_text: + parts.append(types.Part(text=thought_text, thought=True)) + if text: + parts.append(types.Part.from_text(text=text)) + yield LlmResponse( + content=types.ModelContent(parts=parts), + usage_metadata=llm_response.usage_metadata, + ) + thought_text = '' + text = '' + yield llm_response # generate an aggregated content at the end regardless the # response.candidates[0].finish_reason diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 51fdb965..45d0c81c 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -53,6 +53,7 @@ from .sessions.in_memory_session_service import InMemorySessionService from .sessions.session import Session from .telemetry import tracer from .tools.base_toolset import BaseToolset +from .utils.context_utils import Aclosing logger = logging.getLogger('google_adk.' + __name__) @@ -146,13 +147,16 @@ class Runner: async def _invoke_run_async(): try: - async for event in self.run_async( - user_id=user_id, - session_id=session_id, - new_message=new_message, - run_config=run_config, - ): - event_queue.put(event) + async with Aclosing( + self.run_async( + user_id=user_id, + session_id=session_id, + new_message=new_message, + run_config=run_config, + ) + ) as agen: + async for event in agen: + event_queue.put(event) finally: event_queue.put(None) @@ -195,47 +199,55 @@ class Runner: Yields: The events generated by the agent. """ - with tracer.start_as_current_span('invocation'): - session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id - ) - if not session: - raise ValueError(f'Session not found: {session_id}') - invocation_context = self._new_invocation_context( - session, - new_message=new_message, - run_config=run_config, - ) - root_agent = self.agent - - # Modify user message before execution. - modified_user_message = ( - await invocation_context.plugin_manager.run_on_user_message_callback( - invocation_context=invocation_context, user_message=new_message - ) - ) - if modified_user_message is not None: - new_message = modified_user_message - - if new_message: - await self._append_new_message_to_session( - session, - new_message, - invocation_context, - run_config.save_input_blobs_as_artifacts, - state_delta, + async def _run_with_trace( + new_message: types.Content, + ) -> AsyncGenerator[Event, None]: + with tracer.start_as_current_span('invocation'): + session = await self.session_service.get_session( + app_name=self.app_name, user_id=user_id, session_id=session_id ) + if not session: + raise ValueError(f'Session not found: {session_id}') - invocation_context.agent = self._find_agent_to_run(session, root_agent) + invocation_context = self._new_invocation_context( + session, + new_message=new_message, + run_config=run_config, + ) + root_agent = self.agent - async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: - async for event in ctx.agent.run_async(ctx): - yield event + # Modify user message before execution. + modified_user_message = await invocation_context.plugin_manager.run_on_user_message_callback( + invocation_context=invocation_context, user_message=new_message + ) + if modified_user_message is not None: + new_message = modified_user_message - async for event in self._exec_with_plugin( - invocation_context, session, execute - ): + if new_message: + await self._append_new_message_to_session( + session, + new_message, + invocation_context, + run_config.save_input_blobs_as_artifacts, + state_delta, + ) + + invocation_context.agent = self._find_agent_to_run(session, root_agent) + + async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: + async with Aclosing(ctx.agent.run_async(ctx)) as agen: + async for event in agen: + yield event + + async with Aclosing( + self._exec_with_plugin(invocation_context, session, execute) + ) as agen: + async for event in agen: + yield event + + async with Aclosing(_run_with_trace(new_message)) as agen: + async for event in agen: yield event async def _exec_with_plugin( @@ -274,14 +286,17 @@ class Runner: yield early_exit_event else: # Step 2: Otherwise continue with normal execution - async for event in execute_fn(invocation_context): - if not event.partial: - await self.session_service.append_event(session=session, event=event) - # Step 3: Run the on_event callbacks to optionally modify the event. - modified_event = await plugin_manager.run_on_event_callback( - invocation_context=invocation_context, event=event - ) - yield (modified_event if modified_event else event) + async with Aclosing(execute_fn(invocation_context)) as agen: + async for event in agen: + if not event.partial: + await self.session_service.append_event( + session=session, event=event + ) + # Step 3: Run the on_event callbacks to optionally modify the event. + modified_event = await plugin_manager.run_on_event_callback( + invocation_context=invocation_context, event=event + ) + yield (modified_event if modified_event else event) # Step 4: Run the after_run callbacks to optionally modify the context. await plugin_manager.run_after_run_callback( @@ -439,13 +454,15 @@ class Runner: ) async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: - async for event in ctx.agent.run_live(ctx): - yield event + async with Aclosing(ctx.agent.run_live(ctx)) as agen: + async for event in agen: + yield event - async for event in self._exec_with_plugin( - invocation_context, session, execute - ): - yield event + async with Aclosing( + self._exec_with_plugin(invocation_context, session, execute) + ) as agen: + async for event in agen: + yield event def _find_agent_to_run( self, session: Session, root_agent: BaseAgent diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index c0d07238..6a1edcc6 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -26,6 +26,7 @@ from typing_extensions import override from . import _automatic_function_calling_util from ..agents.common_configs import AgentRefConfig from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService from .base_tool import BaseTool from .tool_configs import BaseToolConfig @@ -141,13 +142,16 @@ class AgentTool(BaseTool): ) last_event = None - async for event in runner.run_async( - user_id=session.user_id, session_id=session.id, new_message=content - ): - # Forward state delta to parent session. - if event.actions.state_delta: - tool_context.state.update(event.actions.state_delta) - last_event = event + async with Aclosing( + runner.run_async( + user_id=session.user_id, session_id=session.id, new_message=content + ) + ) as agen: + async for event in agen: + # Forward state delta to parent session. + if event.actions.state_delta: + tool_context.state.update(event.actions.state_delta) + last_event = event if not last_event or not last_event.content or not last_event.content.parts: return '' diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 2687f120..69f5934b 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -22,6 +22,7 @@ from typing import Optional from google.genai import types from typing_extensions import override +from ..utils.context_utils import Aclosing from ._automatic_function_calling_util import build_function_declaration from .base_tool import BaseTool from .tool_context import ToolContext @@ -136,8 +137,9 @@ You could retry calling this tool, but it is IMPORTANT for you to provide all th ].stream if 'tool_context' in signature.parameters: args_to_call['tool_context'] = tool_context - async for item in self.func(**args_to_call): - yield item + async with Aclosing(self.func(**args_to_call)) as agen: + async for item in agen: + yield item def _get_mandatory_args( self, diff --git a/src/google/adk/utils/context_utils.py b/src/google/adk/utils/context_utils.py new file mode 100644 index 00000000..243d5edf --- /dev/null +++ b/src/google/adk/utils/context_utils.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for ADK context management. + +This module is for ADK internal use only. +Please do not rely on the implementation details. +""" + +from __future__ import annotations + +from contextlib import AbstractAsyncContextManager +from typing import Any +from typing import AsyncGenerator + + +class Aclosing(AbstractAsyncContextManager): + """Async context manager for safely finalizing an asynchronously cleaned-up + resource such as an async generator, calling its ``aclose()`` method. + Needed to correctly close contexts for OTel spans. + See https://github.com/google/adk-python/issues/1670#issuecomment-3115891100. + + Based on + https://docs.python.org/3/library/contextlib.html#contextlib.aclosing + which is available in Python 3.10+. + + TODO: replace all occurences with contextlib.aclosing once Python 3.9 is no + longer supported. + """ + + def __init__(self, async_generator: AsyncGenerator[Any, None]): + self.async_generator = async_generator + + async def __aenter__(self): + return self.async_generator + + async def __aexit__(self, *exc_info): + await self.async_generator.aclose() diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 9004245c..e37f856e 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -47,6 +47,9 @@ class MockAsyncIterator: except StopIteration as exc: raise StopAsyncIteration from exc + async def aclose(self): + pass + @pytest.fixture def generate_content_response():