You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Support passing fully qualified agent engine resource name when constructing session service and memory service
Resolves https://github.com/google/adk-python/issues/1760 PiperOrigin-RevId: 784353411
This commit is contained in:
committed by
Copybara-Service
parent
36e45cdab3
commit
2e778049d0
@@ -461,7 +461,10 @@ def adk_services_options():
|
||||
"--session_service_uri",
|
||||
help=(
|
||||
"""Optional. The URI of the session service.
|
||||
- Use 'agentengine://<agent_engine_resource_id>' to connect to Agent Engine sessions.
|
||||
- Use 'agentengine://<agent_engine>' to connect to Agent Engine
|
||||
sessions. <agent_engine> can either be the full qualified resource
|
||||
name 'projects/abc/locations/us-central1/reasoningEngines/123' or
|
||||
the resource id '123'.
|
||||
- Use 'sqlite://<path_to_sqlite_file>' to connect to a SQLite DB.
|
||||
- See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported database URIs."""
|
||||
),
|
||||
@@ -487,11 +490,12 @@ def adk_services_options():
|
||||
@click.option(
|
||||
"--memory_service_uri",
|
||||
type=str,
|
||||
help=(
|
||||
"""Optional. The URI of the memory service.
|
||||
help=("""Optional. The URI of the memory service.
|
||||
- Use 'rag://<rag_corpus_id>' to connect to Vertex AI Rag Memory Service.
|
||||
- Use 'agentengine://<agent_engine_resource_id>' to connect to Vertex AI Memory Bank Service. e.g. agentengine://12345"""
|
||||
),
|
||||
- Use 'agentengine://<agent_engine>' to connect to Agent Engine
|
||||
sessions. <agent_engine> can either be the full qualified resource
|
||||
name 'projects/abc/locations/us-central1/reasoningEngines/123' or
|
||||
the resource id '123'."""),
|
||||
default=None,
|
||||
)
|
||||
@functools.wraps(func)
|
||||
|
||||
@@ -297,6 +297,31 @@ def get_fast_api_app(
|
||||
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
|
||||
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
|
||||
|
||||
def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name):
|
||||
if not agent_engine_id_or_resource_name:
|
||||
raise click.ClickException(
|
||||
"Agent engine resource name or resource id can not be empty."
|
||||
)
|
||||
|
||||
# "projects/my-project/locations/us-central1/reasoningEngines/1234567890",
|
||||
if "/" in agent_engine_id_or_resource_name:
|
||||
# Validate resource name.
|
||||
if len(agent_engine_id_or_resource_name.split("/")) != 6:
|
||||
raise click.ClickException(
|
||||
"Agent engine resource name is mal-formatted. It should be of"
|
||||
" format :"
|
||||
" projects/{project_id}/locations/{location}/reasoningEngines/{resource_id}"
|
||||
)
|
||||
project = agent_engine_id_or_resource_name.split("/")[1]
|
||||
location = agent_engine_id_or_resource_name.split("/")[3]
|
||||
agent_engine_id = agent_engine_id_or_resource_name.split("/")[-1]
|
||||
else:
|
||||
envs.load_dotenv_for_agent("", agents_dir)
|
||||
project = os.environ["GOOGLE_CLOUD_PROJECT"]
|
||||
location = os.environ["GOOGLE_CLOUD_LOCATION"]
|
||||
agent_engine_id = agent_engine_id_or_resource_name
|
||||
return project, location, agent_engine_id
|
||||
|
||||
# Build the Memory service
|
||||
if memory_service_uri:
|
||||
if memory_service_uri.startswith("rag://"):
|
||||
@@ -308,13 +333,13 @@ def get_fast_api_app(
|
||||
rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}'
|
||||
)
|
||||
elif memory_service_uri.startswith("agentengine://"):
|
||||
agent_engine_id = memory_service_uri.split("://")[1]
|
||||
if not agent_engine_id:
|
||||
raise click.ClickException("Agent engine id can not be empty.")
|
||||
envs.load_dotenv_for_agent("", agents_dir)
|
||||
agent_engine_id_or_resource_name = memory_service_uri.split("://")[1]
|
||||
project, location, agent_engine_id = _parse_agent_engine_resource_name(
|
||||
agent_engine_id_or_resource_name
|
||||
)
|
||||
memory_service = VertexAiMemoryBankService(
|
||||
project=os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
location=os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
project=project,
|
||||
location=location,
|
||||
agent_engine_id=agent_engine_id,
|
||||
)
|
||||
else:
|
||||
@@ -327,14 +352,13 @@ def get_fast_api_app(
|
||||
# Build the Session service
|
||||
if session_service_uri:
|
||||
if session_service_uri.startswith("agentengine://"):
|
||||
# Create vertex session service
|
||||
agent_engine_id = session_service_uri.split("://")[1]
|
||||
if not agent_engine_id:
|
||||
raise click.ClickException("Agent engine id can not be empty.")
|
||||
envs.load_dotenv_for_agent("", agents_dir)
|
||||
agent_engine_id_or_resource_name = session_service_uri.split("://")[1]
|
||||
project, location, agent_engine_id = _parse_agent_engine_resource_name(
|
||||
agent_engine_id_or_resource_name
|
||||
)
|
||||
session_service = VertexAiSessionService(
|
||||
project=os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
location=os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
project=project,
|
||||
location=location,
|
||||
agent_engine_id=agent_engine_id,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user