From c46308b7cfcfcdbc300e6dda8eeef86ac73d2e64 Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Thu, 2 Oct 2025 08:51:43 -0700 Subject: [PATCH] chore: Add session patch endpoint to api server for state update This is allow user to update session state without running the agent. e.g. if I want to test some case when session has certain state on adk web. PiperOrigin-RevId: 814252851 --- src/google/adk/cli/adk_web_server.py | 57 ++++++++++++ .../cli/conformance/adk_web_server_client.py | 30 +++++++ .../conformance/test_adk_web_server_client.py | 37 ++++++++ tests/unittests/cli/test_fast_api.py | 90 +++++++++++++++++++ 4 files changed, 214 insertions(+) diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 92d1e94c..2a2dbc05 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -220,6 +220,13 @@ class UpdateMemoryRequest(common.BaseModel): """The ID of the session to add to memory.""" +class UpdateSessionRequest(common.BaseModel): + """Request to update session state without running the agent.""" + + state_delta: dict[str, Any] + """The state changes to apply to the session.""" + + class RunEvalResult(common.BaseModel): eval_set_file: str eval_set_id: str @@ -767,6 +774,56 @@ class AdkWebServer: app_name=app_name, user_id=user_id, session_id=session_id ) + @app.patch( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}", + response_model_exclude_none=True, + ) + async def update_session( + app_name: str, + user_id: str, + session_id: str, + req: UpdateSessionRequest, + ) -> Session: + """Updates session state without running the agent. + + Args: + app_name: The name of the application. + user_id: The ID of the user. + session_id: The ID of the session to update. + req: The patch request containing state changes. + + Returns: + The updated session. + + Raises: + HTTPException: If the session is not found. + """ + session = await self.session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + # Create an event to record the state change + import uuid + + from ..events.event import Event + from ..events.event import EventActions + + state_update_event = Event( + invocation_id="p-" + str(uuid.uuid4()), + author="user", + actions=EventActions(state_delta=req.state_delta), + ) + + # Append the event to the session + # This will automatically update the session state through __update_session_state + await self.session_service.append_event( + session=session, event=state_update_event + ) + + return session + @app.post( "/apps/{app_name}/eval-sets", response_model_exclude_none=True, diff --git a/src/google/adk/cli/conformance/adk_web_server_client.py b/src/google/adk/cli/conformance/adk_web_server_client.py index e1f29478..88fe2ead 100644 --- a/src/google/adk/cli/conformance/adk_web_server_client.py +++ b/src/google/adk/cli/conformance/adk_web_server_client.py @@ -176,6 +176,36 @@ class AdkWebServerClient: ) response.raise_for_status() + async def update_session( + self, + *, + app_name: str, + user_id: str, + session_id: str, + state_delta: Dict[str, Any], + ) -> Session: + """Update session state without running the agent. + + Args: + app_name: Name of the application + user_id: User identifier + session_id: Session identifier to update + state_delta: The state changes to apply to the session + + Returns: + The updated Session object + + Raises: + httpx.HTTPStatusError: If the request fails or session not found + """ + async with self._get_client() as client: + response = await client.patch( + f"/apps/{app_name}/users/{user_id}/sessions/{session_id}", + json={"state_delta": state_delta}, + ) + response.raise_for_status() + return Session.model_validate(response.json()) + async def run_agent( self, request: RunAgentRequest, diff --git a/tests/unittests/cli/conformance/test_adk_web_server_client.py b/tests/unittests/cli/conformance/test_adk_web_server_client.py index 642b10a1..b2bfc43c 100644 --- a/tests/unittests/cli/conformance/test_adk_web_server_client.py +++ b/tests/unittests/cli/conformance/test_adk_web_server_client.py @@ -127,6 +127,43 @@ async def test_delete_session(): mock_response.raise_for_status.assert_called_once() +@pytest.mark.asyncio +async def test_update_session(): + client = AdkWebServerClient() + + # Mock the HTTP response + mock_response = MagicMock() + mock_response.json.return_value = { + "id": "test_session", + "app_name": "test_app", + "user_id": "test_user", + "events": [], + "state": {"key": "updated_value", "new_key": "new_value"}, + } + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.patch.return_value = mock_response + mock_client_class.return_value = mock_client + + state_delta = {"key": "updated_value", "new_key": "new_value"} + session = await client.update_session( + app_name="test_app", + user_id="test_user", + session_id="test_session", + state_delta=state_delta, + ) + + assert isinstance(session, Session) + assert session.id == "test_session" + assert session.state == {"key": "updated_value", "new_key": "new_value"} + mock_client.patch.assert_called_once_with( + "/apps/test_app/users/test_user/sessions/test_session", + json={"state_delta": state_delta}, + ) + mock_response.raise_for_status.assert_called_once() + + @pytest.mark.asyncio async def test_run_agent(): client = AdkWebServerClient() diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 4bcf2f11..a18b409e 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -288,6 +288,22 @@ def mock_session_service(): ): del session_data[app_name][user_id][session_id] + async def append_event(self, session, event): + """Append an event to a session.""" + # Update session state if event has state_delta + if event.actions and event.actions.state_delta: + session["state"].update(event.actions.state_delta) + + # Add event to session events + session["events"].append(event.model_dump()) + + # Update the session in storage + session_data[session["app_name"]][session["user_id"]][ + session["id"] + ] = session + + return event + # Return an instance of our mock service return MockSessionService() @@ -725,6 +741,80 @@ def test_delete_session(test_app, create_test_session): logger.info("Session deleted successfully") +def test_update_session(test_app, create_test_session): + """Test patching a session state.""" + info = create_test_session + url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}" + + # Get the original session + response = test_app.get(url) + assert response.status_code == 200 + original_session = response.json() + original_state = original_session.get("state", {}) + + # Prepare state delta + state_delta = {"test_key": "test_value", "counter": 42} + + # Patch the session + response = test_app.patch(url, json={"state_delta": state_delta}) + assert response.status_code == 200 + + # Verify the response + patched_session = response.json() + assert patched_session["id"] == info["session_id"] + + # Verify state was updated correctly + expected_state = {**original_state, **state_delta} + assert patched_session["state"] == expected_state + + # Verify the session was actually updated in storage + response = test_app.get(url) + assert response.status_code == 200 + retrieved_session = response.json() + assert retrieved_session["state"] == expected_state + + # Verify an event was created for the state change + events = retrieved_session.get("events", []) + assert len(events) > len(original_session.get("events", [])) + + # Find the state patch event (looking for "p-" prefix pattern) + state_patch_events = [ + event + for event in events + if ( + event.get("invocationId") or event.get("invocation_id", "") + ).startswith("p-") + ] + + assert len(state_patch_events) == 1, ( + f"Expected 1 state_patch event, found {len(state_patch_events)}. Events:" + f" {events}" + ) + state_patch_event = state_patch_events[0] + assert state_patch_event["author"] == "user" + + # Check for actions in both camelCase and snake_case + actions = state_patch_event.get("actions") or state_patch_event.get("actions") + assert actions is not None, f"No actions found in event: {state_patch_event}" + state_delta_in_event = actions.get("state_delta") or actions.get("stateDelta") + assert state_delta_in_event == state_delta + + logger.info("Session state patched successfully") + + +def test_patch_session_not_found(test_app, test_session_info): + """Test patching a non-existent session.""" + info = test_session_info + url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/nonexistent" + + state_delta = {"test_key": "test_value"} + response = test_app.patch(url, json={"state_delta": state_delta}) + + assert response.status_code == 404 + assert "Session not found" in response.json()["detail"] + logger.info("Patch session not found test passed") + + def test_agent_run(test_app, create_test_session): """Test running an agent with a message.""" info = create_test_session