Security configuration and best practices for MCP implementations
# auth.py
import os
import hashlib
from functools import wraps
from flask import request, jsonify
class APIKeyAuth:
def __init__(self):
self.valid_keys = set(os.environ.get('MCP_API_KEYS', '').split(','))
self.key_hashes = {self._hash_key(key) for key in self.valid_keys if key}
def _hash_key(self, key):
return hashlib.sha256(key.encode()).hexdigest()
def verify_key(self, provided_key):
if not provided_key:
return False
return self._hash_key(provided_key) in self.key_hashes
def require_auth(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header:
return jsonify({'error': 'Missing Authorization header'}), 401
try:
scheme, token = auth_header.split(' ', 1)
if scheme.lower() != 'bearer':
return jsonify({'error': 'Invalid authorization scheme'}), 401
if not self.verify_key(token):
return jsonify({'error': 'Invalid API key'}), 401
return func(*args, **kwargs)
except ValueError:
return jsonify({'error': 'Invalid authorization format'}), 401
return wrapper
# jwt_auth.py
import jwt
import os
from datetime import datetime, timedelta
from functools import wraps
from flask import request, jsonify
class JWTAuth:
def __init__(self):
self.secret_key = os.environ.get('MCP_JWT_SECRET')
self.algorithm = 'HS256'
self.token_expiry = timedelta(hours=24)
def generate_token(self, user_id, permissions=None):
payload = {
'user_id': user_id,
'permissions': permissions or [],
'exp': datetime.utcnow() + self.token_expiry,
'iat': datetime.utcnow()
}
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
def verify_token(self, token):
try:
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
return payload
except jwt.ExpiredSignatureError:
return None
except jwt.InvalidTokenError:
return None
def require_auth(self, required_permissions=None):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
auth_header = request.headers.get('Authorization')
if not auth_header:
return jsonify({'error': 'Missing Authorization header'}), 401
try:
scheme, token = auth_header.split(' ', 1)
if scheme.lower() != 'bearer':
return jsonify({'error': 'Invalid authorization scheme'}), 401
payload = self.verify_token(token)
if not payload:
return jsonify({'error': 'Invalid or expired token'}), 401
# Check permissions
if required_permissions:
user_permissions = set(payload.get('permissions', []))
if not user_permissions.intersection(required_permissions):
return jsonify({'error': 'Insufficient permissions'}), 403
request.user = payload
return func(*args, **kwargs)
except ValueError:
return jsonify({'error': 'Invalid authorization format'}), 401
return wrapper
return decorator
# validation.py
import re
from functools import wraps
from flask import request, jsonify
from marshmallow import Schema, fields, ValidationError
class ToolExecutionSchema(Schema):
tool = fields.Str(required=True, validate=lambda x: re.match(r'^[a-zA-Z0-9_-]+$', x))
parameters = fields.Dict(required=True)
timeout = fields.Int(missing=30, validate=lambda x: 1 <= x <= 300)
class RequestValidator:
def __init__(self):
self.schemas = {
'tool_execution': ToolExecutionSchema(),
}
def validate_request(self, schema_name):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
schema = self.schemas.get(schema_name)
if not schema:
return jsonify({'error': 'Unknown schema'}), 500
try:
data = request.get_json()
validated_data = schema.load(data)
request.validated_data = validated_data
return func(*args, **kwargs)
except ValidationError as e:
return jsonify({'error': 'Validation failed', 'details': e.messages}), 400
return wrapper
return decorator
# sanitization.py
import html
import re
from typing import Any, Dict
class ParameterSanitizer:
def __init__(self):
self.dangerous_patterns = [
r'<script[^>]*>.*?</script>',
r'javascript:',
r'on\w+\s*=',
r'<iframe[^>]*>.*?</iframe>',
]
def sanitize_string(self, value: str) -> str:
if not isinstance(value, str):
return value
# HTML escape
value = html.escape(value)
# Remove dangerous patterns
for pattern in self.dangerous_patterns:
value = re.sub(pattern, '', value, flags=re.IGNORECASE)
return value
def sanitize_parameters(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
sanitized = {}
for key, value in parameters.items():
# Sanitize key
clean_key = self.sanitize_string(key)
# Sanitize value
if isinstance(value, str):
sanitized[clean_key] = self.sanitize_string(value)
elif isinstance(value, dict):
sanitized[clean_key] = self.sanitize_parameters(value)
elif isinstance(value, list):
sanitized[clean_key] = [
self.sanitize_string(item) if isinstance(item, str) else item
for item in value
]
else:
sanitized[clean_key] = value
return sanitized
# rate_limiter.py
import time
from collections import defaultdict
from threading import Lock
class TokenBucket:
def __init__(self, capacity, refill_rate):
self.capacity = capacity
self.tokens = capacity
self.refill_rate = refill_rate
self.last_refill = time.time()
self.lock = Lock()
def consume(self, tokens=1):
with self.lock:
now = time.time()
# Add tokens based on time elapsed
tokens_to_add = (now - self.last_refill) * self.refill_rate
self.tokens = min(self.capacity, self.tokens + tokens_to_add)
self.last_refill = now
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
class RateLimiter:
def __init__(self):
self.buckets = defaultdict(lambda: TokenBucket(capacity=100, refill_rate=10))
self.lock = Lock()
def is_allowed(self, key, tokens=1):
return self.buckets[key].consume(tokens)
def require_rate_limit(self, key_func=None):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Generate key for rate limiting
if key_func:
key = key_func()
else:
key = request.remote_addr
if not self.is_allowed(key):
return jsonify({'error': 'Rate limit exceeded'}), 429
return func(*args, **kwargs)
return wrapper
return decorator
# config.py
import os
from typing import Optional
class SecurityConfig:
def __init__(self):
self.api_keys = self._get_api_keys()
self.jwt_secret = self._get_jwt_secret()
self.cors_origins = self._get_cors_origins()
self.tls_cert_path = os.environ.get('MCP_TLS_CERT_PATH')
self.tls_key_path = os.environ.get('MCP_TLS_KEY_PATH')
self.require_https = os.environ.get('MCP_REQUIRE_HTTPS', 'true').lower() == 'true'
def _get_api_keys(self) -> set:
keys = os.environ.get('MCP_API_KEYS', '')
return set(key.strip() for key in keys.split(',') if key.strip())
def _get_jwt_secret(self) -> Optional[str]:
secret = os.environ.get('MCP_JWT_SECRET')
if not secret:
raise ValueError("MCP_JWT_SECRET environment variable is required")
return secret
def _get_cors_origins(self) -> list:
origins = os.environ.get('MCP_CORS_ORIGINS', '')
return [origin.strip() for origin in origins.split(',') if origin.strip()]
def validate(self):
if not self.api_keys and not self.jwt_secret:
raise ValueError("Either MCP_API_KEYS or MCP_JWT_SECRET must be configured")
if self.require_https and not (self.tls_cert_path and self.tls_key_path):
raise ValueError("TLS certificate and key paths required when HTTPS is enabled")
# cors_config.py
from flask_cors import CORS
def configure_cors(app, config):
if config.cors_origins:
CORS(app, origins=config.cors_origins)
else:
# Default to localhost for development
CORS(app, origins=['http://localhost:3000', 'https://localhost:3000'])
# encryption.py
import os
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64
class DataEncryption:
def __init__(self):
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
def _get_or_create_key(self):
# Use password-based key derivation
password = os.environ.get('MCP_ENCRYPTION_PASSWORD', '').encode()
salt = os.environ.get('MCP_ENCRYPTION_SALT', 'default_salt').encode()
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=salt,
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(password))
return key
def encrypt(self, data: str) -> str:
return self.fernet.encrypt(data.encode()).decode()
def decrypt(self, encrypted_data: str) -> str:
return self.fernet.decrypt(encrypted_data.encode()).decode()
# tls_config.py
import ssl
import os
def create_ssl_context():
cert_path = os.environ.get('MCP_TLS_CERT_PATH')
key_path = os.environ.get('MCP_TLS_KEY_PATH')
if not cert_path or not key_path:
return None
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(cert_path, key_path)
# Security settings
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20:!aNULL:!MD5:!DSS')
return context
# security_logging.py
import logging
import json
from datetime import datetime
from flask import request, g
class SecurityLogger:
def __init__(self):
self.logger = logging.getLogger('security')
self.logger.setLevel(logging.INFO)
def log_auth_event(self, event_type, user_id=None, success=True, details=None):
self.logger.info(json.dumps({
'event_type': event_type,
'user_id': user_id,
'success': success,
'timestamp': datetime.utcnow().isoformat(),
'ip_address': request.remote_addr,
'user_agent': request.headers.get('User-Agent'),
'details': details or {}
}))
def log_access_event(self, resource, action, user_id=None, allowed=True):
self.logger.info(json.dumps({
'event_type': 'access',
'resource': resource,
'action': action,
'user_id': user_id,
'allowed': allowed,
'timestamp': datetime.utcnow().isoformat(),
'ip_address': request.remote_addr,
'request_id': getattr(g, 'request_id', None)
}))
def log_security_event(self, event_type, severity='medium', details=None):
self.logger.warning(json.dumps({
'event_type': event_type,
'severity': severity,
'timestamp': datetime.utcnow().isoformat(),
'ip_address': request.remote_addr,
'details': details or {}
}))
# security_headers.py
from flask import Flask
def add_security_headers(app):
@app.after_request
def add_headers(response):
# Prevent clickjacking
response.headers['X-Frame-Options'] = 'DENY'
# Prevent MIME type sniffing
response.headers['X-Content-Type-Options'] = 'nosniff'
# XSS protection
response.headers['X-XSS-Protection'] = '1; mode=block'
# HTTPS enforcement
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
# Content Security Policy
response.headers['Content-Security-Policy'] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"font-src 'self'; "
"connect-src 'self'; "
"frame-ancestors 'none';"
)
# Referrer Policy
response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
return response
# secure_app.py
from flask import Flask, request, jsonify
from security_config import SecurityConfig
from auth import APIKeyAuth, JWTAuth
from validation import RequestValidator
from rate_limiter import RateLimiter
from security_logging import SecurityLogger
from security_headers import add_security_headers
from cors_config import configure_cors
from tls_config import create_ssl_context
app = Flask(__name__)
# Initialize security components
config = SecurityConfig()
api_auth = APIKeyAuth()
jwt_auth = JWTAuth()
validator = RequestValidator()
rate_limiter = RateLimiter()
security_logger = SecurityLogger()
# Configure security
add_security_headers(app)
configure_cors(app, config)
@app.route('/execute', methods=['POST'])
@rate_limiter.require_rate_limit()
@api_auth.require_auth
@validator.validate_request('tool_execution')
def execute_tool():
data = request.validated_data
# Log access
security_logger.log_access_event('tool_execution', 'execute')
# Execute tool logic here
result = {"status": "success", "result": "Tool executed"}
return jsonify(result)
if __name__ == '__main__':
config.validate()
ssl_context = create_ssl_context()
app.run(
host='0.0.0.0',
port=8443 if ssl_context else 8080,
ssl_context=ssl_context,
debug=False
)
Was this page helpful?