gateway/modules/auth/csrf.py

112 lines
4 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
CSRF Protection Middleware for PowerOn Gateway
This module provides CSRF protection for state-changing operations.
"""
import logging
from fastapi import Request, HTTPException, status
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Set
logger = logging.getLogger(__name__)
class CSRFMiddleware(BaseHTTPMiddleware):
"""
CSRF protection middleware that validates CSRF tokens for state-changing operations.
"""
def __init__(self, app, exempt_paths: Set[str] = None):
super().__init__(app)
# Paths that are exempt from CSRF protection
self.exempt_paths = exempt_paths or {
"/api/local/login",
"/api/local/register",
"/api/msft/login",
"/api/google/login",
"/api/msft/callback",
"/api/google/callback",
"/api/billing/webhook/stripe", # Stripe webhook (auth via Stripe-Signature)
}
# Path prefixes exempt from CSRF (for service-to-service callbacks)
self._exemptPrefixes = [
"/api/teamsbot/", # .NET Media Bridge callbacks (bridge/status, bridge/audio)
]
# State-changing HTTP methods that require CSRF protection
self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"}
async def dispatch(self, request: Request, call_next):
"""
Check CSRF token for state-changing operations.
"""
# Skip CSRF check for exempt paths (exact match)
if request.url.path in self.exempt_paths:
return await call_next(request)
# Skip CSRF check for exempt path prefixes (service-to-service callbacks)
if any(request.url.path.startswith(p) for p in self._exemptPrefixes):
if "/bridge/" in request.url.path or "/bot/" in request.url.path:
return await call_next(request)
# Skip CSRF check for non-state-changing methods
if request.method not in self.protected_methods:
return await call_next(request)
# Skip CSRF check for OPTIONS requests (CORS preflight)
if request.method == "OPTIONS":
return await call_next(request)
# Get CSRF token from header
csrf_token = request.headers.get("X-CSRF-Token")
if not csrf_token:
logger.warning(f"CSRF token missing for {request.method} {request.url.path}")
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "CSRF token missing"}
)
# Validate CSRF token format (basic validation)
if not self._is_valid_csrf_token(csrf_token):
logger.warning(f"Invalid CSRF token format for {request.method} {request.url.path}")
from fastapi.responses import JSONResponse
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "Invalid CSRF token format"}
)
# Additional CSRF validation could be added here:
# - Check token against session
# - Validate token expiration
# - Verify token origin
return await call_next(request)
def _is_valid_csrf_token(self, token: str) -> bool:
"""
Basic validation of CSRF token format.
Args:
token: The CSRF token to validate
Returns:
bool: True if token format is valid
"""
if not token or not isinstance(token, str):
return False
# Basic format validation (hex string, reasonable length)
if len(token) < 16 or len(token) > 64:
return False
# Check if token contains only valid hex characters
try:
int(token, 16)
return True
except ValueError:
return False