chore: Add sample agent that need to go through oauth flow during mcp tool listing

Co-authored-by: Xiang (Sean) Zhou <seanzhougoogle@google.com>
PiperOrigin-RevId: 864133951
This commit is contained in:
Xiang (Sean) Zhou
2026-02-01 20:21:58 -08:00
committed by Copybara-Service
parent 131fbd3948
commit 2770012cec
5 changed files with 426 additions and 0 deletions
@@ -0,0 +1,47 @@
# MCP Toolset OAuth Authentication Sample
This sample demonstrates the toolset authentication feature where OAuth credentials are required for both tool listing and tool calling.
## Overview
The toolset authentication flow works in two phases:
1. **Phase 1**: When the agent tries to get tools from the MCP server without credentials, the toolset signals "authentication required" and returns an auth request event.
2. **Phase 2**: After the user provides OAuth credentials, the agent can successfully list and call tools.
## Files
- `oauth_mcp_server.py` - MCP server that requires Bearer token authentication
- `agent.py` - Agent configuration with OAuth-protected MCP toolset
- `main.py` - Test script demonstrating the two-phase auth flow
## Running the Sample
1. Start the MCP server in one terminal:
```bash
PYTHONPATH=src python contributing/samples/mcp_toolset_auth/oauth_mcp_server.py
```
2. Run the test script in another terminal:
```bash
PYTHONPATH=src python contributing/samples/mcp_toolset_auth/main.py
```
## Expected Behavior
1. First invocation yields an `adk_request_credential` function call
2. The credential ID is `_adk_toolset_auth_McpToolset` to indicate toolset auth
3. After providing the access token, the agent can list and call tools
## Testing with ADK Web UI
You can also test with the ADK web UI:
```bash
adk web contributing/samples/mcp_toolset_auth
```
Note: The web UI will display the auth request and you'll need to manually provide credentials.
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import agent
@@ -0,0 +1,76 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Agent that uses MCP toolset requiring OAuth authentication.
This agent demonstrates the toolset authentication feature where OAuth
credentials are required for both tool listing and tool calling.
"""
from __future__ import annotations
from fastapi.openapi.models import OAuth2
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlows
from google.adk.agents import LlmAgent
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
# OAuth2 auth scheme with authorization code flow
# This specifies the OAuth metadata needed for the full OAuth flow
auth_scheme = OAuth2(
flows=OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl='https://example.com/oauth/authorize',
tokenUrl='https://example.com/oauth/token',
scopes={'read': 'Read access', 'write': 'Write access'},
)
)
)
# OAuth credential with client credentials (used for token exchange)
# In a real scenario, this would be used to obtain the access token
auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
client_id='test_client_id',
client_secret='test_client_secret',
),
)
# Create the MCP toolset with OAuth authentication
mcp_toolset = McpToolset(
connection_params=StreamableHTTPConnectionParams(
url='http://localhost:3001/mcp',
),
auth_scheme=auth_scheme,
auth_credential=auth_credential,
)
# Define the agent that uses the OAuth-protected MCP toolset
root_agent = LlmAgent(
model='gemini-2.0-flash',
name='oauth_mcp_agent',
instruction="""You are a helpful assistant that can access user information.
You have access to tools that require authentication:
- get_user_profile: Get profile information for a specific user
- list_users: List all available users
When the user asks about users, use these tools to help them.""",
tools=[mcp_toolset],
)
@@ -0,0 +1,168 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test script for MCP Toolset OAuth Authentication Flow.
This script demonstrates the two-phase tool discovery flow:
1. First invocation: Agent tries to get tools, auth is required, returns auth
request event (adk_request_credential)
2. User provides OAuth credentials (simulated)
3. Second invocation: Agent has credentials, can list and call tools
Usage:
# Start the MCP server first (in another terminal):
PYTHONPATH=src python contributing/samples/mcp_toolset_auth/oauth_mcp_server.py
# Run the demo:
PYTHONPATH=src python contributing/samples/mcp_toolset_auth/main.py
"""
from __future__ import annotations
import asyncio
from agent import auth_credential
from agent import auth_scheme
from agent import mcp_toolset
from agent import root_agent
from google.adk.auth.auth_credential import AuthCredential
from google.adk.auth.auth_credential import AuthCredentialTypes
from google.adk.auth.auth_credential import OAuth2Auth
from google.adk.auth.auth_tool import AuthConfig
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.genai import types
async def run_demo():
"""Run demo with real MCP server."""
print('=' * 60)
print('MCP Toolset OAuth Authentication Demo')
print('=' * 60)
print('\nNote: Make sure the MCP server is running:')
print(' python oauth_mcp_server.py\n')
# Create session service and runner
session_service = InMemorySessionService()
runner = Runner(
agent=root_agent,
app_name='toolset_auth_demo',
session_service=session_service,
)
# Create a session
session = await session_service.create_session(
app_name='toolset_auth_demo',
user_id='test_user',
)
print(f'Session created: {session.id}')
print('\n--- Phase 1: Initial request (no credentials) ---\n')
# First invocation - should trigger auth request
user_message = 'List all users'
print(f'User: {user_message}')
events = []
auth_function_call_id = None
max_events = 10
try:
async for event in runner.run_async(
session_id=session.id,
user_id='test_user',
new_message=types.Content(
role='user',
parts=[types.Part(text=user_message)],
),
):
events.append(event)
print(f'\nEvent from {event.author}:')
if event.content and event.content.parts:
for part in event.content.parts:
if part.text:
print(f' Text: {part.text}')
if part.function_call:
print(f' Function call: {part.function_call.name}')
if part.function_call.name == 'adk_request_credential':
auth_function_call_id = part.function_call.id
if len(events) >= max_events:
print(f'\n** SAFETY LIMIT ({max_events} events) **')
break
except Exception as e:
print(f'\nError: {e}')
print('Make sure the MCP server is running!')
await mcp_toolset.close()
return
if auth_function_call_id:
print('\n** Auth request detected! **')
print('\n--- Phase 2: Provide OAuth credentials ---\n')
# Simulate user providing OAuth credentials after completing OAuth flow
auth_response = AuthConfig(
auth_scheme=auth_scheme,
raw_auth_credential=auth_credential,
exchanged_auth_credential=AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2,
oauth2=OAuth2Auth(
access_token='test_access_token_12345',
),
),
)
print('Providing access token: test_access_token_12345')
auth_response_message = types.Content(
role='user',
parts=[
types.Part(
function_response=types.FunctionResponse(
name='adk_request_credential',
id=auth_function_call_id,
response=auth_response.model_dump(exclude_none=True),
)
)
],
)
async for event in runner.run_async(
session_id=session.id,
user_id='test_user',
new_message=auth_response_message,
):
print(f'\nEvent from {event.author}:')
if event.content and event.content.parts:
for part in event.content.parts:
if part.text:
text = (
part.text[:200] + '...' if len(part.text) > 200 else part.text
)
print(f' Text: {text}')
if part.function_call:
print(f' Function call: {part.function_call.name}')
else:
print('\n** No auth request - credentials may already be available **')
print('\n' + '=' * 60)
print('Demo completed')
print('=' * 60)
await mcp_toolset.close()
if __name__ == '__main__':
asyncio.run(run_demo())
@@ -0,0 +1,120 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MCP Server that requires OAuth Bearer token for both tool listing and calling.
This server validates the Authorization header on every request including:
- Tool listing (list_tools endpoint)
- Tool calling (call_tool endpoint)
This is used to test the toolset authentication feature in ADK.
"""
from __future__ import annotations
import logging
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Request
from mcp.server.fastmcp import Context
from mcp.server.fastmcp import FastMCP
import uvicorn
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Expected OAuth token for testing
VALID_TOKEN = 'test_access_token_12345'
# Create FastMCP server
mcp = FastMCP('OAuth Protected MCP Server', host='localhost', port=3001)
def validate_auth_header(request: Request) -> bool:
"""Validate the Authorization header contains a valid Bearer token."""
auth_header = request.headers.get('authorization', '')
if not auth_header.startswith('Bearer '):
logger.warning('Missing or invalid Authorization header: %s', auth_header)
return False
token = auth_header[7:] # Remove 'Bearer ' prefix
if token != VALID_TOKEN:
logger.warning('Invalid token: %s', token)
return False
logger.info('Valid token received')
return True
@mcp.tool(description='Get user profile information. Requires authentication.')
def get_user_profile(user_id: str, context: Context) -> dict:
"""Return user profile data for the given user ID."""
logger.info('get_user_profile called for user: %s', user_id)
if context.request_context and context.request_context.request:
if not validate_auth_header(context.request_context.request):
return {'error': 'Unauthorized - invalid or missing token'}
# Mock user data
users = {
'user1': {'id': 'user1', 'name': 'Alice', 'email': 'alice@example.com'},
'user2': {'id': 'user2', 'name': 'Bob', 'email': 'bob@example.com'},
}
if user_id in users:
return users[user_id]
return {'error': f'User {user_id} not found'}
@mcp.tool(description='List all available users. Requires authentication.')
def list_users(context: Context) -> dict:
"""Return a list of all users."""
logger.info('list_users called')
if context.request_context and context.request_context.request:
if not validate_auth_header(context.request_context.request):
return {'error': 'Unauthorized - invalid or missing token'}
return {
'users': [
{'id': 'user1', 'name': 'Alice'},
{'id': 'user2', 'name': 'Bob'},
]
}
# Create custom FastAPI app to add auth middleware for list_tools
app = FastAPI()
@app.middleware('http')
async def auth_middleware(request: Request, call_next):
"""Middleware to validate auth on all MCP endpoints."""
# Check if this is an MCP request
if request.url.path.startswith('/mcp'):
if not validate_auth_header(request):
raise HTTPException(status_code=401, detail='Unauthorized')
return await call_next(request)
if __name__ == '__main__':
print(f'Starting OAuth Protected MCP server on http://localhost:3001')
print(f'Expected token: Bearer {VALID_TOKEN}')
print(
'This server requires authentication for both tool listing and calling.'
)
# Run with streamable-http transport
mcp.run(transport='streamable-http')