""" Authentication module for backend API. Handles JWT-based authentication, token generation, and user context. """ from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any, Tuple from fastapi import Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt import logging from slowapi import Limiter from slowapi.util import get_remote_address from modules.shared.configuration import APP_CONFIG from modules.interfaces.serviceAppClass import getRootInterface from modules.interfaces.serviceAppModel import Session, AuthEvent, UserPrivilege, User # Get Config Data SECRET_KEY = APP_CONFIG.get("APP_JWT_SECRET_SECRET") ALGORITHM = APP_CONFIG.get("Auth_ALGORITHM") ACCESS_TOKEN_EXPIRE_MINUTES = int(APP_CONFIG.get("APP_TOKEN_EXPIRY")) REFRESH_TOKEN_EXPIRE_DAYS = int(APP_CONFIG.get("APP_REFRESH_TOKEN_EXPIRY", "7")) # OAuth2 Setup oauth2Scheme = OAuth2PasswordBearer(tokenUrl="token") # Rate Limiter limiter = Limiter(key_func=get_remote_address) # Logger logger = logging.getLogger(__name__) def createAccessToken(data: dict, expiresDelta: Optional[timedelta] = None) -> Tuple[str, datetime]: """ Creates a JWT Access Token. Args: data: Data to encode (usually user ID or username) expiresDelta: Validity duration of the token (optional) Returns: Tuple of (JWT Token as string, expiration datetime) """ toEncode = data.copy() if expiresDelta: expire = datetime.now(timezone.utc) + expiresDelta else: expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) toEncode.update({"exp": expire}) encodedJwt = jwt.encode(toEncode, SECRET_KEY, algorithm=ALGORITHM) return encodedJwt, expire def createRefreshToken(data: dict) -> Tuple[str, datetime]: """ Creates a JWT Refresh Token. Args: data: Data to encode (usually user ID or username) Returns: Tuple of (JWT Token as string, expiration datetime) """ toEncode = data.copy() expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) toEncode.update({"exp": expire, "type": "refresh"}) encodedJwt = jwt.encode(toEncode, SECRET_KEY, algorithm=ALGORITHM) return encodedJwt, expire def _getUserBase(token: str = Depends(oauth2Scheme)) -> User: """ Extracts and validates the current user from the JWT token. Args: token: JWT Token from the Authorization header Returns: User model instance Raises: HTTPException: For invalid token or user """ credentialsException = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: # Decode token payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) # Extract username from token username: str = payload.get("sub") if username is None: raise credentialsException # Extract mandate ID and user ID from token mandateId: str = payload.get("mandateId") userId: str = payload.get("userId") if not mandateId or not userId: logger.error(f"Missing context in token: mandateId={mandateId}, userId={userId}") raise credentialsException except JWTError: logger.warning("Invalid JWT Token") raise credentialsException # Initialize Gateway Interface with context appInterface = getRootInterface() # Retrieve user from database user = appInterface.getUserByUsername(username) if user is None: logger.warning(f"User {username} not found") raise credentialsException if user.disabled: logger.warning(f"User {username} is disabled") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled") # Ensure the user has the correct context if str(user.mandateId) != str(mandateId) or str(user.id) != str(userId): logger.error(f"User context mismatch: token(mandateId={mandateId}, userId={userId}) vs user(mandateId={user.mandateId}, id={user.id})") raise credentialsException return user def getCurrentUser(currentUser: User = Depends(_getUserBase)) -> User: """Get current active user with additional validation.""" if currentUser.disabled: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled" ) return currentUser def createUserSession(userId: str, tokenId: str, request: Request) -> Session: """Create a new user session.""" appInterface = getRootInterface() session = Session( userId=userId, tokenId=tokenId, expiresAt=datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), ipAddress=request.client.host if request.client else None, userAgent=request.headers.get("user-agent") ) # Save session to database appInterface.db.recordCreate("sessions", session.to_dict()) # Log auth event event = AuthEvent( userId=userId, eventType="login", details={"method": "local"}, ipAddress=request.client.host if request.client else None, userAgent=request.headers.get("user-agent") ) appInterface.db.recordCreate("auth_events", event.to_dict()) return session def logAuthEvent(userId: str, eventType: str, details: Dict[str, Any], request: Request) -> None: """Log an authentication event.""" appInterface = getRootInterface() event = AuthEvent( userId=userId, eventType=eventType, details=details, ipAddress=request.client.host if request.client else None, userAgent=request.headers.get("user-agent") ) # Save event to database appInterface.db.recordCreate("auth_events", event.to_dict()) def validateSession(sessionId: str) -> bool: """Validate a user session.""" appInterface = getRootInterface() session = appInterface.db.getRecordset("sessions", recordFilter={"id": sessionId}) if not session: return False session = session[0] if datetime.now(timezone.utc) > session["expiresAt"]: return False # Update last activity appInterface.db.recordModify("sessions", sessionId, { "lastActivity": datetime.now(timezone.utc) }) return True def revokeSession(sessionId: str) -> None: """Revoke a user session.""" appInterface = getRootInterface() # Delete session appInterface.db.recordDelete("sessions", sessionId) def revokeAllUserSessions(userId: str) -> None: """Revoke all sessions for a user.""" appInterface = getRootInterface() # Get all sessions for user sessions = appInterface.db.getRecordset("sessions", recordFilter={"userId": userId}) # Delete each session for session in sessions: appInterface.db.recordDelete("sessions", session["id"])