feat: Add new methods in the artifact service interface

PiperOrigin-RevId: 818473733
This commit is contained in:
Google Team Member
2025-10-12 21:18:50 -07:00
committed by Copybara-Service
parent e63180cb62
commit e212ff558e
4 changed files with 163 additions and 1 deletions
@@ -15,9 +15,23 @@ from __future__ import annotations
from abc import ABC
from abc import abstractmethod
from typing import Any
from typing import Optional
from google.genai import types
from pydantic import BaseModel
from pydantic import Field
class ArtifactVersion(BaseModel):
"""Represents the metadata of a specific version of an artifact."""
version: int
"""The version number of the artifact."""
canonical_uri: str
"""The canonical URI of the artifact version."""
custom_metadata: dict[str, Any] = Field(default_factory=dict)
"""A dictionary of custom metadata associated with the artifact version."""
class BaseArtifactService(ABC):
@@ -32,6 +46,7 @@ class BaseArtifactService(ABC):
filename: str,
artifact: types.Part,
session_id: Optional[str] = None,
custom_metadata: Optional[dict[str, Any]] = None,
) -> int:
"""Saves an artifact to the artifact service storage.
@@ -43,8 +58,12 @@ class BaseArtifactService(ABC):
app_name: The app name.
user_id: The user ID.
filename: The filename of the artifact.
artifact: The artifact to save.
artifact: The artifact to save. If the artifact consists of `file_data`,
the artifact service assumes its content has been uploaded separately,
and this method will associate the `file_data` with the artifact if
necessary.
session_id: The session ID. If `None`, the artifact is user-scoped.
custom_metadata: custom metadata to associate with the artifact.
Returns:
The revision ID. The first version of the artifact has a revision ID of 0.
@@ -136,3 +155,54 @@ class BaseArtifactService(ABC):
Returns:
A list of all available versions of the artifact.
"""
@abstractmethod
async def list_artifact_versions(
self,
*,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str] = None,
) -> list[ArtifactVersion]:
"""Lists all versions and their metadata for a specific artifact.
Args:
app_name: The name of the application.
user_id: The ID of the user.
filename: The name of the artifact file.
session_id: The ID of the session. If `None`, lists versions of the
user-scoped artifact. Otherwise, lists versions of the artifact within
the specified session.
Returns:
A list of ArtifactVersion objects, each representing a version of the
artifact and its associated metadata.
"""
@abstractmethod
async def get_artifact_version(
self,
*,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str] = None,
version: Optional[int] = None,
) -> Optional[ArtifactVersion]:
"""Gets the metadata for a specific version of an artifact.
Args:
app_name: The name of the application.
user_id: The ID of the user.
filename: The name of the artifact file.
session_id: The ID of the session. If `None`, the artifact will be fetched
from the user-scoped artifacts. Otherwise, it will be fetched from the
specified session.
version: The version number of the artifact to retrieve. If `None`, the
latest version will be returned.
Returns:
An ArtifactVersion object containing the metadata of the specified
artifact version, or `None` if the artifact version is not found.
"""
@@ -24,12 +24,14 @@ from __future__ import annotations
import asyncio
import logging
from typing import Any
from typing import Optional
from google.cloud import storage
from google.genai import types
from typing_extensions import override
from .base_artifact_service import ArtifactVersion
from .base_artifact_service import BaseArtifactService
logger = logging.getLogger("google_adk." + __name__)
@@ -58,6 +60,7 @@ class GcsArtifactService(BaseArtifactService):
filename: str,
artifact: types.Part,
session_id: Optional[str] = None,
custom_metadata: Optional[dict[str, Any]] = None,
) -> int:
return await asyncio.to_thread(
self._save_artifact,
@@ -66,6 +69,7 @@ class GcsArtifactService(BaseArtifactService):
session_id,
filename,
artifact,
custom_metadata,
)
@override
@@ -180,7 +184,12 @@ class GcsArtifactService(BaseArtifactService):
session_id: Optional[str],
filename: str,
artifact: types.Part,
custom_metadata: Optional[dict[str, Any]] = None,
) -> int:
if custom_metadata:
# TODO: b/447451270 - support saving artifact with custom metadata.
raise NotImplementedError("custom_metadata is not supported yet.")
versions = self._list_versions(
app_name=app_name,
user_id=user_id,
@@ -316,3 +325,28 @@ class GcsArtifactService(BaseArtifactService):
*_, version = blob.name.split("/")
versions.append(int(version))
return versions
@override
async def list_artifact_versions(
self,
*,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str] = None,
) -> list[ArtifactVersion]:
# TODO: b/447451270 - Support list_artifact_versions.
raise NotImplementedError("list_artifact_versions is not implemented yet.")
@override
async def get_artifact_version(
self,
*,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str] = None,
version: Optional[int] = None,
) -> Optional[ArtifactVersion]:
# TODO: b/447451270 - Support get_artifact_version.
raise NotImplementedError("get_artifact_version is not implemented yet.")
@@ -14,6 +14,7 @@
from __future__ import annotations
import logging
from typing import Any
from typing import Optional
from google.genai import types
@@ -21,6 +22,7 @@ from pydantic import BaseModel
from pydantic import Field
from typing_extensions import override
from .base_artifact_service import ArtifactVersion
from .base_artifact_service import BaseArtifactService
logger = logging.getLogger("google_adk." + __name__)
@@ -83,7 +85,12 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
filename: str,
artifact: types.Part,
session_id: Optional[str] = None,
custom_metadata: Optional[dict[str, Any]] = None,
) -> int:
# TODO: b/447451270 - Support saving artifact with custom metadata.
if custom_metadata:
raise NotImplementedError("custom_metadata is not supported yet.")
path = self._artifact_path(app_name, user_id, filename, session_id)
if path not in self.artifacts:
self.artifacts[path] = []
@@ -155,3 +162,28 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
if not versions:
return []
return list(range(len(versions)))
@override
async def list_artifact_versions(
self,
*,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str] = None,
) -> list[ArtifactVersion]:
# TODO: b/447451270 - Support list_artifact_versions.
raise NotImplementedError("list_artifact_versions is not implemented yet.")
@override
async def get_artifact_version(
self,
*,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str] = None,
version: Optional[int] = None,
) -> Optional[ArtifactVersion]:
# TODO: b/447451270 - Support get_artifact_version.
raise NotImplementedError("get_artifact_version is not implemented yet.")
@@ -14,12 +14,14 @@
from __future__ import annotations
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from google.genai import types
from typing_extensions import override
from ..artifacts.base_artifact_service import ArtifactVersion
from ..artifacts.base_artifact_service import BaseArtifactService
if TYPE_CHECKING:
@@ -42,6 +44,7 @@ class ForwardingArtifactService(BaseArtifactService):
filename: str,
artifact: types.Part,
session_id: Optional[str] = None,
custom_metadata: Optional[dict[str, Any]] = None,
) -> int:
return await self.tool_context.save_artifact(
filename=filename, artifact=artifact
@@ -104,3 +107,26 @@ class ForwardingArtifactService(BaseArtifactService):
session_id=self._invocation_context.session.id,
filename=filename,
)
@override
async def list_artifact_versions(
self,
*,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str] = None,
) -> list[ArtifactVersion]:
raise NotImplementedError("list_artifact_versions is not implemented yet.")
@override
async def get_artifact_version(
self,
*,
app_name: str,
user_id: str,
filename: str,
session_id: Optional[str] = None,
version: Optional[int] = None,
) -> Optional[ArtifactVersion]:
raise NotImplementedError("get_artifact_version is not implemented yet.")