You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Support registering custom services from local files
This change introduces a mechanism for users to register their own custom backend services for sessions, memory, and artifacts without modifying the ADK framework. This enhances the extensibility of ADK.
Two methods of registration are supported, both by placing a file in the parent directory of the agents.
**YAML Configuration (services.yaml or .yml)**
This is the recommended approach for simple services that can be instantiated with a constructor like MyService(uri="...", **kwargs).
Example services.yaml:
```
services:
- scheme: mysession
type: session
class: my_package.my_module.MyCustomSessionService
```
**Python Registration (services.py)**
For services requiring more complex initialization logic, users can define factory functions in a services.py file.
Example services.py
```
from google.adk.cli.service_registry import get_service_registry
from my_package.my_module import MyCustomSessionService
def my_session_factory(uri: str, **kwargs):
# custom initialization logic
return MyCustomSessionService(...)
get_service_registry().register_session_service("mysession", my_session_factory)
```
ADK will load services from services.yaml/.yml first, and then from services.py. If the same service scheme is defined in both, the registration in services.py will take precedence.
To use a registered service, specify its URI via the corresponding command-line flag, e.g., `--session_service_uri=mysession://....`
Co-authored-by: Shangjie Chen <deanchen@google.com>
PiperOrigin-RevId: 831211371
This commit is contained in:
committed by
Copybara-Service
parent
99fc17b336
commit
a501c59ac4
@@ -0,0 +1,96 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Dummy service implementations for testing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.adk.memory.base_memory_service import BaseMemoryService
|
||||
from google.adk.memory.base_memory_service import SearchMemoryResponse
|
||||
from google.adk.memory.memory_entry import MemoryEntry
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.adk.sessions.session import Session
|
||||
|
||||
|
||||
class FooMemoryService(BaseMemoryService):
|
||||
"""A dummy memory service that returns a fixed response."""
|
||||
|
||||
def __init__(self, uri: str | None = None, **kwargs):
|
||||
"""Initializes the foo memory service.
|
||||
|
||||
Args:
|
||||
uri: The service URI.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
del uri, kwargs # Unused in this dummy implementation.
|
||||
|
||||
@override
|
||||
async def add_session_to_memory(self, session: Session):
|
||||
print('FooMemoryService.add_session_to_memory')
|
||||
|
||||
@override
|
||||
async def search_memory(
|
||||
self, *, app_name: str, user_id: str, query: str
|
||||
) -> SearchMemoryResponse:
|
||||
print('FooMemoryService.search_memory')
|
||||
return SearchMemoryResponse(
|
||||
memories=[
|
||||
MemoryEntry(
|
||||
content=types.Content(
|
||||
parts=[types.Part(text='I love ADK from Foo')]
|
||||
),
|
||||
author='bot',
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class BarMemoryService(BaseMemoryService):
|
||||
"""A dummy memory service that returns a fixed response."""
|
||||
|
||||
def __init__(self, uri: str | None = None, **kwargs):
|
||||
"""Initializes the bar memory service.
|
||||
|
||||
Args:
|
||||
uri: The service URI.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
del uri, kwargs # Unused in this dummy implementation.
|
||||
|
||||
@override
|
||||
async def add_session_to_memory(self, session: Session):
|
||||
print('BarMemoryService.add_session_to_memory')
|
||||
|
||||
@override
|
||||
async def search_memory(
|
||||
self, *, app_name: str, user_id: str, query: str
|
||||
) -> SearchMemoryResponse:
|
||||
print('BarMemoryService.search_memory')
|
||||
return SearchMemoryResponse(
|
||||
memories=[
|
||||
MemoryEntry(
|
||||
content=types.Content(
|
||||
parts=[types.Part(text='I love ADK from Bar')]
|
||||
),
|
||||
author='bot',
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Example of Python-based service registration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dummy_services import FooMemoryService
|
||||
from google.adk.cli.service_registry import get_service_registry
|
||||
|
||||
|
||||
def foo_memory_factory(uri: str, **kwargs) -> FooMemoryService:
|
||||
"""Factory for FooMemoryService."""
|
||||
return FooMemoryService(uri=uri, **kwargs)
|
||||
|
||||
|
||||
# Register the foo memory service with scheme "foo".
|
||||
# To use this memory service, set --memory_service_uri=foo:// in the ADK CLI.
|
||||
get_service_registry().register_memory_service("foo", foo_memory_factory)
|
||||
|
||||
# The BarMemoryService is registered in services.yaml with scheme "bar".
|
||||
# To use it, set --memory_service_uri=bar:// in the ADK CLI.
|
||||
@@ -0,0 +1,7 @@
|
||||
# Example of YAML-based service registration.
|
||||
# The BarMemoryService is registered here with scheme "bar".
|
||||
# To use this memory service, set --memory_service_uri=bar:// in the ADK CLI.
|
||||
services:
|
||||
- scheme: bar
|
||||
type: memory
|
||||
class: dummy_services.BarMemoryService
|
||||
@@ -19,6 +19,7 @@ import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
@@ -42,6 +43,7 @@ from ..runners import Runner
|
||||
from ..sessions.in_memory_session_service import InMemorySessionService
|
||||
from .adk_web_server import AdkWebServer
|
||||
from .service_registry import get_service_registry
|
||||
from .service_registry import load_services_module
|
||||
from .utils import envs
|
||||
from .utils import evals
|
||||
from .utils.agent_change_handler import AgentChangeEventHandler
|
||||
@@ -72,6 +74,7 @@ def get_fast_api_app(
|
||||
logo_text: Optional[str] = None,
|
||||
logo_image_url: Optional[str] = None,
|
||||
) -> FastAPI:
|
||||
|
||||
# Set up eval managers.
|
||||
if eval_storage_uri:
|
||||
gcs_eval_managers = evals.create_gcs_eval_managers_from_uri(
|
||||
@@ -83,6 +86,11 @@ def get_fast_api_app(
|
||||
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
|
||||
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
|
||||
|
||||
# initialize Agent Loader
|
||||
agent_loader = AgentLoader(agents_dir)
|
||||
# Load services.py from agents_dir for custom service registration.
|
||||
load_services_module(agents_dir)
|
||||
|
||||
service_registry = get_service_registry()
|
||||
|
||||
# Build the Memory service
|
||||
@@ -129,9 +137,6 @@ def get_fast_api_app(
|
||||
# Build the Credential service
|
||||
credential_service = InMemoryCredentialService()
|
||||
|
||||
# initialize Agent Loader
|
||||
agent_loader = AgentLoader(agents_dir)
|
||||
|
||||
adk_web_server = AdkWebServer(
|
||||
agent_loader=agent_loader,
|
||||
session_service=session_service,
|
||||
|
||||
@@ -11,12 +11,64 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
ADK Service Registry.
|
||||
|
||||
This module manages pluggable backend services for sessions, artifacts, and memory.
|
||||
ADK includes built-in support for common backends like SQLite, PostgreSQL,
|
||||
GCS, and Vertex AI Agent Engine. You can also extend ADK by registering
|
||||
custom services.
|
||||
|
||||
There are two ways to register custom services:
|
||||
|
||||
1. YAML Configuration (Recommended for simple cases)
|
||||
If your custom service can be instantiated with `MyService(uri="...", **kwargs)`,
|
||||
you can register it without writing Python code by creating a `services.yaml`
|
||||
or `services.yml` file in your agent directory (e.g., `my_agent/services.yaml`).
|
||||
|
||||
Example `services.yaml`:
|
||||
```yaml
|
||||
services:
|
||||
- scheme: mysession
|
||||
type: session
|
||||
class: my_package.my_module.MyCustomSessionService
|
||||
- scheme: mymemory
|
||||
type: memory
|
||||
class: my_package.other_module.MyCustomMemoryService
|
||||
```
|
||||
|
||||
2. Python Registration (`services.py`)
|
||||
For more complex initialization logic, create a `services.py` file in your
|
||||
agent directory (e.g., `my_agent/services.py`). In this file, get the
|
||||
registry instance and register your custom factory functions. This file can
|
||||
be used for registration in addition to, or instead of, `services.yaml`.
|
||||
|
||||
Example `services.py`:
|
||||
```python
|
||||
from google.adk.cli.service_registry import get_service_registry
|
||||
from my_package.my_module import MyCustomSessionService
|
||||
|
||||
def my_session_factory(uri: str, **kwargs):
|
||||
# custom logic
|
||||
return MyCustomSessionService(...)
|
||||
|
||||
get_service_registry().register_session_service("mysession", my_session_factory)
|
||||
```
|
||||
|
||||
Note: If both `services.yaml` (or `.yml`) and `services.py` are present in the
|
||||
same directory, services from **both** files will be loaded. YAML files are
|
||||
processed first, then `services.py`. If the same service scheme is defined in
|
||||
both, the definition in `services.py` will overwrite the one from YAML.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -24,62 +76,9 @@ from ..artifacts.base_artifact_service import BaseArtifactService
|
||||
from ..memory.base_memory_service import BaseMemoryService
|
||||
from ..sessions import InMemorySessionService
|
||||
from ..sessions.base_session_service import BaseSessionService
|
||||
from ..utils import yaml_utils
|
||||
|
||||
|
||||
def _load_gcp_config(
|
||||
agents_dir: str | None, service_name: str
|
||||
) -> tuple[str, str]:
|
||||
"""Loads GCP project and location from environment."""
|
||||
if not agents_dir:
|
||||
raise ValueError(f"agents_dir must be provided for {service_name}")
|
||||
|
||||
from .utils import envs
|
||||
|
||||
envs.load_dotenv_for_agent("", agents_dir)
|
||||
|
||||
project = os.environ.get("GOOGLE_CLOUD_PROJECT")
|
||||
location = os.environ.get("GOOGLE_CLOUD_LOCATION")
|
||||
|
||||
if not project or not location:
|
||||
raise ValueError("GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_LOCATION not set.")
|
||||
|
||||
return project, location
|
||||
|
||||
|
||||
def _parse_agent_engine_kwargs(
|
||||
uri_part: str, agents_dir: str | None
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to parse agent engine resource name."""
|
||||
if not uri_part:
|
||||
raise ValueError(
|
||||
"Agent engine resource name or resource id cannot be empty."
|
||||
)
|
||||
if "/" in uri_part:
|
||||
parts = uri_part.split("/")
|
||||
if not (
|
||||
len(parts) == 6
|
||||
and parts[0] == "projects"
|
||||
and parts[2] == "locations"
|
||||
and parts[4] == "reasoningEngines"
|
||||
):
|
||||
raise ValueError(
|
||||
"Agent engine resource name is mal-formatted. It should be of"
|
||||
" format :"
|
||||
" projects/{project_id}/locations/{location}/reasoningEngines/{resource_id}"
|
||||
)
|
||||
project = parts[1]
|
||||
location = parts[3]
|
||||
agent_engine_id = parts[5]
|
||||
else:
|
||||
project, location = _load_gcp_config(
|
||||
agents_dir, "short-form agent engine IDs"
|
||||
)
|
||||
agent_engine_id = uri_part
|
||||
return {
|
||||
"project": project,
|
||||
"location": location,
|
||||
"agent_engine_id": agent_engine_id,
|
||||
}
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
class ServiceFactory(Protocol):
|
||||
@@ -95,9 +94,9 @@ class ServiceRegistry:
|
||||
"""Registry for custom service URI schemes."""
|
||||
|
||||
def __init__(self):
|
||||
self._session_factories: Dict[str, ServiceFactory] = {}
|
||||
self._artifact_factories: Dict[str, ServiceFactory] = {}
|
||||
self._memory_factories: Dict[str, ServiceFactory] = {}
|
||||
self._session_factories: dict[str, ServiceFactory] = {}
|
||||
self._artifact_factories: dict[str, ServiceFactory] = {}
|
||||
self._memory_factories: dict[str, ServiceFactory] = {}
|
||||
|
||||
def register_session_service(
|
||||
self, scheme: str, factory: ServiceFactory
|
||||
@@ -151,6 +150,70 @@ class ServiceRegistry:
|
||||
return None
|
||||
|
||||
|
||||
def get_service_registry() -> ServiceRegistry:
|
||||
"""Gets the singleton ServiceRegistry instance, initializing it if needed."""
|
||||
global _service_registry_instance
|
||||
if _service_registry_instance is None:
|
||||
_service_registry_instance = ServiceRegistry()
|
||||
_register_builtin_services(_service_registry_instance)
|
||||
return _service_registry_instance
|
||||
|
||||
|
||||
def load_services_module(agents_dir: str) -> None:
|
||||
"""Load services.py or services.yaml from agents_dir for custom service registration.
|
||||
|
||||
If services.yaml or services.yml is found, it will be loaded first,
|
||||
followed by services.py if it exists.
|
||||
|
||||
Skip if neither services.yaml/yml nor services.py is not found.
|
||||
"""
|
||||
if not os.path.isdir(agents_dir):
|
||||
logger.debug(
|
||||
"agents_dir %s is not a valid directory, skipping service loading.",
|
||||
agents_dir,
|
||||
)
|
||||
return
|
||||
if agents_dir not in sys.path:
|
||||
sys.path.insert(0, agents_dir)
|
||||
|
||||
# Try loading services.yaml or services.yml first
|
||||
for yaml_file in ["services.yaml", "services.yml"]:
|
||||
yaml_path = os.path.join(agents_dir, yaml_file)
|
||||
if os.path.exists(yaml_path):
|
||||
try:
|
||||
config = yaml_utils.load_yaml_file(yaml_path)
|
||||
_register_services_from_yaml_config(config, get_service_registry())
|
||||
logger.debug(
|
||||
"Loaded custom services from %s in %s.", yaml_file, agents_dir
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to load %s from %s: %s",
|
||||
yaml_file,
|
||||
agents_dir,
|
||||
e,
|
||||
)
|
||||
return # If yaml exists but fails to load, stop.
|
||||
|
||||
try:
|
||||
importlib.import_module("services")
|
||||
logger.debug(
|
||||
"Loaded services.py from %s for custom service registration.",
|
||||
agents_dir,
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
logger.debug("services.py not found in %s, skipping.", agents_dir)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to load services.py from %s: %s",
|
||||
agents_dir,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
_service_registry_instance: ServiceRegistry | None = None
|
||||
|
||||
|
||||
def _register_builtin_services(registry: ServiceRegistry) -> None:
|
||||
"""Register built-in service implementations."""
|
||||
|
||||
@@ -229,11 +292,108 @@ def _register_builtin_services(registry: ServiceRegistry) -> None:
|
||||
registry.register_memory_service("agentengine", agentengine_memory_factory)
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_global_registry = ServiceRegistry()
|
||||
_register_builtin_services(_global_registry)
|
||||
def _load_gcp_config(
|
||||
agents_dir: Optional[str], service_name: str
|
||||
) -> tuple[str, str]:
|
||||
"""Loads GCP project and location from environment."""
|
||||
if not agents_dir:
|
||||
raise ValueError(f"agents_dir must be provided for {service_name}")
|
||||
|
||||
from .utils import envs
|
||||
|
||||
envs.load_dotenv_for_agent("", agents_dir)
|
||||
|
||||
project = os.environ.get("GOOGLE_CLOUD_PROJECT")
|
||||
location = os.environ.get("GOOGLE_CLOUD_LOCATION")
|
||||
|
||||
if not project or not location:
|
||||
raise ValueError("GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_LOCATION not set.")
|
||||
|
||||
return project, location
|
||||
|
||||
|
||||
def get_service_registry() -> ServiceRegistry:
|
||||
"""Get the global service registry instance."""
|
||||
return _global_registry
|
||||
def _parse_agent_engine_kwargs(
|
||||
uri_part: str, agents_dir: Optional[str]
|
||||
) -> dict[str, Any]:
|
||||
"""Helper to parse agent engine resource name."""
|
||||
if not uri_part:
|
||||
raise ValueError(
|
||||
"Agent engine resource name or resource id cannot be empty."
|
||||
)
|
||||
|
||||
# If uri_part is just an ID, load project/location from env
|
||||
if "/" not in uri_part:
|
||||
project, location = _load_gcp_config(
|
||||
agents_dir, "short-form agent engine IDs"
|
||||
)
|
||||
return {
|
||||
"project": project,
|
||||
"location": location,
|
||||
"agent_engine_id": uri_part,
|
||||
}
|
||||
|
||||
# If uri_part is a full resource name, parse it
|
||||
parts = uri_part.split("/")
|
||||
if not (
|
||||
len(parts) == 6
|
||||
and parts[0] == "projects"
|
||||
and parts[2] == "locations"
|
||||
and parts[4] == "reasoningEngines"
|
||||
):
|
||||
raise ValueError(
|
||||
"Agent engine resource name is mal-formatted. It should be of"
|
||||
" format :"
|
||||
" projects/{project_id}/locations/{location}/reasoningEngines/{resource_id}"
|
||||
)
|
||||
return {
|
||||
"project": parts[1],
|
||||
"location": parts[3],
|
||||
"agent_engine_id": parts[5],
|
||||
}
|
||||
|
||||
|
||||
def _get_class_from_string(class_path: str) -> Any:
|
||||
"""Dynamically import a class from a string path."""
|
||||
try:
|
||||
module_name, class_name = class_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not import class {class_path}: {e}") from e
|
||||
|
||||
|
||||
def _create_generic_factory(class_path: str) -> ServiceFactory:
|
||||
"""Create a generic factory for a service class."""
|
||||
cls = _get_class_from_string(class_path)
|
||||
|
||||
def factory(uri: str, **kwargs):
|
||||
return cls(uri=uri, **kwargs)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
def _register_services_from_yaml_config(
|
||||
config: dict[str, Any], registry: ServiceRegistry
|
||||
) -> None:
|
||||
"""Register services defined in a YAML configuration."""
|
||||
if not config or "services" not in config:
|
||||
return
|
||||
|
||||
for service_config in config["services"]:
|
||||
scheme = service_config.get("scheme")
|
||||
service_type = service_config.get("type")
|
||||
class_path = service_config.get("class")
|
||||
|
||||
if not all([scheme, service_type, class_path]):
|
||||
logger.warning("Invalid service config in YAML: %s", service_config)
|
||||
continue
|
||||
|
||||
factory = _create_generic_factory(class_path)
|
||||
if service_type == "session":
|
||||
registry.register_session_service(scheme, factory)
|
||||
elif service_type == "artifact":
|
||||
registry.register_artifact_service(scheme, factory)
|
||||
elif service_type == "memory":
|
||||
registry.register_memory_service(scheme, factory)
|
||||
else:
|
||||
logger.warning("Unknown service type in YAML: %s", service_type)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
@@ -26,6 +27,25 @@ if TYPE_CHECKING:
|
||||
from pydantic.main import IncEx
|
||||
|
||||
|
||||
def load_yaml_file(file_path: Union[str, Path]) -> Any:
|
||||
"""Loads a YAML file and returns its content.
|
||||
|
||||
Args:
|
||||
file_path: Path to the YAML file.
|
||||
|
||||
Returns:
|
||||
The content of the YAML file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file_path does not exist.
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
if not file_path.is_file():
|
||||
raise FileNotFoundError(f'YAML file not found: {file_path}')
|
||||
with file_path.open('r', encoding='utf-8') as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def dump_pydantic_to_yaml(
|
||||
model: BaseModel,
|
||||
file_path: Union[str, Path],
|
||||
|
||||
Reference in New Issue
Block a user