You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: support regex for allowed origins
fixes https://github.com/google/adk-python/issues/3908 Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 845397350
This commit is contained in:
committed by
Copybara-Service
parent
b6f6dcbeb4
commit
2ea6e513cf
@@ -104,6 +104,36 @@ _EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
||||
TAG_DEBUG = "Debug"
|
||||
TAG_EVALUATION = "Evaluation"
|
||||
|
||||
_REGEX_PREFIX = "regex:"
|
||||
|
||||
|
||||
def _parse_cors_origins(
|
||||
allow_origins: list[str],
|
||||
) -> tuple[list[str], Optional[str]]:
|
||||
"""Parse allow_origins into literal origins and a combined regex pattern.
|
||||
|
||||
Args:
|
||||
allow_origins: List of origin strings. Entries prefixed with 'regex:' are
|
||||
treated as regex patterns; all others are treated as literal origins.
|
||||
|
||||
Returns:
|
||||
A tuple of (literal_origins, combined_regex) where combined_regex is None
|
||||
if no regex patterns were provided, or a single pattern joining all regex
|
||||
patterns with '|'.
|
||||
"""
|
||||
literal_origins = []
|
||||
regex_patterns = []
|
||||
for origin in allow_origins:
|
||||
if origin.startswith(_REGEX_PREFIX):
|
||||
pattern = origin[len(_REGEX_PREFIX) :]
|
||||
if pattern:
|
||||
regex_patterns.append(pattern)
|
||||
else:
|
||||
literal_origins.append(origin)
|
||||
|
||||
combined_regex = "|".join(regex_patterns) if regex_patterns else None
|
||||
return literal_origins, combined_regex
|
||||
|
||||
|
||||
class ApiServerSpanExporter(export_lib.SpanExporter):
|
||||
|
||||
@@ -662,6 +692,8 @@ class AdkWebServer:
|
||||
Args:
|
||||
lifespan: The lifespan of the FastAPI app.
|
||||
allow_origins: The origins that are allowed to make cross-origin requests.
|
||||
Entries can be literal origins (e.g., 'https://example.com') or regex
|
||||
patterns prefixed with 'regex:' (e.g., 'regex:https://.*\\.example\\.com').
|
||||
web_assets_dir: The directory containing the web assets to serve.
|
||||
setup_observer: Callback for setting up the file system observer.
|
||||
tear_down_observer: Callback for cleaning up the file system observer.
|
||||
@@ -714,9 +746,11 @@ class AdkWebServer:
|
||||
app = FastAPI(lifespan=internal_lifespan)
|
||||
|
||||
if allow_origins:
|
||||
literal_origins, combined_regex = _parse_cors_origins(allow_origins)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allow_origins,
|
||||
allow_origins=literal_origins,
|
||||
allow_origin_regex=combined_regex,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -512,7 +512,8 @@ def to_cloud_run(
|
||||
with_ui: Whether to deploy with UI.
|
||||
verbosity: The verbosity level of the CLI.
|
||||
adk_version: The ADK version to use in Cloud Run.
|
||||
allow_origins: The list of allowed origins for the ADK api server.
|
||||
allow_origins: Origins to allow for CORS. Can be literal origins or regex
|
||||
patterns prefixed with 'regex:'.
|
||||
session_service_uri: The URI of the session service.
|
||||
artifact_service_uri: The URI of the artifact service.
|
||||
memory_service_uri: The URI of the memory service.
|
||||
@@ -961,7 +962,8 @@ def to_gke(
|
||||
with_ui: Whether to deploy with UI.
|
||||
log_level: The logging level.
|
||||
adk_version: The ADK version to use in GKE.
|
||||
allow_origins: The list of allowed origins for the ADK api server.
|
||||
allow_origins: Origins to allow for CORS. Can be literal origins or regex
|
||||
patterns prefixed with 'regex:'.
|
||||
session_service_uri: The URI of the session service.
|
||||
artifact_service_uri: The URI of the artifact service.
|
||||
memory_service_uri: The URI of the memory service.
|
||||
|
||||
@@ -1002,7 +1002,11 @@ def fast_api_common_options():
|
||||
)
|
||||
@click.option(
|
||||
"--allow_origins",
|
||||
help="Optional. Any additional origins to allow for CORS.",
|
||||
help=(
|
||||
"Optional. Origins to allow for CORS. Can be literal origins"
|
||||
" (e.g., 'https://example.com') or regex patterns prefixed with"
|
||||
" 'regex:' (e.g., 'regex:https://.*\\.example\\.com')."
|
||||
),
|
||||
multiple=True,
|
||||
)
|
||||
@click.option(
|
||||
@@ -1390,7 +1394,11 @@ def cli_api_server(
|
||||
)
|
||||
@click.option(
|
||||
"--allow_origins",
|
||||
help="Optional. Any additional origins to allow for CORS.",
|
||||
help=(
|
||||
"Optional. Origins to allow for CORS. Can be literal origins"
|
||||
" (e.g., 'https://example.com') or regex patterns prefixed with"
|
||||
" 'regex:' (e.g., 'regex:https://.*\\.example\\.com')."
|
||||
),
|
||||
multiple=True,
|
||||
)
|
||||
# TODO: Add eval_storage_uri option back when evals are supported in Cloud Run.
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for CORS configuration with regex prefix support."""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
from google.adk.artifacts.base_artifact_service import BaseArtifactService
|
||||
from google.adk.auth.credential_service.base_credential_service import BaseCredentialService
|
||||
from google.adk.cli.adk_web_server import _parse_cors_origins
|
||||
from google.adk.cli.adk_web_server import AdkWebServer
|
||||
from google.adk.cli.utils.base_agent_loader import BaseAgentLoader
|
||||
from google.adk.evaluation.eval_set_results_manager import EvalSetResultsManager
|
||||
from google.adk.evaluation.eval_sets_manager import EvalSetsManager
|
||||
from google.adk.memory.base_memory_service import BaseMemoryService
|
||||
from google.adk.sessions.base_session_service import BaseSessionService
|
||||
import pytest
|
||||
|
||||
|
||||
class MockAgentLoader:
|
||||
"""Mock agent loader for testing."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load_agent(self, app_name):
|
||||
del self, app_name
|
||||
return mock.MagicMock()
|
||||
|
||||
def list_agents(self):
|
||||
del self
|
||||
return ["test_app"]
|
||||
|
||||
def list_agents_detailed(self):
|
||||
del self
|
||||
return []
|
||||
|
||||
|
||||
def create_adk_web_server():
|
||||
"""Create an AdkWebServer instance for testing."""
|
||||
return AdkWebServer(
|
||||
agent_loader=MockAgentLoader(),
|
||||
session_service=mock.create_autospec(BaseSessionService, instance=True),
|
||||
memory_service=mock.create_autospec(BaseMemoryService, instance=True),
|
||||
artifact_service=mock.create_autospec(BaseArtifactService, instance=True),
|
||||
credential_service=mock.create_autospec(
|
||||
BaseCredentialService, instance=True
|
||||
),
|
||||
eval_sets_manager=mock.create_autospec(EvalSetsManager, instance=True),
|
||||
eval_set_results_manager=mock.create_autospec(
|
||||
EvalSetResultsManager, instance=True
|
||||
),
|
||||
agents_dir=".",
|
||||
)
|
||||
|
||||
|
||||
def _get_cors_middleware(app):
|
||||
"""Extract CORSMiddleware from app's middleware stack.
|
||||
|
||||
Returns:
|
||||
The CORSMiddleware instance, or None if not found.
|
||||
"""
|
||||
for middleware in app.user_middleware:
|
||||
if middleware.cls.__name__ == "CORSMiddleware":
|
||||
return middleware
|
||||
return None
|
||||
|
||||
|
||||
CORS_ORIGINS_TEST_CASES = [
|
||||
# Literal origins only
|
||||
(
|
||||
["https://example.com", "https://test.com"],
|
||||
["https://example.com", "https://test.com"],
|
||||
None,
|
||||
),
|
||||
# Regex patterns only
|
||||
(
|
||||
[
|
||||
"regex:https://.*\\.example\\.com",
|
||||
"regex:https://.*\\.test\\.com",
|
||||
],
|
||||
[],
|
||||
"https://.*\\.example\\.com|https://.*\\.test\\.com",
|
||||
),
|
||||
# Mixed literal and regex
|
||||
(
|
||||
[
|
||||
"https://example.com",
|
||||
"regex:https://.*\\.subdomain\\.com",
|
||||
"https://test.com",
|
||||
"regex:https://tenant-.*\\.myapp\\.com",
|
||||
],
|
||||
["https://example.com", "https://test.com"],
|
||||
"https://.*\\.subdomain\\.com|https://tenant-.*\\.myapp\\.com",
|
||||
),
|
||||
# Wildcard origin
|
||||
(["*"], ["*"], None),
|
||||
# Single regex
|
||||
(
|
||||
["regex:https://.*\\.example\\.com"],
|
||||
[],
|
||||
"https://.*\\.example\\.com",
|
||||
),
|
||||
]
|
||||
|
||||
CORS_ORIGINS_TEST_IDS = [
|
||||
"literal_only",
|
||||
"regex_only",
|
||||
"mixed",
|
||||
"wildcard",
|
||||
"single_regex",
|
||||
]
|
||||
|
||||
|
||||
class TestParseCorsOrigins:
|
||||
"""Tests for the _parse_cors_origins helper function."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allow_origins,expected_literal,expected_regex",
|
||||
CORS_ORIGINS_TEST_CASES,
|
||||
ids=CORS_ORIGINS_TEST_IDS,
|
||||
)
|
||||
def test_parse_cors_origins(
|
||||
self, allow_origins, expected_literal, expected_regex
|
||||
):
|
||||
"""Test parsing of allow_origins into literal and regex components."""
|
||||
literal_origins, combined_regex = _parse_cors_origins(allow_origins)
|
||||
assert literal_origins == expected_literal
|
||||
assert combined_regex == expected_regex
|
||||
|
||||
|
||||
class TestCorsMiddlewareConfiguration:
|
||||
"""Tests for CORS middleware configuration in AdkWebServer."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allow_origins,expected_literal,expected_regex",
|
||||
CORS_ORIGINS_TEST_CASES,
|
||||
ids=CORS_ORIGINS_TEST_IDS,
|
||||
)
|
||||
def test_cors_middleware_configuration(
|
||||
self, allow_origins, expected_literal, expected_regex
|
||||
):
|
||||
"""Test CORS middleware is configured correctly with various origin types."""
|
||||
server = create_adk_web_server()
|
||||
app = server.get_fast_api_app(
|
||||
allow_origins=allow_origins,
|
||||
setup_observer=lambda _o, _s: None,
|
||||
tear_down_observer=lambda _o, _s: None,
|
||||
)
|
||||
|
||||
cors_middleware = _get_cors_middleware(app)
|
||||
assert cors_middleware is not None
|
||||
assert cors_middleware.kwargs["allow_origins"] == expected_literal
|
||||
assert cors_middleware.kwargs["allow_origin_regex"] == expected_regex
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"allow_origins",
|
||||
[None, []],
|
||||
ids=["none", "empty_list"],
|
||||
)
|
||||
def test_cors_middleware_not_added_when_no_origins(self, allow_origins):
|
||||
"""Test that no CORS middleware is added when allow_origins is None or empty."""
|
||||
server = create_adk_web_server()
|
||||
app = server.get_fast_api_app(
|
||||
allow_origins=allow_origins,
|
||||
setup_observer=lambda _o, _s: None,
|
||||
tear_down_observer=lambda _o, _s: None,
|
||||
)
|
||||
|
||||
cors_middleware = _get_cors_middleware(app)
|
||||
assert cors_middleware is None
|
||||
Reference in New Issue
Block a user