You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Introduce OAuth2DiscoveryManager to fetch metadata needed for OAuth
This is the first step to bring ADK to compliance with MCP Authorization Spec. PiperOrigin-RevId: 811177152
This commit is contained in:
committed by
Copybara-Service
parent
5a485b01cd
commit
2a2da0fe03
@@ -0,0 +1,148 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ..utils.feature_decorator import experimental
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
@experimental
|
||||
class AuthorizationServerMetadata(BaseModel):
|
||||
"""Represents the OAuth2 authorization server metadata per RFC8414."""
|
||||
|
||||
issuer: str
|
||||
authorization_endpoint: str
|
||||
token_endpoint: str
|
||||
scopes_supported: Optional[List[str]] = None
|
||||
registration_endpoint: Optional[str] = None
|
||||
|
||||
|
||||
@experimental
|
||||
class ProtectedResourceMetadata(BaseModel):
|
||||
"""Represents the OAuth2 protected resource metadata per RFC9728."""
|
||||
|
||||
resource: str
|
||||
authorization_servers: List[str] = []
|
||||
|
||||
|
||||
@experimental
|
||||
class OAuth2DiscoveryManager:
|
||||
"""Implements Metadata discovery for OAuth2 following RFC8414 and RFC9728."""
|
||||
|
||||
async def discover_auth_server_metadata(
|
||||
self, issuer_url: str
|
||||
) -> Optional[AuthorizationServerMetadata]:
|
||||
"""Discovers the OAuth2 authorization server metadata."""
|
||||
try:
|
||||
parsed_url = urlparse(issuer_url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
path = parsed_url.path
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to parse issuer_url %s: %s", issuer_url, e)
|
||||
return None
|
||||
|
||||
# Try the standard well-known endpoints in order.
|
||||
if path and path != "/":
|
||||
endpoints_to_try = [
|
||||
# 1. OAuth 2.0 Authorization Server Metadata with path insertion
|
||||
f"{base_url}/.well-known/oauth-authorization-server{path}",
|
||||
# 2. OpenID Connect Discovery 1.0 with path insertion
|
||||
f"{base_url}/.well-known/openid-configuration{path}",
|
||||
# 3. OpenID Connect Discovery 1.0 with path appending
|
||||
f"{base_url}{path}/.well-known/openid-configuration",
|
||||
]
|
||||
else:
|
||||
endpoints_to_try = [
|
||||
# 1. OAuth 2.0 Authorization Server Metadata
|
||||
f"{base_url}/.well-known/oauth-authorization-server",
|
||||
# 2. OpenID Connect Discovery 1.0
|
||||
f"{base_url}/.well-known/openid-configuration",
|
||||
]
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
for endpoint in endpoints_to_try:
|
||||
try:
|
||||
response = await client.get(endpoint, timeout=5)
|
||||
response.raise_for_status()
|
||||
metadata = AuthorizationServerMetadata.model_validate(response.json())
|
||||
# Validate issuer to defend against MIX-UP attacks
|
||||
if metadata.issuer == issuer_url.rstrip("/"):
|
||||
return metadata
|
||||
else:
|
||||
logger.warning(
|
||||
"Issuer in metadata %s does not match issuer_url %s",
|
||||
metadata.issuer,
|
||||
issuer_url,
|
||||
)
|
||||
except httpx.HTTPError as e:
|
||||
logger.debug("Failed to fetch metadata from %s: %s", endpoint, e)
|
||||
except (json.decoder.JSONDecodeError, ValidationError) as e:
|
||||
logger.debug("Failed to parse metadata from %s: %s", endpoint, e)
|
||||
return None
|
||||
|
||||
async def discover_resource_metadata(
|
||||
self, resource_url: str
|
||||
) -> Optional[ProtectedResourceMetadata]:
|
||||
"""Discovers the OAuth2 protected resource metadata."""
|
||||
try:
|
||||
parsed_url = urlparse(resource_url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
path = parsed_url.path
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to parse resource_url %s: %s", resource_url, e)
|
||||
return None
|
||||
|
||||
if path and path != "/":
|
||||
well_known_endpoint = (
|
||||
f"{base_url}/.well-known/oauth-protected-resource{path}"
|
||||
)
|
||||
else:
|
||||
well_known_endpoint = f"{base_url}/.well-known/oauth-protected-resource"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(well_known_endpoint, timeout=5)
|
||||
response.raise_for_status()
|
||||
metadata = ProtectedResourceMetadata.model_validate(response.json())
|
||||
# Validate resource to defend against MIX-UP attacks
|
||||
if metadata.resource == resource_url.rstrip("/"):
|
||||
return metadata
|
||||
else:
|
||||
logger.warning(
|
||||
"Resource in metadata %s does not match resource_url %s",
|
||||
metadata.resource,
|
||||
resource_url,
|
||||
)
|
||||
except httpx.HTTPError as e:
|
||||
logger.debug(
|
||||
"Failed to fetch metadata from %s: %s", well_known_endpoint, e
|
||||
)
|
||||
except (json.decoder.JSONDecodeError, ValidationError) as e:
|
||||
logger.debug(
|
||||
"Failed to parse metadata from %s: %s", well_known_endpoint, e
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,285 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from unittest.mock import call
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from google.adk.auth.oauth2_discovery import AuthorizationServerMetadata
|
||||
from google.adk.auth.oauth2_discovery import OAuth2DiscoveryManager
|
||||
from google.adk.auth.oauth2_discovery import ProtectedResourceMetadata
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
class TestOAuth2Discovery:
|
||||
"""Tests for the OAuth2DiscoveryManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def auth_server_metadata(self):
|
||||
"""Create AuthorizationServerMetadata object."""
|
||||
return AuthorizationServerMetadata(
|
||||
issuer="https://auth.example.com",
|
||||
authorization_endpoint="https://auth.example.com/authorize",
|
||||
token_endpoint="https://auth.example.com/token",
|
||||
scopes_supported=["read", "write"],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def resource_metadata(self):
|
||||
"""Create ProtectedResourceMetadata object."""
|
||||
return ProtectedResourceMetadata(
|
||||
resource="https://resource.example.com",
|
||||
authorization_servers=["https://auth.example.com"],
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_failed_response(self):
|
||||
"""Create a mock HTTP response with a failure status."""
|
||||
response = Mock()
|
||||
response.raise_for_status.side_effect = httpx.HTTPError("Failed")
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_empty_response(self):
|
||||
"""Create a mock HTTP response with an empty JSON body."""
|
||||
response = Mock()
|
||||
response.json = lambda: {}
|
||||
return response
|
||||
|
||||
@pytest.fixture
|
||||
def mock_invalid_json_response(self):
|
||||
"""Create a mock HTTP response with an invalid JSON body."""
|
||||
response = Mock()
|
||||
response.json.side_effect = json.decoder.JSONDecodeError(
|
||||
"Invalid JSON", "invalid_json", 0
|
||||
)
|
||||
return response
|
||||
|
||||
def mock_success_response(self, json_data):
|
||||
"""Create a mock HTTP successful response with auth server metadata."""
|
||||
response = Mock()
|
||||
response.json = json_data.model_dump
|
||||
return response
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_auth_server_metadata_failed(
|
||||
self,
|
||||
mock_get,
|
||||
mock_failed_response,
|
||||
):
|
||||
"""Test discovering auth server metadata with failed response."""
|
||||
|
||||
mock_get.side_effect = mock_failed_response
|
||||
discovery_manager = OAuth2DiscoveryManager()
|
||||
result = await discovery_manager.discover_auth_server_metadata(
|
||||
"https://auth.example.com"
|
||||
)
|
||||
assert not result
|
||||
mock_get.assert_has_calls([
|
||||
call(
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server",
|
||||
timeout=5,
|
||||
),
|
||||
call(
|
||||
"https://auth.example.com/.well-known/openid-configuration",
|
||||
timeout=5,
|
||||
),
|
||||
])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_metadata_invalid_url(self):
|
||||
"""Test discovering resource/auth metadata with an invalid URL."""
|
||||
discovery_manager = OAuth2DiscoveryManager()
|
||||
result = await discovery_manager.discover_auth_server_metadata("bad_url")
|
||||
assert not result
|
||||
result = await discovery_manager.discover_resource_metadata("bad_url")
|
||||
assert not result
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_auth_server_metadata_without_path(
|
||||
self,
|
||||
mock_get,
|
||||
auth_server_metadata,
|
||||
mock_empty_response,
|
||||
):
|
||||
"""Test discovering auth server metadata with an issuer URL without a path."""
|
||||
|
||||
mock_get.side_effect = [
|
||||
mock_empty_response,
|
||||
self.mock_success_response(auth_server_metadata),
|
||||
]
|
||||
discovery_manager = OAuth2DiscoveryManager()
|
||||
result = await discovery_manager.discover_auth_server_metadata(
|
||||
"https://auth.example.com/"
|
||||
)
|
||||
assert result == auth_server_metadata
|
||||
mock_get.assert_has_calls([
|
||||
call(
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server",
|
||||
timeout=5,
|
||||
),
|
||||
call(
|
||||
"https://auth.example.com/.well-known/openid-configuration",
|
||||
timeout=5,
|
||||
),
|
||||
])
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_auth_server_metadata_with_path(
|
||||
self,
|
||||
mock_get,
|
||||
auth_server_metadata,
|
||||
mock_failed_response,
|
||||
mock_invalid_json_response,
|
||||
):
|
||||
"""Test discovering auth server metadata with an issuer URL with a path."""
|
||||
|
||||
auth_server_metadata.issuer = "https://auth.example.com/oauth"
|
||||
mock_get.side_effect = [
|
||||
mock_failed_response,
|
||||
mock_invalid_json_response,
|
||||
self.mock_success_response(auth_server_metadata),
|
||||
]
|
||||
discovery_manager = OAuth2DiscoveryManager()
|
||||
result = await discovery_manager.discover_auth_server_metadata(
|
||||
"https://auth.example.com/oauth"
|
||||
)
|
||||
assert result == auth_server_metadata
|
||||
mock_get.assert_has_calls([
|
||||
call(
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server/oauth",
|
||||
timeout=5,
|
||||
),
|
||||
call(
|
||||
"https://auth.example.com/.well-known/openid-configuration/oauth",
|
||||
timeout=5,
|
||||
),
|
||||
call(
|
||||
"https://auth.example.com/oauth/.well-known/openid-configuration",
|
||||
timeout=5,
|
||||
),
|
||||
])
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_auth_server_metadata_discard_mismatched_issuer(
|
||||
self,
|
||||
mock_get,
|
||||
auth_server_metadata,
|
||||
):
|
||||
"""Test discover_auth_server_metadata() discards response with mismatched issuer."""
|
||||
|
||||
bad_auth_server_metadata = auth_server_metadata.model_copy(
|
||||
update={"issuer": "https://bad.example.com"}
|
||||
)
|
||||
mock_get.side_effect = [
|
||||
self.mock_success_response(bad_auth_server_metadata),
|
||||
self.mock_success_response(auth_server_metadata),
|
||||
]
|
||||
discovery_manager = OAuth2DiscoveryManager()
|
||||
result = await discovery_manager.discover_auth_server_metadata(
|
||||
"https://auth.example.com"
|
||||
)
|
||||
assert result == auth_server_metadata
|
||||
mock_get.assert_has_calls([
|
||||
call(
|
||||
"https://auth.example.com/.well-known/oauth-authorization-server",
|
||||
timeout=5,
|
||||
),
|
||||
call(
|
||||
"https://auth.example.com/.well-known/openid-configuration",
|
||||
timeout=5,
|
||||
),
|
||||
])
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_resource_metadata_failed(
|
||||
self,
|
||||
mock_get,
|
||||
mock_failed_response,
|
||||
):
|
||||
"""Test discovering resource metadata fails."""
|
||||
|
||||
mock_get.return_value = mock_failed_response
|
||||
discovery_manager = OAuth2DiscoveryManager()
|
||||
result = await discovery_manager.discover_resource_metadata(
|
||||
"https://resource.example.com"
|
||||
)
|
||||
assert not result
|
||||
mock_get.assert_called_once_with(
|
||||
"https://resource.example.com/.well-known/oauth-protected-resource",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_resource_metadata_without_path(
|
||||
self, mock_get, resource_metadata
|
||||
):
|
||||
"""Test discovering resource metadata with a resource URL without a path."""
|
||||
mock_get.return_value = self.mock_success_response(resource_metadata)
|
||||
discovery_manager = OAuth2DiscoveryManager()
|
||||
result = await discovery_manager.discover_resource_metadata(
|
||||
"https://resource.example.com/"
|
||||
)
|
||||
assert result == resource_metadata
|
||||
mock_get.assert_called_once_with(
|
||||
"https://resource.example.com/.well-known/oauth-protected-resource",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_resource_metadata_with_path(
|
||||
self, mock_get, resource_metadata
|
||||
):
|
||||
"""Test discovering resource metadata with a resource URL with a path."""
|
||||
resource_metadata.resource = "https://resource.example.com/tenant1"
|
||||
mock_get.return_value = self.mock_success_response(resource_metadata)
|
||||
discovery_manager = OAuth2DiscoveryManager()
|
||||
result = await discovery_manager.discover_resource_metadata(
|
||||
"https://resource.example.com/tenant1"
|
||||
)
|
||||
assert result == resource_metadata
|
||||
mock_get.assert_called_once_with(
|
||||
"https://resource.example.com/.well-known/oauth-protected-resource/tenant1",
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
@patch("httpx.AsyncClient.get")
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_resource_metadata_discard_mismatched_resource(
|
||||
self,
|
||||
mock_get,
|
||||
resource_metadata,
|
||||
):
|
||||
"""Test discover_resource_metadata() discards response with mismatched resource."""
|
||||
|
||||
resource_metadata.resource = "https://bad.example.com"
|
||||
mock_get.return_value = self.mock_success_response(resource_metadata)
|
||||
discovery_manager = OAuth2DiscoveryManager()
|
||||
result = await discovery_manager.discover_resource_metadata(
|
||||
"https://resource.example.com"
|
||||
)
|
||||
assert not result
|
||||
mock_get.assert_called_once_with(
|
||||
"https://resource.example.com/.well-known/oauth-protected-resource",
|
||||
timeout=5,
|
||||
)
|
||||
Reference in New Issue
Block a user