You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Add back unit tests for CLI utility to deploy to AgentEngine
Co-authored-by: Yeesian Ng <ysian@google.com> PiperOrigin-RevId: 856749290
This commit is contained in:
committed by
Copybara-Service
parent
6ad18cc2fc
commit
6dbe851fca
@@ -20,6 +20,7 @@ import shutil
|
||||
import subprocess
|
||||
from typing import Final
|
||||
from typing import Optional
|
||||
import warnings
|
||||
|
||||
import click
|
||||
from packaging.version import parse
|
||||
@@ -27,6 +28,36 @@ from packaging.version import parse
|
||||
_IS_WINDOWS = os.name == 'nt'
|
||||
_GCLOUD_CMD = 'gcloud.cmd' if _IS_WINDOWS else 'gcloud'
|
||||
_LOCAL_STORAGE_FLAG_MIN_VERSION: Final[str] = '1.21.0'
|
||||
_AGENT_ENGINE_REQUIREMENT: Final[str] = (
|
||||
'google-cloud-aiplatform[adk,agent_engines]'
|
||||
)
|
||||
|
||||
|
||||
def _ensure_agent_engine_dependency(requirements_txt_path: str) -> None:
|
||||
"""Ensures staged requirements include Agent Engine dependencies."""
|
||||
if not os.path.exists(requirements_txt_path):
|
||||
raise FileNotFoundError(
|
||||
f'requirements.txt not found at: {requirements_txt_path}'
|
||||
)
|
||||
|
||||
requirements = ''
|
||||
with open(requirements_txt_path, 'r', encoding='utf-8') as f:
|
||||
requirements = f.read()
|
||||
|
||||
for line in requirements.splitlines():
|
||||
stripped = line.strip()
|
||||
if (
|
||||
stripped
|
||||
and not stripped.startswith('#')
|
||||
and stripped.startswith('google-cloud-aiplatform')
|
||||
):
|
||||
return
|
||||
|
||||
with open(requirements_txt_path, 'a', encoding='utf-8') as f:
|
||||
if requirements and not requirements.endswith('\n'):
|
||||
f.write('\n')
|
||||
f.write(_AGENT_ENGINE_REQUIREMENT + '\n')
|
||||
|
||||
|
||||
_DOCKERFILE_TEMPLATE: Final[str] = """
|
||||
FROM python:3.11-slim
|
||||
@@ -656,7 +687,7 @@ def to_agent_engine(
|
||||
agent_folder: str,
|
||||
temp_folder: Optional[str] = None,
|
||||
adk_app: str,
|
||||
staging_bucket: str,
|
||||
staging_bucket: Optional[str] = None,
|
||||
trace_to_cloud: Optional[bool] = None,
|
||||
api_key: Optional[str] = None,
|
||||
adk_app_object: Optional[str] = None,
|
||||
@@ -699,7 +730,8 @@ def to_agent_engine(
|
||||
files. It will be replaced with the generated files if it already exists.
|
||||
adk_app (str): The name of the file (without .py) containing the AdkApp
|
||||
instance.
|
||||
staging_bucket (str): The GCS bucket for staging the deployment artifacts.
|
||||
staging_bucket (str): Deprecated. This argument is no longer required or
|
||||
used.
|
||||
trace_to_cloud (bool): Whether to enable Cloud Trace.
|
||||
api_key (str): Optional. The API key to use for Express Mode.
|
||||
If not provided, the API key from the GOOGLE_API_KEY environment variable
|
||||
@@ -729,13 +761,6 @@ def to_agent_engine(
|
||||
app_name = os.path.basename(agent_folder)
|
||||
display_name = display_name or app_name
|
||||
parent_folder = os.path.dirname(agent_folder)
|
||||
if parent_folder != os.getcwd():
|
||||
click.echo(f'Please deploy from the project dir: {parent_folder}')
|
||||
return
|
||||
tmp_app_name = app_name + '_tmp' + datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
temp_folder = temp_folder or tmp_app_name
|
||||
agent_src_path = os.path.join(parent_folder, temp_folder)
|
||||
click.echo(f'Staging all files in: {agent_src_path}')
|
||||
adk_app_object = adk_app_object or 'root_agent'
|
||||
if adk_app_object not in ['root_agent', 'app']:
|
||||
click.echo(
|
||||
@@ -743,12 +768,34 @@ def to_agent_engine(
|
||||
' or "app".'
|
||||
)
|
||||
return
|
||||
if staging_bucket:
|
||||
warnings.warn(
|
||||
'WARNING: `staging_bucket` is deprecated and will be removed in a'
|
||||
' future release. Please drop it from the list of arguments.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
original_cwd = os.getcwd()
|
||||
did_change_cwd = False
|
||||
if parent_folder != original_cwd:
|
||||
click.echo(
|
||||
'Agent Engine deployment uses relative paths; temporarily switching '
|
||||
f'working directory to: {parent_folder}'
|
||||
)
|
||||
os.chdir(parent_folder)
|
||||
did_change_cwd = True
|
||||
tmp_app_name = app_name + '_tmp' + datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
temp_folder = temp_folder or tmp_app_name
|
||||
agent_src_path = os.path.join(parent_folder, temp_folder)
|
||||
click.echo(f'Staging all files in: {agent_src_path}')
|
||||
# remove agent_src_path if it exists
|
||||
if os.path.exists(agent_src_path):
|
||||
click.echo('Removing existing files')
|
||||
shutil.rmtree(agent_src_path)
|
||||
|
||||
try:
|
||||
click.echo(f'Staging all files in: {agent_src_path}')
|
||||
ignore_patterns = None
|
||||
ae_ignore_path = os.path.join(agent_folder, '.ae_ignore')
|
||||
if os.path.exists(ae_ignore_path):
|
||||
@@ -757,15 +804,18 @@ def to_agent_engine(
|
||||
patterns = [pattern.strip() for pattern in f.readlines()]
|
||||
ignore_patterns = shutil.ignore_patterns(*patterns)
|
||||
click.echo('Copying agent source code...')
|
||||
shutil.copytree(agent_folder, agent_src_path, ignore=ignore_patterns)
|
||||
shutil.copytree(
|
||||
agent_folder,
|
||||
agent_src_path,
|
||||
ignore=ignore_patterns,
|
||||
dirs_exist_ok=True,
|
||||
)
|
||||
click.echo('Copying agent source code complete.')
|
||||
|
||||
project = _resolve_project(project)
|
||||
|
||||
click.echo('Resolving files and dependencies...')
|
||||
agent_config = {}
|
||||
if staging_bucket:
|
||||
agent_config['staging_bucket'] = staging_bucket
|
||||
if not agent_engine_config_file:
|
||||
# Attempt to read the agent engine config from .agent_engine_config.json in the dir (if any).
|
||||
agent_engine_config_file = os.path.join(
|
||||
@@ -808,8 +858,9 @@ def to_agent_engine(
|
||||
if not os.path.exists(requirements_txt_path):
|
||||
click.echo(f'Creating {requirements_txt_path}...')
|
||||
with open(requirements_txt_path, 'w', encoding='utf-8') as f:
|
||||
f.write('google-cloud-aiplatform[adk,agent_engines]')
|
||||
f.write(_AGENT_ENGINE_REQUIREMENT + '\n')
|
||||
click.echo(f'Created {requirements_txt_path}')
|
||||
_ensure_agent_engine_dependency(requirements_txt_path)
|
||||
agent_config['requirements_file'] = f'{temp_folder}/requirements.txt'
|
||||
|
||||
env_vars = {}
|
||||
@@ -940,7 +991,9 @@ def to_agent_engine(
|
||||
click.secho(f'✅ Updated agent engine: {agent_engine_id}', fg='green')
|
||||
finally:
|
||||
click.echo(f'Cleaning up the temp folder: {temp_folder}')
|
||||
shutil.rmtree(temp_folder)
|
||||
shutil.rmtree(agent_src_path)
|
||||
if did_change_cwd:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
def to_gke(
|
||||
|
||||
@@ -1031,6 +1031,19 @@ def web_options():
|
||||
return decorator
|
||||
|
||||
|
||||
def _deprecate_staging_bucket(ctx, param, value):
|
||||
if value:
|
||||
click.echo(
|
||||
click.style(
|
||||
f"WARNING: --{param} is deprecated and will be removed. Please"
|
||||
" leave it unspecified.",
|
||||
fg="yellow",
|
||||
),
|
||||
err=True,
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def deprecated_adk_services_options():
|
||||
"""Deprecated ADK services options."""
|
||||
|
||||
@@ -1689,10 +1702,8 @@ def cli_migrate_session(
|
||||
"--staging_bucket",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Optional. GCS bucket for staging the deployment artifacts. It will be"
|
||||
" ignored if api_key is set."
|
||||
),
|
||||
help="Deprecated. This argument is no longer required or used.",
|
||||
callback=_deprecate_staging_bucket,
|
||||
)
|
||||
@click.option(
|
||||
"--agent_engine_id",
|
||||
@@ -1827,8 +1838,7 @@ def cli_deploy_agent_engine(
|
||||
|
||||
# With Google Cloud Project and Region
|
||||
adk deploy agent_engine --project=[project] --region=[region]
|
||||
--staging_bucket=[staging_bucket] --display_name=[app_name]
|
||||
my_agent
|
||||
--display_name=[app_name] my_agent
|
||||
"""
|
||||
logging.getLogger("vertexai_genai.agentengines").setLevel(logging.INFO)
|
||||
try:
|
||||
@@ -1836,7 +1846,6 @@ def cli_deploy_agent_engine(
|
||||
agent_folder=agent,
|
||||
project=project,
|
||||
region=region,
|
||||
staging_bucket=staging_bucket,
|
||||
agent_engine_id=agent_engine_id,
|
||||
trace_to_cloud=trace_to_cloud,
|
||||
api_key=api_key,
|
||||
|
||||
@@ -26,7 +26,6 @@ import types
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import Dict
|
||||
from typing import Generator
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from unittest import mock
|
||||
@@ -227,6 +226,72 @@ def test_get_service_option_by_adk_version(
|
||||
assert actual.rstrip() == expected.rstrip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("include_requirements", [True, False])
|
||||
def test_to_agent_engine_happy_path(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
agent_dir: Callable[[bool, bool], Path],
|
||||
include_requirements: bool,
|
||||
) -> None:
|
||||
"""Tests the happy path for the `to_agent_engine` function."""
|
||||
rmtree_recorder = _Recorder()
|
||||
monkeypatch.setattr(shutil, "rmtree", rmtree_recorder)
|
||||
create_recorder = _Recorder()
|
||||
|
||||
fake_vertexai = types.ModuleType("vertexai")
|
||||
|
||||
class _FakeAgentEngines:
|
||||
|
||||
def create(self, *, config: Dict[str, Any]) -> Any:
|
||||
create_recorder(config=config)
|
||||
return types.SimpleNamespace(
|
||||
api_resource=types.SimpleNamespace(
|
||||
name="projects/p/locations/l/reasoningEngines/e"
|
||||
)
|
||||
)
|
||||
|
||||
def update(self, *, name: str, config: Dict[str, Any]) -> None:
|
||||
del name
|
||||
del config
|
||||
|
||||
class _FakeVertexClient:
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
del args
|
||||
del kwargs
|
||||
self.agent_engines = _FakeAgentEngines()
|
||||
|
||||
fake_vertexai.Client = _FakeVertexClient
|
||||
monkeypatch.setitem(sys.modules, "vertexai", fake_vertexai)
|
||||
src_dir = agent_dir(include_requirements, False)
|
||||
tmp_dir = src_dir.parent / "tmp"
|
||||
cli_deploy.to_agent_engine(
|
||||
agent_folder=str(src_dir),
|
||||
temp_folder="tmp",
|
||||
adk_app="my_adk_app",
|
||||
trace_to_cloud=True,
|
||||
project="my-gcp-project",
|
||||
region="us-central1",
|
||||
display_name="My Test Agent",
|
||||
description="A test agent.",
|
||||
)
|
||||
agent_file = tmp_dir / "agent.py"
|
||||
assert agent_file.is_file()
|
||||
init_file = tmp_dir / "__init__.py"
|
||||
assert init_file.is_file()
|
||||
adk_app_file = tmp_dir / "my_adk_app.py"
|
||||
assert adk_app_file.is_file()
|
||||
content = adk_app_file.read_text()
|
||||
assert "from .agent import root_agent" in content
|
||||
assert "adk_app = AdkApp(" in content
|
||||
assert "agent=root_agent" in content
|
||||
assert "enable_tracing=True" in content
|
||||
reqs_path = tmp_dir / "requirements.txt"
|
||||
assert reqs_path.is_file()
|
||||
assert "google-cloud-aiplatform[adk,agent_engines]" in reqs_path.read_text()
|
||||
assert len(create_recorder.calls) == 1
|
||||
assert str(rmtree_recorder.get_last_call_args()[0]) == str(tmp_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("include_requirements", [True, False])
|
||||
def test_to_gke_happy_path(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
|
||||
@@ -400,8 +400,6 @@ def test_cli_deploy_agent_engine_success(
|
||||
"test-proj",
|
||||
"--region",
|
||||
"us-central1",
|
||||
"--staging_bucket",
|
||||
"gs://mybucket",
|
||||
str(agent_dir),
|
||||
],
|
||||
)
|
||||
@@ -410,7 +408,6 @@ def test_cli_deploy_agent_engine_success(
|
||||
called_kwargs = rec.calls[0][1]
|
||||
assert called_kwargs.get("project") == "test-proj"
|
||||
assert called_kwargs.get("region") == "us-central1"
|
||||
assert called_kwargs.get("staging_bucket") == "gs://mybucket"
|
||||
|
||||
|
||||
# cli deploy gke
|
||||
|
||||
Reference in New Issue
Block a user