feat: Add A2A endpoints to fast api server when --a2a option is specified (WIP)

PiperOrigin-RevId: 776211580
This commit is contained in:
Xiang (Sean) Zhou
2025-06-26 11:18:12 -07:00
committed by Copybara-Service
parent 22629a17bd
commit e79651cd86
2 changed files with 240 additions and 5 deletions
+81 -4
View File
@@ -16,6 +16,7 @@ from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
import json
import logging
import os
from pathlib import Path
@@ -32,7 +33,6 @@ from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
@@ -53,7 +53,6 @@ from typing_extensions import override
from ..agents import RunConfig
from ..agents.live_request_queue import LiveRequest
from ..agents.live_request_queue import LiveRequestQueue
from ..agents.llm_agent import Agent
from ..agents.run_config import StreamingMode
from ..artifacts.gcs_artifact_service import GcsArtifactService
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
@@ -65,8 +64,6 @@ from ..evaluation.eval_metrics import EvalMetric
from ..evaluation.eval_metrics import EvalMetricResult
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
from ..evaluation.eval_result import EvalSetResult
from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
from ..events.event import Event
@@ -965,6 +962,86 @@ def get_fast_api_app(
runner_dict[app_name] = runner
return runner
if a2a:
try:
from a2a.server.apps import A2AStarletteApplication
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryTaskStore
from a2a.types import AgentCard
from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor
except ImportError as e:
import sys
if sys.version_info < (3, 10):
raise ImportError(
"A2A requires Python 3.10 or above. Please upgrade your Python"
" version."
) from e
else:
raise e
# locate all a2a agent apps in the agents directory
base_path = Path.cwd() / agents_dir
# the root agents directory should be an existing folder
if base_path.exists() and base_path.is_dir():
a2a_task_store = InMemoryTaskStore()
def create_a2a_runner_loader(captured_app_name: str):
"""Factory function to create A2A runner with proper closure."""
async def _get_a2a_runner_async() -> Runner:
return await _get_runner_async(captured_app_name)
return _get_a2a_runner_async
for p in base_path.iterdir():
# only folders with an agent.json file representing agent card are valid
# a2a agents
if (
p.is_file()
or p.name.startswith((".", "__pycache__"))
or not (p / "agent.json").is_file()
):
continue
app_name = p.name
logger.info("Setting up A2A agent: %s", app_name)
try:
a2a_rpc_path = f"http://{host}:{port}/a2a/{app_name}"
agent_executor = A2aAgentExecutor(
runner=create_a2a_runner_loader(app_name),
)
request_handler = DefaultRequestHandler(
agent_executor=agent_executor, task_store=a2a_task_store
)
with (p / "agent.json").open("r", encoding="utf-8") as f:
data = json.load(f)
agent_card = AgentCard(**data)
agent_card.url = a2a_rpc_path
a2a_app = A2AStarletteApplication(
agent_card=agent_card,
http_handler=request_handler,
)
routes = a2a_app.routes(
rpc_url=f"/a2a/{app_name}",
agent_card_url=f"/a2a/{app_name}/.well-known/agent.json",
)
for new_route in routes:
app.router.routes.append(new_route)
logger.info("Successfully configured A2A agent: %s", app_name)
except Exception as e:
logger.error("Failed to setup A2A agent %s: %s", app_name, e)
# Continue with other agents even if one fails
if web:
import mimetypes
+159 -1
View File
@@ -13,10 +13,14 @@
# limitations under the License.
import asyncio
import json
import logging
import os
from pathlib import Path
import sys
import tempfile
import time
from typing import Any
from typing import Optional
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -465,6 +469,9 @@ def test_app(
artifact_service_uri="",
memory_service_uri="",
allow_origins=["*"],
a2a=False, # Disable A2A for most tests
host="127.0.0.1",
port=8000,
)
# Create a TestClient that doesn't start a real server
@@ -520,6 +527,134 @@ async def create_test_eval_set(
return test_session_info
@pytest.fixture
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
)
def temp_agents_dir_with_a2a():
"""Create a temporary agents directory with A2A agent configurations for testing."""
with tempfile.TemporaryDirectory() as temp_dir:
# Create test agent directory
agent_dir = Path(temp_dir) / "test_a2a_agent"
agent_dir.mkdir()
# Create agent.json file
agent_card = {
"name": "test_a2a_agent",
"description": "Test A2A agent",
"version": "1.0.0",
"author": "test",
"capabilities": ["text"],
}
with open(agent_dir / "agent.json", "w") as f:
json.dump(agent_card, f)
# Create a simple agent.py file
agent_py_content = """
from google.adk.agents.base_agent import BaseAgent
class TestA2AAgent(BaseAgent):
def __init__(self):
super().__init__(name="test_a2a_agent")
"""
with open(agent_dir / "agent.py", "w") as f:
f.write(agent_py_content)
yield temp_dir
@pytest.fixture
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
)
def test_app_with_a2a(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
temp_agents_dir_with_a2a,
):
"""Create a TestClient for the FastAPI app with A2A enabled."""
# Mock A2A related classes
with (
patch("signal.signal", return_value=None),
patch(
"google.adk.cli.fast_api.InMemorySessionService",
return_value=mock_session_service,
),
patch(
"google.adk.cli.fast_api.InMemoryArtifactService",
return_value=mock_artifact_service,
),
patch(
"google.adk.cli.fast_api.InMemoryMemoryService",
return_value=mock_memory_service,
),
patch(
"google.adk.cli.fast_api.AgentLoader",
return_value=mock_agent_loader,
),
patch(
"google.adk.cli.fast_api.LocalEvalSetsManager",
return_value=mock_eval_sets_manager,
),
patch(
"google.adk.cli.fast_api.LocalEvalSetResultsManager",
return_value=mock_eval_set_results_manager,
),
patch(
"google.adk.cli.cli_eval.run_evals",
new=mock_run_evals_for_fast_api,
),
patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store,
patch(
"google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"
) as mock_executor,
patch(
"a2a.server.request_handlers.DefaultRequestHandler"
) as mock_handler,
patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app,
):
# Configure mocks
mock_task_store.return_value = MagicMock()
mock_executor.return_value = MagicMock()
mock_handler.return_value = MagicMock()
# Mock A2AStarletteApplication
mock_app_instance = MagicMock()
mock_app_instance.routes.return_value = (
[]
) # Return empty routes for testing
mock_a2a_app.return_value = mock_app_instance
# Change to temp directory
original_cwd = os.getcwd()
os.chdir(temp_agents_dir_with_a2a)
try:
app = get_fast_api_app(
agents_dir=".",
web=True,
session_service_uri="",
artifact_service_uri="",
memory_service_uri="",
allow_origins=["*"],
a2a=True,
host="127.0.0.1",
port=8000,
)
client = TestClient(app)
yield client
finally:
os.chdir(original_cwd)
#################################################
# Test Cases
#################################################
@@ -760,5 +895,28 @@ def test_debug_trace(test_app):
logger.info("Debug trace test completed successfully")
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
)
def test_a2a_agent_discovery(test_app_with_a2a):
"""Test that A2A agents are properly discovered and configured."""
# This test mainly verifies that the A2A setup doesn't break the app
response = test_app_with_a2a.get("/list-apps")
assert response.status_code == 200
logger.info("A2A agent discovery test passed")
@pytest.mark.skipif(
sys.version_info < (3, 10), reason="A2A requires Python 3.10+"
)
def test_a2a_disabled_by_default(test_app):
"""Test that A2A functionality is disabled by default."""
# The regular test_app fixture has a2a=False
# This test ensures no A2A routes are added
response = test_app.get("/list-apps")
assert response.status_code == 200
logger.info("A2A disabled by default test passed")
if __name__ == "__main__":
pytest.main(["-xvs", __file__])