MCP Monitoring

This guide covers monitoring and observability for MCP (Model Context Protocol) implementations in Kubiya.

Overview

MCP monitoring provides visibility into:

  • Server health and performance
  • Tool execution metrics
  • Resource utilization
  • Error tracking and debugging
  • Client-server communication

Health Monitoring

Health Check Endpoints

# server.py
from flask import Flask, jsonify
import time

app = Flask(__name__)
start_time = time.time()

@app.route('/health')
def health_check():
    return jsonify({
        "status": "healthy",
        "timestamp": time.time(),
        "uptime": time.time() - start_time,
        "version": "1.0.0"
    })

@app.route('/ready')
def readiness_check():
    # Check dependencies
    try:
        # Test database connection
        # Test external APIs
        # Test file system access
        return jsonify({"status": "ready"})
    except Exception as e:
        return jsonify({"status": "not ready", "error": str(e)}), 503

Kubernetes Health Probes

# deployment.yaml
apiVersion: apps/v1
kind: Deployment
metadata:
  name: mcp-server
spec:
  template:
    spec:
      containers:
      - name: mcp-server
        image: mcp-server:latest
        ports:
        - containerPort: 8000
        livenessProbe:
          httpGet:
            path: /health
            port: 8000
          initialDelaySeconds: 30
          periodSeconds: 10
          failureThreshold: 3
        readinessProbe:
          httpGet:
            path: /ready
            port: 8000
          initialDelaySeconds: 5
          periodSeconds: 5
          failureThreshold: 3

Metrics Collection

Prometheus Metrics

# metrics.py
from prometheus_client import Counter, Histogram, Gauge, generate_latest
import time

# Define metrics
REQUEST_COUNT = Counter('mcp_requests_total', 'Total requests', ['method', 'endpoint'])
REQUEST_DURATION = Histogram('mcp_request_duration_seconds', 'Request duration')
ACTIVE_CONNECTIONS = Gauge('mcp_active_connections', 'Active connections')
TOOL_EXECUTIONS = Counter('mcp_tool_executions_total', 'Tool executions', ['tool', 'status'])

class MetricsCollector:
    def __init__(self):
        self.start_time = time.time()
    
    def record_request(self, method, endpoint, duration):
        REQUEST_COUNT.labels(method=method, endpoint=endpoint).inc()
        REQUEST_DURATION.observe(duration)
    
    def record_tool_execution(self, tool_name, status):
        TOOL_EXECUTIONS.labels(tool=tool_name, status=status).inc()
    
    def set_active_connections(self, count):
        ACTIVE_CONNECTIONS.set(count)
    
    def get_metrics(self):
        return generate_latest()

Metrics Endpoint

# app.py
from flask import Flask, Response
from metrics import MetricsCollector
import time

app = Flask(__name__)
metrics = MetricsCollector()

@app.route('/metrics')
def metrics_endpoint():
    return Response(metrics.get_metrics(), mimetype='text/plain')

@app.before_request
def before_request():
    request.start_time = time.time()

@app.after_request
def after_request(response):
    duration = time.time() - request.start_time
    metrics.record_request(
        method=request.method,
        endpoint=request.endpoint,
        duration=duration
    )
    return response

Logging

Structured Logging

# logging_config.py
import logging
import json
import sys
from datetime import datetime

class StructuredFormatter(logging.Formatter):
    def format(self, record):
        log_entry = {
            "timestamp": datetime.utcnow().isoformat(),
            "level": record.levelname,
            "message": record.getMessage(),
            "logger": record.name,
            "module": record.module,
            "function": record.funcName,
            "line": record.lineno
        }
        
        # Add extra fields
        if hasattr(record, 'tool_name'):
            log_entry['tool_name'] = record.tool_name
        if hasattr(record, 'execution_id'):
            log_entry['execution_id'] = record.execution_id
        if hasattr(record, 'user_id'):
            log_entry['user_id'] = record.user_id
            
        return json.dumps(log_entry)

def setup_logging():
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    
    handler = logging.StreamHandler(sys.stdout)
    handler.setFormatter(StructuredFormatter())
    logger.addHandler(handler)
    
    return logger

Request Logging

# middleware.py
import logging
import uuid
from flask import request, g

logger = logging.getLogger(__name__)

