From 120cbabeb23c16d9ce4be511e768885f19a8c2d2 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Mon, 23 Jun 2025 12:22:53 -0700 Subject: [PATCH] refactor: Rename long util function name in runner.py and move it to functions.py PiperOrigin-RevId: 774880990 --- src/google/adk/flows/llm_flows/functions.py | 32 ++++ src/google/adk/runners.py | 37 +--- .../flows/llm_flows/test_functions_simple.py | 136 ++++++++++++++ tests/unittests/test_runners.py | 171 ------------------ 4 files changed, 170 insertions(+), 206 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 2772550c..5c690f1f 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -519,3 +519,35 @@ def merge_parallel_function_response_events( # Use the base_event as the timestamp merged_event.timestamp = base_event.timestamp return merged_event + + +def find_matching_function_call( + events: list[Event], +) -> Optional[Event]: + """Finds the function call event that matches the function response id of the last event.""" + if not events: + return None + + last_event = events[-1] + if ( + last_event.content + and last_event.content.parts + and any(part.function_response for part in last_event.content.parts) + ): + + function_call_id = next( + part.function_response.id + for part in last_event.content.parts + if part.function_response + ) + for i in range(len(events) - 2, -1, -1): + event = events[i] + # looking for the system long running request euc function call + function_calls = event.get_function_calls() + if not function_calls: + continue + + for function_call in function_calls: + if function_call.id == function_call_id: + return event + return None diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 936bc520..017997bb 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -36,6 +36,7 @@ from .artifacts.in_memory_artifact_service import InMemoryArtifactService from .auth.credential_service.base_credential_service import BaseCredentialService from .code_executors.built_in_code_executor import BuiltInCodeExecutor from .events.event import Event +from .flows.llm_flows.functions import find_matching_function_call from .memory.base_memory_service import BaseMemoryService from .memory.in_memory_memory_service import InMemoryMemoryService from .platform.thread import create_thread @@ -354,9 +355,7 @@ class Runner: # the agent that returned the corressponding function call regardless the # type of the agent. e.g. a remote a2a agent may surface a credential # request as a special long running function tool call. - event = _find_function_call_event_if_last_event_is_function_response( - session - ) + event = find_matching_function_call(session.events) if event and event.author: return root_agent.find_agent(event.author) for event in filter(lambda e: e.author != 'user', reversed(session.events)): @@ -538,35 +537,3 @@ class InMemoryRunner(Runner): session_service=self._in_memory_session_service, memory_service=InMemoryMemoryService(), ) - - -def _find_function_call_event_if_last_event_is_function_response( - session: Session, -) -> Optional[Event]: - events = session.events - if not events: - return None - - last_event = events[-1] - if ( - last_event.content - and last_event.content.parts - and any(part.function_response for part in last_event.content.parts) - ): - - function_call_id = next( - part.function_response.id - for part in last_event.content.parts - if part.function_response - ) - for i in range(len(events) - 2, -1, -1): - event = events[i] - # looking for the system long running request euc function call - function_calls = event.get_function_calls() - if not function_calls: - continue - - for function_call in function_calls: - if function_call.id == function_call_id: - return event - return None diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index 2c5ef9bc..720af516 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -17,6 +17,9 @@ from typing import AsyncGenerator from typing import Callable from google.adk.agents import Agent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.functions import find_matching_function_call +from google.adk.sessions.session import Session from google.adk.tools import ToolContext from google.adk.tools.function_tool import FunctionTool from google.genai import types @@ -256,3 +259,136 @@ def test_function_call_id(): assert part.function_response.id is None assert events[0].content.parts[0].function_call.id.startswith('adk-') assert events[1].content.parts[0].function_response.id.startswith('adk-') + + +def test_find_function_call_event_no_function_response_in_last_event(): + """Test when last event has no function response.""" + events = [ + Event( + invocation_id='inv1', + author='user', + content=types.Content(role='user', parts=[types.Part(text='Hello')]), + ) + ] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_empty_session_events(): + """Test when session has no events.""" + events = [] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_function_response_but_no_matching_call(): + """Test when last event has function response but no matching call found.""" + # Create a function response + function_response = types.FunctionResponse( + id='func_123', name='test_func', response={} + ) + + events = [ + Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', + parts=[types.Part(text='Some other response')], + ), + ), + Event( + invocation_id='inv2', + author='user', + content=types.Content( + role='user', + parts=[types.Part(function_response=function_response)], + ), + ), + ] + + result = find_matching_function_call(events) + assert result is None + + +def test_find_function_call_event_function_response_with_matching_call(): + """Test when last event has function response with matching function call.""" + # Create a function call + function_call = types.FunctionCall(id='func_123', name='test_func', args={}) + + # Create a function response with matching ID + function_response = types.FunctionResponse( + id='func_123', name='test_func', response={} + ) + + call_event = Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call)] + ), + ) + + response_event = Event( + invocation_id='inv2', + author='user', + content=types.Content( + role='user', parts=[types.Part(function_response=function_response)] + ), + ) + + events = [call_event, response_event] + + result = find_matching_function_call(events) + assert result == call_event + + +def test_find_function_call_event_multiple_function_responses(): + """Test when last event has multiple function responses.""" + # Create function calls + function_call1 = types.FunctionCall(id='func_123', name='test_func1', args={}) + function_call2 = types.FunctionCall(id='func_456', name='test_func2', args={}) + + # Create function responses + function_response1 = types.FunctionResponse( + id='func_123', name='test_func1', response={} + ) + function_response2 = types.FunctionResponse( + id='func_456', name='test_func2', response={} + ) + + call_event1 = Event( + invocation_id='inv1', + author='agent1', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call1)] + ), + ) + + call_event2 = Event( + invocation_id='inv2', + author='agent2', + content=types.Content( + role='model', parts=[types.Part(function_call=function_call2)] + ), + ) + + response_event = Event( + invocation_id='inv3', + author='user', + content=types.Content( + role='user', + parts=[ + types.Part(function_response=function_response1), + types.Part(function_response=function_response2), + ], + ), + ) + + events = [call_event1, call_event2, response_event] + + # Should return the first matching function call event found + result = find_matching_function_call(events) + assert result == call_event1 # First match (func_123) diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 56d7667a..8d5bd241 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -18,7 +18,6 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.llm_agent import LlmAgent from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.events.event import Event -from google.adk.runners import _find_function_call_event_if_last_event_is_function_response from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.sessions.session import Session @@ -73,176 +72,6 @@ class MockLlmAgent(LlmAgent): ) -class TestFindFunctionCallEventIfLastEventIsFunctionResponse: - """Tests for _find_function_call_event_if_last_event_is_function_response function.""" - - def test_no_function_response_in_last_event(self): - """Test when last event has no function response.""" - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[ - Event( - invocation_id="inv1", - author="user", - content=types.Content( - role="user", parts=[types.Part(text="Hello")] - ), - ) - ], - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result is None - - def test_empty_session_events(self): - """Test when session has no events.""" - session = Session( - id="test_session", user_id="test_user", app_name="test_app", events=[] - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result is None - - def test_last_event_has_function_response_but_no_matching_call(self): - """Test when last event has function response but no matching call found.""" - # Create a function response - function_response = types.FunctionResponse( - id="func_123", name="test_func", response={} - ) - - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[ - Event( - invocation_id="inv1", - author="agent1", - content=types.Content( - role="model", - parts=[types.Part(text="Some other response")], - ), - ), - Event( - invocation_id="inv2", - author="user", - content=types.Content( - role="user", - parts=[types.Part(function_response=function_response)], - ), - ), - ], - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result is None - - def test_last_event_has_function_response_with_matching_call(self): - """Test when last event has function response with matching function call.""" - # Create a function call - function_call = types.FunctionCall(id="func_123", name="test_func", args={}) - - # Create a function response with matching ID - function_response = types.FunctionResponse( - id="func_123", name="test_func", response={} - ) - - call_event = Event( - invocation_id="inv1", - author="agent1", - content=types.Content( - role="model", parts=[types.Part(function_call=function_call)] - ), - ) - - response_event = Event( - invocation_id="inv2", - author="user", - content=types.Content( - role="user", parts=[types.Part(function_response=function_response)] - ), - ) - - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[call_event, response_event], - ) - - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result == call_event - - def test_last_event_has_multiple_function_responses(self): - """Test when last event has multiple function responses.""" - # Create function calls - function_call1 = types.FunctionCall( - id="func_123", name="test_func1", args={} - ) - function_call2 = types.FunctionCall( - id="func_456", name="test_func2", args={} - ) - - # Create function responses - function_response1 = types.FunctionResponse( - id="func_123", name="test_func1", response={} - ) - function_response2 = types.FunctionResponse( - id="func_456", name="test_func2", response={} - ) - - call_event1 = Event( - invocation_id="inv1", - author="agent1", - content=types.Content( - role="model", parts=[types.Part(function_call=function_call1)] - ), - ) - - call_event2 = Event( - invocation_id="inv2", - author="agent2", - content=types.Content( - role="model", parts=[types.Part(function_call=function_call2)] - ), - ) - - response_event = Event( - invocation_id="inv3", - author="user", - content=types.Content( - role="user", - parts=[ - types.Part(function_response=function_response1), - types.Part(function_response=function_response2), - ], - ), - ) - - session = Session( - id="test_session", - user_id="test_user", - app_name="test_app", - events=[call_event1, call_event2, response_event], - ) - - # Should return the first matching function call event found - result = _find_function_call_event_if_last_event_is_function_response( - session - ) - assert result == call_event1 # First match (func_123) - - class TestRunnerFindAgentToRun: """Tests for Runner._find_agent_to_run method."""