From 847df1638cbf1686aa43e8e094121d4e23e40245 Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 9 Oct 2025 08:59:48 -0700 Subject: [PATCH] fix: handle `App` instances returned by `agent_loader.load_agent` The `agent_loader.load_agent` method can now return an `App` object. This change unwraps the `App` to get its `root_agent` before passing it to the graph builder, makes sure a `BaseAgent` instance is always used PiperOrigin-RevId: 817209601 --- src/google/adk/cli/adk_web_server.py | 9 ++++- tests/unittests/cli/test_fast_api.py | 52 ++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 98c77f5d..76ba4ce9 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -1429,7 +1429,14 @@ class AdkWebServer: function_calls = event.get_function_calls() function_responses = event.get_function_responses() - root_agent = self.agent_loader.load_agent(app_name) + agent_or_app = self.agent_loader.load_agent(app_name) + # The loader may return an App; unwrap to its root agent so the graph builder + # receives a BaseAgent instance. + root_agent = ( + agent_or_app.root_agent + if isinstance(agent_or_app, App) + else agent_or_app + ) dot_graph = None if function_calls: function_call_highlights = [] diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index a18b409e..90bba32c 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -29,6 +29,7 @@ from unittest.mock import patch from fastapi.testclient import TestClient from google.adk.agents.base_agent import BaseAgent from google.adk.agents.run_config import RunConfig +from google.adk.apps.app import App from google.adk.cli.fast_api import get_fast_api_app from google.adk.evaluation.eval_case import EvalCase from google.adk.evaluation.eval_case import Invocation @@ -39,6 +40,7 @@ from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.runners import Runner from google.adk.sessions.base_session_service import ListSessionsResponse +from google.adk.sessions.session import Session from google.genai import types from pydantic import BaseModel import pytest @@ -1007,6 +1009,56 @@ def test_debug_trace(test_app): logger.info("Debug trace test completed successfully") +def test_get_event_graph_returns_dot_src_for_app_agent(): + """Ensure graph endpoint unwraps App instances before building the graph.""" + from google.adk.cli.adk_web_server import AdkWebServer + + root_agent = DummyAgent(name="dummy_agent") + app_agent = App(name="test_app", root_agent=root_agent) + + class Loader: + + def load_agent(self, app_name): + return app_agent + + def list_agents(self): + return [app_agent.name] + + session_service = AsyncMock() + session = Session( + id="session_id", + app_name="test_app", + user_id="user", + state={}, + events=[Event(author="dummy_agent")], + ) + event_id = session.events[0].id + session_service.get_session.return_value = session + + adk_web_server = AdkWebServer( + agent_loader=Loader(), + session_service=session_service, + memory_service=MagicMock(), + artifact_service=MagicMock(), + credential_service=MagicMock(), + eval_sets_manager=MagicMock(), + eval_set_results_manager=MagicMock(), + agents_dir=".", + ) + + fast_api_app = adk_web_server.get_fast_api_app( + setup_observer=lambda _observer, _server: None, + tear_down_observer=lambda _observer, _server: None, + ) + + client = TestClient(fast_api_app) + response = client.get( + f"/apps/test_app/users/user/sessions/session_id/events/{event_id}/graph" + ) + assert response.status_code == 200 + assert "dotSrc" in response.json() + + @pytest.mark.skipif( sys.version_info < (3, 10), reason="A2A requires Python 3.10+" )