@app.before_request
def log_request():
    g.request_id = str(uuid.uuid4())
    logger.info(
        "Request started",
        extra={
            "request_id": g.request_id,
            "method": request.method,
            "path": request.path,
            "remote_addr": request.remote_addr,
            "user_agent": request.headers.get('User-Agent')
        }
    )

@app.after_request
def log_response(response):
    logger.info(
        "Request completed",
        extra={
            "request_id": g.request_id,
            "status_code": response.status_code,
            "response_size": len(response.data) if response.data else 0
        }
    )
    return response

Tracing

OpenTelemetry Integration

# tracing.py
from opentelemetry import trace
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.instrumentation.flask import FlaskInstrumentor
from opentelemetry.instrumentation.requests import RequestsInstrumentor

def setup_tracing(app):
    # Configure tracer
    trace.set_tracer_provider(TracerProvider())
    tracer = trace.get_tracer(__name__)
    
    # Configure Jaeger exporter
    jaeger_exporter = JaegerExporter(
        agent_host_name="localhost",
        agent_port=6831,
    )
    
    span_processor = BatchSpanProcessor(jaeger_exporter)
    trace.get_tracer_provider().add_span_processor(span_processor)
    
    # Instrument Flask
    FlaskInstrumentor().instrument_app(app)
    RequestsInstrumentor().instrument()
    
    return tracer

Custom Spans

# tool_execution.py
from opentelemetry import trace
import logging

tracer = trace.get_tracer(__name__)
logger = logging.getLogger(__name__)

class ToolExecutor:
    def execute_tool(self, tool_name, parameters):
        with tracer.start_as_current_span(f"tool_execution_{tool_name}") as span:
            span.set_attribute("tool.name", tool_name)
            span.set_attribute("tool.parameters", str(parameters))
            
            try:
                logger.info(
                    "Tool execution started",
                    extra={
                        "tool_name": tool_name,
                        "parameters": parameters,
                        "trace_id": span.get_span_context().trace_id
                    }
                )
                
                # Execute tool
                result = self._execute_tool_logic(tool_name, parameters)
                
                span.set_attribute("tool.result", str(result))
                span.set_status(trace.Status(trace.StatusCode.OK))
                
                logger.info(
                    "Tool execution completed",
                    extra={
                        "tool_name": tool_name,
                        "result": result,
                        "trace_id": span.get_span_context().trace_id
                    }
                )
                
                return result
                
            except Exception as e:
                span.set_status(trace.Status(trace.StatusCode.ERROR, str(e)))
                span.record_exception(e)
                
                logger.error(
                    "Tool execution failed",
                    extra={
                        "tool_name": tool_name,
                        "error": str(e),
                        "trace_id": span.get_span_context().trace_id
                    }
                )
                
                raise

Error Tracking

Error Aggregation

# error_tracker.py
import logging
from collections import defaultdict
from datetime import datetime, timedelta

class ErrorTracker:
    def __init__(self):
        self.error_counts = defaultdict(int)
        self.error_details = defaultdict(list)
        
    def record_error(self, error_type, message, context=None):
        self.error_counts[error_type] += 1
        self.error_details[error_type].append({
            "message": message,
            "timestamp": datetime.utcnow(),
            "context": context or {}
        })
        
        # Alert if error rate is high
        if self.error_counts[error_type] > 10:
            self._send_alert(error_type, message)
    
    def _send_alert(self, error_type, message):
        # Send to alerting system
        logging.critical(
            f"High error rate detected: {error_type}",
            extra={
                "error_type": error_type,
                "error_count": self.error_counts[error_type],
                "recent_message": message
            }
        )
    
    def get_error_summary(self):
        return {
            "total_errors": sum(self.error_counts.values()),
            "error_types": dict(self.error_counts),
            "recent_errors": {
                error_type: details[-5:]  # Last 5 errors
                for error_type, details in self.error_details.items()
            }
        }

Exception Handling

# exception_handler.py
import logging
from functools import wraps

logger = logging.getLogger(__name__)
error_tracker = ErrorTracker()

