""" Token Refresh Middleware for PowerOn Gateway This middleware automatically refreshes expired OAuth tokens when API endpoints are accessed, providing seamless user experience. """ import logging from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response as StarletteResponse from typing import Callable import asyncio from modules.security.tokenRefreshService import token_refresh_service from modules.shared.timezoneUtils import get_utc_timestamp logger = logging.getLogger(__name__) class TokenRefreshMiddleware(BaseHTTPMiddleware): """ Middleware that automatically refreshes expired OAuth tokens when API endpoints are accessed. """ def __init__(self, app, enabled: bool = True): super().__init__(app) self.enabled = enabled self.refresh_endpoints = { '/api/connections', '/api/files', '/api/chat', '/api/msft', '/api/google' } async def dispatch(self, request: Request, call_next: Callable) -> Response: """ Process request and refresh tokens if needed """ if not self.enabled: return await call_next(request) # Check if this is an endpoint that might need token refresh if not self._should_check_tokens(request): return await call_next(request) # Extract user ID from request (if available) user_id = self._extract_user_id(request) if not user_id: return await call_next(request) try: # Perform silent token refresh in background # Don't wait for completion to avoid slowing down the request asyncio.create_task(self._silent_refresh_tokens(user_id)) except Exception as e: logger.warning(f"Error scheduling token refresh: {str(e)}") # Continue with request even if refresh scheduling fails # Process the original request response = await call_next(request) return response def _should_check_tokens(self, request: Request) -> bool: """ Check if this request should trigger token refresh """ path = request.url.path # Only check specific API endpoints for endpoint in self.refresh_endpoints: if path.startswith(endpoint): return True return False def _extract_user_id(self, request: Request) -> str: """ Extract user ID from request context """ try: # Try to get user from request state (set by auth middleware) if hasattr(request.state, 'user_id'): return request.state.user_id # Try to get from JWT token in cookies or headers # This is a fallback if user state is not available return None except Exception as e: logger.debug(f"Could not extract user ID: {str(e)}") return None async def _silent_refresh_tokens(self, user_id: str) -> None: """ Perform silent token refresh for the user """ try: logger.debug(f"Starting silent token refresh for user {user_id}") # Refresh expired tokens result = await token_refresh_service.refresh_expired_tokens(user_id) if result.get("refreshed", 0) > 0: logger.info(f"Silently refreshed {result['refreshed']} tokens for user {user_id}") except Exception as e: logger.error(f"Error in silent token refresh for user {user_id}: {str(e)}") class ProactiveTokenRefreshMiddleware(BaseHTTPMiddleware): """ Middleware that proactively refreshes tokens before they expire """ def __init__(self, app, enabled: bool = True, check_interval_minutes: int = 5): super().__init__(app) self.enabled = enabled self.check_interval_minutes = check_interval_minutes self.last_check = {} async def dispatch(self, request: Request, call_next: Callable) -> Response: """ Process request and check for proactive refresh needs """ if not self.enabled: return await call_next(request) # Extract user ID from request user_id = self._extract_user_id(request) if not user_id: return await call_next(request) # Check if we need to do proactive refresh if self._should_check_proactive_refresh(user_id): try: # Perform proactive refresh in background asyncio.create_task(self._proactive_refresh_tokens(user_id)) self.last_check[user_id] = get_utc_timestamp() except Exception as e: logger.warning(f"Error scheduling proactive refresh: {str(e)}") # Process the original request response = await call_next(request) return response def _extract_user_id(self, request: Request) -> str: """ Extract user ID from request context """ try: if hasattr(request.state, 'user_id'): return request.state.user_id return None except Exception: return None def _should_check_proactive_refresh(self, user_id: str) -> bool: """ Check if we should perform proactive refresh for this user """ try: current_time = get_utc_timestamp() last_check = self.last_check.get(user_id, 0) # Check every 5 minutes return (current_time - last_check) > (self.check_interval_minutes * 60) except Exception: return False async def _proactive_refresh_tokens(self, user_id: str) -> None: """ Perform proactive token refresh for the user """ try: logger.debug(f"Starting proactive token refresh for user {user_id}") result = await token_refresh_service.proactive_refresh(user_id) if result.get("refreshed", 0) > 0: logger.info(f"Proactively refreshed {result['refreshed']} tokens for user {user_id}") except Exception as e: logger.error(f"Error in proactive token refresh for user {user_id}: {str(e)}")