You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: aclose all async generators to fix OTel tracing context
See https://github.com/google/adk-python/issues/1670#issuecomment-3115891100 PiperOrigin-RevId: 794659547
This commit is contained in:
committed by
Copybara-Service
parent
c5af44cfc0
commit
a30c63c593
@@ -46,13 +46,19 @@ async def main():
|
||||
role='user', parts=[types.Part.from_text(text=new_message)]
|
||||
)
|
||||
print('** User says:', content.model_dump(exclude_none=True))
|
||||
async for event in runner.run_async(
|
||||
# TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is
|
||||
# no longer supported.
|
||||
agen = runner.run_async(
|
||||
user_id=user_id_1,
|
||||
session_id=session.id,
|
||||
new_message=content,
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||
)
|
||||
try:
|
||||
async for event in agen:
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||
finally:
|
||||
await agen.aclose()
|
||||
|
||||
async def run_prompt_bytes(session: Session, new_message: str):
|
||||
content = types.Content(
|
||||
@@ -64,14 +70,20 @@ async def main():
|
||||
],
|
||||
)
|
||||
print('** User says:', content.model_dump(exclude_none=True))
|
||||
async for event in runner.run_async(
|
||||
# TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is
|
||||
# no longer supported.
|
||||
agen = runner.run_async(
|
||||
user_id=user_id_1,
|
||||
session_id=session.id,
|
||||
new_message=content,
|
||||
run_config=RunConfig(save_input_blobs_as_artifacts=True),
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||
)
|
||||
try:
|
||||
async for event in agen:
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||
finally:
|
||||
await agen.aclose()
|
||||
|
||||
start_time = time.time()
|
||||
print('Start time:', start_time)
|
||||
|
||||
@@ -24,6 +24,8 @@ from typing import Callable
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from ...utils.context_utils import Aclosing
|
||||
|
||||
try:
|
||||
from a2a.server.agent_execution import AgentExecutor
|
||||
from a2a.server.agent_execution.context import RequestContext
|
||||
@@ -212,12 +214,13 @@ class A2aAgentExecutor(AgentExecutor):
|
||||
)
|
||||
|
||||
task_result_aggregator = TaskResultAggregator()
|
||||
async for adk_event in runner.run_async(**run_args):
|
||||
for a2a_event in convert_event_to_a2a_events(
|
||||
adk_event, invocation_context, context.task_id, context.context_id
|
||||
):
|
||||
task_result_aggregator.process_event(a2a_event)
|
||||
await event_queue.enqueue_event(a2a_event)
|
||||
async with Aclosing(runner.run_async(**run_args)) as agen:
|
||||
async for adk_event in agen:
|
||||
for a2a_event in convert_event_to_a2a_events(
|
||||
adk_event, invocation_context, context.task_id, context.context_id
|
||||
):
|
||||
task_result_aggregator.process_event(a2a_event)
|
||||
await event_queue.enqueue_event(a2a_event)
|
||||
|
||||
# publish the task result event - this is final
|
||||
if (
|
||||
|
||||
@@ -39,6 +39,7 @@ from typing_extensions import override
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from ..events.event import Event
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ..utils.feature_decorator import experimental
|
||||
from .base_agent_config import BaseAgentConfig
|
||||
from .callback_context import CallbackContext
|
||||
@@ -212,21 +213,27 @@ class BaseAgent(BaseModel):
|
||||
Event: the events generated by the agent.
|
||||
"""
|
||||
|
||||
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
||||
ctx = self._create_invocation_context(parent_context)
|
||||
async def _run_with_trace() -> AsyncGenerator[Event, None]:
|
||||
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
||||
ctx = self._create_invocation_context(parent_context)
|
||||
|
||||
if event := await self.__handle_before_agent_callback(ctx):
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
if event := await self.__handle_before_agent_callback(ctx):
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
async for event in self._run_async_impl(ctx):
|
||||
yield event
|
||||
async with Aclosing(self._run_async_impl(ctx)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
if event := await self.__handle_after_agent_callback(ctx):
|
||||
if event := await self.__handle_after_agent_callback(ctx):
|
||||
yield event
|
||||
|
||||
async with Aclosing(_run_with_trace()) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
@final
|
||||
@@ -243,18 +250,25 @@ class BaseAgent(BaseModel):
|
||||
Yields:
|
||||
Event: the events generated by the agent.
|
||||
"""
|
||||
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
||||
ctx = self._create_invocation_context(parent_context)
|
||||
|
||||
if event := await self.__handle_before_agent_callback(ctx):
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
async def _run_with_trace() -> AsyncGenerator[Event, None]:
|
||||
with tracer.start_as_current_span(f'agent_run [{self.name}]'):
|
||||
ctx = self._create_invocation_context(parent_context)
|
||||
|
||||
async for event in self._run_live_impl(ctx):
|
||||
yield event
|
||||
if event := await self.__handle_before_agent_callback(ctx):
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
if event := await self.__handle_after_agent_callback(ctx):
|
||||
async with Aclosing(self._run_live_impl(ctx)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
if event := await self.__handle_after_agent_callback(ctx):
|
||||
yield event
|
||||
|
||||
async with Aclosing(_run_with_trace()) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
async def _run_async_impl(
|
||||
|
||||
@@ -51,6 +51,7 @@ from ..tools.base_toolset import BaseToolset
|
||||
from ..tools.function_tool import FunctionTool
|
||||
from ..tools.tool_configs import ToolConfig
|
||||
from ..tools.tool_context import ToolContext
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ..utils.feature_decorator import experimental
|
||||
from .base_agent import BaseAgent
|
||||
from .base_agent_config import BaseAgentConfig
|
||||
@@ -283,19 +284,21 @@ class LlmAgent(BaseAgent):
|
||||
async def _run_async_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
async for event in self._llm_flow.run_async(ctx):
|
||||
self.__maybe_save_output_to_state(event)
|
||||
yield event
|
||||
async with Aclosing(self._llm_flow.run_async(ctx)) as agen:
|
||||
async for event in agen:
|
||||
self.__maybe_save_output_to_state(event)
|
||||
yield event
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
async for event in self._llm_flow.run_live(ctx):
|
||||
self.__maybe_save_output_to_state(event)
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
async with Aclosing(self._llm_flow.run_live(ctx)) as agen:
|
||||
async for event in agen:
|
||||
self.__maybe_save_output_to_state(event)
|
||||
yield event
|
||||
if ctx.end_invocation:
|
||||
return
|
||||
|
||||
@property
|
||||
def canonical_model(self) -> BaseLlm:
|
||||
|
||||
@@ -27,6 +27,7 @@ from typing_extensions import override
|
||||
|
||||
from ..agents.invocation_context import InvocationContext
|
||||
from ..events.event import Event
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ..utils.feature_decorator import experimental
|
||||
from .base_agent import BaseAgent
|
||||
from .base_agent_config import BaseAgentConfig
|
||||
@@ -58,10 +59,11 @@ class LoopAgent(BaseAgent):
|
||||
while not self.max_iterations or times_looped < self.max_iterations:
|
||||
for sub_agent in self.sub_agents:
|
||||
should_exit = False
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
if event.actions.escalate:
|
||||
should_exit = True
|
||||
async with Aclosing(sub_agent.run_async(ctx)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
if event.actions.escalate:
|
||||
should_exit = True
|
||||
|
||||
if should_exit:
|
||||
return
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Type
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from ..utils.context_utils import Aclosing
|
||||
from .base_agent import BaseAgent
|
||||
from .base_agent_config import BaseAgentConfig
|
||||
from .invocation_context import InvocationContext
|
||||
@@ -111,8 +112,10 @@ class ParallelAgent(BaseAgent):
|
||||
)
|
||||
for sub_agent in self.sub_agents
|
||||
]
|
||||
async for event in _merge_agent_run(agent_runs):
|
||||
yield event
|
||||
|
||||
async with Aclosing(_merge_agent_run(agent_runs)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import Type
|
||||
from typing_extensions import override
|
||||
|
||||
from ..events.event import Event
|
||||
from ..utils.context_utils import Aclosing
|
||||
from .base_agent import BaseAgent
|
||||
from .base_agent import BaseAgentConfig
|
||||
from .invocation_context import InvocationContext
|
||||
@@ -40,8 +41,9 @@ class SequentialAgent(BaseAgent):
|
||||
self, ctx: InvocationContext
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_async(ctx):
|
||||
yield event
|
||||
async with Aclosing(sub_agent.run_async(ctx)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
@override
|
||||
async def _run_live_impl(
|
||||
@@ -78,5 +80,6 @@ class SequentialAgent(BaseAgent):
|
||||
do not generate any text other than the function call."""
|
||||
|
||||
for sub_agent in self.sub_agents:
|
||||
async for event in sub_agent.run_live(ctx):
|
||||
yield event
|
||||
async with Aclosing(sub_agent.run_live(ctx)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
@@ -73,6 +73,7 @@ from ..memory.base_memory_service import BaseMemoryService
|
||||
from ..runners import Runner
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.session import Session
|
||||
from ..utils.context_utils import Aclosing
|
||||
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
||||
from .cli_eval import EvalStatus
|
||||
from .utils import cleanup
|
||||
@@ -828,14 +829,16 @@ class AdkWebServer:
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
runner = await self.get_runner_async(req.app_name)
|
||||
events = [
|
||||
event
|
||||
async for event in runner.run_async(
|
||||
|
||||
events = []
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
)
|
||||
]
|
||||
) as agen:
|
||||
events = [event async for event in agen]
|
||||
logger.info("Generated %s events in agent run", len(events))
|
||||
logger.debug("Events generated: %s", events)
|
||||
return events
|
||||
@@ -856,19 +859,24 @@ class AdkWebServer:
|
||||
StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
||||
)
|
||||
runner = await self.get_runner_async(req.app_name)
|
||||
async for event in runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
state_delta=req.state_delta,
|
||||
run_config=RunConfig(streaming_mode=stream_mode),
|
||||
):
|
||||
# Format as SSE data
|
||||
sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
|
||||
logger.debug(
|
||||
"Generated event in agent run streaming: %s", sse_event
|
||||
)
|
||||
yield f"data: {sse_event}\n\n"
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
state_delta=req.state_delta,
|
||||
run_config=RunConfig(streaming_mode=stream_mode),
|
||||
)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
# Format as SSE data
|
||||
sse_event = event.model_dump_json(
|
||||
exclude_none=True, by_alias=True
|
||||
)
|
||||
logger.debug(
|
||||
"Generated event in agent run streaming: %s", sse_event
|
||||
)
|
||||
yield f"data: {sse_event}\n\n"
|
||||
except Exception as e:
|
||||
logger.exception("Error in event_generator: %s", e)
|
||||
# You might want to yield an error event here
|
||||
@@ -954,12 +962,15 @@ class AdkWebServer:
|
||||
|
||||
async def forward_events():
|
||||
runner = await self.get_runner_async(app_name)
|
||||
async for event in runner.run_live(
|
||||
session=session, live_request_queue=live_request_queue
|
||||
):
|
||||
await websocket.send_text(
|
||||
event.model_dump_json(exclude_none=True, by_alias=True)
|
||||
)
|
||||
async with Aclosing(
|
||||
runner.run_live(
|
||||
session=session, live_request_queue=live_request_queue
|
||||
)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
await websocket.send_text(
|
||||
event.model_dump_json(exclude_none=True, by_alias=True)
|
||||
)
|
||||
|
||||
async def process_messages():
|
||||
try:
|
||||
|
||||
+23
-14
@@ -30,6 +30,7 @@ from ..runners import Runner
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from ..utils.context_utils import Aclosing
|
||||
from .utils import envs
|
||||
from .utils.agent_loader import AgentLoader
|
||||
|
||||
@@ -65,12 +66,15 @@ async def run_input_file(
|
||||
for query in input_file.queries:
|
||||
click.echo(f'[user]: {query}')
|
||||
content = types.Content(role='user', parts=[types.Part(text=query)])
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
):
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
return session
|
||||
|
||||
|
||||
@@ -94,14 +98,19 @@ async def run_interactively(
|
||||
continue
|
||||
if query == 'exit':
|
||||
break
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
new_message=types.Content(role='user', parts=[types.Part(text=query)]),
|
||||
):
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=session.user_id,
|
||||
session_id=session.id,
|
||||
new_message=types.Content(
|
||||
role='user', parts=[types.Part(text=query)]
|
||||
),
|
||||
)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
if event.content and event.content.parts:
|
||||
if text := ''.join(part.text or '' for part in event.content.parts):
|
||||
click.echo(f'[{event.author}]: {text}')
|
||||
await runner.close()
|
||||
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ from ..evaluation.eval_result import EvalCaseResult
|
||||
from ..evaluation.evaluator import EvalStatus
|
||||
from ..evaluation.evaluator import Evaluator
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..utils.context_utils import Aclosing
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
@@ -159,10 +160,11 @@ async def _collect_inferences(
|
||||
"""
|
||||
inference_results = []
|
||||
for inference_request in inference_requests:
|
||||
async for inference_result in eval_service.perform_inference(
|
||||
inference_request=inference_request
|
||||
):
|
||||
inference_results.append(inference_result)
|
||||
async with Aclosing(
|
||||
eval_service.perform_inference(inference_request=inference_request)
|
||||
) as agen:
|
||||
async for inference_result in agen:
|
||||
inference_results.append(inference_result)
|
||||
return inference_results
|
||||
|
||||
|
||||
@@ -180,10 +182,11 @@ async def _collect_eval_results(
|
||||
inference_results=inference_results,
|
||||
evaluate_config=EvaluateConfig(eval_metrics=eval_metrics),
|
||||
)
|
||||
async for eval_result in eval_service.evaluate(
|
||||
evaluate_request=evaluate_request
|
||||
):
|
||||
eval_results.append(eval_result)
|
||||
async with Aclosing(
|
||||
eval_service.evaluate(evaluate_request=evaluate_request)
|
||||
) as agen:
|
||||
async for eval_result in agen:
|
||||
eval_results.append(eval_result)
|
||||
|
||||
return eval_results
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ..agents.base_agent import BaseAgent
|
||||
from ..utils.context_utils import Aclosing
|
||||
from .constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||
from .eval_case import IntermediateData
|
||||
from .eval_case import Invocation
|
||||
@@ -538,10 +539,11 @@ class AgentEvaluator:
|
||||
# Generate inferences
|
||||
inference_results = []
|
||||
for inference_request in inference_requests:
|
||||
async for inference_result in eval_service.perform_inference(
|
||||
inference_request=inference_request
|
||||
):
|
||||
inference_results.append(inference_result)
|
||||
async with Aclosing(
|
||||
eval_service.perform_inference(inference_request=inference_request)
|
||||
) as agen:
|
||||
async for inference_result in agen:
|
||||
inference_results.append(inference_result)
|
||||
|
||||
# Evaluate metrics
|
||||
# As we perform more than one run for an eval case, we collect eval results
|
||||
@@ -551,14 +553,15 @@ class AgentEvaluator:
|
||||
inference_results=inference_results,
|
||||
evaluate_config=EvaluateConfig(eval_metrics=eval_metrics),
|
||||
)
|
||||
async for eval_result in eval_service.evaluate(
|
||||
evaluate_request=evaluate_request
|
||||
):
|
||||
eval_id = eval_result.eval_id
|
||||
if eval_id not in eval_results_by_eval_id:
|
||||
eval_results_by_eval_id[eval_id] = []
|
||||
async with Aclosing(
|
||||
eval_service.evaluate(evaluate_request=evaluate_request)
|
||||
) as agen:
|
||||
async for eval_result in agen:
|
||||
eval_id = eval_result.eval_id
|
||||
if eval_id not in eval_results_by_eval_id:
|
||||
eval_results_by_eval_id[eval_id] = []
|
||||
|
||||
eval_results_by_eval_id[eval_id].append(eval_result)
|
||||
eval_results_by_eval_id[eval_id].append(eval_result)
|
||||
|
||||
return eval_results_by_eval_id
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ from ..runners import Runner
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from ..sessions.session import Session
|
||||
from ..utils.context_utils import Aclosing
|
||||
from .eval_case import EvalCase
|
||||
from .eval_case import IntermediateData
|
||||
from .eval_case import Invocation
|
||||
@@ -189,18 +190,25 @@ class EvaluationGenerator:
|
||||
tool_uses = []
|
||||
invocation_id = ""
|
||||
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id, session_id=session_id, new_message=user_content
|
||||
):
|
||||
invocation_id = (
|
||||
event.invocation_id if not invocation_id else invocation_id
|
||||
)
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=user_id, session_id=session_id, new_message=user_content
|
||||
)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
invocation_id = (
|
||||
event.invocation_id if not invocation_id else invocation_id
|
||||
)
|
||||
|
||||
if event.is_final_response() and event.content and event.content.parts:
|
||||
final_response = event.content
|
||||
elif event.get_function_calls():
|
||||
for call in event.get_function_calls():
|
||||
tool_uses.append(call)
|
||||
if (
|
||||
event.is_final_response()
|
||||
and event.content
|
||||
and event.content.parts
|
||||
):
|
||||
final_response = event.content
|
||||
elif event.get_function_calls():
|
||||
for call in event.get_function_calls():
|
||||
tool_uses.append(call)
|
||||
|
||||
response_invocations.append(
|
||||
Invocation(
|
||||
|
||||
@@ -24,6 +24,7 @@ from ..models.base_llm import BaseLlm
|
||||
from ..models.llm_request import LlmRequest
|
||||
from ..models.llm_response import LlmResponse
|
||||
from ..models.registry import LLMRegistry
|
||||
from ..utils.context_utils import Aclosing
|
||||
from .eval_case import Invocation
|
||||
from .eval_metrics import EvalMetric
|
||||
from .evaluator import EvaluationResult
|
||||
@@ -109,21 +110,22 @@ class LlmAsJudge(Evaluator):
|
||||
num_samples = self._judge_model_options.num_samples
|
||||
invocation_result_samples = []
|
||||
for _ in range(num_samples):
|
||||
async for llm_response in self._judge_model.generate_content_async(
|
||||
llm_request
|
||||
):
|
||||
# Non-streaming call, so there is only one response content.
|
||||
score = self.convert_auto_rater_response_to_score(llm_response)
|
||||
invocation_result_samples.append(
|
||||
PerInvocationResult(
|
||||
actual_invocation=actual,
|
||||
expected_invocation=expected,
|
||||
score=score,
|
||||
eval_status=get_eval_status(
|
||||
score, self._eval_metric.threshold
|
||||
),
|
||||
)
|
||||
)
|
||||
async with Aclosing(
|
||||
self._judge_model.generate_content_async(llm_request)
|
||||
) as agen:
|
||||
async for llm_response in agen:
|
||||
# Non-streaming call, so there is only one response content.
|
||||
score = self.convert_auto_rater_response_to_score(llm_response)
|
||||
invocation_result_samples.append(
|
||||
PerInvocationResult(
|
||||
actual_invocation=actual,
|
||||
expected_invocation=expected,
|
||||
score=score,
|
||||
eval_status=get_eval_status(
|
||||
score, self._eval_metric.threshold
|
||||
),
|
||||
)
|
||||
)
|
||||
if not invocation_result_samples:
|
||||
continue
|
||||
per_invocation_results.append(
|
||||
|
||||
@@ -39,6 +39,7 @@ from ...code_executors.code_executor_context import CodeExecutorContext
|
||||
from ...events.event import Event
|
||||
from ...events.event_actions import EventActions
|
||||
from ...models.llm_response import LlmResponse
|
||||
from ...utils.context_utils import Aclosing
|
||||
from ._base_llm_processor import BaseLlmRequestProcessor
|
||||
from ._base_llm_processor import BaseLlmResponseProcessor
|
||||
|
||||
@@ -122,8 +123,11 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
|
||||
if not invocation_context.agent.code_executor:
|
||||
return
|
||||
|
||||
async for event in _run_pre_processor(invocation_context, llm_request):
|
||||
yield event
|
||||
async with Aclosing(
|
||||
_run_pre_processor(invocation_context, llm_request)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
# Convert the code execution parts to text parts.
|
||||
if not isinstance(invocation_context.agent.code_executor, BaseCodeExecutor):
|
||||
@@ -152,8 +156,11 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
|
||||
if llm_response.partial:
|
||||
return
|
||||
|
||||
async for event in _run_post_processor(invocation_context, llm_response):
|
||||
yield event
|
||||
async with Aclosing(
|
||||
_run_post_processor(invocation_context, llm_response)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
|
||||
response_processor = _CodeExecutionResponseProcessor()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -40,6 +40,7 @@ from ...telemetry import trace_tool_call
|
||||
from ...telemetry import tracer
|
||||
from ...tools.base_tool import BaseTool
|
||||
from ...tools.tool_context import ToolContext
|
||||
from ...utils.context_utils import Aclosing
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...agents.llm_agent import LlmAgent
|
||||
@@ -510,21 +511,24 @@ async def _process_function_live_helper(
|
||||
# we require the function to be a async generator function
|
||||
async def run_tool_and_update_queue(tool, function_args, tool_context):
|
||||
try:
|
||||
async for result in __call_tool_live(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
tool_context=tool_context,
|
||||
invocation_context=invocation_context,
|
||||
):
|
||||
updated_content = types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=f'Function {tool.name} returned: {result}'
|
||||
)
|
||||
],
|
||||
)
|
||||
invocation_context.live_request_queue.send_content(updated_content)
|
||||
async with Aclosing(
|
||||
__call_tool_live(
|
||||
tool=tool,
|
||||
args=function_args,
|
||||
tool_context=tool_context,
|
||||
invocation_context=invocation_context,
|
||||
)
|
||||
) as agen:
|
||||
async for result in agen:
|
||||
updated_content = types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part.from_text(
|
||||
text=f'Function {tool.name} returned: {result}'
|
||||
)
|
||||
],
|
||||
)
|
||||
invocation_context.live_request_queue.send_content(updated_content)
|
||||
except asyncio.CancelledError:
|
||||
raise # Re-raise to properly propagate the cancellation
|
||||
|
||||
@@ -586,12 +590,15 @@ async def __call_tool_live(
|
||||
invocation_context: InvocationContext,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
"""Calls the tool asynchronously (awaiting the coroutine)."""
|
||||
async for item in tool._call_live(
|
||||
args=args,
|
||||
tool_context=tool_context,
|
||||
invocation_context=invocation_context,
|
||||
):
|
||||
yield item
|
||||
async with Aclosing(
|
||||
tool._call_live(
|
||||
args=args,
|
||||
tool_context=tool_context,
|
||||
invocation_context=invocation_context,
|
||||
)
|
||||
) as agen:
|
||||
async for item in agen:
|
||||
yield item
|
||||
|
||||
|
||||
async def __call_tool_async(
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import Union
|
||||
from google.genai import live
|
||||
from google.genai import types
|
||||
|
||||
from ..utils.context_utils import Aclosing
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
@@ -142,90 +143,92 @@ class GeminiLlmConnection(BaseLlmConnection):
|
||||
"""
|
||||
|
||||
text = ''
|
||||
async for message in self._gemini_session.receive():
|
||||
logger.debug('Got LLM Live message: %s', message)
|
||||
if message.server_content:
|
||||
content = message.server_content.model_turn
|
||||
if content and content.parts:
|
||||
llm_response = LlmResponse(
|
||||
content=content, interrupted=message.server_content.interrupted
|
||||
)
|
||||
if content.parts[0].text:
|
||||
text += content.parts[0].text
|
||||
llm_response.partial = True
|
||||
# don't yield the merged text event when receiving audio data
|
||||
elif text and not content.parts[0].inline_data:
|
||||
async with Aclosing(self._gemini_session.receive()) as agen:
|
||||
async for message in agen:
|
||||
logger.debug('Got LLM Live message: %s', message)
|
||||
if message.server_content:
|
||||
content = message.server_content.model_turn
|
||||
if content and content.parts:
|
||||
llm_response = LlmResponse(
|
||||
content=content, interrupted=message.server_content.interrupted
|
||||
)
|
||||
if content.parts[0].text:
|
||||
text += content.parts[0].text
|
||||
llm_response.partial = True
|
||||
# don't yield the merged text event when receiving audio data
|
||||
elif text and not content.parts[0].inline_data:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield llm_response
|
||||
if (
|
||||
message.server_content.input_transcription
|
||||
and message.server_content.input_transcription.text
|
||||
):
|
||||
user_text = message.server_content.input_transcription.text
|
||||
parts = [
|
||||
types.Part.from_text(
|
||||
text=user_text,
|
||||
)
|
||||
]
|
||||
llm_response = LlmResponse(
|
||||
content=types.Content(role='user', parts=parts)
|
||||
)
|
||||
yield llm_response
|
||||
if (
|
||||
message.server_content.output_transcription
|
||||
and message.server_content.output_transcription.text
|
||||
):
|
||||
# TODO: Right now, we just support output_transcription without
|
||||
# changing interface and data protocol. Later, we can consider to
|
||||
# support output_transcription as a separate field in LlmResponse.
|
||||
|
||||
# Transcription is always considered as partial event
|
||||
# We rely on other control signals to determine when to yield the
|
||||
# full text response(turn_complete, interrupted, or tool_call).
|
||||
text += message.server_content.output_transcription.text
|
||||
parts = [
|
||||
types.Part.from_text(
|
||||
text=message.server_content.output_transcription.text
|
||||
)
|
||||
]
|
||||
llm_response = LlmResponse(
|
||||
content=types.Content(role='model', parts=parts), partial=True
|
||||
)
|
||||
yield llm_response
|
||||
|
||||
if message.server_content.turn_complete:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield LlmResponse(
|
||||
turn_complete=True,
|
||||
interrupted=message.server_content.interrupted,
|
||||
)
|
||||
break
|
||||
# in case of empty content or parts, we sill surface it
|
||||
# in case it's an interrupted message, we merge the previous partial
|
||||
# text. Other we don't merge. because content can be none when model
|
||||
# safety threshold is triggered
|
||||
if message.server_content.interrupted and text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield llm_response
|
||||
if (
|
||||
message.server_content.input_transcription
|
||||
and message.server_content.input_transcription.text
|
||||
):
|
||||
user_text = message.server_content.input_transcription.text
|
||||
parts = [
|
||||
types.Part.from_text(
|
||||
text=user_text,
|
||||
)
|
||||
]
|
||||
llm_response = LlmResponse(
|
||||
content=types.Content(role='user', parts=parts)
|
||||
)
|
||||
yield llm_response
|
||||
if (
|
||||
message.server_content.output_transcription
|
||||
and message.server_content.output_transcription.text
|
||||
):
|
||||
# TODO: Right now, we just support output_transcription without
|
||||
# changing interface and data protocol. Later, we can consider to
|
||||
# support output_transcription as a separate field in LlmResponse.
|
||||
|
||||
# Transcription is always considered as partial event
|
||||
# We rely on other control signals to determine when to yield the
|
||||
# full text response(turn_complete, interrupted, or tool_call).
|
||||
text += message.server_content.output_transcription.text
|
||||
parts = [
|
||||
types.Part.from_text(
|
||||
text=message.server_content.output_transcription.text
|
||||
)
|
||||
]
|
||||
llm_response = LlmResponse(
|
||||
content=types.Content(role='model', parts=parts), partial=True
|
||||
)
|
||||
yield llm_response
|
||||
|
||||
if message.server_content.turn_complete:
|
||||
yield LlmResponse(interrupted=message.server_content.interrupted)
|
||||
if message.tool_call:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield LlmResponse(
|
||||
turn_complete=True, interrupted=message.server_content.interrupted
|
||||
parts = [
|
||||
types.Part(function_call=function_call)
|
||||
for function_call in message.tool_call.function_calls
|
||||
]
|
||||
yield LlmResponse(content=types.Content(role='model', parts=parts))
|
||||
if message.session_resumption_update:
|
||||
logger.info('Redeived session reassumption message: %s', message)
|
||||
yield (
|
||||
LlmResponse(
|
||||
live_session_resumption_update=message.session_resumption_update
|
||||
)
|
||||
)
|
||||
break
|
||||
# in case of empty content or parts, we sill surface it
|
||||
# in case it's an interrupted message, we merge the previous partial
|
||||
# text. Other we don't merge. because content can be none when model
|
||||
# safety threshold is triggered
|
||||
if message.server_content.interrupted and text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield LlmResponse(interrupted=message.server_content.interrupted)
|
||||
if message.tool_call:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
parts = [
|
||||
types.Part(function_call=function_call)
|
||||
for function_call in message.tool_call.function_calls
|
||||
]
|
||||
yield LlmResponse(content=types.Content(role='model', parts=parts))
|
||||
if message.session_resumption_update:
|
||||
logger.info('Redeived session reassumption message: %s', message)
|
||||
yield (
|
||||
LlmResponse(
|
||||
live_session_resumption_update=message.session_resumption_update
|
||||
)
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Closes the llm server connection."""
|
||||
|
||||
@@ -32,6 +32,7 @@ from google.genai.types import FinishReason
|
||||
from typing_extensions import override
|
||||
|
||||
from .. import version
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ..utils.variant_utils import GoogleLLMVariant
|
||||
from .base_llm import BaseLlm
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
@@ -141,39 +142,40 @@ class Gemini(BaseLlm):
|
||||
# contents are sent, we send an accumulated event which contains all the
|
||||
# previous partial content. The only difference is bidi rely on
|
||||
# complete_turn flag to detect end while sse depends on finish_reason.
|
||||
async for response in responses:
|
||||
logger.debug(_build_response_log(response))
|
||||
llm_response = LlmResponse.create(response)
|
||||
usage_metadata = llm_response.usage_metadata
|
||||
if (
|
||||
llm_response.content
|
||||
and llm_response.content.parts
|
||||
and llm_response.content.parts[0].text
|
||||
):
|
||||
part0 = llm_response.content.parts[0]
|
||||
if part0.thought:
|
||||
thought_text += part0.text
|
||||
else:
|
||||
text += part0.text
|
||||
llm_response.partial = True
|
||||
elif (thought_text or text) and (
|
||||
not llm_response.content
|
||||
or not llm_response.content.parts
|
||||
# don't yield the merged text event when receiving audio data
|
||||
or not llm_response.content.parts[0].inline_data
|
||||
):
|
||||
parts = []
|
||||
if thought_text:
|
||||
parts.append(types.Part(text=thought_text, thought=True))
|
||||
if text:
|
||||
parts.append(types.Part.from_text(text=text))
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(parts=parts),
|
||||
usage_metadata=llm_response.usage_metadata,
|
||||
)
|
||||
thought_text = ''
|
||||
text = ''
|
||||
yield llm_response
|
||||
async with Aclosing(responses) as agen:
|
||||
async for response in agen:
|
||||
logger.debug(_build_response_log(response))
|
||||
llm_response = LlmResponse.create(response)
|
||||
usage_metadata = llm_response.usage_metadata
|
||||
if (
|
||||
llm_response.content
|
||||
and llm_response.content.parts
|
||||
and llm_response.content.parts[0].text
|
||||
):
|
||||
part0 = llm_response.content.parts[0]
|
||||
if part0.thought:
|
||||
thought_text += part0.text
|
||||
else:
|
||||
text += part0.text
|
||||
llm_response.partial = True
|
||||
elif (thought_text or text) and (
|
||||
not llm_response.content
|
||||
or not llm_response.content.parts
|
||||
# don't yield the merged text event when receiving audio data
|
||||
or not llm_response.content.parts[0].inline_data
|
||||
):
|
||||
parts = []
|
||||
if thought_text:
|
||||
parts.append(types.Part(text=thought_text, thought=True))
|
||||
if text:
|
||||
parts.append(types.Part.from_text(text=text))
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(parts=parts),
|
||||
usage_metadata=llm_response.usage_metadata,
|
||||
)
|
||||
thought_text = ''
|
||||
text = ''
|
||||
yield llm_response
|
||||
|
||||
# generate an aggregated content at the end regardless the
|
||||
# response.candidates[0].finish_reason
|
||||
|
||||
+74
-57
@@ -53,6 +53,7 @@ from .sessions.in_memory_session_service import InMemorySessionService
|
||||
from .sessions.session import Session
|
||||
from .telemetry import tracer
|
||||
from .tools.base_toolset import BaseToolset
|
||||
from .utils.context_utils import Aclosing
|
||||
|
||||
logger = logging.getLogger('google_adk.' + __name__)
|
||||
|
||||
@@ -146,13 +147,16 @@ class Runner:
|
||||
|
||||
async def _invoke_run_async():
|
||||
try:
|
||||
async for event in self.run_async(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
new_message=new_message,
|
||||
run_config=run_config,
|
||||
):
|
||||
event_queue.put(event)
|
||||
async with Aclosing(
|
||||
self.run_async(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
new_message=new_message,
|
||||
run_config=run_config,
|
||||
)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
event_queue.put(event)
|
||||
finally:
|
||||
event_queue.put(None)
|
||||
|
||||
@@ -195,47 +199,55 @@ class Runner:
|
||||
Yields:
|
||||
The events generated by the agent.
|
||||
"""
|
||||
with tracer.start_as_current_span('invocation'):
|
||||
session = await self.session_service.get_session(
|
||||
app_name=self.app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise ValueError(f'Session not found: {session_id}')
|
||||
|
||||
invocation_context = self._new_invocation_context(
|
||||
session,
|
||||
new_message=new_message,
|
||||
run_config=run_config,
|
||||
)
|
||||
root_agent = self.agent
|
||||
|
||||
# Modify user message before execution.
|
||||
modified_user_message = (
|
||||
await invocation_context.plugin_manager.run_on_user_message_callback(
|
||||
invocation_context=invocation_context, user_message=new_message
|
||||
)
|
||||
)
|
||||
if modified_user_message is not None:
|
||||
new_message = modified_user_message
|
||||
|
||||
if new_message:
|
||||
await self._append_new_message_to_session(
|
||||
session,
|
||||
new_message,
|
||||
invocation_context,
|
||||
run_config.save_input_blobs_as_artifacts,
|
||||
state_delta,
|
||||
async def _run_with_trace(
|
||||
new_message: types.Content,
|
||||
) -> AsyncGenerator[Event, None]:
|
||||
with tracer.start_as_current_span('invocation'):
|
||||
session = await self.session_service.get_session(
|
||||
app_name=self.app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise ValueError(f'Session not found: {session_id}')
|
||||
|
||||
invocation_context.agent = self._find_agent_to_run(session, root_agent)
|
||||
invocation_context = self._new_invocation_context(
|
||||
session,
|
||||
new_message=new_message,
|
||||
run_config=run_config,
|
||||
)
|
||||
root_agent = self.agent
|
||||
|
||||
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
|
||||
async for event in ctx.agent.run_async(ctx):
|
||||
yield event
|
||||
# Modify user message before execution.
|
||||
modified_user_message = await invocation_context.plugin_manager.run_on_user_message_callback(
|
||||
invocation_context=invocation_context, user_message=new_message
|
||||
)
|
||||
if modified_user_message is not None:
|
||||
new_message = modified_user_message
|
||||
|
||||
async for event in self._exec_with_plugin(
|
||||
invocation_context, session, execute
|
||||
):
|
||||
if new_message:
|
||||
await self._append_new_message_to_session(
|
||||
session,
|
||||
new_message,
|
||||
invocation_context,
|
||||
run_config.save_input_blobs_as_artifacts,
|
||||
state_delta,
|
||||
)
|
||||
|
||||
invocation_context.agent = self._find_agent_to_run(session, root_agent)
|
||||
|
||||
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
|
||||
async with Aclosing(ctx.agent.run_async(ctx)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
async with Aclosing(
|
||||
self._exec_with_plugin(invocation_context, session, execute)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
async with Aclosing(_run_with_trace(new_message)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
async def _exec_with_plugin(
|
||||
@@ -274,14 +286,17 @@ class Runner:
|
||||
yield early_exit_event
|
||||
else:
|
||||
# Step 2: Otherwise continue with normal execution
|
||||
async for event in execute_fn(invocation_context):
|
||||
if not event.partial:
|
||||
await self.session_service.append_event(session=session, event=event)
|
||||
# Step 3: Run the on_event callbacks to optionally modify the event.
|
||||
modified_event = await plugin_manager.run_on_event_callback(
|
||||
invocation_context=invocation_context, event=event
|
||||
)
|
||||
yield (modified_event if modified_event else event)
|
||||
async with Aclosing(execute_fn(invocation_context)) as agen:
|
||||
async for event in agen:
|
||||
if not event.partial:
|
||||
await self.session_service.append_event(
|
||||
session=session, event=event
|
||||
)
|
||||
# Step 3: Run the on_event callbacks to optionally modify the event.
|
||||
modified_event = await plugin_manager.run_on_event_callback(
|
||||
invocation_context=invocation_context, event=event
|
||||
)
|
||||
yield (modified_event if modified_event else event)
|
||||
|
||||
# Step 4: Run the after_run callbacks to optionally modify the context.
|
||||
await plugin_manager.run_after_run_callback(
|
||||
@@ -439,13 +454,15 @@ class Runner:
|
||||
)
|
||||
|
||||
async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]:
|
||||
async for event in ctx.agent.run_live(ctx):
|
||||
yield event
|
||||
async with Aclosing(ctx.agent.run_live(ctx)) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
async for event in self._exec_with_plugin(
|
||||
invocation_context, session, execute
|
||||
):
|
||||
yield event
|
||||
async with Aclosing(
|
||||
self._exec_with_plugin(invocation_context, session, execute)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
def _find_agent_to_run(
|
||||
self, session: Session, root_agent: BaseAgent
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing_extensions import override
|
||||
from . import _automatic_function_calling_util
|
||||
from ..agents.common_configs import AgentRefConfig
|
||||
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from ..utils.context_utils import Aclosing
|
||||
from ._forwarding_artifact_service import ForwardingArtifactService
|
||||
from .base_tool import BaseTool
|
||||
from .tool_configs import BaseToolConfig
|
||||
@@ -141,13 +142,16 @@ class AgentTool(BaseTool):
|
||||
)
|
||||
|
||||
last_event = None
|
||||
async for event in runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
):
|
||||
# Forward state delta to parent session.
|
||||
if event.actions.state_delta:
|
||||
tool_context.state.update(event.actions.state_delta)
|
||||
last_event = event
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=session.user_id, session_id=session.id, new_message=content
|
||||
)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
# Forward state delta to parent session.
|
||||
if event.actions.state_delta:
|
||||
tool_context.state.update(event.actions.state_delta)
|
||||
last_event = event
|
||||
|
||||
if not last_event or not last_event.content or not last_event.content.parts:
|
||||
return ''
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user