feat: Support authentication for MCP tool listing

Currently only tool calling supports MCP auth. This refactors the auth logic into a auth_utils file and uses it for tool listing as well. Fixes https://github.com/google/adk-python/issues/2168.

Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com>
PiperOrigin-RevId: 859201722
This commit is contained in:
Xiang (Sean) Zhou
2026-01-21 11:57:05 -08:00
committed by Copybara-Service
parent d62f9c896c
commit e3d542a5ba
6 changed files with 386 additions and 490 deletions
@@ -1,110 +0,0 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for MCP tool authentication."""
from __future__ import annotations
import base64
import logging
from typing import Dict
from typing import Optional
from fastapi.openapi import models as openapi_models
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import HTTPBase
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
logger = logging.getLogger("google_adk." + __name__)
def get_mcp_auth_headers(
auth_scheme: Optional[AuthScheme], credential: Optional[AuthCredential]
) -> Optional[Dict[str, str]]:
"""Generates HTTP authentication headers for MCP calls.
Args:
auth_scheme: The authentication scheme.
credential: The resolved authentication credential.
Returns:
A dictionary of headers, or None if no auth is applicable.
Raises:
ValueError: If the auth scheme is unsupported or misconfigured.
"""
if not credential:
return None
headers: Optional[Dict[str, str]] = None
if credential.oauth2:
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
elif credential.http:
if not auth_scheme or not isinstance(auth_scheme, HTTPBase):
logger.warning(
"HTTP credential provided, but auth_scheme is missing or not"
" HTTPBase."
)
return None
scheme = auth_scheme.scheme.lower()
if scheme == "bearer" and credential.http.credentials.token:
headers = {"Authorization": f"Bearer {credential.http.credentials.token}"}
elif scheme == "basic":
if (
credential.http.credentials.username
and credential.http.credentials.password
):
creds = f"{credential.http.credentials.username}:{credential.http.credentials.password}"
encoded_creds = base64.b64encode(creds.encode()).decode()
headers = {"Authorization": f"Basic {encoded_creds}"}
else:
logger.warning("Basic auth scheme missing username or password.")
elif credential.http.credentials.token:
# Handle other HTTP schemes like Digest, etc. if token is present
headers = {
"Authorization": (
f"{auth_scheme.scheme} {credential.http.credentials.token}"
)
}
else:
logger.warning(f"Unsupported or incomplete HTTP auth scheme '{scheme}'.")
elif credential.api_key:
if not auth_scheme or not isinstance(auth_scheme, APIKey):
logger.warning(
"API key credential provided, but auth_scheme is missing or not"
" APIKey."
)
return None
if auth_scheme.in_ != openapi_models.APIKeyIn.header:
error_msg = (
"MCP tools only support header-based API key authentication. "
f"Configured location: {auth_scheme.in_}"
)
logger.error(error_msg)
raise ValueError(error_msg)
headers = {auth_scheme.name: credential.api_key}
elif credential.service_account:
logger.warning(
"Service account credentials should be exchanged for an access token "
"before calling get_mcp_auth_headers."
)
else:
logger.warning(f"Unsupported credential type: {type(credential)}")
return headers
+87 -7
View File
@@ -14,6 +14,7 @@
from __future__ import annotations
import base64
import inspect
import logging
from typing import Any
@@ -23,6 +24,7 @@ from typing import Optional
from typing import Union
import warnings
from fastapi.openapi.models import APIKeyIn
from google.genai.types import FunctionDeclaration
from mcp.types import Tool as McpBaseTool
from typing_extensions import override
@@ -37,7 +39,6 @@ from .._gemini_schema_util import _to_gemini_schema
from ..base_authenticated_tool import BaseAuthenticatedTool
# import
from ..tool_context import ToolContext
from .mcp_auth_utils import get_mcp_auth_headers
from .mcp_session_manager import MCPSessionManager
from .mcp_session_manager import retry_on_errors
@@ -194,12 +195,7 @@ class McpTool(BaseAuthenticatedTool):
Any: The response from the tool.
"""
# Extract headers from credential for session pooling
auth_scheme = (
self._auth_config.auth_scheme
if hasattr(self, "_auth_config") and self._auth_config
else None
)
auth_headers = get_mcp_auth_headers(auth_scheme, credential)
auth_headers = await self._get_headers(tool_context, credential)
dynamic_headers = None
if self._header_provider:
dynamic_headers = self._header_provider(
@@ -221,6 +217,90 @@ class McpTool(BaseAuthenticatedTool):
response = await session.call_tool(self._mcp_tool.name, arguments=args)
return response.model_dump(exclude_none=True, mode="json")
async def _get_headers(
self, tool_context: ToolContext, credential: AuthCredential
) -> Optional[dict[str, str]]:
"""Extracts authentication headers from credentials.
Args:
tool_context: The tool context of the current invocation.
credential: The authentication credential to process.
Returns:
Dictionary of headers to add to the request, or None if no auth.
Raises:
ValueError: If API key authentication is configured for non-header location.
"""
headers: Optional[dict[str, str]] = None
if credential:
if credential.oauth2:
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
elif credential.http:
# Handle HTTP authentication schemes
if (
credential.http.scheme.lower() == "bearer"
and credential.http.credentials.token
):
headers = {
"Authorization": f"Bearer {credential.http.credentials.token}"
}
elif credential.http.scheme.lower() == "basic":
# Handle basic auth
if (
credential.http.credentials.username
and credential.http.credentials.password
):
credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}"
encoded_credentials = base64.b64encode(
credentials.encode()
).decode()
headers = {"Authorization": f"Basic {encoded_credentials}"}
elif credential.http.credentials.token:
# Handle other HTTP schemes with token
headers = {
"Authorization": (
f"{credential.http.scheme} {credential.http.credentials.token}"
)
}
elif credential.api_key:
if (
not self._credentials_manager
or not self._credentials_manager._auth_config
):
error_msg = (
"Cannot find corresponding auth scheme for API key credential"
f" {credential}"
)
logger.error(error_msg)
raise ValueError(error_msg)
elif (
self._credentials_manager._auth_config.auth_scheme.in_
!= APIKeyIn.header
):
error_msg = (
"McpTool only supports header-based API key authentication."
" Configured location:"
f" {self._credentials_manager._auth_config.auth_scheme.in_}"
)
logger.error(error_msg)
raise ValueError(error_msg)
else:
headers = {
self._credentials_manager._auth_config.auth_scheme.name: (
credential.api_key
)
}
elif credential.service_account:
# Service accounts should be exchanged for access tokens before reaching this point
logger.warning(
"Service account credentials should be exchanged before MCP"
" session creation"
)
return headers
class MCPTool(McpTool):
"""Deprecated name, use `McpTool` instead."""
+3 -43
View File
@@ -33,14 +33,11 @@ from typing_extensions import override
from ...agents.readonly_context import ReadonlyContext
from ...auth.auth_credential import AuthCredential
from ...auth.auth_schemes import AuthScheme
from ...auth.auth_tool import AuthConfig
from ...auth.credential_manager import CredentialManager
from ..base_tool import BaseTool
from ..base_toolset import BaseToolset
from ..base_toolset import ToolPredicate
from ..tool_configs import BaseToolConfig
from ..tool_configs import ToolArgsConfig
from .mcp_auth_utils import get_mcp_auth_headers
from .mcp_session_manager import MCPSessionManager
from .mcp_session_manager import retry_on_errors
from .mcp_session_manager import SseConnectionParams
@@ -157,50 +154,13 @@ class McpToolset(BaseToolset):
Returns:
List[BaseTool]: A list of tools available under the specified context.
"""
provided_headers = (
headers = (
self._header_provider(readonly_context)
if self._header_provider and readonly_context
else {}
else None
)
auth_headers = {}
if self._auth_scheme:
try:
# Instantiate CredentialsManager to resolve credentials
auth_config = AuthConfig(
auth_scheme=self._auth_scheme,
raw_auth_credential=self._auth_credential,
)
credentials_manager = CredentialManager(auth_config)
# Resolve the credential
resolved_credential = await credentials_manager.get_auth_credential(
readonly_context
)
if resolved_credential:
auth_headers = get_mcp_auth_headers(
self._auth_scheme, resolved_credential
)
else:
logger.warning(
"Failed to resolve credential for tool listing, proceeding"
" without auth headers."
)
except Exception as e:
logger.warning(
"Error generating auth headers for tool listing: %s, proceeding"
" without auth headers.",
e,
exc_info=True,
)
merged_headers = {**(provided_headers or {}), **(auth_headers or {})}
# Get session from session manager
session = await self._mcp_session_manager.create_session(
headers=merged_headers
)
session = await self._mcp_session_manager.create_session(headers=headers)
# Fetch available tools from the MCP server
timeout_in_seconds = (
@@ -1,240 +0,0 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
from unittest.mock import patch
from fastapi.openapi import models as openapi_models
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_credential import ServiceAccount
from google.adk.auth.auth_schemes import AuthSchemeType
from google.adk.tools.mcp_tool import mcp_auth_utils
import pytest
def test_get_mcp_auth_headers_no_credential():
"""Test header generation with no credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=None
)
assert headers is None
def test_get_mcp_auth_headers_no_auth_scheme():
"""Test header generation with no auth_scheme."""
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="test_token"),
)
with patch.object(mcp_auth_utils, "logger") as mock_logger:
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=None, credential=credential
)
assert headers == {"Authorization": "Bearer test_token"}
def test_get_mcp_auth_headers_oauth2():
"""Test header generation for OAuth2 credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="test_token"),
)
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers == {"Authorization": "Bearer test_token"}
def test_get_mcp_auth_headers_http_bearer():
"""Test header generation for HTTP Bearer credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
),
)
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers == {"Authorization": "Bearer bearer_token"}
def test_get_mcp_auth_headers_http_basic():
"""Test header generation for HTTP Basic credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="basic")
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="basic",
credentials=HttpCredentials(username="user", password="pass"),
),
)
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
expected_encoded = base64.b64encode(b"user:pass").decode()
assert headers == {"Authorization": f"Basic {expected_encoded}"}
def test_get_mcp_auth_headers_http_basic_missing_credentials():
"""Test header generation for HTTP Basic with missing credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="basic")
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="basic",
credentials=HttpCredentials(username="user", password=None),
),
)
with patch.object(mcp_auth_utils, "logger") as mock_logger:
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers is None
mock_logger.warning.assert_called_once_with(
"Basic auth scheme missing username or password."
)
def test_get_mcp_auth_headers_http_custom_scheme():
"""Test header generation for custom HTTP scheme."""
auth_scheme = openapi_models.HTTPBase(scheme="custom")
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="custom", credentials=HttpCredentials(token="custom_token")
),
)
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers == {"Authorization": "custom custom_token"}
def test_get_mcp_auth_headers_http_cred_wrong_scheme():
"""Test HTTP credential with non-HTTPBase auth scheme."""
auth_scheme = openapi_models.APIKey(**{
"type": AuthSchemeType.apiKey,
"in": openapi_models.APIKeyIn.header,
"name": "X-API-Key",
})
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP,
http=HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
),
)
with patch.object(mcp_auth_utils, "logger") as mock_logger:
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers is None
mock_logger.warning.assert_called_once_with(
"HTTP credential provided, but auth_scheme is missing or not HTTPBase."
)
def test_get_mcp_auth_headers_api_key_header():
"""Test header generation for API Key in header."""
auth_scheme = openapi_models.APIKey(**{
"type": AuthSchemeType.apiKey,
"in": openapi_models.APIKeyIn.header,
"name": "X-Custom-API-Key",
})
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers == {"X-Custom-API-Key": "my_api_key"}
def test_get_mcp_auth_headers_api_key_query_raises_error():
"""Test API Key in query raises ValueError."""
auth_scheme = openapi_models.APIKey(**{
"type": AuthSchemeType.apiKey,
"in": openapi_models.APIKeyIn.query,
"name": "api_key",
})
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
with pytest.raises(
ValueError,
match="MCP tools only support header-based API key authentication.",
):
mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
def test_get_mcp_auth_headers_api_key_cookie_raises_error():
"""Test API Key in cookie raises ValueError."""
auth_scheme = openapi_models.APIKey(**{
"type": AuthSchemeType.apiKey,
"in": openapi_models.APIKeyIn.cookie,
"name": "session_id",
})
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
with pytest.raises(
ValueError,
match="MCP tools only support header-based API key authentication.",
):
mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
def test_get_mcp_auth_headers_api_key_cred_wrong_scheme():
"""Test API key credential with non-APIKey auth scheme."""
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
with patch.object(mcp_auth_utils, "logger") as mock_logger:
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers is None
mock_logger.warning.assert_called_once_with(
"API key credential provided, but auth_scheme is missing or not APIKey."
)
def test_get_mcp_auth_headers_service_account():
"""Test header generation for service account credentials."""
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=ServiceAccount(scopes=["test"]),
)
with patch.object(mcp_auth_utils, "logger") as mock_logger:
headers = mcp_auth_utils.get_mcp_auth_headers(
auth_scheme=auth_scheme, credential=credential
)
assert headers is None
mock_logger.warning.assert_called_once_with(
"Service account credentials should be exchanged for an access "
"token before calling get_mcp_auth_headers."
)
@@ -18,7 +18,10 @@ from unittest.mock import patch
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import HttpAuth
from google.adk.auth.auth_credential import HttpCredentials
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_credential import ServiceAccount
from google.adk.features import FeatureName
from google.adk.features._feature_registry import temporary_feature_override
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
@@ -258,6 +261,240 @@ class TestMCPTool:
headers = call_args[1]["headers"]
assert headers == {"Authorization": "Bearer test_access_token"}
@pytest.mark.asyncio
async def test_get_headers_oauth2(self):
"""Test header generation for OAuth2 credentials."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
oauth2_auth = OAuth2Auth(access_token="test_token")
credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)
assert headers == {"Authorization": "Bearer test_token"}
@pytest.mark.asyncio
async def test_get_headers_http_bearer(self):
"""Test header generation for HTTP Bearer credentials."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
http_auth = HttpAuth(
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, http=http_auth
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)
assert headers == {"Authorization": "Bearer bearer_token"}
@pytest.mark.asyncio
async def test_get_headers_http_basic(self):
"""Test header generation for HTTP Basic credentials."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
http_auth = HttpAuth(
scheme="basic",
credentials=HttpCredentials(username="user", password="pass"),
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, http=http_auth
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)
# Should create Basic auth header with base64 encoded credentials
import base64
expected_encoded = base64.b64encode(b"user:pass").decode()
assert headers == {"Authorization": f"Basic {expected_encoded}"}
@pytest.mark.asyncio
async def test_get_headers_api_key_with_valid_header_scheme(self):
"""Test header generation for API Key credentials with header-based auth scheme."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType
# Create auth scheme for header-based API key
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.header,
"name": "X-Custom-API-Key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, auth_credential)
assert headers == {"X-Custom-API-Key": "my_api_key"}
@pytest.mark.asyncio
async def test_get_headers_api_key_with_query_scheme_raises_error(self):
"""Test that API Key with query-based auth scheme raises ValueError."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType
# Create auth scheme for query-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.query,
"name": "api_key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
tool_context = Mock(spec=ToolContext)
with pytest.raises(
ValueError,
match="McpTool only supports header-based API key authentication",
):
await tool._get_headers(tool_context, auth_credential)
@pytest.mark.asyncio
async def test_get_headers_api_key_with_cookie_scheme_raises_error(self):
"""Test that API Key with cookie-based auth scheme raises ValueError."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType
# Create auth scheme for cookie-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.cookie,
"name": "session_id",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
tool_context = Mock(spec=ToolContext)
with pytest.raises(
ValueError,
match="McpTool only supports header-based API key authentication",
):
await tool._get_headers(tool_context, auth_credential)
@pytest.mark.asyncio
async def test_get_headers_api_key_without_auth_config_raises_error(self):
"""Test that API Key without auth config raises ValueError."""
# Create tool without auth scheme/config
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool_context = Mock(spec=ToolContext)
with pytest.raises(
ValueError,
match="Cannot find corresponding auth scheme for API key credential",
):
await tool._get_headers(tool_context, credential)
@pytest.mark.asyncio
async def test_get_headers_api_key_without_credentials_manager_raises_error(
self,
):
"""Test that API Key without credentials manager raises ValueError."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
# Manually set credentials manager to None to simulate error condition
tool._credentials_manager = None
credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool_context = Mock(spec=ToolContext)
with pytest.raises(
ValueError,
match="Cannot find corresponding auth scheme for API key credential",
):
await tool._get_headers(tool_context, credential)
@pytest.mark.asyncio
async def test_get_headers_no_credential(self):
"""Test header generation with no credentials."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, None)
assert headers is None
@pytest.mark.asyncio
async def test_get_headers_service_account(self):
"""Test header generation for service account credentials."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
# Create service account credential
service_account = ServiceAccount(scopes=["test"])
credential = AuthCredential(
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
service_account=service_account,
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)
# Should return None as service account credentials are not supported for direct header generation
assert headers is None
@pytest.mark.asyncio
async def test_run_async_impl_with_api_key_header_auth(self):
"""Test running tool with API key header authentication end-to-end."""
@@ -314,6 +551,65 @@ class TestMCPTool:
# Check that the method has the retry decorator
assert hasattr(tool._run_async_impl, "__wrapped__")
@pytest.mark.asyncio
async def test_get_headers_http_custom_scheme(self):
"""Test header generation for custom HTTP scheme."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
)
http_auth = HttpAuth(
scheme="custom", credentials=HttpCredentials(token="custom_token")
)
credential = AuthCredential(
auth_type=AuthCredentialTypes.HTTP, http=http_auth
)
tool_context = Mock(spec=ToolContext)
headers = await tool._get_headers(tool_context, credential)
assert headers == {"Authorization": "custom custom_token"}
@pytest.mark.asyncio
async def test_get_headers_api_key_error_logging(self):
"""Test that API key errors are logged correctly."""
from fastapi.openapi.models import APIKey
from fastapi.openapi.models import APIKeyIn
from google.adk.auth.auth_schemes import AuthSchemeType
# Create auth scheme for query-based API key (not supported)
auth_scheme = APIKey(**{
"type": AuthSchemeType.apiKey,
"in": APIKeyIn.query,
"name": "api_key",
})
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
)
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
tool_context = Mock(spec=ToolContext)
# Test with logging
with patch("google.adk.tools.mcp_tool.mcp_tool.logger") as mock_logger:
with pytest.raises(ValueError):
await tool._get_headers(tool_context, auth_credential)
# Verify error was logged
mock_logger.error.assert_called_once()
logged_message = mock_logger.error.call_args[0][0]
assert (
"McpTool only supports header-based API key authentication"
in logged_message
)
@pytest.mark.asyncio
async def test_run_async_require_confirmation_true_no_confirmation(self):
"""Test require_confirmation=True with no confirmation in context."""
@@ -30,8 +30,6 @@ from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnecti
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
from google.adk.tools.mcp_tool.mcp_toolset import McpToolsetConfig
from google.adk.tools.tool_configs import ToolArgsConfig
from mcp import StdioServerParameters
import pytest
@@ -247,94 +245,6 @@ class TestMCPToolset:
headers=expected_headers
)
@pytest.mark.asyncio
async def test_get_tools_with_auth_headers(self):
"""Test get_tools with auth headers."""
from fastapi.openapi import models as openapi_models
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
mock_tools = [MockMCPTool("tool1")]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)
mock_readonly_context = Mock(spec=ReadonlyContext)
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="test_token"),
)
with patch(
"google.adk.tools.mcp_tool.mcp_toolset.CredentialManager"
) as MockCredentialManager:
mock_manager_instance = MockCredentialManager.return_value
mock_manager_instance.get_auth_credential = AsyncMock(
return_value=auth_credential
)
toolset = MCPToolset(
connection_params=self.mock_stdio_params,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
toolset._mcp_session_manager = self.mock_session_manager
await toolset.get_tools(readonly_context=mock_readonly_context)
self.mock_session_manager.create_session.assert_called_once()
call_args = self.mock_session_manager.create_session.call_args
headers = call_args[1]["headers"]
assert headers == {"Authorization": "Bearer test_token"}
@pytest.mark.asyncio
async def test_get_tools_with_auth_and_header_provider(self):
"""Test get_tools with auth and header_provider."""
from fastapi.openapi import models as openapi_models
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
mock_tools = [MockMCPTool("tool1")]
self.mock_session.list_tools = AsyncMock(
return_value=MockListToolsResult(mock_tools)
)
mock_readonly_context = Mock(spec=ReadonlyContext)
provided_headers = {"X-Tenant-ID": "test-tenant"}
header_provider = Mock(return_value=provided_headers)
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(access_token="test_token"),
)
with patch(
"google.adk.tools.mcp_tool.mcp_toolset.CredentialManager"
) as MockCredentialManager:
mock_manager_instance = MockCredentialManager.return_value
mock_manager_instance.get_auth_credential = AsyncMock(
return_value=auth_credential
)
toolset = MCPToolset(
connection_params=self.mock_stdio_params,
auth_scheme=auth_scheme,
auth_credential=auth_credential,
header_provider=header_provider,
)
toolset._mcp_session_manager = self.mock_session_manager
await toolset.get_tools(readonly_context=mock_readonly_context)
self.mock_session_manager.create_session.assert_called_once()
call_args = self.mock_session_manager.create_session.call_args
headers = call_args[1]["headers"]
assert headers == {
"X-Tenant-ID": "test-tenant",
"Authorization": "Bearer test_token",
}
@pytest.mark.asyncio
async def test_close_success(self):
"""Test successful cleanup."""