You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add A2A endpoints to fast api server when --a2a option is specified (WIP)
PiperOrigin-RevId: 776211580
This commit is contained in:
committed by
Copybara-Service
parent
22629a17bd
commit
e79651cd86
@@ -16,6 +16,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -32,7 +33,6 @@ from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -53,7 +53,6 @@ from typing_extensions import override
|
||||
from ..agents import RunConfig
|
||||
from ..agents.live_request_queue import LiveRequest
|
||||
from ..agents.live_request_queue import LiveRequestQueue
|
||||
from ..agents.llm_agent import Agent
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..artifacts.gcs_artifact_service import GcsArtifactService
|
||||
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
@@ -65,8 +64,6 @@ from ..evaluation.eval_metrics import EvalMetric
|
||||
from ..evaluation.eval_metrics import EvalMetricResult
|
||||
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
|
||||
from ..evaluation.eval_result import EvalSetResult
|
||||
from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
|
||||
from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
|
||||
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
|
||||
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
||||
from ..events.event import Event
|
||||
@@ -965,6 +962,86 @@ def get_fast_api_app(
|
||||
runner_dict[app_name] = runner
|
||||
return runner
|
||||
|
||||
if a2a:
|
||||
try:
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
from a2a.server.request_handlers import DefaultRequestHandler
|
||||
from a2a.server.tasks import InMemoryTaskStore
|
||||
from a2a.types import AgentCard
|
||||
|
||||
from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor
|
||||
|
||||
except ImportError as e:
|
||||
import sys
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
"A2A requires Python 3.10 or above. Please upgrade your Python"
|
||||
" version."
|
||||
) from e
|
||||
else:
|
||||
raise e
|
||||
# locate all a2a agent apps in the agents directory
|
||||
base_path = Path.cwd() / agents_dir
|
||||
# the root agents directory should be an existing folder
|
||||
if base_path.exists() and base_path.is_dir():
|
||||
a2a_task_store = InMemoryTaskStore()
|
||||
|
||||
def create_a2a_runner_loader(captured_app_name: str):
|
||||
"""Factory function to create A2A runner with proper closure."""
|
||||
|
||||
async def _get_a2a_runner_async() -> Runner:
|
||||
return await _get_runner_async(captured_app_name)
|
||||
|
||||
return _get_a2a_runner_async
|
||||
|
||||
for p in base_path.iterdir():
|
||||
# only folders with an agent.json file representing agent card are valid
|
||||
# a2a agents
|
||||
if (
|
||||
p.is_file()
|
||||
or p.name.startswith((".", "__pycache__"))
|
||||
or not (p / "agent.json").is_file()
|
||||
):
|
||||
continue
|
||||
|
||||
app_name = p.name
|
||||
logger.info("Setting up A2A agent: %s", app_name)
|
||||
|
||||
try:
|
||||
a2a_rpc_path = f"http://{host}:{port}/a2a/{app_name}"
|
||||
|
||||
agent_executor = A2aAgentExecutor(
|
||||
runner=create_a2a_runner_loader(app_name),
|
||||
)
|
||||
|
||||
request_handler = DefaultRequestHandler(
|
||||
agent_executor=agent_executor, task_store=a2a_task_store
|
||||
)
|
||||
|
||||
with (p / "agent.json").open("r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
agent_card = AgentCard(**data)
|
||||
agent_card.url = a2a_rpc_path
|
||||
|
||||
a2a_app = A2AStarletteApplication(
|
||||
agent_card=agent_card,
|
||||
http_handler=request_handler,
|
||||
)
|
||||
|
||||
routes = a2a_app.routes(
|
||||
rpc_url=f"/a2a/{app_name}",
|
||||
agent_card_url=f"/a2a/{app_name}/.well-known/agent.json",
|
||||
)
|
||||
|
||||
for new_route in routes:
|
||||
app.router.routes.append(new_route)
|
||||
|
||||
logger.info("Successfully configured A2A agent: %s", app_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to setup A2A agent %s: %s", app_name, e)
|
||||
# Continue with other agents even if one fails
|
||||
if web:
|
||||
import mimetypes
|
||||
|
||||
|
||||
@@ -13,10 +13,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -465,6 +469,9 @@ def test_app(
|
||||
artifact_service_uri="",
|
||||
memory_service_uri="",
|
||||
allow_origins=["*"],
|
||||
a2a=False, # Disable A2A for most tests
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
)
|
||||
|
||||
# Create a TestClient that doesn't start a real server
|
||||
@@ -520,6 +527,134 @@ async def create_test_eval_set(
|
||||
return test_session_info
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
|
||||
)
|
||||
def temp_agents_dir_with_a2a():
|
||||
"""Create a temporary agents directory with A2A agent configurations for testing."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create test agent directory
|
||||
agent_dir = Path(temp_dir) / "test_a2a_agent"
|
||||
agent_dir.mkdir()
|
||||
|
||||
# Create agent.json file
|
||||
agent_card = {
|
||||
"name": "test_a2a_agent",
|
||||
"description": "Test A2A agent",
|
||||
"version": "1.0.0",
|
||||
"author": "test",
|
||||
"capabilities": ["text"],
|
||||
}
|
||||
|
||||
with open(agent_dir / "agent.json", "w") as f:
|
||||
json.dump(agent_card, f)
|
||||
|
||||
# Create a simple agent.py file
|
||||
agent_py_content = """
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
|
||||
class TestA2AAgent(BaseAgent):
|
||||
def __init__(self):
|
||||
super().__init__(name="test_a2a_agent")
|
||||
"""
|
||||
|
||||
with open(agent_dir / "agent.py", "w") as f:
|
||||
f.write(agent_py_content)
|
||||
|
||||
yield temp_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
|
||||
)
|
||||
def test_app_with_a2a(
|
||||
mock_session_service,
|
||||
mock_artifact_service,
|
||||
mock_memory_service,
|
||||
mock_agent_loader,
|
||||
mock_eval_sets_manager,
|
||||
mock_eval_set_results_manager,
|
||||
temp_agents_dir_with_a2a,
|
||||
):
|
||||
"""Create a TestClient for the FastAPI app with A2A enabled."""
|
||||
|
||||
# Mock A2A related classes
|
||||
with (
|
||||
patch("signal.signal", return_value=None),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.InMemorySessionService",
|
||||
return_value=mock_session_service,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.InMemoryArtifactService",
|
||||
return_value=mock_artifact_service,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.InMemoryMemoryService",
|
||||
return_value=mock_memory_service,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.AgentLoader",
|
||||
return_value=mock_agent_loader,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.LocalEvalSetsManager",
|
||||
return_value=mock_eval_sets_manager,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.LocalEvalSetResultsManager",
|
||||
return_value=mock_eval_set_results_manager,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.cli_eval.run_evals",
|
||||
new=mock_run_evals_for_fast_api,
|
||||
),
|
||||
patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store,
|
||||
patch(
|
||||
"google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"
|
||||
) as mock_executor,
|
||||
patch(
|
||||
"a2a.server.request_handlers.DefaultRequestHandler"
|
||||
) as mock_handler,
|
||||
patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app,
|
||||
):
|
||||
# Configure mocks
|
||||
mock_task_store.return_value = MagicMock()
|
||||
mock_executor.return_value = MagicMock()
|
||||
mock_handler.return_value = MagicMock()
|
||||
|
||||
# Mock A2AStarletteApplication
|
||||
mock_app_instance = MagicMock()
|
||||
mock_app_instance.routes.return_value = (
|
||||
[]
|
||||
) # Return empty routes for testing
|
||||
mock_a2a_app.return_value = mock_app_instance
|
||||
|
||||
# Change to temp directory
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(temp_agents_dir_with_a2a)
|
||||
|
||||
try:
|
||||
app = get_fast_api_app(
|
||||
agents_dir=".",
|
||||
web=True,
|
||||
session_service_uri="",
|
||||
artifact_service_uri="",
|
||||
memory_service_uri="",
|
||||
allow_origins=["*"],
|
||||
a2a=True,
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
yield client
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
#################################################
|
||||
# Test Cases
|
||||
#################################################
|
||||
@@ -760,5 +895,28 @@ def test_debug_trace(test_app):
|
||||
logger.info("Debug trace test completed successfully")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
|
||||
)
|
||||
def test_a2a_agent_discovery(test_app_with_a2a):
|
||||
"""Test that A2A agents are properly discovered and configured."""
|
||||
# This test mainly verifies that the A2A setup doesn't break the app
|
||||
response = test_app_with_a2a.get("/list-apps")
|
||||
assert response.status_code == 200
|
||||
logger.info("A2A agent discovery test passed")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
|
||||
)
|
||||
def test_a2a_disabled_by_default(test_app):
|
||||
"""Test that A2A functionality is disabled by default."""
|
||||
# The regular test_app fixture has a2a=False
|
||||
# This test ensures no A2A routes are added
|
||||
response = test_app.get("/list-apps")
|
||||
assert response.status_code == 200
|
||||
logger.info("A2A disabled by default test passed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-xvs", __file__])
|
||||
|
||||
Reference in New Issue
Block a user