fix: Stream errors as simple JSON objects in ADK web server SSE

The ADK web server's /run_sse endpoint now yields a JSON object like {"error": "..."} when an exception occurs during event generation. The adk_web_server_client is updated to detect this error payload and raise a RuntimeError.

Close #4291

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 863475838
This commit is contained in:
George Weale
2026-01-30 18:17:56 -08:00
committed by Copybara-Service
parent d0102ecea3
commit 798d0053c8
4 changed files with 84 additions and 11 deletions
+1 -11
View File
@@ -1474,17 +1474,7 @@ class AdkWebServer:
yield f"data: {sse_event}\n\n"
except Exception as e:
logger.exception("Error in event_generator: %s", e)
# Yield a proper Event object for the error
error_event = Event(
author="system",
content=types.Content(
role="model", parts=[types.Part(text=f"Error: {e}")]
),
)
yield (
"data:"
f" {error_event.model_dump_json(by_alias=True, exclude_none=True)}\n\n"
)
yield f"data: {json.dumps({'error': str(e)})}\n\n"
# Returns a streaming response with the proper media type for SSE
return StreamingResponse(
@@ -228,6 +228,7 @@ class AdkWebServerClient:
ValueError: If mode is provided but test_case_dir or user_message_index is None
httpx.HTTPStatusError: If the request fails
json.JSONDecodeError: If event data cannot be parsed
RuntimeError: If the server streams an error payload
"""
# Add recording parameters to state_delta for conformance tests
if mode:
@@ -262,6 +263,8 @@ class AdkWebServerClient:
async for line in response.aiter_lines():
if line.startswith("data:") and (data := line[5:].strip()):
event_data = json.loads(data)
if isinstance(event_data, dict) and "error" in event_data:
raise RuntimeError(event_data["error"])
yield Event.model_validate(event_data)
else:
logger.debug("Non data line received: %s", line)
@@ -224,6 +224,44 @@ async def test_run_agent():
assert events[1].invocation_id == "test_invocation_2"
@pytest.mark.asyncio
async def test_run_agent_raises_on_streamed_error():
client = AdkWebServerClient()
class MockStreamResponse:
def raise_for_status(self):
pass
async def aiter_lines(self):
yield 'data: {"error": "boom"}'
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
def mock_stream(*_args, **_kwargs):
return MockStreamResponse()
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.stream = mock_stream
mock_client_class.return_value = mock_client
request = RunAgentRequest(
app_name="test_app",
user_id="test_user",
session_id="test_session",
new_message=types.Content(role="user", parts=[types.Part(text="Hi")]),
)
with pytest.raises(RuntimeError, match="boom"):
async for _ in client.run_agent(request):
pass
@pytest.mark.asyncio
async def test_close():
client = AdkWebServerClient()
+42
View File
@@ -1019,6 +1019,48 @@ def test_agent_run_sse_splits_artifact_delta(
assert sse_events[1]["actions"]["artifactDelta"] == {"artifact.txt": 0}
def test_agent_run_sse_yields_error_object_on_exception(
test_app, create_test_session, monkeypatch
):
"""Test /run_sse streams an error object if streaming raises."""
info = create_test_session
async def run_async_raises(
self,
*,
user_id: str,
session_id: str,
invocation_id: Optional[str] = None,
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
):
del user_id, session_id, invocation_id, new_message, state_delta, run_config
raise ValueError("boom")
if False: # pylint: disable=using-constant-test
yield _event_1()
monkeypatch.setattr(Runner, "run_async", run_async_raises)
payload = {
"app_name": info["app_name"],
"user_id": info["user_id"],
"session_id": info["session_id"],
"new_message": {"role": "user", "parts": [{"text": "Hello agent"}]},
"streaming": True,
}
response = test_app.post("/run_sse", json=payload)
assert response.status_code == 200
sse_events = [
json.loads(line.removeprefix("data: "))
for line in response.text.splitlines()
if line.startswith("data: ")
]
assert sse_events == [{"error": "boom"}]
def test_list_artifact_names(test_app, create_test_session):
"""Test listing artifact names for a session."""
info = create_test_session