You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add new methods in the artifact service interface
PiperOrigin-RevId: 818473733
This commit is contained in:
committed by
Copybara-Service
parent
e63180cb62
commit
e212ff558e
@@ -15,9 +15,23 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class ArtifactVersion(BaseModel):
|
||||
"""Represents the metadata of a specific version of an artifact."""
|
||||
|
||||
version: int
|
||||
"""The version number of the artifact."""
|
||||
canonical_uri: str
|
||||
"""The canonical URI of the artifact version."""
|
||||
custom_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
"""A dictionary of custom metadata associated with the artifact version."""
|
||||
|
||||
|
||||
class BaseArtifactService(ABC):
|
||||
@@ -32,6 +46,7 @@ class BaseArtifactService(ABC):
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
session_id: Optional[str] = None,
|
||||
custom_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> int:
|
||||
"""Saves an artifact to the artifact service storage.
|
||||
|
||||
@@ -43,8 +58,12 @@ class BaseArtifactService(ABC):
|
||||
app_name: The app name.
|
||||
user_id: The user ID.
|
||||
filename: The filename of the artifact.
|
||||
artifact: The artifact to save.
|
||||
artifact: The artifact to save. If the artifact consists of `file_data`,
|
||||
the artifact service assumes its content has been uploaded separately,
|
||||
and this method will associate the `file_data` with the artifact if
|
||||
necessary.
|
||||
session_id: The session ID. If `None`, the artifact is user-scoped.
|
||||
custom_metadata: custom metadata to associate with the artifact.
|
||||
|
||||
Returns:
|
||||
The revision ID. The first version of the artifact has a revision ID of 0.
|
||||
@@ -136,3 +155,54 @@ class BaseArtifactService(ABC):
|
||||
Returns:
|
||||
A list of all available versions of the artifact.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def list_artifact_versions(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
session_id: Optional[str] = None,
|
||||
) -> list[ArtifactVersion]:
|
||||
"""Lists all versions and their metadata for a specific artifact.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
filename: The name of the artifact file.
|
||||
session_id: The ID of the session. If `None`, lists versions of the
|
||||
user-scoped artifact. Otherwise, lists versions of the artifact within
|
||||
the specified session.
|
||||
|
||||
Returns:
|
||||
A list of ArtifactVersion objects, each representing a version of the
|
||||
artifact and its associated metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_artifact_version(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
session_id: Optional[str] = None,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[ArtifactVersion]:
|
||||
"""Gets the metadata for a specific version of an artifact.
|
||||
|
||||
Args:
|
||||
app_name: The name of the application.
|
||||
user_id: The ID of the user.
|
||||
filename: The name of the artifact file.
|
||||
session_id: The ID of the session. If `None`, the artifact will be fetched
|
||||
from the user-scoped artifacts. Otherwise, it will be fetched from the
|
||||
specified session.
|
||||
version: The version number of the artifact to retrieve. If `None`, the
|
||||
latest version will be returned.
|
||||
|
||||
Returns:
|
||||
An ArtifactVersion object containing the metadata of the specified
|
||||
artifact version, or `None` if the artifact version is not found.
|
||||
"""
|
||||
|
||||
@@ -24,12 +24,14 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.cloud import storage
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_artifact_service import ArtifactVersion
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
@@ -58,6 +60,7 @@ class GcsArtifactService(BaseArtifactService):
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
session_id: Optional[str] = None,
|
||||
custom_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> int:
|
||||
return await asyncio.to_thread(
|
||||
self._save_artifact,
|
||||
@@ -66,6 +69,7 @@ class GcsArtifactService(BaseArtifactService):
|
||||
session_id,
|
||||
filename,
|
||||
artifact,
|
||||
custom_metadata,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -180,7 +184,12 @@ class GcsArtifactService(BaseArtifactService):
|
||||
session_id: Optional[str],
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
custom_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> int:
|
||||
if custom_metadata:
|
||||
# TODO: b/447451270 - support saving artifact with custom metadata.
|
||||
raise NotImplementedError("custom_metadata is not supported yet.")
|
||||
|
||||
versions = self._list_versions(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
@@ -316,3 +325,28 @@ class GcsArtifactService(BaseArtifactService):
|
||||
*_, version = blob.name.split("/")
|
||||
versions.append(int(version))
|
||||
return versions
|
||||
|
||||
@override
|
||||
async def list_artifact_versions(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
session_id: Optional[str] = None,
|
||||
) -> list[ArtifactVersion]:
|
||||
# TODO: b/447451270 - Support list_artifact_versions.
|
||||
raise NotImplementedError("list_artifact_versions is not implemented yet.")
|
||||
|
||||
@override
|
||||
async def get_artifact_version(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
session_id: Optional[str] = None,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[ArtifactVersion]:
|
||||
# TODO: b/447451270 - Support get_artifact_version.
|
||||
raise NotImplementedError("get_artifact_version is not implemented yet.")
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
@@ -21,6 +22,7 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_artifact_service import ArtifactVersion
|
||||
from .base_artifact_service import BaseArtifactService
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
@@ -83,7 +85,12 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
session_id: Optional[str] = None,
|
||||
custom_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> int:
|
||||
# TODO: b/447451270 - Support saving artifact with custom metadata.
|
||||
if custom_metadata:
|
||||
raise NotImplementedError("custom_metadata is not supported yet.")
|
||||
|
||||
path = self._artifact_path(app_name, user_id, filename, session_id)
|
||||
if path not in self.artifacts:
|
||||
self.artifacts[path] = []
|
||||
@@ -155,3 +162,28 @@ class InMemoryArtifactService(BaseArtifactService, BaseModel):
|
||||
if not versions:
|
||||
return []
|
||||
return list(range(len(versions)))
|
||||
|
||||
@override
|
||||
async def list_artifact_versions(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
session_id: Optional[str] = None,
|
||||
) -> list[ArtifactVersion]:
|
||||
# TODO: b/447451270 - Support list_artifact_versions.
|
||||
raise NotImplementedError("list_artifact_versions is not implemented yet.")
|
||||
|
||||
@override
|
||||
async def get_artifact_version(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
session_id: Optional[str] = None,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[ArtifactVersion]:
|
||||
# TODO: b/447451270 - Support get_artifact_version.
|
||||
raise NotImplementedError("get_artifact_version is not implemented yet.")
|
||||
|
||||
@@ -14,12 +14,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from ..artifacts.base_artifact_service import ArtifactVersion
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -42,6 +44,7 @@ class ForwardingArtifactService(BaseArtifactService):
|
||||
filename: str,
|
||||
artifact: types.Part,
|
||||
session_id: Optional[str] = None,
|
||||
custom_metadata: Optional[dict[str, Any]] = None,
|
||||
) -> int:
|
||||
return await self.tool_context.save_artifact(
|
||||
filename=filename, artifact=artifact
|
||||
@@ -104,3 +107,26 @@ class ForwardingArtifactService(BaseArtifactService):
|
||||
session_id=self._invocation_context.session.id,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
@override
|
||||
async def list_artifact_versions(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
session_id: Optional[str] = None,
|
||||
) -> list[ArtifactVersion]:
|
||||
raise NotImplementedError("list_artifact_versions is not implemented yet.")
|
||||
|
||||
@override
|
||||
async def get_artifact_version(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
filename: str,
|
||||
session_id: Optional[str] = None,
|
||||
version: Optional[int] = None,
|
||||
) -> Optional[ArtifactVersion]:
|
||||
raise NotImplementedError("get_artifact_version is not implemented yet.")
|
||||
|
||||
Reference in New Issue
Block a user