You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
refactor: Rename long util function name in runner.py and move it to functions.py
PiperOrigin-RevId: 774880990
This commit is contained in:
committed by
Copybara-Service
parent
29cd183aa1
commit
120cbabeb2
@@ -519,3 +519,35 @@ def merge_parallel_function_response_events(
|
||||
# Use the base_event as the timestamp
|
||||
merged_event.timestamp = base_event.timestamp
|
||||
return merged_event
|
||||
|
||||
|
||||
def find_matching_function_call(
|
||||
events: list[Event],
|
||||
) -> Optional[Event]:
|
||||
"""Finds the function call event that matches the function response id of the last event."""
|
||||
if not events:
|
||||
return None
|
||||
|
||||
last_event = events[-1]
|
||||
if (
|
||||
last_event.content
|
||||
and last_event.content.parts
|
||||
and any(part.function_response for part in last_event.content.parts)
|
||||
):
|
||||
|
||||
function_call_id = next(
|
||||
part.function_response.id
|
||||
for part in last_event.content.parts
|
||||
if part.function_response
|
||||
)
|
||||
for i in range(len(events) - 2, -1, -1):
|
||||
event = events[i]
|
||||
# looking for the system long running request euc function call
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
|
||||
for function_call in function_calls:
|
||||
if function_call.id == function_call_id:
|
||||
return event
|
||||
return None
|
||||
|
||||
@@ -36,6 +36,7 @@ from .artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from .auth.credential_service.base_credential_service import BaseCredentialService
|
||||
from .code_executors.built_in_code_executor import BuiltInCodeExecutor
|
||||
from .events.event import Event
|
||||
from .flows.llm_flows.functions import find_matching_function_call
|
||||
from .memory.base_memory_service import BaseMemoryService
|
||||
from .memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from .platform.thread import create_thread
|
||||
@@ -354,9 +355,7 @@ class Runner:
|
||||
# the agent that returned the corressponding function call regardless the
|
||||
# type of the agent. e.g. a remote a2a agent may surface a credential
|
||||
# request as a special long running function tool call.
|
||||
event = _find_function_call_event_if_last_event_is_function_response(
|
||||
session
|
||||
)
|
||||
event = find_matching_function_call(session.events)
|
||||
if event and event.author:
|
||||
return root_agent.find_agent(event.author)
|
||||
for event in filter(lambda e: e.author != 'user', reversed(session.events)):
|
||||
@@ -538,35 +537,3 @@ class InMemoryRunner(Runner):
|
||||
session_service=self._in_memory_session_service,
|
||||
memory_service=InMemoryMemoryService(),
|
||||
)
|
||||
|
||||
|
||||
def _find_function_call_event_if_last_event_is_function_response(
|
||||
session: Session,
|
||||
) -> Optional[Event]:
|
||||
events = session.events
|
||||
if not events:
|
||||
return None
|
||||
|
||||
last_event = events[-1]
|
||||
if (
|
||||
last_event.content
|
||||
and last_event.content.parts
|
||||
and any(part.function_response for part in last_event.content.parts)
|
||||
):
|
||||
|
||||
function_call_id = next(
|
||||
part.function_response.id
|
||||
for part in last_event.content.parts
|
||||
if part.function_response
|
||||
)
|
||||
for i in range(len(events) - 2, -1, -1):
|
||||
event = events[i]
|
||||
# looking for the system long running request euc function call
|
||||
function_calls = event.get_function_calls()
|
||||
if not function_calls:
|
||||
continue
|
||||
|
||||
for function_call in function_calls:
|
||||
if function_call.id == function_call_id:
|
||||
return event
|
||||
return None
|
||||
|
||||
@@ -17,6 +17,9 @@ from typing import AsyncGenerator
|
||||
from typing import Callable
|
||||
|
||||
from google.adk.agents import Agent
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.flows.llm_flows.functions import find_matching_function_call
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.tools import ToolContext
|
||||
from google.adk.tools.function_tool import FunctionTool
|
||||
from google.genai import types
|
||||
@@ -256,3 +259,136 @@ def test_function_call_id():
|
||||
assert part.function_response.id is None
|
||||
assert events[0].content.parts[0].function_call.id.startswith('adk-')
|
||||
assert events[1].content.parts[0].function_response.id.startswith('adk-')
|
||||
|
||||
|
||||
def test_find_function_call_event_no_function_response_in_last_event():
|
||||
"""Test when last event has no function response."""
|
||||
events = [
|
||||
Event(
|
||||
invocation_id='inv1',
|
||||
author='user',
|
||||
content=types.Content(role='user', parts=[types.Part(text='Hello')]),
|
||||
)
|
||||
]
|
||||
|
||||
result = find_matching_function_call(events)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_find_function_call_event_empty_session_events():
|
||||
"""Test when session has no events."""
|
||||
events = []
|
||||
|
||||
result = find_matching_function_call(events)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_find_function_call_event_function_response_but_no_matching_call():
|
||||
"""Test when last event has function response but no matching call found."""
|
||||
# Create a function response
|
||||
function_response = types.FunctionResponse(
|
||||
id='func_123', name='test_func', response={}
|
||||
)
|
||||
|
||||
events = [
|
||||
Event(
|
||||
invocation_id='inv1',
|
||||
author='agent1',
|
||||
content=types.Content(
|
||||
role='model',
|
||||
parts=[types.Part(text='Some other response')],
|
||||
),
|
||||
),
|
||||
Event(
|
||||
invocation_id='inv2',
|
||||
author='user',
|
||||
content=types.Content(
|
||||
role='user',
|
||||
parts=[types.Part(function_response=function_response)],
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
result = find_matching_function_call(events)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_find_function_call_event_function_response_with_matching_call():
|
||||
"""Test when last event has function response with matching function call."""
|
||||
# Create a function call
|
||||
function_call = types.FunctionCall(id='func_123', name='test_func', args={})
|
||||
|
||||
# Create a function response with matching ID
|
||||
function_response = types.FunctionResponse(
|
||||
id='func_123', name='test_func', response={}
|
||||
)
|
||||
|
||||
call_event = Event(
|
||||
invocation_id='inv1',
|
||||
author='agent1',
|
||||
content=types.Content(
|
||||
role='model', parts=[types.Part(function_call=function_call)]
|
||||
),
|
||||
)
|
||||
|
||||
response_event = Event(
|
||||
invocation_id='inv2',
|
||||
author='user',
|
||||
content=types.Content(
|
||||
role='user', parts=[types.Part(function_response=function_response)]
|
||||
),
|
||||
)
|
||||
|
||||
events = [call_event, response_event]
|
||||
|
||||
result = find_matching_function_call(events)
|
||||
assert result == call_event
|
||||
|
||||
|
||||
def test_find_function_call_event_multiple_function_responses():
|
||||
"""Test when last event has multiple function responses."""
|
||||
# Create function calls
|
||||
function_call1 = types.FunctionCall(id='func_123', name='test_func1', args={})
|
||||
function_call2 = types.FunctionCall(id='func_456', name='test_func2', args={})
|
||||
|
||||
# Create function responses
|
||||
function_response1 = types.FunctionResponse(
|
||||
id='func_123', name='test_func1', response={}
|
||||
)
|
||||
function_response2 = types.FunctionResponse(
|
||||
id='func_456', name='test_func2', response={}
|
||||
)
|
||||
|
||||
call_event1 = Event(
|
||||
invocation_id='inv1',
|
||||
author='agent1',
|
||||
content=types.Content(
|
||||
role='model', parts=[types.Part(function_call=function_call1)]
|
||||
),
|
||||
)
|
||||
|
||||
call_event2 = Event(
|
||||
invocation_id='inv2',
|
||||
author='agent2',
|
||||
content=types.Content(
|
||||
role='model', parts=[types.Part(function_call=function_call2)]
|
||||
),
|
||||
)
|
||||
|
||||
response_event = Event(
|
||||
invocation_id='inv3',
|
||||
author='user',
|
||||
content=types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(function_response=function_response1),
|
||||
types.Part(function_response=function_response2),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
events = [call_event1, call_event2, response_event]
|
||||
|
||||
# Should return the first matching function call event found
|
||||
result = find_matching_function_call(events)
|
||||
assert result == call_event1 # First match (func_123)
|
||||
|
||||
@@ -18,7 +18,6 @@ from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.runners import _find_function_call_event_if_last_event_is_function_response
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.sessions.session import Session
|
||||
@@ -73,176 +72,6 @@ class MockLlmAgent(LlmAgent):
|
||||
)
|
||||
|
||||
|
||||
class TestFindFunctionCallEventIfLastEventIsFunctionResponse:
|
||||
"""Tests for _find_function_call_event_if_last_event_is_function_response function."""
|
||||
|
||||
def test_no_function_response_in_last_event(self):
|
||||
"""Test when last event has no function response."""
|
||||
session = Session(
|
||||
id="test_session",
|
||||
user_id="test_user",
|
||||
app_name="test_app",
|
||||
events=[
|
||||
Event(
|
||||
invocation_id="inv1",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user", parts=[types.Part(text="Hello")]
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
result = _find_function_call_event_if_last_event_is_function_response(
|
||||
session
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_empty_session_events(self):
|
||||
"""Test when session has no events."""
|
||||
session = Session(
|
||||
id="test_session", user_id="test_user", app_name="test_app", events=[]
|
||||
)
|
||||
|
||||
result = _find_function_call_event_if_last_event_is_function_response(
|
||||
session
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_last_event_has_function_response_but_no_matching_call(self):
|
||||
"""Test when last event has function response but no matching call found."""
|
||||
# Create a function response
|
||||
function_response = types.FunctionResponse(
|
||||
id="func_123", name="test_func", response={}
|
||||
)
|
||||
|
||||
session = Session(
|
||||
id="test_session",
|
||||
user_id="test_user",
|
||||
app_name="test_app",
|
||||
events=[
|
||||
Event(
|
||||
invocation_id="inv1",
|
||||
author="agent1",
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[types.Part(text="Some other response")],
|
||||
),
|
||||
),
|
||||
Event(
|
||||
invocation_id="inv2",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user",
|
||||
parts=[types.Part(function_response=function_response)],
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = _find_function_call_event_if_last_event_is_function_response(
|
||||
session
|
||||
)
|
||||
assert result is None
|
||||
|
||||
def test_last_event_has_function_response_with_matching_call(self):
|
||||
"""Test when last event has function response with matching function call."""
|
||||
# Create a function call
|
||||
function_call = types.FunctionCall(id="func_123", name="test_func", args={})
|
||||
|
||||
# Create a function response with matching ID
|
||||
function_response = types.FunctionResponse(
|
||||
id="func_123", name="test_func", response={}
|
||||
)
|
||||
|
||||
call_event = Event(
|
||||
invocation_id="inv1",
|
||||
author="agent1",
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(function_call=function_call)]
|
||||
),
|
||||
)
|
||||
|
||||
response_event = Event(
|
||||
invocation_id="inv2",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user", parts=[types.Part(function_response=function_response)]
|
||||
),
|
||||
)
|
||||
|
||||
session = Session(
|
||||
id="test_session",
|
||||
user_id="test_user",
|
||||
app_name="test_app",
|
||||
events=[call_event, response_event],
|
||||
)
|
||||
|
||||
result = _find_function_call_event_if_last_event_is_function_response(
|
||||
session
|
||||
)
|
||||
assert result == call_event
|
||||
|
||||
def test_last_event_has_multiple_function_responses(self):
|
||||
"""Test when last event has multiple function responses."""
|
||||
# Create function calls
|
||||
function_call1 = types.FunctionCall(
|
||||
id="func_123", name="test_func1", args={}
|
||||
)
|
||||
function_call2 = types.FunctionCall(
|
||||
id="func_456", name="test_func2", args={}
|
||||
)
|
||||
|
||||
# Create function responses
|
||||
function_response1 = types.FunctionResponse(
|
||||
id="func_123", name="test_func1", response={}
|
||||
)
|
||||
function_response2 = types.FunctionResponse(
|
||||
id="func_456", name="test_func2", response={}
|
||||
)
|
||||
|
||||
call_event1 = Event(
|
||||
invocation_id="inv1",
|
||||
author="agent1",
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(function_call=function_call1)]
|
||||
),
|
||||
)
|
||||
|
||||
call_event2 = Event(
|
||||
invocation_id="inv2",
|
||||
author="agent2",
|
||||
content=types.Content(
|
||||
role="model", parts=[types.Part(function_call=function_call2)]
|
||||
),
|
||||
)
|
||||
|
||||
response_event = Event(
|
||||
invocation_id="inv3",
|
||||
author="user",
|
||||
content=types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part(function_response=function_response1),
|
||||
types.Part(function_response=function_response2),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
session = Session(
|
||||
id="test_session",
|
||||
user_id="test_user",
|
||||
app_name="test_app",
|
||||
events=[call_event1, call_event2, response_event],
|
||||
)
|
||||
|
||||
# Should return the first matching function call event found
|
||||
result = _find_function_call_event_if_last_event_is_function_response(
|
||||
session
|
||||
)
|
||||
assert result == call_event1 # First match (func_123)
|
||||
|
||||
|
||||
class TestRunnerFindAgentToRun:
|
||||
"""Tests for Runner._find_agent_to_run method."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user