diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index e57e0c8f..b757cc03 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -1474,17 +1474,7 @@ class AdkWebServer: yield f"data: {sse_event}\n\n" except Exception as e: logger.exception("Error in event_generator: %s", e) - # Yield a proper Event object for the error - error_event = Event( - author="system", - content=types.Content( - role="model", parts=[types.Part(text=f"Error: {e}")] - ), - ) - yield ( - "data:" - f" {error_event.model_dump_json(by_alias=True, exclude_none=True)}\n\n" - ) + yield f"data: {json.dumps({'error': str(e)})}\n\n" # Returns a streaming response with the proper media type for SSE return StreamingResponse( diff --git a/src/google/adk/cli/conformance/adk_web_server_client.py b/src/google/adk/cli/conformance/adk_web_server_client.py index a249413b..59bca621 100644 --- a/src/google/adk/cli/conformance/adk_web_server_client.py +++ b/src/google/adk/cli/conformance/adk_web_server_client.py @@ -228,6 +228,7 @@ class AdkWebServerClient: ValueError: If mode is provided but test_case_dir or user_message_index is None httpx.HTTPStatusError: If the request fails json.JSONDecodeError: If event data cannot be parsed + RuntimeError: If the server streams an error payload """ # Add recording parameters to state_delta for conformance tests if mode: @@ -262,6 +263,8 @@ class AdkWebServerClient: async for line in response.aiter_lines(): if line.startswith("data:") and (data := line[5:].strip()): event_data = json.loads(data) + if isinstance(event_data, dict) and "error" in event_data: + raise RuntimeError(event_data["error"]) yield Event.model_validate(event_data) else: logger.debug("Non data line received: %s", line) diff --git a/tests/unittests/cli/conformance/test_adk_web_server_client.py b/tests/unittests/cli/conformance/test_adk_web_server_client.py index 3e5cc2f1..3676fb92 100644 --- a/tests/unittests/cli/conformance/test_adk_web_server_client.py +++ b/tests/unittests/cli/conformance/test_adk_web_server_client.py @@ -224,6 +224,44 @@ async def test_run_agent(): assert events[1].invocation_id == "test_invocation_2" +@pytest.mark.asyncio +async def test_run_agent_raises_on_streamed_error(): + client = AdkWebServerClient() + + class MockStreamResponse: + + def raise_for_status(self): + pass + + async def aiter_lines(self): + yield 'data: {"error": "boom"}' + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + def mock_stream(*_args, **_kwargs): + return MockStreamResponse() + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.stream = mock_stream + mock_client_class.return_value = mock_client + + request = RunAgentRequest( + app_name="test_app", + user_id="test_user", + session_id="test_session", + new_message=types.Content(role="user", parts=[types.Part(text="Hi")]), + ) + + with pytest.raises(RuntimeError, match="boom"): + async for _ in client.run_agent(request): + pass + + @pytest.mark.asyncio async def test_close(): client = AdkWebServerClient() diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index fa89021e..4da804d5 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -1019,6 +1019,48 @@ def test_agent_run_sse_splits_artifact_delta( assert sse_events[1]["actions"]["artifactDelta"] == {"artifact.txt": 0} +def test_agent_run_sse_yields_error_object_on_exception( + test_app, create_test_session, monkeypatch +): + """Test /run_sse streams an error object if streaming raises.""" + info = create_test_session + + async def run_async_raises( + self, + *, + user_id: str, + session_id: str, + invocation_id: Optional[str] = None, + new_message: Optional[types.Content] = None, + state_delta: Optional[dict[str, Any]] = None, + run_config: Optional[RunConfig] = None, + ): + del user_id, session_id, invocation_id, new_message, state_delta, run_config + raise ValueError("boom") + if False: # pylint: disable=using-constant-test + yield _event_1() + + monkeypatch.setattr(Runner, "run_async", run_async_raises) + + payload = { + "app_name": info["app_name"], + "user_id": info["user_id"], + "session_id": info["session_id"], + "new_message": {"role": "user", "parts": [{"text": "Hello agent"}]}, + "streaming": True, + } + + response = test_app.post("/run_sse", json=payload) + assert response.status_code == 200 + + sse_events = [ + json.loads(line.removeprefix("data: ")) + for line in response.text.splitlines() + if line.startswith("data: ") + ] + assert sse_events == [{"error": "boom"}] + + def test_list_artifact_names(test_app, create_test_session): """Test listing artifact names for a session.""" info = create_test_session