gateway/modules/auth/tokenRefreshMiddleware.py
2025-12-09 23:25:06 +01:00

186 lines
6.4 KiB
Python

"""
Token Refresh Middleware for PowerOn Gateway
This middleware automatically refreshes expired OAuth tokens
when API endpoints are accessed, providing seamless user experience.
"""
import logging
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Callable
import asyncio
from modules.auth.tokenRefreshService import token_refresh_service
from modules.shared.timeUtils import getUtcTimestamp
logger = logging.getLogger(__name__)
class TokenRefreshMiddleware(BaseHTTPMiddleware):
"""
Middleware that automatically refreshes expired OAuth tokens
when API endpoints are accessed.
"""
def __init__(self, app, enabled: bool = True):
super().__init__(app)
self.enabled = enabled
self.refresh_endpoints = {
'/api/connections',
'/api/files',
'/api/chat',
'/api/msft',
'/api/google'
}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Process request and refresh tokens if needed
"""
if not self.enabled:
return await call_next(request)
# Check if this is an endpoint that might need token refresh
if not self._should_check_tokens(request):
return await call_next(request)
# Extract user ID from request (if available)
user_id = self._extract_user_id(request)
if not user_id:
return await call_next(request)
try:
# Perform silent token refresh in background
# Don't wait for completion to avoid slowing down the request
asyncio.create_task(self._silent_refresh_tokens(user_id))
except Exception as e:
logger.warning(f"Error scheduling token refresh: {str(e)}")
# Continue with request even if refresh scheduling fails
# Process the original request
response = await call_next(request)
return response
def _should_check_tokens(self, request: Request) -> bool:
"""
Check if this request should trigger token refresh
"""
path = request.url.path
# Only check specific API endpoints
for endpoint in self.refresh_endpoints:
if path.startswith(endpoint):
return True
return False
def _extract_user_id(self, request: Request) -> str:
"""
Extract user ID from request context
"""
try:
# Try to get user from request state (set by auth middleware)
if hasattr(request.state, 'user_id'):
return request.state.user_id
# Try to get from JWT token in cookies or headers
# This is a fallback if user state is not available
return None
except Exception as e:
logger.debug(f"Could not extract user ID: {str(e)}")
return None
async def _silent_refresh_tokens(self, user_id: str) -> None:
"""
Perform silent token refresh for the user
"""
try:
logger.debug(f"Starting silent token refresh for user {user_id}")
# Refresh expired tokens
result = await token_refresh_service.refresh_expired_tokens(user_id)
if result.get("refreshed", 0) > 0:
logger.info(f"Silently refreshed {result['refreshed']} tokens for user {user_id}")
except Exception as e:
logger.error(f"Error in silent token refresh for user {user_id}: {str(e)}")
class ProactiveTokenRefreshMiddleware(BaseHTTPMiddleware):
"""
Middleware that proactively refreshes tokens before they expire
"""
def __init__(self, app, enabled: bool = True, check_interval_minutes: int = 5):
super().__init__(app)
self.enabled = enabled
self.check_interval_minutes = check_interval_minutes
self.last_check = {}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Process request and check for proactive refresh needs
"""
if not self.enabled:
return await call_next(request)
# Extract user ID from request
user_id = self._extract_user_id(request)
if not user_id:
return await call_next(request)
# Check if we need to do proactive refresh
if self._should_check_proactive_refresh(user_id):
try:
# Perform proactive refresh in background
asyncio.create_task(self._proactive_refresh_tokens(user_id))
self.last_check[user_id] = getUtcTimestamp()
except Exception as e:
logger.warning(f"Error scheduling proactive refresh: {str(e)}")
# Process the original request
response = await call_next(request)
return response
def _extract_user_id(self, request: Request) -> str:
"""
Extract user ID from request context
"""
try:
if hasattr(request.state, 'user_id'):
return request.state.user_id
return None
except Exception:
return None
def _should_check_proactive_refresh(self, user_id: str) -> bool:
"""
Check if we should perform proactive refresh for this user
"""
try:
current_time = getUtcTimestamp()
last_check = self.last_check.get(user_id, 0)
# Check every 5 minutes
return (current_time - last_check) > (self.check_interval_minutes * 60)
except Exception:
return False
async def _proactive_refresh_tokens(self, user_id: str) -> None:
"""
Perform proactive token refresh for the user
"""
try:
logger.debug(f"Starting proactive token refresh for user {user_id}")
result = await token_refresh_service.proactive_refresh(user_id)
if result.get("refreshed", 0) > 0:
logger.info(f"Proactively refreshed {result['refreshed']} tokens for user {user_id}")
except Exception as e:
logger.error(f"Error in proactive token refresh for user {user_id}: {str(e)}")