feat: Support authentication for MCP tool listing

Currently only tool calling supports MCP auth. This refactors the auth logic into a auth_utils file and uses it for tool listing as well. Fixes https://github.com/google/adk-python/issues/2168.

Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com>
PiperOrigin-RevId: 859201722
This commit is contained in:
Xiang (Sean) Zhou
2026-01-21 11:57:05 -08:00
committed by Copybara-Service
parent d62f9c896c
commit e3d542a5ba
6 changed files with 386 additions and 490 deletions
@@ -1,240 +0,0 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from unittest.mock import patch
from fastapi.openapi import models as openapi_models
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_credential import ServiceAccount
from google.adk.auth.auth_schemes import AuthSchemeType
from google.adk.tools.mcp_tool import mcp_auth_utils
import pytest
def test_get_mcp_auth_headers_no_credential():
"""Test header generation with no credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=None
)
assert headers is None
def test_get_mcp_auth_headers_no_auth_scheme():
"""Test header generation with no auth_scheme."""
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="test_token"),
)
with patch.object(mcp_auth_utils, "logger") as mock_logger:
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=None, credential=credential
)
assert headers == {"Authorization": "Bearer test_token"}
def test_get_mcp_auth_headers_oauth2():
"""Test header generation for OAuth2 credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="test_token"),
)
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers == {"Authorization": "Bearer test_token"}
def test_get_mcp_auth_headers_http_bearer():
"""Test header generation for HTTP Bearer credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
),
)
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers == {"Authorization": "Bearer bearer_token"}
def test_get_mcp_auth_headers_http_basic():
"""Test header generation for HTTP Basic credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="basic")
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="basic",
credentials=HttpCredentials(username="user", password="pass"),
),
)
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
expected_encoded = base64.b64encode(b"user:pass").decode()
assert headers == {"Authorization": f"Basic {expected_encoded}"}
def test_get_mcp_auth_headers_http_basic_missing_credentials():
"""Test header generation for HTTP Basic with missing credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="basic")
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="basic",
credentials=HttpCredentials(username="user", password=None),
),
)
with patch.object(mcp_auth_utils, "logger") as mock_logger:
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers is None
mock_logger.warning.assert_called_once_with(
"Basic auth scheme missing username or password."
)
def test_get_mcp_auth_headers_http_custom_scheme():
"""Test header generation for custom HTTP scheme."""
auth_scheme = openapi_models.HTTPBase(scheme="custom")
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="custom", credentials=HttpCredentials(token="custom_token")
),
)
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers == {"Authorization": "custom custom_token"}
def test_get_mcp_auth_headers_http_cred_wrong_scheme():
"""Test HTTP credential with non-HTTPBase auth scheme."""
auth_scheme = openapi_models.APIKey(**{
"type": AuthSchemeType.apiKey,
"in": openapi_models.APIKeyIn.header,
"name": "X-API-Key",
})
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
),
)
with patch.object(mcp_auth_utils, "logger") as mock_logger:
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers is None
mock_logger.warning.assert_called_once_with(
"HTTP credential provided, but auth_scheme is missing or not HTTPBase."
)
def test_get_mcp_auth_headers_api_key_header():
"""Test header generation for API Key in header."""
auth_scheme = openapi_models.APIKey(**{
"type": AuthSchemeType.apiKey,
"in": openapi_models.APIKeyIn.header,
"name": "X-Custom-API-Key",
})
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers == {"X-Custom-API-Key": "my_api_key"}
def test_get_mcp_auth_headers_api_key_query_raises_error():
"""Test API Key in query raises ValueError."""
auth_scheme = openapi_models.APIKey(**{
"type": AuthSchemeType.apiKey,
"in": openapi_models.APIKeyIn.query,
"name": "api_key",
})
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
with pytest.raises(
ValueError,
match="MCP tools only support header-based API key authentication.",
):
mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
def test_get_mcp_auth_headers_api_key_cookie_raises_error():
"""Test API Key in cookie raises ValueError."""
auth_scheme = openapi_models.APIKey(**{
"type": AuthSchemeType.apiKey,
"in": openapi_models.APIKeyIn.cookie,
"name": "session_id",
})
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
with pytest.raises(
ValueError,
match="MCP tools only support header-based API key authentication.",
):
mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
def test_get_mcp_auth_headers_api_key_cred_wrong_scheme():
"""Test API key credential with non-APIKey auth scheme."""
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
with patch.object(mcp_auth_utils, "logger") as mock_logger:
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers is None
mock_logger.warning.assert_called_once_with(
"API key credential provided, but auth_scheme is missing or not APIKey."
)
def test_get_mcp_auth_headers_service_account():
"""Test header generation for service account credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(scopes=["test"]),
)
with patch.object(mcp_auth_utils, "logger") as mock_logger:
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers is None
mock_logger.warning.assert_called_once_with(
"Service account credentials should be exchanged for an access "
"token before calling get_mcp_auth_headers."
)
@@ -18,7 +18,10 @@ from unittest.mock import patch
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_credential import ServiceAccount
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
@@ -258,6 +261,240 @@ class TestMCPTool:
headers = call_args[1]["headers"]
assert headers == {"Authorization": "Bearer test_access_token"}
@pytest.mark.asyncio
async def test_get_headers_oauth2(self):
"""Test header generation for OAuth2 credentials."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
oauth2_auth = OAuth2Auth(access_token="test_token")
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)
assert headers == {"Authorization": "Bearer test_token"}
@pytest.mark.asyncio
async def test_get_headers_http_bearer(self):
"""Test header generation for HTTP Bearer credentials."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
http_auth = HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, http=http_auth
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)
assert headers == {"Authorization": "Bearer bearer_token"}
@pytest.mark.asyncio
async def test_get_headers_http_basic(self):
"""Test header generation for HTTP Basic credentials."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
http_auth = HttpAuth(
scheme="basic",
credentials=HttpCredentials(username="user", password="pass"),
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, http=http_auth
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)
# Should create Basic auth header with base64 encoded credentials
import base64
expected_encoded = base64.b64encode(b"user:pass").decode()
assert headers == {"Authorization": f"Basic {expected_encoded}"}
@pytest.mark.asyncio
async def test_get_headers_api_key_with_valid_header_scheme(self):
"""Test header generation for API Key credentials with header-based auth scheme."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType
# Create auth scheme for header-based API key
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.header,
"name": "X-Custom-API-Key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, auth_credential)
assert headers == {"X-Custom-API-Key": "my_api_key"}
@pytest.mark.asyncio
async def test_get_headers_api_key_with_query_scheme_raises_error(self):
"""Test that API Key with query-based auth scheme raises ValueError."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType
# Create auth scheme for query-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.query,
"name": "api_key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
tool_context = Mock(spec=ToolContext)
with pytest.raises(
ValueError,
match="McpTool only supports header-based API key authentication",
):
await tool._get_headers(tool_context, auth_credential)
@pytest.mark.asyncio
async def test_get_headers_api_key_with_cookie_scheme_raises_error(self):
"""Test that API Key with cookie-based auth scheme raises ValueError."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType
# Create auth scheme for cookie-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.cookie,
"name": "session_id",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
tool_context = Mock(spec=ToolContext)
with pytest.raises(
ValueError,
match="McpTool only supports header-based API key authentication",
):
await tool._get_headers(tool_context, auth_credential)
@pytest.mark.asyncio
async def test_get_headers_api_key_without_auth_config_raises_error(self):
"""Test that API Key without auth config raises ValueError."""
# Create tool without auth scheme/config
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool_context = Mock(spec=ToolContext)
with pytest.raises(
ValueError,
match="Cannot find corresponding auth scheme for API key credential",
):
await tool._get_headers(tool_context, credential)
@pytest.mark.asyncio
async def test_get_headers_api_key_without_credentials_manager_raises_error(
self,
):
"""Test that API Key without credentials manager raises ValueError."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
# Manually set credentials manager to None to simulate error condition
tool._credentials_manager = None
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool_context = Mock(spec=ToolContext)
with pytest.raises(
ValueError,
match="Cannot find corresponding auth scheme for API key credential",
):
await tool._get_headers(tool_context, credential)
@pytest.mark.asyncio
async def test_get_headers_no_credential(self):
"""Test header generation with no credentials."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, None)
assert headers is None
@pytest.mark.asyncio
async def test_get_headers_service_account(self):
"""Test header generation for service account credentials."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
# Create service account credential
service_account = ServiceAccount(scopes=["test"])
credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=service_account,
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)
# Should return None as service account credentials are not supported for direct header generation
assert headers is None
@pytest.mark.asyncio
async def test_run_async_impl_with_api_key_header_auth(self):
"""Test running tool with API key header authentication end-to-end."""
@@ -314,6 +551,65 @@ class TestMCPTool:
# Check that the method has the retry decorator
assert hasattr(tool._run_async_impl, "__wrapped__")
@pytest.mark.asyncio
async def test_get_headers_http_custom_scheme(self):
"""Test header generation for custom HTTP scheme."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
http_auth = HttpAuth(
scheme="custom", credentials=HttpCredentials(token="custom_token")
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, http=http_auth
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)
assert headers == {"Authorization": "custom custom_token"}
@pytest.mark.asyncio
async def test_get_headers_api_key_error_logging(self):
"""Test that API key errors are logged correctly."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType
# Create auth scheme for query-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.query,
"name": "api_key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
tool_context = Mock(spec=ToolContext)
# Test with logging
with patch("google.adk.tools.mcp_tool.mcp_tool.logger") as mock_logger:
with pytest.raises(ValueError):
await tool._get_headers(tool_context, auth_credential)
# Verify error was logged
mock_logger.error.assert_called_once()
logged_message = mock_logger.error.call_args[0][0]
assert (
"McpTool only supports header-based API key authentication"
in logged_message
)
@pytest.mark.asyncio
async def test_run_async_require_confirmation_true_no_confirmation(self):
"""Test require_confirmation=True with no confirmation in context."""
@@ -30,8 +30,6 @@ from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnecti
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
from google.adk.tools.mcp_tool.mcp_toolset import McpToolsetConfig
from google.adk.tools.tool_configs import ToolArgsConfig
from mcp import StdioServerParameters
import pytest
@@ -247,94 +245,6 @@ class TestMCPToolset:
headers=expected_headers
)
@pytest.mark.asyncio
async def test_get_tools_with_auth_headers(self):
"""Test get_tools with auth headers."""
from fastapi.openapi import models as openapi_models
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
mock_tools = [MockMCPTool("tool1")]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)
mock_readonly_context = Mock(spec=ReadonlyContext)
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="test_token"),
)
with patch(
"google.adk.tools.mcp_tool.mcp_toolset.CredentialManager"
) as MockCredentialManager:
mock_manager_instance = MockCredentialManager.return_value
mock_manager_instance.get_auth_credential = AsyncMock(
return_value=auth_credential
)
toolset = MCPToolset(
connection_params=self.mock_stdio_params,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
toolset._mcp_session_manager = self.mock_session_manager
await toolset.get_tools(readonly_context=mock_readonly_context)
self.mock_session_manager.create_session.assert_called_once()
call_args = self.mock_session_manager.create_session.call_args
headers = call_args[1]["headers"]
assert headers == {"Authorization": "Bearer test_token"}
@pytest.mark.asyncio
async def test_get_tools_with_auth_and_header_provider(self):
"""Test get_tools with auth and header_provider."""
from fastapi.openapi import models as openapi_models
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
mock_tools = [MockMCPTool("tool1")]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)
mock_readonly_context = Mock(spec=ReadonlyContext)
provided_headers = {"X-Tenant-ID": "test-tenant"}
header_provider = Mock(return_value=provided_headers)
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="test_token"),
)
with patch(
"google.adk.tools.mcp_tool.mcp_toolset.CredentialManager"
) as MockCredentialManager:
mock_manager_instance = MockCredentialManager.return_value
mock_manager_instance.get_auth_credential = AsyncMock(
return_value=auth_credential
)
toolset = MCPToolset(
connection_params=self.mock_stdio_params,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
header_provider=header_provider,
)
toolset._mcp_session_manager = self.mock_session_manager
await toolset.get_tools(readonly_context=mock_readonly_context)
self.mock_session_manager.create_session.assert_called_once()
call_args = self.mock_session_manager.create_session.call_args
headers = call_args[1]["headers"]
assert headers == {
"X-Tenant-ID": "test-tenant",
"Authorization": "Bearer test_token",
}
@pytest.mark.asyncio
async def test_close_success(self):
"""Test successful cleanup."""