from datetime import datetime, timedelta, timezone from typing import Optional, Dict, Any, Tuple from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jose import JWTError, jwt import logging from modules.gateway_interface import get_gateway_interface from modules.configuration import APP_CONFIG # Get Config Data SECRET_KEY = APP_CONFIG.get("APP_JWT_SECRET_SECRET") ALGORITHM = APP_CONFIG.get("Auth_ALGORITHM") ACCESS_TOKEN_EXPIRE_MINUTES = int(APP_CONFIG.get("APP_TOKEN_EXPIRY")) # OAuth2 Setup oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # Logger logger = logging.getLogger(__name__) def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: """ Creates a JWT Access Token. Args: data: Data to encode (usually user ID or username) expires_delta: Validity duration of the token (optional) Returns: JWT Token as string """ to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt async def get_current_user(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]: """ Extracts and validates the current user from the JWT token. Args: token: JWT Token from the Authorization header Returns: User data Raises: HTTPException: For invalid token or user """ credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: # Decode token payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) # Extract username from token username: str = payload.get("sub") if username is None: raise credentials_exception # Extract mandate ID from token (if present) mandate_id: int = payload.get("mandate_id", 1) # Default: Root mandate except JWTError: logger.warning("Invalid JWT Token") raise credentials_exception # Initialize Gateway Interface without context gateway = get_gateway_interface() # Retrieve user from database user = gateway.get_user_by_username(username) if user is None: logger.warning(f"User {username} not found") raise credentials_exception if user.get("disabled", False): logger.warning(f"User {username} is disabled") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled") return user async def get_current_active_user(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]: """ Ensures that the user is active. Args: current_user: Current user data Returns: User data Raises: HTTPException: If the user is disabled """ if current_user.get("disabled", False): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled") return current_user async def get_user_context(current_user: Dict[str, Any]) -> Tuple[int, int]: """ Extracts the mandate ID and user ID from the current user. Enhanced with better logging. Args: current_user: The current user Returns: Tuple of (mandate_id, user_id) """ # Default values default_mandate_id = 0 default_user_id = 0 # Extract mandate_id mandate_id = current_user.get("mandate_id", None) if mandate_id is None: logger.warning(f"No mandate_id found in current_user, using default: {default_mandate_id}") mandate_id = default_mandate_id else: try: mandate_id = int(mandate_id) except (ValueError, TypeError): logger.error(f"Invalid mandate_id value: {mandate_id}, using default: {default_mandate_id}") mandate_id = default_mandate_id # Extract user_id user_id = current_user.get("id", None) if user_id is None: logger.warning(f"No user_id found in current_user, using default: {default_user_id}") user_id = default_user_id else: try: user_id = int(user_id) except (ValueError, TypeError): logger.error(f"Invalid user_id value: {user_id}, using default: {default_user_id}") user_id = default_user_id # logger.info(f"User context: mandate_id={mandate_id}, user_id={user_id}") return mandate_id, user_id