diff --git a/contributing/samples/mcp_toolset_auth/README.md b/contributing/samples/mcp_toolset_auth/README.md new file mode 100644 index 00000000..20ead6a2 --- /dev/null +++ b/contributing/samples/mcp_toolset_auth/README.md @@ -0,0 +1,47 @@ +# MCP Toolset OAuth Authentication Sample + +This sample demonstrates the toolset authentication feature where OAuth credentials are required for both tool listing and tool calling. + +## Overview + +The toolset authentication flow works in two phases: + +1. **Phase 1**: When the agent tries to get tools from the MCP server without credentials, the toolset signals "authentication required" and returns an auth request event. + +2. **Phase 2**: After the user provides OAuth credentials, the agent can successfully list and call tools. + +## Files + +- `oauth_mcp_server.py` - MCP server that requires Bearer token authentication +- `agent.py` - Agent configuration with OAuth-protected MCP toolset +- `main.py` - Test script demonstrating the two-phase auth flow + +## Running the Sample + +1. Start the MCP server in one terminal: + +```bash +PYTHONPATH=src python contributing/samples/mcp_toolset_auth/oauth_mcp_server.py +``` + +2. Run the test script in another terminal: + +```bash +PYTHONPATH=src python contributing/samples/mcp_toolset_auth/main.py +``` + +## Expected Behavior + +1. First invocation yields an `adk_request_credential` function call +2. The credential ID is `_adk_toolset_auth_McpToolset` to indicate toolset auth +3. After providing the access token, the agent can list and call tools + +## Testing with ADK Web UI + +You can also test with the ADK web UI: + +```bash +adk web contributing/samples/mcp_toolset_auth +``` + +Note: The web UI will display the auth request and you'll need to manually provide credentials. diff --git a/contributing/samples/mcp_toolset_auth/__init__.py b/contributing/samples/mcp_toolset_auth/__init__.py new file mode 100644 index 00000000..c48963cd --- /dev/null +++ b/contributing/samples/mcp_toolset_auth/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/mcp_toolset_auth/agent.py b/contributing/samples/mcp_toolset_auth/agent.py new file mode 100644 index 00000000..ad417a6e --- /dev/null +++ b/contributing/samples/mcp_toolset_auth/agent.py @@ -0,0 +1,76 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent that uses MCP toolset requiring OAuth authentication. + +This agent demonstrates the toolset authentication feature where OAuth +credentials are required for both tool listing and tool calling. +""" + +from __future__ import annotations + +from fastapi.openapi.models import OAuth2 +from fastapi.openapi.models import OAuthFlowAuthorizationCode +from fastapi.openapi.models import OAuthFlows +from google.adk.agents import LlmAgent +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset + +# OAuth2 auth scheme with authorization code flow +# This specifies the OAuth metadata needed for the full OAuth flow +auth_scheme = OAuth2( + flows=OAuthFlows( + authorizationCode=OAuthFlowAuthorizationCode( + authorizationUrl='https://example.com/oauth/authorize', + tokenUrl='https://example.com/oauth/token', + scopes={'read': 'Read access', 'write': 'Write access'}, + ) + ) +) + +# OAuth credential with client credentials (used for token exchange) +# In a real scenario, this would be used to obtain the access token +auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id='test_client_id', + client_secret='test_client_secret', + ), +) + +# Create the MCP toolset with OAuth authentication +mcp_toolset = McpToolset( + connection_params=StreamableHTTPConnectionParams( + url='http://localhost:3001/mcp', + ), + auth_scheme=auth_scheme, + auth_credential=auth_credential, +) + +# Define the agent that uses the OAuth-protected MCP toolset +root_agent = LlmAgent( + model='gemini-2.0-flash', + name='oauth_mcp_agent', + instruction="""You are a helpful assistant that can access user information. + +You have access to tools that require authentication: +- get_user_profile: Get profile information for a specific user +- list_users: List all available users + +When the user asks about users, use these tools to help them.""", + tools=[mcp_toolset], +) diff --git a/contributing/samples/mcp_toolset_auth/main.py b/contributing/samples/mcp_toolset_auth/main.py new file mode 100644 index 00000000..e9b8950a --- /dev/null +++ b/contributing/samples/mcp_toolset_auth/main.py @@ -0,0 +1,168 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script for MCP Toolset OAuth Authentication Flow. + +This script demonstrates the two-phase tool discovery flow: +1. First invocation: Agent tries to get tools, auth is required, returns auth + request event (adk_request_credential) +2. User provides OAuth credentials (simulated) +3. Second invocation: Agent has credentials, can list and call tools + +Usage: + # Start the MCP server first (in another terminal): + PYTHONPATH=src python contributing/samples/mcp_toolset_auth/oauth_mcp_server.py + + # Run the demo: + PYTHONPATH=src python contributing/samples/mcp_toolset_auth/main.py +""" + +from __future__ import annotations + +import asyncio + +from agent import auth_credential +from agent import auth_scheme +from agent import mcp_toolset +from agent import root_agent +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import OAuth2Auth +from google.adk.auth.auth_tool import AuthConfig +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types + + +async def run_demo(): + """Run demo with real MCP server.""" + print('=' * 60) + print('MCP Toolset OAuth Authentication Demo') + print('=' * 60) + print('\nNote: Make sure the MCP server is running:') + print(' python oauth_mcp_server.py\n') + + # Create session service and runner + session_service = InMemorySessionService() + runner = Runner( + agent=root_agent, + app_name='toolset_auth_demo', + session_service=session_service, + ) + + # Create a session + session = await session_service.create_session( + app_name='toolset_auth_demo', + user_id='test_user', + ) + + print(f'Session created: {session.id}') + print('\n--- Phase 1: Initial request (no credentials) ---\n') + + # First invocation - should trigger auth request + user_message = 'List all users' + print(f'User: {user_message}') + + events = [] + auth_function_call_id = None + max_events = 10 + + try: + async for event in runner.run_async( + session_id=session.id, + user_id='test_user', + new_message=types.Content( + role='user', + parts=[types.Part(text=user_message)], + ), + ): + events.append(event) + print(f'\nEvent from {event.author}:') + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + print(f' Text: {part.text}') + if part.function_call: + print(f' Function call: {part.function_call.name}') + if part.function_call.name == 'adk_request_credential': + auth_function_call_id = part.function_call.id + + if len(events) >= max_events: + print(f'\n** SAFETY LIMIT ({max_events} events) **') + break + + except Exception as e: + print(f'\nError: {e}') + print('Make sure the MCP server is running!') + await mcp_toolset.close() + return + + if auth_function_call_id: + print('\n** Auth request detected! **') + print('\n--- Phase 2: Provide OAuth credentials ---\n') + + # Simulate user providing OAuth credentials after completing OAuth flow + auth_response = AuthConfig( + auth_scheme=auth_scheme, + raw_auth_credential=auth_credential, + exchanged_auth_credential=AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + access_token='test_access_token_12345', + ), + ), + ) + + print('Providing access token: test_access_token_12345') + + auth_response_message = types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name='adk_request_credential', + id=auth_function_call_id, + response=auth_response.model_dump(exclude_none=True), + ) + ) + ], + ) + + async for event in runner.run_async( + session_id=session.id, + user_id='test_user', + new_message=auth_response_message, + ): + print(f'\nEvent from {event.author}:') + if event.content and event.content.parts: + for part in event.content.parts: + if part.text: + text = ( + part.text[:200] + '...' if len(part.text) > 200 else part.text + ) + print(f' Text: {text}') + if part.function_call: + print(f' Function call: {part.function_call.name}') + else: + print('\n** No auth request - credentials may already be available **') + + print('\n' + '=' * 60) + print('Demo completed') + print('=' * 60) + + await mcp_toolset.close() + + +if __name__ == '__main__': + asyncio.run(run_demo()) diff --git a/contributing/samples/mcp_toolset_auth/oauth_mcp_server.py b/contributing/samples/mcp_toolset_auth/oauth_mcp_server.py new file mode 100644 index 00000000..0eeab51c --- /dev/null +++ b/contributing/samples/mcp_toolset_auth/oauth_mcp_server.py @@ -0,0 +1,120 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MCP Server that requires OAuth Bearer token for both tool listing and calling. + +This server validates the Authorization header on every request including: +- Tool listing (list_tools endpoint) +- Tool calling (call_tool endpoint) + +This is used to test the toolset authentication feature in ADK. +""" + +from __future__ import annotations + +import logging + +from fastapi import FastAPI +from fastapi import HTTPException +from fastapi import Request +from mcp.server.fastmcp import Context +from mcp.server.fastmcp import FastMCP +import uvicorn + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Expected OAuth token for testing +VALID_TOKEN = 'test_access_token_12345' + +# Create FastMCP server +mcp = FastMCP('OAuth Protected MCP Server', host='localhost', port=3001) + + +def validate_auth_header(request: Request) -> bool: + """Validate the Authorization header contains a valid Bearer token.""" + auth_header = request.headers.get('authorization', '') + if not auth_header.startswith('Bearer '): + logger.warning('Missing or invalid Authorization header: %s', auth_header) + return False + + token = auth_header[7:] # Remove 'Bearer ' prefix + if token != VALID_TOKEN: + logger.warning('Invalid token: %s', token) + return False + + logger.info('Valid token received') + return True + + +@mcp.tool(description='Get user profile information. Requires authentication.') +def get_user_profile(user_id: str, context: Context) -> dict: + """Return user profile data for the given user ID.""" + logger.info('get_user_profile called for user: %s', user_id) + + if context.request_context and context.request_context.request: + if not validate_auth_header(context.request_context.request): + return {'error': 'Unauthorized - invalid or missing token'} + + # Mock user data + users = { + 'user1': {'id': 'user1', 'name': 'Alice', 'email': 'alice@example.com'}, + 'user2': {'id': 'user2', 'name': 'Bob', 'email': 'bob@example.com'}, + } + + if user_id in users: + return users[user_id] + return {'error': f'User {user_id} not found'} + + +@mcp.tool(description='List all available users. Requires authentication.') +def list_users(context: Context) -> dict: + """Return a list of all users.""" + logger.info('list_users called') + + if context.request_context and context.request_context.request: + if not validate_auth_header(context.request_context.request): + return {'error': 'Unauthorized - invalid or missing token'} + + return { + 'users': [ + {'id': 'user1', 'name': 'Alice'}, + {'id': 'user2', 'name': 'Bob'}, + ] + } + + +# Create custom FastAPI app to add auth middleware for list_tools +app = FastAPI() + + +@app.middleware('http') +async def auth_middleware(request: Request, call_next): + """Middleware to validate auth on all MCP endpoints.""" + # Check if this is an MCP request + if request.url.path.startswith('/mcp'): + if not validate_auth_header(request): + raise HTTPException(status_code=401, detail='Unauthorized') + return await call_next(request) + + +if __name__ == '__main__': + print(f'Starting OAuth Protected MCP server on http://localhost:3001') + print(f'Expected token: Bearer {VALID_TOKEN}') + print( + 'This server requires authentication for both tool listing and calling.' + ) + + # Run with streamable-http transport + mcp.run(transport='streamable-http')