You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: handle App instances returned by agent_loader.load_agent
The `agent_loader.load_agent` method can now return an `App` object. This change unwraps the `App` to get its `root_agent` before passing it to the graph builder, makes sure a `BaseAgent` instance is always used PiperOrigin-RevId: 817209601
This commit is contained in:
committed by
Copybara-Service
parent
55aa6f669b
commit
847df1638c
@@ -1429,7 +1429,14 @@ class AdkWebServer:
|
|||||||
|
|
||||||
function_calls = event.get_function_calls()
|
function_calls = event.get_function_calls()
|
||||||
function_responses = event.get_function_responses()
|
function_responses = event.get_function_responses()
|
||||||
root_agent = self.agent_loader.load_agent(app_name)
|
agent_or_app = self.agent_loader.load_agent(app_name)
|
||||||
|
# The loader may return an App; unwrap to its root agent so the graph builder
|
||||||
|
# receives a BaseAgent instance.
|
||||||
|
root_agent = (
|
||||||
|
agent_or_app.root_agent
|
||||||
|
if isinstance(agent_or_app, App)
|
||||||
|
else agent_or_app
|
||||||
|
)
|
||||||
dot_graph = None
|
dot_graph = None
|
||||||
if function_calls:
|
if function_calls:
|
||||||
function_call_highlights = []
|
function_call_highlights = []
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from unittest.mock import patch
|
|||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from google.adk.agents.base_agent import BaseAgent
|
from google.adk.agents.base_agent import BaseAgent
|
||||||
from google.adk.agents.run_config import RunConfig
|
from google.adk.agents.run_config import RunConfig
|
||||||
|
from google.adk.apps.app import App
|
||||||
from google.adk.cli.fast_api import get_fast_api_app
|
from google.adk.cli.fast_api import get_fast_api_app
|
||||||
from google.adk.evaluation.eval_case import EvalCase
|
from google.adk.evaluation.eval_case import EvalCase
|
||||||
from google.adk.evaluation.eval_case import Invocation
|
from google.adk.evaluation.eval_case import Invocation
|
||||||
@@ -39,6 +40,7 @@ from google.adk.events.event import Event
|
|||||||
from google.adk.events.event_actions import EventActions
|
from google.adk.events.event_actions import EventActions
|
||||||
from google.adk.runners import Runner
|
from google.adk.runners import Runner
|
||||||
from google.adk.sessions.base_session_service import ListSessionsResponse
|
from google.adk.sessions.base_session_service import ListSessionsResponse
|
||||||
|
from google.adk.sessions.session import Session
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import pytest
|
import pytest
|
||||||
@@ -1007,6 +1009,56 @@ def test_debug_trace(test_app):
|
|||||||
logger.info("Debug trace test completed successfully")
|
logger.info("Debug trace test completed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_event_graph_returns_dot_src_for_app_agent():
|
||||||
|
"""Ensure graph endpoint unwraps App instances before building the graph."""
|
||||||
|
from google.adk.cli.adk_web_server import AdkWebServer
|
||||||
|
|
||||||
|
root_agent = DummyAgent(name="dummy_agent")
|
||||||
|
app_agent = App(name="test_app", root_agent=root_agent)
|
||||||
|
|
||||||
|
class Loader:
|
||||||
|
|
||||||
|
def load_agent(self, app_name):
|
||||||
|
return app_agent
|
||||||
|
|
||||||
|
def list_agents(self):
|
||||||
|
return [app_agent.name]
|
||||||
|
|
||||||
|
session_service = AsyncMock()
|
||||||
|
session = Session(
|
||||||
|
id="session_id",
|
||||||
|
app_name="test_app",
|
||||||
|
user_id="user",
|
||||||
|
state={},
|
||||||
|
events=[Event(author="dummy_agent")],
|
||||||
|
)
|
||||||
|
event_id = session.events[0].id
|
||||||
|
session_service.get_session.return_value = session
|
||||||
|
|
||||||
|
adk_web_server = AdkWebServer(
|
||||||
|
agent_loader=Loader(),
|
||||||
|
session_service=session_service,
|
||||||
|
memory_service=MagicMock(),
|
||||||
|
artifact_service=MagicMock(),
|
||||||
|
credential_service=MagicMock(),
|
||||||
|
eval_sets_manager=MagicMock(),
|
||||||
|
eval_set_results_manager=MagicMock(),
|
||||||
|
agents_dir=".",
|
||||||
|
)
|
||||||
|
|
||||||
|
fast_api_app = adk_web_server.get_fast_api_app(
|
||||||
|
setup_observer=lambda _observer, _server: None,
|
||||||
|
tear_down_observer=lambda _observer, _server: None,
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(fast_api_app)
|
||||||
|
response = client.get(
|
||||||
|
f"/apps/test_app/users/user/sessions/session_id/events/{event_id}/graph"
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "dotSrc" in response.json()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
|
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user