feat: Introduce OAuth2DiscoveryManager to fetch metadata needed for OAuth

This is the first step to bring ADK to compliance with MCP Authorization Spec.

PiperOrigin-RevId: 811177152
This commit is contained in:
Google Team Member
2025-09-24 21:52:57 -07:00
committed by Copybara-Service
parent 5a485b01cd
commit 2a2da0fe03
2 changed files with 433 additions and 0 deletions
+148
View File
@@ -0,0 +1,148 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import logging
from typing import List
from typing import Optional
from urllib.parse import urlparse
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from ..utils.feature_decorator import experimental
logger = logging.getLogger("google_adk." + __name__)
@experimental
class AuthorizationServerMetadata(BaseModel):
"""Represents the OAuth2 authorization server metadata per RFC8414."""
issuer: str
authorization_endpoint: str
token_endpoint: str
scopes_supported: Optional[List[str]] = None
registration_endpoint: Optional[str] = None
@experimental
class ProtectedResourceMetadata(BaseModel):
"""Represents the OAuth2 protected resource metadata per RFC9728."""
resource: str
authorization_servers: List[str] = []
@experimental
class OAuth2DiscoveryManager:
"""Implements Metadata discovery for OAuth2 following RFC8414 and RFC9728."""
async def discover_auth_server_metadata(
self, issuer_url: str
) -> Optional[AuthorizationServerMetadata]:
"""Discovers the OAuth2 authorization server metadata."""
try:
parsed_url = urlparse(issuer_url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
path = parsed_url.path
except ValueError as e:
logger.warning("Failed to parse issuer_url %s: %s", issuer_url, e)
return None
# Try the standard well-known endpoints in order.
if path and path != "/":
endpoints_to_try = [
# 1. OAuth 2.0 Authorization Server Metadata with path insertion
f"{base_url}/.well-known/oauth-authorization-server{path}",
# 2. OpenID Connect Discovery 1.0 with path insertion
f"{base_url}/.well-known/openid-configuration{path}",
# 3. OpenID Connect Discovery 1.0 with path appending
f"{base_url}{path}/.well-known/openid-configuration",
]
else:
endpoints_to_try = [
# 1. OAuth 2.0 Authorization Server Metadata
f"{base_url}/.well-known/oauth-authorization-server",
# 2. OpenID Connect Discovery 1.0
f"{base_url}/.well-known/openid-configuration",
]
async with httpx.AsyncClient() as client:
for endpoint in endpoints_to_try:
try:
response = await client.get(endpoint, timeout=5)
response.raise_for_status()
metadata = AuthorizationServerMetadata.model_validate(response.json())
# Validate issuer to defend against MIX-UP attacks
if metadata.issuer == issuer_url.rstrip("/"):
return metadata
else:
logger.warning(
"Issuer in metadata %s does not match issuer_url %s",
metadata.issuer,
issuer_url,
)
except httpx.HTTPError as e:
logger.debug("Failed to fetch metadata from %s: %s", endpoint, e)
except (json.decoder.JSONDecodeError, ValidationError) as e:
logger.debug("Failed to parse metadata from %s: %s", endpoint, e)
return None
async def discover_resource_metadata(
self, resource_url: str
) -> Optional[ProtectedResourceMetadata]:
"""Discovers the OAuth2 protected resource metadata."""
try:
parsed_url = urlparse(resource_url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
path = parsed_url.path
except ValueError as e:
logger.warning("Failed to parse resource_url %s: %s", resource_url, e)
return None
if path and path != "/":
well_known_endpoint = (
f"{base_url}/.well-known/oauth-protected-resource{path}"
)
else:
well_known_endpoint = f"{base_url}/.well-known/oauth-protected-resource"
async with httpx.AsyncClient() as client:
try:
response = await client.get(well_known_endpoint, timeout=5)
response.raise_for_status()
metadata = ProtectedResourceMetadata.model_validate(response.json())
# Validate resource to defend against MIX-UP attacks
if metadata.resource == resource_url.rstrip("/"):
return metadata
else:
logger.warning(
"Resource in metadata %s does not match resource_url %s",
metadata.resource,
resource_url,
)
except httpx.HTTPError as e:
logger.debug(
"Failed to fetch metadata from %s: %s", well_known_endpoint, e
)
except (json.decoder.JSONDecodeError, ValidationError) as e:
logger.debug(
"Failed to parse metadata from %s: %s", well_known_endpoint, e
)
return None
@@ -0,0 +1,285 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from unittest.mock import call
from unittest.mock import Mock
from unittest.mock import patch
from google.adk.auth.oauth2_discovery import AuthorizationServerMetadata
from google.adk.auth.oauth2_discovery import OAuth2DiscoveryManager
from google.adk.auth.oauth2_discovery import ProtectedResourceMetadata
import httpx
import pytest
class TestOAuth2Discovery:
"""Tests for the OAuth2DiscoveryManager class."""
@pytest.fixture
def auth_server_metadata(self):
"""Create AuthorizationServerMetadata object."""
return AuthorizationServerMetadata(
issuer="https://auth.example.com",
authorization_endpoint="https://auth.example.com/authorize",
token_endpoint="https://auth.example.com/token",
scopes_supported=["read", "write"],
)
@pytest.fixture
def resource_metadata(self):
"""Create ProtectedResourceMetadata object."""
return ProtectedResourceMetadata(
resource="https://resource.example.com",
authorization_servers=["https://auth.example.com"],
)
@pytest.fixture
def mock_failed_response(self):
"""Create a mock HTTP response with a failure status."""
response = Mock()
response.raise_for_status.side_effect = httpx.HTTPError("Failed")
return response
@pytest.fixture
def mock_empty_response(self):
"""Create a mock HTTP response with an empty JSON body."""
response = Mock()
response.json = lambda: {}
return response
@pytest.fixture
def mock_invalid_json_response(self):
"""Create a mock HTTP response with an invalid JSON body."""
response = Mock()
response.json.side_effect = json.decoder.JSONDecodeError(
"Invalid JSON", "invalid_json", 0
)
return response
def mock_success_response(self, json_data):
"""Create a mock HTTP successful response with auth server metadata."""
response = Mock()
response.json = json_data.model_dump
return response
@patch("httpx.AsyncClient.get")
@pytest.mark.asyncio
async def test_discover_auth_server_metadata_failed(
self,
mock_get,
mock_failed_response,
):
"""Test discovering auth server metadata with failed response."""
mock_get.side_effect = mock_failed_response
discovery_manager = OAuth2DiscoveryManager()
result = await discovery_manager.discover_auth_server_metadata(
"https://auth.example.com"
)
assert not result
mock_get.assert_has_calls([
call(
"https://auth.example.com/.well-known/oauth-authorization-server",
timeout=5,
),
call(
"https://auth.example.com/.well-known/openid-configuration",
timeout=5,
),
])
@pytest.mark.asyncio
async def test_discover_metadata_invalid_url(self):
"""Test discovering resource/auth metadata with an invalid URL."""
discovery_manager = OAuth2DiscoveryManager()
result = await discovery_manager.discover_auth_server_metadata("bad_url")
assert not result
result = await discovery_manager.discover_resource_metadata("bad_url")
assert not result
@patch("httpx.AsyncClient.get")
@pytest.mark.asyncio
async def test_discover_auth_server_metadata_without_path(
self,
mock_get,
auth_server_metadata,
mock_empty_response,
):
"""Test discovering auth server metadata with an issuer URL without a path."""
mock_get.side_effect = [
mock_empty_response,
self.mock_success_response(auth_server_metadata),
]
discovery_manager = OAuth2DiscoveryManager()
result = await discovery_manager.discover_auth_server_metadata(
"https://auth.example.com/"
)
assert result == auth_server_metadata
mock_get.assert_has_calls([
call(
"https://auth.example.com/.well-known/oauth-authorization-server",
timeout=5,
),
call(
"https://auth.example.com/.well-known/openid-configuration",
timeout=5,
),
])
@patch("httpx.AsyncClient.get")
@pytest.mark.asyncio
async def test_discover_auth_server_metadata_with_path(
self,
mock_get,
auth_server_metadata,
mock_failed_response,
mock_invalid_json_response,
):
"""Test discovering auth server metadata with an issuer URL with a path."""
auth_server_metadata.issuer = "https://auth.example.com/oauth"
mock_get.side_effect = [
mock_failed_response,
mock_invalid_json_response,
self.mock_success_response(auth_server_metadata),
]
discovery_manager = OAuth2DiscoveryManager()
result = await discovery_manager.discover_auth_server_metadata(
"https://auth.example.com/oauth"
)
assert result == auth_server_metadata
mock_get.assert_has_calls([
call(
"https://auth.example.com/.well-known/oauth-authorization-server/oauth",
timeout=5,
),
call(
"https://auth.example.com/.well-known/openid-configuration/oauth",
timeout=5,
),
call(
"https://auth.example.com/oauth/.well-known/openid-configuration",
timeout=5,
),
])
@patch("httpx.AsyncClient.get")
@pytest.mark.asyncio
async def test_discover_auth_server_metadata_discard_mismatched_issuer(
self,
mock_get,
auth_server_metadata,
):
"""Test discover_auth_server_metadata() discards response with mismatched issuer."""
bad_auth_server_metadata = auth_server_metadata.model_copy(
update={"issuer": "https://bad.example.com"}
)
mock_get.side_effect = [
self.mock_success_response(bad_auth_server_metadata),
self.mock_success_response(auth_server_metadata),
]
discovery_manager = OAuth2DiscoveryManager()
result = await discovery_manager.discover_auth_server_metadata(
"https://auth.example.com"
)
assert result == auth_server_metadata
mock_get.assert_has_calls([
call(
"https://auth.example.com/.well-known/oauth-authorization-server",
timeout=5,
),
call(
"https://auth.example.com/.well-known/openid-configuration",
timeout=5,
),
])
@patch("httpx.AsyncClient.get")
@pytest.mark.asyncio
async def test_discover_resource_metadata_failed(
self,
mock_get,
mock_failed_response,
):
"""Test discovering resource metadata fails."""
mock_get.return_value = mock_failed_response
discovery_manager = OAuth2DiscoveryManager()
result = await discovery_manager.discover_resource_metadata(
"https://resource.example.com"
)
assert not result
mock_get.assert_called_once_with(
"https://resource.example.com/.well-known/oauth-protected-resource",
timeout=5,
)
@patch("httpx.AsyncClient.get")
@pytest.mark.asyncio
async def test_discover_resource_metadata_without_path(
self, mock_get, resource_metadata
):
"""Test discovering resource metadata with a resource URL without a path."""
mock_get.return_value = self.mock_success_response(resource_metadata)
discovery_manager = OAuth2DiscoveryManager()
result = await discovery_manager.discover_resource_metadata(
"https://resource.example.com/"
)
assert result == resource_metadata
mock_get.assert_called_once_with(
"https://resource.example.com/.well-known/oauth-protected-resource",
timeout=5,
)
@patch("httpx.AsyncClient.get")
@pytest.mark.asyncio
async def test_discover_resource_metadata_with_path(
self, mock_get, resource_metadata
):
"""Test discovering resource metadata with a resource URL with a path."""
resource_metadata.resource = "https://resource.example.com/tenant1"
mock_get.return_value = self.mock_success_response(resource_metadata)
discovery_manager = OAuth2DiscoveryManager()
result = await discovery_manager.discover_resource_metadata(
"https://resource.example.com/tenant1"
)
assert result == resource_metadata
mock_get.assert_called_once_with(
"https://resource.example.com/.well-known/oauth-protected-resource/tenant1",
timeout=5,
)
@patch("httpx.AsyncClient.get")
@pytest.mark.asyncio
async def test_discover_resource_metadata_discard_mismatched_resource(
self,
mock_get,
resource_metadata,
):
"""Test discover_resource_metadata() discards response with mismatched resource."""
resource_metadata.resource = "https://bad.example.com"
mock_get.return_value = self.mock_success_response(resource_metadata)
discovery_manager = OAuth2DiscoveryManager()
result = await discovery_manager.discover_resource_metadata(
"https://resource.example.com"
)
assert not result
mock_get.assert_called_once_with(
"https://resource.example.com/.well-known/oauth-protected-resource",
timeout=5,
)