gateway/modules/security/auth.py

173 lines
No EOL
5.4 KiB
Python

"""
Authentication module for backend API.
Handles JWT-based authentication, token generation, and user context.
"""
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any, Tuple
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
import logging
from modules.interfaces.gatewayInterface import getGatewayInterface
from modules.shared.configuration import APP_CONFIG
# Get Config Data
SECRET_KEY = APP_CONFIG.get("APP_JWT_SECRET_SECRET")
ALGORITHM = APP_CONFIG.get("Auth_ALGORITHM")
ACCESS_TOKEN_EXPIRE_MINUTES = int(APP_CONFIG.get("APP_TOKEN_EXPIRY"))
# OAuth2 Setup
oauth2Scheme = OAuth2PasswordBearer(tokenUrl="token")
# Logger
logger = logging.getLogger(__name__)
def createAccessToken(data: dict, expiresDelta: Optional[timedelta] = None) -> str:
"""
Creates a JWT Access Token.
Args:
data: Data to encode (usually user ID or username)
expiresDelta: Validity duration of the token (optional)
Returns:
JWT Token as string
"""
toEncode = data.copy()
if expiresDelta:
expire = datetime.now(timezone.utc) + expiresDelta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
toEncode.update({"exp": expire})
encodedJwt = jwt.encode(toEncode, SECRET_KEY, algorithm=ALGORITHM)
return encodedJwt
async def getCurrentUser(token: str = Depends(oauth2Scheme)) -> Dict[str, Any]:
"""
Extracts and validates the current user from the JWT token.
Args:
token: JWT Token from the Authorization header
Returns:
User data
Raises:
HTTPException: For invalid token or user
"""
credentialsException = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
# Decode token
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
# Extract username from token
username: str = payload.get("sub")
if username is None:
raise credentialsException
# Extract mandate ID and user ID from token
_mandateId: str = payload.get("_mandateId")
_userId: str = payload.get("_userId")
if not _mandateId or not _userId:
logger.error(f"Missing context in token: _mandateId={_mandateId}, _userId={_userId}")
raise credentialsException
except JWTError:
logger.warning("Invalid JWT Token")
raise credentialsException
# Initialize Gateway Interface with context
gateway = getGatewayInterface(_mandateId, _userId)
# Retrieve user from database
user = gateway.getUserByUsername(username)
if user is None:
logger.warning(f"User {username} not found")
raise credentialsException
if user.get("disabled", False):
logger.warning(f"User {username} is disabled")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled")
# Ensure the user has the correct context
if str(user.get("_mandateId")) != str(_mandateId) or str(user.get("id")) != str(_userId):
logger.error(f"User context mismatch: token(_mandateId={_mandateId}, _userId={_userId}) vs user(_mandateId={user.get('_mandateId')}, id={user.get('id')})")
raise credentialsException
return user
async def getCurrentActiveUser(currentUser: Dict[str, Any] = Depends(getCurrentUser)) -> Dict[str, Any]:
"""
Ensures that the user is active.
Args:
currentUser: Current user data
Returns:
User data
Raises:
HTTPException: If the user is disabled
"""
if currentUser.get("disabled", False):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is disabled")
return currentUser
async def getUserContext(currentUser: Dict[str, Any]) -> Tuple[str, str]:
"""
Extracts the mandate ID and user ID from the current user.
Args:
currentUser: The current user
Returns:
Tuple of (_mandateId, _userId) as strings
Raises:
HTTPException: If mandate or user ID is missing
"""
# Extract _mandateId
_mandateId = currentUser.get("_mandateId")
if not _mandateId:
logger.error("No _mandateId found in currentUser")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing mandate context"
)
# Extract _userId
_userId = currentUser.get("id") # Note: using 'id' instead of '_userId'
if not _userId:
logger.error("No _userId found in currentUser")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing user context"
)
return str(_mandateId), str(_userId)
def getInitialContext() -> tuple[str, str]:
"""
Returns the initial mandate and user IDs from the gateway.
This is used by other interfaces to get their context.
Returns:
tuple[str, str]: (_mandateId, _userId) or (None, None) if not available
"""
gateway = getGatewayInterface()
mandateId = gateway.getInitialId("mandates")
userId = gateway.getInitialId("users")
return mandateId, userId