def handle_exceptions(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            error_tracker.record_error(
                error_type=type(e).__name__,
                message=str(e),
                context={
                    "function": func.__name__,
                    "args": str(args),
                    "kwargs": str(kwargs)
                }
            )
            logger.exception(f"Error in {func.__name__}")
            raise
    return wrapper

Performance Monitoring

Response Time Tracking

# performance_monitor.py
import time
from collections import deque
from threading import Lock

class PerformanceMonitor:
    def __init__(self, window_size=1000):
        self.response_times = deque(maxlen=window_size)
        self.lock = Lock()
    
    def record_response_time(self, duration):
        with self.lock:
            self.response_times.append(duration)
    
    def get_stats(self):
        with self.lock:
            if not self.response_times:
                return {}
            
            times = list(self.response_times)
            return {
                "count": len(times),
                "average": sum(times) / len(times),
                "min": min(times),
                "max": max(times),
                "p95": self._percentile(times, 95),
                "p99": self._percentile(times, 99)
            }
    
    def _percentile(self, data, percentile):
        sorted_data = sorted(data)
        index = int(len(sorted_data) * percentile / 100)
        return sorted_data[min(index, len(sorted_data) - 1)]

Dashboard Configuration

Grafana Dashboard

{
  "dashboard": {
    "title": "MCP Server Monitoring",
    "panels": [
      {
        "title": "Request Rate",
        "type": "graph",
        "targets": [
          {
            "expr": "rate(mcp_requests_total[5m])",
            "legendFormat": "{{method}} {{endpoint}}"
          }
        ]
      },
      {
        "title": "Response Time",
        "type": "graph",
        "targets": [
          {
            "expr": "histogram_quantile(0.95, rate(mcp_request_duration_seconds_bucket[5m]))",
            "legendFormat": "95th percentile"
          }
        ]
      },
      {
        "title": "Error Rate",
        "type": "graph",
        "targets": [
          {
            "expr": "rate(mcp_tool_executions_total{status=\"error\"}[5m])",
            "legendFormat": "{{tool}} errors"
          }
        ]
      },
      {
        "title": "Active Connections",
        "type": "stat",
        "targets": [
          {
            "expr": "mcp_active_connections",
            "legendFormat": "Active Connections"
          }
        ]
      }
    ]
  }
}

Alerting Rules

Prometheus Alerting

# alerts.yaml
groups:
  - name: mcp_server_alerts
    rules:
      - alert: HighErrorRate
        expr: rate(mcp_tool_executions_total{status="error"}[5m]) > 0.1
        for: 2m
        labels:
          severity: warning
        annotations:
          summary: "High error rate detected"
          description: "Error rate is {{ $value }} errors per second"
      
      - alert: HighResponseTime
        expr: histogram_quantile(0.95, rate(mcp_request_duration_seconds_bucket[5m])) > 1
        for: 5m
        labels:
          severity: warning
        annotations:
          summary: "High response time detected"
          description: "95th percentile response time is {{ $value }} seconds"
      
      - alert: ServerDown
        expr: up{job="mcp-server"} == 0
        for: 1m
        labels:
          severity: critical
        annotations:
          summary: "MCP Server is down"
          description: "MCP Server has been down for more than 1 minute"

Monitoring Best Practices

Resource Monitoring

# resource_monitor.py
import psutil
import time
from prometheus_client import Gauge

# Define resource metrics
CPU_USAGE = Gauge('mcp_cpu_usage_percent', 'CPU usage percentage')
MEMORY_USAGE = Gauge('mcp_memory_usage_bytes', 'Memory usage in bytes')
DISK_USAGE = Gauge('mcp_disk_usage_percent', 'Disk usage percentage')

class ResourceMonitor:
    def __init__(self):
        self.running = True
    
    def start_monitoring(self):
        while self.running:
            # CPU usage
            cpu_percent = psutil.cpu_percent(interval=1)
            CPU_USAGE.set(cpu_percent)
            
            # Memory usage
            memory = psutil.virtual_memory()
            MEMORY_USAGE.set(memory.used)
            
            # Disk usage
            disk = psutil.disk_usage('/')
            DISK_USAGE.set(disk.percent)
            
            time.sleep(30)  # Update every 30 seconds
    
    def stop_monitoring(self):
        self.running = False

Log Aggregation

# fluentd.conf
<source>
  @type tail
  path /var/log/mcp-server/*.log
  pos_file /var/log/fluentd/mcp-server.log.pos
  tag mcp.server
  format json
</source>

<match mcp.server>
  @type elasticsearch
  host elasticsearch.default.svc.cluster.local
  port 9200
  index_name mcp-server
  type_name logs
</match>