fix: Add a FastAPI endpoint for saving artifacts

This change adds new `POST` endpoint `/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts` to the ADK web server. This endpoint lets clients to save new artifacts associated with a specific session. The endpoint uses `SaveArtifactRequest` and returns `SaveArtifactResponse`, including the version and canonical URI of the saved artifact.

Close #1975

Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 838977880
This commit is contained in:
George Weale
2025-12-01 16:28:45 -08:00
committed by Copybara-Service
parent ed9da3fa45
commit 7e8eeca6aa
8 changed files with 363 additions and 57 deletions
@@ -32,6 +32,7 @@ from pydantic import Field
from pydantic import ValidationError
from typing_extensions import override
from ..errors.input_validation_error import InputValidationError
from .base_artifact_service import ArtifactVersion
from .base_artifact_service import BaseArtifactService
@@ -100,14 +101,14 @@ def _resolve_scoped_artifact_path(
to `scope_root`.
Raises:
ValueError: If `filename` resolves outside of `scope_root`.
InputValidationError: If `filename` resolves outside of `scope_root`.
"""
stripped = _strip_user_namespace(filename).strip()
pure_path = _to_posix_path(stripped)
scope_root_resolved = scope_root.resolve(strict=False)
if pure_path.is_absolute():
raise ValueError(
raise InputValidationError(
f"Absolute artifact filename {filename!r} is not permitted; "
"provide a path relative to the storage scope."
)
@@ -118,7 +119,7 @@ def _resolve_scoped_artifact_path(
try:
relative = candidate.relative_to(scope_root_resolved)
except ValueError as exc:
raise ValueError(
raise InputValidationError(
f"Artifact filename {filename!r} escapes storage directory "
f"{scope_root_resolved}"
) from exc
@@ -230,7 +231,7 @@ class FileArtifactService(BaseArtifactService):
if _is_user_scoped(session_id, filename):
return _user_artifacts_dir(base)
if not session_id:
raise ValueError(
raise InputValidationError(
"Session ID must be provided for session-scoped artifacts."
)
return _session_artifacts_dir(base, session_id)
@@ -371,7 +372,9 @@ class FileArtifactService(BaseArtifactService):
content_path.write_text(artifact.text, encoding="utf-8")
mime_type = None
else:
raise ValueError("Artifact must have either inline_data or text content.")
raise InputValidationError(
"Artifact must have either inline_data or text content."
)
canonical_uri = self._canonical_uri(
user_id=user_id,
@@ -30,6 +30,7 @@ from typing import Optional
from google.genai import types
from typing_extensions import override
from ..errors.input_validation_error import InputValidationError
from .base_artifact_service import ArtifactVersion
from .base_artifact_service import BaseArtifactService
@@ -161,7 +162,7 @@ class GcsArtifactService(BaseArtifactService):
return f"{app_name}/{user_id}/user/{filename}"
if session_id is None:
raise ValueError(
raise InputValidationError(
"Session ID must be provided for session-scoped artifacts."
)
return f"{app_name}/{user_id}/{session_id}/{filename}"
@@ -230,7 +231,9 @@ class GcsArtifactService(BaseArtifactService):
" GcsArtifactService."
)
else:
raise ValueError("Artifact must have either inline_data or text.")
raise InputValidationError(
"Artifact must have either inline_data or text."
)
return version
@@ -18,12 +18,13 @@ import logging
from typing import Any
from typing import Optional
from google.adk.artifacts import artifact_util
from google.genai import types
from pydantic import BaseModel
from pydantic import Field
from typing_extensions import override
from . import artifact_util
from ..errors.input_validation_error import InputValidationError
from .base_artifact_service import ArtifactVersion
from .base_artifact_service import BaseArtifactService
@@ -86,7 +87,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
return f"{app_name}/{user_id}/user/{filename}"
if session_id is None:
raise ValueError(
raise InputValidationError(
"Session ID must be provided for session-scoped artifacts."
)
return f"{app_name}/{user_id}/{session_id}/{filename}"
@@ -125,7 +126,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
elif artifact.file_data is not None:
if artifact_util.is_artifact_ref(artifact):
if not artifact_util.parse_artifact_uri(artifact.file_data.file_uri):
raise ValueError(
raise InputValidationError(
f"Invalid artifact reference URI: {artifact.file_data.file_uri}"
)
# If it's a valid artifact URI, we store the artifact part as-is.
@@ -133,7 +134,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
else:
artifact_version.mime_type = artifact.file_data.mime_type
else:
raise ValueError("Not supported artifact type.")
raise InputValidationError("Not supported artifact type.")
self.artifacts[path].append(
_ArtifactEntry(data=artifact, artifact_version=artifact_version)
@@ -172,7 +173,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
artifact_data.file_data.file_uri
)
if not parsed_uri:
raise ValueError(
raise InputValidationError(
"Invalid artifact reference URI:"
f" {artifact_data.file_data.file_uri}"
)
+62
View File
@@ -61,9 +61,11 @@ from ..agents.live_request_queue import LiveRequestQueue
from ..agents.run_config import RunConfig
from ..agents.run_config import StreamingMode
from ..apps.app import App
from ..artifacts.base_artifact_service import ArtifactVersion
from ..artifacts.base_artifact_service import BaseArtifactService
from ..auth.credential_service.base_credential_service import BaseCredentialService
from ..errors.already_exists_error import AlreadyExistsError
from ..errors.input_validation_error import InputValidationError
from ..errors.not_found_error import NotFoundError
from ..evaluation.base_eval_service import InferenceConfig
from ..evaluation.base_eval_service import InferenceRequest
@@ -194,6 +196,19 @@ class CreateSessionRequest(common.BaseModel):
)
class SaveArtifactRequest(common.BaseModel):
"""Request payload for saving a new artifact."""
filename: str = Field(description="Artifact filename.")
artifact: types.Part = Field(
description="Artifact payload encoded as google.genai.types.Part."
)
custom_metadata: Optional[dict[str, Any]] = Field(
default=None,
description="Optional metadata to associate with the artifact version.",
)
class AddSessionToEvalSetRequest(common.BaseModel):
eval_id: str
session_id: str
@@ -1316,6 +1331,53 @@ class AdkWebServer:
raise HTTPException(status_code=404, detail="Artifact not found")
return artifact
@app.post(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
response_model=ArtifactVersion,
response_model_exclude_none=True,
)
async def save_artifact(
app_name: str,
user_id: str,
session_id: str,
req: SaveArtifactRequest,
) -> ArtifactVersion:
try:
version = await self.artifact_service.save_artifact(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=req.filename,
artifact=req.artifact,
custom_metadata=req.custom_metadata,
)
except InputValidationError as ive:
raise HTTPException(status_code=400, detail=str(ive)) from ive
except Exception as exc: # pylint: disable=broad-exception-caught
logger.error(
"Internal error while saving artifact %s for app=%s user=%s"
" session=%s: %s",
req.filename,
app_name,
user_id,
session_id,
exc,
exc_info=True,
)
raise HTTPException(status_code=500, detail=str(exc)) from exc
artifact_version = await self.artifact_service.get_artifact_version(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=req.filename,
version=version,
)
if artifact_version is None:
raise HTTPException(
status_code=500, detail="Artifact metadata unavailable"
)
return artifact_version
@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
response_model_exclude_none=True,
+60 -22
View File
@@ -14,6 +14,7 @@
from __future__ import annotations
import importlib
import json
import logging
import os
@@ -34,22 +35,43 @@ from opentelemetry.sdk.trace import TracerProvider
from starlette.types import Lifespan
from watchdog.observers import Observer
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..runners import Runner
from ..sessions.in_memory_session_service import InMemorySessionService
from .adk_web_server import AdkWebServer
from .service_registry import get_service_registry
from .service_registry import load_services_module
from .utils import envs
from .utils import evals
from .utils.agent_change_handler import AgentChangeEventHandler
from .utils.agent_loader import AgentLoader
from .utils.service_factory import create_artifact_service_from_options
from .utils.service_factory import create_memory_service_from_options
from .utils.service_factory import create_session_service_from_options
logger = logging.getLogger("google_adk." + __name__)
_LAZY_SERVICE_IMPORTS: dict[str, str] = {
"AgentLoader": ".utils.agent_loader",
"InMemoryArtifactService": "..artifacts.in_memory_artifact_service",
"InMemoryMemoryService": "..memory.in_memory_memory_service",
"InMemorySessionService": "..sessions.in_memory_session_service",
"LocalEvalSetResultsManager": "..evaluation.local_eval_set_results_manager",
"LocalEvalSetsManager": "..evaluation.local_eval_sets_manager",
}
def __getattr__(name: str):
"""Lazily import defaults so patching in tests keeps working."""
if name not in _LAZY_SERVICE_IMPORTS:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
module = importlib.import_module(_LAZY_SERVICE_IMPORTS[name], __package__)
attr = getattr(module, name)
globals()[name] = attr
return attr
def get_fast_api_app(
*,
@@ -73,8 +95,6 @@ def get_fast_api_app(
logo_text: Optional[str] = None,
logo_image_url: Optional[str] = None,
) -> FastAPI:
# Convert to absolute path for consistency
agents_dir = str(Path(agents_dir).resolve())
# Set up eval managers.
if eval_storage_uri:
@@ -92,30 +112,48 @@ def get_fast_api_app(
# Load services.py from agents_dir for custom service registration.
load_services_module(agents_dir)
service_registry = get_service_registry()
# Build the Memory service
try:
memory_service = create_memory_service_from_options(
base_dir=agents_dir,
memory_service_uri=memory_service_uri,
if memory_service_uri:
memory_service = service_registry.create_memory_service(
memory_service_uri, agents_dir=agents_dir
)
except ValueError as exc:
raise click.ClickException(str(exc)) from exc
if not memory_service:
raise click.ClickException(
"Unsupported memory service URI: %s" % memory_service_uri
)
else:
memory_service = InMemoryMemoryService()
# Build the Session service
session_service = create_session_service_from_options(
base_dir=agents_dir,
session_service_uri=session_service_uri,
session_db_kwargs=session_db_kwargs,
)
if session_service_uri:
session_kwargs = session_db_kwargs or {}
session_service = service_registry.create_session_service(
session_service_uri, agents_dir=agents_dir, **session_kwargs
)
if not session_service:
# Fallback to DatabaseSessionService if the service registry doesn't
# support the session service URI scheme.
from ..sessions.database_session_service import DatabaseSessionService
session_service = DatabaseSessionService(
db_url=session_service_uri, **session_kwargs
)
else:
session_service = InMemorySessionService()
# Build the Artifact service
try:
artifact_service = create_artifact_service_from_options(
base_dir=agents_dir,
artifact_service_uri=artifact_service_uri,
if artifact_service_uri:
artifact_service = service_registry.create_artifact_service(
artifact_service_uri, agents_dir=agents_dir
)
except ValueError as exc:
raise click.ClickException(str(exc)) from exc
if not artifact_service:
raise click.ClickException(
"Unsupported artifact service URI: %s" % artifact_service_uri
)
else:
artifact_service = InMemoryArtifactService()
# Build the Credential service
credential_service = InMemoryCredentialService()
@@ -0,0 +1,28 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
class InputValidationError(ValueError):
"""Represents an error raised when user input fails validation."""
def __init__(self, message="Invalid input."):
"""Initializes the InputValidationError exception.
Args:
message (str): A message describing why the input is invalid.
"""
self.message = message
super().__init__(self.message)
@@ -32,6 +32,7 @@ from google.adk.artifacts.base_artifact_service import ArtifactVersion
from google.adk.artifacts.file_artifact_service import FileArtifactService
from google.adk.artifacts.gcs_artifact_service import GcsArtifactService
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.errors.input_validation_error import InputValidationError
from google.genai import types
import pytest
@@ -732,7 +733,7 @@ async def test_file_save_artifact_rejects_out_of_scope_paths(
"""FileArtifactService prevents path traversal outside of its storage roots."""
artifact_service = FileArtifactService(root_dir=tmp_path / "artifacts")
part = types.Part(text="content")
with pytest.raises(ValueError):
with pytest.raises(InputValidationError):
await artifact_service.save_artifact(
app_name="myapp",
user_id="user123",
@@ -757,7 +758,7 @@ async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path):
/ "diagram.png"
)
part = types.Part(text="content")
with pytest.raises(ValueError):
with pytest.raises(InputValidationError):
await artifact_service.save_artifact(
app_name="myapp",
user_id="user123",
+191 -21
View File
@@ -30,7 +30,9 @@ from fastapi.testclient import TestClient
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.run_config import RunConfig
from google.adk.apps.app import App
from google.adk.artifacts.base_artifact_service import ArtifactVersion
from google.adk.cli.fast_api import get_fast_api_app
from google.adk.errors.input_validation_error import InputValidationError
from google.adk.evaluation.eval_case import EvalCase
from google.adk.evaluation.eval_case import Invocation
from google.adk.evaluation.eval_result import EvalSetResult
@@ -211,48 +213,135 @@ def mock_session_service():
def mock_artifact_service():
"""Create a mock artifact service."""
# Storage for artifacts
artifacts = {}
artifacts: dict[str, list[dict[str, Any]]] = {}
def _artifact_key(
app_name: str, user_id: str, session_id: Optional[str], filename: str
) -> str:
if session_id is None:
return f"{app_name}:{user_id}:user:{filename}"
return f"{app_name}:{user_id}:{session_id}:{filename}"
def _canonical_uri(
app_name: str,
user_id: str,
session_id: Optional[str],
filename: str,
version: int,
) -> str:
if session_id is None:
return (
f"artifact://apps/{app_name}/users/{user_id}/artifacts/"
f"{filename}/versions/{version}"
)
return (
f"artifact://apps/{app_name}/users/{user_id}/sessions/{session_id}/"
f"artifacts/{filename}/versions/{version}"
)
class MockArtifactService:
def __init__(self):
self._artifacts = artifacts
self.save_artifact_side_effect: Optional[BaseException] = None
async def save_artifact(
self,
*,
app_name: str,
user_id: str,
filename: str,
artifact: types.Part,
session_id: Optional[str] = None,
custom_metadata: Optional[dict[str, Any]] = None,
) -> int:
if self.save_artifact_side_effect is not None:
effect = self.save_artifact_side_effect
if isinstance(effect, BaseException):
raise effect
raise TypeError(
"save_artifact_side_effect must be an exception instance."
)
key = _artifact_key(app_name, user_id, session_id, filename)
entries = artifacts.setdefault(key, [])
version = len(entries)
artifact_version = ArtifactVersion(
version=version,
canonical_uri=_canonical_uri(
app_name, user_id, session_id, filename, version
),
custom_metadata=custom_metadata or {},
)
if artifact.inline_data is not None:
artifact_version.mime_type = artifact.inline_data.mime_type
elif artifact.text is not None:
artifact_version.mime_type = "text/plain"
elif artifact.file_data is not None:
artifact_version.mime_type = artifact.file_data.mime_type
entries.append({
"version": version,
"artifact": artifact,
"metadata": artifact_version,
})
return version
async def load_artifact(
self, app_name, user_id, session_id, filename, version=None
):
"""Load an artifact by filename."""
key = f"{app_name}:{user_id}:{session_id}:{filename}"
key = _artifact_key(app_name, user_id, session_id, filename)
if key not in artifacts:
return None
if version is not None:
# Get a specific version
for v in artifacts[key]:
if v["version"] == version:
return v["artifact"]
for entry in artifacts[key]:
if entry["version"] == version:
return entry["artifact"]
return None
# Get the latest version
return sorted(artifacts[key], key=lambda x: x["version"])[-1]["artifact"]
return artifacts[key][-1]["artifact"]
async def list_artifact_keys(self, app_name, user_id, session_id):
"""List artifact names for a session."""
prefix = f"{app_name}:{user_id}:{session_id}:"
return [
k.split(":")[-1] for k in artifacts.keys() if k.startswith(prefix)
key.split(":")[-1]
for key in artifacts.keys()
if key.startswith(prefix)
]
async def list_versions(self, app_name, user_id, session_id, filename):
"""List versions of an artifact."""
key = f"{app_name}:{user_id}:{session_id}:{filename}"
key = _artifact_key(app_name, user_id, session_id, filename)
if key not in artifacts:
return []
return [a["version"] for a in artifacts[key]]
return [entry["version"] for entry in artifacts[key]]
async def delete_artifact(self, app_name, user_id, session_id, filename):
"""Delete an artifact."""
key = f"{app_name}:{user_id}:{session_id}:{filename}"
if key in artifacts:
del artifacts[key]
key = _artifact_key(app_name, user_id, session_id, filename)
artifacts.pop(key, None)
async def get_artifact_version(
self,
*,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str] = None,
version: Optional[int] = None,
) -> Optional[ArtifactVersion]:
key = _artifact_key(app_name, user_id, session_id, filename)
entries = artifacts.get(key)
if not entries:
return None
if version is None:
return entries[-1]["metadata"]
for entry in entries:
if entry["version"] == version:
return entry["metadata"]
return None
return MockArtifactService()
@@ -327,15 +416,15 @@ def test_app(
with (
patch("signal.signal", return_value=None),
patch(
"google.adk.cli.fast_api.create_session_service_from_options",
"google.adk.cli.fast_api.InMemorySessionService",
return_value=mock_session_service,
),
patch(
"google.adk.cli.fast_api.create_artifact_service_from_options",
"google.adk.cli.fast_api.InMemoryArtifactService",
return_value=mock_artifact_service,
),
patch(
"google.adk.cli.fast_api.create_memory_service_from_options",
"google.adk.cli.fast_api.InMemoryMemoryService",
return_value=mock_memory_service,
),
patch(
@@ -472,15 +561,15 @@ def test_app_with_a2a(
with (
patch("signal.signal", return_value=None),
patch(
"google.adk.cli.fast_api.create_session_service_from_options",
"google.adk.cli.fast_api.InMemorySessionService",
return_value=mock_session_service,
),
patch(
"google.adk.cli.fast_api.create_artifact_service_from_options",
"google.adk.cli.fast_api.InMemoryArtifactService",
return_value=mock_artifact_service,
),
patch(
"google.adk.cli.fast_api.create_memory_service_from_options",
"google.adk.cli.fast_api.InMemoryMemoryService",
return_value=mock_memory_service,
),
patch(
@@ -810,6 +899,87 @@ def test_list_artifact_names(test_app, create_test_session):
logger.info(f"Listed {len(data)} artifacts")
def test_save_artifact(test_app, create_test_session, mock_artifact_service):
"""Test saving an artifact through the FastAPI endpoint."""
info = create_test_session
url = (
f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
f"{info['session_id']}/artifacts"
)
artifact_part = types.Part(text="hello world")
payload = {
"filename": "greeting.txt",
"artifact": artifact_part.model_dump(by_alias=True, exclude_none=True),
}
response = test_app.post(url, json=payload)
assert response.status_code == 200
data = response.json()
assert data["version"] == 0
assert data["customMetadata"] == {}
assert data["mimeType"] in (None, "text/plain")
assert data["canonicalUri"].endswith(
f"/sessions/{info['session_id']}/artifacts/"
f"{payload['filename']}/versions/0"
)
assert isinstance(data["createTime"], float)
key = (
f"{info['app_name']}:{info['user_id']}:{info['session_id']}:"
f"{payload['filename']}"
)
stored = mock_artifact_service._artifacts[key][0]
assert stored["artifact"].text == "hello world"
def test_save_artifact_returns_400_on_validation_error(
test_app, create_test_session, mock_artifact_service
):
"""Test save artifact endpoint surfaces validation errors as HTTP 400."""
info = create_test_session
url = (
f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
f"{info['session_id']}/artifacts"
)
artifact_part = types.Part(text="bad data")
payload = {
"filename": "invalid.txt",
"artifact": artifact_part.model_dump(by_alias=True, exclude_none=True),
}
mock_artifact_service.save_artifact_side_effect = InputValidationError(
"invalid artifact"
)
response = test_app.post(url, json=payload)
assert response.status_code == 400
assert response.json()["detail"] == "invalid artifact"
def test_save_artifact_returns_500_on_unexpected_error(
test_app, create_test_session, mock_artifact_service
):
"""Test save artifact endpoint surfaces unexpected errors as HTTP 500."""
info = create_test_session
url = (
f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
f"{info['session_id']}/artifacts"
)
artifact_part = types.Part(text="bad data")
payload = {
"filename": "invalid.txt",
"artifact": artifact_part.model_dump(by_alias=True, exclude_none=True),
}
mock_artifact_service.save_artifact_side_effect = RuntimeError(
"unexpected failure"
)
response = test_app.post(url, json=payload)
assert response.status_code == 500
assert response.json()["detail"] == "unexpected failure"
def test_create_eval_set(test_app, test_session_info):
"""Test creating an eval set."""
url = f"/apps/{test_session_info['app_name']}/eval_sets/test_eval_set_id"