You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
fix: Add a FastAPI endpoint for saving artifacts
This change adds new `POST` endpoint `/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts` to the ADK web server. This endpoint lets clients to save new artifacts associated with a specific session. The endpoint uses `SaveArtifactRequest` and returns `SaveArtifactResponse`, including the version and canonical URI of the saved artifact.
Close #1975
Co-authored-by: George Weale <gweale@google.com>
PiperOrigin-RevId: 838977880
This commit is contained in:
committed by
Copybara-Service
parent
ed9da3fa45
commit
7e8eeca6aa
@@ -32,6 +32,7 @@ from pydantic import Field
|
||||
from pydantic import ValidationError
|
||||
from typing_extensions import override
|
||||
|
||||
from ..errors.input_validation_error import InputValidationError
|
||||
from .base_artifact_service import ArtifactVersion
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
@@ -100,14 +101,14 @@ def _resolve_scoped_artifact_path(
|
||||
to `scope_root`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `filename` resolves outside of `scope_root`.
|
||||
InputValidationError: If `filename` resolves outside of `scope_root`.
|
||||
"""
|
||||
stripped = _strip_user_namespace(filename).strip()
|
||||
pure_path = _to_posix_path(stripped)
|
||||
|
||||
scope_root_resolved = scope_root.resolve(strict=False)
|
||||
if pure_path.is_absolute():
|
||||
raise ValueError(
|
||||
raise InputValidationError(
|
||||
f"Absolute artifact filename {filename!r} is not permitted; "
|
||||
"provide a path relative to the storage scope."
|
||||
)
|
||||
@@ -118,7 +119,7 @@ def _resolve_scoped_artifact_path(
|
||||
try:
|
||||
relative = candidate.relative_to(scope_root_resolved)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
raise InputValidationError(
|
||||
f"Artifact filename {filename!r} escapes storage directory "
|
||||
f"{scope_root_resolved}"
|
||||
) from exc
|
||||
@@ -230,7 +231,7 @@ class FileArtifactService(BaseArtifactService):
|
||||
if _is_user_scoped(session_id, filename):
|
||||
return _user_artifacts_dir(base)
|
||||
if not session_id:
|
||||
raise ValueError(
|
||||
raise InputValidationError(
|
||||
"Session ID must be provided for session-scoped artifacts."
|
||||
)
|
||||
return _session_artifacts_dir(base, session_id)
|
||||
@@ -371,7 +372,9 @@ class FileArtifactService(BaseArtifactService):
|
||||
content_path.write_text(artifact.text, encoding="utf-8")
|
||||
mime_type = None
|
||||
else:
|
||||
raise ValueError("Artifact must have either inline_data or text content.")
|
||||
raise InputValidationError(
|
||||
"Artifact must have either inline_data or text content."
|
||||
)
|
||||
|
||||
canonical_uri = self._canonical_uri(
|
||||
user_id=user_id,
|
||||
|
||||
@@ -30,6 +30,7 @@ from typing import Optional
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ..errors.input_validation_error import InputValidationError
|
||||
from .base_artifact_service import ArtifactVersion
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
@@ -161,7 +162,7 @@ class GcsArtifactService(BaseArtifactService):
|
||||
return f"{app_name}/{user_id}/user/{filename}"
|
||||
|
||||
if session_id is None:
|
||||
raise ValueError(
|
||||
raise InputValidationError(
|
||||
"Session ID must be provided for session-scoped artifacts."
|
||||
)
|
||||
return f"{app_name}/{user_id}/{session_id}/{filename}"
|
||||
@@ -230,7 +231,9 @@ class GcsArtifactService(BaseArtifactService):
|
||||
" GcsArtifactService."
|
||||
)
|
||||
else:
|
||||
raise ValueError("Artifact must have either inline_data or text.")
|
||||
raise InputValidationError(
|
||||
"Artifact must have either inline_data or text."
|
||||
)
|
||||
|
||||
return version
|
||||
|
||||
|
||||
@@ -18,12 +18,13 @@ import logging
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.adk.artifacts import artifact_util
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from . import artifact_util
|
||||
from ..errors.input_validation_error import InputValidationError
|
||||
from .base_artifact_service import ArtifactVersion
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
@@ -86,7 +87,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
||||
return f"{app_name}/{user_id}/user/{filename}"
|
||||
|
||||
if session_id is None:
|
||||
raise ValueError(
|
||||
raise InputValidationError(
|
||||
"Session ID must be provided for session-scoped artifacts."
|
||||
)
|
||||
return f"{app_name}/{user_id}/{session_id}/{filename}"
|
||||
@@ -125,7 +126,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
||||
elif artifact.file_data is not None:
|
||||
if artifact_util.is_artifact_ref(artifact):
|
||||
if not artifact_util.parse_artifact_uri(artifact.file_data.file_uri):
|
||||
raise ValueError(
|
||||
raise InputValidationError(
|
||||
f"Invalid artifact reference URI: {artifact.file_data.file_uri}"
|
||||
)
|
||||
# If it's a valid artifact URI, we store the artifact part as-is.
|
||||
@@ -133,7 +134,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
||||
else:
|
||||
artifact_version.mime_type = artifact.file_data.mime_type
|
||||
else:
|
||||
raise ValueError("Not supported artifact type.")
|
||||
raise InputValidationError("Not supported artifact type.")
|
||||
|
||||
self.artifacts[path].append(
|
||||
_ArtifactEntry(data=artifact, artifact_version=artifact_version)
|
||||
@@ -172,7 +173,7 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
||||
artifact_data.file_data.file_uri
|
||||
)
|
||||
if not parsed_uri:
|
||||
raise ValueError(
|
||||
raise InputValidationError(
|
||||
"Invalid artifact reference URI:"
|
||||
f" {artifact_data.file_data.file_uri}"
|
||||
)
|
||||
|
||||
@@ -61,9 +61,11 @@ from ..agents.live_request_queue import LiveRequestQueue
|
||||
from ..agents.run_config import RunConfig
|
||||
from ..agents.run_config import StreamingMode
|
||||
from ..apps.app import App
|
||||
from ..artifacts.base_artifact_service import ArtifactVersion
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..auth.credential_service.base_credential_service import BaseCredentialService
|
||||
from ..errors.already_exists_error import AlreadyExistsError
|
||||
from ..errors.input_validation_error import InputValidationError
|
||||
from ..errors.not_found_error import NotFoundError
|
||||
from ..evaluation.base_eval_service import InferenceConfig
|
||||
from ..evaluation.base_eval_service import InferenceRequest
|
||||
@@ -194,6 +196,19 @@ class CreateSessionRequest(common.BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class SaveArtifactRequest(common.BaseModel):
|
||||
"""Request payload for saving a new artifact."""
|
||||
|
||||
filename: str = Field(description="Artifact filename.")
|
||||
artifact: types.Part = Field(
|
||||
description="Artifact payload encoded as google.genai.types.Part."
|
||||
)
|
||||
custom_metadata: Optional[dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional metadata to associate with the artifact version.",
|
||||
)
|
||||
|
||||
|
||||
class AddSessionToEvalSetRequest(common.BaseModel):
|
||||
eval_id: str
|
||||
session_id: str
|
||||
@@ -1316,6 +1331,53 @@ class AdkWebServer:
|
||||
raise HTTPException(status_code=404, detail="Artifact not found")
|
||||
return artifact
|
||||
|
||||
@app.post(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
|
||||
response_model=ArtifactVersion,
|
||||
response_model_exclude_none=True,
|
||||
)
|
||||
async def save_artifact(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
req: SaveArtifactRequest,
|
||||
) -> ArtifactVersion:
|
||||
try:
|
||||
version = await self.artifact_service.save_artifact(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=req.filename,
|
||||
artifact=req.artifact,
|
||||
custom_metadata=req.custom_metadata,
|
||||
)
|
||||
except InputValidationError as ive:
|
||||
raise HTTPException(status_code=400, detail=str(ive)) from ive
|
||||
except Exception as exc: # pylint: disable=broad-exception-caught
|
||||
logger.error(
|
||||
"Internal error while saving artifact %s for app=%s user=%s"
|
||||
" session=%s: %s",
|
||||
req.filename,
|
||||
app_name,
|
||||
user_id,
|
||||
session_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
artifact_version = await self.artifact_service.get_artifact_version(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
filename=req.filename,
|
||||
version=version,
|
||||
)
|
||||
if artifact_version is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Artifact metadata unavailable"
|
||||
)
|
||||
return artifact_version
|
||||
|
||||
@app.get(
|
||||
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
|
||||
response_model_exclude_none=True,
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -34,22 +35,43 @@ from opentelemetry.sdk.trace import TracerProvider
|
||||
from starlette.types import Lifespan
|
||||
from watchdog.observers import Observer
|
||||
|
||||
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService
|
||||
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
|
||||
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
||||
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
||||
from ..runners import Runner
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from .adk_web_server import AdkWebServer
|
||||
from .service_registry import get_service_registry
|
||||
from .service_registry import load_services_module
|
||||
from .utils import envs
|
||||
from .utils import evals
|
||||
from .utils.agent_change_handler import AgentChangeEventHandler
|
||||
from .utils.agent_loader import AgentLoader
|
||||
from .utils.service_factory import create_artifact_service_from_options
|
||||
from .utils.service_factory import create_memory_service_from_options
|
||||
from .utils.service_factory import create_session_service_from_options
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
_LAZY_SERVICE_IMPORTS: dict[str, str] = {
|
||||
"AgentLoader": ".utils.agent_loader",
|
||||
"InMemoryArtifactService": "..artifacts.in_memory_artifact_service",
|
||||
"InMemoryMemoryService": "..memory.in_memory_memory_service",
|
||||
"InMemorySessionService": "..sessions.in_memory_session_service",
|
||||
"LocalEvalSetResultsManager": "..evaluation.local_eval_set_results_manager",
|
||||
"LocalEvalSetsManager": "..evaluation.local_eval_sets_manager",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
"""Lazily import defaults so patching in tests keeps working."""
|
||||
if name not in _LAZY_SERVICE_IMPORTS:
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
module = importlib.import_module(_LAZY_SERVICE_IMPORTS[name], __package__)
|
||||
attr = getattr(module, name)
|
||||
globals()[name] = attr
|
||||
return attr
|
||||
|
||||
|
||||
def get_fast_api_app(
|
||||
*,
|
||||
@@ -73,8 +95,6 @@ def get_fast_api_app(
|
||||
logo_text: Optional[str] = None,
|
||||
logo_image_url: Optional[str] = None,
|
||||
) -> FastAPI:
|
||||
# Convert to absolute path for consistency
|
||||
agents_dir = str(Path(agents_dir).resolve())
|
||||
|
||||
# Set up eval managers.
|
||||
if eval_storage_uri:
|
||||
@@ -92,30 +112,48 @@ def get_fast_api_app(
|
||||
# Load services.py from agents_dir for custom service registration.
|
||||
load_services_module(agents_dir)
|
||||
|
||||
service_registry = get_service_registry()
|
||||
|
||||
# Build the Memory service
|
||||
try:
|
||||
memory_service = create_memory_service_from_options(
|
||||
base_dir=agents_dir,
|
||||
memory_service_uri=memory_service_uri,
|
||||
if memory_service_uri:
|
||||
memory_service = service_registry.create_memory_service(
|
||||
memory_service_uri, agents_dir=agents_dir
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
if not memory_service:
|
||||
raise click.ClickException(
|
||||
"Unsupported memory service URI: %s" % memory_service_uri
|
||||
)
|
||||
else:
|
||||
memory_service = InMemoryMemoryService()
|
||||
|
||||
# Build the Session service
|
||||
session_service = create_session_service_from_options(
|
||||
base_dir=agents_dir,
|
||||
session_service_uri=session_service_uri,
|
||||
session_db_kwargs=session_db_kwargs,
|
||||
)
|
||||
if session_service_uri:
|
||||
session_kwargs = session_db_kwargs or {}
|
||||
session_service = service_registry.create_session_service(
|
||||
session_service_uri, agents_dir=agents_dir, **session_kwargs
|
||||
)
|
||||
if not session_service:
|
||||
# Fallback to DatabaseSessionService if the service registry doesn't
|
||||
# support the session service URI scheme.
|
||||
from ..sessions.database_session_service import DatabaseSessionService
|
||||
|
||||
session_service = DatabaseSessionService(
|
||||
db_url=session_service_uri, **session_kwargs
|
||||
)
|
||||
else:
|
||||
session_service = InMemorySessionService()
|
||||
|
||||
# Build the Artifact service
|
||||
try:
|
||||
artifact_service = create_artifact_service_from_options(
|
||||
base_dir=agents_dir,
|
||||
artifact_service_uri=artifact_service_uri,
|
||||
if artifact_service_uri:
|
||||
artifact_service = service_registry.create_artifact_service(
|
||||
artifact_service_uri, agents_dir=agents_dir
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise click.ClickException(str(exc)) from exc
|
||||
if not artifact_service:
|
||||
raise click.ClickException(
|
||||
"Unsupported artifact service URI: %s" % artifact_service_uri
|
||||
)
|
||||
else:
|
||||
artifact_service = InMemoryArtifactService()
|
||||
|
||||
# Build the Credential service
|
||||
credential_service = InMemoryCredentialService()
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class InputValidationError(ValueError):
|
||||
"""Represents an error raised when user input fails validation."""
|
||||
|
||||
def __init__(self, message="Invalid input."):
|
||||
"""Initializes the InputValidationError exception.
|
||||
|
||||
Args:
|
||||
message (str): A message describing why the input is invalid.
|
||||
"""
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
@@ -32,6 +32,7 @@ from google.adk.artifacts.base_artifact_service import ArtifactVersion
|
||||
from google.adk.artifacts.file_artifact_service import FileArtifactService
|
||||
from google.adk.artifacts.gcs_artifact_service import GcsArtifactService
|
||||
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
from google.adk.errors.input_validation_error import InputValidationError
|
||||
from google.genai import types
|
||||
import pytest
|
||||
|
||||
@@ -732,7 +733,7 @@ async def test_file_save_artifact_rejects_out_of_scope_paths(
|
||||
"""FileArtifactService prevents path traversal outside of its storage roots."""
|
||||
artifact_service = FileArtifactService(root_dir=tmp_path / "artifacts")
|
||||
part = types.Part(text="content")
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(InputValidationError):
|
||||
await artifact_service.save_artifact(
|
||||
app_name="myapp",
|
||||
user_id="user123",
|
||||
@@ -757,7 +758,7 @@ async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path):
|
||||
/ "diagram.png"
|
||||
)
|
||||
part = types.Part(text="content")
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(InputValidationError):
|
||||
await artifact_service.save_artifact(
|
||||
app_name="myapp",
|
||||
user_id="user123",
|
||||
|
||||
@@ -30,7 +30,9 @@ from fastapi.testclient import TestClient
|
||||
from google.adk.agents.base_agent import BaseAgent
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.apps.app import App
|
||||
from google.adk.artifacts.base_artifact_service import ArtifactVersion
|
||||
from google.adk.cli.fast_api import get_fast_api_app
|
||||
from google.adk.errors.input_validation_error import InputValidationError
|
||||
from google.adk.evaluation.eval_case import EvalCase
|
||||
from google.adk.evaluation.eval_case import Invocation
|
||||
from google.adk.evaluation.eval_result import EvalSetResult
|
||||
@@ -211,48 +213,135 @@ def mock_session_service():
|
||||
def mock_artifact_service():
|
||||
"""Create a mock artifact service."""
|
||||
|
||||
# Storage for artifacts
|
||||
artifacts = {}
|
||||
artifacts: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
def _artifact_key(
|
||||
app_name: str, user_id: str, session_id: Optional[str], filename: str
|
||||
) -> str:
|
||||
if session_id is None:
|
||||
return f"{app_name}:{user_id}:user:{filename}"
|
||||
return f"{app_name}:{user_id}:{session_id}:{filename}"
|
||||
|
||||
def _canonical_uri(
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: Optional[str],
|
||||
filename: str,
|
||||
version: int,
|
||||
) -> str:
|
||||
if session_id is None:
|
||||
return (
|
||||
f"artifact://apps/{app_name}/users/{user_id}/artifacts/"
|
||||
f"{filename}/versions/{version}"
|
||||
)
|
||||
return (
|
||||
f"artifact://apps/{app_name}/users/{user_id}/sessions/{session_id}/"
|
||||
f"artifacts/{filename}/versions/{version}"
|
||||
)
|
||||
|
||||
class MockArtifactService:
|
||||
|
||||
def __init__(self):
|
||||
self._artifacts = artifacts
|
||||
self.save_artifact_side_effect: Optional[BaseException] = None
|
||||
|
||||
async def save_artifact(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
session_id: Optional[str] = None,
|
||||
custom_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> int:
|
||||
if self.save_artifact_side_effect is not None:
|
||||
effect = self.save_artifact_side_effect
|
||||
if isinstance(effect, BaseException):
|
||||
raise effect
|
||||
raise TypeError(
|
||||
"save_artifact_side_effect must be an exception instance."
|
||||
)
|
||||
key = _artifact_key(app_name, user_id, session_id, filename)
|
||||
entries = artifacts.setdefault(key, [])
|
||||
version = len(entries)
|
||||
artifact_version = ArtifactVersion(
|
||||
version=version,
|
||||
canonical_uri=_canonical_uri(
|
||||
app_name, user_id, session_id, filename, version
|
||||
),
|
||||
custom_metadata=custom_metadata or {},
|
||||
)
|
||||
if artifact.inline_data is not None:
|
||||
artifact_version.mime_type = artifact.inline_data.mime_type
|
||||
elif artifact.text is not None:
|
||||
artifact_version.mime_type = "text/plain"
|
||||
elif artifact.file_data is not None:
|
||||
artifact_version.mime_type = artifact.file_data.mime_type
|
||||
|
||||
entries.append({
|
||||
"version": version,
|
||||
"artifact": artifact,
|
||||
"metadata": artifact_version,
|
||||
})
|
||||
return version
|
||||
|
||||
async def load_artifact(
|
||||
self, app_name, user_id, session_id, filename, version=None
|
||||
):
|
||||
"""Load an artifact by filename."""
|
||||
key = f"{app_name}:{user_id}:{session_id}:{filename}"
|
||||
key = _artifact_key(app_name, user_id, session_id, filename)
|
||||
if key not in artifacts:
|
||||
return None
|
||||
|
||||
if version is not None:
|
||||
# Get a specific version
|
||||
for v in artifacts[key]:
|
||||
if v["version"] == version:
|
||||
return v["artifact"]
|
||||
for entry in artifacts[key]:
|
||||
if entry["version"] == version:
|
||||
return entry["artifact"]
|
||||
return None
|
||||
|
||||
# Get the latest version
|
||||
return sorted(artifacts[key], key=lambda x: x["version"])[-1]["artifact"]
|
||||
return artifacts[key][-1]["artifact"]
|
||||
|
||||
async def list_artifact_keys(self, app_name, user_id, session_id):
|
||||
"""List artifact names for a session."""
|
||||
prefix = f"{app_name}:{user_id}:{session_id}:"
|
||||
return [
|
||||
k.split(":")[-1] for k in artifacts.keys() if k.startswith(prefix)
|
||||
key.split(":")[-1]
|
||||
for key in artifacts.keys()
|
||||
if key.startswith(prefix)
|
||||
]
|
||||
|
||||
async def list_versions(self, app_name, user_id, session_id, filename):
|
||||
"""List versions of an artifact."""
|
||||
key = f"{app_name}:{user_id}:{session_id}:{filename}"
|
||||
key = _artifact_key(app_name, user_id, session_id, filename)
|
||||
if key not in artifacts:
|
||||
return []
|
||||
return [a["version"] for a in artifacts[key]]
|
||||
return [entry["version"] for entry in artifacts[key]]
|
||||
|
||||
async def delete_artifact(self, app_name, user_id, session_id, filename):
|
||||
"""Delete an artifact."""
|
||||
key = f"{app_name}:{user_id}:{session_id}:{filename}"
|
||||
if key in artifacts:
|
||||
del artifacts[key]
|
||||
key = _artifact_key(app_name, user_id, session_id, filename)
|
||||
artifacts.pop(key, None)
|
||||
|
||||
async def get_artifact_version(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
session_id: Optional[str] = None,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[ArtifactVersion]:
|
||||
key = _artifact_key(app_name, user_id, session_id, filename)
|
||||
entries = artifacts.get(key)
|
||||
if not entries:
|
||||
return None
|
||||
if version is None:
|
||||
return entries[-1]["metadata"]
|
||||
for entry in entries:
|
||||
if entry["version"] == version:
|
||||
return entry["metadata"]
|
||||
return None
|
||||
|
||||
return MockArtifactService()
|
||||
|
||||
@@ -327,15 +416,15 @@ def test_app(
|
||||
with (
|
||||
patch("signal.signal", return_value=None),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.create_session_service_from_options",
|
||||
"google.adk.cli.fast_api.InMemorySessionService",
|
||||
return_value=mock_session_service,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.create_artifact_service_from_options",
|
||||
"google.adk.cli.fast_api.InMemoryArtifactService",
|
||||
return_value=mock_artifact_service,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.create_memory_service_from_options",
|
||||
"google.adk.cli.fast_api.InMemoryMemoryService",
|
||||
return_value=mock_memory_service,
|
||||
),
|
||||
patch(
|
||||
@@ -472,15 +561,15 @@ def test_app_with_a2a(
|
||||
with (
|
||||
patch("signal.signal", return_value=None),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.create_session_service_from_options",
|
||||
"google.adk.cli.fast_api.InMemorySessionService",
|
||||
return_value=mock_session_service,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.create_artifact_service_from_options",
|
||||
"google.adk.cli.fast_api.InMemoryArtifactService",
|
||||
return_value=mock_artifact_service,
|
||||
),
|
||||
patch(
|
||||
"google.adk.cli.fast_api.create_memory_service_from_options",
|
||||
"google.adk.cli.fast_api.InMemoryMemoryService",
|
||||
return_value=mock_memory_service,
|
||||
),
|
||||
patch(
|
||||
@@ -810,6 +899,87 @@ def test_list_artifact_names(test_app, create_test_session):
|
||||
logger.info(f"Listed {len(data)} artifacts")
|
||||
|
||||
|
||||
def test_save_artifact(test_app, create_test_session, mock_artifact_service):
|
||||
"""Test saving an artifact through the FastAPI endpoint."""
|
||||
info = create_test_session
|
||||
url = (
|
||||
f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
|
||||
f"{info['session_id']}/artifacts"
|
||||
)
|
||||
artifact_part = types.Part(text="hello world")
|
||||
payload = {
|
||||
"filename": "greeting.txt",
|
||||
"artifact": artifact_part.model_dump(by_alias=True, exclude_none=True),
|
||||
}
|
||||
|
||||
response = test_app.post(url, json=payload)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["version"] == 0
|
||||
assert data["customMetadata"] == {}
|
||||
assert data["mimeType"] in (None, "text/plain")
|
||||
assert data["canonicalUri"].endswith(
|
||||
f"/sessions/{info['session_id']}/artifacts/"
|
||||
f"{payload['filename']}/versions/0"
|
||||
)
|
||||
assert isinstance(data["createTime"], float)
|
||||
|
||||
key = (
|
||||
f"{info['app_name']}:{info['user_id']}:{info['session_id']}:"
|
||||
f"{payload['filename']}"
|
||||
)
|
||||
stored = mock_artifact_service._artifacts[key][0]
|
||||
assert stored["artifact"].text == "hello world"
|
||||
|
||||
|
||||
def test_save_artifact_returns_400_on_validation_error(
|
||||
test_app, create_test_session, mock_artifact_service
|
||||
):
|
||||
"""Test save artifact endpoint surfaces validation errors as HTTP 400."""
|
||||
info = create_test_session
|
||||
url = (
|
||||
f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
|
||||
f"{info['session_id']}/artifacts"
|
||||
)
|
||||
artifact_part = types.Part(text="bad data")
|
||||
payload = {
|
||||
"filename": "invalid.txt",
|
||||
"artifact": artifact_part.model_dump(by_alias=True, exclude_none=True),
|
||||
}
|
||||
|
||||
mock_artifact_service.save_artifact_side_effect = InputValidationError(
|
||||
"invalid artifact"
|
||||
)
|
||||
|
||||
response = test_app.post(url, json=payload)
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "invalid artifact"
|
||||
|
||||
|
||||
def test_save_artifact_returns_500_on_unexpected_error(
|
||||
test_app, create_test_session, mock_artifact_service
|
||||
):
|
||||
"""Test save artifact endpoint surfaces unexpected errors as HTTP 500."""
|
||||
info = create_test_session
|
||||
url = (
|
||||
f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/"
|
||||
f"{info['session_id']}/artifacts"
|
||||
)
|
||||
artifact_part = types.Part(text="bad data")
|
||||
payload = {
|
||||
"filename": "invalid.txt",
|
||||
"artifact": artifact_part.model_dump(by_alias=True, exclude_none=True),
|
||||
}
|
||||
|
||||
mock_artifact_service.save_artifact_side_effect = RuntimeError(
|
||||
"unexpected failure"
|
||||
)
|
||||
|
||||
response = test_app.post(url, json=payload)
|
||||
assert response.status_code == 500
|
||||
assert response.json()["detail"] == "unexpected failure"
|
||||
|
||||
|
||||
def test_create_eval_set(test_app, test_session_info):
|
||||
"""Test creating an eval set."""
|
||||
url = f"/apps/{test_session_info['app_name']}/eval_sets/test_eval_set_id"
|
||||
|
||||
Reference in New Issue
Block a user