You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Raise AlreadyExistsError when trying to create a resource with same ID
Move the dedupe logic into session service so that the internal error can be surfaced to client PiperOrigin-RevId: 822294430
This commit is contained in:
committed by
Copybara-Service
parent
c850da3a07
commit
2a901d12f4
@@ -63,6 +63,7 @@ from ..agents.run_config import StreamingMode
|
||||
from ..apps.app import App
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..auth.credential_service.base_credential_service import BaseCredentialService
|
||||
from ..errors.already_exists_error import AlreadyExistsError
|
||||
from ..errors.not_found_error import NotFoundError
|
||||
from ..evaluation.base_eval_service import InferenceConfig
|
||||
from ..evaluation.base_eval_service import InferenceRequest
|
||||
@@ -583,6 +584,33 @@ class AdkWebServer:
|
||||
"Failed to write runtime config file %s: %s", runtime_config_path, e
|
||||
)
|
||||
|
||||
async def _create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
try:
|
||||
session = await self.session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state=state,
|
||||
session_id=session_id,
|
||||
)
|
||||
logger.info("New session created: %s", session.id)
|
||||
return session
|
||||
except AlreadyExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=409, detail=f"Session already exists: {session_id}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Internal server error during session creation: %s", e, exc_info=True
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
def get_fast_api_app(
|
||||
self,
|
||||
lifespan: Optional[Lifespan[FastAPI]] = None,
|
||||
@@ -740,20 +768,12 @@ class AdkWebServer:
|
||||
session_id: str,
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
) -> Session:
|
||||
if (
|
||||
await self.session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
)
|
||||
is not None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=409, detail=f"Session already exists: {session_id}"
|
||||
)
|
||||
session = await self.session_service.create_session(
|
||||
app_name=app_name, user_id=user_id, state=state, session_id=session_id
|
||||
return await self._create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state=state,
|
||||
session_id=session_id,
|
||||
)
|
||||
logger.info("New session created: %s", session_id)
|
||||
return session
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/users/{user_id}/sessions",
|
||||
@@ -765,18 +785,9 @@ class AdkWebServer:
|
||||
req: Optional[CreateSessionRequest] = None,
|
||||
) -> Session:
|
||||
if not req:
|
||||
return await self.session_service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
return await self._create_session(app_name=app_name, user_id=user_id)
|
||||
|
||||
if req.session_id and await self.session_service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=req.session_id
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=409, detail=f"Session already exists: {req.session_id}"
|
||||
)
|
||||
|
||||
session = await self.session_service.create_session(
|
||||
session = await self._create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state=req.state,
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class AlreadyExistsError(Exception):
|
||||
"""Represents an error that occurs when an entity already exists."""
|
||||
|
||||
def __init__(self, message="The resource already exists."):
|
||||
"""Initializes the AlreadyExistsError exception.
|
||||
|
||||
Args:
|
||||
message (str): An optional custom message to describe the error.
|
||||
"""
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
@@ -53,6 +53,7 @@ from typing_extensions import override
|
||||
from tzlocal import get_localzone
|
||||
|
||||
from . import _session_util
|
||||
from ..errors.already_exists_error import AlreadyExistsError
|
||||
from ..events.event import Event
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
@@ -465,6 +466,12 @@ class DatabaseSessionService(BaseSessionService):
|
||||
# 5. Return the session
|
||||
|
||||
with self.database_session_factory() as sql_session:
|
||||
if session_id and sql_session.get(
|
||||
StorageSession, (app_name, user_id, session_id)
|
||||
):
|
||||
raise AlreadyExistsError(
|
||||
f"Session with id {session_id} already exists."
|
||||
)
|
||||
# Fetch app and user states from storage
|
||||
storage_app_state = sql_session.get(StorageAppState, (app_name))
|
||||
if not storage_app_state:
|
||||
|
||||
@@ -23,6 +23,7 @@ import uuid
|
||||
from typing_extensions import override
|
||||
|
||||
from . import _session_util
|
||||
from ..errors.already_exists_error import AlreadyExistsError
|
||||
from ..events.event import Event
|
||||
from .base_session_service import BaseSessionService
|
||||
from .base_session_service import GetSessionConfig
|
||||
@@ -89,6 +90,10 @@ class InMemorySessionService(BaseSessionService):
|
||||
state: Optional[dict[str, Any]] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Session:
|
||||
if session_id and self._get_session_impl(
|
||||
app_name=app_name, user_id=user_id, session_id=session_id
|
||||
):
|
||||
raise AlreadyExistsError(f'Session with id {session_id} already exists.')
|
||||
state_deltas = _session_util.extract_state_delta(state)
|
||||
app_state_delta = state_deltas['app']
|
||||
user_state_delta = state_deltas['user']
|
||||
|
||||
@@ -39,8 +39,9 @@ from google.adk.evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsMa
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.events.event_actions import EventActions
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.base_session_service import ListSessionsResponse
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.sessions.state import State
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
import pytest
|
||||
@@ -194,98 +195,8 @@ def mock_agent_loader():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_service():
|
||||
"""Create a mock session service that uses an in-memory dictionary."""
|
||||
|
||||
# In-memory database to store sessions during testing
|
||||
session_data = {
|
||||
"test_app": {
|
||||
"test_user": {
|
||||
"test_session": {
|
||||
"id": "test_session",
|
||||
"app_name": "test_app",
|
||||
"user_id": "test_user",
|
||||
"events": [],
|
||||
"state": {},
|
||||
"created_at": time.time(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Mock session service class that operates on the in-memory database
|
||||
class MockSessionService:
|
||||
|
||||
async def get_session(self, app_name, user_id, session_id):
|
||||
"""Retrieve a session by ID."""
|
||||
if (
|
||||
app_name in session_data
|
||||
and user_id in session_data[app_name]
|
||||
and session_id in session_data[app_name][user_id]
|
||||
):
|
||||
return session_data[app_name][user_id][session_id]
|
||||
return None
|
||||
|
||||
async def create_session(
|
||||
self, app_name, user_id, state=None, session_id=None
|
||||
):
|
||||
"""Create a new session."""
|
||||
if session_id is None:
|
||||
session_id = f"session_{int(time.time())}"
|
||||
|
||||
# Initialize app_name and user_id if they don't exist
|
||||
if app_name not in session_data:
|
||||
session_data[app_name] = {}
|
||||
if user_id not in session_data[app_name]:
|
||||
session_data[app_name][user_id] = {}
|
||||
|
||||
# Create the session
|
||||
session = {
|
||||
"id": session_id,
|
||||
"app_name": app_name,
|
||||
"user_id": user_id,
|
||||
"events": [],
|
||||
"state": state or {},
|
||||
}
|
||||
|
||||
session_data[app_name][user_id][session_id] = session
|
||||
return session
|
||||
|
||||
async def list_sessions(self, app_name, user_id):
|
||||
"""List all sessions for a user."""
|
||||
if app_name not in session_data or user_id not in session_data[app_name]:
|
||||
return {"sessions": []}
|
||||
|
||||
return ListSessionsResponse(
|
||||
sessions=list(session_data[app_name][user_id].values())
|
||||
)
|
||||
|
||||
async def delete_session(self, app_name, user_id, session_id):
|
||||
"""Delete a session."""
|
||||
if (
|
||||
app_name in session_data
|
||||
and user_id in session_data[app_name]
|
||||
and session_id in session_data[app_name][user_id]
|
||||
):
|
||||
del session_data[app_name][user_id][session_id]
|
||||
|
||||
async def append_event(self, session, event):
|
||||
"""Append an event to a session."""
|
||||
# Update session state if event has state_delta
|
||||
if event.actions and event.actions.state_delta:
|
||||
session["state"].update(event.actions.state_delta)
|
||||
|
||||
# Add event to session events
|
||||
session["events"].append(event.model_dump())
|
||||
|
||||
# Update the session in storage
|
||||
session_data[session["app_name"]][session["user_id"]][
|
||||
session["id"]
|
||||
] = session
|
||||
|
||||
return event
|
||||
|
||||
# Return an instance of our mock service
|
||||
return MockSessionService()
|
||||
"""Create an in-memory session service instance for testing."""
|
||||
return InMemorySessionService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -465,7 +376,7 @@ async def create_test_session(
|
||||
state={},
|
||||
)
|
||||
|
||||
logger.info(f"Created test session: {session['id']}")
|
||||
logger.info(f"Created test session: {session.id}")
|
||||
return test_session_info
|
||||
|
||||
|
||||
@@ -654,6 +565,22 @@ def test_create_session_with_id(test_app, test_session_info):
|
||||
logger.info(f"Created session with ID: {data['id']}")
|
||||
|
||||
|
||||
def test_create_session_with_id_already_exists(test_app, test_session_info):
|
||||
"""Test creating a session with an ID that already exists."""
|
||||
session_id = "existing_session_id"
|
||||
url = f"/apps/{test_session_info['app_name']}/users/{test_session_info['user_id']}/sessions/{session_id}"
|
||||
|
||||
# Create the session for the first time
|
||||
response = test_app.post(url, json={"state": {}})
|
||||
assert response.status_code == 200
|
||||
|
||||
# Attempt to create it again
|
||||
response = test_app.post(url, json={"state": {}})
|
||||
assert response.status_code == 409
|
||||
assert "Session already exists" in response.json()["detail"]
|
||||
logger.info("Verified 409 on duplicate session creation.")
|
||||
|
||||
|
||||
def test_create_session_without_id(test_app, test_session_info):
|
||||
"""Test creating a session with a generated ID."""
|
||||
url = f"/apps/{test_session_info['app_name']}/users/{test_session_info['user_id']}/sessions"
|
||||
@@ -753,9 +680,7 @@ def test_update_session(test_app, create_test_session):
|
||||
state_patch_events = [
|
||||
event
|
||||
for event in events
|
||||
if (
|
||||
event.get("invocationId") or event.get("invocation_id", "")
|
||||
).startswith("p-")
|
||||
if event.get("invocationId", "").startswith("p-")
|
||||
]
|
||||
|
||||
assert len(state_patch_events) == 1, (
|
||||
@@ -766,9 +691,9 @@ def test_update_session(test_app, create_test_session):
|
||||
assert state_patch_event["author"] == "user"
|
||||
|
||||
# Check for actions in both camelCase and snake_case
|
||||
actions = state_patch_event.get("actions") or state_patch_event.get("actions")
|
||||
actions = state_patch_event.get("actions")
|
||||
assert actions is not None, f"No actions found in event: {state_patch_event}"
|
||||
state_delta_in_event = actions.get("state_delta") or actions.get("stateDelta")
|
||||
state_delta_in_event = actions.get("stateDelta")
|
||||
assert state_delta_in_event == state_delta
|
||||
|
||||
logger.info("Session state patched successfully")
|
||||
@@ -818,7 +743,7 @@ def test_agent_run(test_app, create_test_session):
|
||||
)
|
||||
|
||||
# Third event should have interrupted flag
|
||||
assert data[2]["interrupted"] == True
|
||||
assert data[2]["interrupted"] is True
|
||||
|
||||
logger.info("Agent run test completed successfully")
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
import enum
|
||||
|
||||
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.events.event_actions import EventActions
|
||||
from google.adk.sessions.base_session_service import GetSessionConfig
|
||||
@@ -336,6 +337,32 @@ async def test_get_session_respects_user_id(service_type):
|
||||
assert len(session2_got.events) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
)
|
||||
async def test_create_session_with_existing_id_raises_error(service_type):
|
||||
session_service = get_session_service(service_type)
|
||||
app_name = 'my_app'
|
||||
user_id = 'test_user'
|
||||
session_id = 'existing_session'
|
||||
|
||||
# Create the first session
|
||||
await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Attempt to create a session with the same ID
|
||||
with pytest.raises(AlreadyExistsError):
|
||||
await session_service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
'service_type', [SessionServiceType.IN_MEMORY, SessionServiceType.DATABASE]
|
||||
|
||||
Reference in New Issue
Block a user