You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Add session patch endpoint to api server for state update
This is allow user to update session state without running the agent. e.g. if I want to test some case when session has certain state on adk web. PiperOrigin-RevId: 814252851
This commit is contained in:
committed by
Copybara-Service
parent
822efe0065
commit
c46308b7cf
@@ -220,6 +220,13 @@ class UpdateMemoryRequest(common.BaseModel):
|
||||
"""The ID of the session to add to memory."""
|
||||
|
||||
|
||||
class UpdateSessionRequest(common.BaseModel):
|
||||
"""Request to update session state without running the agent."""
|
||||
|
||||
state_delta: dict[str, Any]
|
||||
"""The state changes to apply to the session."""
|
||||
|
||||
|
||||
class RunEvalResult(common.BaseModel):
|
||||
eval_set_file: str
|
||||
eval_set_id: str
|
||||
@@ -767,6 +774,56 @@ class AdkWebServer:
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
|
||||
@app.patch(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def update_session(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
req: UpdateSessionRequest,
|
||||
) -> Session:
|
||||
"""Updates session state without running the agent.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
session_id: The ID of the session to update.
|
||||
req: The patch request containing state changes.
|
||||
|
||||
Returns:
|
||||
The updated session.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the session is not found.
|
||||
"""
|
||||
session = await self.session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Create an event to record the state change
|
||||
import uuid
|
||||
|
||||
from ..events.event import Event
|
||||
from ..events.event import EventActions
|
||||
|
||||
state_update_event = Event(
|
||||
invocation_id="p-" + str(uuid.uuid4()),
|
||||
author="user",
|
||||
actions=EventActions(state_delta=req.state_delta),
|
||||
)
|
||||
|
||||
# Append the event to the session
|
||||
# This will automatically update the session state through __update_session_state
|
||||
await self.session_service.append_event(
|
||||
session=session, event=state_update_event
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/eval-sets",
|
||||
response_model_exclude_none=True,
|
||||
|
||||
@@ -176,6 +176,36 @@ class AdkWebServerClient:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def update_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
state_delta: Dict[str, Any],
|
||||
) -> Session:
|
||||
"""Update session state without running the agent.
|
||||
|
||||
Args:
|
||||
app_name: Name of the application
|
||||
user_id: User identifier
|
||||
session_id: Session identifier to update
|
||||
state_delta: The state changes to apply to the session
|
||||
|
||||
Returns:
|
||||
The updated Session object
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the request fails or session not found
|
||||
"""
|
||||
async with self._get_client() as client:
|
||||
response = await client.patch(
|
||||
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
||||
json={"state_delta": state_delta},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return Session.model_validate(response.json())
|
||||
|
||||
async def run_agent(
|
||||
self,
|
||||
request: RunAgentRequest,
|
||||
|
||||
@@ -127,6 +127,43 @@ async def test_delete_session():
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_session():
|
||||
client = AdkWebServerClient()
|
||||
|
||||
# Mock the HTTP response
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"id": "test_session",
|
||||
"app_name": "test_app",
|
||||
"user_id": "test_user",
|
||||
"events": [],
|
||||
"state": {"key": "updated_value", "new_key": "new_value"},
|
||||
}
|
||||
|
||||
with patch("httpx.AsyncClient") as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client.patch.return_value = mock_response
|
||||
mock_client_class.return_value = mock_client
|
||||
|
||||
state_delta = {"key": "updated_value", "new_key": "new_value"}
|
||||
session = await client.update_session(
|
||||
app_name="test_app",
|
||||
user_id="test_user",
|
||||
session_id="test_session",
|
||||
state_delta=state_delta,
|
||||
)
|
||||
|
||||
assert isinstance(session, Session)
|
||||
assert session.id == "test_session"
|
||||
assert session.state == {"key": "updated_value", "new_key": "new_value"}
|
||||
mock_client.patch.assert_called_once_with(
|
||||
"/apps/test_app/users/test_user/sessions/test_session",
|
||||
json={"state_delta": state_delta},
|
||||
)
|
||||
mock_response.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent():
|
||||
client = AdkWebServerClient()
|
||||
|
||||
@@ -288,6 +288,22 @@ def mock_session_service():
|
||||
):
|
||||
del session_data[app_name][user_id][session_id]
|
||||
|
||||
async def append_event(self, session, event):
|
||||
"""Append an event to a session."""
|
||||
# Update session state if event has state_delta
|
||||
if event.actions and event.actions.state_delta:
|
||||
session["state"].update(event.actions.state_delta)
|
||||
|
||||
# Add event to session events
|
||||
session["events"].append(event.model_dump())
|
||||
|
||||
# Update the session in storage
|
||||
session_data[session["app_name"]][session["user_id"]][
|
||||
session["id"]
|
||||
] = session
|
||||
|
||||
return event
|
||||
|
||||
# Return an instance of our mock service
|
||||
return MockSessionService()
|
||||
|
||||
@@ -725,6 +741,80 @@ def test_delete_session(test_app, create_test_session):
|
||||
logger.info("Session deleted successfully")
|
||||
|
||||
|
||||
def test_update_session(test_app, create_test_session):
|
||||
"""Test patching a session state."""
|
||||
info = create_test_session
|
||||
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/{info['session_id']}"
|
||||
|
||||
# Get the original session
|
||||
response = test_app.get(url)
|
||||
assert response.status_code == 200
|
||||
original_session = response.json()
|
||||
original_state = original_session.get("state", {})
|
||||
|
||||
# Prepare state delta
|
||||
state_delta = {"test_key": "test_value", "counter": 42}
|
||||
|
||||
# Patch the session
|
||||
response = test_app.patch(url, json={"state_delta": state_delta})
|
||||
assert response.status_code == 200
|
||||
|
||||
# Verify the response
|
||||
patched_session = response.json()
|
||||
assert patched_session["id"] == info["session_id"]
|
||||
|
||||
# Verify state was updated correctly
|
||||
expected_state = {**original_state, **state_delta}
|
||||
assert patched_session["state"] == expected_state
|
||||
|
||||
# Verify the session was actually updated in storage
|
||||
response = test_app.get(url)
|
||||
assert response.status_code == 200
|
||||
retrieved_session = response.json()
|
||||
assert retrieved_session["state"] == expected_state
|
||||
|
||||
# Verify an event was created for the state change
|
||||
events = retrieved_session.get("events", [])
|
||||
assert len(events) > len(original_session.get("events", []))
|
||||
|
||||
# Find the state patch event (looking for "p-" prefix pattern)
|
||||
state_patch_events = [
|
||||
event
|
||||
for event in events
|
||||
if (
|
||||
event.get("invocationId") or event.get("invocation_id", "")
|
||||
).startswith("p-")
|
||||
]
|
||||
|
||||
assert len(state_patch_events) == 1, (
|
||||
f"Expected 1 state_patch event, found {len(state_patch_events)}. Events:"
|
||||
f" {events}"
|
||||
)
|
||||
state_patch_event = state_patch_events[0]
|
||||
assert state_patch_event["author"] == "user"
|
||||
|
||||
# Check for actions in both camelCase and snake_case
|
||||
actions = state_patch_event.get("actions") or state_patch_event.get("actions")
|
||||
assert actions is not None, f"No actions found in event: {state_patch_event}"
|
||||
state_delta_in_event = actions.get("state_delta") or actions.get("stateDelta")
|
||||
assert state_delta_in_event == state_delta
|
||||
|
||||
logger.info("Session state patched successfully")
|
||||
|
||||
|
||||
def test_patch_session_not_found(test_app, test_session_info):
|
||||
"""Test patching a non-existent session."""
|
||||
info = test_session_info
|
||||
url = f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/nonexistent"
|
||||
|
||||
state_delta = {"test_key": "test_value"}
|
||||
response = test_app.patch(url, json={"state_delta": state_delta})
|
||||
|
||||
assert response.status_code == 404
|
||||
assert "Session not found" in response.json()["detail"]
|
||||
logger.info("Patch session not found test passed")
|
||||
|
||||
|
||||
def test_agent_run(test_app, create_test_session):
|
||||
"""Test running an agent with a message."""
|
||||
info = create_test_session
|
||||
|
||||
Reference in New Issue
Block a user