You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: remove duplicate session GET when using API server, unbreak auto_session_create when using API server
Co-authored-by: Sasha Sobran <asobran@google.com> PiperOrigin-RevId: 874188082
This commit is contained in:
committed by
Copybara-Service
parent
2dbd1f25bd
commit
445dc189e9
@@ -68,6 +68,7 @@ from ..auth.credential_service.base_credential_service import BaseCredentialServ
|
||||
from ..errors.already_exists_error import AlreadyExistsError
|
||||
from ..errors.input_validation_error import InputValidationError
|
||||
from ..errors.not_found_error import NotFoundError
|
||||
from ..errors.session_not_found_error import SessionNotFoundError
|
||||
from ..evaluation.base_eval_service import InferenceConfig
|
||||
from ..evaluation.base_eval_service import InferenceRequest
|
||||
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
|
||||
@@ -1558,53 +1559,68 @@ class AdkWebServer:
|
||||
|
||||
@app.post("/run", response_model_exclude_none=True)
|
||||
async def run_agent(req: RunAgentRequest) -> list[Event]:
|
||||
session = await self.session_service.get_session(
|
||||
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
runner = await self.get_runner_async(req.app_name)
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
state_delta=req.state_delta,
|
||||
invocation_id=req.invocation_id,
|
||||
)
|
||||
) as agen:
|
||||
events = [event async for event in agen]
|
||||
try:
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
state_delta=req.state_delta,
|
||||
invocation_id=req.invocation_id,
|
||||
)
|
||||
) as agen:
|
||||
events = [event async for event in agen]
|
||||
except SessionNotFoundError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e)) from e
|
||||
logger.info("Generated %s events in agent run", len(events))
|
||||
logger.debug("Events generated: %s", events)
|
||||
return events
|
||||
|
||||
@app.post("/run_sse")
|
||||
async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse:
|
||||
# SSE endpoint
|
||||
session = await self.session_service.get_session(
|
||||
app_name=req.app_name, user_id=req.user_id, session_id=req.session_id
|
||||
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
||||
runner = await self.get_runner_async(req.app_name)
|
||||
agen = runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
state_delta=req.state_delta,
|
||||
run_config=RunConfig(streaming_mode=stream_mode),
|
||||
invocation_id=req.invocation_id,
|
||||
)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Eagerly advance the generator to trigger session validation
|
||||
# before the streaming response is created. This lets us return
|
||||
# a proper HTTP 404 for missing sessions without a redundant
|
||||
# get_session call — the Runner's single _get_or_create_session
|
||||
# call is the only one that runs.
|
||||
first_event = None
|
||||
first_error = None
|
||||
try:
|
||||
first_event = await anext(agen)
|
||||
except SessionNotFoundError as e:
|
||||
await agen.aclose()
|
||||
raise HTTPException(status_code=404, detail=str(e)) from e
|
||||
except StopAsyncIteration:
|
||||
await agen.aclose()
|
||||
except Exception as e:
|
||||
first_error = e
|
||||
|
||||
# Convert the events to properly formatted SSE
|
||||
async def event_generator():
|
||||
try:
|
||||
stream_mode = (
|
||||
StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
||||
)
|
||||
runner = await self.get_runner_async(req.app_name)
|
||||
async with Aclosing(
|
||||
runner.run_async(
|
||||
user_id=req.user_id,
|
||||
session_id=req.session_id,
|
||||
new_message=req.new_message,
|
||||
state_delta=req.state_delta,
|
||||
run_config=RunConfig(streaming_mode=stream_mode),
|
||||
invocation_id=req.invocation_id,
|
||||
)
|
||||
) as agen:
|
||||
async for event in agen:
|
||||
async with Aclosing(agen):
|
||||
try:
|
||||
if first_error:
|
||||
raise first_error
|
||||
|
||||
async def all_events():
|
||||
if first_event is not None:
|
||||
yield first_event
|
||||
async for event in agen:
|
||||
yield event
|
||||
|
||||
async for event in all_events():
|
||||
# ADK Web renders artifacts from `actions.artifactDelta`
|
||||
# during part processing *and* during action processing
|
||||
# 1) the original event with `artifactDelta` cleared (content)
|
||||
@@ -1630,9 +1646,9 @@ class AdkWebServer:
|
||||
"Generated event in agent run streaming: %s", sse_event
|
||||
)
|
||||
yield f"data: {sse_event}\n\n"
|
||||
except Exception as e:
|
||||
logger.exception("Error in event_generator: %s", e)
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
except Exception as e:
|
||||
logger.exception("Error in event_generator: %s", e)
|
||||
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
||||
|
||||
# Returns a streaming response with the proper media type for SSE
|
||||
return StreamingResponse(
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .not_found_error import NotFoundError
|
||||
|
||||
|
||||
class SessionNotFoundError(ValueError, NotFoundError):
|
||||
"""Raised when a session cannot be found.
|
||||
|
||||
Inherits from both ValueError (for backward compatibility) and NotFoundError
|
||||
(for semantic consistency with the project's error hierarchy).
|
||||
"""
|
||||
|
||||
def __init__(self, message="Session not found."):
|
||||
super().__init__(message)
|
||||
@@ -45,6 +45,7 @@ from .artifacts.base_artifact_service import BaseArtifactService
|
||||
from .artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from .auth.credential_service.base_credential_service import BaseCredentialService
|
||||
from .code_executors.built_in_code_executor import BuiltInCodeExecutor
|
||||
from .errors.session_not_found_error import SessionNotFoundError
|
||||
from .events.event import Event
|
||||
from .events.event import EventActions
|
||||
from .flows.llm_flows import contents
|
||||
@@ -358,7 +359,7 @@ class Runner:
|
||||
|
||||
This helper first attempts to retrieve the session. If not found and
|
||||
auto_create_session is True, it creates a new session with the provided
|
||||
identifiers. Otherwise, it raises a ValueError with a helpful message.
|
||||
identifiers. Otherwise, it raises a SessionNotFoundError.
|
||||
|
||||
Args:
|
||||
user_id: The user ID of the session.
|
||||
@@ -368,7 +369,8 @@ class Runner:
|
||||
The existing or newly created `Session`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session is not found and auto_create_session is False.
|
||||
SessionNotFoundError: If the session is not found and
|
||||
auto_create_session is False.
|
||||
"""
|
||||
session = await self.session_service.get_session(
|
||||
app_name=self.app_name, user_id=user_id, session_id=session_id
|
||||
@@ -380,7 +382,7 @@ class Runner:
|
||||
)
|
||||
else:
|
||||
message = self._format_session_not_found_message(session_id)
|
||||
raise ValueError(message)
|
||||
raise SessionNotFoundError(message)
|
||||
return session
|
||||
|
||||
def run(
|
||||
|
||||
@@ -32,6 +32,7 @@ from google.adk.artifacts.base_artifact_service import ArtifactVersion
|
||||
from google.adk.cli import fast_api as fast_api_module
|
||||
from google.adk.cli.fast_api import get_fast_api_app
|
||||
from google.adk.errors.input_validation_error import InputValidationError
|
||||
from google.adk.errors.session_not_found_error import SessionNotFoundError
|
||||
from google.adk.evaluation.eval_case import EvalCase
|
||||
from google.adk.evaluation.eval_case import Invocation
|
||||
from google.adk.evaluation.eval_result import EvalSetResult
|
||||
@@ -451,18 +452,28 @@ def mock_eval_set_results_manager():
|
||||
return MockEvalSetResultsManager()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app(
|
||||
def _create_test_client(
|
||||
mock_session_service,
|
||||
mock_artifact_service,
|
||||
mock_memory_service,
|
||||
mock_agent_loader,
|
||||
mock_eval_sets_manager,
|
||||
mock_eval_set_results_manager,
|
||||
**app_kwargs,
|
||||
):
|
||||
"""Create a TestClient for the FastAPI app without starting a server."""
|
||||
|
||||
# Patch multiple services and signal handlers
|
||||
"""Helper to create a TestClient with the given get_fast_api_app overrides."""
|
||||
defaults = dict(
|
||||
agents_dir=".",
|
||||
web=True,
|
||||
session_service_uri="",
|
||||
artifact_service_uri="",
|
||||
memory_service_uri="",
|
||||
allow_origins=["*"],
|
||||
a2a=False,
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
)
|
||||
defaults.update(app_kwargs)
|
||||
with (
|
||||
patch.object(signal, "signal", autospec=True, return_value=None),
|
||||
patch.object(
|
||||
@@ -502,23 +513,28 @@ def test_app(
|
||||
return_value=mock_eval_set_results_manager,
|
||||
),
|
||||
):
|
||||
# Get the FastAPI app, but don't actually run it
|
||||
app = get_fast_api_app(
|
||||
agents_dir=".",
|
||||
web=True,
|
||||
session_service_uri="",
|
||||
artifact_service_uri="",
|
||||
memory_service_uri="",
|
||||
allow_origins=["*"],
|
||||
a2a=False, # Disable A2A for most tests
|
||||
host="127.0.0.1",
|
||||
port=8000,
|
||||
)
|
||||
app = get_fast_api_app(**defaults)
|
||||
return TestClient(app)
|
||||
|
||||
# Create a TestClient that doesn't start a real server
|
||||
client = TestClient(app)
|
||||
|
||||
return client
|
||||
@pytest.fixture
|
||||
def test_app(
|
||||
mock_session_service,
|
||||
mock_artifact_service,
|
||||
mock_memory_service,
|
||||
mock_agent_loader,
|
||||
mock_eval_sets_manager,
|
||||
mock_eval_set_results_manager,
|
||||
):
|
||||
"""Create a TestClient for the FastAPI app without starting a server."""
|
||||
return _create_test_client(
|
||||
mock_session_service,
|
||||
mock_artifact_service,
|
||||
mock_memory_service,
|
||||
mock_agent_loader,
|
||||
mock_eval_sets_manager,
|
||||
mock_eval_set_results_manager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1106,20 +1122,9 @@ def test_agent_run_sse_yields_error_object_on_exception(
|
||||
"""Test /run_sse streams an error object if streaming raises."""
|
||||
info = create_test_session
|
||||
|
||||
async def run_async_raises(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
invocation_id: Optional[str] = None,
|
||||
new_message: Optional[types.Content] = None,
|
||||
state_delta: Optional[dict[str, Any]] = None,
|
||||
run_config: Optional[RunConfig] = None,
|
||||
):
|
||||
del user_id, session_id, invocation_id, new_message, state_delta, run_config
|
||||
async def run_async_raises(self, **kwargs):
|
||||
raise ValueError("boom")
|
||||
if False: # pylint: disable=using-constant-test
|
||||
yield _event_1()
|
||||
yield # make it an async generator # pylint: disable=unreachable
|
||||
|
||||
monkeypatch.setattr(Runner, "run_async", run_async_raises)
|
||||
|
||||
@@ -1637,5 +1642,80 @@ def test_version_endpoint(test_app):
|
||||
assert "language_version" in data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app_auto_session(
|
||||
mock_session_service,
|
||||
mock_artifact_service,
|
||||
mock_memory_service,
|
||||
mock_agent_loader,
|
||||
mock_eval_sets_manager,
|
||||
mock_eval_set_results_manager,
|
||||
):
|
||||
"""Create a TestClient with auto_create_session=True."""
|
||||
return _create_test_client(
|
||||
mock_session_service,
|
||||
mock_artifact_service,
|
||||
mock_memory_service,
|
||||
mock_agent_loader,
|
||||
mock_eval_sets_manager,
|
||||
mock_eval_set_results_manager,
|
||||
web=False,
|
||||
auto_create_session=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("endpoint", ["/run", "/run_sse"])
|
||||
def test_auto_creates_session(
|
||||
test_app_auto_session, test_session_info, endpoint
|
||||
):
|
||||
"""Test /run and /run_sse auto-create sessions when auto_create_session=True."""
|
||||
payload = {
|
||||
"app_name": test_session_info["app_name"],
|
||||
"user_id": test_session_info["user_id"],
|
||||
"session_id": "nonexistent_session",
|
||||
"new_message": {"role": "user", "parts": [{"text": "Hello"}]},
|
||||
}
|
||||
|
||||
response = test_app_auto_session.post(endpoint, json=payload)
|
||||
assert response.status_code == 200
|
||||
|
||||
if endpoint == "/run":
|
||||
data = response.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) > 0
|
||||
else:
|
||||
sse_events = [
|
||||
json.loads(line.removeprefix("data: "))
|
||||
for line in response.text.splitlines()
|
||||
if line.startswith("data: ")
|
||||
]
|
||||
assert len(sse_events) > 0
|
||||
assert not any("error" in e for e in sse_events)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("endpoint", ["/run", "/run_sse"])
|
||||
def test_returns_404_without_auto_create(
|
||||
test_app, test_session_info, monkeypatch, endpoint
|
||||
):
|
||||
"""Test /run and /run_sse return 404 for missing sessions without auto_create."""
|
||||
|
||||
async def run_async_session_not_found(self, **kwargs):
|
||||
raise SessionNotFoundError(f"Session not found: {kwargs['session_id']}")
|
||||
yield # make it an async generator # pylint: disable=unreachable
|
||||
|
||||
monkeypatch.setattr(Runner, "run_async", run_async_session_not_found)
|
||||
|
||||
payload = {
|
||||
"app_name": test_session_info["app_name"],
|
||||
"user_id": test_session_info["user_id"],
|
||||
"session_id": "nonexistent_session",
|
||||
"new_message": {"role": "user", "parts": [{"text": "Hello"}]},
|
||||
}
|
||||
|
||||
response = test_app.post(endpoint, json=payload)
|
||||
assert response.status_code == 404
|
||||
assert "Session not found" in response.json()["detail"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-xvs", __file__])
|
||||
|
||||
@@ -29,6 +29,7 @@ from google.adk.apps.app import App
|
||||
from google.adk.apps.app import ResumabilityConfig
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.cli.utils.agent_loader import AgentLoader
|
||||
from google.adk.errors.session_not_found_error import SessionNotFoundError
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.plugins.base_plugin import BasePlugin
|
||||
from google.adk.runners import Runner
|
||||
@@ -243,7 +244,7 @@ async def test_session_not_found_message_includes_alignment_hint():
|
||||
new_message=types.Content(role="user", parts=[]),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
with pytest.raises(SessionNotFoundError) as excinfo:
|
||||
await agen.__anext__()
|
||||
|
||||
await agen.aclose()
|
||||
|
||||
Reference in New Issue
Block a user