You've already forked adk-python
mirror of
https://github.com/encounter/adk-python.git
synced 2026-03-30 10:57:20 -07:00
chore: Add sample agent that need to go through oauth flow during mcp tool listing
Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com> PiperOrigin-RevId: 864133951
This commit is contained in:
committed by
Copybara-Service
parent
131fbd3948
commit
2770012cec
@@ -0,0 +1,47 @@
|
||||
# MCP Toolset OAuth Authentication Sample
|
||||
|
||||
This sample demonstrates the toolset authentication feature where OAuth credentials are required for both tool listing and tool calling.
|
||||
|
||||
## Overview
|
||||
|
||||
The toolset authentication flow works in two phases:
|
||||
|
||||
1. **Phase 1**: When the agent tries to get tools from the MCP server without credentials, the toolset signals "authentication required" and returns an auth request event.
|
||||
|
||||
2. **Phase 2**: After the user provides OAuth credentials, the agent can successfully list and call tools.
|
||||
|
||||
## Files
|
||||
|
||||
- `oauth_mcp_server.py` - MCP server that requires Bearer token authentication
|
||||
- `agent.py` - Agent configuration with OAuth-protected MCP toolset
|
||||
- `main.py` - Test script demonstrating the two-phase auth flow
|
||||
|
||||
## Running the Sample
|
||||
|
||||
1. Start the MCP server in one terminal:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=src python contributing/samples/mcp_toolset_auth/oauth_mcp_server.py
|
||||
```
|
||||
|
||||
2. Run the test script in another terminal:
|
||||
|
||||
```bash
|
||||
PYTHONPATH=src python contributing/samples/mcp_toolset_auth/main.py
|
||||
```
|
||||
|
||||
## Expected Behavior
|
||||
|
||||
1. First invocation yields an `adk_request_credential` function call
|
||||
2. The credential ID is `_adk_toolset_auth_McpToolset` to indicate toolset auth
|
||||
3. After providing the access token, the agent can list and call tools
|
||||
|
||||
## Testing with ADK Web UI
|
||||
|
||||
You can also test with the ADK web UI:
|
||||
|
||||
```bash
|
||||
adk web contributing/samples/mcp_toolset_auth
|
||||
```
|
||||
|
||||
Note: The web UI will display the auth request and you'll need to manually provide credentials.
|
||||
@@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Agent that uses MCP toolset requiring OAuth authentication.
|
||||
|
||||
This agent demonstrates the toolset authentication feature where OAuth
|
||||
credentials are required for both tool listing and tool calling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi.openapi.models import OAuth2
|
||||
from fastapi.openapi.models import OAuthFlowAuthorizationCode
|
||||
from fastapi.openapi.models import OAuthFlows
|
||||
from google.adk.agents import LlmAgent
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
|
||||
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
|
||||
|
||||
# OAuth2 auth scheme with authorization code flow
|
||||
# This specifies the OAuth metadata needed for the full OAuth flow
|
||||
auth_scheme = OAuth2(
|
||||
flows=OAuthFlows(
|
||||
authorizationCode=OAuthFlowAuthorizationCode(
|
||||
authorizationUrl='https://example.com/oauth/authorize',
|
||||
tokenUrl='https://example.com/oauth/token',
|
||||
scopes={'read': 'Read access', 'write': 'Write access'},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# OAuth credential with client credentials (used for token exchange)
|
||||
# In a real scenario, this would be used to obtain the access token
|
||||
auth_credential = AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
client_id='test_client_id',
|
||||
client_secret='test_client_secret',
|
||||
),
|
||||
)
|
||||
|
||||
# Create the MCP toolset with OAuth authentication
|
||||
mcp_toolset = McpToolset(
|
||||
connection_params=StreamableHTTPConnectionParams(
|
||||
url='http://localhost:3001/mcp',
|
||||
),
|
||||
auth_scheme=auth_scheme,
|
||||
auth_credential=auth_credential,
|
||||
)
|
||||
|
||||
# Define the agent that uses the OAuth-protected MCP toolset
|
||||
root_agent = LlmAgent(
|
||||
model='gemini-2.0-flash',
|
||||
name='oauth_mcp_agent',
|
||||
instruction="""You are a helpful assistant that can access user information.
|
||||
|
||||
You have access to tools that require authentication:
|
||||
- get_user_profile: Get profile information for a specific user
|
||||
- list_users: List all available users
|
||||
|
||||
When the user asks about users, use these tools to help them.""",
|
||||
tools=[mcp_toolset],
|
||||
)
|
||||
@@ -0,0 +1,168 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test script for MCP Toolset OAuth Authentication Flow.
|
||||
|
||||
This script demonstrates the two-phase tool discovery flow:
|
||||
1. First invocation: Agent tries to get tools, auth is required, returns auth
|
||||
request event (adk_request_credential)
|
||||
2. User provides OAuth credentials (simulated)
|
||||
3. Second invocation: Agent has credentials, can list and call tools
|
||||
|
||||
Usage:
|
||||
# Start the MCP server first (in another terminal):
|
||||
PYTHONPATH=src python contributing/samples/mcp_toolset_auth/oauth_mcp_server.py
|
||||
|
||||
# Run the demo:
|
||||
PYTHONPATH=src python contributing/samples/mcp_toolset_auth/main.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from agent import auth_credential
|
||||
from agent import auth_scheme
|
||||
from agent import mcp_toolset
|
||||
from agent import root_agent
|
||||
from google.adk.auth.auth_credential import AuthCredential
|
||||
from google.adk.auth.auth_credential import AuthCredentialTypes
|
||||
from google.adk.auth.auth_credential import OAuth2Auth
|
||||
from google.adk.auth.auth_tool import AuthConfig
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions.in_memory_session_service import InMemorySessionService
|
||||
from google.genai import types
|
||||
|
||||
|
||||
async def run_demo():
|
||||
"""Run demo with real MCP server."""
|
||||
print('=' * 60)
|
||||
print('MCP Toolset OAuth Authentication Demo')
|
||||
print('=' * 60)
|
||||
print('\nNote: Make sure the MCP server is running:')
|
||||
print(' python oauth_mcp_server.py\n')
|
||||
|
||||
# Create session service and runner
|
||||
session_service = InMemorySessionService()
|
||||
runner = Runner(
|
||||
agent=root_agent,
|
||||
app_name='toolset_auth_demo',
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
# Create a session
|
||||
session = await session_service.create_session(
|
||||
app_name='toolset_auth_demo',
|
||||
user_id='test_user',
|
||||
)
|
||||
|
||||
print(f'Session created: {session.id}')
|
||||
print('\n--- Phase 1: Initial request (no credentials) ---\n')
|
||||
|
||||
# First invocation - should trigger auth request
|
||||
user_message = 'List all users'
|
||||
print(f'User: {user_message}')
|
||||
|
||||
events = []
|
||||
auth_function_call_id = None
|
||||
max_events = 10
|
||||
|
||||
try:
|
||||
async for event in runner.run_async(
|
||||
session_id=session.id,
|
||||
user_id='test_user',
|
||||
new_message=types.Content(
|
||||
role='user',
|
||||
parts=[types.Part(text=user_message)],
|
||||
),
|
||||
):
|
||||
events.append(event)
|
||||
print(f'\nEvent from {event.author}:')
|
||||
if event.content and event.content.parts:
|
||||
for part in event.content.parts:
|
||||
if part.text:
|
||||
print(f' Text: {part.text}')
|
||||
if part.function_call:
|
||||
print(f' Function call: {part.function_call.name}')
|
||||
if part.function_call.name == 'adk_request_credential':
|
||||
auth_function_call_id = part.function_call.id
|
||||
|
||||
if len(events) >= max_events:
|
||||
print(f'\n** SAFETY LIMIT ({max_events} events) **')
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f'\nError: {e}')
|
||||
print('Make sure the MCP server is running!')
|
||||
await mcp_toolset.close()
|
||||
return
|
||||
|
||||
if auth_function_call_id:
|
||||
print('\n** Auth request detected! **')
|
||||
print('\n--- Phase 2: Provide OAuth credentials ---\n')
|
||||
|
||||
# Simulate user providing OAuth credentials after completing OAuth flow
|
||||
auth_response = AuthConfig(
|
||||
auth_scheme=auth_scheme,
|
||||
raw_auth_credential=auth_credential,
|
||||
exchanged_auth_credential=AuthCredential(
|
||||
auth_type=AuthCredentialTypes.OAUTH2,
|
||||
oauth2=OAuth2Auth(
|
||||
access_token='test_access_token_12345',
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
print('Providing access token: test_access_token_12345')
|
||||
|
||||
auth_response_message = types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
function_response=types.FunctionResponse(
|
||||
name='adk_request_credential',
|
||||
id=auth_function_call_id,
|
||||
response=auth_response.model_dump(exclude_none=True),
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async for event in runner.run_async(
|
||||
session_id=session.id,
|
||||
user_id='test_user',
|
||||
new_message=auth_response_message,
|
||||
):
|
||||
print(f'\nEvent from {event.author}:')
|
||||
if event.content and event.content.parts:
|
||||
for part in event.content.parts:
|
||||
if part.text:
|
||||
text = (
|
||||
part.text[:200] + '...' if len(part.text) > 200 else part.text
|
||||
)
|
||||
print(f' Text: {text}')
|
||||
if part.function_call:
|
||||
print(f' Function call: {part.function_call.name}')
|
||||
else:
|
||||
print('\n** No auth request - credentials may already be available **')
|
||||
|
||||
print('\n' + '=' * 60)
|
||||
print('Demo completed')
|
||||
print('=' * 60)
|
||||
|
||||
await mcp_toolset.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(run_demo())
|
||||
@@ -0,0 +1,120 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MCP Server that requires OAuth Bearer token for both tool listing and calling.
|
||||
|
||||
This server validates the Authorization header on every request including:
|
||||
- Tool listing (list_tools endpoint)
|
||||
- Tool calling (call_tool endpoint)
|
||||
|
||||
This is used to test the toolset authentication feature in ADK.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from mcp.server.fastmcp import Context
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
import uvicorn
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Expected OAuth token for testing
|
||||
VALID_TOKEN = 'test_access_token_12345'
|
||||
|
||||
# Create FastMCP server
|
||||
mcp = FastMCP('OAuth Protected MCP Server', host='localhost', port=3001)
|
||||
|
||||
|
||||
def validate_auth_header(request: Request) -> bool:
|
||||
"""Validate the Authorization header contains a valid Bearer token."""
|
||||
auth_header = request.headers.get('authorization', '')
|
||||
if not auth_header.startswith('Bearer '):
|
||||
logger.warning('Missing or invalid Authorization header: %s', auth_header)
|
||||
return False
|
||||
|
||||
token = auth_header[7:] # Remove 'Bearer ' prefix
|
||||
if token != VALID_TOKEN:
|
||||
logger.warning('Invalid token: %s', token)
|
||||
return False
|
||||
|
||||
logger.info('Valid token received')
|
||||
return True
|
||||
|
||||
|
||||
@mcp.tool(description='Get user profile information. Requires authentication.')
|
||||
def get_user_profile(user_id: str, context: Context) -> dict:
|
||||
"""Return user profile data for the given user ID."""
|
||||
logger.info('get_user_profile called for user: %s', user_id)
|
||||
|
||||
if context.request_context and context.request_context.request:
|
||||
if not validate_auth_header(context.request_context.request):
|
||||
return {'error': 'Unauthorized - invalid or missing token'}
|
||||
|
||||
# Mock user data
|
||||
users = {
|
||||
'user1': {'id': 'user1', 'name': 'Alice', 'email': 'alice@example.com'},
|
||||
'user2': {'id': 'user2', 'name': 'Bob', 'email': 'bob@example.com'},
|
||||
}
|
||||
|
||||
if user_id in users:
|
||||
return users[user_id]
|
||||
return {'error': f'User {user_id} not found'}
|
||||
|
||||
|
||||
@mcp.tool(description='List all available users. Requires authentication.')
|
||||
def list_users(context: Context) -> dict:
|
||||
"""Return a list of all users."""
|
||||
logger.info('list_users called')
|
||||
|
||||
if context.request_context and context.request_context.request:
|
||||
if not validate_auth_header(context.request_context.request):
|
||||
return {'error': 'Unauthorized - invalid or missing token'}
|
||||
|
||||
return {
|
||||
'users': [
|
||||
{'id': 'user1', 'name': 'Alice'},
|
||||
{'id': 'user2', 'name': 'Bob'},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# Create custom FastAPI app to add auth middleware for list_tools
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.middleware('http')
|
||||
async def auth_middleware(request: Request, call_next):
|
||||
"""Middleware to validate auth on all MCP endpoints."""
|
||||
# Check if this is an MCP request
|
||||
if request.url.path.startswith('/mcp'):
|
||||
if not validate_auth_header(request):
|
||||
raise HTTPException(status_code=401, detail='Unauthorized')
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(f'Starting OAuth Protected MCP server on http://localhost:3001')
|
||||
print(f'Expected token: Bearer {VALID_TOKEN}')
|
||||
print(
|
||||
'This server requires authentication for both tool listing and calling.'
|
||||
)
|
||||
|
||||
# Run with streamable-http transport
|
||||
mcp.run(transport='streamable-http')
|
||||
Reference in New Issue
Block a user