3685 lines
143 KiB
Python
3685 lines
143 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Interface to the Gateway system.
|
|
Manages users and mandates for authentication.
|
|
|
|
Multi-Tenant Design:
|
|
- User gehört nicht mehr direkt zu einem Mandanten
|
|
- mandateId wird aus Request-Context übergeben (X-Mandate-Id Header)
|
|
"""
|
|
|
|
import logging
|
|
import math
|
|
from typing import Dict, Any, List, Optional, Union
|
|
from passlib.context import CryptContext
|
|
import uuid
|
|
|
|
from modules.connectors.connectorDbPostgre import DatabaseConnector, _get_cached_connector
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
|
|
from modules.interfaces.interfaceBootstrap import initBootstrap
|
|
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
|
from modules.security.rbac import RbacClass
|
|
from modules.datamodels.datamodelUam import (
|
|
User,
|
|
Mandate,
|
|
UserInDB,
|
|
UserConnection,
|
|
AuthAuthority,
|
|
ConnectionStatus,
|
|
)
|
|
from modules.datamodels.datamodelRbac import (
|
|
AccessRule,
|
|
AccessRuleContext,
|
|
Role,
|
|
)
|
|
from modules.datamodels.datamodelUam import AccessLevel
|
|
from modules.datamodels.datamodelSecurity import Token, AuthEvent, TokenStatus, TokenPurpose
|
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
|
|
from modules.datamodels.datamodelMembership import (
|
|
UserMandate,
|
|
UserMandateRole,
|
|
FeatureAccess,
|
|
FeatureAccessRole,
|
|
)
|
|
from modules.datamodels.datamodelFeatures import Feature, FeatureInstance
|
|
from modules.datamodels.datamodelInvitation import Invitation
|
|
from modules.datamodels.datamodelNotification import UserNotification
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Singleton factory for AppObjects instances per context
|
|
_gatewayInterfaces = {}
|
|
|
|
# Root interface instance
|
|
_rootAppObjects = None
|
|
|
|
# Bootstrap completion flag - ensures bootstrap runs only ONCE per application lifecycle
|
|
_bootstrapCompleted = False
|
|
|
|
# Password-Hashing
|
|
pwdContext = CryptContext(schemes=["argon2"], deprecated="auto")
|
|
|
|
|
|
class AppObjects:
|
|
"""
|
|
Interface to the Gateway system.
|
|
Manages users and mandates.
|
|
"""
|
|
|
|
def __init__(self, currentUser: Optional[User] = None):
|
|
"""Initializes the Gateway Interface."""
|
|
# Initialize variables
|
|
self.currentUser = currentUser # Store User object directly
|
|
self.userId = currentUser.id if currentUser else None
|
|
self.mandateId = None # mandateId comes from setUserContext, not from User
|
|
self.featureInstanceId = None # featureInstanceId comes from setUserContext
|
|
|
|
# Initialize database
|
|
self._initializeDatabase()
|
|
|
|
# Initialize standard records if needed
|
|
self._initRecords()
|
|
|
|
# Set user context if provided
|
|
if currentUser:
|
|
self.setUserContext(currentUser)
|
|
|
|
def setUserContext(self, currentUser: User, mandateId: Optional[str] = None):
|
|
"""
|
|
Sets the user context for the interface.
|
|
|
|
Multi-Tenant Design:
|
|
- mandateId wird explizit übergeben (aus Request-Context / X-Mandate-Id Header)
|
|
- isSysAdmin User brauchen kein mandateId für System-Operationen
|
|
|
|
Args:
|
|
currentUser: User object
|
|
mandateId: Explicit mandate context (from request header). Required for non-sysadmin.
|
|
"""
|
|
if not currentUser:
|
|
logger.info("Initializing interface without user context")
|
|
return
|
|
|
|
self.currentUser = currentUser # Store User object directly
|
|
self.userId = currentUser.id
|
|
|
|
# mandateId comes from parameter only
|
|
self.mandateId = mandateId
|
|
|
|
# Validate: userId is always required
|
|
if not self.userId:
|
|
raise ValueError("Invalid user context: id is required")
|
|
|
|
# Note: mandateId is ALWAYS optional here - it comes from Request-Context, not from User.
|
|
# Users are NOT assigned to mandates by design - they get mandate context from the request.
|
|
# sysAdmin users can additionally perform cross-mandate operations.
|
|
|
|
# Add language settings
|
|
self.userLanguage = currentUser.language # Default user language
|
|
|
|
# Initialize RBAC interface
|
|
# Pass self.db as dbApp since this interface uses DbApp database
|
|
self.rbac = RbacClass(self.db, dbApp=self.db)
|
|
|
|
# Update database context
|
|
self.db.updateContext(self.userId)
|
|
|
|
def __del__(self):
|
|
"""Cleanup method to close database connection."""
|
|
if hasattr(self, "db") and self.db is not None:
|
|
try:
|
|
self.db.close()
|
|
except Exception as e:
|
|
logger.error(f"Error closing database connection: {e}")
|
|
|
|
def _initializeDatabase(self):
|
|
"""Initializes the database connection directly."""
|
|
try:
|
|
# Get configuration values with defaults
|
|
dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data")
|
|
dbDatabase = "poweron_app"
|
|
dbUser = APP_CONFIG.get("DB_USER")
|
|
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
|
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
|
|
|
self.db = _get_cached_connector(
|
|
dbHost=dbHost,
|
|
dbDatabase=dbDatabase,
|
|
dbUser=dbUser,
|
|
dbPassword=dbPassword,
|
|
dbPort=dbPort,
|
|
userId=self.userId,
|
|
)
|
|
|
|
logger.info(f"Database initialized successfully for user {self.userId}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize database: {str(e)}")
|
|
raise
|
|
|
|
def _separateObjectFields(self, model_class, data: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
|
"""Separate simple fields from object fields based on Pydantic model structure."""
|
|
simpleFields = {}
|
|
objectFields = {}
|
|
|
|
# Get field information from the Pydantic model
|
|
modelFields = model_class.model_fields
|
|
|
|
for fieldName, value in data.items():
|
|
# Check if this field should be stored as JSONB in the database
|
|
if fieldName in modelFields:
|
|
fieldInfo = modelFields[fieldName]
|
|
# Pydantic v2 only
|
|
fieldType = fieldInfo.annotation
|
|
|
|
# Check if this is a JSONB field (Dict, List, or complex types)
|
|
# Purely type-based detection - no hardcoded field names
|
|
if (fieldType == dict or
|
|
fieldType == list or
|
|
(hasattr(fieldType, '__origin__') and fieldType.__origin__ in (dict, list))):
|
|
# Store as JSONB - include in simple_fields for database storage
|
|
simpleFields[fieldName] = value
|
|
elif isinstance(value, (str, int, float, bool, type(None))):
|
|
# Simple scalar types
|
|
simpleFields[fieldName] = value
|
|
else:
|
|
# Complex objects that should be filtered out
|
|
objectFields[fieldName] = value
|
|
else:
|
|
# Field not in model - treat as scalar if simple, otherwise filter out
|
|
# BUT: always include metadata fields (_createdBy, _createdAt, etc.) as they're handled by connector
|
|
if fieldName.startswith("_"):
|
|
# Metadata fields should be passed through to connector
|
|
simpleFields[fieldName] = value
|
|
elif isinstance(value, (str, int, float, bool, type(None))):
|
|
simpleFields[fieldName] = value
|
|
else:
|
|
objectFields[fieldName] = value
|
|
|
|
return simpleFields, objectFields
|
|
|
|
def _initRecords(self):
|
|
"""Initialize standard records if they don't exist.
|
|
|
|
Uses a global flag to ensure bootstrap only runs ONCE per application lifecycle.
|
|
The flag is set BEFORE calling bootstrap to prevent recursive calls during bootstrap.
|
|
"""
|
|
global _bootstrapCompleted
|
|
|
|
if _bootstrapCompleted:
|
|
return
|
|
|
|
# Set flag BEFORE bootstrap to prevent recursive calls during bootstrap
|
|
_bootstrapCompleted = True
|
|
logger.info("Starting bootstrap (will only run once)")
|
|
|
|
try:
|
|
initBootstrap(self.db)
|
|
logger.info("Bootstrap completed successfully")
|
|
except Exception as e:
|
|
# Reset flag on failure so bootstrap can be retried
|
|
_bootstrapCompleted = False
|
|
logger.error(f"Bootstrap failed: {e}")
|
|
raise
|
|
|
|
|
|
def checkRbacPermission(
|
|
self,
|
|
modelClass: type,
|
|
operation: str,
|
|
recordId: Optional[str] = None
|
|
) -> bool:
|
|
"""
|
|
Check RBAC permission for a specific operation on a table.
|
|
|
|
Args:
|
|
modelClass: Pydantic model class for the table
|
|
operation: Operation to check ('create', 'update', 'delete', 'read')
|
|
recordId: Optional record ID for specific record check
|
|
|
|
Returns:
|
|
Boolean indicating permission
|
|
"""
|
|
if not self.rbac or not self.currentUser:
|
|
return False
|
|
|
|
tableName = modelClass.__name__
|
|
# Use buildDataObjectKey for semantic namespace lookup
|
|
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
|
objectKey = buildDataObjectKey(tableName)
|
|
permissions = self.rbac.getUserPermissions(
|
|
self.currentUser,
|
|
AccessRuleContext.DATA,
|
|
objectKey,
|
|
mandateId=self.mandateId
|
|
)
|
|
|
|
if operation == "create":
|
|
return permissions.create != AccessLevel.NONE
|
|
elif operation == "update":
|
|
return permissions.update != AccessLevel.NONE
|
|
elif operation == "delete":
|
|
return permissions.delete != AccessLevel.NONE
|
|
elif operation == "read":
|
|
return permissions.read != AccessLevel.NONE
|
|
else:
|
|
return False
|
|
|
|
def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Apply filter criteria to records.
|
|
|
|
Supports:
|
|
- General search: {"search": "text"} - searches across all text fields
|
|
- Field-specific filters:
|
|
- Simple: {"status": "running"} - equals match
|
|
- With operator: {"status": {"operator": "equals", "value": "running"}}
|
|
- Operators: equals, contains, gt, gte, lt, lte, in, notIn, startsWith, endsWith
|
|
|
|
Args:
|
|
records: List of record dictionaries to filter
|
|
filters: Filter criteria dictionary
|
|
|
|
Returns:
|
|
Filtered list of records
|
|
"""
|
|
if not filters or not records:
|
|
return records
|
|
|
|
filtered = []
|
|
|
|
for record in records:
|
|
matches = True
|
|
|
|
# Handle general search across text fields
|
|
if "search" in filters:
|
|
search_term = str(filters["search"]).lower()
|
|
if search_term:
|
|
# Search in all string fields
|
|
found = False
|
|
for key, value in record.items():
|
|
if isinstance(value, str) and search_term in value.lower():
|
|
found = True
|
|
break
|
|
elif isinstance(value, (int, float)) and search_term in str(value):
|
|
found = True
|
|
break
|
|
if not found:
|
|
matches = False
|
|
|
|
# Handle field-specific filters
|
|
for field_name, filter_value in filters.items():
|
|
if field_name == "search":
|
|
continue # Already handled above
|
|
|
|
if field_name not in record:
|
|
matches = False
|
|
break
|
|
|
|
record_value = record.get(field_name)
|
|
|
|
# Handle simple value (equals operator)
|
|
# Compare as strings to handle type mismatches (frontend sends strings)
|
|
if not isinstance(filter_value, dict):
|
|
# Convert both to strings for comparison (case-insensitive for strings)
|
|
recordStr = str(record_value).lower() if record_value is not None else ""
|
|
filterStr = str(filter_value).lower() if filter_value is not None else ""
|
|
if recordStr != filterStr:
|
|
matches = False
|
|
break
|
|
continue
|
|
|
|
# Handle filter with operator
|
|
operator = filter_value.get("operator", "equals")
|
|
filter_val = filter_value.get("value")
|
|
|
|
if operator in ["equals", "eq"]:
|
|
# Convert both to strings for comparison
|
|
recordStr = str(record_value).lower() if record_value is not None else ""
|
|
filterStr = str(filter_val).lower() if filter_val is not None else ""
|
|
if recordStr != filterStr:
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "contains":
|
|
record_str = str(record_value).lower() if record_value is not None else ""
|
|
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
|
if filter_str not in record_str:
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "startsWith":
|
|
record_str = str(record_value).lower() if record_value is not None else ""
|
|
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
|
if not record_str.startswith(filter_str):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "endsWith":
|
|
record_str = str(record_value).lower() if record_value is not None else ""
|
|
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
|
if not record_str.endswith(filter_str):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "gt":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('-inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('-inf')
|
|
if record_num <= filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "gte":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('-inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('-inf')
|
|
if record_num < filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "lt":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('inf')
|
|
if record_num >= filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "lte":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('inf')
|
|
if record_num > filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "in":
|
|
if not isinstance(filter_val, list):
|
|
filter_val = [filter_val]
|
|
if record_value not in filter_val:
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "notIn":
|
|
if not isinstance(filter_val, list):
|
|
filter_val = [filter_val]
|
|
if record_value in filter_val:
|
|
matches = False
|
|
break
|
|
|
|
else:
|
|
# Unknown operator - default to equals
|
|
if str(record_value).lower() != str(filter_val).lower():
|
|
matches = False
|
|
break
|
|
|
|
if matches:
|
|
filtered.append(record)
|
|
|
|
return filtered
|
|
|
|
def _applySorting(self, records: List[Dict[str, Any]], sortFields: List[Any]) -> List[Dict[str, Any]]:
|
|
"""Apply multi-level sorting to records using stable sort (sorts from least to most significant field)."""
|
|
if not sortFields:
|
|
return records
|
|
|
|
# Start with a copy to avoid modifying original
|
|
sortedRecords = list(records)
|
|
|
|
# Sort from least significant to most significant field (reverse order)
|
|
# Python's sort is stable, so this creates proper multi-level sorting
|
|
for sortField in reversed(sortFields):
|
|
# Handle both dict and object formats
|
|
if isinstance(sortField, dict):
|
|
fieldName = sortField.get("field")
|
|
direction = sortField.get("direction", "asc")
|
|
else:
|
|
fieldName = getattr(sortField, "field", None)
|
|
direction = getattr(sortField, "direction", "asc")
|
|
|
|
if not fieldName:
|
|
continue
|
|
|
|
isDesc = (direction == "desc")
|
|
|
|
def sortKey(record):
|
|
value = record.get(fieldName)
|
|
# Handle None values - place them at the end for both directions
|
|
if value is None:
|
|
# Use a special value that sorts last
|
|
return (1, "") # (is_none_flag, empty_value) - sorts after (0, ...)
|
|
else:
|
|
# Return tuple with type indicator for proper comparison
|
|
if isinstance(value, (int, float)):
|
|
return (0, value)
|
|
elif isinstance(value, str):
|
|
return (0, value)
|
|
elif isinstance(value, bool):
|
|
return (0, value)
|
|
else:
|
|
return (0, str(value))
|
|
|
|
# Sort with reverse parameter for descending
|
|
sortedRecords.sort(key=sortKey, reverse=isDesc)
|
|
|
|
return sortedRecords
|
|
|
|
def getInitialId(self, model_class: type) -> Optional[str]:
|
|
"""Returns the initial ID for a table."""
|
|
return self.db.getInitialId(model_class)
|
|
|
|
def _getRootMandateId(self) -> Optional[str]:
|
|
"""Get the root mandate ID (name='root', isSystem=True)."""
|
|
rootMandates = self.db.getRecordset(Mandate, recordFilter={"name": "root", "isSystem": True})
|
|
if rootMandates:
|
|
return rootMandates[0].get("id")
|
|
return None
|
|
|
|
def _getPasswordHash(self, password: str) -> str:
|
|
"""Creates a hash for a password."""
|
|
return pwdContext.hash(password)
|
|
|
|
def _verifyPassword(self, plainPassword: str, hashedPassword: str) -> bool:
|
|
"""Checks if the password matches the hash."""
|
|
return pwdContext.verify(plainPassword, hashedPassword)
|
|
|
|
# User methods
|
|
|
|
def getUsersByMandate(self, mandateId: str, pagination: Optional[PaginationParams] = None) -> Union[List[User], PaginatedResult]:
|
|
"""
|
|
Returns users for a specific mandate.
|
|
Uses UserMandate junction table to find users belonging to the mandate.
|
|
|
|
Args:
|
|
mandateId: The mandate ID to get users for
|
|
pagination: Optional pagination parameters. If None, returns all items.
|
|
|
|
Returns:
|
|
If pagination is None: List[User]
|
|
If pagination is provided: PaginatedResult with items and metadata
|
|
"""
|
|
userMandates = self.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})
|
|
userIds = [um.get("userId") for um in userMandates if um.get("userId")]
|
|
|
|
if not userIds:
|
|
if pagination is None:
|
|
return []
|
|
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
|
|
|
result = self.db.getRecordsetPaginated(
|
|
UserInDB,
|
|
pagination=pagination,
|
|
recordFilter={"id": userIds}
|
|
)
|
|
|
|
items = []
|
|
for record in result["items"]:
|
|
cleanedUser = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
if cleanedUser.get("roleLabels") is None:
|
|
cleanedUser["roleLabels"] = []
|
|
items.append(User(**cleanedUser))
|
|
|
|
if pagination is None:
|
|
return items
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=result["totalItems"],
|
|
totalPages=result["totalPages"]
|
|
)
|
|
|
|
def getUserByUsername(self, username: str) -> Optional[User]:
|
|
"""Returns a user by username."""
|
|
try:
|
|
# Use RBAC filtering
|
|
users = getRecordsetWithRBAC(self.db,
|
|
UserInDB,
|
|
self.currentUser,
|
|
recordFilter={"username": username},
|
|
mandateId=self.mandateId
|
|
)
|
|
|
|
if not users:
|
|
logger.info(f"No user found with username {username}")
|
|
return None
|
|
|
|
# Return first matching user (should be unique)
|
|
userDict = users[0]
|
|
# Filter out database-specific fields
|
|
cleanedUser = {k: v for k, v in userDict.items() if not k.startswith("_")}
|
|
# Ensure roleLabels is always a list, not None
|
|
if cleanedUser.get("roleLabels") is None:
|
|
cleanedUser["roleLabels"] = []
|
|
return User(**cleanedUser)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting user by username: {str(e)}")
|
|
return None
|
|
|
|
def getUser(self, userId: str) -> Optional[User]:
|
|
"""Returns a user by ID if user has access."""
|
|
try:
|
|
# Get users filtered by RBAC
|
|
users = getRecordsetWithRBAC(self.db,
|
|
UserInDB,
|
|
self.currentUser,
|
|
recordFilter={"id": userId},
|
|
mandateId=self.mandateId
|
|
)
|
|
|
|
if not users:
|
|
return None
|
|
|
|
# User already filtered by RBAC, just clean fields
|
|
user_dict = users[0]
|
|
cleanedUser = {k: v for k, v in user_dict.items() if not k.startswith("_")}
|
|
# Ensure roleLabels is always a list, not None
|
|
if cleanedUser.get("roleLabels") is None:
|
|
cleanedUser["roleLabels"] = []
|
|
return User(**cleanedUser)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting user by ID: {str(e)}")
|
|
return None
|
|
|
|
def _getUserForAuthentication(self, username: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get user record by username for authentication purposes.
|
|
|
|
SECURITY NOTE: This method bypasses RBAC intentionally because:
|
|
1. Users are NOT mandate-bound (Multi-Tenant Design)
|
|
2. Authentication must work regardless of mandate context
|
|
3. RBAC filtering for User table requires mandate context which doesn't exist at login time
|
|
|
|
This method should ONLY be used for authentication flows.
|
|
For all other user queries, use getUserByUsername() which applies RBAC.
|
|
|
|
Returns:
|
|
Full UserInDB record as dict, or None if not found
|
|
"""
|
|
try:
|
|
users = self.db.getRecordset(UserInDB, recordFilter={"username": username})
|
|
if not users:
|
|
return None
|
|
return users[0]
|
|
except Exception as e:
|
|
logger.error(f"Error getting user for authentication: {str(e)}")
|
|
return None
|
|
|
|
def authenticateLocalUser(self, username: str, password: str) -> Optional[User]:
|
|
"""
|
|
Authenticates a user by username and password using local authentication.
|
|
|
|
SECURITY NOTE: Uses _getUserForAuthentication() which bypasses RBAC.
|
|
This is intentional because users are mandate-independent.
|
|
"""
|
|
# Get full user record directly (bypasses RBAC - see _getUserForAuthentication docstring)
|
|
userRecord = self._getUserForAuthentication(username)
|
|
|
|
if not userRecord:
|
|
raise ValueError("User not found")
|
|
|
|
# Check if the user is enabled
|
|
if not userRecord.get("enabled", True):
|
|
raise ValueError("User is disabled")
|
|
|
|
# Verify that the user has local authentication enabled
|
|
authAuthority = userRecord.get("authenticationAuthority", AuthAuthority.LOCAL)
|
|
if authAuthority != AuthAuthority.LOCAL and authAuthority != AuthAuthority.LOCAL.value:
|
|
raise ValueError("User does not have local authentication enabled")
|
|
|
|
if not userRecord.get("hashedPassword"):
|
|
raise ValueError("User has no password set")
|
|
|
|
if not self._verifyPassword(password, userRecord["hashedPassword"]):
|
|
raise ValueError("Invalid password")
|
|
|
|
# Return clean User object (without password hash and internal fields)
|
|
cleanedUser = {k: v for k, v in userRecord.items() if not k.startswith("_") and k != "hashedPassword" and k != "resetToken" and k != "resetTokenExpires"}
|
|
# Ensure roleLabels is always a list
|
|
if cleanedUser.get("roleLabels") is None:
|
|
cleanedUser["roleLabels"] = []
|
|
return User(**cleanedUser)
|
|
|
|
def createUser(
|
|
self,
|
|
username: str,
|
|
password: str = None,
|
|
email: str = None,
|
|
fullName: str = None,
|
|
language: str = "en",
|
|
enabled: bool = True,
|
|
authenticationAuthority: AuthAuthority = AuthAuthority.LOCAL,
|
|
externalId: str = None,
|
|
externalUsername: str = None,
|
|
externalEmail: str = None,
|
|
isSysAdmin: bool = False,
|
|
addExternalIdentityConnection: bool = True,
|
|
) -> User:
|
|
"""
|
|
Create a new user.
|
|
|
|
Note: Role assignment is done via createUserMandate(), not via User fields.
|
|
|
|
Args:
|
|
addExternalIdentityConnection: If True (default) and externalId/externalUsername are set,
|
|
creates a UserConnection row. OAuth login-only flows should pass False (data connection
|
|
is created separately via /auth/connect).
|
|
"""
|
|
try:
|
|
# Ensure username is a string
|
|
username = str(username).strip()
|
|
|
|
# Check global username uniqueness (across ALL mandates)
|
|
if not self.isUsernameGloballyUnique(username):
|
|
raise ValueError(f"Username '{username}' is already taken")
|
|
|
|
# For LOCAL auth: password is set exclusively via magic link, not during registration
|
|
# Password parameter is only used for internal/admin operations (e.g., resetUserPassword)
|
|
if password:
|
|
if not isinstance(password, str):
|
|
raise ValueError("Password must be a string")
|
|
if not password.strip():
|
|
raise ValueError("Password cannot be empty")
|
|
|
|
# Create user data using UserInDB model
|
|
# Note: mandateId and roleLabels are REMOVED - use UserMandate + UserMandateRole
|
|
userData = UserInDB(
|
|
username=username,
|
|
email=email,
|
|
fullName=fullName,
|
|
language=language,
|
|
enabled=enabled,
|
|
isSysAdmin=isSysAdmin,
|
|
authenticationAuthority=authenticationAuthority,
|
|
hashedPassword=self._getPasswordHash(password) if password else None,
|
|
)
|
|
|
|
# Create user record
|
|
createdRecord = self.db.recordCreate(UserInDB, userData)
|
|
if not createdRecord or not createdRecord.get("id"):
|
|
raise ValueError("Failed to create user record")
|
|
|
|
# Optional: mirror external IdP identity into UserConnections (data/API OAuth).
|
|
# Auth-only login (Google/MSFT JWT) must NOT create a connection — see OAuth split.
|
|
if addExternalIdentityConnection and externalId and externalUsername:
|
|
self.addUserConnection(
|
|
createdRecord["id"],
|
|
authenticationAuthority,
|
|
externalId,
|
|
externalUsername,
|
|
externalEmail,
|
|
)
|
|
|
|
# Get created user using the returned ID
|
|
createdUser = self.db.getRecordset(
|
|
UserInDB, recordFilter={"id": createdRecord["id"]}
|
|
)
|
|
if not createdUser or len(createdUser) == 0:
|
|
raise ValueError("Failed to retrieve created user")
|
|
|
|
# Clear cache to ensure fresh data (already done above)
|
|
|
|
# Note: root mandate assignment removed — users get their own mandate via
|
|
# _provisionMandateForUser during registration. Root mandate is purely technical.
|
|
|
|
return User(**createdUser[0])
|
|
|
|
except ValueError as e:
|
|
logger.error(f"Error creating user: {str(e)}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error creating user: {str(e)}")
|
|
raise ValueError(f"Failed to create user: {str(e)}")
|
|
|
|
def updateUser(self, userId: str, updateData: Union[Dict[str, Any], User], allowSysAdminChange: bool = False) -> User:
|
|
"""Update a user's information.
|
|
|
|
Args:
|
|
userId: ID of the user to update
|
|
updateData: User data to update (dict or User model)
|
|
allowSysAdminChange: If True, allows changing isSysAdmin field.
|
|
Only set to True when called by a SysAdmin explicitly
|
|
changing another user's admin status.
|
|
"""
|
|
try:
|
|
# Get user
|
|
user = self.getUser(userId)
|
|
if not user:
|
|
raise ValueError(f"User {userId} not found")
|
|
|
|
# Convert updateData to dict if it's a User model
|
|
if isinstance(updateData, User):
|
|
updateDict = updateData.model_dump()
|
|
else:
|
|
updateDict = updateData.copy() if isinstance(updateData, dict) else updateData
|
|
|
|
# Remove id field from updateDict if present - we'll use userId from parameter
|
|
updateDict.pop("id", None)
|
|
|
|
# SECURITY: Protect sensitive fields from being overwritten by profile updates.
|
|
# These fields should only be changed explicitly by admins, not through
|
|
# profile forms where they might be sent as default values (e.g., isSysAdmin=False).
|
|
protectedFields = ["isSysAdmin"]
|
|
if not allowSysAdminChange:
|
|
for field in protectedFields:
|
|
updateDict.pop(field, None)
|
|
|
|
# Update user data using model
|
|
updatedData = user.model_dump()
|
|
updatedData.update(updateDict)
|
|
# Ensure ID matches userId parameter
|
|
updatedData["id"] = userId
|
|
updatedUser = User(**updatedData)
|
|
|
|
# Update user record
|
|
self.db.recordModify(UserInDB, userId, updatedUser)
|
|
|
|
# Get updated user
|
|
updatedUser = self.getUser(userId)
|
|
if not updatedUser:
|
|
raise ValueError("Failed to retrieve updated user")
|
|
|
|
return updatedUser
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating user: {str(e)}")
|
|
raise ValueError(f"Failed to update user: {str(e)}")
|
|
|
|
def _assignUserToRootMandate(self, userId: str) -> None:
|
|
"""
|
|
Assign a new user to the root mandate with the mandate-instance 'user' role.
|
|
This ensures every user has a base membership in the system mandate.
|
|
|
|
Uses the mandate-instance role (mandateId=rootMandateId), not the global template.
|
|
Feature instance access is NOT granted here - it is managed separately
|
|
via invitations or admin assignment.
|
|
|
|
Args:
|
|
userId: User ID to assign
|
|
"""
|
|
try:
|
|
from modules.datamodels.datamodelRbac import Role
|
|
|
|
rootMandateId = self._getRootMandateId()
|
|
if not rootMandateId:
|
|
logger.warning("No root mandate found, skipping root mandate assignment")
|
|
return
|
|
|
|
# Check if user already has a mandate membership
|
|
existing = self.getUserMandate(userId, rootMandateId)
|
|
if existing:
|
|
logger.debug(f"User {userId} already assigned to root mandate")
|
|
return
|
|
|
|
# Mandate-instance 'user' role (bound to this mandate, not a global template)
|
|
mandateUserRoles = self.db.getRecordset(
|
|
Role,
|
|
recordFilter={"roleLabel": "user", "mandateId": rootMandateId, "featureInstanceId": None}
|
|
)
|
|
userRoleId = mandateUserRoles[0].get("id") if mandateUserRoles else None
|
|
|
|
roleIds = [userRoleId] if userRoleId else []
|
|
|
|
self.createUserMandate(userId, rootMandateId, roleIds)
|
|
logger.info(f"Assigned user {userId} to root mandate with user role")
|
|
|
|
except Exception as e:
|
|
# Log but don't fail user creation
|
|
logger.error(f"Error assigning user {userId} to root mandate: {e}")
|
|
|
|
def disableUser(self, userId: str) -> User:
|
|
"""Disables a user if current user has permission."""
|
|
return self.updateUser(userId, {"enabled": False})
|
|
|
|
def enableUser(self, userId: str) -> User:
|
|
"""Enables a user if current user has permission."""
|
|
return self.updateUser(userId, {"enabled": True})
|
|
|
|
def resetUserPassword(self, userId: str, newPassword: str) -> bool:
|
|
"""Reset a user's password (admin function)."""
|
|
try:
|
|
if not newPassword or len(newPassword) < 8:
|
|
raise ValueError("Password must be at least 8 characters long")
|
|
|
|
hashedPassword = self._getPasswordHash(newPassword)
|
|
self.db.recordModify(UserInDB, userId, {"hashedPassword": hashedPassword})
|
|
logger.info(f"Password reset for user {userId}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error resetting password for user {userId}: {str(e)}")
|
|
return False
|
|
|
|
def generateResetTokenAndExpiry(self) -> tuple:
|
|
"""Generate a new reset token and expiration timestamp.
|
|
|
|
Returns:
|
|
tuple: (tokenUuid: str, expiresTimestamp: float)
|
|
"""
|
|
token = str(uuid.uuid4())
|
|
expiryHours = int(APP_CONFIG.get("Auth_RESET_TOKEN_EXPIRY_HOURS", "24"))
|
|
expires = getUtcTimestamp() + (expiryHours * 3600)
|
|
return token, expires
|
|
|
|
def findUserByEmailLocalAuth(self, email: str) -> Optional[User]:
|
|
"""Find LOCAL auth user by email (searches across all mandates).
|
|
|
|
Note: If multiple users exist with the same email (in different mandates),
|
|
this returns only the first one. Use findAllUsersByEmailLocalAuth() to get all.
|
|
|
|
Args:
|
|
email: Email address to search for (case-insensitive)
|
|
|
|
Returns:
|
|
User if found, None otherwise
|
|
"""
|
|
users = self.findAllUsersByEmailLocalAuth(email)
|
|
return users[0] if users else None
|
|
|
|
def findAllUsersByEmailLocalAuth(self, email: str) -> List[User]:
|
|
"""Find ALL LOCAL auth users by email (searches across all mandates).
|
|
|
|
Use this when a user might have multiple accounts with the same email
|
|
in different mandates.
|
|
|
|
Args:
|
|
email: Email address to search for (case-insensitive)
|
|
|
|
Returns:
|
|
List of Users (empty list if none found)
|
|
"""
|
|
if not email:
|
|
return []
|
|
|
|
normalizedEmail = email.lower().strip()
|
|
|
|
try:
|
|
# Search directly without RBAC for cross-mandate search
|
|
users = self.db.getRecordset(
|
|
UserInDB,
|
|
recordFilter={
|
|
"email": normalizedEmail,
|
|
"authenticationAuthority": AuthAuthority.LOCAL.value
|
|
}
|
|
)
|
|
|
|
result = []
|
|
for userRecord in users:
|
|
cleanedUser = {k: v for k, v in userRecord.items() if not k.startswith("_")}
|
|
if cleanedUser.get("roleLabels") is None:
|
|
cleanedUser["roleLabels"] = []
|
|
result.append(User(**cleanedUser))
|
|
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error finding users by email: {str(e)}")
|
|
return []
|
|
|
|
def findUserByEmailAndUsernameLocalAuth(self, email: str, username: str) -> Optional[User]:
|
|
"""Find LOCAL auth user by email AND username combination.
|
|
|
|
This uniquely identifies a user even if they have multiple accounts
|
|
with the same email in different mandates.
|
|
|
|
Args:
|
|
email: Email address to search for (case-insensitive)
|
|
username: Username to search for (case-sensitive)
|
|
|
|
Returns:
|
|
User if found, None otherwise
|
|
"""
|
|
if not email or not username:
|
|
return None
|
|
|
|
normalizedEmail = email.lower().strip()
|
|
|
|
try:
|
|
# Search directly without RBAC for cross-mandate search
|
|
users = self.db.getRecordset(
|
|
UserInDB,
|
|
recordFilter={
|
|
"email": normalizedEmail,
|
|
"username": username,
|
|
"authenticationAuthority": AuthAuthority.LOCAL.value
|
|
}
|
|
)
|
|
|
|
if users:
|
|
cleanedUser = {k: v for k, v in users[0].items() if not k.startswith("_")}
|
|
if cleanedUser.get("roleLabels") is None:
|
|
cleanedUser["roleLabels"] = []
|
|
return User(**cleanedUser)
|
|
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error finding user by email and username: {str(e)}")
|
|
return None
|
|
|
|
def isUsernameGloballyUnique(self, username: str) -> bool:
|
|
"""Check if username is unique across ALL mandates (no RBAC filtering).
|
|
|
|
This is used for registration to ensure usernames are globally unique.
|
|
|
|
Args:
|
|
username: Username to check
|
|
|
|
Returns:
|
|
True if username is available (not used), False if already taken
|
|
"""
|
|
if not username:
|
|
return False
|
|
|
|
try:
|
|
# Search directly without RBAC for cross-mandate search
|
|
users = self.db.getRecordset(
|
|
UserInDB,
|
|
recordFilter={"username": username}
|
|
)
|
|
|
|
return len(users) == 0
|
|
except Exception as e:
|
|
logger.error(f"Error checking username uniqueness: {str(e)}")
|
|
return False # Fail safe - assume not unique on error
|
|
|
|
def findUserByUsernameLocalAuth(self, username: str) -> Optional[User]:
|
|
"""Find LOCAL auth user by username (searches across all mandates).
|
|
|
|
Username is globally unique, so this returns at most one user.
|
|
|
|
Args:
|
|
username: Username to search for
|
|
|
|
Returns:
|
|
User if found, None otherwise
|
|
"""
|
|
if not username:
|
|
return None
|
|
|
|
try:
|
|
# Search directly without RBAC for cross-mandate search
|
|
users = self.db.getRecordset(
|
|
UserInDB,
|
|
recordFilter={
|
|
"username": username,
|
|
"authenticationAuthority": AuthAuthority.LOCAL.value
|
|
}
|
|
)
|
|
|
|
if users:
|
|
cleanedUser = {k: v for k, v in users[0].items() if not k.startswith("_")}
|
|
if cleanedUser.get("roleLabels") is None:
|
|
cleanedUser["roleLabels"] = []
|
|
return User(**cleanedUser)
|
|
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error finding user by username: {str(e)}")
|
|
return None
|
|
|
|
def setResetToken(self, userId: str, token: str, expires: float, clearPassword: bool = True) -> bool:
|
|
"""Set reset token for a user.
|
|
|
|
Args:
|
|
userId: User ID
|
|
token: Reset token UUID
|
|
expires: Expiration timestamp (float)
|
|
clearPassword: If True, clears the password hash
|
|
"""
|
|
try:
|
|
updateData = {
|
|
"resetToken": token,
|
|
"resetTokenExpires": expires
|
|
}
|
|
if clearPassword:
|
|
updateData["hashedPassword"] = None
|
|
|
|
self.db.recordModify(UserInDB, userId, updateData)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error setting reset token for user {userId}: {str(e)}")
|
|
return False
|
|
|
|
def verifyResetToken(self, token: str) -> Optional[User]:
|
|
"""Verify reset token and return user if valid.
|
|
|
|
Returns:
|
|
User if token is valid and not expired, None otherwise
|
|
"""
|
|
if not token:
|
|
return None
|
|
|
|
try:
|
|
users = self.db.getRecordset(UserInDB, recordFilter={"resetToken": token})
|
|
|
|
if not users:
|
|
return None
|
|
|
|
userRecord = users[0]
|
|
|
|
# Check expiration - ensure expires is converted to float for comparison
|
|
expires = userRecord.get("resetTokenExpires")
|
|
if expires is not None:
|
|
try:
|
|
expires = float(expires)
|
|
except (ValueError, TypeError):
|
|
logger.warning(f"Invalid resetTokenExpires value for user {userRecord.get('id')}: {expires}")
|
|
return None
|
|
|
|
if not expires or getUtcTimestamp() > expires:
|
|
logger.warning(f"Reset token expired for user {userRecord.get('id')}")
|
|
return None
|
|
|
|
cleanedUser = {k: v for k, v in userRecord.items() if not k.startswith("_")}
|
|
if cleanedUser.get("roleLabels") is None:
|
|
cleanedUser["roleLabels"] = []
|
|
return User(**cleanedUser)
|
|
except Exception as e:
|
|
logger.error(f"Error verifying reset token: {str(e)}")
|
|
return None
|
|
|
|
def resetPasswordWithToken(self, token: str, newPassword: str) -> bool:
|
|
"""Reset password using token (atomic operation).
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
user = self.verifyResetToken(token)
|
|
if not user:
|
|
return False
|
|
|
|
if not newPassword or len(newPassword) < 8:
|
|
raise ValueError("Password must be at least 8 characters long")
|
|
|
|
hashedPassword = self._getPasswordHash(newPassword)
|
|
|
|
# Atomic update: set password, clear token, enable user
|
|
self.db.recordModify(UserInDB, user.id, {
|
|
"hashedPassword": hashedPassword,
|
|
"resetToken": None,
|
|
"resetTokenExpires": None,
|
|
"enabled": True
|
|
})
|
|
|
|
logger.info(f"Password reset completed for user {user.id}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error in resetPasswordWithToken: {str(e)}")
|
|
return False
|
|
|
|
def _deleteUserReferencedData(self, userId: str) -> None:
|
|
"""Deletes all data associated with a user."""
|
|
try:
|
|
# Delete user auth events
|
|
events = self.db.getRecordset(AuthEvent, recordFilter={"userId": userId})
|
|
for event in events:
|
|
self.db.recordDelete(AuthEvent, event["id"])
|
|
|
|
# Delete user tokens
|
|
tokens = self.db.getRecordset(Token, recordFilter={"userId": userId})
|
|
for token in tokens:
|
|
self.db.recordDelete(Token, token["id"])
|
|
|
|
# Delete user connections
|
|
connections = self.db.getRecordset(
|
|
UserConnection, recordFilter={"userId": userId}
|
|
)
|
|
for conn in connections:
|
|
self.db.recordDelete(UserConnection, conn["id"])
|
|
|
|
logger.info(f"All referenced data for user {userId} has been deleted")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting referenced data for user {userId}: {str(e)}")
|
|
raise
|
|
|
|
def deleteUser(self, userId: str) -> bool:
|
|
"""Deletes a user if current user has permission."""
|
|
try:
|
|
# Get user
|
|
user = self.getUser(userId)
|
|
if not user:
|
|
raise ValueError(f"User {userId} not found")
|
|
|
|
if not self.checkRbacPermission(UserInDB, "update", userId):
|
|
raise PermissionError(f"No permission to delete user {userId}")
|
|
|
|
# Delete all referenced data first
|
|
self._deleteUserReferencedData(userId)
|
|
|
|
# Delete user record
|
|
success = self.db.recordDelete(UserInDB, userId)
|
|
if not success:
|
|
raise ValueError(f"Failed to delete user {userId}")
|
|
|
|
logger.info(f"User {userId} successfully deleted")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting user: {str(e)}")
|
|
raise ValueError(f"Failed to delete user: {str(e)}")
|
|
|
|
def _getInitialUser(self) -> Optional[Dict[str, Any]]:
|
|
"""Get the initial user record directly from database without access control."""
|
|
try:
|
|
initialUserId = self.getInitialId(UserInDB)
|
|
if not initialUserId:
|
|
return None
|
|
|
|
users = getRecordsetWithRBAC(self.db,
|
|
UserInDB,
|
|
self.currentUser,
|
|
recordFilter={"id": initialUserId},
|
|
mandateId=self.mandateId
|
|
)
|
|
return users[0] if users else None
|
|
except Exception as e:
|
|
logger.error(f"Error getting initial user: {str(e)}")
|
|
return None
|
|
|
|
def checkUsernameAvailability(self, checkData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Checks if a username is available for registration.
|
|
|
|
Uses global uniqueness check (across ALL mandates) to ensure
|
|
usernames are unique system-wide.
|
|
"""
|
|
try:
|
|
username = checkData.get("username")
|
|
# authenticationAuthority not used - usernames must be globally unique
|
|
|
|
if not username:
|
|
return {"available": False, "message": "Username is required"}
|
|
|
|
# Check global uniqueness (across all mandates, no RBAC filtering)
|
|
if not self.isUsernameGloballyUnique(username):
|
|
return {"available": False, "message": "Username is already taken"}
|
|
|
|
return {"available": True, "message": "Username is available"}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking username availability: {str(e)}")
|
|
return {
|
|
"available": False,
|
|
"message": f"Error checking username availability: {str(e)}",
|
|
}
|
|
|
|
# Connection methods
|
|
|
|
def getUserConnections(self, userId: str) -> List[UserConnection]:
|
|
"""Returns all connections for a user."""
|
|
try:
|
|
# Get connections for this user
|
|
connections = self.db.getRecordset(
|
|
UserConnection, recordFilter={"userId": userId}
|
|
)
|
|
|
|
# Convert to UserConnection objects
|
|
result = []
|
|
for conn_dict in connections:
|
|
try:
|
|
# Create UserConnection object
|
|
connection = UserConnection(
|
|
id=conn_dict["id"],
|
|
userId=conn_dict["userId"],
|
|
authority=conn_dict.get("authority"),
|
|
externalId=conn_dict.get("externalId", ""),
|
|
externalUsername=conn_dict.get("externalUsername", ""),
|
|
externalEmail=conn_dict.get("externalEmail"),
|
|
status=conn_dict.get("status", "pending"),
|
|
connectedAt=conn_dict.get("connectedAt"),
|
|
lastChecked=conn_dict.get("lastChecked"),
|
|
expiresAt=conn_dict.get("expiresAt"),
|
|
)
|
|
result.append(connection)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error converting connection dict to object: {str(e)}"
|
|
)
|
|
continue
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting user connections: {str(e)}")
|
|
return []
|
|
|
|
def getUserConnectionById(self, connectionId: str) -> Optional[UserConnection]:
|
|
"""Get a single UserConnection by ID."""
|
|
try:
|
|
connections = self.db.getRecordset(
|
|
UserConnection, recordFilter={"id": connectionId}
|
|
)
|
|
if connections:
|
|
conn_dict = connections[0]
|
|
return UserConnection(
|
|
id=conn_dict["id"],
|
|
userId=conn_dict["userId"],
|
|
authority=conn_dict.get("authority"),
|
|
externalId=conn_dict.get("externalId", ""),
|
|
externalUsername=conn_dict.get("externalUsername", ""),
|
|
externalEmail=conn_dict.get("externalEmail"),
|
|
status=conn_dict.get("status", "pending"),
|
|
connectedAt=conn_dict.get("connectedAt"),
|
|
lastChecked=conn_dict.get("lastChecked"),
|
|
expiresAt=conn_dict.get("expiresAt"),
|
|
)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting user connection by ID: {str(e)}")
|
|
return None
|
|
|
|
def addUserConnection(
|
|
self,
|
|
userId: str,
|
|
authority: AuthAuthority,
|
|
externalId: str,
|
|
externalUsername: str,
|
|
externalEmail: Optional[str] = None,
|
|
status: ConnectionStatus = ConnectionStatus.PENDING,
|
|
) -> UserConnection:
|
|
"""
|
|
Adds a new connection for a user.
|
|
|
|
Args:
|
|
userId: The ID of the user
|
|
authority: The authentication authority (e.g., MSFT, GOOGLE)
|
|
externalId: The external ID from the authority
|
|
externalUsername: The username from the authority
|
|
externalEmail: Optional email from the authority
|
|
status: The connection status (defaults to PENDING)
|
|
|
|
Returns:
|
|
The created UserConnection object
|
|
"""
|
|
try:
|
|
# Note: User verification is skipped here because:
|
|
# 1. The caller (route) already has an authenticated currentUser
|
|
# 2. Users should always be able to create connections for themselves
|
|
# 3. getUser() uses RBAC filtering which may fail for users without UserInDB view permissions
|
|
|
|
# Create new connection with all required fields
|
|
connection = UserConnection(
|
|
id=str(uuid.uuid4()),
|
|
userId=userId,
|
|
authority=authority,
|
|
externalId=externalId,
|
|
externalUsername=externalUsername,
|
|
externalEmail=externalEmail,
|
|
status=status,
|
|
connectedAt=getUtcTimestamp(),
|
|
lastChecked=getUtcTimestamp(),
|
|
expiresAt=None, # Optional field, set to None by default
|
|
)
|
|
|
|
# Save to connections table
|
|
self.db.recordCreate(UserConnection, connection)
|
|
|
|
return connection
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding user connection: {str(e)}")
|
|
raise ValueError(f"Failed to add user connection: {str(e)}")
|
|
|
|
def removeUserConnection(self, connectionId: str) -> None:
|
|
"""Remove a connection to an external service"""
|
|
try:
|
|
# Get connection
|
|
connections = self.db.getRecordset(
|
|
UserConnection, recordFilter={"id": connectionId}
|
|
)
|
|
|
|
if not connections:
|
|
raise ValueError(f"Connection {connectionId} not found")
|
|
|
|
# Delete connection
|
|
self.db.recordDelete(UserConnection, connectionId)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error removing user connection: {str(e)}")
|
|
raise ValueError(f"Failed to remove user connection: {str(e)}")
|
|
|
|
# Mandate methods
|
|
|
|
def getAllMandates(self, pagination: Optional[PaginationParams] = None) -> Union[List[Mandate], PaginatedResult]:
|
|
"""
|
|
Returns all mandates based on user access level.
|
|
Supports optional pagination, sorting, and filtering.
|
|
|
|
Args:
|
|
pagination: Optional pagination parameters. If None, returns all items.
|
|
|
|
Returns:
|
|
If pagination is None: List[Mandate]
|
|
If pagination is provided: PaginatedResult with items and metadata
|
|
"""
|
|
# Use RBAC filtering
|
|
allMandates = getRecordsetWithRBAC(self.db, Mandate, self.currentUser, mandateId=self.mandateId)
|
|
|
|
# Filter out database-specific fields
|
|
filteredMandates = []
|
|
for mandate in allMandates:
|
|
cleanedMandate = {k: v for k, v in mandate.items() if not k.startswith("_")}
|
|
filteredMandates.append(cleanedMandate)
|
|
|
|
# If no pagination requested, return all items
|
|
if pagination is None:
|
|
return [Mandate(**mandate) for mandate in filteredMandates]
|
|
|
|
# Apply filtering (if filters provided)
|
|
if pagination.filters:
|
|
filteredMandates = self._applyFilters(filteredMandates, pagination.filters)
|
|
|
|
# Apply sorting (in order of sortFields)
|
|
if pagination.sort:
|
|
filteredMandates = self._applySorting(filteredMandates, pagination.sort)
|
|
|
|
# Count total items after filters
|
|
totalItems = len(filteredMandates)
|
|
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
|
|
|
# Apply pagination (skip/limit)
|
|
startIdx = (pagination.page - 1) * pagination.pageSize
|
|
endIdx = startIdx + pagination.pageSize
|
|
pagedMandates = filteredMandates[startIdx:endIdx]
|
|
|
|
# Convert to model objects
|
|
items = [Mandate(**mandate) for mandate in pagedMandates]
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=totalItems,
|
|
totalPages=totalPages
|
|
)
|
|
|
|
def getMandate(self, mandateId: str) -> Optional[Mandate]:
|
|
"""Returns a mandate by ID if user has access."""
|
|
# Use RBAC filtering
|
|
mandates = getRecordsetWithRBAC(self.db,
|
|
Mandate,
|
|
self.currentUser,
|
|
recordFilter={"id": mandateId},
|
|
mandateId=self.mandateId
|
|
)
|
|
|
|
if not mandates:
|
|
return None
|
|
|
|
# Filter out database-specific fields
|
|
filteredMandates = []
|
|
for mandate in mandates:
|
|
cleanedMandate = {k: v for k, v in mandate.items() if not k.startswith("_")}
|
|
filteredMandates.append(cleanedMandate)
|
|
if not filteredMandates:
|
|
return None
|
|
|
|
return Mandate(**filteredMandates[0])
|
|
|
|
def createMandate(self, name: str, label: str = None, enabled: bool = True) -> Mandate:
|
|
"""
|
|
Creates a new mandate if user has permission.
|
|
Automatically copies system template roles (admin, user, viewer) to the new mandate.
|
|
"""
|
|
if not self.checkRbacPermission(Mandate, "create"):
|
|
raise PermissionError("No permission to create mandates")
|
|
|
|
# Create mandate data using model
|
|
mandateData = Mandate(name=name, label=label, enabled=enabled)
|
|
|
|
# Create mandate record
|
|
createdRecord = self.db.recordCreate(Mandate, mandateData)
|
|
if not createdRecord or not createdRecord.get("id"):
|
|
raise ValueError("Failed to create mandate record")
|
|
|
|
mandateId = createdRecord.get("id")
|
|
|
|
# Copy system template roles to new mandate (admin, user, viewer + AccessRules)
|
|
try:
|
|
from modules.interfaces.interfaceBootstrap import copySystemRolesToMandate
|
|
copiedCount = copySystemRolesToMandate(self.db, mandateId)
|
|
logger.info(f"Copied {copiedCount} system roles to new mandate {mandateId}")
|
|
except Exception as e:
|
|
logger.error(f"Error copying system roles to mandate {mandateId}: {e}")
|
|
|
|
return Mandate(**createdRecord)
|
|
|
|
def _provisionMandateForUser(self, userId: str, mandateType: str, mandateName: str, planKey: str) -> Dict[str, Any]:
|
|
"""
|
|
Atomic provisioning: create Mandate + UserMandate + Subscription + auto-create FeatureInstances.
|
|
Internal method — bypasses RBAC (used during registration when user has no permissions yet).
|
|
"""
|
|
from modules.datamodels.datamodelUam import MandateType
|
|
from modules.datamodels.datamodelSubscription import MandateSubscription, SubscriptionStatusEnum, BUILTIN_PLANS
|
|
from modules.datamodels.datamodelFeatures import FeatureInstance
|
|
from modules.interfaces.interfaceBootstrap import copySystemRolesToMandate
|
|
from modules.interfaces.interfaceFeatures import getFeatureInterface
|
|
from modules.system.registry import loadFeatureMainModules
|
|
|
|
plan = BUILTIN_PLANS.get(planKey)
|
|
if not plan:
|
|
raise ValueError(f"Unknown plan: {planKey}")
|
|
|
|
mandateData = Mandate(
|
|
name=mandateName,
|
|
label=mandateName,
|
|
enabled=True,
|
|
isSystem=False,
|
|
mandateType=MandateType(mandateType),
|
|
)
|
|
createdMandate = self.db.recordCreate(Mandate, mandateData)
|
|
if not createdMandate or not createdMandate.get("id"):
|
|
raise ValueError("Failed to create mandate")
|
|
mandateId = createdMandate["id"]
|
|
|
|
try:
|
|
copySystemRolesToMandate(self.db, mandateId)
|
|
|
|
adminRoleId = None
|
|
mandateRoles = self.db.getRecordset(Role, recordFilter={"mandateId": mandateId, "featureInstanceId": None})
|
|
for r in mandateRoles:
|
|
if "admin" in (r.get("roleLabel") or "").lower():
|
|
adminRoleId = r.get("id")
|
|
break
|
|
|
|
userMandate = UserMandate(userId=userId, mandateId=mandateId, enabled=True)
|
|
createdUm = self.db.recordCreate(UserMandate, userMandate.model_dump())
|
|
if adminRoleId and createdUm:
|
|
umRole = UserMandateRole(userMandateId=createdUm["id"], roleId=adminRoleId)
|
|
self.db.recordCreate(UserMandateRole, umRole.model_dump())
|
|
|
|
subscription = MandateSubscription(
|
|
mandateId=mandateId,
|
|
planKey=planKey,
|
|
status=SubscriptionStatusEnum.PENDING,
|
|
)
|
|
if plan.trialDays:
|
|
pass # trialEndsAt set on ACTIVE transition
|
|
self.db.recordCreate(MandateSubscription, subscription.model_dump())
|
|
|
|
featureInterface = getFeatureInterface(self.db)
|
|
mainModules = loadFeatureMainModules()
|
|
createdInstances = []
|
|
for featureName, module in mainModules.items():
|
|
if not hasattr(module, "getFeatureDefinition"):
|
|
continue
|
|
try:
|
|
featureDef = module.getFeatureDefinition()
|
|
if not featureDef.get("autoCreateInstance", False):
|
|
continue
|
|
featureCode = featureDef.get("code", featureName)
|
|
featureLabel = featureDef.get("label", {}).get("en", featureName)
|
|
instance = featureInterface.createFeatureInstance(
|
|
featureCode=featureCode,
|
|
mandateId=mandateId,
|
|
label=featureLabel,
|
|
enabled=True,
|
|
copyTemplateRoles=True,
|
|
)
|
|
if instance:
|
|
instanceId = instance.get("id") if isinstance(instance, dict) else instance.id
|
|
createdInstances.append(instanceId)
|
|
instanceRoles = self.db.getRecordset(Role, recordFilter={"featureInstanceId": instanceId})
|
|
adminInstRoleId = None
|
|
for ir in instanceRoles:
|
|
if "admin" in (ir.get("roleLabel") or "").lower():
|
|
adminInstRoleId = ir.get("id")
|
|
break
|
|
fa = FeatureAccess(userId=userId, featureInstanceId=instanceId, enabled=True)
|
|
createdFa = self.db.recordCreate(FeatureAccess, fa.model_dump())
|
|
if adminInstRoleId and createdFa:
|
|
far = FeatureAccessRole(featureAccessId=createdFa["id"], roleId=adminInstRoleId)
|
|
self.db.recordCreate(FeatureAccessRole, far.model_dump())
|
|
except Exception as e:
|
|
logger.error(f"Error auto-creating instance for '{featureName}': {e}")
|
|
|
|
logger.info(f"Provisioned mandate {mandateId} (type={mandateType}, plan={planKey}) for user {userId}, instances={createdInstances}")
|
|
return {
|
|
"mandateId": mandateId,
|
|
"planKey": planKey,
|
|
"mandateType": mandateType,
|
|
"featureInstances": createdInstances,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Provisioning failed for user {userId}, cleaning up mandate {mandateId}: {e}")
|
|
try:
|
|
self.db.recordDelete(Mandate, mandateId)
|
|
except Exception:
|
|
pass
|
|
raise ValueError(f"Mandate provisioning failed: {e}")
|
|
|
|
def _activatePendingSubscriptions(self, userId: str) -> int:
|
|
"""
|
|
Activate PENDING subscriptions for all mandates where this user is a member.
|
|
Called on login — trial period begins NOW, not at registration.
|
|
Returns number of activated subscriptions.
|
|
"""
|
|
from modules.datamodels.datamodelSubscription import (
|
|
MandateSubscription, SubscriptionStatusEnum, BUILTIN_PLANS,
|
|
)
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
activated = 0
|
|
userMandates = self.db.getRecordset(
|
|
UserMandate, recordFilter={"userId": userId, "enabled": True}
|
|
)
|
|
|
|
for um in userMandates:
|
|
mandateId = um.get("mandateId")
|
|
subs = self.db.getRecordset(
|
|
MandateSubscription,
|
|
recordFilter={"mandateId": mandateId, "status": SubscriptionStatusEnum.PENDING.value}
|
|
)
|
|
for sub in subs:
|
|
subId = sub.get("id")
|
|
planKey = sub.get("planKey")
|
|
plan = BUILTIN_PLANS.get(planKey)
|
|
now = datetime.now(timezone.utc)
|
|
|
|
updateData = {
|
|
"status": SubscriptionStatusEnum.TRIALING.value if plan and plan.trialDays else SubscriptionStatusEnum.ACTIVE.value,
|
|
"currentPeriodStart": now.isoformat(),
|
|
}
|
|
|
|
if plan and plan.trialDays:
|
|
trialEnd = now + timedelta(days=plan.trialDays)
|
|
updateData["trialEndsAt"] = trialEnd.isoformat()
|
|
updateData["currentPeriodEnd"] = trialEnd.isoformat()
|
|
elif plan and plan.billingPeriod:
|
|
from modules.datamodels.datamodelSubscription import BillingPeriodEnum
|
|
if plan.billingPeriod == BillingPeriodEnum.MONTHLY:
|
|
updateData["currentPeriodEnd"] = (now + timedelta(days=30)).isoformat()
|
|
elif plan.billingPeriod == BillingPeriodEnum.YEARLY:
|
|
updateData["currentPeriodEnd"] = (now + timedelta(days=365)).isoformat()
|
|
|
|
try:
|
|
self.db.recordModify(MandateSubscription, subId, updateData)
|
|
activated += 1
|
|
logger.info(f"Activated subscription {subId} (plan={planKey}) for mandate {mandateId}: {updateData.get('status')}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to activate subscription {subId}: {e}")
|
|
|
|
return activated
|
|
|
|
def updateMandate(self, mandateId: str, updateData: Dict[str, Any]) -> Mandate:
|
|
"""Updates a mandate if user has access."""
|
|
try:
|
|
# First check if user has permission to modify mandates
|
|
if not self.checkRbacPermission(Mandate, "update", mandateId):
|
|
raise PermissionError(f"No permission to update mandate {mandateId}")
|
|
|
|
# Get mandate with access control
|
|
mandate = self.getMandate(mandateId)
|
|
if not mandate:
|
|
raise ValueError(f"Mandate {mandateId} not found")
|
|
|
|
# Strip immutable/protected fields from update data
|
|
_protectedFields = {"id", "isSystem"}
|
|
_sanitizedData = {k: v for k, v in updateData.items() if k not in _protectedFields}
|
|
|
|
# Update mandate data using model
|
|
updatedData = mandate.model_dump()
|
|
updatedData.update(_sanitizedData)
|
|
updatedMandate = Mandate(**updatedData)
|
|
|
|
# Update mandate record
|
|
self.db.recordModify(Mandate, mandateId, updatedMandate)
|
|
|
|
# Clear cache to ensure fresh data
|
|
|
|
# Get updated mandate
|
|
updatedMandate = self.getMandate(mandateId)
|
|
if not updatedMandate:
|
|
raise ValueError("Failed to retrieve updated mandate")
|
|
|
|
return updatedMandate
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating mandate: {str(e)}")
|
|
raise ValueError(f"Failed to update mandate: {str(e)}")
|
|
|
|
def deleteMandate(self, mandateId: str, force: bool = False) -> bool:
|
|
"""
|
|
Delete a mandate with full cascade.
|
|
|
|
Default (force=False): soft-delete — sets enabled=false.
|
|
With force=True: hard-delete — removes all related data.
|
|
System mandates (isSystem=True) cannot be deleted.
|
|
"""
|
|
try:
|
|
mandate = self.getMandate(mandateId)
|
|
if not mandate:
|
|
return False
|
|
|
|
if getattr(mandate, "isSystem", False):
|
|
raise ValueError(f"System mandate '{mandate.name}' cannot be deleted")
|
|
|
|
if not self.checkRbacPermission(Mandate, "delete", mandateId):
|
|
raise PermissionError(f"No permission to delete mandate {mandateId}")
|
|
|
|
if not force:
|
|
self.db.recordModify(Mandate, mandateId, {"enabled": False})
|
|
logger.info(f"Soft-deleted mandate {mandateId}")
|
|
return True
|
|
|
|
# Hard delete with cascade
|
|
from modules.datamodels.datamodelSubscription import MandateSubscription
|
|
|
|
# 1. Delete FeatureAccess + FeatureAccessRole for all instances in this mandate
|
|
instances = self.db.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId})
|
|
for inst in instances:
|
|
instId = inst.get("id")
|
|
accesses = self.db.getRecordset(FeatureAccess, recordFilter={"featureInstanceId": instId})
|
|
for access in accesses:
|
|
self.db.recordDelete(FeatureAccess, access.get("id"))
|
|
self.db.recordDelete(FeatureInstance, instId)
|
|
logger.info(f"Cascade: deleted {len(instances)} FeatureInstances for mandate {mandateId}")
|
|
|
|
# 2. Delete UserMandate + UserMandateRole
|
|
memberships = self.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})
|
|
for um in memberships:
|
|
self.db.recordDelete(UserMandate, um.get("id"))
|
|
logger.info(f"Cascade: deleted {len(memberships)} UserMandates for mandate {mandateId}")
|
|
|
|
# 3. Delete MandateSubscriptions
|
|
subs = self.db.getRecordset(MandateSubscription, recordFilter={"mandateId": mandateId})
|
|
for sub in subs:
|
|
self.db.recordDelete(MandateSubscription, sub.get("id"))
|
|
logger.info(f"Cascade: deleted {len(subs)} subscriptions for mandate {mandateId}")
|
|
|
|
# 4. Delete mandate-level Roles
|
|
from modules.datamodels.datamodelRbac import Role, AccessRule
|
|
roles = self.db.getRecordset(Role, recordFilter={"mandateId": mandateId})
|
|
for role in roles:
|
|
rules = self.db.getRecordset(AccessRule, recordFilter={"roleId": role.get("id")})
|
|
for rule in rules:
|
|
self.db.recordDelete(AccessRule, rule.get("id"))
|
|
self.db.recordDelete(Role, role.get("id"))
|
|
logger.info(f"Cascade: deleted {len(roles)} Roles for mandate {mandateId}")
|
|
|
|
# 5. Delete mandate record
|
|
success = self.db.recordDelete(Mandate, mandateId)
|
|
logger.info(f"Hard-deleted mandate {mandateId}")
|
|
return success
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting mandate: {str(e)}")
|
|
raise ValueError(f"Failed to delete mandate: {str(e)}")
|
|
|
|
# ============================================
|
|
# User-Mandate Membership Methods (Multi-Tenant)
|
|
# ============================================
|
|
|
|
def getUserMandate(self, userId: str, mandateId: str) -> Optional[UserMandate]:
|
|
"""
|
|
Get UserMandate record for a user in a specific mandate.
|
|
|
|
Args:
|
|
userId: User ID
|
|
mandateId: Mandate ID
|
|
|
|
Returns:
|
|
UserMandate object or None
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
UserMandate,
|
|
recordFilter={"userId": userId, "mandateId": mandateId}
|
|
)
|
|
if not records:
|
|
return None
|
|
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
|
return UserMandate(**cleanedRecord)
|
|
except Exception as e:
|
|
logger.error(f"Error getting UserMandate: {e}")
|
|
return None
|
|
|
|
def getUserMandates(self, userId: str) -> List[UserMandate]:
|
|
"""
|
|
Get all mandates a user is member of.
|
|
|
|
Args:
|
|
userId: User ID
|
|
|
|
Returns:
|
|
List of UserMandate objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
UserMandate,
|
|
recordFilter={"userId": userId, "enabled": True}
|
|
)
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(UserMandate(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting UserMandates: {e}")
|
|
return []
|
|
|
|
def createUserMandate(self, userId: str, mandateId: str, roleIds: List[str] = None) -> UserMandate:
|
|
"""
|
|
Create a UserMandate record (add user to mandate).
|
|
Also creates a billing account for the user if billing is configured for PREPAY_USER.
|
|
|
|
Args:
|
|
userId: User ID
|
|
mandateId: Mandate ID
|
|
roleIds: Optional list of role IDs to assign
|
|
|
|
Returns:
|
|
Created UserMandate object
|
|
"""
|
|
try:
|
|
# Check if already exists
|
|
existing = self.getUserMandate(userId, mandateId)
|
|
if existing:
|
|
raise ValueError(f"User {userId} is already member of mandate {mandateId}")
|
|
|
|
# Subscription capacity check (before insert)
|
|
self._checkSubscriptionCapacity(mandateId, "users", delta=1)
|
|
|
|
# Create UserMandate
|
|
userMandate = UserMandate(
|
|
userId=userId,
|
|
mandateId=mandateId,
|
|
enabled=True
|
|
)
|
|
createdRecord = self.db.recordCreate(UserMandate, userMandate.model_dump())
|
|
|
|
# Assign roles via junction table
|
|
if roleIds and createdRecord:
|
|
userMandateId = createdRecord.get("id")
|
|
for roleId in roleIds:
|
|
userMandateRole = UserMandateRole(
|
|
userMandateId=userMandateId,
|
|
roleId=roleId
|
|
)
|
|
self.db.recordCreate(UserMandateRole, userMandateRole.model_dump())
|
|
|
|
# Create billing account for user if billing is configured
|
|
self._ensureUserBillingAccount(userId, mandateId)
|
|
|
|
# Sync Stripe quantity after successful insert
|
|
self._syncSubscriptionQuantity(mandateId)
|
|
|
|
cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")}
|
|
return UserMandate(**cleanedRecord)
|
|
except Exception as e:
|
|
logger.error(f"Error creating UserMandate: {e}")
|
|
raise ValueError(f"Failed to create UserMandate: {e}")
|
|
|
|
def _ensureUserBillingAccount(self, userId: str, mandateId: str) -> None:
|
|
"""
|
|
Ensure a user has a billing account for the mandate if billing is configured.
|
|
User accounts are always created for all billing models (for audit trail).
|
|
Initial balance depends on billing model:
|
|
- PREPAY_USER: defaultUserCredit from mandate BillingSettings when joining the root mandate (missing key => 0.0);
|
|
other mandates get 0.0.
|
|
- PREPAY_MANDATE: 0.0 on the user account (shared pool — no per-user start credit)
|
|
|
|
Args:
|
|
userId: User ID
|
|
mandateId: Mandate ID
|
|
"""
|
|
try:
|
|
from modules.interfaces.interfaceDbBilling import _getRootInterface as getBillingRootInterface
|
|
from modules.datamodels.datamodelBilling import BillingModelEnum, parseBillingModelFromStoredValue
|
|
|
|
billingInterface = getBillingRootInterface()
|
|
settings = billingInterface.getSettings(mandateId)
|
|
|
|
if not settings:
|
|
return # No billing configured for this mandate
|
|
|
|
billingModel = parseBillingModelFromStoredValue(settings.get("billingModel"))
|
|
|
|
# Initial balance depends on billing model (start credit only on root mandate for PREPAY_USER)
|
|
rootMandateId = self._getRootMandateId()
|
|
isRootMandate = rootMandateId is not None and str(mandateId) == str(rootMandateId)
|
|
if billingModel == BillingModelEnum.PREPAY_USER:
|
|
initialBalance = (
|
|
float(settings.get("defaultUserCredit", 0.0))
|
|
if isRootMandate
|
|
else 0.0
|
|
)
|
|
else:
|
|
initialBalance = 0.0 # PREPAY_MANDATE: budget is on pool
|
|
|
|
billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=initialBalance)
|
|
logger.info(f"Ensured billing account for user {userId} in mandate {mandateId} (model={billingModel.value}, initial={initialBalance} CHF)")
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Failed to create billing account for user {userId} (non-critical): {e}")
|
|
|
|
def _checkSubscriptionCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> None:
|
|
"""Check subscription capacity before creating a resource. Raises on cap violation."""
|
|
try:
|
|
from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface
|
|
from modules.security.rootAccess import getRootUser
|
|
subIf = getSubInterface(getRootUser(), mandateId)
|
|
subIf.assertCapacity(mandateId, resourceType, delta)
|
|
except Exception as e:
|
|
if "SubscriptionCapacityException" in type(e).__name__:
|
|
raise
|
|
logger.debug(f"Subscription capacity check skipped: {e}")
|
|
|
|
def _syncSubscriptionQuantity(self, mandateId: str) -> None:
|
|
"""Sync Stripe subscription quantities after a resource mutation."""
|
|
try:
|
|
from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface
|
|
from modules.security.rootAccess import getRootUser
|
|
subIf = getSubInterface(getRootUser(), mandateId)
|
|
subIf.syncQuantityToStripe(mandateId)
|
|
except Exception as e:
|
|
logger.debug(f"Subscription quantity sync skipped: {e}")
|
|
|
|
def deleteUserMandate(self, userId: str, mandateId: str) -> bool:
|
|
"""
|
|
Delete a UserMandate record (remove user from mandate).
|
|
CASCADE will delete UserMandateRole entries.
|
|
Also removes FeatureAccess rows for any feature instances that belong to this mandate
|
|
(FeatureAccessRole rows cascade from FeatureAccess).
|
|
|
|
Args:
|
|
userId: User ID
|
|
mandateId: Mandate ID
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
try:
|
|
existing = self.getUserMandate(userId, mandateId)
|
|
if not existing:
|
|
return False
|
|
|
|
# Drop feature-instance memberships for instances under this mandate
|
|
instanceRows = self.db.getRecordset(
|
|
FeatureInstance,
|
|
recordFilter={"mandateId": mandateId}
|
|
)
|
|
for row in instanceRows:
|
|
instId = row.get("id")
|
|
if not instId:
|
|
continue
|
|
accessRows = self.db.getRecordset(
|
|
FeatureAccess,
|
|
recordFilter={"userId": userId, "featureInstanceId": instId}
|
|
)
|
|
for acc in accessRows:
|
|
accId = acc.get("id")
|
|
if accId:
|
|
self.db.recordDelete(FeatureAccess, accId)
|
|
|
|
return self.db.recordDelete(UserMandate, existing.id)
|
|
except Exception as e:
|
|
logger.error(f"Error deleting UserMandate: {e}")
|
|
raise ValueError(f"Failed to delete UserMandate: {e}")
|
|
|
|
def getUserMandatesByMandate(self, mandateId: str) -> List[UserMandate]:
|
|
"""
|
|
Get all UserMandate records for a specific mandate.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
|
|
Returns:
|
|
List of UserMandate objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
UserMandate,
|
|
recordFilter={"mandateId": mandateId}
|
|
)
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(UserMandate(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting UserMandates for mandate {mandateId}: {e}")
|
|
return []
|
|
|
|
def getUserMandateRoles(self, userMandateId: str) -> List[UserMandateRole]:
|
|
"""
|
|
Get all UserMandateRole records for a UserMandate.
|
|
|
|
Args:
|
|
userMandateId: UserMandate ID
|
|
|
|
Returns:
|
|
List of UserMandateRole objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
UserMandateRole,
|
|
recordFilter={"userMandateId": userMandateId}
|
|
)
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(UserMandateRole(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting UserMandateRoles: {e}")
|
|
return []
|
|
|
|
def deleteUserMandateRoles(self, userMandateId: str) -> int:
|
|
"""
|
|
Delete all role assignments for a UserMandate.
|
|
|
|
Args:
|
|
userMandateId: UserMandate ID
|
|
|
|
Returns:
|
|
Number of deleted role assignments
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
UserMandateRole,
|
|
recordFilter={"userMandateId": userMandateId}
|
|
)
|
|
deletedCount = 0
|
|
for record in records:
|
|
if self.db.recordDelete(UserMandateRole, record.get("id")):
|
|
deletedCount += 1
|
|
return deletedCount
|
|
except Exception as e:
|
|
logger.error(f"Error deleting UserMandateRoles: {e}")
|
|
return 0
|
|
|
|
def validateRoleForMandate(self, roleId: str, mandateId: str) -> Role:
|
|
"""
|
|
Validate a role exists and belongs to the specified mandate (or is global).
|
|
|
|
Args:
|
|
roleId: Role ID to validate
|
|
mandateId: Mandate ID for context validation
|
|
|
|
Returns:
|
|
Role object if valid
|
|
|
|
Raises:
|
|
ValueError: If role not found or belongs to different mandate
|
|
"""
|
|
role = self.getRole(roleId)
|
|
if not role:
|
|
raise ValueError(f"Role {roleId} not found")
|
|
|
|
# Check mandate scope
|
|
if role.mandateId and str(role.mandateId) != str(mandateId):
|
|
raise ValueError(f"Role {roleId} belongs to a different mandate")
|
|
|
|
# Check feature-instance scope (not allowed at mandate level)
|
|
if role.featureInstanceId:
|
|
raise ValueError(f"Role {roleId} is a feature-instance role and cannot be assigned at mandate level")
|
|
|
|
return role
|
|
|
|
def getRoleIdsForUserMandate(self, userMandateId: str) -> List[str]:
|
|
"""
|
|
Get all role IDs assigned to a UserMandate.
|
|
|
|
Args:
|
|
userMandateId: UserMandate ID
|
|
|
|
Returns:
|
|
List of role IDs
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
UserMandateRole,
|
|
recordFilter={"userMandateId": userMandateId}
|
|
)
|
|
return [r.get("roleId") for r in records if r.get("roleId")]
|
|
except Exception as e:
|
|
logger.error(f"Error getting role IDs for UserMandate: {e}")
|
|
return []
|
|
|
|
def addRoleToUserMandate(self, userMandateId: str, roleId: str) -> UserMandateRole:
|
|
"""
|
|
Add a role to a UserMandate.
|
|
|
|
Args:
|
|
userMandateId: UserMandate ID
|
|
roleId: Role ID to add
|
|
|
|
Returns:
|
|
Created UserMandateRole object
|
|
"""
|
|
try:
|
|
# Check if already exists
|
|
existing = self.db.getRecordset(
|
|
UserMandateRole,
|
|
recordFilter={"userMandateId": userMandateId, "roleId": roleId}
|
|
)
|
|
if existing:
|
|
cleanedRecord = {k: v for k, v in existing[0].items() if not k.startswith("_")}
|
|
return UserMandateRole(**cleanedRecord)
|
|
|
|
userMandateRole = UserMandateRole(
|
|
userMandateId=userMandateId,
|
|
roleId=roleId
|
|
)
|
|
createdRecord = self.db.recordCreate(UserMandateRole, userMandateRole.model_dump())
|
|
cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")}
|
|
return UserMandateRole(**cleanedRecord)
|
|
except Exception as e:
|
|
logger.error(f"Error adding role to UserMandate: {e}")
|
|
raise ValueError(f"Failed to add role: {e}")
|
|
|
|
def removeRoleFromUserMandate(self, userMandateId: str, roleId: str) -> bool:
|
|
"""
|
|
Remove a role from a UserMandate.
|
|
If no roles remain, the UserMandate is deleted (Application-Level Cleanup).
|
|
|
|
Args:
|
|
userMandateId: UserMandate ID
|
|
roleId: Role ID to remove
|
|
|
|
Returns:
|
|
True if removed
|
|
"""
|
|
try:
|
|
# Find and delete the junction record
|
|
records = self.db.getRecordset(
|
|
UserMandateRole,
|
|
recordFilter={"userMandateId": userMandateId, "roleId": roleId}
|
|
)
|
|
if not records:
|
|
return False
|
|
|
|
self.db.recordDelete(UserMandateRole, records[0].get("id"))
|
|
|
|
# Application-Level Cleanup: Delete UserMandate if no roles remain
|
|
remainingRoles = self.db.getRecordset(
|
|
UserMandateRole,
|
|
recordFilter={"userMandateId": userMandateId}
|
|
)
|
|
if not remainingRoles:
|
|
self.db.recordDelete(UserMandate, userMandateId)
|
|
logger.info(f"Deleted empty UserMandate {userMandateId}")
|
|
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error removing role from UserMandate: {e}")
|
|
raise ValueError(f"Failed to remove role: {e}")
|
|
|
|
# ============================================
|
|
# Feature Access Methods (Multi-Tenant)
|
|
# ============================================
|
|
|
|
def getFeatureAccess(self, userId: str, featureInstanceId: str) -> Optional[FeatureAccess]:
|
|
"""
|
|
Get FeatureAccess record for a user to a specific feature instance.
|
|
|
|
Args:
|
|
userId: User ID
|
|
featureInstanceId: FeatureInstance ID
|
|
|
|
Returns:
|
|
FeatureAccess object or None
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
FeatureAccess,
|
|
recordFilter={"userId": userId, "featureInstanceId": featureInstanceId}
|
|
)
|
|
if not records:
|
|
return None
|
|
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
|
return FeatureAccess(**cleanedRecord)
|
|
except Exception as e:
|
|
logger.error(f"Error getting FeatureAccess: {e}")
|
|
return None
|
|
|
|
def getFeatureAccessesForUser(self, userId: str) -> List[FeatureAccess]:
|
|
"""
|
|
Get all feature accesses for a user.
|
|
|
|
Args:
|
|
userId: User ID
|
|
|
|
Returns:
|
|
List of FeatureAccess objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
FeatureAccess,
|
|
recordFilter={"userId": userId, "enabled": True}
|
|
)
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(FeatureAccess(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting FeatureAccesses: {e}")
|
|
return []
|
|
|
|
def getFeatureAccessesByInstance(self, featureInstanceId: str) -> List[FeatureAccess]:
|
|
"""
|
|
Get all FeatureAccess records for a specific feature instance.
|
|
|
|
Args:
|
|
featureInstanceId: FeatureInstance ID
|
|
|
|
Returns:
|
|
List of FeatureAccess objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
FeatureAccess,
|
|
recordFilter={"featureInstanceId": featureInstanceId}
|
|
)
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(FeatureAccess(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting FeatureAccesses for instance {featureInstanceId}: {e}")
|
|
return []
|
|
|
|
def createFeatureAccess(self, userId: str, featureInstanceId: str, roleIds: List[str] = None) -> FeatureAccess:
|
|
"""
|
|
Create a FeatureAccess record (grant user access to feature instance).
|
|
Also auto-assigns the user to the mandate with the 'user' role if not already a member.
|
|
|
|
Args:
|
|
userId: User ID
|
|
featureInstanceId: FeatureInstance ID
|
|
roleIds: Optional list of role IDs to assign
|
|
|
|
Returns:
|
|
Created FeatureAccess object
|
|
"""
|
|
try:
|
|
# Check if already exists
|
|
existing = self.getFeatureAccess(userId, featureInstanceId)
|
|
if existing:
|
|
raise ValueError(f"User {userId} already has access to feature instance {featureInstanceId}")
|
|
|
|
# Auto-assign user to mandate with 'user' role if not already a member
|
|
self._ensureUserMandateMembership(userId, featureInstanceId)
|
|
|
|
# Create FeatureAccess
|
|
featureAccess = FeatureAccess(
|
|
userId=userId,
|
|
featureInstanceId=featureInstanceId,
|
|
enabled=True
|
|
)
|
|
createdRecord = self.db.recordCreate(FeatureAccess, featureAccess.model_dump())
|
|
|
|
# Assign roles via junction table
|
|
if roleIds and createdRecord:
|
|
featureAccessId = createdRecord.get("id")
|
|
for roleId in roleIds:
|
|
featureAccessRole = FeatureAccessRole(
|
|
featureAccessId=featureAccessId,
|
|
roleId=roleId
|
|
)
|
|
self.db.recordCreate(FeatureAccessRole, featureAccessRole.model_dump())
|
|
|
|
cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")}
|
|
return FeatureAccess(**cleanedRecord)
|
|
except Exception as e:
|
|
logger.error(f"Error creating FeatureAccess: {e}")
|
|
raise ValueError(f"Failed to create FeatureAccess: {e}")
|
|
|
|
def _ensureUserMandateMembership(self, userId: str, featureInstanceId: str) -> None:
|
|
"""
|
|
Ensure user is a member of the mandate that owns the feature instance.
|
|
If not already a member, adds them with the 'user' role (no access rights, membership only).
|
|
"""
|
|
try:
|
|
from modules.interfaces.interfaceFeatures import getFeatureInterface
|
|
|
|
featureInterface = getFeatureInterface(self.db)
|
|
instance = featureInterface.getFeatureInstance(featureInstanceId)
|
|
if not instance or not instance.mandateId:
|
|
logger.warning(f"Cannot auto-assign mandate: feature instance {featureInstanceId} not found or has no mandateId")
|
|
return
|
|
|
|
mandateId = str(instance.mandateId)
|
|
|
|
# Check if user already has mandate membership
|
|
existing = self.getUserMandate(userId, mandateId)
|
|
if existing:
|
|
logger.debug(f"User {userId} already member of mandate {mandateId}")
|
|
return
|
|
|
|
# Find the mandate-level 'user' role (membership marker, no access rights)
|
|
userRoles = self.db.getRecordset(
|
|
Role,
|
|
recordFilter={"roleLabel": "user", "mandateId": mandateId, "featureInstanceId": None}
|
|
)
|
|
userRoleId = userRoles[0].get("id") if userRoles else None
|
|
roleIds = [userRoleId] if userRoleId else []
|
|
|
|
self.createUserMandate(userId, mandateId, roleIds)
|
|
logger.info(f"Auto-assigned user {userId} to mandate {mandateId} with 'user' role (via feature instance {featureInstanceId})")
|
|
|
|
except ValueError:
|
|
# createUserMandate raises ValueError if already exists - safe to ignore
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"Error auto-assigning user {userId} to mandate: {e}")
|
|
|
|
def getRoleIdsForFeatureAccess(self, featureAccessId: str) -> List[str]:
|
|
"""
|
|
Get all role IDs assigned to a FeatureAccess.
|
|
|
|
Args:
|
|
featureAccessId: FeatureAccess ID
|
|
|
|
Returns:
|
|
List of role IDs
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
FeatureAccessRole,
|
|
recordFilter={"featureAccessId": featureAccessId}
|
|
)
|
|
return [r.get("roleId") for r in records if r.get("roleId")]
|
|
except Exception as e:
|
|
logger.error(f"Error getting role IDs for FeatureAccess: {e}")
|
|
return []
|
|
|
|
def addRoleToFeatureAccess(self, featureAccessId: str, roleId: str) -> None:
|
|
"""
|
|
Add a role to a FeatureAccess (via junction table).
|
|
Skips if the role is already assigned.
|
|
|
|
Args:
|
|
featureAccessId: FeatureAccess ID
|
|
roleId: Role ID to add
|
|
"""
|
|
try:
|
|
# Check if already exists
|
|
existing = self.db.getRecordset(
|
|
FeatureAccessRole,
|
|
recordFilter={"featureAccessId": featureAccessId, "roleId": roleId}
|
|
)
|
|
if existing:
|
|
return # Already assigned
|
|
|
|
featureAccessRole = FeatureAccessRole(
|
|
featureAccessId=featureAccessId,
|
|
roleId=roleId
|
|
)
|
|
self.db.recordCreate(FeatureAccessRole, featureAccessRole.model_dump())
|
|
logger.debug(f"Added role {roleId} to FeatureAccess {featureAccessId}")
|
|
except Exception as e:
|
|
logger.error(f"Error adding role to FeatureAccess: {e}")
|
|
raise ValueError(f"Failed to add role to FeatureAccess: {e}")
|
|
|
|
def deleteFeatureAccessRoles(self, featureAccessId: str) -> int:
|
|
"""
|
|
Delete all FeatureAccessRole records for a FeatureAccess.
|
|
|
|
Args:
|
|
featureAccessId: FeatureAccess ID
|
|
|
|
Returns:
|
|
Number of records deleted
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(
|
|
FeatureAccessRole,
|
|
recordFilter={"featureAccessId": featureAccessId}
|
|
)
|
|
count = 0
|
|
for record in records:
|
|
recordId = record.get("id")
|
|
if recordId:
|
|
self.db.recordDelete(FeatureAccessRole, recordId)
|
|
count += 1
|
|
return count
|
|
except Exception as e:
|
|
logger.error(f"Error deleting FeatureAccessRoles for {featureAccessId}: {e}")
|
|
return 0
|
|
|
|
# ============================================
|
|
# Invitation Methods
|
|
# ============================================
|
|
|
|
def getInvitation(self, invitationId: str) -> Optional[Invitation]:
|
|
"""
|
|
Get an invitation by ID.
|
|
|
|
Args:
|
|
invitationId: Invitation ID
|
|
|
|
Returns:
|
|
Invitation object if found, None otherwise
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(Invitation, recordFilter={"id": invitationId})
|
|
if records:
|
|
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
|
return Invitation(**cleanedRecord)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting invitation {invitationId}: {e}")
|
|
return None
|
|
|
|
def getInvitationByToken(self, token: str) -> Optional[Invitation]:
|
|
"""
|
|
Get an invitation by token.
|
|
|
|
Args:
|
|
token: Invitation token
|
|
|
|
Returns:
|
|
Invitation object if found, None otherwise
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(Invitation, recordFilter={"token": token})
|
|
if records:
|
|
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
|
return Invitation(**cleanedRecord)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting invitation by token: {e}")
|
|
return None
|
|
|
|
def getInvitationsByMandate(self, mandateId: str) -> List[Invitation]:
|
|
"""
|
|
Get all invitations for a mandate.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
|
|
Returns:
|
|
List of Invitation objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(Invitation, recordFilter={"mandateId": mandateId})
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(Invitation(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting invitations for mandate {mandateId}: {e}")
|
|
return []
|
|
|
|
def getInvitationsByCreator(self, creatorId: str) -> List[Invitation]:
|
|
"""
|
|
Get all invitations created by a user.
|
|
|
|
Args:
|
|
creatorId: User ID who created the invitations
|
|
|
|
Returns:
|
|
List of Invitation objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(Invitation, recordFilter={"createdBy": creatorId})
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(Invitation(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting invitations by creator {creatorId}: {e}")
|
|
return []
|
|
|
|
def getInvitationsByUsedBy(self, usedById: str) -> List[Invitation]:
|
|
"""
|
|
Get all invitations used by a user.
|
|
|
|
Args:
|
|
usedById: User ID who used the invitations
|
|
|
|
Returns:
|
|
List of Invitation objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(Invitation, recordFilter={"usedBy": usedById})
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(Invitation(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting invitations used by {usedById}: {e}")
|
|
return []
|
|
|
|
def getInvitationsByTargetUsername(self, targetUsername: str) -> List[Invitation]:
|
|
"""
|
|
Get all invitations for a target username.
|
|
|
|
Args:
|
|
targetUsername: Target username for the invitations
|
|
|
|
Returns:
|
|
List of Invitation objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(Invitation, recordFilter={"targetUsername": targetUsername})
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(Invitation(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting invitations for target username {targetUsername}: {e}")
|
|
return []
|
|
|
|
# ============================================
|
|
# Additional Helper Methods
|
|
# ============================================
|
|
|
|
def getAllUsers(self, pagination: Optional[PaginationParams] = None) -> Union[List[User], PaginatedResult]:
|
|
"""
|
|
Get all users (for SysAdmin only).
|
|
|
|
Args:
|
|
pagination: Optional pagination parameters. If None, returns all items.
|
|
|
|
Returns:
|
|
If pagination is None: List[User] (without sensitive fields)
|
|
If pagination is provided: PaginatedResult with items and metadata
|
|
"""
|
|
try:
|
|
result = self.db.getRecordsetPaginated(UserInDB, pagination=pagination)
|
|
|
|
items = []
|
|
for record in result["items"]:
|
|
cleanedRecord = {
|
|
k: v for k, v in record.items()
|
|
if not k.startswith("_") and k not in ["hashedPassword", "resetToken", "resetTokenExpires"]
|
|
}
|
|
if cleanedRecord.get("roleLabels") is None:
|
|
cleanedRecord["roleLabels"] = []
|
|
items.append(User(**cleanedRecord))
|
|
|
|
if pagination is None:
|
|
return items
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=result["totalItems"],
|
|
totalPages=result["totalPages"]
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting all users: {e}")
|
|
if pagination is None:
|
|
return []
|
|
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
|
|
|
def getUserMandateById(self, userMandateId: str) -> Optional[UserMandate]:
|
|
"""
|
|
Get a UserMandate by its ID.
|
|
|
|
Args:
|
|
userMandateId: UserMandate ID
|
|
|
|
Returns:
|
|
UserMandate object if found, None otherwise
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(UserMandate, recordFilter={"id": userMandateId})
|
|
if records:
|
|
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
|
return UserMandate(**cleanedRecord)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting UserMandate {userMandateId}: {e}")
|
|
return None
|
|
|
|
def getUserMandateRolesByRole(self, roleId: str) -> List[UserMandateRole]:
|
|
"""
|
|
Get all UserMandateRole records for a specific role.
|
|
|
|
Args:
|
|
roleId: Role ID
|
|
|
|
Returns:
|
|
List of UserMandateRole objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(UserMandateRole, recordFilter={"roleId": roleId})
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(UserMandateRole(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting UserMandateRoles for role {roleId}: {e}")
|
|
return []
|
|
|
|
def getFeatureInstance(self, instanceId: str):
|
|
"""
|
|
Get a FeatureInstance by ID.
|
|
|
|
Args:
|
|
instanceId: FeatureInstance ID
|
|
|
|
Returns:
|
|
FeatureInstance object if found, None otherwise
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(FeatureInstance, recordFilter={"id": instanceId})
|
|
if records:
|
|
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
|
return FeatureInstance(**cleanedRecord)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting FeatureInstance {instanceId}: {e}")
|
|
return None
|
|
|
|
def getFeatureByCode(self, featureCode: str) -> Optional[Feature]:
|
|
"""
|
|
Get a Feature by its code.
|
|
|
|
Args:
|
|
featureCode: Feature code
|
|
|
|
Returns:
|
|
Feature object if found, None otherwise
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(Feature, recordFilter={"code": featureCode})
|
|
if records:
|
|
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
|
return Feature(**cleanedRecord)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting Feature by code {featureCode}: {e}")
|
|
return None
|
|
|
|
def getFeatureInstancesByMandate(self, mandateId: str, enabledOnly: bool = False) -> List[FeatureInstance]:
|
|
"""
|
|
Get all FeatureInstances for a mandate.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
enabledOnly: If True, only return enabled instances
|
|
|
|
Returns:
|
|
List of FeatureInstance objects
|
|
"""
|
|
try:
|
|
recordFilter = {"mandateId": mandateId}
|
|
if enabledOnly:
|
|
recordFilter["enabled"] = True
|
|
records = self.db.getRecordset(FeatureInstance, recordFilter=recordFilter)
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(FeatureInstance(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting FeatureInstances for mandate {mandateId}: {e}")
|
|
return []
|
|
|
|
# ============================================
|
|
# Notification Methods
|
|
# ============================================
|
|
|
|
def getNotification(self, notificationId: str) -> Optional[UserNotification]:
|
|
"""
|
|
Get a notification by ID.
|
|
|
|
Args:
|
|
notificationId: Notification ID
|
|
|
|
Returns:
|
|
UserNotification object if found, None otherwise
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(UserNotification, recordFilter={"id": notificationId})
|
|
if records:
|
|
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
|
return UserNotification(**cleanedRecord)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting notification {notificationId}: {e}")
|
|
return None
|
|
|
|
def getNotificationsByUser(
|
|
self,
|
|
userId: str,
|
|
status: Optional[str] = None,
|
|
limit: Optional[int] = None
|
|
) -> List[UserNotification]:
|
|
"""
|
|
Get notifications for a user.
|
|
|
|
Args:
|
|
userId: User ID
|
|
status: Optional status filter (e.g., 'unread')
|
|
limit: Optional limit on number of results
|
|
|
|
Returns:
|
|
List of UserNotification objects
|
|
"""
|
|
try:
|
|
recordFilter = {"userId": userId}
|
|
if status:
|
|
recordFilter["status"] = status
|
|
records = self.db.getRecordset(UserNotification, recordFilter=recordFilter)
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(UserNotification(**cleanedRecord))
|
|
# Sort by createdAt descending
|
|
result.sort(key=lambda x: x.createdAt or 0, reverse=True)
|
|
if limit:
|
|
result = result[:limit]
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting notifications for user {userId}: {e}")
|
|
return []
|
|
|
|
# ============================================
|
|
# AccessRule Methods
|
|
# ============================================
|
|
|
|
def getAccessRule(self, ruleId: str) -> Optional[AccessRule]:
|
|
"""
|
|
Get an AccessRule by ID.
|
|
|
|
Args:
|
|
ruleId: AccessRule ID
|
|
|
|
Returns:
|
|
AccessRule object if found, None otherwise
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(AccessRule, recordFilter={"id": ruleId})
|
|
if records:
|
|
cleanedRecord = {k: v for k, v in records[0].items() if not k.startswith("_")}
|
|
return AccessRule(**cleanedRecord)
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting AccessRule {ruleId}: {e}")
|
|
return None
|
|
|
|
def getAccessRulesByRole(self, roleId: str) -> List[AccessRule]:
|
|
"""
|
|
Get all AccessRules for a role.
|
|
|
|
Args:
|
|
roleId: Role ID
|
|
|
|
Returns:
|
|
List of AccessRule objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(AccessRule, recordFilter={"roleId": roleId})
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(AccessRule(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting AccessRules for role {roleId}: {e}")
|
|
return []
|
|
|
|
def getRolesByFeatureInstance(self, featureInstanceId: str) -> List[Role]:
|
|
"""
|
|
Get all roles for a feature instance.
|
|
|
|
Args:
|
|
featureInstanceId: FeatureInstance ID
|
|
|
|
Returns:
|
|
List of Role objects
|
|
"""
|
|
try:
|
|
records = self.db.getRecordset(Role, recordFilter={"featureInstanceId": featureInstanceId})
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(Role(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting roles for feature instance {featureInstanceId}: {e}")
|
|
return []
|
|
|
|
def getRolesByFeatureCode(self, featureCode: str, featureInstanceId: Optional[str] = None) -> List[Role]:
|
|
"""
|
|
Get all roles for a feature code, optionally filtered by instance.
|
|
|
|
Args:
|
|
featureCode: Feature code
|
|
featureInstanceId: Optional FeatureInstance ID filter
|
|
|
|
Returns:
|
|
List of Role objects
|
|
"""
|
|
try:
|
|
recordFilter = {"featureCode": featureCode}
|
|
if featureInstanceId:
|
|
recordFilter["featureInstanceId"] = featureInstanceId
|
|
records = self.db.getRecordset(Role, recordFilter=recordFilter)
|
|
result = []
|
|
for record in records:
|
|
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
result.append(Role(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting roles for feature code {featureCode}: {e}")
|
|
return []
|
|
|
|
# Token methods
|
|
|
|
def saveAccessToken(self, token: Token, replace_existing: bool = True) -> None:
|
|
"""Save an access token for the current user (must NOT have connectionId)"""
|
|
try:
|
|
# Validate that this is NOT a connection token
|
|
if token.connectionId:
|
|
raise ValueError(
|
|
"Access tokens cannot have connectionId - use saveConnectionToken instead"
|
|
)
|
|
|
|
_tp = (
|
|
token.tokenPurpose.value
|
|
if isinstance(token.tokenPurpose, TokenPurpose)
|
|
else token.tokenPurpose
|
|
)
|
|
if _tp != TokenPurpose.AUTH_SESSION.value:
|
|
raise ValueError(
|
|
"saveAccessToken requires tokenPurpose=authSession (gateway session JWT)"
|
|
)
|
|
|
|
# Validate user context
|
|
if not self.currentUser or not self.currentUser.id:
|
|
raise ValueError("No valid user context available for token storage")
|
|
|
|
# Set the user ID and mandate ID
|
|
token.userId = self.currentUser.id
|
|
|
|
# Ensure token has required fields
|
|
if not token.id:
|
|
token.id = str(uuid.uuid4())
|
|
if not token.createdAt:
|
|
token.createdAt = getUtcTimestamp()
|
|
|
|
# If replace_existing is True, delete old access tokens for this user and authority first
|
|
if replace_existing:
|
|
try:
|
|
old_tokens = self.db.getRecordset(
|
|
Token,
|
|
recordFilter={
|
|
"userId": self.currentUser.id,
|
|
"authority": token.authority,
|
|
"connectionId": None, # Ensure we only delete access tokens
|
|
"tokenPurpose": TokenPurpose.AUTH_SESSION.value,
|
|
},
|
|
)
|
|
deleted_count = 0
|
|
for old_token in old_tokens:
|
|
if (
|
|
old_token["id"] != token.id
|
|
): # Don't delete the new token if it already exists
|
|
self.db.recordDelete(Token, old_token["id"])
|
|
deleted_count += 1
|
|
|
|
if deleted_count > 0:
|
|
logger.info(
|
|
f"Replaced {deleted_count} old access tokens for user {self.currentUser.id} and authority {token.authority}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to delete old access tokens for user {self.currentUser.id} and authority {token.authority}: {str(e)}"
|
|
)
|
|
# Continue with saving the new token even if deletion fails
|
|
|
|
# Convert to dict and ensure all fields are properly set
|
|
token_dict = token.model_dump()
|
|
# Ensure userId is set to current user
|
|
# Convert to dict and ensure all fields are properly set
|
|
token_dict = token.model_dump()
|
|
# Ensure userId is set to current user
|
|
token_dict["userId"] = self.currentUser.id
|
|
|
|
# Save to database
|
|
self.db.recordCreate(Token, token_dict)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving access token: {str(e)}")
|
|
raise
|
|
|
|
def saveConnectionToken(self, token: Token, replace_existing: bool = True) -> None:
|
|
"""Save a connection token (must have connectionId)"""
|
|
try:
|
|
# Validate that this IS a connection token
|
|
if not token.connectionId:
|
|
raise ValueError(
|
|
"Connection tokens must have connectionId - use saveAccessToken instead"
|
|
)
|
|
|
|
_tp = (
|
|
token.tokenPurpose.value
|
|
if isinstance(token.tokenPurpose, TokenPurpose)
|
|
else token.tokenPurpose
|
|
)
|
|
if _tp != TokenPurpose.DATA_CONNECTION.value:
|
|
raise ValueError(
|
|
"saveConnectionToken requires tokenPurpose=dataConnection (provider OAuth)"
|
|
)
|
|
|
|
# Validate user context
|
|
if not self.currentUser or not self.currentUser.id:
|
|
raise ValueError("No valid user context available for token storage")
|
|
|
|
# Set the user ID for the connection token
|
|
token.userId = self.currentUser.id
|
|
|
|
# Ensure token has required fields
|
|
if not token.id:
|
|
token.id = str(uuid.uuid4())
|
|
if not token.createdAt:
|
|
token.createdAt = getUtcTimestamp()
|
|
|
|
# Convert to dict and ensure all fields are properly set
|
|
token_dict = token.model_dump()
|
|
# Ensure userId is set to current user
|
|
token_dict["userId"] = self.currentUser.id
|
|
|
|
# Save to database
|
|
self.db.recordCreate(Token, token_dict)
|
|
|
|
# After successful save, delete old tokens for this connectionId (if requested)
|
|
if replace_existing:
|
|
try:
|
|
old_tokens = self.db.getRecordset(
|
|
Token, recordFilter={"connectionId": token.connectionId}
|
|
)
|
|
deleted_count = 0
|
|
for old_token in old_tokens:
|
|
if old_token["id"] != token.id:
|
|
self.db.recordDelete(Token, old_token["id"])
|
|
deleted_count += 1
|
|
|
|
if deleted_count > 0:
|
|
logger.info(
|
|
f"Replaced {deleted_count} old tokens for connectionId {token.connectionId}"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to delete old tokens for connectionId {token.connectionId}: {str(e)}"
|
|
)
|
|
# Keep the newly saved token; cleanup can be retried later
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving connection token: {str(e)}")
|
|
raise
|
|
|
|
def getConnectionToken(self, connectionId: str) -> Optional[Token]:
|
|
"""Get the latest stored token for a specific connectionId (no refresh)."""
|
|
try:
|
|
# Validate connectionId
|
|
if not connectionId:
|
|
raise ValueError("connectionId is required for getConnectionToken")
|
|
|
|
# Get token for this specific connection
|
|
# Query for specific connection
|
|
tokens = self.db.getRecordset(
|
|
Token, recordFilter={"connectionId": connectionId}
|
|
)
|
|
|
|
if not tokens:
|
|
logger.warning(
|
|
f"No connection token found for connectionId: {connectionId}"
|
|
)
|
|
return None
|
|
|
|
# Sort by expiration date and get the latest (most recent expiration)
|
|
tokens.sort(key=lambda x: parseTimestamp(x.get("expiresAt"), default=0), reverse=True)
|
|
latest_token = Token(**tokens[0])
|
|
|
|
# No auto-refresh here. Callers should use a higher-level service to refresh when needed.
|
|
|
|
return latest_token
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting connection token for connectionId {connectionId}: {str(e)}"
|
|
)
|
|
return None
|
|
|
|
def getTokensByConnectionIdAndAuthority(
|
|
self, connectionId: str, authority: AuthAuthority
|
|
) -> List[Token]:
|
|
"""Get tokens for a connection with specific authority."""
|
|
try:
|
|
tokens = self.db.getRecordset(
|
|
Token, recordFilter={
|
|
"connectionId": connectionId,
|
|
"authority": authority.value if hasattr(authority, 'value') else str(authority)
|
|
}
|
|
)
|
|
result = []
|
|
for token_dict in tokens:
|
|
cleanedRecord = {k: v for k, v in token_dict.items() if not k.startswith("_")}
|
|
result.append(Token(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting tokens by connection and authority: {str(e)}")
|
|
return []
|
|
|
|
def getTokensByUserIdNoConnection(
|
|
self, userId: str, authority: AuthAuthority
|
|
) -> List[Token]:
|
|
"""Get tokens for a user without a connection (access tokens)."""
|
|
try:
|
|
tokens = self.db.getRecordset(
|
|
Token, recordFilter={
|
|
"userId": userId,
|
|
"connectionId": None,
|
|
"authority": authority.value if hasattr(authority, 'value') else str(authority)
|
|
}
|
|
)
|
|
result = []
|
|
for token_dict in tokens:
|
|
cleanedRecord = {k: v for k, v in token_dict.items() if not k.startswith("_")}
|
|
result.append(Token(**cleanedRecord))
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error getting tokens by user and authority: {str(e)}")
|
|
return []
|
|
|
|
def getAllTokens(self, recordFilter: dict = None) -> List[dict]:
|
|
"""Get all tokens with optional filtering (returns raw dicts)."""
|
|
try:
|
|
tokens = self.db.getRecordset(Token, recordFilter=recordFilter or {})
|
|
return tokens
|
|
except Exception as e:
|
|
logger.error(f"Error getting all tokens: {str(e)}")
|
|
return []
|
|
|
|
def findActiveTokenById(
|
|
self,
|
|
tokenId: str,
|
|
userId: str,
|
|
authority: AuthAuthority,
|
|
sessionId: str = None,
|
|
mandateId: str = None,
|
|
tokenPurpose: str = None,
|
|
) -> Optional[Token]:
|
|
"""Find an active access token by its id (jti) with optional session/tenant scoping."""
|
|
try:
|
|
recordFilter = {
|
|
"id": tokenId,
|
|
"userId": userId,
|
|
"authority": authority.value
|
|
if hasattr(authority, "value")
|
|
else str(authority),
|
|
"status": TokenStatus.ACTIVE,
|
|
}
|
|
if sessionId is not None:
|
|
recordFilter["sessionId"] = sessionId
|
|
if mandateId is not None:
|
|
recordFilter["mandateId"] = mandateId
|
|
if tokenPurpose is not None:
|
|
recordFilter["tokenPurpose"] = tokenPurpose
|
|
tokens = self.db.getRecordset(Token, recordFilter=recordFilter)
|
|
if not tokens:
|
|
return None
|
|
return Token(**tokens[0])
|
|
except Exception as e:
|
|
logger.error(f"Error finding active token by id {tokenId}: {str(e)}")
|
|
return None
|
|
|
|
def revokeTokenById(self, tokenId: str, revokedBy: str, reason: str = None) -> bool:
|
|
"""Revoke a single token by id by setting status fields (no delete)."""
|
|
try:
|
|
existing = self.db.getRecordset(Token, recordFilter={"id": tokenId})
|
|
if not existing:
|
|
return False
|
|
token = existing[0]
|
|
if token.get("status") == TokenStatus.REVOKED:
|
|
return True
|
|
tokenUpdate = {
|
|
"status": TokenStatus.REVOKED,
|
|
"revokedAt": getUtcTimestamp(),
|
|
"revokedBy": revokedBy,
|
|
"reason": reason or "revoked",
|
|
}
|
|
self.db.recordModify(Token, tokenId, tokenUpdate)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error revoking token {tokenId}: {str(e)}")
|
|
return False
|
|
|
|
def revokeTokensBySessionId(
|
|
self,
|
|
sessionId: str,
|
|
userId: str,
|
|
authority: AuthAuthority,
|
|
revokedBy: str,
|
|
reason: str = None,
|
|
) -> int:
|
|
"""Revoke all tokens of a session for a user/authority."""
|
|
try:
|
|
tokens = self.db.getRecordset(
|
|
Token,
|
|
recordFilter={
|
|
"userId": userId,
|
|
"authority": authority.value
|
|
if hasattr(authority, "value")
|
|
else str(authority),
|
|
"sessionId": sessionId,
|
|
"status": TokenStatus.ACTIVE,
|
|
},
|
|
)
|
|
count = 0
|
|
for t in tokens:
|
|
self.db.recordModify(
|
|
Token,
|
|
t["id"],
|
|
{
|
|
"status": TokenStatus.REVOKED,
|
|
"revokedAt": getUtcTimestamp(),
|
|
"revokedBy": revokedBy,
|
|
"reason": reason or "session logout",
|
|
},
|
|
)
|
|
count += 1
|
|
return count
|
|
except Exception as e:
|
|
logger.error(f"Error revoking tokens for session {sessionId}: {str(e)}")
|
|
return 0
|
|
|
|
def revokeTokensByUser(
|
|
self,
|
|
userId: str,
|
|
authority: AuthAuthority = None,
|
|
mandateId: str = None,
|
|
revokedBy: str = None,
|
|
reason: str = None,
|
|
) -> int:
|
|
"""Revoke all active tokens for a user, optionally filtered by authority/mandate."""
|
|
try:
|
|
# Fetch all active tokens for user (optionally filtered by authority)
|
|
recordFilter = {
|
|
"userId": userId,
|
|
"status": TokenStatus.ACTIVE,
|
|
}
|
|
if authority is not None:
|
|
recordFilter["authority"] = (
|
|
authority.value if hasattr(authority, "value") else str(authority)
|
|
)
|
|
tokens = self.db.getRecordset(Token, recordFilter=recordFilter)
|
|
count = 0
|
|
for t in tokens:
|
|
self.db.recordModify(
|
|
Token,
|
|
t["id"],
|
|
{
|
|
"status": TokenStatus.REVOKED,
|
|
"revokedAt": getUtcTimestamp(),
|
|
"revokedBy": revokedBy,
|
|
"reason": reason or "admin revoke",
|
|
},
|
|
)
|
|
count += 1
|
|
return count
|
|
except Exception as e:
|
|
logger.error(f"Error revoking tokens for user {userId}: {str(e)}")
|
|
return 0
|
|
|
|
def cleanupExpiredTokens(self) -> int:
|
|
"""Clean up expired tokens for all connections, returns count of cleaned tokens"""
|
|
try:
|
|
current_time = getUtcTimestamp()
|
|
cleaned_count = 0
|
|
|
|
# Get all tokens
|
|
all_tokens = self.db.getRecordset(Token, recordFilter={})
|
|
|
|
for token_data in all_tokens:
|
|
expiresAt = parseTimestamp(token_data.get("expiresAt"))
|
|
if expiresAt and expiresAt < current_time:
|
|
# Token is expired, delete it
|
|
self.db.recordDelete(Token, token_data["id"])
|
|
cleaned_count += 1
|
|
|
|
# Clear cache to ensure fresh data
|
|
if cleaned_count > 0:
|
|
logger.info(f"Cleaned up {cleaned_count} expired tokens")
|
|
|
|
return cleaned_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up expired tokens: {str(e)}")
|
|
return 0
|
|
|
|
def logout(self) -> None:
|
|
"""Logout current user - clear user context and tokens"""
|
|
try:
|
|
# Clear user context
|
|
self.currentUser = None
|
|
self.userId = None
|
|
self.mandateId = None
|
|
self.rbac = None
|
|
|
|
# Clear database context
|
|
if hasattr(self, "db"):
|
|
self.db.updateContext("")
|
|
|
|
logger.info("User logged out successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during logout: {str(e)}")
|
|
raise
|
|
|
|
# RBAC CRUD Methods
|
|
|
|
def createAccessRule(self, accessRule: AccessRule) -> AccessRule:
|
|
"""
|
|
Create a new access rule.
|
|
|
|
Args:
|
|
accessRule: AccessRule object to create
|
|
|
|
Returns:
|
|
Created AccessRule object
|
|
"""
|
|
try:
|
|
createdRule = self.db.recordCreate(AccessRule, accessRule)
|
|
logger.info(f"Created access rule with ID {createdRule.get('id')}")
|
|
return AccessRule(**createdRule)
|
|
except Exception as e:
|
|
logger.error(f"Error creating access rule: {str(e)}")
|
|
raise
|
|
|
|
def getAccessRule(self, ruleId: str) -> Optional[AccessRule]:
|
|
"""
|
|
Get an access rule by ID.
|
|
|
|
Args:
|
|
ruleId: Access rule ID
|
|
|
|
Returns:
|
|
AccessRule object if found, None otherwise
|
|
"""
|
|
try:
|
|
rules = self.db.getRecordset(AccessRule, recordFilter={"id": ruleId})
|
|
if rules:
|
|
return AccessRule(**rules[0])
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting access rule {ruleId}: {str(e)}")
|
|
return None
|
|
|
|
def updateAccessRule(self, ruleId: str, accessRule: AccessRule) -> AccessRule:
|
|
"""
|
|
Update an existing access rule.
|
|
|
|
Args:
|
|
ruleId: Access rule ID
|
|
accessRule: Updated AccessRule object
|
|
|
|
Returns:
|
|
Updated AccessRule object
|
|
"""
|
|
try:
|
|
# Exclude id from model_dump - the URL ruleId is authoritative
|
|
updatedRule = self.db.recordModify(AccessRule, ruleId, accessRule.model_dump(exclude={"id"}))
|
|
logger.info(f"Updated access rule with ID {ruleId}")
|
|
return AccessRule(**updatedRule)
|
|
except Exception as e:
|
|
logger.error(f"Error updating access rule {ruleId}: {str(e)}")
|
|
raise
|
|
|
|
def deleteAccessRule(self, ruleId: str) -> bool:
|
|
"""
|
|
Delete an access rule.
|
|
|
|
Args:
|
|
ruleId: Access rule ID
|
|
|
|
Returns:
|
|
True if deleted successfully, False otherwise
|
|
"""
|
|
try:
|
|
self.db.recordDelete(AccessRule, ruleId)
|
|
logger.info(f"Deleted access rule with ID {ruleId}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error deleting access rule {ruleId}: {str(e)}")
|
|
return False
|
|
|
|
def getAccessRules(
|
|
self,
|
|
roleLabel: Optional[str] = None,
|
|
roleId: Optional[str] = None,
|
|
context: Optional[AccessRuleContext] = None,
|
|
item: Optional[str] = None,
|
|
pagination: Optional[PaginationParams] = None
|
|
) -> Union[List[AccessRule], PaginatedResult]:
|
|
"""
|
|
Get access rules with optional filters and pagination.
|
|
|
|
Args:
|
|
roleLabel: Optional role label filter (deprecated, use roleId)
|
|
roleId: Optional role ID filter
|
|
context: Optional context filter
|
|
item: Optional item filter
|
|
pagination: Optional pagination parameters. If None, returns all items.
|
|
|
|
Returns:
|
|
If pagination is None: List[AccessRule]
|
|
If pagination is provided: PaginatedResult with items and metadata
|
|
"""
|
|
try:
|
|
recordFilter = {}
|
|
if roleId:
|
|
recordFilter["roleId"] = roleId
|
|
elif roleLabel:
|
|
recordFilter["roleLabel"] = roleLabel
|
|
if context:
|
|
recordFilter["context"] = context.value
|
|
if item:
|
|
recordFilter["item"] = item
|
|
|
|
# Use RBAC filtering
|
|
rules = getRecordsetWithRBAC(
|
|
self.db,
|
|
AccessRule,
|
|
self.currentUser,
|
|
recordFilter=recordFilter if recordFilter else None,
|
|
mandateId=self.mandateId
|
|
)
|
|
|
|
# Filter out database-specific fields
|
|
filteredRules = []
|
|
for rule in rules:
|
|
cleanedRule = {k: v for k, v in rule.items() if not k.startswith("_")}
|
|
filteredRules.append(cleanedRule)
|
|
|
|
# If no pagination requested, return all items
|
|
if pagination is None:
|
|
return [AccessRule(**rule) for rule in filteredRules]
|
|
|
|
# Apply filtering (if filters provided)
|
|
if pagination.filters:
|
|
filteredRules = self._applyFilters(filteredRules, pagination.filters)
|
|
|
|
# Apply sorting (in order of sortFields)
|
|
if pagination.sort:
|
|
filteredRules = self._applySorting(filteredRules, pagination.sort)
|
|
|
|
# Count total items after filters
|
|
totalItems = len(filteredRules)
|
|
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
|
|
|
# Apply pagination (skip/limit)
|
|
startIdx = (pagination.page - 1) * pagination.pageSize
|
|
endIdx = startIdx + pagination.pageSize
|
|
pagedRules = filteredRules[startIdx:endIdx]
|
|
|
|
# Convert to model objects
|
|
items = [AccessRule(**rule) for rule in pagedRules]
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=totalItems,
|
|
totalPages=totalPages
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting access rules: {str(e)}")
|
|
if pagination is None:
|
|
return []
|
|
else:
|
|
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
|
|
|
def getAccessRulesForRoles(
|
|
self,
|
|
roleLabels: List[str],
|
|
context: AccessRuleContext,
|
|
item: str
|
|
) -> List[AccessRule]:
|
|
"""
|
|
Get access rules for multiple roles, context, and item.
|
|
Returns the most specific matching rules for each role.
|
|
|
|
Args:
|
|
roleLabels: List of role labels
|
|
context: Context type
|
|
item: Item identifier
|
|
|
|
Returns:
|
|
List of AccessRule objects (most specific for each role)
|
|
"""
|
|
try:
|
|
# Pass self.db as dbApp since this interface uses DbApp database
|
|
RbacInstance = RbacClass(self.db, dbApp=self.db)
|
|
allRules = []
|
|
|
|
for roleLabel in roleLabels:
|
|
# Get all rules for this role and context
|
|
roleRules = RbacInstance._getRulesForRole(roleLabel, context)
|
|
|
|
# Find most specific rule for this item
|
|
mostSpecificRule = RbacInstance.findMostSpecificRule(roleRules, item)
|
|
|
|
if mostSpecificRule:
|
|
allRules.append(mostSpecificRule)
|
|
|
|
return allRules
|
|
except Exception as e:
|
|
logger.error(f"Error getting access rules for roles: {str(e)}")
|
|
return []
|
|
|
|
def createRole(self, role: Role) -> Role:
|
|
"""
|
|
Create a new role.
|
|
|
|
Args:
|
|
role: Role object to create
|
|
|
|
Returns:
|
|
Created Role object
|
|
"""
|
|
try:
|
|
# Check if role label already exists
|
|
existingRoles = self.db.getRecordset(Role, recordFilter={"roleLabel": role.roleLabel})
|
|
if existingRoles:
|
|
raise ValueError(f"Role with label '{role.roleLabel}' already exists")
|
|
|
|
createdRole = self.db.recordCreate(Role, role)
|
|
logger.info(f"Created role with ID {createdRole.get('id')} and label {role.roleLabel}")
|
|
return Role(**createdRole)
|
|
except Exception as e:
|
|
logger.error(f"Error creating role: {str(e)}")
|
|
raise
|
|
|
|
def getRole(self, roleId: str) -> Optional[Role]:
|
|
"""
|
|
Get a role by ID.
|
|
|
|
Args:
|
|
roleId: Role ID
|
|
|
|
Returns:
|
|
Role object if found, None otherwise
|
|
"""
|
|
try:
|
|
roles = self.db.getRecordset(Role, recordFilter={"id": roleId})
|
|
if roles:
|
|
return Role(**roles[0])
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting role {roleId}: {str(e)}")
|
|
return None
|
|
|
|
def getRoleByLabel(self, roleLabel: str) -> Optional[Role]:
|
|
"""
|
|
Get a role by label.
|
|
|
|
Args:
|
|
roleLabel: Role label
|
|
|
|
Returns:
|
|
Role object if found, None otherwise
|
|
"""
|
|
try:
|
|
roles = self.db.getRecordset(Role, recordFilter={"roleLabel": roleLabel})
|
|
if roles:
|
|
return Role(**roles[0])
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting role by label {roleLabel}: {str(e)}")
|
|
return None
|
|
|
|
def getRoleByLabelAndScope(
|
|
self,
|
|
roleLabel: str,
|
|
mandateId: Optional[str] = None,
|
|
featureInstanceId: Optional[str] = None,
|
|
featureCode: Optional[str] = None
|
|
) -> Optional[Role]:
|
|
"""
|
|
Get a role by label with scope filtering.
|
|
|
|
Args:
|
|
roleLabel: Role label
|
|
mandateId: Mandate ID (use None for global roles)
|
|
featureInstanceId: Feature instance ID
|
|
featureCode: Feature code
|
|
|
|
Returns:
|
|
Role object if found, None otherwise
|
|
"""
|
|
try:
|
|
recordFilter = {"roleLabel": roleLabel}
|
|
if mandateId is not None:
|
|
recordFilter["mandateId"] = mandateId
|
|
if featureInstanceId is not None:
|
|
recordFilter["featureInstanceId"] = featureInstanceId
|
|
if featureCode is not None:
|
|
recordFilter["featureCode"] = featureCode
|
|
|
|
roles = self.db.getRecordset(Role, recordFilter=recordFilter)
|
|
if roles:
|
|
return Role(**roles[0])
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting role by label and scope {roleLabel}: {str(e)}")
|
|
return None
|
|
|
|
def getRolesForMandate(self, mandateId: str) -> List[Role]:
|
|
"""
|
|
Get mandate-level roles for a specific mandate (featureInstanceId=NULL).
|
|
These are the roles created by copySystemRolesToMandate during bootstrap.
|
|
"""
|
|
try:
|
|
roles = self.db.getRecordset(
|
|
Role,
|
|
recordFilter={"mandateId": mandateId, "featureInstanceId": None}
|
|
)
|
|
return [Role(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in roles]
|
|
except Exception as e:
|
|
logger.error(f"Error getting roles for mandate {mandateId}: {e}")
|
|
return []
|
|
|
|
def getAllRoles(self, pagination: Optional[PaginationParams] = None) -> Union[List[Role], PaginatedResult]:
|
|
"""
|
|
Get all roles with optional pagination, sorting, and filtering.
|
|
|
|
Args:
|
|
pagination: Optional pagination parameters. If None, returns all items.
|
|
|
|
Returns:
|
|
If pagination is None: List[Role]
|
|
If pagination is provided: PaginatedResult with items and metadata
|
|
"""
|
|
try:
|
|
result = self.db.getRecordsetPaginated(Role, pagination=pagination)
|
|
|
|
items = []
|
|
for record in result["items"]:
|
|
cleanedRole = {k: v for k, v in record.items() if not k.startswith("_")}
|
|
items.append(Role(**cleanedRole))
|
|
|
|
if pagination is None:
|
|
return items
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=result["totalItems"],
|
|
totalPages=result["totalPages"]
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting all roles: {str(e)}")
|
|
if pagination is None:
|
|
return []
|
|
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
|
|
|
def countRoleAssignments(self) -> Dict[str, int]:
|
|
"""
|
|
Count the number of user assignments per role from UserMandateRole table.
|
|
|
|
Returns:
|
|
Dictionary mapping roleId to count of user assignments
|
|
"""
|
|
try:
|
|
# Get all UserMandateRole records
|
|
assignments = self.db.getRecordset(UserMandateRole)
|
|
|
|
# Count assignments per roleId
|
|
roleCounts: Dict[str, int] = {}
|
|
for assignment in assignments:
|
|
roleId = str(assignment.get("roleId", ""))
|
|
if roleId:
|
|
roleCounts[roleId] = roleCounts.get(roleId, 0) + 1
|
|
|
|
return roleCounts
|
|
except Exception as e:
|
|
logger.error(f"Error counting role assignments: {str(e)}")
|
|
return {}
|
|
|
|
def updateRole(self, roleId: str, role: Role) -> Role:
|
|
"""
|
|
Update an existing role.
|
|
|
|
Args:
|
|
roleId: Role ID
|
|
role: Updated Role object
|
|
|
|
Returns:
|
|
Updated Role object
|
|
"""
|
|
try:
|
|
# Check if role exists
|
|
existingRole = self.getRole(roleId)
|
|
if not existingRole:
|
|
raise ValueError(f"Role with ID {roleId} not found")
|
|
|
|
# If role label is being changed, check for conflicts
|
|
if role.roleLabel != existingRole.roleLabel:
|
|
conflictingRole = self.getRoleByLabel(role.roleLabel)
|
|
if conflictingRole and conflictingRole.id != roleId:
|
|
raise ValueError(f"Role with label '{role.roleLabel}' already exists")
|
|
|
|
# Exclude id from model_dump - the URL roleId is authoritative
|
|
updatedRole = self.db.recordModify(Role, roleId, role.model_dump(exclude={"id"}))
|
|
logger.info(f"Updated role with ID {roleId}")
|
|
return Role(**updatedRole)
|
|
except Exception as e:
|
|
logger.error(f"Error updating role {roleId}: {str(e)}")
|
|
raise
|
|
|
|
def deleteRole(self, roleId: str) -> bool:
|
|
"""
|
|
Delete a role.
|
|
|
|
Args:
|
|
roleId: Role ID
|
|
|
|
Returns:
|
|
True if deleted successfully, False otherwise
|
|
"""
|
|
try:
|
|
# Check if role exists
|
|
role = self.getRole(roleId)
|
|
if not role:
|
|
return False
|
|
|
|
# Prevent deletion of system roles
|
|
if role.isSystemRole:
|
|
raise ValueError(f"Cannot delete system role '{role.roleLabel}'")
|
|
|
|
# Check if role is assigned to any users via UserMandateRole
|
|
roleAssignments = self.db.getRecordset(UserMandateRole, recordFilter={"roleId": roleId})
|
|
if roleAssignments:
|
|
raise ValueError(f"Cannot delete role '{role.roleLabel}' - it is assigned to users")
|
|
|
|
# Check if role is used in any access rules
|
|
accessRules = self.getAccessRules(roleId=roleId)
|
|
if accessRules:
|
|
raise ValueError(f"Cannot delete role '{role.roleLabel}' - it is used in access rules")
|
|
|
|
self.db.recordDelete(Role, roleId)
|
|
logger.info(f"Deleted role with ID {roleId}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error deleting role {roleId}: {str(e)}")
|
|
raise
|
|
|
|
|
|
# Public Methods
|
|
|
|
|
|
def getInterface(currentUser: User, mandateId: Optional[str] = None) -> AppObjects:
|
|
"""
|
|
Returns a AppObjects instance for the current user.
|
|
Handles initialization of database and records.
|
|
|
|
Multi-Tenant Design:
|
|
- mandateId wird explizit übergeben (aus Request-Context / X-Mandate-Id Header)
|
|
|
|
Args:
|
|
currentUser: User object
|
|
mandateId: Explicit mandate context (from request header). Required for non-sysadmin.
|
|
|
|
Returns:
|
|
AppObjects instance for the user context
|
|
"""
|
|
if not currentUser:
|
|
raise ValueError("Invalid user context: user is required")
|
|
|
|
effectiveMandateId = mandateId
|
|
|
|
# Create context key (user + mandate combination)
|
|
contextKey = f"{effectiveMandateId}_{currentUser.id}"
|
|
|
|
# Create new instance if not exists
|
|
if contextKey not in _gatewayInterfaces:
|
|
instance = AppObjects(currentUser)
|
|
instance.setUserContext(currentUser, mandateId=effectiveMandateId)
|
|
_gatewayInterfaces[contextKey] = instance
|
|
else:
|
|
# Re-apply user on every resolve: a prior code path (e.g. legacy logout) may have
|
|
# cleared currentUser on this cached singleton; OAuth/login must not see a stale context.
|
|
_gatewayInterfaces[contextKey].setUserContext(currentUser, mandateId=effectiveMandateId)
|
|
|
|
return _gatewayInterfaces[contextKey]
|
|
|
|
|
|
def getRootInterface() -> AppObjects:
|
|
"""
|
|
Returns a AppObjects instance with root privileges.
|
|
This is used for initial setup and user creation.
|
|
|
|
Note: This function uses security.rootAccess internally to avoid circular dependencies.
|
|
Routes can continue using this function, but connectors/interfaces should use
|
|
security.rootAccess.getRootDbAppConnector() or security.rootAccess.getRootUser() directly.
|
|
"""
|
|
global _rootAppObjects
|
|
|
|
if _rootAppObjects is None:
|
|
try:
|
|
# Use security.rootAccess to get root user (avoids circular dependencies)
|
|
from modules.security.rootAccess import getRootUser
|
|
rootUser = getRootUser()
|
|
|
|
# Create root interface with the root user
|
|
_rootAppObjects = AppObjects(rootUser)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting root user: {str(e)}")
|
|
raise ValueError(f"Failed to get root user: {str(e)}")
|
|
|
|
return _rootAppObjects
|