diff --git a/src/google/adk/cli/utils/service_factory.py b/src/google/adk/cli/utils/service_factory.py index c03ac10b..d8903ece 100644 --- a/src/google/adk/cli/utils/service_factory.py +++ b/src/google/adk/cli/utils/service_factory.py @@ -19,6 +19,9 @@ import os from pathlib import Path from typing import Any from typing import Optional +from urllib.parse import parse_qsl +from urllib.parse import urlsplit +from urllib.parse import urlunsplit from ...artifacts.base_artifact_service import BaseArtifactService from ...memory.base_memory_service import BaseMemoryService @@ -42,6 +45,41 @@ _CLOUD_RUN_SERVICE_ENV = "K_SERVICE" _KUBERNETES_HOST_ENV = "KUBERNETES_SERVICE_HOST" +def _redact_uri_for_log(uri: str) -> str: + """Returns a safe-to-log representation of a URI. + + Redacts user info (username/password) and query parameter values. + """ + if not uri or not uri.strip(): + return "" + sanitized = uri.replace("\r", "\\r").replace("\n", "\\n") + if "://" not in sanitized: + return "" + try: + parsed = urlsplit(sanitized) + except ValueError: + return "" + + if not parsed.scheme: + return "" + + netloc = parsed.netloc + if "@" in netloc: + _, netloc = netloc.rsplit("@", 1) + + if parsed.query: + try: + redacted_pairs = parse_qsl(parsed.query, keep_blank_values=True) + except ValueError: + query = "" + else: + query = "&".join(f"{key}=" for key, _ in redacted_pairs) + else: + query = "" + + return urlunsplit((parsed.scheme, netloc, parsed.path, query, "")) + + def _is_cloud_run() -> bool: """Returns True when running in Cloud Run.""" return bool(os.environ.get(_CLOUD_RUN_SERVICE_ENV)) @@ -148,7 +186,10 @@ def create_session_service_from_options( kwargs.update(session_db_kwargs) if session_service_uri: - logger.info("Using session service URI: %s", session_service_uri) + logger.info( + "Using session service URI: %s", + _redact_uri_for_log(session_service_uri), + ) service = registry.create_session_service(session_service_uri, **kwargs) if service is not None: return service @@ -162,7 +203,7 @@ def create_session_service_from_options( fallback_kwargs.pop("agents_dir", None) logger.info( "Falling back to DatabaseSessionService for URI: %s", - session_service_uri, + _redact_uri_for_log(session_service_uri), ) return DatabaseSessionService(db_url=session_service_uri, **fallback_kwargs) @@ -208,13 +249,18 @@ def create_memory_service_from_options( registry = get_service_registry() if memory_service_uri: - logger.info("Using memory service URI: %s", memory_service_uri) + logger.info( + "Using memory service URI: %s", _redact_uri_for_log(memory_service_uri) + ) service = registry.create_memory_service( memory_service_uri, agents_dir=str(base_path), ) if service is None: - raise ValueError(f"Unsupported memory service URI: {memory_service_uri}") + raise ValueError( + "Unsupported memory service URI: %s" + % _redact_uri_for_log(memory_service_uri) + ) return service logger.info("Using in-memory memory service") @@ -235,7 +281,10 @@ def create_artifact_service_from_options( registry = get_service_registry() if artifact_service_uri: - logger.info("Using artifact service URI: %s", artifact_service_uri) + logger.info( + "Using artifact service URI: %s", + _redact_uri_for_log(artifact_service_uri), + ) service = registry.create_artifact_service( artifact_service_uri, agents_dir=str(base_path), @@ -243,11 +292,12 @@ def create_artifact_service_from_options( if service is None: if strict_uri: raise ValueError( - f"Unsupported artifact service URI: {artifact_service_uri}" + "Unsupported artifact service URI: %s" + % _redact_uri_for_log(artifact_service_uri) ) return _create_in_memory_artifact_service( "Unsupported artifact service URI: %s, falling back to in-memory", - artifact_service_uri, + _redact_uri_for_log(artifact_service_uri), ) return service diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py index 87b567be..ad9a2389 100644 --- a/tests/unittests/cli/utils/test_service_factory.py +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -16,12 +16,14 @@ from __future__ import annotations +import logging import os from pathlib import Path -from unittest.mock import Mock +from unittest import mock from google.adk.artifacts.file_artifact_service import FileArtifactService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.cli.service_registry import ServiceRegistry from google.adk.cli.utils.local_storage import PerAgentDatabaseSessionService import google.adk.cli.utils.service_factory as service_factory from google.adk.memory.in_memory_memory_service import InMemoryMemoryService @@ -31,7 +33,7 @@ import pytest def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) expected = object() registry.create_session_service.return_value = expected monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -48,6 +50,87 @@ def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): ) +def test_create_session_service_logs_redacted_uri( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) + registry.create_session_service.return_value = object() + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + session_service_uri = ( + "postgresql://user:supersecret@localhost:5432/dbname?sslmode=require" + ) + caplog.set_level(logging.INFO, logger=service_factory.logger.name) + + service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri=session_service_uri, + ) + + assert "supersecret" not in caplog.text + assert "sslmode=require" not in caplog.text + assert "localhost:5432" in caplog.text + + +def test_redact_uri_for_log_removes_credentials_with_at_in_password() -> None: + uri = "postgresql://user:super@secret@localhost:5432/dbname" + + assert ( + service_factory._redact_uri_for_log(uri) + == "postgresql://localhost:5432/dbname" + ) + + +def test_redact_uri_for_log_preserves_host_when_no_credentials() -> None: + uri = "postgresql://localhost:5432/dbname?sslmode=require&password=secret" + + redacted = service_factory._redact_uri_for_log(uri) + + assert redacted.startswith("postgresql://localhost:5432/dbname?") + assert "require" not in redacted + assert "secret" not in redacted + assert "sslmode=" in redacted + assert "password=" in redacted + + +def test_redact_uri_for_log_redacts_when_parse_qsl_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + def _raise_value_error(*_args, **_kwargs): + raise ValueError("bad query") + + monkeypatch.setattr(service_factory, "parse_qsl", _raise_value_error) + + uri = "postgresql://user:pass@localhost:5432/dbname?sslmode=require" + redacted = service_factory._redact_uri_for_log(uri) + + assert "pass" not in redacted + assert "require" not in redacted + assert redacted.endswith("?") + + +def test_redact_uri_for_log_escapes_crlf() -> None: + uri = ( + "postgresql://user:pass@localhost:5432/dbname\rINJECT\nINJECT" + "?sslmode=require" + ) + + redacted = service_factory._redact_uri_for_log(uri) + + assert "\r" not in redacted + assert "\n" not in redacted + assert "\\rINJECT\\nINJECT" in redacted + + +def test_redact_uri_for_log_returns_scheme_missing_without_separator() -> None: + assert ( + service_factory._redact_uri_for_log("user:pass@localhost:5432/dbname") + == "" + ) + + @pytest.mark.asyncio async def test_create_session_service_defaults_to_per_agent_sqlite( tmp_path: Path, @@ -88,7 +171,7 @@ async def test_create_session_service_respects_app_name_mapping( def test_create_session_service_fallbacks_to_database( tmp_path: Path, monkeypatch ): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) registry.create_session_service.return_value = None monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -109,7 +192,7 @@ def test_create_session_service_fallbacks_to_database( def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) expected = object() registry.create_artifact_service.return_value = expected monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -129,7 +212,7 @@ def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch): def test_create_artifact_service_raises_on_unknown_scheme_when_strict( tmp_path: Path, monkeypatch ): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) registry.create_artifact_service.return_value = None monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -142,7 +225,7 @@ def test_create_artifact_service_raises_on_unknown_scheme_when_strict( def test_create_memory_service_uses_registry(tmp_path: Path, monkeypatch): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) expected = object() registry.create_memory_service.return_value = expected monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) @@ -170,7 +253,7 @@ def test_create_memory_service_defaults_to_in_memory(tmp_path: Path): def test_create_memory_service_raises_on_unknown_scheme( tmp_path: Path, monkeypatch ): - registry = Mock() + registry = mock.create_autospec(ServiceRegistry, instance=True, spec_set=True) registry.create_memory_service.return_value = None monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry)