MCP Security Configuration

This guide covers security configuration and best practices for MCP (Model Context Protocol) implementations.

Authentication

API Key Authentication

# auth.py
import os
import hashlib
from functools import wraps
from flask import request, jsonify

class APIKeyAuth:
    def __init__(self):
        self.valid_keys = set(os.environ.get('MCP_API_KEYS', '').split(','))
        self.key_hashes = {self._hash_key(key) for key in self.valid_keys if key}
    
    def _hash_key(self, key):
        return hashlib.sha256(key.encode()).hexdigest()
    
    def verify_key(self, provided_key):
        if not provided_key:
            return False
        return self._hash_key(provided_key) in self.key_hashes
    
    def require_auth(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            auth_header = request.headers.get('Authorization')
            
            if not auth_header:
                return jsonify({'error': 'Missing Authorization header'}), 401
            
            try:
                scheme, token = auth_header.split(' ', 1)
                if scheme.lower() != 'bearer':
                    return jsonify({'error': 'Invalid authorization scheme'}), 401
                
                if not self.verify_key(token):
                    return jsonify({'error': 'Invalid API key'}), 401
                
                return func(*args, **kwargs)
            except ValueError:
                return jsonify({'error': 'Invalid authorization format'}), 401
        
        return wrapper

JWT Authentication

# jwt_auth.py
import jwt
import os
from datetime import datetime, timedelta
from functools import wraps
from flask import request, jsonify

class JWTAuth:
    def __init__(self):
        self.secret_key = os.environ.get('MCP_JWT_SECRET')
        self.algorithm = 'HS256'
        self.token_expiry = timedelta(hours=24)
    
    def generate_token(self, user_id, permissions=None):
        payload = {
            'user_id': user_id,
            'permissions': permissions or [],
            'exp': datetime.utcnow() + self.token_expiry,
            'iat': datetime.utcnow()
        }
        
        return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
    
    def verify_token(self, token):
        try:
            payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
            return payload
        except jwt.ExpiredSignatureError:
            return None
        except jwt.InvalidTokenError:
            return None
    
    def require_auth(self, required_permissions=None):
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                auth_header = request.headers.get('Authorization')
                
                if not auth_header:
                    return jsonify({'error': 'Missing Authorization header'}), 401
                
                try:
                    scheme, token = auth_header.split(' ', 1)
                    if scheme.lower() != 'bearer':
                        return jsonify({'error': 'Invalid authorization scheme'}), 401
                    
                    payload = self.verify_token(token)
                    if not payload:
                        return jsonify({'error': 'Invalid or expired token'}), 401
                    
                    # Check permissions
                    if required_permissions:
                        user_permissions = set(payload.get('permissions', []))
                        if not user_permissions.intersection(required_permissions):
                            return jsonify({'error': 'Insufficient permissions'}), 403
                    
                    request.user = payload
                    return func(*args, **kwargs)
                except ValueError:
                    return jsonify({'error': 'Invalid authorization format'}), 401
            
            return wrapper
        return decorator

Input Validation

Request Validation

# validation.py
import re
from functools import wraps
from flask import request, jsonify
from marshmallow import Schema, fields, ValidationError

class ToolExecutionSchema(Schema):
    tool = fields.Str(required=True, validate=lambda x: re.match(r'^[a-zA-Z0-9_-]+$', x))
    parameters = fields.Dict(required=True)
    timeout = fields.Int(missing=30, validate=lambda x: 1 <= x <= 300)

class RequestValidator:
    def __init__(self):
        self.schemas = {
            'tool_execution': ToolExecutionSchema(),
        }
    
    def validate_request(self, schema_name):
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                schema = self.schemas.get(schema_name)
                if not schema:
                    return jsonify({'error': 'Unknown schema'}), 500
                
                try:
                    data = request.get_json()
                    validated_data = schema.load(data)
                    request.validated_data = validated_data
                    return func(*args, **kwargs)
                except ValidationError as e:
                    return jsonify({'error': 'Validation failed', 'details': e.messages}), 400
            
            return wrapper
        return decorator

Parameter Sanitization

# sanitization.py
import html
import re
from typing import Any, Dict

class ParameterSanitizer:
    def __init__(self):
        self.dangerous_patterns = [
            r'<script[^>]*>.*?</script>',
            r'javascript:',
            r'on\w+\s*=',
            r'<iframe[^>]*>.*?</iframe>',
        ]
    
    def sanitize_string(self, value: str) -> str:
        if not isinstance(value, str):
            return value
        
        # HTML escape
        value = html.escape(value)
        
        # Remove dangerous patterns
        for pattern in self.dangerous_patterns:
            value = re.sub(pattern, '', value, flags=re.IGNORECASE)
        
        return value
    
    def sanitize_parameters(self, parameters: Dict[str, Any]) -> Dict[str, Any]:
        sanitized = {}
        
        for key, value in parameters.items():
            # Sanitize key
            clean_key = self.sanitize_string(key)
            
            # Sanitize value
            if isinstance(value, str):
                sanitized[clean_key] = self.sanitize_string(value)
            elif isinstance(value, dict):
                sanitized[clean_key] = self.sanitize_parameters(value)
            elif isinstance(value, list):
                sanitized[clean_key] = [
                    self.sanitize_string(item) if isinstance(item, str) else item
                    for item in value
                ]
            else:
                sanitized[clean_key] = value
        
        return sanitized

Rate Limiting

Token Bucket Implementation

# rate_limiter.py
import time
from collections import defaultdict
from threading import Lock

class TokenBucket:
    def __init__(self, capacity, refill_rate):
        self.capacity = capacity
        self.tokens = capacity
        self.refill_rate = refill_rate
        self.last_refill = time.time()
        self.lock = Lock()
    
    def consume(self, tokens=1):
        with self.lock:
            now = time.time()
            # Add tokens based on time elapsed
            tokens_to_add = (now - self.last_refill) * self.refill_rate
            self.tokens = min(self.capacity, self.tokens + tokens_to_add)
            self.last_refill = now
            
            if self.tokens >= tokens:
                self.tokens -= tokens
                return True
            return False

class RateLimiter:
    def __init__(self):
        self.buckets = defaultdict(lambda: TokenBucket(capacity=100, refill_rate=10))
        self.lock = Lock()
    
    def is_allowed(self, key, tokens=1):
        return self.buckets[key].consume(tokens)
    
    def require_rate_limit(self, key_func=None):
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                # Generate key for rate limiting
                if key_func:
                    key = key_func()
                else:
                    key = request.remote_addr
                
                if not self.is_allowed(key):
                    return jsonify({'error': 'Rate limit exceeded'}), 429
                
                return func(*args, **kwargs)
            
            return wrapper
        return decorator

Secure Configuration

Environment Variables

# config.py
import os
from typing import Optional

class SecurityConfig:
    def __init__(self):
        self.api_keys = self._get_api_keys()
        self.jwt_secret = self._get_jwt_secret()
        self.cors_origins = self._get_cors_origins()
        self.tls_cert_path = os.environ.get('MCP_TLS_CERT_PATH')
        self.tls_key_path = os.environ.get('MCP_TLS_KEY_PATH')
        self.require_https = os.environ.get('MCP_REQUIRE_HTTPS', 'true').lower() == 'true'
    
    def _get_api_keys(self) -> set:
        keys = os.environ.get('MCP_API_KEYS', '')
        return set(key.strip() for key in keys.split(',') if key.strip())
    
    def _get_jwt_secret(self) -> Optional[str]:
        secret = os.environ.get('MCP_JWT_SECRET')
        if not secret:
            raise ValueError("MCP_JWT_SECRET environment variable is required")
        return secret
    
    def _get_cors_origins(self) -> list:
        origins = os.environ.get('MCP_CORS_ORIGINS', '')
        return [origin.strip() for origin in origins.split(',') if origin.strip()]
    
    def validate(self):
        if not self.api_keys and not self.jwt_secret:
            raise ValueError("Either MCP_API_KEYS or MCP_JWT_SECRET must be configured")
        
        if self.require_https and not (self.tls_cert_path and self.tls_key_path):
            raise ValueError("TLS certificate and key paths required when HTTPS is enabled")

CORS Configuration

# cors_config.py
from flask_cors import CORS

def configure_cors(app, config):
    if config.cors_origins:
        CORS(app, origins=config.cors_origins)
    else:
        # Default to localhost for development
        CORS(app, origins=['http://localhost:3000', 'https://localhost:3000'])

Encryption

Data Encryption

# encryption.py
import os
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64

class DataEncryption:
    def __init__(self):
        self.key = self._get_or_create_key()
        self.fernet = Fernet(self.key)
    
    def _get_or_create_key(self):
        # Use password-based key derivation
        password = os.environ.get('MCP_ENCRYPTION_PASSWORD', '').encode()
        salt = os.environ.get('MCP_ENCRYPTION_SALT', 'default_salt').encode()
        
        kdf = PBKDF2HMAC(
            algorithm=hashes.SHA256(),
            length=32,
            salt=salt,
            iterations=100000,
        )
        key = base64.urlsafe_b64encode(kdf.derive(password))
        return key
    
    def encrypt(self, data: str) -> str:
        return self.fernet.encrypt(data.encode()).decode()
    
    def decrypt(self, encrypted_data: str) -> str:
        return self.fernet.decrypt(encrypted_data.encode()).decode()

TLS Configuration

# tls_config.py
import ssl
import os

def create_ssl_context():
    cert_path = os.environ.get('MCP_TLS_CERT_PATH')
    key_path = os.environ.get('MCP_TLS_KEY_PATH')
    
    if not cert_path or not key_path:
        return None
    
    context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    context.load_cert_chain(cert_path, key_path)
    
    # Security settings
    context.minimum_version = ssl.TLSVersion.TLSv1_2
    context.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20:!aNULL:!MD5:!DSS')
    
    return context

Logging and Auditing

Security Logging

# security_logging.py
import logging
import json
from datetime import datetime
from flask import request, g

class SecurityLogger:
    def __init__(self):
        self.logger = logging.getLogger('security')
        self.logger.setLevel(logging.INFO)
    
    def log_auth_event(self, event_type, user_id=None, success=True, details=None):
        self.logger.info(json.dumps({
            'event_type': event_type,
            'user_id': user_id,
            'success': success,
            'timestamp': datetime.utcnow().isoformat(),
            'ip_address': request.remote_addr,
            'user_agent': request.headers.get('User-Agent'),
            'details': details or {}
        }))
    
    def log_access_event(self, resource, action, user_id=None, allowed=True):
        self.logger.info(json.dumps({
            'event_type': 'access',
            'resource': resource,
            'action': action,
            'user_id': user_id,
            'allowed': allowed,
            'timestamp': datetime.utcnow().isoformat(),
            'ip_address': request.remote_addr,
            'request_id': getattr(g, 'request_id', None)
        }))
    
    def log_security_event(self, event_type, severity='medium', details=None):
        self.logger.warning(json.dumps({
            'event_type': event_type,
            'severity': severity,
            'timestamp': datetime.utcnow().isoformat(),
            'ip_address': request.remote_addr,
            'details': details or {}
        }))

Security Headers

HTTP Security Headers

# security_headers.py
from flask import Flask

def add_security_headers(app):
    @app.after_request
    def add_headers(response):
        # Prevent clickjacking
        response.headers['X-Frame-Options'] = 'DENY'
        
        # Prevent MIME type sniffing
        response.headers['X-Content-Type-Options'] = 'nosniff'
        
        # XSS protection
        response.headers['X-XSS-Protection'] = '1; mode=block'
        
        # HTTPS enforcement
        response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
        
        # Content Security Policy
        response.headers['Content-Security-Policy'] = (
            "default-src 'self'; "
            "script-src 'self' 'unsafe-inline'; "
            "style-src 'self' 'unsafe-inline'; "
            "img-src 'self' data: https:; "
            "font-src 'self'; "
            "connect-src 'self'; "
            "frame-ancestors 'none';"
        )
        
        # Referrer Policy
        response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
        
        return response

Complete Security Setup

Main Application

# secure_app.py
from flask import Flask, request, jsonify
from security_config import SecurityConfig
from auth import APIKeyAuth, JWTAuth
from validation import RequestValidator
from rate_limiter import RateLimiter
from security_logging import SecurityLogger
from security_headers import add_security_headers
from cors_config import configure_cors
from tls_config import create_ssl_context

app = Flask(__name__)

# Initialize security components
config = SecurityConfig()
api_auth = APIKeyAuth()
jwt_auth = JWTAuth()
validator = RequestValidator()
rate_limiter = RateLimiter()
security_logger = SecurityLogger()

# Configure security
add_security_headers(app)
configure_cors(app, config)

@app.route('/execute', methods=['POST'])
@rate_limiter.require_rate_limit()
@api_auth.require_auth
@validator.validate_request('tool_execution')
def execute_tool():
    data = request.validated_data
    
    # Log access
    security_logger.log_access_event('tool_execution', 'execute')
    
    # Execute tool logic here
    result = {"status": "success", "result": "Tool executed"}
    
    return jsonify(result)

if __name__ == '__main__':
    config.validate()
    ssl_context = create_ssl_context()
    
    app.run(
        host='0.0.0.0',
        port=8443 if ssl_context else 8080,
        ssl_context=ssl_context,
        debug=False
    )

Security Best Practices

Development Guidelines

  1. Input Validation: Always validate and sanitize inputs
  2. Authentication: Use strong authentication mechanisms
  3. Authorization: Implement proper access controls
  4. Encryption: Encrypt sensitive data at rest and in transit
  5. Logging: Log all security events
  6. Rate Limiting: Prevent abuse with rate limiting
  7. HTTPS: Always use HTTPS in production
  8. Security Headers: Implement security headers
  9. Regular Updates: Keep dependencies updated
  10. Security Testing: Regularly test for vulnerabilities

Production Checklist

  • API keys are properly configured
  • JWT secrets are secure and rotated
  • HTTPS is enforced
  • Rate limiting is enabled
  • Input validation is implemented
  • Security logging is configured
  • CORS is properly configured
  • Security headers are enabled
  • Dependencies are updated
  • Vulnerability scanning is performed