You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Allow users to pass their own agent card to to_a2a method
PiperOrigin-RevId: 802763510
This commit is contained in:
committed by
Copybara-Service
parent
a30851ee16
commit
a1679dae3f
@@ -21,6 +21,7 @@ try:
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
from a2a.server.request_handlers import DefaultRequestHandler
|
||||
from a2a.server.tasks import InMemoryTaskStore
|
||||
from a2a.types import AgentCard
|
||||
except ImportError as e:
|
||||
if sys.version_info < (3, 10):
|
||||
raise ImportError(
|
||||
@@ -29,6 +30,9 @@ except ImportError as e:
|
||||
else:
|
||||
raise e
|
||||
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from ...agents.base_agent import BaseAgent
|
||||
@@ -43,6 +47,41 @@ from ..experimental import a2a_experimental
|
||||
from .agent_card_builder import AgentCardBuilder
|
||||
|
||||
|
||||
def _load_agent_card(
|
||||
agent_card: Optional[Union[AgentCard, str]],
|
||||
) -> Optional[AgentCard]:
|
||||
"""Load agent card from various sources.
|
||||
|
||||
Args:
|
||||
agent_card: AgentCard object, path to JSON file, or None
|
||||
|
||||
Returns:
|
||||
AgentCard object or None if no agent card provided
|
||||
|
||||
Raises:
|
||||
ValueError: If loading agent card from file fails
|
||||
"""
|
||||
if agent_card is None:
|
||||
return None
|
||||
|
||||
if isinstance(agent_card, str):
|
||||
# Load agent card from file path
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
path = Path(agent_card)
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
agent_card_data = json.load(f)
|
||||
return AgentCard(**agent_card_data)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to load agent card from {agent_card}: {e}"
|
||||
) from e
|
||||
else:
|
||||
return agent_card
|
||||
|
||||
|
||||
@a2a_experimental
|
||||
def to_a2a(
|
||||
agent: BaseAgent,
|
||||
@@ -50,6 +89,7 @@ def to_a2a(
|
||||
host: str = "localhost",
|
||||
port: int = 8000,
|
||||
protocol: str = "http",
|
||||
agent_card: Optional[Union[AgentCard, str]] = None,
|
||||
) -> Starlette:
|
||||
"""Convert an ADK agent to a A2A Starlette application.
|
||||
|
||||
@@ -58,6 +98,9 @@ def to_a2a(
|
||||
host: The host for the A2A RPC URL (default: "localhost")
|
||||
port: The port for the A2A RPC URL (default: 8000)
|
||||
protocol: The protocol for the A2A RPC URL (default: "http")
|
||||
agent_card: Optional pre-built AgentCard object or path to agent card
|
||||
JSON. If not provided, will be built automatically from the
|
||||
agent.
|
||||
|
||||
Returns:
|
||||
A Starlette application that can be run with uvicorn
|
||||
@@ -66,6 +109,9 @@ def to_a2a(
|
||||
agent = MyAgent()
|
||||
app = to_a2a(agent, host="localhost", port=8000, protocol="http")
|
||||
# Then run with: uvicorn module:app --host localhost --port 8000
|
||||
|
||||
# Or with custom agent card:
|
||||
app = to_a2a(agent, agent_card=my_custom_agent_card)
|
||||
"""
|
||||
# Set up ADK logging to ensure logs are visible when using uvicorn directly
|
||||
setup_adk_logger(logging.INFO)
|
||||
@@ -93,8 +139,10 @@ def to_a2a(
|
||||
agent_executor=agent_executor, task_store=task_store
|
||||
)
|
||||
|
||||
# Build agent card
|
||||
# Use provided agent card or build one from the agent
|
||||
rpc_url = f"{protocol}://{host}:{port}/"
|
||||
provided_agent_card = _load_agent_card(agent_card)
|
||||
|
||||
card_builder = AgentCardBuilder(
|
||||
agent=agent,
|
||||
rpc_url=rpc_url,
|
||||
@@ -105,12 +153,15 @@ def to_a2a(
|
||||
|
||||
# Add startup handler to build the agent card and configure A2A routes
|
||||
async def setup_a2a():
|
||||
# Build the agent card asynchronously
|
||||
agent_card = await card_builder.build()
|
||||
# Use provided agent card or build one asynchronously
|
||||
if provided_agent_card is not None:
|
||||
final_agent_card = provided_agent_card
|
||||
else:
|
||||
final_agent_card = await card_builder.build()
|
||||
|
||||
# Create the A2A Starlette application
|
||||
a2a_app = A2AStarletteApplication(
|
||||
agent_card=agent_card,
|
||||
agent_card=final_agent_card,
|
||||
http_handler=request_handler,
|
||||
)
|
||||
|
||||
|
||||
@@ -689,3 +689,183 @@ class TestToA2A:
|
||||
mock_card_builder_class.assert_called_once_with(
|
||||
agent=self.mock_agent, rpc_url="http://192.168.1.1:8000/"
|
||||
)
|
||||
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.Starlette")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.A2AStarletteApplication")
|
||||
async def test_to_a2a_with_custom_agent_card_object(
|
||||
self,
|
||||
mock_a2a_app_class,
|
||||
mock_starlette_class,
|
||||
mock_card_builder_class,
|
||||
mock_task_store_class,
|
||||
mock_request_handler_class,
|
||||
mock_agent_executor_class,
|
||||
):
|
||||
"""Test to_a2a with custom AgentCard object."""
|
||||
# Arrange
|
||||
mock_app = Mock(spec=Starlette)
|
||||
mock_starlette_class.return_value = mock_app
|
||||
mock_task_store = Mock(spec=InMemoryTaskStore)
|
||||
mock_task_store_class.return_value = mock_task_store
|
||||
mock_agent_executor = Mock(spec=A2aAgentExecutor)
|
||||
mock_agent_executor_class.return_value = mock_agent_executor
|
||||
mock_request_handler = Mock(spec=DefaultRequestHandler)
|
||||
mock_request_handler_class.return_value = mock_request_handler
|
||||
mock_card_builder = Mock(spec=AgentCardBuilder)
|
||||
mock_card_builder_class.return_value = mock_card_builder
|
||||
mock_a2a_app = Mock(spec=A2AStarletteApplication)
|
||||
mock_a2a_app_class.return_value = mock_a2a_app
|
||||
|
||||
# Create a custom agent card
|
||||
custom_agent_card = Mock(spec=AgentCard)
|
||||
custom_agent_card.name = "custom_agent"
|
||||
|
||||
# Act
|
||||
result = to_a2a(self.mock_agent, agent_card=custom_agent_card)
|
||||
|
||||
# Assert
|
||||
assert result == mock_app
|
||||
# Get the setup_a2a function that was added as startup handler
|
||||
startup_handler = mock_app.add_event_handler.call_args[0][1]
|
||||
|
||||
# Call the setup_a2a function
|
||||
await startup_handler()
|
||||
|
||||
# Verify the card builder build method was NOT called since we provided a card
|
||||
mock_card_builder.build.assert_not_called()
|
||||
|
||||
# Verify A2A Starlette application was created with custom card
|
||||
mock_a2a_app_class.assert_called_once_with(
|
||||
agent_card=custom_agent_card,
|
||||
http_handler=mock_request_handler,
|
||||
)
|
||||
|
||||
# Verify routes were added to the main app
|
||||
mock_a2a_app.add_routes_to_app.assert_called_once_with(mock_app)
|
||||
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.Starlette")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.A2AStarletteApplication")
|
||||
@patch("json.load")
|
||||
@patch("pathlib.Path.open")
|
||||
@patch("pathlib.Path")
|
||||
async def test_to_a2a_with_agent_card_file_path(
|
||||
self,
|
||||
mock_path_class,
|
||||
mock_open,
|
||||
mock_json_load,
|
||||
mock_a2a_app_class,
|
||||
mock_starlette_class,
|
||||
mock_card_builder_class,
|
||||
mock_task_store_class,
|
||||
mock_request_handler_class,
|
||||
mock_agent_executor_class,
|
||||
):
|
||||
"""Test to_a2a with agent card file path."""
|
||||
# Arrange
|
||||
mock_app = Mock(spec=Starlette)
|
||||
mock_starlette_class.return_value = mock_app
|
||||
mock_task_store = Mock(spec=InMemoryTaskStore)
|
||||
mock_task_store_class.return_value = mock_task_store
|
||||
mock_agent_executor = Mock(spec=A2aAgentExecutor)
|
||||
mock_agent_executor_class.return_value = mock_agent_executor
|
||||
mock_request_handler = Mock(spec=DefaultRequestHandler)
|
||||
mock_request_handler_class.return_value = mock_request_handler
|
||||
mock_card_builder = Mock(spec=AgentCardBuilder)
|
||||
mock_card_builder_class.return_value = mock_card_builder
|
||||
mock_a2a_app = Mock(spec=A2AStarletteApplication)
|
||||
mock_a2a_app_class.return_value = mock_a2a_app
|
||||
|
||||
# Mock file operations
|
||||
mock_path = Mock()
|
||||
mock_path_class.return_value = mock_path
|
||||
mock_file_handle = Mock()
|
||||
# Create a proper context manager mock
|
||||
mock_context_manager = Mock()
|
||||
mock_context_manager.__enter__ = Mock(return_value=mock_file_handle)
|
||||
mock_context_manager.__exit__ = Mock(return_value=None)
|
||||
mock_path.open = Mock(return_value=mock_context_manager)
|
||||
|
||||
# Mock agent card data from file with all required fields
|
||||
agent_card_data = {
|
||||
"name": "file_agent",
|
||||
"url": "http://example.com",
|
||||
"description": "Test agent from file",
|
||||
"version": "1.0.0",
|
||||
"capabilities": {},
|
||||
"skills": [],
|
||||
"defaultInputModes": ["text/plain"],
|
||||
"defaultOutputModes": ["text/plain"],
|
||||
"supportsAuthenticatedExtendedCard": False,
|
||||
}
|
||||
mock_json_load.return_value = agent_card_data
|
||||
|
||||
# Act
|
||||
result = to_a2a(self.mock_agent, agent_card="/path/to/agent_card.json")
|
||||
|
||||
# Assert
|
||||
assert result == mock_app
|
||||
# Get the setup_a2a function that was added as startup handler
|
||||
startup_handler = mock_app.add_event_handler.call_args[0][1]
|
||||
|
||||
# Call the setup_a2a function
|
||||
await startup_handler()
|
||||
|
||||
# Verify file was opened and JSON was loaded
|
||||
mock_path_class.assert_called_once_with("/path/to/agent_card.json")
|
||||
mock_path.open.assert_called_once_with("r", encoding="utf-8")
|
||||
mock_json_load.assert_called_once_with(mock_file_handle)
|
||||
|
||||
# Verify the card builder build method was NOT called since we provided a card
|
||||
mock_card_builder.build.assert_not_called()
|
||||
|
||||
# Verify A2A Starlette application was created with loaded card
|
||||
mock_a2a_app_class.assert_called_once()
|
||||
args, kwargs = mock_a2a_app_class.call_args
|
||||
assert kwargs["http_handler"] == mock_request_handler
|
||||
# The agent_card should be an AgentCard object created from loaded data
|
||||
assert hasattr(kwargs["agent_card"], "name")
|
||||
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder")
|
||||
@patch("google.adk.a2a.utils.agent_to_a2a.Starlette")
|
||||
@patch("pathlib.Path.open", side_effect=FileNotFoundError("File not found"))
|
||||
@patch("pathlib.Path")
|
||||
def test_to_a2a_with_invalid_agent_card_file_path(
|
||||
self,
|
||||
mock_path_class,
|
||||
mock_open,
|
||||
mock_starlette_class,
|
||||
mock_card_builder_class,
|
||||
mock_task_store_class,
|
||||
mock_request_handler_class,
|
||||
mock_agent_executor_class,
|
||||
):
|
||||
"""Test to_a2a with invalid agent card file path."""
|
||||
# Arrange
|
||||
mock_app = Mock(spec=Starlette)
|
||||
mock_starlette_class.return_value = mock_app
|
||||
mock_task_store = Mock(spec=InMemoryTaskStore)
|
||||
mock_task_store_class.return_value = mock_task_store
|
||||
mock_agent_executor = Mock(spec=A2aAgentExecutor)
|
||||
mock_agent_executor_class.return_value = mock_agent_executor
|
||||
mock_request_handler = Mock(spec=DefaultRequestHandler)
|
||||
mock_request_handler_class.return_value = mock_request_handler
|
||||
mock_card_builder = Mock(spec=AgentCardBuilder)
|
||||
mock_card_builder_class.return_value = mock_card_builder
|
||||
|
||||
mock_path = Mock()
|
||||
mock_path_class.return_value = mock_path
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Failed to load agent card from"):
|
||||
to_a2a(self.mock_agent, agent_card="/invalid/path.json")
|
||||
|
||||
Reference in New Issue
Block a user