# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ CSRF Protection Middleware for PowerOn Gateway This module provides CSRF protection for state-changing operations. """ import logging from fastapi import Request, HTTPException, status from starlette.middleware.base import BaseHTTPMiddleware from typing import Set logger = logging.getLogger(__name__) class CSRFMiddleware(BaseHTTPMiddleware): """ CSRF protection middleware that validates CSRF tokens for state-changing operations. """ def __init__(self, app, exempt_paths: Set[str] = None): super().__init__(app) # Paths that are exempt from CSRF protection self.exempt_paths = exempt_paths or { "/api/local/login", "/api/local/register", # OAuth Auth app + Data app (GET redirects / callbacks) "/api/msft/auth/login", "/api/msft/auth/login/callback", "/api/msft/auth/connect", "/api/msft/auth/connect/callback", "/api/msft/adminconsent", "/api/msft/adminconsent/callback", "/api/google/auth/login", "/api/google/auth/login/callback", "/api/google/auth/connect", "/api/google/auth/connect/callback", "/api/clickup/auth/connect", "/api/clickup/auth/connect/callback", "/api/billing/webhook/stripe", # Stripe webhook (auth via Stripe-Signature) } # Path prefixes exempt from CSRF (for service-to-service callbacks) self._exemptPrefixes = [ "/api/teamsbot/", # .NET Media Bridge callbacks (bridge/status, bridge/audio) ] # State-changing HTTP methods that require CSRF protection self.protected_methods = {"POST", "PUT", "DELETE", "PATCH"} async def dispatch(self, request: Request, call_next): """ Check CSRF token for state-changing operations. """ # Skip CSRF check for exempt paths (exact match) if request.url.path in self.exempt_paths: return await call_next(request) # Skip CSRF check for exempt path prefixes (service-to-service callbacks) if any(request.url.path.startswith(p) for p in self._exemptPrefixes): if "/bridge/" in request.url.path or "/bot/" in request.url.path: return await call_next(request) # Skip CSRF check for non-state-changing methods if request.method not in self.protected_methods: return await call_next(request) # Skip CSRF check for OPTIONS requests (CORS preflight) if request.method == "OPTIONS": return await call_next(request) # Get CSRF token from header csrf_token = request.headers.get("X-CSRF-Token") if not csrf_token: logger.warning(f"CSRF token missing for {request.method} {request.url.path}") from fastapi.responses import JSONResponse return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={"detail": "CSRF token missing"} ) # Validate CSRF token format (basic validation) if not self._is_valid_csrf_token(csrf_token): logger.warning(f"Invalid CSRF token format for {request.method} {request.url.path}") from fastapi.responses import JSONResponse return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={"detail": "Invalid CSRF token format"} ) # Additional CSRF validation could be added here: # - Check token against session # - Validate token expiration # - Verify token origin return await call_next(request) def _is_valid_csrf_token(self, token: str) -> bool: """ Basic validation of CSRF token format. Args: token: The CSRF token to validate Returns: bool: True if token format is valid """ if not token or not isinstance(token, str): return False # Basic format validation (hex string, reasonable length) if len(token) < 16 or len(token) > 64: return False # Check if token contains only valid hex characters try: int(token, 16) return True except ValueError: return False