fix: aclose all async generators to fix OTel tracing context

See https://github.com/google/adk-python/issues/1670#issuecomment-3115891100

PiperOrigin-RevId: 794659547
This commit is contained in:
Kacper Jawoszek
2025-08-13 11:17:52 -07:00
committed by Copybara-Service
parent c5af44cfc0
commit a30c63c593
23 changed files with 735 additions and 514 deletions
+20 -8
View File
@@ -46,13 +46,19 @@ async def main():
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
# TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is
# no longer supported.
agen = runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
)
try:
async for event in agen:
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
finally:
await agen.aclose()
async def run_prompt_bytes(session: Session, new_message: str):
content = types.Content(
@@ -64,14 +70,20 @@ async def main():
],
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
# TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is
# no longer supported.
agen = runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
run_config=RunConfig(save_input_blobs_as_artifacts=True),
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
)
try:
async for event in agen:
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')
finally:
await agen.aclose()
start_time = time.time()
print('Start time:', start_time)
@@ -24,6 +24,8 @@ from typing import Callable
from typing import Optional
import uuid
from ...utils.context_utils import Aclosing
try:
from a2a.server.agent_execution import AgentExecutor
from a2a.server.agent_execution.context import RequestContext
@@ -212,12 +214,13 @@ class A2aAgentExecutor(AgentExecutor):
)
task_result_aggregator = TaskResultAggregator()
async for adk_event in runner.run_async(**run_args):
for a2a_event in convert_event_to_a2a_events(
adk_event, invocation_context, context.task_id, context.context_id
):
task_result_aggregator.process_event(a2a_event)
await event_queue.enqueue_event(a2a_event)
async with Aclosing(runner.run_async(**run_args)) as agen:
async for adk_event in agen:
for a2a_event in convert_event_to_a2a_events(
adk_event, invocation_context, context.task_id, context.context_id
):
task_result_aggregator.process_event(a2a_event)
await event_queue.enqueue_event(a2a_event)
# publish the task result event - this is final
if (
+34 -20
View File
@@ -39,6 +39,7 @@ from typing_extensions import override
from typing_extensions import TypeAlias
from ..events.event import Event
from ..utils.context_utils import Aclosing
from ..utils.feature_decorator import experimental
from .base_agent_config import BaseAgentConfig
from .callback_context import CallbackContext
@@ -212,21 +213,27 @@ class BaseAgent(BaseModel):
Event: the events generated by the agent.
"""
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
ctx = self._create_invocation_context(parent_context)
async def _run_with_trace() -> AsyncGenerator[Event, None]:
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
ctx = self._create_invocation_context(parent_context)
if event := await self.__handle_before_agent_callback(ctx):
yield event
if ctx.end_invocation:
return
if event := await self.__handle_before_agent_callback(ctx):
yield event
if ctx.end_invocation:
return
async for event in self._run_async_impl(ctx):
yield event
async with Aclosing(self._run_async_impl(ctx)) as agen:
async for event in agen:
yield event
if ctx.end_invocation:
return
if ctx.end_invocation:
return
if event := await self.__handle_after_agent_callback(ctx):
if event := await self.__handle_after_agent_callback(ctx):
yield event
async with Aclosing(_run_with_trace()) as agen:
async for event in agen:
yield event
@final
@@ -243,18 +250,25 @@ class BaseAgent(BaseModel):
Yields:
Event: the events generated by the agent.
"""
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
ctx = self._create_invocation_context(parent_context)
if event := await self.__handle_before_agent_callback(ctx):
yield event
if ctx.end_invocation:
return
async def _run_with_trace() -> AsyncGenerator[Event, None]:
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
ctx = self._create_invocation_context(parent_context)
async for event in self._run_live_impl(ctx):
yield event
if event := await self.__handle_before_agent_callback(ctx):
yield event
if ctx.end_invocation:
return
if event := await self.__handle_after_agent_callback(ctx):
async with Aclosing(self._run_live_impl(ctx)) as agen:
async for event in agen:
yield event
if event := await self.__handle_after_agent_callback(ctx):
yield event
async with Aclosing(_run_with_trace()) as agen:
async for event in agen:
yield event
async def _run_async_impl(
+11 -8
View File
@@ -51,6 +51,7 @@ from ..tools.base_toolset import BaseToolset
from ..tools.function_tool import FunctionTool
from ..tools.tool_configs import ToolConfig
from ..tools.tool_context import ToolContext
from ..utils.context_utils import Aclosing
from ..utils.feature_decorator import experimental
from .base_agent import BaseAgent
from .base_agent_config import BaseAgentConfig
@@ -283,19 +284,21 @@ class LlmAgent(BaseAgent):
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
async for event in self._llm_flow.run_async(ctx):
self.__maybe_save_output_to_state(event)
yield event
async with Aclosing(self._llm_flow.run_async(ctx)) as agen:
async for event in agen:
self.__maybe_save_output_to_state(event)
yield event
@override
async def _run_live_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
async for event in self._llm_flow.run_live(ctx):
self.__maybe_save_output_to_state(event)
yield event
if ctx.end_invocation:
return
async with Aclosing(self._llm_flow.run_live(ctx)) as agen:
async for event in agen:
self.__maybe_save_output_to_state(event)
yield event
if ctx.end_invocation:
return
@property
def canonical_model(self) -> BaseLlm:
+6 -4
View File
@@ -27,6 +27,7 @@ from typing_extensions import override
from ..agents.invocation_context import InvocationContext
from ..events.event import Event
from ..utils.context_utils import Aclosing
from ..utils.feature_decorator import experimental
from .base_agent import BaseAgent
from .base_agent_config import BaseAgentConfig
@@ -58,10 +59,11 @@ class LoopAgent(BaseAgent):
while not self.max_iterations or times_looped < self.max_iterations:
for sub_agent in self.sub_agents:
should_exit = False
async for event in sub_agent.run_async(ctx):
yield event
if event.actions.escalate:
should_exit = True
async with Aclosing(sub_agent.run_async(ctx)) as agen:
async for event in agen:
yield event
if event.actions.escalate:
should_exit = True
if should_exit:
return
+5 -2
View File
@@ -26,6 +26,7 @@ from typing import Type
from typing_extensions import override
from ..events.event import Event
from ..utils.context_utils import Aclosing
from .base_agent import BaseAgent
from .base_agent_config import BaseAgentConfig
from .invocation_context import InvocationContext
@@ -111,8 +112,10 @@ class ParallelAgent(BaseAgent):
)
for sub_agent in self.sub_agents
]
async for event in _merge_agent_run(agent_runs):
yield event
async with Aclosing(_merge_agent_run(agent_runs)) as agen:
async for event in agen:
yield event
@override
async def _run_live_impl(
+7 -4
View File
@@ -22,6 +22,7 @@ from typing import Type
from typing_extensions import override
from ..events.event import Event
from ..utils.context_utils import Aclosing
from .base_agent import BaseAgent
from .base_agent import BaseAgentConfig
from .invocation_context import InvocationContext
@@ -40,8 +41,9 @@ class SequentialAgent(BaseAgent):
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
for sub_agent in self.sub_agents:
async for event in sub_agent.run_async(ctx):
yield event
async with Aclosing(sub_agent.run_async(ctx)) as agen:
async for event in agen:
yield event
@override
async def _run_live_impl(
@@ -78,5 +80,6 @@ class SequentialAgent(BaseAgent):
do not generate any text other than the function call."""
for sub_agent in self.sub_agents:
async for event in sub_agent.run_live(ctx):
yield event
async with Aclosing(sub_agent.run_live(ctx)) as agen:
async for event in agen:
yield event
+34 -23
View File
@@ -73,6 +73,7 @@ from ..memory.base_memory_service import BaseMemoryService
from ..runners import Runner
from ..sessions.base_session_service import BaseSessionService
from ..sessions.session import Session
from ..utils.context_utils import Aclosing
from .cli_eval import EVAL_SESSION_ID_PREFIX
from .cli_eval import EvalStatus
from .utils import cleanup
@@ -828,14 +829,16 @@ class AdkWebServer:
if not session:
raise HTTPException(status_code=404, detail="Session not found")
runner = await self.get_runner_async(req.app_name)
events = [
event
async for event in runner.run_async(
events = []
async with Aclosing(
runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
)
]
) as agen:
events = [event async for event in agen]
logger.info("Generated %s events in agent run", len(events))
logger.debug("Events generated: %s", events)
return events
@@ -856,19 +859,24 @@ class AdkWebServer:
StreamingMode.SSE if req.streaming else StreamingMode.NONE
)
runner = await self.get_runner_async(req.app_name)
async for event in runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
run_config=RunConfig(streaming_mode=stream_mode),
):
# Format as SSE data
sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
logger.debug(
"Generated event in agent run streaming: %s", sse_event
)
yield f"data: {sse_event}\n\n"
async with Aclosing(
runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
run_config=RunConfig(streaming_mode=stream_mode),
)
) as agen:
async for event in agen:
# Format as SSE data
sse_event = event.model_dump_json(
exclude_none=True, by_alias=True
)
logger.debug(
"Generated event in agent run streaming: %s", sse_event
)
yield f"data: {sse_event}\n\n"
except Exception as e:
logger.exception("Error in event_generator: %s", e)
# You might want to yield an error event here
@@ -954,12 +962,15 @@ class AdkWebServer:
async def forward_events():
runner = await self.get_runner_async(app_name)
async for event in runner.run_live(
session=session, live_request_queue=live_request_queue
):
await websocket.send_text(
event.model_dump_json(exclude_none=True, by_alias=True)
)
async with Aclosing(
runner.run_live(
session=session, live_request_queue=live_request_queue
)
) as agen:
async for event in agen:
await websocket.send_text(
event.model_dump_json(exclude_none=True, by_alias=True)
)
async def process_messages():
try:
+23 -14
View File
@@ -30,6 +30,7 @@ from ..runners import Runner
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from ..utils.context_utils import Aclosing
from .utils import envs
from .utils.agent_loader import AgentLoader
@@ -65,12 +66,15 @@ async def run_input_file(
for query in input_file.queries:
click.echo(f'[user]: {query}')
content = types.Content(role='user', parts=[types.Part(text=query)])
async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
):
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
async with Aclosing(
runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
)
) as agen:
async for event in agen:
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
return session
@@ -94,14 +98,19 @@ async def run_interactively(
continue
if query == 'exit':
break
async for event in runner.run_async(
user_id=session.user_id,
session_id=session.id,
new_message=types.Content(role='user', parts=[types.Part(text=query)]),
):
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
async with Aclosing(
runner.run_async(
user_id=session.user_id,
session_id=session.id,
new_message=types.Content(
role='user', parts=[types.Part(text=query)]
),
)
) as agen:
async for event in agen:
if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}')
await runner.close()
+11 -8
View File
@@ -45,6 +45,7 @@ from ..evaluation.eval_result import EvalCaseResult
from ..evaluation.evaluator import EvalStatus
from ..evaluation.evaluator import Evaluator
from ..sessions.base_session_service import BaseSessionService
from ..utils.context_utils import Aclosing
logger = logging.getLogger("google_adk." + __name__)
@@ -159,10 +160,11 @@ async def _collect_inferences(
"""
inference_results = []
for inference_request in inference_requests:
async for inference_result in eval_service.perform_inference(
inference_request=inference_request
):
inference_results.append(inference_result)
async with Aclosing(
eval_service.perform_inference(inference_request=inference_request)
) as agen:
async for inference_result in agen:
inference_results.append(inference_result)
return inference_results
@@ -180,10 +182,11 @@ async def _collect_eval_results(
inference_results=inference_results,
evaluate_config=EvaluateConfig(eval_metrics=eval_metrics),
)
async for eval_result in eval_service.evaluate(
evaluate_request=evaluate_request
):
eval_results.append(eval_result)
async with Aclosing(
eval_service.evaluate(evaluate_request=evaluate_request)
) as agen:
async for eval_result in agen:
eval_results.append(eval_result)
return eval_results
+14 -11
View File
@@ -32,6 +32,7 @@ from pydantic import BaseModel
from pydantic import ValidationError
from ..agents.base_agent import BaseAgent
from ..utils.context_utils import Aclosing
from .constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
from .eval_case import IntermediateData
from .eval_case import Invocation
@@ -538,10 +539,11 @@ class AgentEvaluator:
# Generate inferences
inference_results = []
for inference_request in inference_requests:
async for inference_result in eval_service.perform_inference(
inference_request=inference_request
):
inference_results.append(inference_result)
async with Aclosing(
eval_service.perform_inference(inference_request=inference_request)
) as agen:
async for inference_result in agen:
inference_results.append(inference_result)
# Evaluate metrics
# As we perform more than one run for an eval case, we collect eval results
@@ -551,14 +553,15 @@ class AgentEvaluator:
inference_results=inference_results,
evaluate_config=EvaluateConfig(eval_metrics=eval_metrics),
)
async for eval_result in eval_service.evaluate(
evaluate_request=evaluate_request
):
eval_id = eval_result.eval_id
if eval_id not in eval_results_by_eval_id:
eval_results_by_eval_id[eval_id] = []
async with Aclosing(
eval_service.evaluate(evaluate_request=evaluate_request)
) as agen:
async for eval_result in agen:
eval_id = eval_result.eval_id
if eval_id not in eval_results_by_eval_id:
eval_results_by_eval_id[eval_id] = []
eval_results_by_eval_id[eval_id].append(eval_result)
eval_results_by_eval_id[eval_id].append(eval_result)
return eval_results_by_eval_id
@@ -30,6 +30,7 @@ from ..runners import Runner
from ..sessions.base_session_service import BaseSessionService
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from ..utils.context_utils import Aclosing
from .eval_case import EvalCase
from .eval_case import IntermediateData
from .eval_case import Invocation
@@ -189,18 +190,25 @@ class EvaluationGenerator:
tool_uses = []
invocation_id = ""
async for event in runner.run_async(
user_id=user_id, session_id=session_id, new_message=user_content
):
invocation_id = (
event.invocation_id if not invocation_id else invocation_id
)
async with Aclosing(
runner.run_async(
user_id=user_id, session_id=session_id, new_message=user_content
)
) as agen:
async for event in agen:
invocation_id = (
event.invocation_id if not invocation_id else invocation_id
)
if event.is_final_response() and event.content and event.content.parts:
final_response = event.content
elif event.get_function_calls():
for call in event.get_function_calls():
tool_uses.append(call)
if (
event.is_final_response()
and event.content
and event.content.parts
):
final_response = event.content
elif event.get_function_calls():
for call in event.get_function_calls():
tool_uses.append(call)
response_invocations.append(
Invocation(
+17 -15
View File
@@ -24,6 +24,7 @@ from ..models.base_llm import BaseLlm
from ..models.llm_request import LlmRequest
from ..models.llm_response import LlmResponse
from ..models.registry import LLMRegistry
from ..utils.context_utils import Aclosing
from .eval_case import Invocation
from .eval_metrics import EvalMetric
from .evaluator import EvaluationResult
@@ -109,21 +110,22 @@ class LlmAsJudge(Evaluator):
num_samples = self._judge_model_options.num_samples
invocation_result_samples = []
for _ in range(num_samples):
async for llm_response in self._judge_model.generate_content_async(
llm_request
):
# Non-streaming call, so there is only one response content.
score = self.convert_auto_rater_response_to_score(llm_response)
invocation_result_samples.append(
PerInvocationResult(
actual_invocation=actual,
expected_invocation=expected,
score=score,
eval_status=get_eval_status(
score, self._eval_metric.threshold
),
)
)
async with Aclosing(
self._judge_model.generate_content_async(llm_request)
) as agen:
async for llm_response in agen:
# Non-streaming call, so there is only one response content.
score = self.convert_auto_rater_response_to_score(llm_response)
invocation_result_samples.append(
PerInvocationResult(
actual_invocation=actual,
expected_invocation=expected,
score=score,
eval_status=get_eval_status(
score, self._eval_metric.threshold
),
)
)
if not invocation_result_samples:
continue
per_invocation_results.append(
@@ -39,6 +39,7 @@ from ...code_executors.code_executor_context import CodeExecutorContext
from ...events.event import Event
from ...events.event_actions import EventActions
from ...models.llm_response import LlmResponse
from ...utils.context_utils import Aclosing
from ._base_llm_processor import BaseLlmRequestProcessor
from ._base_llm_processor import BaseLlmResponseProcessor
@@ -122,8 +123,11 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
if not invocation_context.agent.code_executor:
return
async for event in _run_pre_processor(invocation_context, llm_request):
yield event
async with Aclosing(
_run_pre_processor(invocation_context, llm_request)
) as agen:
async for event in agen:
yield event
# Convert the code execution parts to text parts.
if not isinstance(invocation_context.agent.code_executor, BaseCodeExecutor):
@@ -152,8 +156,11 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
if llm_response.partial:
return
async for event in _run_post_processor(invocation_context, llm_response):
yield event
async with Aclosing(
_run_post_processor(invocation_context, llm_response)
) as agen:
async for event in agen:
yield event
response_processor = _CodeExecutionResponseProcessor()
File diff suppressed because it is too large Load Diff
+28 -21
View File
@@ -40,6 +40,7 @@ from ...telemetry import trace_tool_call
from ...telemetry import tracer
from ...tools.base_tool import BaseTool
from ...tools.tool_context import ToolContext
from ...utils.context_utils import Aclosing
if TYPE_CHECKING:
from ...agents.llm_agent import LlmAgent
@@ -510,21 +511,24 @@ async def _process_function_live_helper(
# we require the function to be a async generator function
async def run_tool_and_update_queue(tool, function_args, tool_context):
try:
async for result in __call_tool_live(
tool=tool,
args=function_args,
tool_context=tool_context,
invocation_context=invocation_context,
):
updated_content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=f'Function {tool.name} returned: {result}'
)
],
)
invocation_context.live_request_queue.send_content(updated_content)
async with Aclosing(
__call_tool_live(
tool=tool,
args=function_args,
tool_context=tool_context,
invocation_context=invocation_context,
)
) as agen:
async for result in agen:
updated_content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=f'Function {tool.name} returned: {result}'
)
],
)
invocation_context.live_request_queue.send_content(updated_content)
except asyncio.CancelledError:
raise # Re-raise to properly propagate the cancellation
@@ -586,12 +590,15 @@ async def __call_tool_live(
invocation_context: InvocationContext,
) -> AsyncGenerator[Event, None]:
"""Calls the tool asynchronously (awaiting the coroutine)."""
async for item in tool._call_live(
args=args,
tool_context=tool_context,
invocation_context=invocation_context,
):
yield item
async with Aclosing(
tool._call_live(
args=args,
tool_context=tool_context,
invocation_context=invocation_context,
)
) as agen:
async for item in agen:
yield item
async def __call_tool_async(
+81 -78
View File
@@ -21,6 +21,7 @@ from typing import Union
from google.genai import live
from google.genai import types
from ..utils.context_utils import Aclosing
from .base_llm_connection import BaseLlmConnection
from .llm_response import LlmResponse
@@ -142,90 +143,92 @@ class GeminiLlmConnection(BaseLlmConnection):
"""
text = ''
async for message in self._gemini_session.receive():
logger.debug('Got LLM Live message: %s', message)
if message.server_content:
content = message.server_content.model_turn
if content and content.parts:
llm_response = LlmResponse(
content=content, interrupted=message.server_content.interrupted
)
if content.parts[0].text:
text += content.parts[0].text
llm_response.partial = True
# don't yield the merged text event when receiving audio data
elif text and not content.parts[0].inline_data:
async with Aclosing(self._gemini_session.receive()) as agen:
async for message in agen:
logger.debug('Got LLM Live message: %s', message)
if message.server_content:
content = message.server_content.model_turn
if content and content.parts:
llm_response = LlmResponse(
content=content, interrupted=message.server_content.interrupted
)
if content.parts[0].text:
text += content.parts[0].text
llm_response.partial = True
# don't yield the merged text event when receiving audio data
elif text and not content.parts[0].inline_data:
yield self.__build_full_text_response(text)
text = ''
yield llm_response
if (
message.server_content.input_transcription
and message.server_content.input_transcription.text
):
user_text = message.server_content.input_transcription.text
parts = [
types.Part.from_text(
text=user_text,
)
]
llm_response = LlmResponse(
content=types.Content(role='user', parts=parts)
)
yield llm_response
if (
message.server_content.output_transcription
and message.server_content.output_transcription.text
):
# TODO: Right now, we just support output_transcription without
# changing interface and data protocol. Later, we can consider to
# support output_transcription as a separate field in LlmResponse.
# Transcription is always considered as partial event
# We rely on other control signals to determine when to yield the
# full text response(turn_complete, interrupted, or tool_call).
text += message.server_content.output_transcription.text
parts = [
types.Part.from_text(
text=message.server_content.output_transcription.text
)
]
llm_response = LlmResponse(
content=types.Content(role='model', parts=parts), partial=True
)
yield llm_response
if message.server_content.turn_complete:
if text:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(
turn_complete=True,
interrupted=message.server_content.interrupted,
)
break
# in case of empty content or parts, we sill surface it
# in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model
# safety threshold is triggered
if message.server_content.interrupted and text:
yield self.__build_full_text_response(text)
text = ''
yield llm_response
if (
message.server_content.input_transcription
and message.server_content.input_transcription.text
):
user_text = message.server_content.input_transcription.text
parts = [
types.Part.from_text(
text=user_text,
)
]
llm_response = LlmResponse(
content=types.Content(role='user', parts=parts)
)
yield llm_response
if (
message.server_content.output_transcription
and message.server_content.output_transcription.text
):
# TODO: Right now, we just support output_transcription without
# changing interface and data protocol. Later, we can consider to
# support output_transcription as a separate field in LlmResponse.
# Transcription is always considered as partial event
# We rely on other control signals to determine when to yield the
# full text response(turn_complete, interrupted, or tool_call).
text += message.server_content.output_transcription.text
parts = [
types.Part.from_text(
text=message.server_content.output_transcription.text
)
]
llm_response = LlmResponse(
content=types.Content(role='model', parts=parts), partial=True
)
yield llm_response
if message.server_content.turn_complete:
yield LlmResponse(interrupted=message.server_content.interrupted)
if message.tool_call:
if text:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(
turn_complete=True, interrupted=message.server_content.interrupted
parts = [
types.Part(function_call=function_call)
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
if message.session_resumption_update:
logger.info('Redeived session reassumption message: %s', message)
yield (
LlmResponse(
live_session_resumption_update=message.session_resumption_update
)
)
break
# in case of empty content or parts, we sill surface it
# in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model
# safety threshold is triggered
if message.server_content.interrupted and text:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(interrupted=message.server_content.interrupted)
if message.tool_call:
if text:
yield self.__build_full_text_response(text)
text = ''
parts = [
types.Part(function_call=function_call)
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
if message.session_resumption_update:
logger.info('Redeived session reassumption message: %s', message)
yield (
LlmResponse(
live_session_resumption_update=message.session_resumption_update
)
)
async def close(self):
"""Closes the llm server connection."""
+35 -33
View File
@@ -32,6 +32,7 @@ from google.genai.types import FinishReason
from typing_extensions import override
from .. import version
from ..utils.context_utils import Aclosing
from ..utils.variant_utils import GoogleLLMVariant
from .base_llm import BaseLlm
from .base_llm_connection import BaseLlmConnection
@@ -141,39 +142,40 @@ class Gemini(BaseLlm):
# contents are sent, we send an accumulated event which contains all the
# previous partial content. The only difference is bidi rely on
# complete_turn flag to detect end while sse depends on finish_reason.
async for response in responses:
logger.debug(_build_response_log(response))
llm_response = LlmResponse.create(response)
usage_metadata = llm_response.usage_metadata
if (
llm_response.content
and llm_response.content.parts
and llm_response.content.parts[0].text
):
part0 = llm_response.content.parts[0]
if part0.thought:
thought_text += part0.text
else:
text += part0.text
llm_response.partial = True
elif (thought_text or text) and (
not llm_response.content
or not llm_response.content.parts
# don't yield the merged text event when receiving audio data
or not llm_response.content.parts[0].inline_data
):
parts = []
if thought_text:
parts.append(types.Part(text=thought_text, thought=True))
if text:
parts.append(types.Part.from_text(text=text))
yield LlmResponse(
content=types.ModelContent(parts=parts),
usage_metadata=llm_response.usage_metadata,
)
thought_text = ''
text = ''
yield llm_response
async with Aclosing(responses) as agen:
async for response in agen:
logger.debug(_build_response_log(response))
llm_response = LlmResponse.create(response)
usage_metadata = llm_response.usage_metadata
if (
llm_response.content
and llm_response.content.parts
and llm_response.content.parts[0].text
):
part0 = llm_response.content.parts[0]
if part0.thought:
thought_text += part0.text
else:
text += part0.text
llm_response.partial = True
elif (thought_text or text) and (
not llm_response.content
or not llm_response.content.parts
# don't yield the merged text event when receiving audio data
or not llm_response.content.parts[0].inline_data
):
parts = []
if thought_text:
parts.append(types.Part(text=thought_text, thought=True))
if text:
parts.append(types.Part.from_text(text=text))
yield LlmResponse(
content=types.ModelContent(parts=parts),
usage_metadata=llm_response.usage_metadata,
)
thought_text = ''
text = ''
yield llm_response
# generate an aggregated content at the end regardless the
# response.candidates[0].finish_reason
+74 -57
View File
@@ -53,6 +53,7 @@ from .sessions.in_memory_session_service import InMemorySessionService
from .sessions.session import Session
from .telemetry import tracer
from .tools.base_toolset import BaseToolset
from .utils.context_utils import Aclosing
logger = logging.getLogger('google_adk.' + __name__)
@@ -146,13 +147,16 @@ class Runner:
async def _invoke_run_async():
try:
async for event in self.run_async(
user_id=user_id,
session_id=session_id,
new_message=new_message,
run_config=run_config,
):
event_queue.put(event)
async with Aclosing(
self.run_async(
user_id=user_id,
session_id=session_id,
new_message=new_message,
run_config=run_config,
)
) as agen:
async for event in agen:
event_queue.put(event)
finally:
event_queue.put(None)
@@ -195,47 +199,55 @@ class Runner:
Yields:
The events generated by the agent.
"""
with tracer.start_as_current_span('invocation'):
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
if not session:
raise ValueError(f'Session not found: {session_id}')
invocation_context = self._new_invocation_context(
session,
new_message=new_message,
run_config=run_config,
)
root_agent = self.agent
# Modify user message before execution.
modified_user_message = (
await invocation_context.plugin_manager.run_on_user_message_callback(
invocation_context=invocation_context, user_message=new_message
)
)
if modified_user_message is not None:
new_message = modified_user_message
if new_message:
await self._append_new_message_to_session(
session,
new_message,
invocation_context,
run_config.save_input_blobs_as_artifacts,
state_delta,
async def _run_with_trace(
new_message: types.Content,
) -> AsyncGenerator[Event, None]:
with tracer.start_as_current_span('invocation'):
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
)
if not session:
raise ValueError(f'Session not found: {session_id}')
invocation_context.agent = self._find_agent_to_run(session, root_agent)
invocation_context = self._new_invocation_context(
session,
new_message=new_message,
run_config=run_config,
)
root_agent = self.agent
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
async for event in ctx.agent.run_async(ctx):
yield event
# Modify user message before execution.
modified_user_message = await invocation_context.plugin_manager.run_on_user_message_callback(
invocation_context=invocation_context, user_message=new_message
)
if modified_user_message is not None:
new_message = modified_user_message
async for event in self._exec_with_plugin(
invocation_context, session, execute
):
if new_message:
await self._append_new_message_to_session(
session,
new_message,
invocation_context,
run_config.save_input_blobs_as_artifacts,
state_delta,
)
invocation_context.agent = self._find_agent_to_run(session, root_agent)
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
async with Aclosing(ctx.agent.run_async(ctx)) as agen:
async for event in agen:
yield event
async with Aclosing(
self._exec_with_plugin(invocation_context, session, execute)
) as agen:
async for event in agen:
yield event
async with Aclosing(_run_with_trace(new_message)) as agen:
async for event in agen:
yield event
async def _exec_with_plugin(
@@ -274,14 +286,17 @@ class Runner:
yield early_exit_event
else:
# Step 2: Otherwise continue with normal execution
async for event in execute_fn(invocation_context):
if not event.partial:
await self.session_service.append_event(session=session, event=event)
# Step 3: Run the on_event callbacks to optionally modify the event.
modified_event = await plugin_manager.run_on_event_callback(
invocation_context=invocation_context, event=event
)
yield (modified_event if modified_event else event)
async with Aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
if not event.partial:
await self.session_service.append_event(
session=session, event=event
)
# Step 3: Run the on_event callbacks to optionally modify the event.
modified_event = await plugin_manager.run_on_event_callback(
invocation_context=invocation_context, event=event
)
yield (modified_event if modified_event else event)
# Step 4: Run the after_run callbacks to optionally modify the context.
await plugin_manager.run_after_run_callback(
@@ -439,13 +454,15 @@ class Runner:
)
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
async for event in ctx.agent.run_live(ctx):
yield event
async with Aclosing(ctx.agent.run_live(ctx)) as agen:
async for event in agen:
yield event
async for event in self._exec_with_plugin(
invocation_context, session, execute
):
yield event
async with Aclosing(
self._exec_with_plugin(invocation_context, session, execute)
) as agen:
async for event in agen:
yield event
def _find_agent_to_run(
self, session: Session, root_agent: BaseAgent
+11 -7
View File
@@ -26,6 +26,7 @@ from typing_extensions import override
from . import _automatic_function_calling_util
from ..agents.common_configs import AgentRefConfig
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..utils.context_utils import Aclosing
from ._forwarding_artifact_service import ForwardingArtifactService
from .base_tool import BaseTool
from .tool_configs import BaseToolConfig
@@ -141,13 +142,16 @@ class AgentTool(BaseTool):
)
last_event = None
async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
):
# Forward state delta to parent session.
if event.actions.state_delta:
tool_context.state.update(event.actions.state_delta)
last_event = event
async with Aclosing(
runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content
)
) as agen:
async for event in agen:
# Forward state delta to parent session.
if event.actions.state_delta:
tool_context.state.update(event.actions.state_delta)
last_event = event
if not last_event or not last_event.content or not last_event.content.parts:
return ''

Some files were not shown because too many files have changed in this diff Show More