You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
feat: Support authentication for MCP tool listing
Currently only tool calling supports MCP auth. This refactors the auth logic into a auth_utils file and uses it for tool listing as well. Fixes https://github.com/google/adk-python/issues/2168. Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 859201722
This commit is contained in:
committed by
Copybara-Service
parent
d62f9c896c
commit
e3d542a5ba
@@ -1,240 +0,0 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi.openapi import models as openapi_models
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
from google.adk.tools.mcp_tool import mcp_auth_utils
|
||||
import pytest
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_no_credential():
|
||||
"""Test header generation with no credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=None
|
||||
)
|
||||
assert headers is None
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_no_auth_scheme():
|
||||
"""Test header generation with no auth_scheme."""
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(access_token="test_token"),
|
||||
)
|
||||
with patch.object(mcp_auth_utils, "logger") as mock_logger:
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=None, credential=credential
|
||||
)
|
||||
assert headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_oauth2():
|
||||
"""Test header generation for OAuth2 credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(access_token="test_token"),
|
||||
)
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_http_bearer():
|
||||
"""Test header generation for HTTP Bearer credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
|
||||
),
|
||||
)
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers == {"Authorization": "Bearer bearer_token"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_http_basic():
|
||||
"""Test header generation for HTTP Basic credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="basic")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password="pass"),
|
||||
),
|
||||
)
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
expected_encoded = base64.b64encode(b"user:pass").decode()
|
||||
assert headers == {"Authorization": f"Basic {expected_encoded}"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_http_basic_missing_credentials():
|
||||
"""Test header generation for HTTP Basic with missing credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="basic")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password=None),
|
||||
),
|
||||
)
|
||||
with patch.object(mcp_auth_utils, "logger") as mock_logger:
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers is None
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Basic auth scheme missing username or password."
|
||||
)
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_http_custom_scheme():
|
||||
"""Test header generation for custom HTTP scheme."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="custom")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="custom", credentials=HttpCredentials(token="custom_token")
|
||||
),
|
||||
)
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers == {"Authorization": "custom custom_token"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_http_cred_wrong_scheme():
|
||||
"""Test HTTP credential with non-HTTPBase auth scheme."""
|
||||
auth_scheme = openapi_models.APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": openapi_models.APIKeyIn.header,
|
||||
"name": "X-API-Key",
|
||||
})
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP,
|
||||
http=HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
|
||||
),
|
||||
)
|
||||
with patch.object(mcp_auth_utils, "logger") as mock_logger:
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers is None
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"HTTP credential provided, but auth_scheme is missing or not HTTPBase."
|
||||
)
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_api_key_header():
|
||||
"""Test header generation for API Key in header."""
|
||||
auth_scheme = openapi_models.APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": openapi_models.APIKeyIn.header,
|
||||
"name": "X-Custom-API-Key",
|
||||
})
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers == {"X-Custom-API-Key": "my_api_key"}
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_api_key_query_raises_error():
|
||||
"""Test API Key in query raises ValueError."""
|
||||
auth_scheme = openapi_models.APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": openapi_models.APIKeyIn.query,
|
||||
"name": "api_key",
|
||||
})
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="MCP tools only support header-based API key authentication.",
|
||||
):
|
||||
mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_api_key_cookie_raises_error():
|
||||
"""Test API Key in cookie raises ValueError."""
|
||||
auth_scheme = openapi_models.APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": openapi_models.APIKeyIn.cookie,
|
||||
"name": "session_id",
|
||||
})
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="MCP tools only support header-based API key authentication.",
|
||||
):
|
||||
mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_api_key_cred_wrong_scheme():
|
||||
"""Test API key credential with non-APIKey auth scheme."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
with patch.object(mcp_auth_utils, "logger") as mock_logger:
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers is None
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"API key credential provided, but auth_scheme is missing or not APIKey."
|
||||
)
|
||||
|
||||
|
||||
def test_get_mcp_auth_headers_service_account():
|
||||
"""Test header generation for service account credentials."""
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=ServiceAccount(scopes=["test"]),
|
||||
)
|
||||
with patch.object(mcp_auth_utils, "logger") as mock_logger:
|
||||
headers = mcp_auth_utils.get_mcp_auth_headers(
|
||||
auth_scheme=auth_scheme, credential=credential
|
||||
)
|
||||
assert headers is None
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"Service account credentials should be exchanged for an access "
|
||||
"token before calling get_mcp_auth_headers."
|
||||
)
|
||||
@@ -18,7 +18,10 @@ from unittest.mock import patch
|
||||
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import HttpAuth
|
||||
from google.adk.auth.auth_credential import HttpCredentials
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_credential import ServiceAccount
|
||||
from google.adk.features import FeatureName
|
||||
from google.adk.features._feature_registry import temporary_feature_override
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
|
||||
@@ -258,6 +261,240 @@ class TestMCPTool:
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers == {"Authorization": "Bearer test_access_token"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_oauth2(self):
|
||||
"""Test header generation for OAuth2 credentials."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
oauth2_auth = OAuth2Auth(access_token="test_token")
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2, oauth2=oauth2_auth
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, credential)
|
||||
|
||||
assert headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_http_bearer(self):
|
||||
"""Test header generation for HTTP Bearer credentials."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
http_auth = HttpAuth(
|
||||
scheme="bearer", credentials=HttpCredentials(token="bearer_token")
|
||||
)
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, http=http_auth
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, credential)
|
||||
|
||||
assert headers == {"Authorization": "Bearer bearer_token"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_http_basic(self):
|
||||
"""Test header generation for HTTP Basic credentials."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
http_auth = HttpAuth(
|
||||
scheme="basic",
|
||||
credentials=HttpCredentials(username="user", password="pass"),
|
||||
)
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, http=http_auth
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, credential)
|
||||
|
||||
# Should create Basic auth header with base64 encoded credentials
|
||||
import base64
|
||||
|
||||
expected_encoded = base64.b64encode(b"user:pass").decode()
|
||||
assert headers == {"Authorization": f"Basic {expected_encoded}"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_with_valid_header_scheme(self):
|
||||
"""Test header generation for API Key credentials with header-based auth scheme."""
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
|
||||
# Create auth scheme for header-based API key
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": APIKeyIn.header,
|
||||
"name": "X-Custom-API-Key",
|
||||
})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, auth_credential)
|
||||
|
||||
assert headers == {"X-Custom-API-Key": "my_api_key"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_with_query_scheme_raises_error(self):
|
||||
"""Test that API Key with query-based auth scheme raises ValueError."""
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
|
||||
# Create auth scheme for query-based API key (not supported)
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": APIKeyIn.query,
|
||||
"name": "api_key",
|
||||
})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="McpTool only supports header-based API key authentication",
|
||||
):
|
||||
await tool._get_headers(tool_context, auth_credential)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_with_cookie_scheme_raises_error(self):
|
||||
"""Test that API Key with cookie-based auth scheme raises ValueError."""
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
|
||||
# Create auth scheme for cookie-based API key (not supported)
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": APIKeyIn.cookie,
|
||||
"name": "session_id",
|
||||
})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="McpTool only supports header-based API key authentication",
|
||||
):
|
||||
await tool._get_headers(tool_context, auth_credential)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_without_auth_config_raises_error(self):
|
||||
"""Test that API Key without auth config raises ValueError."""
|
||||
# Create tool without auth scheme/config
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Cannot find corresponding auth scheme for API key credential",
|
||||
):
|
||||
await tool._get_headers(tool_context, credential)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_without_credentials_manager_raises_error(
|
||||
self,
|
||||
):
|
||||
"""Test that API Key without credentials manager raises ValueError."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
# Manually set credentials manager to None to simulate error condition
|
||||
tool._credentials_manager = None
|
||||
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Cannot find corresponding auth scheme for API key credential",
|
||||
):
|
||||
await tool._get_headers(tool_context, credential)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_no_credential(self):
|
||||
"""Test header generation with no credentials."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, None)
|
||||
|
||||
assert headers is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_service_account(self):
|
||||
"""Test header generation for service account credentials."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
# Create service account credential
|
||||
service_account = ServiceAccount(scopes=["test"])
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.SERVICE_ACCOUNT,
|
||||
service_account=service_account,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, credential)
|
||||
|
||||
# Should return None as service account credentials are not supported for direct header generation
|
||||
assert headers is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_impl_with_api_key_header_auth(self):
|
||||
"""Test running tool with API key header authentication end-to-end."""
|
||||
@@ -314,6 +551,65 @@ class TestMCPTool:
|
||||
# Check that the method has the retry decorator
|
||||
assert hasattr(tool._run_async_impl, "__wrapped__")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_http_custom_scheme(self):
|
||||
"""Test header generation for custom HTTP scheme."""
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
)
|
||||
|
||||
http_auth = HttpAuth(
|
||||
scheme="custom", credentials=HttpCredentials(token="custom_token")
|
||||
)
|
||||
credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.HTTP, http=http_auth
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
headers = await tool._get_headers(tool_context, credential)
|
||||
|
||||
assert headers == {"Authorization": "custom custom_token"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_headers_api_key_error_logging(self):
|
||||
"""Test that API key errors are logged correctly."""
|
||||
from fastapi.openapi.models import APIKey
|
||||
from fastapi.openapi.models import APIKeyIn
|
||||
from google.adk.auth.auth_schemes import AuthSchemeType
|
||||
|
||||
# Create auth scheme for query-based API key (not supported)
|
||||
auth_scheme = APIKey(**{
|
||||
"type": AuthSchemeType.apiKey,
|
||||
"in": APIKeyIn.query,
|
||||
"name": "api_key",
|
||||
})
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.API_KEY, api_key="my_api_key"
|
||||
)
|
||||
|
||||
tool = MCPTool(
|
||||
mcp_tool=self.mock_mcp_tool,
|
||||
mcp_session_manager=self.mock_session_manager,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
|
||||
tool_context = Mock(spec=ToolContext)
|
||||
|
||||
# Test with logging
|
||||
with patch("google.adk.tools.mcp_tool.mcp_tool.logger") as mock_logger:
|
||||
with pytest.raises(ValueError):
|
||||
await tool._get_headers(tool_context, auth_credential)
|
||||
|
||||
# Verify error was logged
|
||||
mock_logger.error.assert_called_once()
|
||||
logged_message = mock_logger.error.call_args[0][0]
|
||||
assert (
|
||||
"McpTool only supports header-based API key authentication"
|
||||
in logged_message
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_async_require_confirmation_true_no_confirmation(self):
|
||||
"""Test require_confirmation=True with no confirmation in context."""
|
||||
|
||||
@@ -30,8 +30,6 @@ from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnecti
|
||||
from google.adk.tools.mcp_tool.mcp_tool import MCPTool
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import McpToolsetConfig
|
||||
from google.adk.tools.tool_configs import ToolArgsConfig
|
||||
from mcp import StdioServerParameters
|
||||
import pytest
|
||||
|
||||
@@ -247,94 +245,6 @@ class TestMCPToolset:
|
||||
headers=expected_headers
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_with_auth_headers(self):
|
||||
"""Test get_tools with auth headers."""
|
||||
from fastapi.openapi import models as openapi_models
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
|
||||
mock_tools = [MockMCPTool("tool1")]
|
||||
self.mock_session.list_tools = AsyncMock(
|
||||
return_value=MockListToolsResult(mock_tools)
|
||||
)
|
||||
mock_readonly_context = Mock(spec=ReadonlyContext)
|
||||
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(access_token="test_token"),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"google.adk.tools.mcp_tool.mcp_toolset.CredentialManager"
|
||||
) as MockCredentialManager:
|
||||
mock_manager_instance = MockCredentialManager.return_value
|
||||
mock_manager_instance.get_auth_credential = AsyncMock(
|
||||
return_value=auth_credential
|
||||
)
|
||||
|
||||
toolset = MCPToolset(
|
||||
connection_params=self.mock_stdio_params,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
await toolset.get_tools(readonly_context=mock_readonly_context)
|
||||
|
||||
self.mock_session_manager.create_session.assert_called_once()
|
||||
call_args = self.mock_session_manager.create_session.call_args
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers == {"Authorization": "Bearer test_token"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_with_auth_and_header_provider(self):
|
||||
"""Test get_tools with auth and header_provider."""
|
||||
from fastapi.openapi import models as openapi_models
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
|
||||
mock_tools = [MockMCPTool("tool1")]
|
||||
self.mock_session.list_tools = AsyncMock(
|
||||
return_value=MockListToolsResult(mock_tools)
|
||||
)
|
||||
mock_readonly_context = Mock(spec=ReadonlyContext)
|
||||
provided_headers = {"X-Tenant-ID": "test-tenant"}
|
||||
header_provider = Mock(return_value=provided_headers)
|
||||
|
||||
auth_scheme = openapi_models.HTTPBase(scheme="bearer")
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(access_token="test_token"),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"google.adk.tools.mcp_tool.mcp_toolset.CredentialManager"
|
||||
) as MockCredentialManager:
|
||||
mock_manager_instance = MockCredentialManager.return_value
|
||||
mock_manager_instance.get_auth_credential = AsyncMock(
|
||||
return_value=auth_credential
|
||||
)
|
||||
|
||||
toolset = MCPToolset(
|
||||
connection_params=self.mock_stdio_params,
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
header_provider=header_provider,
|
||||
)
|
||||
toolset._mcp_session_manager = self.mock_session_manager
|
||||
|
||||
await toolset.get_tools(readonly_context=mock_readonly_context)
|
||||
|
||||
self.mock_session_manager.create_session.assert_called_once()
|
||||
call_args = self.mock_session_manager.create_session.call_args
|
||||
headers = call_args[1]["headers"]
|
||||
assert headers == {
|
||||
"X-Tenant-ID": "test-tenant",
|
||||
"Authorization": "Bearer test_token",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_success(self):
|
||||
"""Test successful cleanup."""
|
||||
|
||||
Reference in New Issue
Block a user