You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Support authentication for MCP tool listing
Currently only tool calling supports MCP auth. This refactors the auth logic into a auth_utils file and uses it for tool listing as well. Fixes https://github.com/google/adk-python/issues/2168. Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 859201722
This commit is contained in:
committed by
Copybara-Service
parent
d62f9c896c
commit
e3d542a5ba
@@ -1,110 +0,0 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility functions for MCP tool authentication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from fastapi.openapi import models as openapi_models
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import HTTPBase
|
||||
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
|
||||
def get_mcp_auth_headers(
|
||||
auth_scheme: Optional[AuthScheme], credential: Optional[AuthCredential]
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""Generates HTTP authentication headers for MCP calls.
|
||||
|
||||
Args:
|
||||
auth_scheme: The authentication scheme.
|
||||
credential: The resolved authentication credential.
|
||||
|
||||
Returns:
|
||||
A dictionary of headers, or None if no auth is applicable.
|
||||
|
||||
Raises:
|
||||
ValueError: If the auth scheme is unsupported or misconfigured.
|
||||
"""
|
||||
if not credential:
|
||||
return None
|
||||
|
||||
headers: Optional[Dict[str, str]] = None
|
||||
|
||||
if credential.oauth2:
|
||||
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
|
||||
elif credential.http:
|
||||
if not auth_scheme or not isinstance(auth_scheme, HTTPBase):
|
||||
logger.warning(
|
||||
"HTTP credential provided, but auth_scheme is missing or not"
|
||||
" HTTPBase."
|
||||
)
|
||||
return None
|
||||
|
||||
scheme = auth_scheme.scheme.lower()
|
||||
if scheme == "bearer" and credential.http.credentials.token:
|
||||
headers = {"Authorization": f"Bearer {credential.http.credentials.token}"}
|
||||
elif scheme == "basic":
|
||||
if (
|
||||
credential.http.credentials.username
|
||||
and credential.http.credentials.password
|
||||
):
|
||||
creds = f"{credential.http.credentials.username}:{credential.http.credentials.password}"
|
||||
encoded_creds = base64.b64encode(creds.encode()).decode()
|
||||
headers = {"Authorization": f"Basic {encoded_creds}"}
|
||||
else:
|
||||
logger.warning("Basic auth scheme missing username or password.")
|
||||
elif credential.http.credentials.token:
|
||||
# Handle other HTTP schemes like Digest, etc. if token is present
|
||||
headers = {
|
||||
"Authorization": (
|
||||
f"{auth_scheme.scheme} {credential.http.credentials.token}"
|
||||
)
|
||||
}
|
||||
else:
|
||||
logger.warning(f"Unsupported or incomplete HTTP auth scheme '{scheme}'.")
|
||||
elif credential.api_key:
|
||||
if not auth_scheme or not isinstance(auth_scheme, APIKey):
|
||||
logger.warning(
|
||||
"API key credential provided, but auth_scheme is missing or not"
|
||||
" APIKey."
|
||||
)
|
||||
return None
|
||||
|
||||
if auth_scheme.in_ != openapi_models.APIKeyIn.header:
|
||||
error_msg = (
|
||||
"MCP tools only support header-based API key authentication. "
|
||||
f"Configured location: {auth_scheme.in_}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
headers = {auth_scheme.name: credential.api_key}
|
||||
elif credential.service_account:
|
||||
logger.warning(
|
||||
"Service account credentials should be exchanged for an access token "
|
||||
"before calling get_mcp_auth_headers."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unsupported credential type: {type(credential)}")
|
||||
|
||||
return headers
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any
|
||||
@@ -23,6 +24,7 @@ from typing import Optional
|
||||
from typing import Union
|
||||
import warnings
|
||||
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.genai.types import FunctionDeclaration
|
||||
from mcp.types import Tool as McpBaseTool
|
||||
from typing_extensions import override
|
||||
@@ -37,7 +39,6 @@ from .._gemini_schema_util import _to_gemini_schema
|
||||
from ..base_authenticated_tool import BaseAuthenticatedTool
|
||||
# import
|
||||
from ..tool_context import ToolContext
|
||||
from .mcp_auth_utils import get_mcp_auth_headers
|
||||
from .mcp_session_manager import MCPSessionManager
|
||||
from .mcp_session_manager import retry_on_errors
|
||||
|
||||
@@ -194,12 +195,7 @@ class McpTool(BaseAuthenticatedTool):
|
||||
Any: The response from the tool.
|
||||
"""
|
||||
# Extract headers from credential for session pooling
|
||||
auth_scheme = (
|
||||
self._auth_config.auth_scheme
|
||||
if hasattr(self, "_auth_config") and self._auth_config
|
||||
else None
|
||||
)
|
||||
auth_headers = get_mcp_auth_headers(auth_scheme, credential)
|
||||
auth_headers = await self._get_headers(tool_context, credential)
|
||||
dynamic_headers = None
|
||||
if self._header_provider:
|
||||
dynamic_headers = self._header_provider(
|
||||
@@ -221,6 +217,90 @@ class McpTool(BaseAuthenticatedTool):
|
||||
response = await session.call_tool(self._mcp_tool.name, arguments=args)
|
||||
return response.model_dump(exclude_none=True, mode="json")
|
||||
|
||||
async def _get_headers(
|
||||
self, tool_context: ToolContext, credential: AuthCredential
|
||||
) -> Optional[dict[str, str]]:
|
||||
"""Extracts authentication headers from credentials.
|
||||
|
||||
Args:
|
||||
tool_context: The tool context of the current invocation.
|
||||
credential: The authentication credential to process.
|
||||
|
||||
Returns:
|
||||
Dictionary of headers to add to the request, or None if no auth.
|
||||
|
||||
Raises:
|
||||
ValueError: If API key authentication is configured for non-header location.
|
||||
"""
|
||||
headers: Optional[dict[str, str]] = None
|
||||
if credential:
|
||||
if credential.oauth2:
|
||||
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
|
||||
elif credential.http:
|
||||
# Handle HTTP authentication schemes
|
||||
if (
|
||||
credential.http.scheme.lower() == "bearer"
|
||||
and credential.http.credentials.token
|
||||
):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {credential.http.credentials.token}"
|
||||
}
|
||||
elif credential.http.scheme.lower() == "basic":
|
||||
# Handle basic auth
|
||||
if (
|
||||
credential.http.credentials.username
|
||||
and credential.http.credentials.password
|
||||
):
|
||||
|
||||
credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}"
|
||||
encoded_credentials = base64.b64encode(
|
||||
credentials.encode()
|
||||
).decode()
|
||||
headers = {"Authorization": f"Basic {encoded_credentials}"}
|
||||
elif credential.http.credentials.token:
|
||||
# Handle other HTTP schemes with token
|
||||
headers = {
|
||||
"Authorization": (
|
||||
f"{credential.http.scheme} {credential.http.credentials.token}"
|
||||
)
|
||||
}
|
||||
elif credential.api_key:
|
||||
if (
|
||||
not self._credentials_manager
|
||||
or not self._credentials_manager._auth_config
|
||||
):
|
||||
error_msg = (
|
||||
"Cannot find corresponding auth scheme for API key credential"
|
||||
f" {credential}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
elif (
|
||||
self._credentials_manager._auth_config.auth_scheme.in_
|
||||
!= APIKeyIn.header
|
||||
):
|
||||
error_msg = (
|
||||
"McpTool only supports header-based API key authentication."
|
||||
" Configured location:"
|
||||
f" {self._credentials_manager._auth_config.auth_scheme.in_}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
else:
|
||||
headers = {
|
||||
self._credentials_manager._auth_config.auth_scheme.name: (
|
||||
credential.api_key
|
||||
)
|
||||
}
|
||||
elif credential.service_account:
|
||||
# Service accounts should be exchanged for access tokens before reaching this point
|
||||
logger.warning(
|
||||
"Service account credentials should be exchanged before MCP"
|
||||
" session creation"
|
||||
)
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
class MCPTool(McpTool):
|
||||
"""Deprecated name, use `McpTool` instead."""
|
||||
|
||||
@@ -33,14 +33,11 @@ from typing_extensions import override
|
||||
from ...agents.readonly_context import ReadonlyContext
|
||||
from ...auth.auth_credential import AuthCredential
|
||||
from ...auth.auth_schemes import AuthScheme
|
||||
from ...auth.auth_tool import AuthConfig
|
||||
from ...auth.credential_manager import CredentialManager
|
||||
from ..base_tool import BaseTool
|
||||
from ..base_toolset import BaseToolset
|
||||
from ..base_toolset import ToolPredicate
|
||||
from ..tool_configs import BaseToolConfig
|
||||
from ..tool_configs import ToolArgsConfig
|
||||
from .mcp_auth_utils import get_mcp_auth_headers
|
||||
from .mcp_session_manager import MCPSessionManager
|
||||
from .mcp_session_manager import retry_on_errors
|
||||
from .mcp_session_manager import SseConnectionParams
|
||||
@@ -157,50 +154,13 @@ class McpToolset(BaseToolset):
|
||||
Returns:
|
||||
List[BaseTool]: A list of tools available under the specified context.
|
||||
"""
|
||||
provided_headers = (
|
||||
headers = (
|
||||
self._header_provider(readonly_context)
|
||||
if self._header_provider and readonly_context
|
||||
else {}
|
||||
else None
|
||||
)
|
||||
|
||||
auth_headers = {}
|
||||
if self._auth_scheme:
|
||||
try:
|
||||
# Instantiate CredentialsManager to resolve credentials
|
||||
auth_config = AuthConfig(
|
||||
auth_scheme=self._auth_scheme,
|
||||
raw_auth_credential=self._auth_credential,
|
||||
)
|
||||
credentials_manager = CredentialManager(auth_config)
|
||||
|
||||
# Resolve the credential
|
||||
resolved_credential = await credentials_manager.get_auth_credential(
|
||||
readonly_context
|
||||
)
|
||||
|
||||
if resolved_credential:
|
||||
auth_headers = get_mcp_auth_headers(
|
||||
self._auth_scheme, resolved_credential
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to resolve credential for tool listing, proceeding"
|
||||
" without auth headers."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error generating auth headers for tool listing: %s, proceeding"
|
||||
" without auth headers.",
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
merged_headers = {**(provided_headers or {}), **(auth_headers or {})}
|
||||
|
||||
# Get session from session manager
|
||||
session = await self._mcp_session_manager.create_session(
|
||||
headers=merged_headers
|
||||
)
|
||||
session = await self._mcp_session_manager.create_session(headers=headers)
|
||||
|
||||
# Fetch available tools from the MCP server
|
||||
timeout_in_seconds = (
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.openapi import models as openapi_models
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.tools.mcp_tool import mcp_auth_utils
|
||||
import pytest
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_no_credential():
|
||||
"""Test header generation with no credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=None
|
||||
)
|
||||
assert headers is None
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_no_auth_scheme():
|
||||
"""Test header generation with no auth_scheme."""
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(access_token="test_token"),
|
||||
)
|
||||
with patch.object(mcp_auth_utils, "logger") as mock_logger:
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=None, credential=credential
|
||||
)
|
||||
assert headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_oauth2():
|
||||
"""Test header generation for OAuth2 credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(access_token="test_token"),
|
||||
)
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_http_bearer():
|
||||
"""Test header generation for HTTP Bearer credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
|
||||
),
|
||||
)
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers == {"Authorization": "Bearer bearer_token"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_http_basic():
|
||||
"""Test header generation for HTTP Basic credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="basic")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password="pass"),
|
||||
),
|
||||
)
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
expected_encoded = base64.b64encode(b"user:pass").decode()
|
||||
assert headers == {"Authorization": f"Basic {expected_encoded}"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_http_basic_missing_credentials():
|
||||
"""Test header generation for HTTP Basic with missing credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="basic")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password=None),
|
||||
),
|
||||
)
|
||||
with patch.object(mcp_auth_utils, "logger") as mock_logger:
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers is None
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Basic auth scheme missing username or password."
|
||||
)
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_http_custom_scheme():
|
||||
"""Test header generation for custom HTTP scheme."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="custom")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="custom", credentials=HttpCredentials(token="custom_token")
|
||||
),
|
||||
)
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers == {"Authorization": "custom custom_token"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_http_cred_wrong_scheme():
|
||||
"""Test HTTP credential with non-HTTPBase auth scheme."""
|
||||
auth_scheme = openapi_models.APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": openapi_models.APIKeyIn.header,
|
||||
"name": "X-API-Key",
|
||||
})
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
|
||||
),
|
||||
)
|
||||
with patch.object(mcp_auth_utils, "logger") as mock_logger:
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers is None
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"HTTP credential provided, but auth_scheme is missing or not HTTPBase."
|
||||
)
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_api_key_header():
|
||||
"""Test header generation for API Key in header."""
|
||||
auth_scheme = openapi_models.APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": openapi_models.APIKeyIn.header,
|
||||
"name": "X-Custom-API-Key",
|
||||
})
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers == {"X-Custom-API-Key": "my_api_key"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_api_key_query_raises_error():
|
||||
"""Test API Key in query raises ValueError."""
|
||||
auth_scheme = openapi_models.APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": openapi_models.APIKeyIn.query,
|
||||
"name": "api_key",
|
||||
})
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="MCP tools only support header-based API key authentication.",
|
||||
):
|
||||
mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_api_key_cookie_raises_error():
|
||||
"""Test API Key in cookie raises ValueError."""
|
||||
auth_scheme = openapi_models.APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": openapi_models.APIKeyIn.cookie,
|
||||
"name": "session_id",
|
||||
})
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="MCP tools only support header-based API key authentication.",
|
||||
):
|
||||
mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_api_key_cred_wrong_scheme():
|
||||
"""Test API key credential with non-APIKey auth scheme."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
with patch.object(mcp_auth_utils, "logger") as mock_logger:
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers is None
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"API key credential provided, but auth_scheme is missing or not APIKey."
|
||||
)
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_service_account():
|
||||
"""Test header generation for service account credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(scopes=["test"]),
|
||||
)
|
||||
with patch.object(mcp_auth_utils, "logger") as mock_logger:
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers is None
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Service account credentials should be exchanged for an access "
|
||||
"token before calling get_mcp_auth_headers."
|
||||
)
|
||||
@@ -18,7 +18,10 @@ from unittest.mock import patch
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.features import FeatureName
|
||||
from google.adk.features._feature_registry import temporary_feature_override
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
|
||||
@@ -258,6 +261,240 @@ class TestMCPTool:
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers == {"Authorization": "Bearer test_access_token"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_oauth2(self):
|
||||
"""Test header generation for OAuth2 credentials."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
oauth2_auth = OAuth2Auth(access_token="test_token")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, credential)
|
||||
|
||||
assert headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_http_bearer(self):
|
||||
"""Test header generation for HTTP Bearer credentials."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
http_auth = HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
|
||||
)
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, http=http_auth
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, credential)
|
||||
|
||||
assert headers == {"Authorization": "Bearer bearer_token"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_http_basic(self):
|
||||
"""Test header generation for HTTP Basic credentials."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
http_auth = HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password="pass"),
|
||||
)
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, http=http_auth
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, credential)
|
||||
|
||||
# Should create Basic auth header with base64 encoded credentials
|
||||
import base64
|
||||
|
||||
expected_encoded = base64.b64encode(b"user:pass").decode()
|
||||
assert headers == {"Authorization": f"Basic {expected_encoded}"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_with_valid_header_scheme(self):
|
||||
"""Test header generation for API Key credentials with header-based auth scheme."""
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
|
||||
# Create auth scheme for header-based API key
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": APIKeyIn.header,
|
||||
"name": "X-Custom-API-Key",
|
||||
})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, auth_credential)
|
||||
|
||||
assert headers == {"X-Custom-API-Key": "my_api_key"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_with_query_scheme_raises_error(self):
|
||||
"""Test that API Key with query-based auth scheme raises ValueError."""
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
|
||||
# Create auth scheme for query-based API key (not supported)
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": APIKeyIn.query,
|
||||
"name": "api_key",
|
||||
})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="McpTool only supports header-based API key authentication",
|
||||
):
|
||||
await tool._get_headers(tool_context, auth_credential)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_with_cookie_scheme_raises_error(self):
|
||||
"""Test that API Key with cookie-based auth scheme raises ValueError."""
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
|
||||
# Create auth scheme for cookie-based API key (not supported)
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": APIKeyIn.cookie,
|
||||
"name": "session_id",
|
||||
})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="McpTool only supports header-based API key authentication",
|
||||
):
|
||||
await tool._get_headers(tool_context, auth_credential)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_without_auth_config_raises_error(self):
|
||||
"""Test that API Key without auth config raises ValueError."""
|
||||
# Create tool without auth scheme/config
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Cannot find corresponding auth scheme for API key credential",
|
||||
):
|
||||
await tool._get_headers(tool_context, credential)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_without_credentials_manager_raises_error(
|
||||
self,
|
||||
):
|
||||
"""Test that API Key without credentials manager raises ValueError."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
# Manually set credentials manager to None to simulate error condition
|
||||
tool._credentials_manager = None
|
||||
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Cannot find corresponding auth scheme for API key credential",
|
||||
):
|
||||
await tool._get_headers(tool_context, credential)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_no_credential(self):
|
||||
"""Test header generation with no credentials."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, None)
|
||||
|
||||
assert headers is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_service_account(self):
|
||||
"""Test header generation for service account credentials."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
# Create service account credential
|
||||
service_account = ServiceAccount(scopes=["test"])
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=service_account,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, credential)
|
||||
|
||||
# Should return None as service account credentials are not supported for direct header generation
|
||||
assert headers is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_impl_with_api_key_header_auth(self):
|
||||
"""Test running tool with API key header authentication end-to-end."""
|
||||
@@ -314,6 +551,65 @@ class TestMCPTool:
|
||||
# Check that the method has the retry decorator
|
||||
assert hasattr(tool._run_async_impl, "__wrapped__")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_http_custom_scheme(self):
|
||||
"""Test header generation for custom HTTP scheme."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
http_auth = HttpAuth(
|
||||
scheme="custom", credentials=HttpCredentials(token="custom_token")
|
||||
)
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, http=http_auth
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, credential)
|
||||
|
||||
assert headers == {"Authorization": "custom custom_token"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_error_logging(self):
|
||||
"""Test that API key errors are logged correctly."""
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
|
||||
# Create auth scheme for query-based API key (not supported)
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": APIKeyIn.query,
|
||||
"name": "api_key",
|
||||
})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
|
||||
# Test with logging
|
||||
with patch("google.adk.tools.mcp_tool.mcp_tool.logger") as mock_logger:
|
||||
with pytest.raises(ValueError):
|
||||
await tool._get_headers(tool_context, auth_credential)
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_once()
|
||||
logged_message = mock_logger.error.call_args[0][0]
|
||||
assert (
|
||||
"McpTool only supports header-based API key authentication"
|
||||
in logged_message
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_require_confirmation_true_no_confirmation(self):
|
||||
"""Test require_confirmation=True with no confirmation in context."""
|
||||
|
||||
@@ -30,8 +30,6 @@ from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnecti
|
||||
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import McpToolsetConfig
|
||||
from google.adk.tools.tool_configs import ToolArgsConfig
|
||||
from mcp import StdioServerParameters
|
||||
import pytest
|
||||
|
||||
@@ -247,94 +245,6 @@ class TestMCPToolset:
|
||||
headers=expected_headers
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_with_auth_headers(self):
|
||||
"""Test get_tools with auth headers."""
|
||||
from fastapi.openapi import models as openapi_models
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
|
||||
mock_tools = [MockMCPTool("tool1")]
|
||||
self.mock_session.list_tools = AsyncMock(
|
||||
return_value=MockListToolsResult(mock_tools)
|
||||
)
|
||||
mock_readonly_context = Mock(spec=ReadonlyContext)
|
||||
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(access_token="test_token"),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"google.adk.tools.mcp_tool.mcp_toolset.CredentialManager"
|
||||
) as MockCredentialManager:
|
||||
mock_manager_instance = MockCredentialManager.return_value
|
||||
mock_manager_instance.get_auth_credential = AsyncMock(
|
||||
return_value=auth_credential
|
||||
)
|
||||
|
||||
toolset = MCPToolset(
|
||||
connection_params=self.mock_stdio_params,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
await toolset.get_tools(readonly_context=mock_readonly_context)
|
||||
|
||||
self.mock_session_manager.create_session.assert_called_once()
|
||||
call_args = self.mock_session_manager.create_session.call_args
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_with_auth_and_header_provider(self):
|
||||
"""Test get_tools with auth and header_provider."""
|
||||
from fastapi.openapi import models as openapi_models
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
|
||||
mock_tools = [MockMCPTool("tool1")]
|
||||
self.mock_session.list_tools = AsyncMock(
|
||||
return_value=MockListToolsResult(mock_tools)
|
||||
)
|
||||
mock_readonly_context = Mock(spec=ReadonlyContext)
|
||||
provided_headers = {"X-Tenant-ID": "test-tenant"}
|
||||
header_provider = Mock(return_value=provided_headers)
|
||||
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(access_token="test_token"),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"google.adk.tools.mcp_tool.mcp_toolset.CredentialManager"
|
||||
) as MockCredentialManager:
|
||||
mock_manager_instance = MockCredentialManager.return_value
|
||||
mock_manager_instance.get_auth_credential = AsyncMock(
|
||||
return_value=auth_credential
|
||||
)
|
||||
|
||||
toolset = MCPToolset(
|
||||
connection_params=self.mock_stdio_params,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
header_provider=header_provider,
|
||||
)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
await toolset.get_tools(readonly_context=mock_readonly_context)
|
||||
|
||||
self.mock_session_manager.create_session.assert_called_once()
|
||||
call_args = self.mock_session_manager.create_session.call_args
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers == {
|
||||
"X-Tenant-ID": "test-tenant",
|
||||
"Authorization": "Bearer test_token",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_success(self):
|
||||
"""Test successful cleanup."""
|
||||
|
||||
Reference in New Issue
Block a user