gateway/modules/interfaces/interfaceDbApp.py
2026-01-25 03:01:01 +01:00

2535 lines
95 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Interface to the Gateway system.
Manages users and mandates for authentication.
Multi-Tenant Design:
- User gehört nicht mehr direkt zu einem Mandanten
- mandateId wird aus Request-Context übergeben (X-Mandate-Id Header)
"""
import logging
import math
from typing import Dict, Any, List, Optional, Union
from passlib.context import CryptContext
import uuid
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
from modules.interfaces.interfaceBootstrap import initBootstrap
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
from modules.security.rbac import RbacClass
from modules.datamodels.datamodelUam import (
User,
Mandate,
UserInDB,
UserConnection,
AuthAuthority,
ConnectionStatus,
)
from modules.datamodels.datamodelRbac import (
AccessRule,
AccessRuleContext,
Role,
)
from modules.datamodels.datamodelUam import AccessLevel
from modules.datamodels.datamodelSecurity import Token, AuthEvent, TokenStatus
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
from modules.datamodels.datamodelMembership import (
UserMandate,
UserMandateRole,
FeatureAccess,
FeatureAccessRole,
)
from modules.datamodels.datamodelFeatures import Feature, FeatureInstance
from modules.datamodels.datamodelInvitation import Invitation
logger = logging.getLogger(__name__)
# Singleton factory for AppObjects instances per context
_gatewayInterfaces = {}
# Root interface instance
_rootAppObjects = None
# Password-Hashing
pwdContext = CryptContext(schemes=["argon2"], deprecated="auto")
class AppObjects:
"""
Interface to the Gateway system.
Manages users and mandates.
"""
def __init__(self, currentUser: Optional[User] = None):
"""Initializes the Gateway Interface."""
# Initialize variables
self.currentUser = currentUser # Store User object directly
self.userId = currentUser.id if currentUser else None
self.mandateId = None # mandateId comes from setUserContext, not from User
# Initialize database
self._initializeDatabase()
# Initialize standard records if needed
self._initRecords()
# Set user context if provided
if currentUser:
self.setUserContext(currentUser)
def setUserContext(self, currentUser: User, mandateId: Optional[str] = None):
"""
Sets the user context for the interface.
Multi-Tenant Design:
- mandateId wird explizit übergeben (aus Request-Context / X-Mandate-Id Header)
- isSysAdmin User brauchen kein mandateId für System-Operationen
Args:
currentUser: User object
mandateId: Explicit mandate context (from request header). Required for non-sysadmin.
"""
if not currentUser:
logger.info("Initializing interface without user context")
return
self.currentUser = currentUser # Store User object directly
self.userId = currentUser.id
# mandateId comes from parameter only
self.mandateId = mandateId
# Validate: userId is always required
if not self.userId:
raise ValueError("Invalid user context: id is required")
# Note: mandateId is ALWAYS optional here - it comes from Request-Context, not from User.
# Users are NOT assigned to mandates by design - they get mandate context from the request.
# sysAdmin users can additionally perform cross-mandate operations.
# Add language settings
self.userLanguage = currentUser.language # Default user language
# Initialize RBAC interface
# Pass self.db as dbApp since this interface uses DbApp database
self.rbac = RbacClass(self.db, dbApp=self.db)
# Update database context
self.db.updateContext(self.userId)
def __del__(self):
"""Cleanup method to close database connection."""
if hasattr(self, "db") and self.db is not None:
try:
self.db.close()
except Exception as e:
logger.error(f"Error closing database connection: {e}")
def _initializeDatabase(self):
"""Initializes the database connection directly."""
try:
# Get configuration values with defaults
dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data")
dbDatabase = "poweron_app"
dbUser = APP_CONFIG.get("DB_USER")
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
# Create database connector directly
self.db = DatabaseConnector(
dbHost=dbHost,
dbDatabase=dbDatabase,
dbUser=dbUser,
dbPassword=dbPassword,
dbPort=dbPort,
userId=self.userId,
)
# Initialize database system
self.db.initDbSystem()
logger.info(f"Database initialized successfully for user {self.userId}")
except Exception as e:
logger.error(f"Failed to initialize database: {str(e)}")
raise
def _separateObjectFields(self, model_class, data: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]:
"""Separate simple fields from object fields based on Pydantic model structure."""
simpleFields = {}
objectFields = {}
# Get field information from the Pydantic model
modelFields = model_class.model_fields
for fieldName, value in data.items():
# Check if this field should be stored as JSONB in the database
if fieldName in modelFields:
fieldInfo = modelFields[fieldName]
# Pydantic v2 only
fieldType = fieldInfo.annotation
# Check if this is a JSONB field (Dict, List, or complex types)
# Purely type-based detection - no hardcoded field names
if (fieldType == dict or
fieldType == list or
(hasattr(fieldType, '__origin__') and fieldType.__origin__ in (dict, list))):
# Store as JSONB - include in simple_fields for database storage
simpleFields[fieldName] = value
elif isinstance(value, (str, int, float, bool, type(None))):
# Simple scalar types
simpleFields[fieldName] = value
else:
# Complex objects that should be filtered out
objectFields[fieldName] = value
else:
# Field not in model - treat as scalar if simple, otherwise filter out
# BUT: always include metadata fields (_createdBy, _createdAt, etc.) as they're handled by connector
if fieldName.startswith("_"):
# Metadata fields should be passed through to connector
simpleFields[fieldName] = value
elif isinstance(value, (str, int, float, bool, type(None))):
simpleFields[fieldName] = value
else:
objectFields[fieldName] = value
return simpleFields, objectFields
def _initRecords(self):
"""Initialize standard records if they don't exist."""
initBootstrap(self.db)
def checkRbacPermission(
self,
modelClass: type,
operation: str,
recordId: Optional[str] = None
) -> bool:
"""
Check RBAC permission for a specific operation on a table.
Args:
modelClass: Pydantic model class for the table
operation: Operation to check ('create', 'update', 'delete', 'read')
recordId: Optional record ID for specific record check
Returns:
Boolean indicating permission
"""
if not self.rbac or not self.currentUser:
return False
tableName = modelClass.__name__
permissions = self.rbac.getUserPermissions(
self.currentUser,
AccessRuleContext.DATA,
tableName,
mandateId=self.mandateId
)
if operation == "create":
return permissions.create != AccessLevel.NONE
elif operation == "update":
return permissions.update != AccessLevel.NONE
elif operation == "delete":
return permissions.delete != AccessLevel.NONE
elif operation == "read":
return permissions.read != AccessLevel.NONE
else:
return False
def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Apply filter criteria to records.
Supports:
- General search: {"search": "text"} - searches across all text fields
- Field-specific filters:
- Simple: {"status": "running"} - equals match
- With operator: {"status": {"operator": "equals", "value": "running"}}
- Operators: equals, contains, gt, gte, lt, lte, in, notIn, startsWith, endsWith
Args:
records: List of record dictionaries to filter
filters: Filter criteria dictionary
Returns:
Filtered list of records
"""
if not filters or not records:
return records
filtered = []
for record in records:
matches = True
# Handle general search across text fields
if "search" in filters:
search_term = str(filters["search"]).lower()
if search_term:
# Search in all string fields
found = False
for key, value in record.items():
if isinstance(value, str) and search_term in value.lower():
found = True
break
elif isinstance(value, (int, float)) and search_term in str(value):
found = True
break
if not found:
matches = False
# Handle field-specific filters
for field_name, filter_value in filters.items():
if field_name == "search":
continue # Already handled above
if field_name not in record:
matches = False
break
record_value = record.get(field_name)
# Handle simple value (equals operator)
# Compare as strings to handle type mismatches (frontend sends strings)
if not isinstance(filter_value, dict):
# Convert both to strings for comparison (case-insensitive for strings)
recordStr = str(record_value).lower() if record_value is not None else ""
filterStr = str(filter_value).lower() if filter_value is not None else ""
if recordStr != filterStr:
matches = False
break
continue
# Handle filter with operator
operator = filter_value.get("operator", "equals")
filter_val = filter_value.get("value")
if operator in ["equals", "eq"]:
# Convert both to strings for comparison
recordStr = str(record_value).lower() if record_value is not None else ""
filterStr = str(filter_val).lower() if filter_val is not None else ""
if recordStr != filterStr:
matches = False
break
elif operator == "contains":
record_str = str(record_value).lower() if record_value is not None else ""
filter_str = str(filter_val).lower() if filter_val is not None else ""
if filter_str not in record_str:
matches = False
break
elif operator == "startsWith":
record_str = str(record_value).lower() if record_value is not None else ""
filter_str = str(filter_val).lower() if filter_val is not None else ""
if not record_str.startswith(filter_str):
matches = False
break
elif operator == "endsWith":
record_str = str(record_value).lower() if record_value is not None else ""
filter_str = str(filter_val).lower() if filter_val is not None else ""
if not record_str.endswith(filter_str):
matches = False
break
elif operator == "gt":
try:
record_num = float(record_value) if record_value is not None else float('-inf')
filter_num = float(filter_val) if filter_val is not None else float('-inf')
if record_num <= filter_num:
matches = False
break
except (ValueError, TypeError):
matches = False
break
elif operator == "gte":
try:
record_num = float(record_value) if record_value is not None else float('-inf')
filter_num = float(filter_val) if filter_val is not None else float('-inf')
if record_num < filter_num:
matches = False
break
except (ValueError, TypeError):
matches = False
break
elif operator == "lt":
try:
record_num = float(record_value) if record_value is not None else float('inf')
filter_num = float(filter_val) if filter_val is not None else float('inf')
if record_num >= filter_num:
matches = False
break
except (ValueError, TypeError):
matches = False
break
elif operator == "lte":
try:
record_num = float(record_value) if record_value is not None else float('inf')
filter_num = float(filter_val) if filter_val is not None else float('inf')
if record_num > filter_num:
matches = False
break
except (ValueError, TypeError):
matches = False
break
elif operator == "in":
if not isinstance(filter_val, list):
filter_val = [filter_val]
if record_value not in filter_val:
matches = False
break
elif operator == "notIn":
if not isinstance(filter_val, list):
filter_val = [filter_val]
if record_value in filter_val:
matches = False
break
else:
# Unknown operator - default to equals
if record_value != filter_val:
matches = False
break
if matches:
filtered.append(record)
return filtered
def _applySorting(self, records: List[Dict[str, Any]], sortFields: List[Any]) -> List[Dict[str, Any]]:
"""Apply multi-level sorting to records using stable sort (sorts from least to most significant field)."""
if not sortFields:
return records
# Start with a copy to avoid modifying original
sortedRecords = list(records)
# Sort from least significant to most significant field (reverse order)
# Python's sort is stable, so this creates proper multi-level sorting
for sortField in reversed(sortFields):
# Handle both dict and object formats
if isinstance(sortField, dict):
fieldName = sortField.get("field")
direction = sortField.get("direction", "asc")
else:
fieldName = getattr(sortField, "field", None)
direction = getattr(sortField, "direction", "asc")
if not fieldName:
continue
isDesc = (direction == "desc")
def sortKey(record):
value = record.get(fieldName)
# Handle None values - place them at the end for both directions
if value is None:
# Use a special value that sorts last
return (1, "") # (is_none_flag, empty_value) - sorts after (0, ...)
else:
# Return tuple with type indicator for proper comparison
if isinstance(value, (int, float)):
return (0, value)
elif isinstance(value, str):
return (0, value)
elif isinstance(value, bool):
return (0, value)
else:
return (0, str(value))
# Sort with reverse parameter for descending
sortedRecords.sort(key=sortKey, reverse=isDesc)
return sortedRecords
def getInitialId(self, model_class: type) -> Optional[str]:
"""Returns the initial ID for a table."""
return self.db.getInitialId(model_class)
def _getDefaultMandateId(self) -> str:
"""Get the default mandate ID, creating it if necessary."""
defaultMandateId = self.getInitialId(Mandate)
if not defaultMandateId:
# If no default mandate exists, create one
logger.warning("No default mandate found, creating Root mandate")
self._initRootMandate()
defaultMandateId = self.getInitialId(Mandate)
if not defaultMandateId:
raise ValueError("Failed to get or create default mandate")
return defaultMandateId
def _getPasswordHash(self, password: str) -> str:
"""Creates a hash for a password."""
return pwdContext.hash(password)
def _verifyPassword(self, plainPassword: str, hashedPassword: str) -> bool:
"""Checks if the password matches the hash."""
return pwdContext.verify(plainPassword, hashedPassword)
# User methods
def getUsersByMandate(self, mandateId: str, pagination: Optional[PaginationParams] = None) -> Union[List[User], PaginatedResult]:
"""
Returns users for a specific mandate if user has access.
Supports optional pagination, sorting, and filtering.
For SYSADMIN, returns all users regardless of mandate.
Args:
mandateId: The mandate ID to get users for (ignored for SYSADMIN)
pagination: Optional pagination parameters. If None, returns all items.
Returns:
If pagination is None: List[User]
If pagination is provided: PaginatedResult with items and metadata
"""
# Use RBAC filtering
users = getRecordsetWithRBAC(
self.db,
UserInDB,
self.currentUser,
recordFilter={"mandateId": mandateId} if mandateId else None
)
# Filter out database-specific fields and normalize data
filteredUsers = []
for user in users:
cleanedUser = {k: v for k, v in user.items() if not k.startswith("_")}
# Ensure roleLabels is always a list, not None
if cleanedUser.get("roleLabels") is None:
cleanedUser["roleLabels"] = []
filteredUsers.append(cleanedUser)
# If no pagination requested, return all items
if pagination is None:
return [User(**user) for user in filteredUsers]
# Apply filtering (if filters provided)
if pagination.filters:
filteredUsers = self._applyFilters(filteredUsers, pagination.filters)
# Apply sorting (in order of sortFields)
if pagination.sort:
filteredUsers = self._applySorting(filteredUsers, pagination.sort)
# Count total items after filters
totalItems = len(filteredUsers)
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
# Apply pagination (skip/limit)
startIdx = (pagination.page - 1) * pagination.pageSize
endIdx = startIdx + pagination.pageSize
pagedUsers = filteredUsers[startIdx:endIdx]
# Ensure roleLabels is always a list for paginated results too
for user in pagedUsers:
if user.get("roleLabels") is None:
user["roleLabels"] = []
# Convert to model objects
items = [User(**user) for user in pagedUsers]
return PaginatedResult(
items=items,
totalItems=totalItems,
totalPages=totalPages
)
def getUserByUsername(self, username: str) -> Optional[User]:
"""Returns a user by username."""
try:
# Use RBAC filtering
users = getRecordsetWithRBAC(self.db,
UserInDB,
self.currentUser,
recordFilter={"username": username}
)
if not users:
logger.info(f"No user found with username {username}")
return None
# Return first matching user (should be unique)
userDict = users[0]
# Filter out database-specific fields
cleanedUser = {k: v for k, v in userDict.items() if not k.startswith("_")}
# Ensure roleLabels is always a list, not None
if cleanedUser.get("roleLabels") is None:
cleanedUser["roleLabels"] = []
return User(**cleanedUser)
except Exception as e:
logger.error(f"Error getting user by username: {str(e)}")
return None
def getUser(self, userId: str) -> Optional[User]:
"""Returns a user by ID if user has access."""
try:
# Get users filtered by RBAC
users = getRecordsetWithRBAC(self.db,
UserInDB,
self.currentUser,
recordFilter={"id": userId}
)
if not users:
return None
# User already filtered by RBAC, just clean fields
user_dict = users[0]
cleanedUser = {k: v for k, v in user_dict.items() if not k.startswith("_")}
# Ensure roleLabels is always a list, not None
if cleanedUser.get("roleLabels") is None:
cleanedUser["roleLabels"] = []
return User(**cleanedUser)
except Exception as e:
logger.error(f"Error getting user by ID: {str(e)}")
return None
def _getUserForAuthentication(self, username: str) -> Optional[Dict[str, Any]]:
"""
Get user record by username for authentication purposes.
SECURITY NOTE: This method bypasses RBAC intentionally because:
1. Users are NOT mandate-bound (Multi-Tenant Design)
2. Authentication must work regardless of mandate context
3. RBAC filtering for User table requires mandate context which doesn't exist at login time
This method should ONLY be used for authentication flows.
For all other user queries, use getUserByUsername() which applies RBAC.
Returns:
Full UserInDB record as dict, or None if not found
"""
try:
users = self.db.getRecordset(UserInDB, recordFilter={"username": username})
if not users:
return None
return users[0]
except Exception as e:
logger.error(f"Error getting user for authentication: {str(e)}")
return None
def authenticateLocalUser(self, username: str, password: str) -> Optional[User]:
"""
Authenticates a user by username and password using local authentication.
SECURITY NOTE: Uses _getUserForAuthentication() which bypasses RBAC.
This is intentional because users are mandate-independent.
"""
# Get full user record directly (bypasses RBAC - see _getUserForAuthentication docstring)
userRecord = self._getUserForAuthentication(username)
if not userRecord:
raise ValueError("User not found")
# Check if the user is enabled
if not userRecord.get("enabled", True):
raise ValueError("User is disabled")
# Verify that the user has local authentication enabled
authAuthority = userRecord.get("authenticationAuthority", AuthAuthority.LOCAL)
if authAuthority != AuthAuthority.LOCAL and authAuthority != AuthAuthority.LOCAL.value:
raise ValueError("User does not have local authentication enabled")
# Check if user has a reset token set (password reset required)
if userRecord.get("resetToken"):
raise ValueError("Passwort-Zurücksetzung erforderlich. Bitte prüfen Sie Ihre E-Mail.")
if not userRecord.get("hashedPassword"):
raise ValueError("User has no password set")
if not self._verifyPassword(password, userRecord["hashedPassword"]):
raise ValueError("Invalid password")
# Return clean User object (without password hash and internal fields)
cleanedUser = {k: v for k, v in userRecord.items() if not k.startswith("_") and k != "hashedPassword" and k != "resetToken" and k != "resetTokenExpires"}
# Ensure roleLabels is always a list
if cleanedUser.get("roleLabels") is None:
cleanedUser["roleLabels"] = []
return User(**cleanedUser)
def createUser(
self,
username: str,
password: str = None,
email: str = None,
fullName: str = None,
language: str = "en",
enabled: bool = True,
authenticationAuthority: AuthAuthority = AuthAuthority.LOCAL,
externalId: str = None,
externalUsername: str = None,
externalEmail: str = None,
isSysAdmin: bool = False,
) -> User:
"""
Create a new user.
Note: Role assignment is done via createUserMandate(), not via User fields.
"""
try:
# Ensure username is a string
username = str(username).strip()
# Check global username uniqueness (across ALL mandates)
if not self.isUsernameGloballyUnique(username):
raise ValueError(f"Username '{username}' is already taken")
# For LOCAL auth: password is set exclusively via magic link, not during registration
# Password parameter is only used for internal/admin operations (e.g., resetUserPassword)
if password:
if not isinstance(password, str):
raise ValueError("Password must be a string")
if not password.strip():
raise ValueError("Password cannot be empty")
# Create user data using UserInDB model
# Note: mandateId and roleLabels are REMOVED - use UserMandate + UserMandateRole
userData = UserInDB(
username=username,
email=email,
fullName=fullName,
language=language,
enabled=enabled,
isSysAdmin=isSysAdmin,
authenticationAuthority=authenticationAuthority,
hashedPassword=self._getPasswordHash(password) if password else None,
)
# Create user record
createdRecord = self.db.recordCreate(UserInDB, userData)
if not createdRecord or not createdRecord.get("id"):
raise ValueError("Failed to create user record")
# Add external connection if provided
if externalId and externalUsername:
self.addUserConnection(
createdRecord["id"],
authenticationAuthority,
externalId,
externalUsername,
externalEmail,
)
# Get created user using the returned ID
createdUser = self.db.getRecordset(
UserInDB, recordFilter={"id": createdRecord["id"]}
)
if not createdUser or len(createdUser) == 0:
raise ValueError("Failed to retrieve created user")
# Clear cache to ensure fresh data (already done above)
return User(**createdUser[0])
except ValueError as e:
logger.error(f"Error creating user: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error creating user: {str(e)}")
raise ValueError(f"Failed to create user: {str(e)}")
def updateUser(self, userId: str, updateData: Union[Dict[str, Any], User]) -> User:
"""Update a user's information"""
try:
# Get user
user = self.getUser(userId)
if not user:
raise ValueError(f"User {userId} not found")
# Convert updateData to dict if it's a User model
if isinstance(updateData, User):
updateDict = updateData.model_dump()
else:
updateDict = updateData.copy() if isinstance(updateData, dict) else updateData
# Remove id field from updateDict if present - we'll use userId from parameter
updateDict.pop("id", None)
# Update user data using model
updatedData = user.model_dump()
updatedData.update(updateDict)
# Ensure ID matches userId parameter
updatedData["id"] = userId
updatedUser = User(**updatedData)
# Update user record
self.db.recordModify(UserInDB, userId, updatedUser)
# Get updated user
updatedUser = self.getUser(userId)
if not updatedUser:
raise ValueError("Failed to retrieve updated user")
return updatedUser
except Exception as e:
logger.error(f"Error updating user: {str(e)}")
raise ValueError(f"Failed to update user: {str(e)}")
def disableUser(self, userId: str) -> User:
"""Disables a user if current user has permission."""
return self.updateUser(userId, {"enabled": False})
def enableUser(self, userId: str) -> User:
"""Enables a user if current user has permission."""
return self.updateUser(userId, {"enabled": True})
def resetUserPassword(self, userId: str, newPassword: str) -> bool:
"""Reset a user's password (admin function)."""
try:
if not newPassword or len(newPassword) < 8:
raise ValueError("Password must be at least 8 characters long")
hashedPassword = self._getPasswordHash(newPassword)
self.db.recordModify(UserInDB, userId, {"hashedPassword": hashedPassword})
logger.info(f"Password reset for user {userId}")
return True
except Exception as e:
logger.error(f"Error resetting password for user {userId}: {str(e)}")
return False
def generateResetTokenAndExpiry(self) -> tuple:
"""Generate a new reset token and expiration timestamp.
Returns:
tuple: (tokenUuid: str, expiresTimestamp: float)
"""
token = str(uuid.uuid4())
expiryHours = int(APP_CONFIG.get("Auth_RESET_TOKEN_EXPIRY_HOURS", "24"))
expires = getUtcTimestamp() + (expiryHours * 3600)
return token, expires
def findUserByEmailLocalAuth(self, email: str) -> Optional[User]:
"""Find LOCAL auth user by email (searches across all mandates).
Note: If multiple users exist with the same email (in different mandates),
this returns only the first one. Use findAllUsersByEmailLocalAuth() to get all.
Args:
email: Email address to search for (case-insensitive)
Returns:
User if found, None otherwise
"""
users = self.findAllUsersByEmailLocalAuth(email)
return users[0] if users else None
def findAllUsersByEmailLocalAuth(self, email: str) -> List[User]:
"""Find ALL LOCAL auth users by email (searches across all mandates).
Use this when a user might have multiple accounts with the same email
in different mandates.
Args:
email: Email address to search for (case-insensitive)
Returns:
List of Users (empty list if none found)
"""
if not email:
return []
normalizedEmail = email.lower().strip()
try:
# Search directly without RBAC for cross-mandate search
users = self.db.getRecordset(
UserInDB,
recordFilter={
"email": normalizedEmail,
"authenticationAuthority": AuthAuthority.LOCAL.value
}
)
result = []
for userRecord in users:
cleanedUser = {k: v for k, v in userRecord.items() if not k.startswith("_")}
if cleanedUser.get("roleLabels") is None:
cleanedUser["roleLabels"] = []
result.append(User(**cleanedUser))
return result
except Exception as e:
logger.error(f"Error finding users by email: {str(e)}")
return []
def findUserByEmailAndUsernameLocalAuth(self, email: str, username: str) -> Optional[User]:
"""Find LOCAL auth user by email AND username combination.
This uniquely identifies a user even if they have multiple accounts
with the same email in different mandates.
Args:
email: Email address to search for (case-insensitive)
username: Username to search for (case-sensitive)
Returns:
User if found, None otherwise
"""
if not email or not username:
return None
normalizedEmail = email.lower().strip()
try:
# Search directly without RBAC for cross-mandate search
users = self.db.getRecordset(
UserInDB,
recordFilter={
"email": normalizedEmail,
"username": username,
"authenticationAuthority": AuthAuthority.LOCAL.value
}
)
if users:
cleanedUser = {k: v for k, v in users[0].items() if not k.startswith("_")}
if cleanedUser.get("roleLabels") is None:
cleanedUser["roleLabels"] = []
return User(**cleanedUser)
return None
except Exception as e:
logger.error(f"Error finding user by email and username: {str(e)}")
return None
def isUsernameGloballyUnique(self, username: str) -> bool:
"""Check if username is unique across ALL mandates (no RBAC filtering).
This is used for registration to ensure usernames are globally unique.
Args:
username: Username to check
Returns:
True if username is available (not used), False if already taken
"""
if not username:
return False
try:
# Search directly without RBAC for cross-mandate search
users = self.db.getRecordset(
UserInDB,
recordFilter={"username": username}
)
return len(users) == 0
except Exception as e:
logger.error(f"Error checking username uniqueness: {str(e)}")
return False # Fail safe - assume not unique on error
def findUserByUsernameLocalAuth(self, username: str) -> Optional[User]:
"""Find LOCAL auth user by username (searches across all mandates).
Username is globally unique, so this returns at most one user.
Args:
username: Username to search for
Returns:
User if found, None otherwise
"""
if not username:
return None
try:
# Search directly without RBAC for cross-mandate search
users = self.db.getRecordset(
UserInDB,
recordFilter={
"username": username,
"authenticationAuthority": AuthAuthority.LOCAL.value
}
)
if users:
cleanedUser = {k: v for k, v in users[0].items() if not k.startswith("_")}
if cleanedUser.get("roleLabels") is None:
cleanedUser["roleLabels"] = []
return User(**cleanedUser)
return None
except Exception as e:
logger.error(f"Error finding user by username: {str(e)}")
return None
def setResetToken(self, userId: str, token: str, expires: float, clearPassword: bool = True) -> bool:
"""Set reset token for a user.
Args:
userId: User ID
token: Reset token UUID
expires: Expiration timestamp (float)
clearPassword: If True, clears the password hash
"""
try:
updateData = {
"resetToken": token,
"resetTokenExpires": expires
}
if clearPassword:
updateData["hashedPassword"] = None
self.db.recordModify(UserInDB, userId, updateData)
return True
except Exception as e:
logger.error(f"Error setting reset token for user {userId}: {str(e)}")
return False
def verifyResetToken(self, token: str) -> Optional[User]:
"""Verify reset token and return user if valid.
Returns:
User if token is valid and not expired, None otherwise
"""
if not token:
return None
try:
users = self.db.getRecordset(UserInDB, recordFilter={"resetToken": token})
if not users:
return None
userRecord = users[0]
# Check expiration - ensure expires is converted to float for comparison
expires = userRecord.get("resetTokenExpires")
if expires is not None:
try:
expires = float(expires)
except (ValueError, TypeError):
logger.warning(f"Invalid resetTokenExpires value for user {userRecord.get('id')}: {expires}")
return None
if not expires or getUtcTimestamp() > expires:
logger.warning(f"Reset token expired for user {userRecord.get('id')}")
return None
cleanedUser = {k: v for k, v in userRecord.items() if not k.startswith("_")}
if cleanedUser.get("roleLabels") is None:
cleanedUser["roleLabels"] = []
return User(**cleanedUser)
except Exception as e:
logger.error(f"Error verifying reset token: {str(e)}")
return None
def resetPasswordWithToken(self, token: str, newPassword: str) -> bool:
"""Reset password using token (atomic operation).
Returns:
True if successful, False otherwise
"""
try:
user = self.verifyResetToken(token)
if not user:
return False
if not newPassword or len(newPassword) < 8:
raise ValueError("Password must be at least 8 characters long")
hashedPassword = self._getPasswordHash(newPassword)
# Atomic update: set password, clear token, enable user
self.db.recordModify(UserInDB, user.id, {
"hashedPassword": hashedPassword,
"resetToken": None,
"resetTokenExpires": None,
"enabled": True
})
logger.info(f"Password reset completed for user {user.id}")
return True
except Exception as e:
logger.error(f"Error in resetPasswordWithToken: {str(e)}")
return False
def _deleteUserReferencedData(self, userId: str) -> None:
"""Deletes all data associated with a user."""
try:
# Delete user auth events
events = self.db.getRecordset(AuthEvent, recordFilter={"userId": userId})
for event in events:
self.db.recordDelete(AuthEvent, event["id"])
# Delete user tokens
tokens = self.db.getRecordset(Token, recordFilter={"userId": userId})
for token in tokens:
self.db.recordDelete(Token, token["id"])
# Delete user connections
connections = self.db.getRecordset(
UserConnection, recordFilter={"userId": userId}
)
for conn in connections:
self.db.recordDelete(UserConnection, conn["id"])
logger.info(f"All referenced data for user {userId} has been deleted")
except Exception as e:
logger.error(f"Error deleting referenced data for user {userId}: {str(e)}")
raise
def deleteUser(self, userId: str) -> bool:
"""Deletes a user if current user has permission."""
try:
# Get user
user = self.getUser(userId)
if not user:
raise ValueError(f"User {userId} not found")
if not self.checkRbacPermission(UserInDB, "update", userId):
raise PermissionError(f"No permission to delete user {userId}")
# Delete all referenced data first
self._deleteUserReferencedData(userId)
# Delete user record
success = self.db.recordDelete(UserInDB, userId)
if not success:
raise ValueError(f"Failed to delete user {userId}")
logger.info(f"User {userId} successfully deleted")
return True
except Exception as e:
logger.error(f"Error deleting user: {str(e)}")
raise ValueError(f"Failed to delete user: {str(e)}")
def _getInitialUser(self) -> Optional[Dict[str, Any]]:
"""Get the initial user record directly from database without access control."""
try:
initialUserId = self.getInitialId(UserInDB)
if not initialUserId:
return None
users = getRecordsetWithRBAC(self.db,
UserInDB,
self.currentUser,
recordFilter={"id": initialUserId}
)
return users[0] if users else None
except Exception as e:
logger.error(f"Error getting initial user: {str(e)}")
return None
def checkUsernameAvailability(self, checkData: Dict[str, Any]) -> Dict[str, Any]:
"""Checks if a username is available for registration.
Uses global uniqueness check (across ALL mandates) to ensure
usernames are unique system-wide.
"""
try:
username = checkData.get("username")
# authenticationAuthority not used - usernames must be globally unique
if not username:
return {"available": False, "message": "Username is required"}
# Check global uniqueness (across all mandates, no RBAC filtering)
if not self.isUsernameGloballyUnique(username):
return {"available": False, "message": "Username is already taken"}
return {"available": True, "message": "Username is available"}
except Exception as e:
logger.error(f"Error checking username availability: {str(e)}")
return {
"available": False,
"message": f"Error checking username availability: {str(e)}",
}
# Connection methods
def getUserConnections(self, userId: str) -> List[UserConnection]:
"""Returns all connections for a user."""
try:
# Get connections for this user
connections = self.db.getRecordset(
UserConnection, recordFilter={"userId": userId}
)
# Convert to UserConnection objects
result = []
for conn_dict in connections:
try:
# Create UserConnection object
connection = UserConnection(
id=conn_dict["id"],
userId=conn_dict["userId"],
authority=conn_dict.get("authority"),
externalId=conn_dict.get("externalId", ""),
externalUsername=conn_dict.get("externalUsername", ""),
externalEmail=conn_dict.get("externalEmail"),
status=conn_dict.get("status", "pending"),
connectedAt=conn_dict.get("connectedAt"),
lastChecked=conn_dict.get("lastChecked"),
expiresAt=conn_dict.get("expiresAt"),
)
result.append(connection)
except Exception as e:
logger.error(
f"Error converting connection dict to object: {str(e)}"
)
continue
return result
except Exception as e:
logger.error(f"Error getting user connections: {str(e)}")
return []
def addUserConnection(
self,
userId: str,
authority: AuthAuthority,
externalId: str,
externalUsername: str,
externalEmail: Optional[str] = None,
status: ConnectionStatus = ConnectionStatus.PENDING,
) -> UserConnection:
"""
Adds a new connection for a user.
Args:
userId: The ID of the user
authority: The authentication authority (e.g., MSFT, GOOGLE)
externalId: The external ID from the authority
externalUsername: The username from the authority
externalEmail: Optional email from the authority
status: The connection status (defaults to PENDING)
Returns:
The created UserConnection object
"""
try:
# Get the user
user = self.getUser(userId)
if not user:
raise ValueError(f"User not found: {userId}")
# Create new connection with all required fields
connection = UserConnection(
id=str(uuid.uuid4()),
userId=userId,
authority=authority,
externalId=externalId,
externalUsername=externalUsername,
externalEmail=externalEmail,
status=status,
connectedAt=getUtcTimestamp(),
lastChecked=getUtcTimestamp(),
expiresAt=None, # Optional field, set to None by default
)
# Save to connections table
self.db.recordCreate(UserConnection, connection)
return connection
except Exception as e:
logger.error(f"Error adding user connection: {str(e)}")
raise ValueError(f"Failed to add user connection: {str(e)}")
def removeUserConnection(self, connectionId: str) -> None:
"""Remove a connection to an external service"""
try:
# Get connection
connections = self.db.getRecordset(
UserConnection, recordFilter={"id": connectionId}
)
if not connections:
raise ValueError(f"Connection {connectionId} not found")
# Delete connection
self.db.recordDelete(UserConnection, connectionId)
except Exception as e:
logger.error(f"Error removing user connection: {str(e)}")
raise ValueError(f"Failed to remove user connection: {str(e)}")
# Mandate methods
def getAllMandates(self, pagination: Optional[PaginationParams] = None) -> Union[List[Mandate], PaginatedResult]:
"""
Returns all mandates based on user access level.
Supports optional pagination, sorting, and filtering.
Args:
pagination: Optional pagination parameters. If None, returns all items.
Returns:
If pagination is None: List[Mandate]
If pagination is provided: PaginatedResult with items and metadata
"""
# Use RBAC filtering
allMandates = getRecordsetWithRBAC(self.db, Mandate, self.currentUser)
# Filter out database-specific fields
filteredMandates = []
for mandate in allMandates:
cleanedMandate = {k: v for k, v in mandate.items() if not k.startswith("_")}
filteredMandates.append(cleanedMandate)
# If no pagination requested, return all items
if pagination is None:
return [Mandate(**mandate) for mandate in filteredMandates]
# Apply filtering (if filters provided)
if pagination.filters:
filteredMandates = self._applyFilters(filteredMandates, pagination.filters)
# Apply sorting (in order of sortFields)
if pagination.sort:
filteredMandates = self._applySorting(filteredMandates, pagination.sort)
# Count total items after filters
totalItems = len(filteredMandates)
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
# Apply pagination (skip/limit)
startIdx = (pagination.page - 1) * pagination.pageSize
endIdx = startIdx + pagination.pageSize
pagedMandates = filteredMandates[startIdx:endIdx]
# Convert to model objects
items = [Mandate(**mandate) for mandate in pagedMandates]
return PaginatedResult(
items=items,
totalItems=totalItems,
totalPages=totalPages
)
def getMandate(self, mandateId: str) -> Optional[Mandate]:
"""Returns a mandate by ID if user has access."""
# Use RBAC filtering
mandates = getRecordsetWithRBAC(self.db,
Mandate,
self.currentUser,
recordFilter={"id": mandateId}
)
if not mandates:
return None
# Filter out database-specific fields
filteredMandates = []
for mandate in mandates:
cleanedMandate = {k: v for k, v in mandate.items() if not k.startswith("_")}
filteredMandates.append(cleanedMandate)
if not filteredMandates:
return None
return Mandate(**filteredMandates[0])
def createMandate(self, name: str, description: str = None, enabled: bool = True) -> Mandate:
"""Creates a new mandate if user has permission."""
if not self.checkRbacPermission(Mandate, "create"):
raise PermissionError("No permission to create mandates")
# Create mandate data using model
mandateData = Mandate(name=name, description=description, enabled=enabled)
# Create mandate record
createdRecord = self.db.recordCreate(Mandate, mandateData)
if not createdRecord or not createdRecord.get("id"):
raise ValueError("Failed to create mandate record")
return Mandate(**createdRecord)
def updateMandate(self, mandateId: str, updateData: Dict[str, Any]) -> Mandate:
"""Updates a mandate if user has access."""
try:
# First check if user has permission to modify mandates
if not self.checkRbacPermission(Mandate, "update", mandateId):
raise PermissionError(f"No permission to update mandate {mandateId}")
# Get mandate with access control
mandate = self.getMandate(mandateId)
if not mandate:
raise ValueError(f"Mandate {mandateId} not found")
# Update mandate data using model
updatedData = mandate.model_dump()
updatedData.update(updateData)
updatedMandate = Mandate(**updatedData)
# Update mandate record
self.db.recordModify(Mandate, mandateId, updatedMandate)
# Clear cache to ensure fresh data
# Get updated mandate
updatedMandate = self.getMandate(mandateId)
if not updatedMandate:
raise ValueError("Failed to retrieve updated mandate")
return updatedMandate
except Exception as e:
logger.error(f"Error updating mandate: {str(e)}")
raise ValueError(f"Failed to update mandate: {str(e)}")
def deleteMandate(self, mandateId: str) -> bool:
"""Deletes a mandate if user has access."""
try:
# Check if mandate exists and user has access
mandate = self.getMandate(mandateId)
if not mandate:
return False
if not self.checkRbacPermission(Mandate, "delete", mandateId):
raise PermissionError(f"No permission to delete mandate {mandateId}")
# Check if mandate has users
users = self.getUsersByMandate(mandateId)
if users:
raise ValueError(
f"Cannot delete mandate {mandateId} with existing users"
)
# Delete mandate
success = self.db.recordDelete(Mandate, mandateId)
# Clear cache to ensure fresh data
return success
except Exception as e:
logger.error(f"Error deleting mandate: {str(e)}")
raise ValueError(f"Failed to delete mandate: {str(e)}")
# ============================================
# User-Mandate Membership Methods (Multi-Tenant)
# ============================================
def getUserMandate(self, userId: str, mandateId: str) -> Optional[UserMandate]:
"""
Get UserMandate record for a user in a specific mandate.
Args:
userId: User ID
mandateId: Mandate ID
Returns:
UserMandate object or None
"""
try:
records = self.db.getRecordset(
UserMandate,
recordFilter={"userId": userId, "mandateId": mandateId}
)
if not records:
return None
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
return UserMandate(**cleanedRecord)
except Exception as e:
logger.error(f"Error getting UserMandate: {e}")
return None
def getUserMandates(self, userId: str) -> List[UserMandate]:
"""
Get all mandates a user is member of.
Args:
userId: User ID
Returns:
List of UserMandate objects
"""
try:
records = self.db.getRecordset(
UserMandate,
recordFilter={"userId": userId, "enabled": True}
)
result = []
for record in records:
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
result.append(UserMandate(**cleanedRecord))
return result
except Exception as e:
logger.error(f"Error getting UserMandates: {e}")
return []
def createUserMandate(self, userId: str, mandateId: str, roleIds: List[str] = None) -> UserMandate:
"""
Create a UserMandate record (add user to mandate).
Args:
userId: User ID
mandateId: Mandate ID
roleIds: Optional list of role IDs to assign
Returns:
Created UserMandate object
"""
try:
# Check if already exists
existing = self.getUserMandate(userId, mandateId)
if existing:
raise ValueError(f"User {userId} is already member of mandate {mandateId}")
# Create UserMandate
userMandate = UserMandate(
userId=userId,
mandateId=mandateId,
enabled=True
)
createdRecord = self.db.recordCreate(UserMandate, userMandate.model_dump())
# Assign roles via junction table
if roleIds and createdRecord:
userMandateId = createdRecord.get("id")
for roleId in roleIds:
userMandateRole = UserMandateRole(
userMandateId=userMandateId,
roleId=roleId
)
self.db.recordCreate(UserMandateRole, userMandateRole.model_dump())
cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")}
return UserMandate(**cleanedRecord)
except Exception as e:
logger.error(f"Error creating UserMandate: {e}")
raise ValueError(f"Failed to create UserMandate: {e}")
def deleteUserMandate(self, userId: str, mandateId: str) -> bool:
"""
Delete a UserMandate record (remove user from mandate).
CASCADE will delete UserMandateRole entries.
Args:
userId: User ID
mandateId: Mandate ID
Returns:
True if deleted, False if not found
"""
try:
existing = self.getUserMandate(userId, mandateId)
if not existing:
return False
return self.db.recordDelete(UserMandate, existing.id)
except Exception as e:
logger.error(f"Error deleting UserMandate: {e}")
raise ValueError(f"Failed to delete UserMandate: {e}")
def getRoleIdsForUserMandate(self, userMandateId: str) -> List[str]:
"""
Get all role IDs assigned to a UserMandate.
Args:
userMandateId: UserMandate ID
Returns:
List of role IDs
"""
try:
records = self.db.getRecordset(
UserMandateRole,
recordFilter={"userMandateId": userMandateId}
)
return [r.get("roleId") for r in records if r.get("roleId")]
except Exception as e:
logger.error(f"Error getting role IDs for UserMandate: {e}")
return []
def addRoleToUserMandate(self, userMandateId: str, roleId: str) -> UserMandateRole:
"""
Add a role to a UserMandate.
Args:
userMandateId: UserMandate ID
roleId: Role ID to add
Returns:
Created UserMandateRole object
"""
try:
# Check if already exists
existing = self.db.getRecordset(
UserMandateRole,
recordFilter={"userMandateId": userMandateId, "roleId": roleId}
)
if existing:
cleanedRecord = {k: v for k, v in existing[0].items() if not k.startswith("_")}
return UserMandateRole(**cleanedRecord)
userMandateRole = UserMandateRole(
userMandateId=userMandateId,
roleId=roleId
)
createdRecord = self.db.recordCreate(UserMandateRole, userMandateRole.model_dump())
cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")}
return UserMandateRole(**cleanedRecord)
except Exception as e:
logger.error(f"Error adding role to UserMandate: {e}")
raise ValueError(f"Failed to add role: {e}")
def removeRoleFromUserMandate(self, userMandateId: str, roleId: str) -> bool:
"""
Remove a role from a UserMandate.
If no roles remain, the UserMandate is deleted (Application-Level Cleanup).
Args:
userMandateId: UserMandate ID
roleId: Role ID to remove
Returns:
True if removed
"""
try:
# Find and delete the junction record
records = self.db.getRecordset(
UserMandateRole,
recordFilter={"userMandateId": userMandateId, "roleId": roleId}
)
if not records:
return False
self.db.recordDelete(UserMandateRole, records[0].get("id"))
# Application-Level Cleanup: Delete UserMandate if no roles remain
remainingRoles = self.db.getRecordset(
UserMandateRole,
recordFilter={"userMandateId": userMandateId}
)
if not remainingRoles:
self.db.recordDelete(UserMandate, userMandateId)
logger.info(f"Deleted empty UserMandate {userMandateId}")
return True
except Exception as e:
logger.error(f"Error removing role from UserMandate: {e}")
raise ValueError(f"Failed to remove role: {e}")
# ============================================
# Feature Access Methods (Multi-Tenant)
# ============================================
def getFeatureAccess(self, userId: str, featureInstanceId: str) -> Optional[FeatureAccess]:
"""
Get FeatureAccess record for a user to a specific feature instance.
Args:
userId: User ID
featureInstanceId: FeatureInstance ID
Returns:
FeatureAccess object or None
"""
try:
records = self.db.getRecordset(
FeatureAccess,
recordFilter={"userId": userId, "featureInstanceId": featureInstanceId}
)
if not records:
return None
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
return FeatureAccess(**cleanedRecord)
except Exception as e:
logger.error(f"Error getting FeatureAccess: {e}")
return None
def getFeatureAccessesForUser(self, userId: str) -> List[FeatureAccess]:
"""
Get all feature accesses for a user.
Args:
userId: User ID
Returns:
List of FeatureAccess objects
"""
try:
records = self.db.getRecordset(
FeatureAccess,
recordFilter={"userId": userId, "enabled": True}
)
result = []
for record in records:
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
result.append(FeatureAccess(**cleanedRecord))
return result
except Exception as e:
logger.error(f"Error getting FeatureAccesses: {e}")
return []
def createFeatureAccess(self, userId: str, featureInstanceId: str, roleIds: List[str] = None) -> FeatureAccess:
"""
Create a FeatureAccess record (grant user access to feature instance).
Args:
userId: User ID
featureInstanceId: FeatureInstance ID
roleIds: Optional list of role IDs to assign
Returns:
Created FeatureAccess object
"""
try:
# Check if already exists
existing = self.getFeatureAccess(userId, featureInstanceId)
if existing:
raise ValueError(f"User {userId} already has access to feature instance {featureInstanceId}")
# Create FeatureAccess
featureAccess = FeatureAccess(
userId=userId,
featureInstanceId=featureInstanceId,
enabled=True
)
createdRecord = self.db.recordCreate(FeatureAccess, featureAccess.model_dump())
# Assign roles via junction table
if roleIds and createdRecord:
featureAccessId = createdRecord.get("id")
for roleId in roleIds:
featureAccessRole = FeatureAccessRole(
featureAccessId=featureAccessId,
roleId=roleId
)
self.db.recordCreate(FeatureAccessRole, featureAccessRole.model_dump())
cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")}
return FeatureAccess(**cleanedRecord)
except Exception as e:
logger.error(f"Error creating FeatureAccess: {e}")
raise ValueError(f"Failed to create FeatureAccess: {e}")
def getRoleIdsForFeatureAccess(self, featureAccessId: str) -> List[str]:
"""
Get all role IDs assigned to a FeatureAccess.
Args:
featureAccessId: FeatureAccess ID
Returns:
List of role IDs
"""
try:
records = self.db.getRecordset(
FeatureAccessRole,
recordFilter={"featureAccessId": featureAccessId}
)
return [r.get("roleId") for r in records if r.get("roleId")]
except Exception as e:
logger.error(f"Error getting role IDs for FeatureAccess: {e}")
return []
# Token methods
def saveAccessToken(self, token: Token, replace_existing: bool = True) -> None:
"""Save an access token for the current user (must NOT have connectionId)"""
try:
# Validate that this is NOT a connection token
if token.connectionId:
raise ValueError(
"Access tokens cannot have connectionId - use saveConnectionToken instead"
)
# Validate user context
if not self.currentUser or not self.currentUser.id:
raise ValueError("No valid user context available for token storage")
# Set the user ID and mandate ID
token.userId = self.currentUser.id
# Ensure token has required fields
if not token.id:
token.id = str(uuid.uuid4())
if not token.createdAt:
token.createdAt = getUtcTimestamp()
# If replace_existing is True, delete old access tokens for this user and authority first
if replace_existing:
try:
old_tokens = self.db.getRecordset(
Token,
recordFilter={
"userId": self.currentUser.id,
"authority": token.authority,
"connectionId": None, # Ensure we only delete access tokens
},
)
deleted_count = 0
for old_token in old_tokens:
if (
old_token["id"] != token.id
): # Don't delete the new token if it already exists
self.db.recordDelete(Token, old_token["id"])
deleted_count += 1
if deleted_count > 0:
logger.info(
f"Replaced {deleted_count} old access tokens for user {self.currentUser.id} and authority {token.authority}"
)
except Exception as e:
logger.warning(
f"Failed to delete old access tokens for user {self.currentUser.id} and authority {token.authority}: {str(e)}"
)
# Continue with saving the new token even if deletion fails
# Convert to dict and ensure all fields are properly set
token_dict = token.model_dump()
# Ensure userId is set to current user
# Convert to dict and ensure all fields are properly set
token_dict = token.model_dump()
# Ensure userId is set to current user
token_dict["userId"] = self.currentUser.id
# Save to database
self.db.recordCreate(Token, token_dict)
except Exception as e:
logger.error(f"Error saving access token: {str(e)}")
raise
def saveConnectionToken(self, token: Token, replace_existing: bool = True) -> None:
"""Save a connection token (must have connectionId)"""
try:
# Validate that this IS a connection token
if not token.connectionId:
raise ValueError(
"Connection tokens must have connectionId - use saveAccessToken instead"
)
# Validate user context
if not self.currentUser or not self.currentUser.id:
raise ValueError("No valid user context available for token storage")
# Set the user ID for the connection token
token.userId = self.currentUser.id
# Ensure token has required fields
if not token.id:
token.id = str(uuid.uuid4())
if not token.createdAt:
token.createdAt = getUtcTimestamp()
# Convert to dict and ensure all fields are properly set
token_dict = token.model_dump()
# Ensure userId is set to current user
token_dict["userId"] = self.currentUser.id
# Save to database
self.db.recordCreate(Token, token_dict)
# After successful save, delete old tokens for this connectionId (if requested)
if replace_existing:
try:
old_tokens = self.db.getRecordset(
Token, recordFilter={"connectionId": token.connectionId}
)
deleted_count = 0
for old_token in old_tokens:
if old_token["id"] != token.id:
self.db.recordDelete(Token, old_token["id"])
deleted_count += 1
if deleted_count > 0:
logger.info(
f"Replaced {deleted_count} old tokens for connectionId {token.connectionId}"
)
except Exception as e:
logger.warning(
f"Failed to delete old tokens for connectionId {token.connectionId}: {str(e)}"
)
# Keep the newly saved token; cleanup can be retried later
except Exception as e:
logger.error(f"Error saving connection token: {str(e)}")
raise
def getConnectionToken(self, connectionId: str) -> Optional[Token]:
"""Get the latest stored token for a specific connectionId (no refresh)."""
try:
# Validate connectionId
if not connectionId:
raise ValueError("connectionId is required for getConnectionToken")
# Get token for this specific connection
# Query for specific connection
tokens = self.db.getRecordset(
Token, recordFilter={"connectionId": connectionId}
)
if not tokens:
logger.warning(
f"No connection token found for connectionId: {connectionId}"
)
return None
# Sort by expiration date and get the latest (most recent expiration)
tokens.sort(key=lambda x: parseTimestamp(x.get("expiresAt"), default=0), reverse=True)
latest_token = Token(**tokens[0])
# No auto-refresh here. Callers should use a higher-level service to refresh when needed.
return latest_token
except Exception as e:
logger.error(
f"Error getting connection token for connectionId {connectionId}: {str(e)}"
)
return None
def findActiveTokenById(
self,
tokenId: str,
userId: str,
authority: AuthAuthority,
sessionId: str = None,
mandateId: str = None,
) -> Optional[Token]:
"""Find an active access token by its id (jti) with optional session/tenant scoping."""
try:
recordFilter = {
"id": tokenId,
"userId": userId,
"authority": authority.value
if hasattr(authority, "value")
else str(authority),
"status": TokenStatus.ACTIVE,
}
if sessionId is not None:
recordFilter["sessionId"] = sessionId
if mandateId is not None:
recordFilter["mandateId"] = mandateId
tokens = self.db.getRecordset(Token, recordFilter=recordFilter)
if not tokens:
return None
return Token(**tokens[0])
except Exception as e:
logger.error(f"Error finding active token by id {tokenId}: {str(e)}")
return None
def revokeTokenById(self, tokenId: str, revokedBy: str, reason: str = None) -> bool:
"""Revoke a single token by id by setting status fields (no delete)."""
try:
existing = self.db.getRecordset(Token, recordFilter={"id": tokenId})
if not existing:
return False
token = existing[0]
if token.get("status") == TokenStatus.REVOKED:
return True
tokenUpdate = {
"status": TokenStatus.REVOKED,
"revokedAt": getUtcTimestamp(),
"revokedBy": revokedBy,
"reason": reason or "revoked",
}
self.db.recordModify(Token, tokenId, tokenUpdate)
return True
except Exception as e:
logger.error(f"Error revoking token {tokenId}: {str(e)}")
return False
def revokeTokensBySessionId(
self,
sessionId: str,
userId: str,
authority: AuthAuthority,
revokedBy: str,
reason: str = None,
) -> int:
"""Revoke all tokens of a session for a user/authority."""
try:
tokens = self.db.getRecordset(
Token,
recordFilter={
"userId": userId,
"authority": authority.value
if hasattr(authority, "value")
else str(authority),
"sessionId": sessionId,
"status": TokenStatus.ACTIVE,
},
)
count = 0
for t in tokens:
self.db.recordModify(
Token,
t["id"],
{
"status": TokenStatus.REVOKED,
"revokedAt": getUtcTimestamp(),
"revokedBy": revokedBy,
"reason": reason or "session logout",
},
)
count += 1
return count
except Exception as e:
logger.error(f"Error revoking tokens for session {sessionId}: {str(e)}")
return 0
def revokeTokensByUser(
self,
userId: str,
authority: AuthAuthority = None,
mandateId: str = None,
revokedBy: str = None,
reason: str = None,
) -> int:
"""Revoke all active tokens for a user, optionally filtered by authority/mandate."""
try:
# Fetch all active tokens for user (optionally filtered by authority)
recordFilter = {
"userId": userId,
"status": TokenStatus.ACTIVE,
}
if authority is not None:
recordFilter["authority"] = (
authority.value if hasattr(authority, "value") else str(authority)
)
tokens = self.db.getRecordset(Token, recordFilter=recordFilter)
count = 0
for t in tokens:
self.db.recordModify(
Token,
t["id"],
{
"status": TokenStatus.REVOKED,
"revokedAt": getUtcTimestamp(),
"revokedBy": revokedBy,
"reason": reason or "admin revoke",
},
)
count += 1
return count
except Exception as e:
logger.error(f"Error revoking tokens for user {userId}: {str(e)}")
return 0
def cleanupExpiredTokens(self) -> int:
"""Clean up expired tokens for all connections, returns count of cleaned tokens"""
try:
current_time = getUtcTimestamp()
cleaned_count = 0
# Get all tokens
all_tokens = self.db.getRecordset(Token, recordFilter={})
for token_data in all_tokens:
expiresAt = parseTimestamp(token_data.get("expiresAt"))
if expiresAt and expiresAt < current_time:
# Token is expired, delete it
self.db.recordDelete(Token, token_data["id"])
cleaned_count += 1
# Clear cache to ensure fresh data
if cleaned_count > 0:
logger.info(f"Cleaned up {cleaned_count} expired tokens")
return cleaned_count
except Exception as e:
logger.error(f"Error cleaning up expired tokens: {str(e)}")
return 0
def logout(self) -> None:
"""Logout current user - clear user context and tokens"""
try:
# Clear user context
self.currentUser = None
self.userId = None
self.mandateId = None
self.rbac = None
# Clear database context
if hasattr(self, "db"):
self.db.updateContext("")
logger.info("User logged out successfully")
except Exception as e:
logger.error(f"Error during logout: {str(e)}")
raise
# RBAC CRUD Methods
def createAccessRule(self, accessRule: AccessRule) -> AccessRule:
"""
Create a new access rule.
Args:
accessRule: AccessRule object to create
Returns:
Created AccessRule object
"""
try:
createdRule = self.db.recordCreate(AccessRule, accessRule)
logger.info(f"Created access rule with ID {createdRule.get('id')}")
return AccessRule(**createdRule)
except Exception as e:
logger.error(f"Error creating access rule: {str(e)}")
raise
def getAccessRule(self, ruleId: str) -> Optional[AccessRule]:
"""
Get an access rule by ID.
Args:
ruleId: Access rule ID
Returns:
AccessRule object if found, None otherwise
"""
try:
rules = self.db.getRecordset(AccessRule, recordFilter={"id": ruleId})
if rules:
return AccessRule(**rules[0])
return None
except Exception as e:
logger.error(f"Error getting access rule {ruleId}: {str(e)}")
return None
def updateAccessRule(self, ruleId: str, accessRule: AccessRule) -> AccessRule:
"""
Update an existing access rule.
Args:
ruleId: Access rule ID
accessRule: Updated AccessRule object
Returns:
Updated AccessRule object
"""
try:
# Exclude id from model_dump - the URL ruleId is authoritative
updatedRule = self.db.recordModify(AccessRule, ruleId, accessRule.model_dump(exclude={"id"}))
logger.info(f"Updated access rule with ID {ruleId}")
return AccessRule(**updatedRule)
except Exception as e:
logger.error(f"Error updating access rule {ruleId}: {str(e)}")
raise
def deleteAccessRule(self, ruleId: str) -> bool:
"""
Delete an access rule.
Args:
ruleId: Access rule ID
Returns:
True if deleted successfully, False otherwise
"""
try:
self.db.recordDelete(AccessRule, ruleId)
logger.info(f"Deleted access rule with ID {ruleId}")
return True
except Exception as e:
logger.error(f"Error deleting access rule {ruleId}: {str(e)}")
return False
def getAccessRules(
self,
roleLabel: Optional[str] = None,
roleId: Optional[str] = None,
context: Optional[AccessRuleContext] = None,
item: Optional[str] = None,
pagination: Optional[PaginationParams] = None
) -> Union[List[AccessRule], PaginatedResult]:
"""
Get access rules with optional filters and pagination.
Args:
roleLabel: Optional role label filter (deprecated, use roleId)
roleId: Optional role ID filter
context: Optional context filter
item: Optional item filter
pagination: Optional pagination parameters. If None, returns all items.
Returns:
If pagination is None: List[AccessRule]
If pagination is provided: PaginatedResult with items and metadata
"""
try:
recordFilter = {}
if roleId:
recordFilter["roleId"] = roleId
elif roleLabel:
recordFilter["roleLabel"] = roleLabel
if context:
recordFilter["context"] = context.value
if item:
recordFilter["item"] = item
# Use RBAC filtering
rules = getRecordsetWithRBAC(
self.db,
AccessRule,
self.currentUser,
recordFilter=recordFilter if recordFilter else None
)
# Filter out database-specific fields
filteredRules = []
for rule in rules:
cleanedRule = {k: v for k, v in rule.items() if not k.startswith("_")}
filteredRules.append(cleanedRule)
# If no pagination requested, return all items
if pagination is None:
return [AccessRule(**rule) for rule in filteredRules]
# Apply filtering (if filters provided)
if pagination.filters:
filteredRules = self._applyFilters(filteredRules, pagination.filters)
# Apply sorting (in order of sortFields)
if pagination.sort:
filteredRules = self._applySorting(filteredRules, pagination.sort)
# Count total items after filters
totalItems = len(filteredRules)
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
# Apply pagination (skip/limit)
startIdx = (pagination.page - 1) * pagination.pageSize
endIdx = startIdx + pagination.pageSize
pagedRules = filteredRules[startIdx:endIdx]
# Convert to model objects
items = [AccessRule(**rule) for rule in pagedRules]
return PaginatedResult(
items=items,
totalItems=totalItems,
totalPages=totalPages
)
except Exception as e:
logger.error(f"Error getting access rules: {str(e)}")
if pagination is None:
return []
else:
return PaginatedResult(items=[], totalItems=0, totalPages=0)
def getAccessRulesForRoles(
self,
roleLabels: List[str],
context: AccessRuleContext,
item: str
) -> List[AccessRule]:
"""
Get access rules for multiple roles, context, and item.
Returns the most specific matching rules for each role.
Args:
roleLabels: List of role labels
context: Context type
item: Item identifier
Returns:
List of AccessRule objects (most specific for each role)
"""
try:
# Pass self.db as dbApp since this interface uses DbApp database
RbacInstance = RbacClass(self.db, dbApp=self.db)
allRules = []
for roleLabel in roleLabels:
# Get all rules for this role and context
roleRules = RbacInstance._getRulesForRole(roleLabel, context)
# Find most specific rule for this item
mostSpecificRule = RbacInstance.findMostSpecificRule(roleRules, item)
if mostSpecificRule:
allRules.append(mostSpecificRule)
return allRules
except Exception as e:
logger.error(f"Error getting access rules for roles: {str(e)}")
return []
def createRole(self, role: Role) -> Role:
"""
Create a new role.
Args:
role: Role object to create
Returns:
Created Role object
"""
try:
# Check if role label already exists
existingRoles = self.db.getRecordset(Role, recordFilter={"roleLabel": role.roleLabel})
if existingRoles:
raise ValueError(f"Role with label '{role.roleLabel}' already exists")
createdRole = self.db.recordCreate(Role, role)
logger.info(f"Created role with ID {createdRole.get('id')} and label {role.roleLabel}")
return Role(**createdRole)
except Exception as e:
logger.error(f"Error creating role: {str(e)}")
raise
def getRole(self, roleId: str) -> Optional[Role]:
"""
Get a role by ID.
Args:
roleId: Role ID
Returns:
Role object if found, None otherwise
"""
try:
roles = self.db.getRecordset(Role, recordFilter={"id": roleId})
if roles:
return Role(**roles[0])
return None
except Exception as e:
logger.error(f"Error getting role {roleId}: {str(e)}")
return None
def getRoleByLabel(self, roleLabel: str) -> Optional[Role]:
"""
Get a role by label.
Args:
roleLabel: Role label
Returns:
Role object if found, None otherwise
"""
try:
roles = self.db.getRecordset(Role, recordFilter={"roleLabel": roleLabel})
if roles:
return Role(**roles[0])
return None
except Exception as e:
logger.error(f"Error getting role by label {roleLabel}: {str(e)}")
return None
def getAllRoles(self, pagination: Optional[PaginationParams] = None) -> Union[List[Role], PaginatedResult]:
"""
Get all roles with optional pagination, sorting, and filtering.
Args:
pagination: Optional pagination parameters. If None, returns all items.
Returns:
If pagination is None: List[Role]
If pagination is provided: PaginatedResult with items and metadata
"""
try:
# Get all roles from database
roles = self.db.getRecordset(Role)
# Filter out database-specific fields
filteredRoles = []
for role in roles:
cleanedRole = {k: v for k, v in role.items() if not k.startswith("_")}
filteredRoles.append(cleanedRole)
# If no pagination requested, return all items
if pagination is None:
return [Role(**role) for role in filteredRoles]
# Apply filtering (if filters provided)
if pagination.filters:
filteredRoles = self._applyFilters(filteredRoles, pagination.filters)
# Apply sorting (in order of sortFields)
if pagination.sort:
filteredRoles = self._applySorting(filteredRoles, pagination.sort)
# Count total items after filters
totalItems = len(filteredRoles)
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
# Apply pagination (skip/limit)
startIdx = (pagination.page - 1) * pagination.pageSize
endIdx = startIdx + pagination.pageSize
pagedRoles = filteredRoles[startIdx:endIdx]
# Convert to model objects
items = [Role(**role) for role in pagedRoles]
return PaginatedResult(
items=items,
totalItems=totalItems,
totalPages=totalPages
)
except Exception as e:
logger.error(f"Error getting all roles: {str(e)}")
if pagination is None:
return []
else:
return PaginatedResult(items=[], totalItems=0, totalPages=0)
def countRoleAssignments(self) -> Dict[str, int]:
"""
Count the number of user assignments per role from UserMandateRole table.
Returns:
Dictionary mapping roleId to count of user assignments
"""
try:
# Get all UserMandateRole records
assignments = self.db.getRecordset(UserMandateRole)
# Count assignments per roleId
roleCounts: Dict[str, int] = {}
for assignment in assignments:
roleId = str(assignment.get("roleId", ""))
if roleId:
roleCounts[roleId] = roleCounts.get(roleId, 0) + 1
return roleCounts
except Exception as e:
logger.error(f"Error counting role assignments: {str(e)}")
return {}
def updateRole(self, roleId: str, role: Role) -> Role:
"""
Update an existing role.
Args:
roleId: Role ID
role: Updated Role object
Returns:
Updated Role object
"""
try:
# Check if role exists
existingRole = self.getRole(roleId)
if not existingRole:
raise ValueError(f"Role with ID {roleId} not found")
# If role label is being changed, check for conflicts
if role.roleLabel != existingRole.roleLabel:
conflictingRole = self.getRoleByLabel(role.roleLabel)
if conflictingRole and conflictingRole.id != roleId:
raise ValueError(f"Role with label '{role.roleLabel}' already exists")
# Exclude id from model_dump - the URL roleId is authoritative
updatedRole = self.db.recordModify(Role, roleId, role.model_dump(exclude={"id"}))
logger.info(f"Updated role with ID {roleId}")
return Role(**updatedRole)
except Exception as e:
logger.error(f"Error updating role {roleId}: {str(e)}")
raise
def deleteRole(self, roleId: str) -> bool:
"""
Delete a role.
Args:
roleId: Role ID
Returns:
True if deleted successfully, False otherwise
"""
try:
# Check if role exists
role = self.getRole(roleId)
if not role:
return False
# Prevent deletion of system roles
if role.isSystemRole:
raise ValueError(f"Cannot delete system role '{role.roleLabel}'")
# Check if role is assigned to any users via UserMandateRole
roleAssignments = self.db.getRecordset(UserMandateRole, recordFilter={"roleId": roleId})
if roleAssignments:
raise ValueError(f"Cannot delete role '{role.roleLabel}' - it is assigned to users")
# Check if role is used in any access rules
accessRules = self.getAccessRules(roleId=roleId)
if accessRules:
raise ValueError(f"Cannot delete role '{role.roleLabel}' - it is used in access rules")
self.db.recordDelete(Role, roleId)
logger.info(f"Deleted role with ID {roleId}")
return True
except Exception as e:
logger.error(f"Error deleting role {roleId}: {str(e)}")
raise
# Public Methods
def getInterface(currentUser: User, mandateId: Optional[str] = None) -> AppObjects:
"""
Returns a AppObjects instance for the current user.
Handles initialization of database and records.
Multi-Tenant Design:
- mandateId wird explizit übergeben (aus Request-Context / X-Mandate-Id Header)
Args:
currentUser: User object
mandateId: Explicit mandate context (from request header). Required for non-sysadmin.
Returns:
AppObjects instance for the user context
"""
if not currentUser:
raise ValueError("Invalid user context: user is required")
effectiveMandateId = mandateId
# Create context key (user + mandate combination)
contextKey = f"{effectiveMandateId}_{currentUser.id}"
# Create new instance if not exists
if contextKey not in _gatewayInterfaces:
instance = AppObjects(currentUser)
instance.setUserContext(currentUser, mandateId=effectiveMandateId)
_gatewayInterfaces[contextKey] = instance
return _gatewayInterfaces[contextKey]
def getRootInterface() -> AppObjects:
"""
Returns a AppObjects instance with root privileges.
This is used for initial setup and user creation.
Note: This function uses security.rootAccess internally to avoid circular dependencies.
Routes can continue using this function, but connectors/interfaces should use
security.rootAccess.getRootDbAppConnector() or security.rootAccess.getRootUser() directly.
"""
global _rootAppObjects
if _rootAppObjects is None:
try:
# Use security.rootAccess to get root user (avoids circular dependencies)
from modules.security.rootAccess import getRootUser
rootUser = getRootUser()
# Create root interface with the root user
_rootAppObjects = AppObjects(rootUser)
except Exception as e:
logger.error(f"Error getting root user: {str(e)}")
raise ValueError(f"Failed to get root user: {str(e)}")
return _rootAppObjects