fix: remove duplicate session GET when using API server, unbreak auto_session_create when using API server

Co-authored-by: Sasha Sobran <asobran@google.com>
PiperOrigin-RevId: 874188082
This commit is contained in:
Sasha Sobran
2026-02-23 12:00:59 -08:00
committed by Copybara-Service
parent 2dbd1f25bd
commit 445dc189e9
5 changed files with 203 additions and 76 deletions
+55 -39
View File
@@ -68,6 +68,7 @@ from ..auth.credential_service.base_credential_service import BaseCredentialServ
from ..errors.already_exists_error import AlreadyExistsError
from ..errors.input_validation_error import InputValidationError
from ..errors.not_found_error import NotFoundError
from ..errors.session_not_found_error import SessionNotFoundError
from ..evaluation.base_eval_service import InferenceConfig
from ..evaluation.base_eval_service import InferenceRequest
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
@@ -1558,53 +1559,68 @@ class AdkWebServer:
@app.post("/run", response_model_exclude_none=True)
async def run_agent(req: RunAgentRequest) -> list[Event]:
session = await self.session_service.get_session(
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
runner = await self.get_runner_async(req.app_name)
async with Aclosing(
runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
invocation_id=req.invocation_id,
)
) as agen:
events = [event async for event in agen]
try:
async with Aclosing(
runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
invocation_id=req.invocation_id,
)
) as agen:
events = [event async for event in agen]
except SessionNotFoundError as e:
raise HTTPException(status_code=404, detail=str(e)) from e
logger.info("Generated %s events in agent run", len(events))
logger.debug("Events generated: %s", events)
return events
@app.post("/run_sse")
async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse:
# SSE endpoint
session = await self.session_service.get_session(
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
runner = await self.get_runner_async(req.app_name)
agen = runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
run_config=RunConfig(streaming_mode=stream_mode),
invocation_id=req.invocation_id,
)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
# Eagerly advance the generator to trigger session validation
# before the streaming response is created. This lets us return
# a proper HTTP 404 for missing sessions without a redundant
# get_session call — the Runner's single _get_or_create_session
# call is the only one that runs.
first_event = None
first_error = None
try:
first_event = await anext(agen)
except SessionNotFoundError as e:
await agen.aclose()
raise HTTPException(status_code=404, detail=str(e)) from e
except StopAsyncIteration:
await agen.aclose()
except Exception as e:
first_error = e
# Convert the events to properly formatted SSE
async def event_generator():
try:
stream_mode = (
StreamingMode.SSE if req.streaming else StreamingMode.NONE
)
runner = await self.get_runner_async(req.app_name)
async with Aclosing(
runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
run_config=RunConfig(streaming_mode=stream_mode),
invocation_id=req.invocation_id,
)
) as agen:
async for event in agen:
async with Aclosing(agen):
try:
if first_error:
raise first_error
async def all_events():
if first_event is not None:
yield first_event
async for event in agen:
yield event
async for event in all_events():
# ADK Web renders artifacts from `actions.artifactDelta`
# during part processing *and* during action processing
# 1) the original event with `artifactDelta` cleared (content)
@@ -1630,9 +1646,9 @@ class AdkWebServer:
"Generated event in agent run streaming: %s", sse_event
)
yield f"data: {sse_event}\n\n"
except Exception as e:
logger.exception("Error in event_generator: %s", e)
yield f"data: {json.dumps({'error': str(e)})}\n\n"
except Exception as e:
logger.exception("Error in event_generator: %s", e)
yield f"data: {json.dumps({'error': str(e)})}\n\n"
# Returns a streaming response with the proper media type for SSE
return StreamingResponse(
@@ -0,0 +1,28 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from .not_found_error import NotFoundError
class SessionNotFoundError(ValueError, NotFoundError):
"""Raised when a session cannot be found.
Inherits from both ValueError (for backward compatibility) and NotFoundError
(for semantic consistency with the project's error hierarchy).
"""
def __init__(self, message="Session not found."):
super().__init__(message)
+5 -3
View File
@@ -45,6 +45,7 @@ from .artifacts.base_artifact_service import BaseArtifactService
from .artifacts.in_memory_artifact_service import InMemoryArtifactService
from .auth.credential_service.base_credential_service import BaseCredentialService
from .code_executors.built_in_code_executor import BuiltInCodeExecutor
from .errors.session_not_found_error import SessionNotFoundError
from .events.event import Event
from .events.event import EventActions
from .flows.llm_flows import contents
@@ -358,7 +359,7 @@ class Runner:
This helper first attempts to retrieve the session. If not found and
auto_create_session is True, it creates a new session with the provided
identifiers. Otherwise, it raises a ValueError with a helpful message.
identifiers. Otherwise, it raises a SessionNotFoundError.
Args:
user_id: The user ID of the session.
@@ -368,7 +369,8 @@ class Runner:
The existing or newly created `Session`.
Raises:
ValueError: If the session is not found and auto_create_session is False.
SessionNotFoundError: If the session is not found and
auto_create_session is False.
"""
session = await self.session_service.get_session(
app_name=self.app_name, user_id=user_id, session_id=session_id
@@ -380,7 +382,7 @@ class Runner:
)
else:
message = self._format_session_not_found_message(session_id)
raise ValueError(message)
raise SessionNotFoundError(message)
return session
def run(
+113 -33
View File
@@ -32,6 +32,7 @@ from google.adk.artifacts.base_artifact_service import ArtifactVersion
from google.adk.cli import fast_api as fast_api_module
from google.adk.cli.fast_api import get_fast_api_app
from google.adk.errors.input_validation_error import InputValidationError
from google.adk.errors.session_not_found_error import SessionNotFoundError
from google.adk.evaluation.eval_case import EvalCase
from google.adk.evaluation.eval_case import Invocation
from google.adk.evaluation.eval_result import EvalSetResult
@@ -451,18 +452,28 @@ def mock_eval_set_results_manager():
return MockEvalSetResultsManager()
@pytest.fixture
def test_app(
def _create_test_client(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
**app_kwargs,
):
"""Create a TestClient for the FastAPI app without starting a server."""
# Patch multiple services and signal handlers
"""Helper to create a TestClient with the given get_fast_api_app overrides."""
defaults = dict(
agents_dir=".",
web=True,
session_service_uri="",
artifact_service_uri="",
memory_service_uri="",
allow_origins=["*"],
a2a=False,
host="127.0.0.1",
port=8000,
)
defaults.update(app_kwargs)
with (
patch.object(signal, "signal", autospec=True, return_value=None),
patch.object(
@@ -502,23 +513,28 @@ def test_app(
return_value=mock_eval_set_results_manager,
),
):
# Get the FastAPI app, but don't actually run it
app = get_fast_api_app(
agents_dir=".",
web=True,
session_service_uri="",
artifact_service_uri="",
memory_service_uri="",
allow_origins=["*"],
a2a=False, # Disable A2A for most tests
host="127.0.0.1",
port=8000,
)
app = get_fast_api_app(**defaults)
return TestClient(app)
# Create a TestClient that doesn't start a real server
client = TestClient(app)
return client
@pytest.fixture
def test_app(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
):
"""Create a TestClient for the FastAPI app without starting a server."""
return _create_test_client(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
)
@pytest.fixture
@@ -1106,20 +1122,9 @@ def test_agent_run_sse_yields_error_object_on_exception(
"""Test /run_sse streams an error object if streaming raises."""
info = create_test_session
async def run_async_raises(
self,
*,
user_id: str,
session_id: str,
invocation_id: Optional[str] = None,
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
):
del user_id, session_id, invocation_id, new_message, state_delta, run_config
async def run_async_raises(self, **kwargs):
raise ValueError("boom")
if False: # pylint: disable=using-constant-test
yield _event_1()
yield # make it an async generator # pylint: disable=unreachable
monkeypatch.setattr(Runner, "run_async", run_async_raises)
@@ -1637,5 +1642,80 @@ def test_version_endpoint(test_app):
assert "language_version" in data
@pytest.fixture
def test_app_auto_session(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
):
"""Create a TestClient with auto_create_session=True."""
return _create_test_client(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
web=False,
auto_create_session=True,
)
@pytest.mark.parametrize("endpoint", ["/run", "/run_sse"])
def test_auto_creates_session(
test_app_auto_session, test_session_info, endpoint
):
"""Test /run and /run_sse auto-create sessions when auto_create_session=True."""
payload = {
"app_name": test_session_info["app_name"],
"user_id": test_session_info["user_id"],
"session_id": "nonexistent_session",
"new_message": {"role": "user", "parts": [{"text": "Hello"}]},
}
response = test_app_auto_session.post(endpoint, json=payload)
assert response.status_code == 200
if endpoint == "/run":
data = response.json()
assert isinstance(data, list)
assert len(data) > 0
else:
sse_events = [
json.loads(line.removeprefix("data: "))
for line in response.text.splitlines()
if line.startswith("data: ")
]
assert len(sse_events) > 0
assert not any("error" in e for e in sse_events)
@pytest.mark.parametrize("endpoint", ["/run", "/run_sse"])
def test_returns_404_without_auto_create(
test_app, test_session_info, monkeypatch, endpoint
):
"""Test /run and /run_sse return 404 for missing sessions without auto_create."""
async def run_async_session_not_found(self, **kwargs):
raise SessionNotFoundError(f"Session not found: {kwargs['session_id']}")
yield # make it an async generator # pylint: disable=unreachable
monkeypatch.setattr(Runner, "run_async", run_async_session_not_found)
payload = {
"app_name": test_session_info["app_name"],
"user_id": test_session_info["user_id"],
"session_id": "nonexistent_session",
"new_message": {"role": "user", "parts": [{"text": "Hello"}]},
}
response = test_app.post(endpoint, json=payload)
assert response.status_code == 404
assert "Session not found" in response.json()["detail"]
if __name__ == "__main__":
pytest.main(["-xvs", __file__])
+2 -1
View File
@@ -29,6 +29,7 @@ from google.adk.apps.app import App
from google.adk.apps.app import ResumabilityConfig
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.cli.utils.agent_loader import AgentLoader
from google.adk.errors.session_not_found_error import SessionNotFoundError
from google.adk.events.event import Event
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.runners import Runner
@@ -243,7 +244,7 @@ async def test_session_not_found_message_includes_alignment_hint():
new_message=types.Content(role="user", parts=[]),
)
with pytest.raises(ValueError) as excinfo:
with pytest.raises(SessionNotFoundError) as excinfo:
await agen.__anext__()
await agen.aclose()