186 lines
6.4 KiB
Python
186 lines
6.4 KiB
Python
"""
|
|
Token Refresh Middleware for PowerOn Gateway
|
|
|
|
This middleware automatically refreshes expired OAuth tokens
|
|
when API endpoints are accessed, providing seamless user experience.
|
|
"""
|
|
|
|
import logging
|
|
from fastapi import Request, Response
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from typing import Callable
|
|
import asyncio
|
|
from modules.auth.tokenRefreshService import token_refresh_service
|
|
from modules.shared.timeUtils import getUtcTimestamp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TokenRefreshMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Middleware that automatically refreshes expired OAuth tokens
|
|
when API endpoints are accessed.
|
|
"""
|
|
|
|
def __init__(self, app, enabled: bool = True):
|
|
super().__init__(app)
|
|
self.enabled = enabled
|
|
self.refresh_endpoints = {
|
|
'/api/connections',
|
|
'/api/files',
|
|
'/api/chat',
|
|
'/api/msft',
|
|
'/api/google'
|
|
}
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
"""
|
|
Process request and refresh tokens if needed
|
|
"""
|
|
if not self.enabled:
|
|
return await call_next(request)
|
|
|
|
# Check if this is an endpoint that might need token refresh
|
|
if not self._should_check_tokens(request):
|
|
return await call_next(request)
|
|
|
|
# Extract user ID from request (if available)
|
|
user_id = self._extract_user_id(request)
|
|
if not user_id:
|
|
return await call_next(request)
|
|
|
|
try:
|
|
# Perform silent token refresh in background
|
|
# Don't wait for completion to avoid slowing down the request
|
|
asyncio.create_task(self._silent_refresh_tokens(user_id))
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error scheduling token refresh: {str(e)}")
|
|
# Continue with request even if refresh scheduling fails
|
|
|
|
# Process the original request
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
def _should_check_tokens(self, request: Request) -> bool:
|
|
"""
|
|
Check if this request should trigger token refresh
|
|
"""
|
|
path = request.url.path
|
|
|
|
# Only check specific API endpoints
|
|
for endpoint in self.refresh_endpoints:
|
|
if path.startswith(endpoint):
|
|
return True
|
|
|
|
return False
|
|
|
|
def _extract_user_id(self, request: Request) -> str:
|
|
"""
|
|
Extract user ID from request context
|
|
"""
|
|
try:
|
|
# Try to get user from request state (set by auth middleware)
|
|
if hasattr(request.state, 'user_id'):
|
|
return request.state.user_id
|
|
|
|
# Try to get from JWT token in cookies or headers
|
|
# This is a fallback if user state is not available
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.debug(f"Could not extract user ID: {str(e)}")
|
|
return None
|
|
|
|
async def _silent_refresh_tokens(self, user_id: str) -> None:
|
|
"""
|
|
Perform silent token refresh for the user
|
|
"""
|
|
try:
|
|
logger.debug(f"Starting silent token refresh for user {user_id}")
|
|
|
|
# Refresh expired tokens
|
|
result = await token_refresh_service.refresh_expired_tokens(user_id)
|
|
|
|
if result.get("refreshed", 0) > 0:
|
|
logger.info(f"Silently refreshed {result['refreshed']} tokens for user {user_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in silent token refresh for user {user_id}: {str(e)}")
|
|
|
|
class ProactiveTokenRefreshMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
Middleware that proactively refreshes tokens before they expire
|
|
"""
|
|
|
|
def __init__(self, app, enabled: bool = True, check_interval_minutes: int = 5):
|
|
super().__init__(app)
|
|
self.enabled = enabled
|
|
self.check_interval_minutes = check_interval_minutes
|
|
self.last_check = {}
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
"""
|
|
Process request and check for proactive refresh needs
|
|
"""
|
|
if not self.enabled:
|
|
return await call_next(request)
|
|
|
|
# Extract user ID from request
|
|
user_id = self._extract_user_id(request)
|
|
if not user_id:
|
|
return await call_next(request)
|
|
|
|
# Check if we need to do proactive refresh
|
|
if self._should_check_proactive_refresh(user_id):
|
|
try:
|
|
# Perform proactive refresh in background
|
|
asyncio.create_task(self._proactive_refresh_tokens(user_id))
|
|
self.last_check[user_id] = getUtcTimestamp()
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Error scheduling proactive refresh: {str(e)}")
|
|
|
|
# Process the original request
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
def _extract_user_id(self, request: Request) -> str:
|
|
"""
|
|
Extract user ID from request context
|
|
"""
|
|
try:
|
|
if hasattr(request.state, 'user_id'):
|
|
return request.state.user_id
|
|
return None
|
|
except Exception:
|
|
return None
|
|
|
|
def _should_check_proactive_refresh(self, user_id: str) -> bool:
|
|
"""
|
|
Check if we should perform proactive refresh for this user
|
|
"""
|
|
try:
|
|
current_time = getUtcTimestamp()
|
|
last_check = self.last_check.get(user_id, 0)
|
|
|
|
# Check every 5 minutes
|
|
return (current_time - last_check) > (self.check_interval_minutes * 60)
|
|
|
|
except Exception:
|
|
return False
|
|
|
|
async def _proactive_refresh_tokens(self, user_id: str) -> None:
|
|
"""
|
|
Perform proactive token refresh for the user
|
|
"""
|
|
try:
|
|
logger.debug(f"Starting proactive token refresh for user {user_id}")
|
|
|
|
result = await token_refresh_service.proactive_refresh(user_id)
|
|
|
|
if result.get("refreshed", 0) > 0:
|
|
logger.info(f"Proactively refreshed {result['refreshed']} tokens for user {user_id}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in proactive token refresh for user {user_id}: {str(e)}")
|
|
|