gateway/modules/auth.py

159 lines
No EOL
4.8 KiB
Python

from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any, Tuple
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
import logging
from modules.gateway_interface import get_gateway_interface
from gateway.modules.configuration import APP_CONFIG
# Get Config Data
SECRET_KEY = APP_CONFIG.get("APP_JWT_SECRET_SECRET")
ALGORITHM = APP_CONFIG.get("Auth_ALGORITHM")
ACCESS_TOKEN_EXPIRE_MINUTES = int(APP_CONFIG.get("APP_TOKEN_EXPIRY"))
# OAuth2 Setup
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Logger
logger = logging.getLogger(__name__)
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
"""
Creates a JWT Access Token.
Args:
data: Data to encode (usually user ID or username)
expires_delta: Validity duration of the token (optional)
Returns:
JWT Token as string
"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
async def get_current_user(token: str = Depends(oauth2_scheme)) -> Dict[str, Any]:
"""
Extracts and validates the current user from the JWT token.
Args:
token: JWT Token from the Authorization header
Returns:
User data
Raises:
HTTPException: For invalid token or user
"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# Decode token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Extract username from token
username: str = payload.get("sub")
if username is None:
raise credentials_exception
# Extract mandate ID from token (if present)
mandate_id: int = payload.get("mandate_id", 1) # Default: Root mandate
except JWTError:
logger.warning("Invalid JWT Token")
raise credentials_exception
# Initialize Gateway Interface without context
gateway = get_gateway_interface()
# Retrieve user from database
user = gateway.get_user_by_username(username)
if user is None:
logger.warning(f"User {username} not found")
raise credentials_exception
if user.get("disabled", False):
logger.warning(f"User {username} is disabled")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled")
return user
async def get_current_active_user(current_user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
"""
Ensures that the user is active.
Args:
current_user: Current user data
Returns:
User data
Raises:
HTTPException: If the user is disabled
"""
if current_user.get("disabled", False):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled")
return current_user
async def get_user_context(current_user: Dict[str, Any]) -> Tuple[int, int]:
"""
Extracts the mandate ID and user ID from the current user.
Enhanced with better logging.
Args:
current_user: The current user
Returns:
Tuple of (mandate_id, user_id)
"""
# Default values
default_mandate_id = 0
default_user_id = 0
# Extract mandate_id
mandate_id = current_user.get("mandate_id", None)
if mandate_id is None:
logger.warning(f"No mandate_id found in current_user, using default: {default_mandate_id}")
mandate_id = default_mandate_id
else:
try:
mandate_id = int(mandate_id)
except (ValueError, TypeError):
logger.error(f"Invalid mandate_id value: {mandate_id}, using default: {default_mandate_id}")
mandate_id = default_mandate_id
# Extract user_id
user_id = current_user.get("id", None)
if user_id is None:
logger.warning(f"No user_id found in current_user, using default: {default_user_id}")
user_id = default_user_id
else:
try:
user_id = int(user_id)
except (ValueError, TypeError):
logger.error(f"Invalid user_id value: {user_id}, using default: {default_user_id}")
user_id = default_user_id
# logger.info(f"User context: mandate_id={mandate_id}, user_id={user_id}")
return mandate_id, user_id