You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Add service factory for configurable session and artifact backends
this creates service_factory to handle .adk folder changes (including per-agent .adk defaults and in-memory/custom URI handling) Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 833875524
This commit is contained in:
committed by
Copybara-Service
parent
8eb1bdbc58
commit
a12ae812d3
@@ -66,10 +66,12 @@ from __future__ import annotations
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
@@ -218,6 +220,11 @@ def _register_builtin_services(registry: ServiceRegistry) -> None:
|
||||
"""Register built-in service implementations."""
|
||||
|
||||
# -- Session Services --
|
||||
def memory_session_factory(uri: str, **kwargs):
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
|
||||
return InMemorySessionService()
|
||||
|
||||
def agentengine_session_factory(uri: str, **kwargs):
|
||||
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
||||
|
||||
@@ -240,19 +247,26 @@ def _register_builtin_services(registry: ServiceRegistry) -> None:
|
||||
parsed = urlparse(uri)
|
||||
db_path = parsed.path
|
||||
if not db_path:
|
||||
return InMemorySessionService()
|
||||
# Treat sqlite:// without a path as an in-memory session service.
|
||||
return memory_session_factory("memory://", **kwargs)
|
||||
elif db_path.startswith("/"):
|
||||
db_path = db_path[1:]
|
||||
kwargs_copy = kwargs.copy()
|
||||
kwargs_copy.pop("agents_dir", None)
|
||||
return SqliteSessionService(db_path=db_path, **kwargs_copy)
|
||||
|
||||
registry.register_session_service("memory", memory_session_factory)
|
||||
registry.register_session_service("agentengine", agentengine_session_factory)
|
||||
registry.register_session_service("sqlite", sqlite_session_factory)
|
||||
for scheme in ["postgresql", "mysql"]:
|
||||
registry.register_session_service(scheme, database_session_factory)
|
||||
|
||||
# -- Artifact Services --
|
||||
def memory_artifact_factory(uri: str, **kwargs):
|
||||
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
||||
|
||||
return InMemoryArtifactService()
|
||||
|
||||
def gcs_artifact_factory(uri: str, **kwargs):
|
||||
from ..artifacts.gcs_artifact_service import GcsArtifactService
|
||||
|
||||
@@ -262,7 +276,27 @@ def _register_builtin_services(registry: ServiceRegistry) -> None:
|
||||
bucket_name = parsed_uri.netloc
|
||||
return GcsArtifactService(bucket_name=bucket_name, **kwargs_copy)
|
||||
|
||||
def file_artifact_factory(uri: str, **kwargs):
|
||||
from ..artifacts.file_artifact_service import FileArtifactService
|
||||
|
||||
per_agent = kwargs.get("per_agent", False)
|
||||
if per_agent:
|
||||
raise ValueError(
|
||||
"file:// artifact URIs are not supported in multi-agent mode."
|
||||
)
|
||||
parsed_uri = urlparse(uri)
|
||||
if parsed_uri.netloc not in ("", "localhost"):
|
||||
raise ValueError(
|
||||
"file:// artifact URIs must reference the local filesystem."
|
||||
)
|
||||
if not parsed_uri.path:
|
||||
raise ValueError("file:// artifact URIs must include a path component.")
|
||||
artifact_path = Path(unquote(parsed_uri.path))
|
||||
return FileArtifactService(root_dir=artifact_path)
|
||||
|
||||
registry.register_artifact_service("memory", memory_artifact_factory)
|
||||
registry.register_artifact_service("gs", gcs_artifact_factory)
|
||||
registry.register_artifact_service("file", file_artifact_factory)
|
||||
|
||||
# -- Memory Services --
|
||||
def rag_memory_factory(uri: str, **kwargs):
|
||||
@@ -270,7 +304,7 @@ def _register_builtin_services(registry: ServiceRegistry) -> None:
|
||||
|
||||
rag_corpus = urlparse(uri).netloc
|
||||
if not rag_corpus:
|
||||
raise ValueError("Rag corpus cannot be empty.")
|
||||
raise ValueError("Rag corpus can not be empty.")
|
||||
agents_dir = kwargs.get("agents_dir")
|
||||
project, location = _load_gcp_config(agents_dir, "RAG memory service")
|
||||
return VertexAiRagMemoryService(
|
||||
|
||||
Reference in New Issue
Block a user