chore: Add session patch endpoint to api server for state update

This is allow user to update session state without running the agent. e.g. if I want to test some case when session has certain state on adk web.

PiperOrigin-RevId: 814252851
This commit is contained in:
Xiang (Sean) Zhou
2025-10-02 08:51:43 -07:00
committed by Copybara-Service
parent 822efe0065
commit c46308b7cf
4 changed files with 214 additions and 0 deletions
+57
View File
@@ -220,6 +220,13 @@ class UpdateMemoryRequest(common.BaseModel):
"""The ID of the session to add to memory."""
class UpdateSessionRequest(common.BaseModel):
"""Request to update session state without running the agent."""
state_delta: dict[str, Any]
"""The state changes to apply to the session."""
class RunEvalResult(common.BaseModel):
eval_set_file: str
eval_set_id: str
@@ -767,6 +774,56 @@ class AdkWebServer:
app_name=app_name, user_id=user_id, session_id=session_id
)
@app.patch(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
response_model_exclude_none=True,
)
async def update_session(
app_name: str,
user_id: str,
session_id: str,
req: UpdateSessionRequest,
) -> Session:
"""Updates session state without running the agent.
Args:
app_name: The name of the application.
user_id: The ID of the user.
session_id: The ID of the session to update.
req: The patch request containing state changes.
Returns:
The updated session.
Raises:
HTTPException: If the session is not found.
"""
session = await self.session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Create an event to record the state change
import uuid
from ..events.event import Event
from ..events.event import EventActions
state_update_event = Event(
invocation_id="p-" + str(uuid.uuid4()),
author="user",
actions=EventActions(state_delta=req.state_delta),
)
# Append the event to the session
# This will automatically update the session state through __update_session_state
await self.session_service.append_event(
session=session, event=state_update_event
)
return session
@app.post(
"/apps/{app_name}/eval-sets",
response_model_exclude_none=True,
@@ -176,6 +176,36 @@ class AdkWebServerClient:
)
response.raise_for_status()
async def update_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
state_delta: Dict[str, Any],
) -> Session:
"""Update session state without running the agent.
Args:
app_name: Name of the application
user_id: User identifier
session_id: Session identifier to update
state_delta: The state changes to apply to the session
Returns:
The updated Session object
Raises:
httpx.HTTPStatusError: If the request fails or session not found
"""
async with self._get_client() as client:
response = await client.patch(
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
json={"state_delta": state_delta},
)
response.raise_for_status()
return Session.model_validate(response.json())
async def run_agent(
self,
request: RunAgentRequest,
@@ -127,6 +127,43 @@ async def test_delete_session():
mock_response.raise_for_status.assert_called_once()
@pytest.mark.asyncio
async def test_update_session():
client = AdkWebServerClient()
# Mock the HTTP response
mock_response = MagicMock()
mock_response.json.return_value = {
"id": "test_session",
"app_name": "test_app",
"user_id": "test_user",
"events": [],
"state": {"key": "updated_value", "new_key": "new_value"},
}
with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.patch.return_value = mock_response
mock_client_class.return_value = mock_client
state_delta = {"key": "updated_value", "new_key": "new_value"}
session = await client.update_session(
app_name="test_app",
user_id="test_user",
session_id="test_session",
state_delta=state_delta,
)
assert isinstance(session, Session)
assert session.id == "test_session"
assert session.state == {"key": "updated_value", "new_key": "new_value"}
mock_client.patch.assert_called_once_with(
"/apps/test_app/users/test_user/sessions/test_session",
json={"state_delta": state_delta},
)
mock_response.raise_for_status.assert_called_once()
@pytest.mark.asyncio
async def test_run_agent():
client = AdkWebServerClient()
+90
View File
@@ -288,6 +288,22 @@ def mock_session_service():
):
del session_data[app_name][user_id][session_id]
async def append_event(self, session, event):
"""Append an event to a session."""
# Update session state if event has state_delta
if event.actions and event.actions.state_delta:
session["state"].update(event.actions.state_delta)
# Add event to session events
session["events"].append(event.model_dump())
# Update the session in storage
session_data[session["app_name"]][session["user_id"]][
session["id"]
] = session
return event
# Return an instance of our mock service
return MockSessionService()
@@ -725,6 +741,80 @@ def test_delete_session(test_app, create_test_session):
logger.info("Session deleted successfully")
def test_update_session(test_app, create_test_session):
"""Test patching a session state."""
info = create_test_session
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}"
# Get the original session
response = test_app.get(url)
assert response.status_code == 200
original_session = response.json()
original_state = original_session.get("state", {})
# Prepare state delta
state_delta = {"test_key": "test_value", "counter": 42}
# Patch the session
response = test_app.patch(url, json={"state_delta": state_delta})
assert response.status_code == 200
# Verify the response
patched_session = response.json()
assert patched_session["id"] == info["session_id"]
# Verify state was updated correctly
expected_state = {**original_state, **state_delta}
assert patched_session["state"] == expected_state
# Verify the session was actually updated in storage
response = test_app.get(url)
assert response.status_code == 200
retrieved_session = response.json()
assert retrieved_session["state"] == expected_state
# Verify an event was created for the state change
events = retrieved_session.get("events", [])
assert len(events) > len(original_session.get("events", []))
# Find the state patch event (looking for "p-" prefix pattern)
state_patch_events = [
event
for event in events
if (
event.get("invocationId") or event.get("invocation_id", "")
).startswith("p-")
]
assert len(state_patch_events) == 1, (
f"Expected 1 state_patch event, found {len(state_patch_events)}. Events:"
f" {events}"
)
state_patch_event = state_patch_events[0]
assert state_patch_event["author"] == "user"
# Check for actions in both camelCase and snake_case
actions = state_patch_event.get("actions") or state_patch_event.get("actions")
assert actions is not None, f"No actions found in event: {state_patch_event}"
state_delta_in_event = actions.get("state_delta") or actions.get("stateDelta")
assert state_delta_in_event == state_delta
logger.info("Session state patched successfully")
def test_patch_session_not_found(test_app, test_session_info):
"""Test patching a non-existent session."""
info = test_session_info
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/nonexistent"
state_delta = {"test_key": "test_value"}
response = test_app.patch(url, json={"state_delta": state_delta})
assert response.status_code == 404
assert "Session not found" in response.json()["detail"]
logger.info("Patch session not found test passed")
def test_agent_run(test_app, create_test_session):
"""Test running an agent with a message."""
info = create_test_session