""" Authentication module for backend API. Handles JWT-based authentication, token generation, and user context. """ from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any, Tuple from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt import logging from modules.interfaces.gatewayInterface import getGatewayInterface from modules.shared.configuration import APP_CONFIG # Get Config Data SECRET_KEY = APP_CONFIG.get("APP_JWT_SECRET_SECRET") ALGORITHM = APP_CONFIG.get("Auth_ALGORITHM") ACCESS_TOKEN_EXPIRE_MINUTES = int(APP_CONFIG.get("APP_TOKEN_EXPIRY")) # OAuth2 Setup oauth2Scheme = OAuth2PasswordBearer(tokenUrl="token") # Logger logger = logging.getLogger(__name__) def createAccessToken(data: dict, expiresDelta: Optional[timedelta] = None) -> str: """ Creates a JWT Access Token. Args: data: Data to encode (usually user ID or username) expiresDelta: Validity duration of the token (optional) Returns: JWT Token as string """ toEncode = data.copy() if expiresDelta: expire = datetime.now(timezone.utc) + expiresDelta else: expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) toEncode.update({"exp": expire}) encodedJwt = jwt.encode(toEncode, SECRET_KEY, algorithm=ALGORITHM) return encodedJwt async def getCurrentUser(token: str = Depends(oauth2Scheme)) -> Dict[str, Any]: """ Extracts and validates the current user from the JWT token. Args: token: JWT Token from the Authorization header Returns: User data Raises: HTTPException: For invalid token or user """ credentialsException = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: # Decode token payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) # Extract username from token username: str = payload.get("sub") if username is None: raise credentialsException # Extract mandate ID and user ID from token _mandateId: str = payload.get("_mandateId") _userId: str = payload.get("_userId") if not _mandateId or not _userId: logger.error(f"Missing context in token: _mandateId={_mandateId}, _userId={_userId}") raise credentialsException except JWTError: logger.warning("Invalid JWT Token") raise credentialsException # Initialize Gateway Interface with context gateway = getGatewayInterface(_mandateId, _userId) # Retrieve user from database user = gateway.getUserByUsername(username) if user is None: logger.warning(f"User {username} not found") raise credentialsException if user.get("disabled", False): logger.warning(f"User {username} is disabled") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled") # Ensure the user has the correct context if str(user.get("_mandateId")) != str(_mandateId) or str(user.get("id")) != str(_userId): logger.error(f"User context mismatch: token(_mandateId={_mandateId}, _userId={_userId}) vs user(_mandateId={user.get('_mandateId')}, id={user.get('id')})") raise credentialsException # Add authentication authority to user data user["authenticationAuthority"] = user.get("authenticationAuthority", "local") return user async def getUserContext(currentUser: Dict[str, Any]) -> Tuple[str, str]: """ Extracts the mandate ID and user ID from the current user. Args: currentUser: The current user Returns: Tuple of (_mandateId, _userId) as strings Raises: HTTPException: If mandate or user ID is missing """ # Extract _mandateId _mandateId = currentUser.get("_mandateId") if not _mandateId: logger.error("No _mandateId found in currentUser") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing mandate context" ) # Extract _userId _userId = currentUser.get("id") # Note: using 'id' instead of '_userId' if not _userId: logger.error("No _userId found in currentUser") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing user context" ) return str(_mandateId), str(_userId) def getInitialContext() -> tuple[str, str]: """ Returns the initial mandate and user IDs from the gateway. This is used by other interfaces to get their context. Returns: tuple[str, str]: (_mandateId, _userId) or (None, None) if not available """ gateway = getGatewayInterface() mandateId = gateway.getInitialId("mandates") userId = gateway.getInitialId("users") return mandateId, userId async def getCurrentActiveUser(currentUser: Dict[str, Any] = Depends(getCurrentUser)) -> Dict[str, Any]: """ Gets the current active user and verifies their authentication authority. Args: currentUser: The current user from getCurrentUser Returns: The current user data Raises: HTTPException: If user is disabled or has invalid authentication authority """ if currentUser.get("disabled", False): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled" ) auth_authority = currentUser.get("authenticationAuthority", "local") if auth_authority not in ["local", "microsoft"]: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid authentication authority: {auth_authority}" ) return currentUser