4181 lines
165 KiB
Python
4181 lines
165 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
||
# All rights reserved.
|
||
"""
|
||
Interface to the Gateway system.
|
||
Manages users and mandates for authentication.
|
||
|
||
Multi-Tenant Design:
|
||
- User gehört nicht mehr direkt zu einem Mandanten
|
||
- mandateId wird aus Request-Context übergeben (X-Mandate-Id Header)
|
||
"""
|
||
|
||
import logging
|
||
import math
|
||
from typing import Dict, Any, List, Optional, Union
|
||
from passlib.context import CryptContext
|
||
import uuid
|
||
|
||
from modules.connectors.connectorDbPostgre import DatabaseConnector, getCachedConnector
|
||
from modules.shared.configuration import APP_CONFIG
|
||
from modules.shared.dbRegistry import registerDatabase
|
||
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
|
||
from modules.shared.i18nRegistry import resolveText
|
||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
||
from modules.security.rbac import RbacClass
|
||
from modules.datamodels.datamodelUam import (
|
||
User,
|
||
Mandate,
|
||
UserInDB,
|
||
UserConnection,
|
||
AuthAuthority,
|
||
ConnectionStatus,
|
||
)
|
||
from modules.datamodels.datamodelRbac import (
|
||
AccessRule,
|
||
AccessRuleContext,
|
||
Role,
|
||
)
|
||
from modules.datamodels.datamodelUam import AccessLevel
|
||
from modules.datamodels.datamodelSecurity import Token, AuthEvent, TokenStatus, TokenPurpose
|
||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
|
||
from modules.datamodels.datamodelMembership import (
|
||
UserMandate,
|
||
UserMandateRole,
|
||
FeatureAccess,
|
||
FeatureAccessRole,
|
||
)
|
||
from modules.datamodels.datamodelFeatures import Feature, FeatureInstance
|
||
from modules.datamodels.datamodelInvitation import Invitation
|
||
from modules.datamodels.datamodelNotification import UserNotification
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
appDatabase = "poweron_app"
|
||
registerDatabase(appDatabase)
|
||
|
||
# Singleton factory for AppObjects instances per context
|
||
_gatewayInterfaces = {}
|
||
|
||
# Root interface instance
|
||
_rootAppObjects = None
|
||
|
||
|
||
# Password-Hashing
|
||
pwdContext = CryptContext(schemes=["argon2"], deprecated="auto")
|
||
|
||
|
||
class AppObjects:
|
||
"""
|
||
Interface to the Gateway system.
|
||
Manages users and mandates.
|
||
"""
|
||
|
||
def __init__(self, currentUser: Optional[User] = None):
|
||
"""Initializes the Gateway Interface."""
|
||
# Initialize variables
|
||
self.currentUser = currentUser # Store User object directly
|
||
self.userId = currentUser.id if currentUser else None
|
||
self.mandateId = None # mandateId comes from setUserContext, not from User
|
||
self.featureInstanceId = None # featureInstanceId comes from setUserContext
|
||
|
||
# Initialize database
|
||
self._initializeDatabase()
|
||
|
||
# Set user context if provided
|
||
if currentUser:
|
||
self.setUserContext(currentUser)
|
||
|
||
def setUserContext(self, currentUser: User, mandateId: Optional[str] = None):
|
||
"""
|
||
Sets the user context for the interface.
|
||
|
||
Multi-Tenant Design:
|
||
- mandateId wird explizit übergeben (aus Request-Context / X-Mandate-Id Header)
|
||
- isSysAdmin User brauchen kein mandateId für System-Operationen
|
||
|
||
Args:
|
||
currentUser: User object
|
||
mandateId: Explicit mandate context (from request header). Required for non-sysadmin.
|
||
"""
|
||
if not currentUser:
|
||
logger.info("Initializing interface without user context")
|
||
return
|
||
|
||
self.currentUser = currentUser # Store User object directly
|
||
self.userId = currentUser.id
|
||
|
||
# mandateId comes from parameter only
|
||
self.mandateId = mandateId
|
||
|
||
# Validate: userId is always required
|
||
if not self.userId:
|
||
raise ValueError("Invalid user context: id is required")
|
||
|
||
# Note: mandateId is ALWAYS optional here - it comes from Request-Context, not from User.
|
||
# Users are NOT assigned to mandates by design - they get mandate context from the request.
|
||
# sysAdmin users can additionally perform cross-mandate operations.
|
||
|
||
# Add language settings
|
||
self.userLanguage = currentUser.language # Default user language
|
||
|
||
# Initialize RBAC interface
|
||
# Pass self.db as dbApp since this interface uses DbApp database
|
||
self.rbac = RbacClass(self.db, dbApp=self.db)
|
||
|
||
# Update database context
|
||
self.db.updateContext(self.userId)
|
||
|
||
def __del__(self):
|
||
"""Cleanup method to close database connection."""
|
||
if hasattr(self, "db") and self.db is not None:
|
||
try:
|
||
self.db.close()
|
||
except Exception as e:
|
||
logger.error(f"Error closing database connection: {e}")
|
||
|
||
def _initializeDatabase(self):
|
||
"""Initializes the database connection directly."""
|
||
try:
|
||
# Get configuration values with defaults
|
||
dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data")
|
||
dbDatabase = appDatabase
|
||
dbUser = APP_CONFIG.get("DB_USER")
|
||
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
||
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
||
|
||
self.db = getCachedConnector(
|
||
dbHost=dbHost,
|
||
dbDatabase=dbDatabase,
|
||
dbUser=dbUser,
|
||
dbPassword=dbPassword,
|
||
dbPort=dbPort,
|
||
userId=self.userId,
|
||
)
|
||
|
||
logger.info(f"Database initialized successfully for user {self.userId}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize database: {str(e)}")
|
||
raise
|
||
|
||
def _separateObjectFields(self, model_class, data: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
||
"""Separate simple fields from object fields based on Pydantic model structure."""
|
||
simpleFields = {}
|
||
objectFields = {}
|
||
|
||
# Get field information from the Pydantic model
|
||
modelFields = model_class.model_fields
|
||
|
||
for fieldName, value in data.items():
|
||
# Check if this field should be stored as JSONB in the database
|
||
if fieldName in modelFields:
|
||
fieldInfo = modelFields[fieldName]
|
||
# Pydantic v2 only
|
||
fieldType = fieldInfo.annotation
|
||
|
||
# Check if this is a JSONB field (Dict, List, or complex types)
|
||
# Purely type-based detection - no hardcoded field names
|
||
if (fieldType == dict or
|
||
fieldType == list or
|
||
(hasattr(fieldType, '__origin__') and fieldType.__origin__ in (dict, list))):
|
||
# Store as JSONB - include in simple_fields for database storage
|
||
simpleFields[fieldName] = value
|
||
elif isinstance(value, (str, int, float, bool, type(None))):
|
||
# Simple scalar types
|
||
simpleFields[fieldName] = value
|
||
else:
|
||
# Complex objects that should be filtered out
|
||
objectFields[fieldName] = value
|
||
else:
|
||
# Field not in model - pass through scalars; nested objects go to objectFields
|
||
if isinstance(value, (str, int, float, bool, type(None))):
|
||
simpleFields[fieldName] = value
|
||
else:
|
||
objectFields[fieldName] = value
|
||
|
||
return simpleFields, objectFields
|
||
|
||
|
||
|
||
def checkRbacPermission(
|
||
self,
|
||
modelClass: type,
|
||
operation: str,
|
||
recordId: Optional[str] = None
|
||
) -> bool:
|
||
"""
|
||
Check RBAC permission for a specific operation on a table.
|
||
|
||
Args:
|
||
modelClass: Pydantic model class for the table
|
||
operation: Operation to check ('create', 'update', 'delete', 'read')
|
||
recordId: Optional record ID for specific record check
|
||
|
||
Returns:
|
||
Boolean indicating permission
|
||
"""
|
||
if not self.rbac or not self.currentUser:
|
||
return False
|
||
|
||
tableName = modelClass.__name__
|
||
# Use buildDataObjectKey for semantic namespace lookup
|
||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||
objectKey = buildDataObjectKey(tableName)
|
||
permissions = self.rbac.getUserPermissions(
|
||
self.currentUser,
|
||
AccessRuleContext.DATA,
|
||
objectKey,
|
||
mandateId=self.mandateId
|
||
)
|
||
|
||
if operation == "create":
|
||
return permissions.create != AccessLevel.NONE
|
||
elif operation == "update":
|
||
return permissions.update != AccessLevel.NONE
|
||
elif operation == "delete":
|
||
return permissions.delete != AccessLevel.NONE
|
||
elif operation == "read":
|
||
return permissions.read != AccessLevel.NONE
|
||
else:
|
||
return False
|
||
|
||
def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
"""
|
||
Apply filter criteria to records.
|
||
|
||
Supports:
|
||
- General search: {"search": "text"} - searches across all text fields
|
||
- Field-specific filters:
|
||
- Simple: {"status": "running"} - equals match
|
||
- With operator: {"status": {"operator": "equals", "value": "running"}}
|
||
- Operators: equals, contains, gt, gte, lt, lte, in, notIn, startsWith, endsWith
|
||
|
||
Args:
|
||
records: List of record dictionaries to filter
|
||
filters: Filter criteria dictionary
|
||
|
||
Returns:
|
||
Filtered list of records
|
||
"""
|
||
if not filters or not records:
|
||
return records
|
||
|
||
filtered = []
|
||
|
||
for record in records:
|
||
matches = True
|
||
|
||
# Handle general search across text fields
|
||
if "search" in filters:
|
||
search_term = str(filters["search"]).lower()
|
||
if search_term:
|
||
# Search in all string fields
|
||
found = False
|
||
for key, value in record.items():
|
||
if isinstance(value, str) and search_term in value.lower():
|
||
found = True
|
||
break
|
||
elif isinstance(value, (int, float)) and search_term in str(value):
|
||
found = True
|
||
break
|
||
if not found:
|
||
matches = False
|
||
|
||
# Handle field-specific filters
|
||
for field_name, filter_value in filters.items():
|
||
if field_name == "search":
|
||
continue # Already handled above
|
||
|
||
if field_name not in record:
|
||
matches = False
|
||
break
|
||
|
||
record_value = record.get(field_name)
|
||
|
||
# Handle simple value (equals operator)
|
||
# Compare as strings to handle type mismatches (frontend sends strings)
|
||
if not isinstance(filter_value, dict):
|
||
# Convert both to strings for comparison (case-insensitive for strings)
|
||
recordStr = str(record_value).lower() if record_value is not None else ""
|
||
filterStr = str(filter_value).lower() if filter_value is not None else ""
|
||
if recordStr != filterStr:
|
||
matches = False
|
||
break
|
||
continue
|
||
|
||
# Handle filter with operator
|
||
operator = filter_value.get("operator", "equals")
|
||
filter_val = filter_value.get("value")
|
||
|
||
if operator in ["equals", "eq"]:
|
||
# Convert both to strings for comparison
|
||
recordStr = str(record_value).lower() if record_value is not None else ""
|
||
filterStr = str(filter_val).lower() if filter_val is not None else ""
|
||
if recordStr != filterStr:
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "contains":
|
||
record_str = str(record_value).lower() if record_value is not None else ""
|
||
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
||
if filter_str not in record_str:
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "startsWith":
|
||
record_str = str(record_value).lower() if record_value is not None else ""
|
||
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
||
if not record_str.startswith(filter_str):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "endsWith":
|
||
record_str = str(record_value).lower() if record_value is not None else ""
|
||
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
||
if not record_str.endswith(filter_str):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "gt":
|
||
try:
|
||
record_num = float(record_value) if record_value is not None else float('-inf')
|
||
filter_num = float(filter_val) if filter_val is not None else float('-inf')
|
||
if record_num <= filter_num:
|
||
matches = False
|
||
break
|
||
except (ValueError, TypeError):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "gte":
|
||
try:
|
||
record_num = float(record_value) if record_value is not None else float('-inf')
|
||
filter_num = float(filter_val) if filter_val is not None else float('-inf')
|
||
if record_num < filter_num:
|
||
matches = False
|
||
break
|
||
except (ValueError, TypeError):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "lt":
|
||
try:
|
||
record_num = float(record_value) if record_value is not None else float('inf')
|
||
filter_num = float(filter_val) if filter_val is not None else float('inf')
|
||
if record_num >= filter_num:
|
||
matches = False
|
||
break
|
||
except (ValueError, TypeError):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "lte":
|
||
try:
|
||
record_num = float(record_value) if record_value is not None else float('inf')
|
||
filter_num = float(filter_val) if filter_val is not None else float('inf')
|
||
if record_num > filter_num:
|
||
matches = False
|
||
break
|
||
except (ValueError, TypeError):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "in":
|
||
if not isinstance(filter_val, list):
|
||
filter_val = [filter_val]
|
||
if record_value not in filter_val:
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "notIn":
|
||
if not isinstance(filter_val, list):
|
||
filter_val = [filter_val]
|
||
if record_value in filter_val:
|
||
matches = False
|
||
break
|
||
|
||
else:
|
||
# Unknown operator - default to equals
|
||
if str(record_value).lower() != str(filter_val).lower():
|
||
matches = False
|
||
break
|
||
|
||
if matches:
|
||
filtered.append(record)
|
||
|
||
return filtered
|
||
|
||
def _applySorting(self, records: List[Dict[str, Any]], sortFields: List[Any]) -> List[Dict[str, Any]]:
|
||
"""Apply multi-level sorting to records using stable sort (sorts from least to most significant field)."""
|
||
if not sortFields:
|
||
return records
|
||
|
||
# Start with a copy to avoid modifying original
|
||
sortedRecords = list(records)
|
||
|
||
# Sort from least significant to most significant field (reverse order)
|
||
# Python's sort is stable, so this creates proper multi-level sorting
|
||
for sortField in reversed(sortFields):
|
||
# Handle both dict and object formats
|
||
if isinstance(sortField, dict):
|
||
fieldName = sortField.get("field")
|
||
direction = sortField.get("direction", "asc")
|
||
else:
|
||
fieldName = getattr(sortField, "field", None)
|
||
direction = getattr(sortField, "direction", "asc")
|
||
|
||
if not fieldName:
|
||
continue
|
||
|
||
isDesc = (direction == "desc")
|
||
|
||
def sortKey(record):
|
||
value = record.get(fieldName)
|
||
# Handle None values - place them at the end for both directions
|
||
if value is None:
|
||
# Use a special value that sorts last
|
||
return (1, "") # (is_none_flag, empty_value) - sorts after (0, ...)
|
||
else:
|
||
# Return tuple with type indicator for proper comparison
|
||
if isinstance(value, (int, float)):
|
||
return (0, value)
|
||
elif isinstance(value, str):
|
||
return (0, value)
|
||
elif isinstance(value, bool):
|
||
return (0, value)
|
||
else:
|
||
return (0, str(value))
|
||
|
||
# Sort with reverse parameter for descending
|
||
sortedRecords.sort(key=sortKey, reverse=isDesc)
|
||
|
||
return sortedRecords
|
||
|
||
def getInitialId(self, model_class: type) -> Optional[str]:
|
||
"""Returns the initial ID for a table."""
|
||
return self.db.getInitialId(model_class)
|
||
|
||
def _getRootMandateId(self) -> Optional[str]:
|
||
"""Get the root mandate ID (name='root', isSystem=True)."""
|
||
rootMandates = self.db.getRecordset(Mandate, recordFilter={"name": "root", "isSystem": True})
|
||
if rootMandates:
|
||
return rootMandates[0].get("id")
|
||
return None
|
||
|
||
def _getPasswordHash(self, password: str) -> str:
|
||
"""Creates a hash for a password."""
|
||
return pwdContext.hash(password)
|
||
|
||
def _verifyPassword(self, plainPassword: str, hashedPassword: str) -> bool:
|
||
"""Checks if the password matches the hash."""
|
||
return pwdContext.verify(plainPassword, hashedPassword)
|
||
|
||
# User methods
|
||
|
||
def getUsersByMandate(self, mandateId: str, pagination: Optional[PaginationParams] = None) -> Union[List[User], PaginatedResult]:
|
||
"""
|
||
Returns users for a specific mandate.
|
||
Uses UserMandate junction table to find users belonging to the mandate.
|
||
|
||
Args:
|
||
mandateId: The mandate ID to get users for
|
||
pagination: Optional pagination parameters. If None, returns all items.
|
||
|
||
Returns:
|
||
If pagination is None: List[User]
|
||
If pagination is provided: PaginatedResult with items and metadata
|
||
"""
|
||
userMandates = self.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})
|
||
userIds = [um.get("userId") for um in userMandates if um.get("userId")]
|
||
|
||
if not userIds:
|
||
if pagination is None:
|
||
return []
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
|
||
result = self.db.getRecordsetPaginated(
|
||
UserInDB,
|
||
pagination=pagination,
|
||
recordFilter={"id": userIds}
|
||
)
|
||
|
||
items = []
|
||
for record in result["items"]:
|
||
cleanedUser = dict(record)
|
||
if cleanedUser.get("roleLabels") is None:
|
||
cleanedUser["roleLabels"] = []
|
||
items.append(User(**cleanedUser))
|
||
|
||
if pagination is None:
|
||
return items
|
||
|
||
return PaginatedResult(
|
||
items=items,
|
||
totalItems=result["totalItems"],
|
||
totalPages=result["totalPages"]
|
||
)
|
||
|
||
def getUserByUsername(self, username: str) -> Optional[User]:
|
||
"""Returns a user by username."""
|
||
try:
|
||
# Use RBAC filtering
|
||
users = getRecordsetWithRBAC(self.db,
|
||
UserInDB,
|
||
self.currentUser,
|
||
recordFilter={"username": username},
|
||
mandateId=self.mandateId
|
||
)
|
||
|
||
if not users:
|
||
logger.info(f"No user found with username {username}")
|
||
return None
|
||
|
||
# Return first matching user (should be unique)
|
||
userDict = users[0]
|
||
# Filter out database-specific fields
|
||
cleanedUser = dict(userDict)
|
||
# Ensure roleLabels is always a list, not None
|
||
if cleanedUser.get("roleLabels") is None:
|
||
cleanedUser["roleLabels"] = []
|
||
return User(**cleanedUser)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting user by username: {str(e)}")
|
||
return None
|
||
|
||
def getUser(self, userId: str) -> Optional[User]:
|
||
"""Returns a user by ID if user has access."""
|
||
try:
|
||
# Get users filtered by RBAC
|
||
users = getRecordsetWithRBAC(self.db,
|
||
UserInDB,
|
||
self.currentUser,
|
||
recordFilter={"id": userId},
|
||
mandateId=self.mandateId
|
||
)
|
||
|
||
if not users:
|
||
return None
|
||
|
||
# User already filtered by RBAC, just clean fields
|
||
user_dict = users[0]
|
||
cleanedUser = dict(user_dict)
|
||
# Ensure roleLabels is always a list, not None
|
||
if cleanedUser.get("roleLabels") is None:
|
||
cleanedUser["roleLabels"] = []
|
||
return User(**cleanedUser)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting user by ID: {str(e)}")
|
||
return None
|
||
|
||
def getUsersByIds(self, userIds: list[str]) -> dict[str, User]:
|
||
"""Batch-load users by IDs in a single SQL query (id = ANY(...)).
|
||
Returns {userId: User} dict. Skips IDs not found or not accessible."""
|
||
if not userIds:
|
||
return {}
|
||
try:
|
||
uniqueIds = list(set(userIds))
|
||
records = self.db.getRecordset(UserInDB, recordFilter={"id": uniqueIds})
|
||
result: dict[str, User] = {}
|
||
for rec in (records or []):
|
||
cleaned = dict(rec)
|
||
if cleaned.get("roleLabels") is None:
|
||
cleaned["roleLabels"] = []
|
||
uid = cleaned.get("id")
|
||
if uid:
|
||
result[uid] = User(**cleaned)
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error batch-loading users: {e}")
|
||
return {}
|
||
|
||
def getMandatesByIds(self, mandateIds: list[str]) -> dict[str, Mandate]:
|
||
"""Batch-load mandates by IDs in a single SQL query (id = ANY(...)).
|
||
Returns {mandateId: Mandate} dict."""
|
||
if not mandateIds:
|
||
return {}
|
||
try:
|
||
uniqueIds = list(set(mandateIds))
|
||
records = self.db.getRecordset(Mandate, recordFilter={"id": uniqueIds})
|
||
result: dict[str, Mandate] = {}
|
||
for rec in (records or []):
|
||
cleaned = dict(rec)
|
||
mid = cleaned.get("id")
|
||
if mid:
|
||
result[mid] = Mandate(**cleaned)
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error batch-loading mandates: {e}")
|
||
return {}
|
||
|
||
def _getUserForAuthentication(self, username: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Get user record by username for authentication purposes.
|
||
|
||
SECURITY NOTE: This method bypasses RBAC intentionally because:
|
||
1. Users are NOT mandate-bound (Multi-Tenant Design)
|
||
2. Authentication must work regardless of mandate context
|
||
3. RBAC filtering for User table requires mandate context which doesn't exist at login time
|
||
|
||
This method should ONLY be used for authentication flows.
|
||
For all other user queries, use getUserByUsername() which applies RBAC.
|
||
|
||
Returns:
|
||
Full UserInDB record as dict, or None if not found
|
||
"""
|
||
try:
|
||
users = self.db.getRecordset(UserInDB, recordFilter={"username": username})
|
||
if not users:
|
||
return None
|
||
return users[0]
|
||
except Exception as e:
|
||
logger.error(f"Error getting user for authentication: {str(e)}")
|
||
return None
|
||
|
||
def authenticateLocalUser(self, username: str, password: str) -> Optional[User]:
|
||
"""
|
||
Authenticates a user by username and password using local authentication.
|
||
|
||
SECURITY NOTE: Uses _getUserForAuthentication() which bypasses RBAC.
|
||
This is intentional because users are mandate-independent.
|
||
"""
|
||
# Get full user record directly (bypasses RBAC - see _getUserForAuthentication docstring)
|
||
userRecord = self._getUserForAuthentication(username)
|
||
|
||
if not userRecord:
|
||
raise ValueError("User not found")
|
||
|
||
# Check if the user is enabled
|
||
if not userRecord.get("enabled", True):
|
||
raise ValueError("User is disabled")
|
||
|
||
# Verify that the user has local authentication enabled
|
||
authAuthority = userRecord.get("authenticationAuthority", AuthAuthority.LOCAL)
|
||
if authAuthority != AuthAuthority.LOCAL and authAuthority != AuthAuthority.LOCAL.value:
|
||
raise ValueError("User does not have local authentication enabled")
|
||
|
||
if not userRecord.get("hashedPassword"):
|
||
raise ValueError("User has no password set")
|
||
|
||
if not self._verifyPassword(password, userRecord["hashedPassword"]):
|
||
raise ValueError("Invalid password")
|
||
|
||
user = User.model_validate(userRecord)
|
||
if user.roleLabels is None:
|
||
return user.model_copy(update={"roleLabels": []})
|
||
return user
|
||
|
||
def createUser(
|
||
self,
|
||
username: str,
|
||
password: str = None,
|
||
email: str = None,
|
||
fullName: str = None,
|
||
language: str = "de",
|
||
enabled: bool = True,
|
||
authenticationAuthority: AuthAuthority = AuthAuthority.LOCAL,
|
||
externalId: str = None,
|
||
externalUsername: str = None,
|
||
externalEmail: str = None,
|
||
isSysAdmin: bool = False,
|
||
isPlatformAdmin: bool = False,
|
||
addExternalIdentityConnection: bool = True,
|
||
) -> User:
|
||
"""
|
||
Create a new user.
|
||
|
||
Note: Role assignment is done via createUserMandate(), not via User fields.
|
||
|
||
Args:
|
||
addExternalIdentityConnection: If True (default) and externalId/externalUsername are set,
|
||
creates a UserConnection row. OAuth login-only flows should pass False (data connection
|
||
is created separately via /auth/connect).
|
||
"""
|
||
try:
|
||
# Ensure username is a string
|
||
username = str(username).strip()
|
||
|
||
# Check global username uniqueness (across ALL mandates)
|
||
if not self.isUsernameGloballyUnique(username):
|
||
raise ValueError(f"Username '{username}' is already taken")
|
||
|
||
# For LOCAL auth: password is set exclusively via magic link, not during registration
|
||
# Password parameter is only used for internal/admin operations (e.g., resetUserPassword)
|
||
if password:
|
||
if not isinstance(password, str):
|
||
raise ValueError("Password must be a string")
|
||
if not password.strip():
|
||
raise ValueError("Password cannot be empty")
|
||
|
||
# Create user data using UserInDB model
|
||
# Note: mandateId and roleLabels are REMOVED - use UserMandate + UserMandateRole
|
||
userData = UserInDB(
|
||
username=username,
|
||
email=email,
|
||
fullName=fullName,
|
||
language=language,
|
||
enabled=enabled,
|
||
isSysAdmin=isSysAdmin,
|
||
isPlatformAdmin=isPlatformAdmin,
|
||
authenticationAuthority=authenticationAuthority,
|
||
hashedPassword=self._getPasswordHash(password) if password else None,
|
||
)
|
||
|
||
# Create user record
|
||
createdRecord = self.db.recordCreate(UserInDB, userData)
|
||
if not createdRecord or not createdRecord.get("id"):
|
||
raise ValueError("Failed to create user record")
|
||
|
||
# Optional: mirror external IdP identity into UserConnections (data/API OAuth).
|
||
# Auth-only login (Google/MSFT JWT) must NOT create a connection — see OAuth split.
|
||
if addExternalIdentityConnection and externalId and externalUsername:
|
||
self.addUserConnection(
|
||
createdRecord["id"],
|
||
authenticationAuthority,
|
||
externalId,
|
||
externalUsername,
|
||
externalEmail,
|
||
)
|
||
|
||
# Get created user using the returned ID
|
||
createdUser = self.db.getRecordset(
|
||
UserInDB, recordFilter={"id": createdRecord["id"]}
|
||
)
|
||
if not createdUser or len(createdUser) == 0:
|
||
raise ValueError("Failed to retrieve created user")
|
||
|
||
# Clear cache to ensure fresh data (already done above)
|
||
|
||
# Note: root mandate assignment removed — users get their own mandate via
|
||
# _provisionMandateForUser during registration. Root mandate is purely technical.
|
||
|
||
return User(**createdUser[0])
|
||
|
||
except ValueError as e:
|
||
logger.error(f"Error creating user: {str(e)}")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error creating user: {str(e)}")
|
||
raise ValueError(f"Failed to create user: {str(e)}")
|
||
|
||
def updateUser(
|
||
self,
|
||
userId: str,
|
||
updateData: Union[Dict[str, Any], User],
|
||
allowAdminFlagChange: bool = False,
|
||
) -> User:
|
||
"""Update a user's information.
|
||
|
||
Args:
|
||
userId: ID of the user to update
|
||
updateData: User data to update (dict or User model)
|
||
allowAdminFlagChange: If True, allows changing the privileged platform
|
||
flags ``isSysAdmin`` and ``isPlatformAdmin``.
|
||
Only set to True when called by a Platform Admin
|
||
explicitly changing another user's admin status.
|
||
"""
|
||
try:
|
||
# Get user
|
||
user = self.getUser(userId)
|
||
if not user:
|
||
raise ValueError(f"User {userId} not found")
|
||
|
||
# Convert updateData to dict if it's a User model.
|
||
#
|
||
# IMPORTANT: When the route layer passes a Pydantic ``User`` instance,
|
||
# ``model_dump()`` returns ALL fields — including those the client
|
||
# never sent — populated with Pydantic defaults (e.g. ``isSysAdmin=False``).
|
||
# That historical pattern caused silent flag flips on inline-toggles.
|
||
#
|
||
# The PUT route now ships a plain dict carrying ONLY the explicitly
|
||
# changed fields, so this branch should rarely fire; however internal
|
||
# callers (``disableUser`` / ``enableUser`` / migration scripts) still
|
||
# use dict-style partials and must remain partial-safe.
|
||
if isinstance(updateData, User):
|
||
updateDict = updateData.model_dump(exclude_unset=True)
|
||
# Fallback for legacy callers that constructed a fully-defaulted
|
||
# User: if nothing was marked as explicitly set, treat the dump
|
||
# as authoritative but DROP privileged flags unconditionally
|
||
# unless allowAdminFlagChange is True.
|
||
if not updateDict:
|
||
updateDict = updateData.model_dump()
|
||
else:
|
||
updateDict = updateData.copy() if isinstance(updateData, dict) else dict(updateData)
|
||
|
||
updateDict.pop("id", None)
|
||
|
||
# SECURITY: Protect privileged platform flags from accidental
|
||
# overwrite via profile forms or partial payloads from clients
|
||
# whose model defaults could pull the value down to False.
|
||
protectedFields = ["isSysAdmin", "isPlatformAdmin"]
|
||
if not allowAdminFlagChange:
|
||
for field in protectedFields:
|
||
updateDict.pop(field, None)
|
||
|
||
# Update user data using model
|
||
updatedData = user.model_dump()
|
||
updatedData.update(updateDict)
|
||
# Ensure ID matches userId parameter
|
||
updatedData["id"] = userId
|
||
updatedUser = User(**updatedData)
|
||
|
||
# Update user record
|
||
self.db.recordModify(UserInDB, userId, updatedUser)
|
||
|
||
# Get updated user
|
||
updatedUser = self.getUser(userId)
|
||
if not updatedUser:
|
||
raise ValueError("Failed to retrieve updated user")
|
||
|
||
return updatedUser
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error updating user: {str(e)}")
|
||
raise ValueError(f"Failed to update user: {str(e)}")
|
||
|
||
def disableUser(self, userId: str) -> User:
|
||
"""Disables a user if current user has permission."""
|
||
return self.updateUser(userId, {"enabled": False})
|
||
|
||
def enableUser(self, userId: str) -> User:
|
||
"""Enables a user if current user has permission."""
|
||
return self.updateUser(userId, {"enabled": True})
|
||
|
||
def resetUserPassword(self, userId: str, newPassword: str) -> bool:
|
||
"""Reset a user's password (admin function)."""
|
||
try:
|
||
if not newPassword or len(newPassword) < 8:
|
||
raise ValueError("Password must be at least 8 characters long")
|
||
|
||
hashedPassword = self._getPasswordHash(newPassword)
|
||
self.db.recordModify(UserInDB, userId, {"hashedPassword": hashedPassword})
|
||
logger.info(f"Password reset for user {userId}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error resetting password for user {userId}: {str(e)}")
|
||
return False
|
||
|
||
def generateResetTokenAndExpiry(self) -> tuple:
|
||
"""Generate a new reset token and expiration timestamp.
|
||
|
||
Returns:
|
||
tuple: (tokenUuid: str, expiresTimestamp: float)
|
||
"""
|
||
token = str(uuid.uuid4())
|
||
expiryHours = int(APP_CONFIG.get("Auth_RESET_TOKEN_EXPIRY_HOURS", "24"))
|
||
expires = getUtcTimestamp() + (expiryHours * 3600)
|
||
return token, expires
|
||
|
||
def findUserByEmailLocalAuth(self, email: str) -> Optional[User]:
|
||
"""Find LOCAL auth user by email (searches across all mandates).
|
||
|
||
Note: If multiple users exist with the same email (in different mandates),
|
||
this returns only the first one. Use findAllUsersByEmailLocalAuth() to get all.
|
||
|
||
Args:
|
||
email: Email address to search for (case-insensitive)
|
||
|
||
Returns:
|
||
User if found, None otherwise
|
||
"""
|
||
users = self.findAllUsersByEmailLocalAuth(email)
|
||
return users[0] if users else None
|
||
|
||
def findAllUsersByEmailLocalAuth(self, email: str) -> List[User]:
|
||
"""Find ALL LOCAL auth users by email (searches across all mandates).
|
||
|
||
Use this when a user might have multiple accounts with the same email
|
||
in different mandates.
|
||
|
||
Args:
|
||
email: Email address to search for (case-insensitive)
|
||
|
||
Returns:
|
||
List of Users (empty list if none found)
|
||
"""
|
||
if not email:
|
||
return []
|
||
|
||
normalizedEmail = email.lower().strip()
|
||
|
||
try:
|
||
# Search directly without RBAC for cross-mandate search
|
||
users = self.db.getRecordset(
|
||
UserInDB,
|
||
recordFilter={
|
||
"email": normalizedEmail,
|
||
"authenticationAuthority": AuthAuthority.LOCAL.value
|
||
}
|
||
)
|
||
|
||
result = []
|
||
for userRecord in users:
|
||
cleanedUser = dict(userRecord)
|
||
if cleanedUser.get("roleLabels") is None:
|
||
cleanedUser["roleLabels"] = []
|
||
result.append(User(**cleanedUser))
|
||
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error finding users by email: {str(e)}")
|
||
return []
|
||
|
||
def findUserByEmailAndUsernameLocalAuth(self, email: str, username: str) -> Optional[User]:
|
||
"""Find LOCAL auth user by email AND username combination.
|
||
|
||
This uniquely identifies a user even if they have multiple accounts
|
||
with the same email in different mandates.
|
||
|
||
Args:
|
||
email: Email address to search for (case-insensitive)
|
||
username: Username to search for (case-sensitive)
|
||
|
||
Returns:
|
||
User if found, None otherwise
|
||
"""
|
||
if not email or not username:
|
||
return None
|
||
|
||
normalizedEmail = email.lower().strip()
|
||
|
||
try:
|
||
# Search directly without RBAC for cross-mandate search
|
||
users = self.db.getRecordset(
|
||
UserInDB,
|
||
recordFilter={
|
||
"email": normalizedEmail,
|
||
"username": username,
|
||
"authenticationAuthority": AuthAuthority.LOCAL.value
|
||
}
|
||
)
|
||
|
||
if users:
|
||
cleanedUser = dict(users[0])
|
||
if cleanedUser.get("roleLabels") is None:
|
||
cleanedUser["roleLabels"] = []
|
||
return User(**cleanedUser)
|
||
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error finding user by email and username: {str(e)}")
|
||
return None
|
||
|
||
def isUsernameGloballyUnique(self, username: str) -> bool:
|
||
"""Check if username is unique across ALL mandates (no RBAC filtering).
|
||
|
||
This is used for registration to ensure usernames are globally unique.
|
||
|
||
Args:
|
||
username: Username to check
|
||
|
||
Returns:
|
||
True if username is available (not used), False if already taken
|
||
"""
|
||
if not username:
|
||
return False
|
||
|
||
try:
|
||
# Search directly without RBAC for cross-mandate search
|
||
users = self.db.getRecordset(
|
||
UserInDB,
|
||
recordFilter={"username": username}
|
||
)
|
||
|
||
return len(users) == 0
|
||
except Exception as e:
|
||
logger.error(f"Error checking username uniqueness: {str(e)}")
|
||
return False # Fail safe - assume not unique on error
|
||
|
||
def findUserByUsernameLocalAuth(self, username: str) -> Optional[User]:
|
||
"""Find LOCAL auth user by username (searches across all mandates).
|
||
|
||
Username is globally unique, so this returns at most one user.
|
||
|
||
Args:
|
||
username: Username to search for
|
||
|
||
Returns:
|
||
User if found, None otherwise
|
||
"""
|
||
if not username:
|
||
return None
|
||
|
||
try:
|
||
# Search directly without RBAC for cross-mandate search
|
||
users = self.db.getRecordset(
|
||
UserInDB,
|
||
recordFilter={
|
||
"username": username,
|
||
"authenticationAuthority": AuthAuthority.LOCAL.value
|
||
}
|
||
)
|
||
|
||
if users:
|
||
cleanedUser = dict(users[0])
|
||
if cleanedUser.get("roleLabels") is None:
|
||
cleanedUser["roleLabels"] = []
|
||
return User(**cleanedUser)
|
||
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error finding user by username: {str(e)}")
|
||
return None
|
||
|
||
def setResetToken(self, userId: str, token: str, expires: float, clearPassword: bool = True) -> bool:
|
||
"""Set reset token for a user.
|
||
|
||
Args:
|
||
userId: User ID
|
||
token: Reset token UUID
|
||
expires: Expiration timestamp (float)
|
||
clearPassword: If True, clears the password hash
|
||
"""
|
||
try:
|
||
updateData = {
|
||
"resetToken": token,
|
||
"resetTokenExpires": expires
|
||
}
|
||
if clearPassword:
|
||
updateData["hashedPassword"] = None
|
||
|
||
self.db.recordModify(UserInDB, userId, updateData)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error setting reset token for user {userId}: {str(e)}")
|
||
return False
|
||
|
||
def verifyResetToken(self, token: str) -> Optional[User]:
|
||
"""Verify reset token and return user if valid.
|
||
|
||
Returns:
|
||
User if token is valid and not expired, None otherwise
|
||
"""
|
||
if not token:
|
||
return None
|
||
|
||
try:
|
||
users = self.db.getRecordset(UserInDB, recordFilter={"resetToken": token})
|
||
|
||
if not users:
|
||
return None
|
||
|
||
userRecord = users[0]
|
||
|
||
# Check expiration - ensure expires is converted to float for comparison
|
||
expires = userRecord.get("resetTokenExpires")
|
||
if expires is not None:
|
||
try:
|
||
expires = float(expires)
|
||
except (ValueError, TypeError):
|
||
logger.warning(f"Invalid resetTokenExpires value for user {userRecord.get('id')}: {expires}")
|
||
return None
|
||
|
||
if not expires or getUtcTimestamp() > expires:
|
||
logger.warning(f"Reset token expired for user {userRecord.get('id')}")
|
||
return None
|
||
|
||
cleanedUser = dict(userRecord)
|
||
if cleanedUser.get("roleLabels") is None:
|
||
cleanedUser["roleLabels"] = []
|
||
return User(**cleanedUser)
|
||
except Exception as e:
|
||
logger.error(f"Error verifying reset token: {str(e)}")
|
||
return None
|
||
|
||
def resetPasswordWithToken(self, token: str, newPassword: str) -> bool:
|
||
"""Reset password using token (atomic operation).
|
||
|
||
Returns:
|
||
True if successful, False otherwise
|
||
"""
|
||
try:
|
||
user = self.verifyResetToken(token)
|
||
if not user:
|
||
return False
|
||
|
||
if not newPassword or len(newPassword) < 8:
|
||
raise ValueError("Password must be at least 8 characters long")
|
||
|
||
hashedPassword = self._getPasswordHash(newPassword)
|
||
|
||
# Atomic update: set password, clear token, enable user
|
||
self.db.recordModify(UserInDB, user.id, {
|
||
"hashedPassword": hashedPassword,
|
||
"resetToken": None,
|
||
"resetTokenExpires": None,
|
||
"enabled": True
|
||
})
|
||
|
||
logger.info(f"Password reset completed for user {user.id}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error in resetPasswordWithToken: {str(e)}")
|
||
return False
|
||
|
||
def _deleteUserReferencedData(self, userId: str) -> None:
|
||
"""Deletes all data associated with a user (full cascade)."""
|
||
try:
|
||
from modules.datamodels.datamodelNotification import UserNotification
|
||
from modules.datamodels.datamodelInvitation import Invitation
|
||
|
||
# 1. FeatureAccess + FeatureAccessRole
|
||
accesses = self.db.getRecordset(FeatureAccess, recordFilter={"userId": userId})
|
||
for acc in accesses:
|
||
accId = acc.get("id")
|
||
if not accId:
|
||
continue
|
||
roles = self.db.getRecordset(FeatureAccessRole, recordFilter={"featureAccessId": accId})
|
||
for role in roles:
|
||
self.db.recordDelete(FeatureAccessRole, role.get("id"))
|
||
self.db.recordDelete(FeatureAccess, accId)
|
||
if accesses:
|
||
logger.info(f"User cascade: deleted {len(accesses)} FeatureAccess records for user {userId}")
|
||
|
||
# 2. UserMandate + UserMandateRole
|
||
memberships = self.db.getRecordset(UserMandate, recordFilter={"userId": userId})
|
||
for um in memberships:
|
||
umId = um.get("id")
|
||
if not umId:
|
||
continue
|
||
umRoles = self.db.getRecordset(UserMandateRole, recordFilter={"userMandateId": umId})
|
||
for umr in umRoles:
|
||
self.db.recordDelete(UserMandateRole, umr.get("id"))
|
||
self.db.recordDelete(UserMandate, umId)
|
||
if memberships:
|
||
logger.info(f"User cascade: deleted {len(memberships)} UserMandate records for user {userId}")
|
||
|
||
# 3. UserNotifications
|
||
notifications = self.db.getRecordset(UserNotification, recordFilter={"userId": userId})
|
||
for notif in notifications:
|
||
self.db.recordDelete(UserNotification, notif.get("id"))
|
||
if notifications:
|
||
logger.info(f"User cascade: deleted {len(notifications)} notifications for user {userId}")
|
||
|
||
# 4. Invitations (by email)
|
||
user = self.getUser(userId)
|
||
userEmail = getattr(user, "email", None) if user else None
|
||
if userEmail:
|
||
invitations = self.db.getRecordset(Invitation, recordFilter={"email": userEmail})
|
||
for inv in invitations:
|
||
self.db.recordDelete(Invitation, inv.get("id"))
|
||
if invitations:
|
||
logger.info(f"User cascade: deleted {len(invitations)} invitations for {userEmail}")
|
||
|
||
# 5. AuthEvents
|
||
events = self.db.getRecordset(AuthEvent, recordFilter={"userId": userId})
|
||
for event in events:
|
||
self.db.recordDelete(AuthEvent, event["id"])
|
||
|
||
# 6. Tokens
|
||
tokens = self.db.getRecordset(Token, recordFilter={"userId": userId})
|
||
for token in tokens:
|
||
self.db.recordDelete(Token, token["id"])
|
||
|
||
# 7. UserConnections
|
||
connections = self.db.getRecordset(
|
||
UserConnection, recordFilter={"userId": userId}
|
||
)
|
||
for conn in connections:
|
||
self.db.recordDelete(UserConnection, conn["id"])
|
||
|
||
logger.info(f"All referenced data for user {userId} has been deleted")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error deleting referenced data for user {userId}: {str(e)}")
|
||
raise
|
||
|
||
def deleteUser(self, userId: str) -> bool:
|
||
"""Deletes a user if current user has permission."""
|
||
try:
|
||
# Get user
|
||
user = self.getUser(userId)
|
||
if not user:
|
||
raise ValueError(f"User {userId} not found")
|
||
|
||
if not self.checkRbacPermission(UserInDB, "update", userId):
|
||
raise PermissionError(f"No permission to delete user {userId}")
|
||
|
||
# Delete all referenced data first
|
||
self._deleteUserReferencedData(userId)
|
||
|
||
# Delete user record
|
||
success = self.db.recordDelete(UserInDB, userId)
|
||
if not success:
|
||
raise ValueError(f"Failed to delete user {userId}")
|
||
|
||
logger.info(f"User {userId} successfully deleted")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error deleting user: {str(e)}")
|
||
raise ValueError(f"Failed to delete user: {str(e)}")
|
||
|
||
def _getInitialUser(self) -> Optional[Dict[str, Any]]:
|
||
"""Get the initial user record directly from database without access control."""
|
||
try:
|
||
initialUserId = self.getInitialId(UserInDB)
|
||
if not initialUserId:
|
||
return None
|
||
|
||
users = getRecordsetWithRBAC(self.db,
|
||
UserInDB,
|
||
self.currentUser,
|
||
recordFilter={"id": initialUserId},
|
||
mandateId=self.mandateId
|
||
)
|
||
return users[0] if users else None
|
||
except Exception as e:
|
||
logger.error(f"Error getting initial user: {str(e)}")
|
||
return None
|
||
|
||
def checkUsernameAvailability(self, checkData: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Checks if a username is available for registration.
|
||
|
||
Uses global uniqueness check (across ALL mandates) to ensure
|
||
usernames are unique system-wide.
|
||
"""
|
||
try:
|
||
username = checkData.get("username")
|
||
# authenticationAuthority not used - usernames must be globally unique
|
||
|
||
if not username:
|
||
return {"available": False, "message": "Username is required"}
|
||
|
||
# Check global uniqueness (across all mandates, no RBAC filtering)
|
||
if not self.isUsernameGloballyUnique(username):
|
||
return {"available": False, "message": "Username is already taken"}
|
||
|
||
return {"available": True, "message": "Username is available"}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error checking username availability: {str(e)}")
|
||
return {
|
||
"available": False,
|
||
"message": f"Error checking username availability: {str(e)}",
|
||
}
|
||
|
||
# Connection methods
|
||
|
||
def getUserConnections(self, userId: str) -> List[UserConnection]:
|
||
"""Returns all connections for a user."""
|
||
try:
|
||
# Get connections for this user
|
||
connections = self.db.getRecordset(
|
||
UserConnection, recordFilter={"userId": userId}
|
||
)
|
||
|
||
# Convert to UserConnection objects
|
||
result = []
|
||
for conn_dict in connections:
|
||
try:
|
||
connection = UserConnection.model_validate(conn_dict)
|
||
result.append(connection)
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Error converting connection dict to object: {str(e)}"
|
||
)
|
||
continue
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting user connections: {str(e)}")
|
||
return []
|
||
|
||
def getActiveKnowledgeConnections(self) -> List[UserConnection]:
|
||
"""Return all UserConnections with knowledgeIngestionEnabled=True and status=active.
|
||
|
||
Used by the daily re-sync scheduler to determine which connections to re-index.
|
||
"""
|
||
try:
|
||
rows = self.db.getRecordset(
|
||
UserConnection,
|
||
recordFilter={"knowledgeIngestionEnabled": True, "status": ConnectionStatus.ACTIVE.value},
|
||
)
|
||
result = []
|
||
for row in rows or []:
|
||
try:
|
||
conn = UserConnection.model_validate(row) if isinstance(row, dict) else row
|
||
result.append(conn)
|
||
except Exception as _e:
|
||
logger.warning(f"getActiveKnowledgeConnections: could not parse row: {_e}")
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"getActiveKnowledgeConnections failed: {e}")
|
||
return []
|
||
|
||
def getUserConnectionById(self, connectionId: str) -> Optional[UserConnection]:
|
||
"""Get a single UserConnection by ID or by reference string (connection:authority:username)."""
|
||
try:
|
||
# Try direct UUID lookup first
|
||
connections = self.db.getRecordset(
|
||
UserConnection, recordFilter={"id": connectionId}
|
||
)
|
||
|
||
# Fallback: parse "connection:authority:username" format from AI agent
|
||
if not connections and connectionId.startswith("connection:"):
|
||
parts = connectionId.split(":", 2)
|
||
if len(parts) >= 3:
|
||
authority = parts[1]
|
||
username = parts[2]
|
||
allConns = self.db.getRecordset(UserConnection, recordFilter={"externalUsername": username})
|
||
for c in (allConns or []):
|
||
a = c.get("authority", "")
|
||
aVal = a.value if hasattr(a, "value") else str(a)
|
||
if aVal == authority:
|
||
connections = [c]
|
||
break
|
||
|
||
if connections:
|
||
conn_dict = connections[0]
|
||
try:
|
||
return UserConnection.model_validate(conn_dict)
|
||
except Exception:
|
||
return UserConnection(
|
||
id=conn_dict["id"],
|
||
userId=conn_dict["userId"],
|
||
authority=conn_dict.get("authority"),
|
||
externalId=conn_dict.get("externalId", ""),
|
||
externalUsername=conn_dict.get("externalUsername", ""),
|
||
externalEmail=conn_dict.get("externalEmail"),
|
||
status=conn_dict.get("status", "pending"),
|
||
connectedAt=conn_dict.get("connectedAt"),
|
||
lastChecked=conn_dict.get("lastChecked"),
|
||
expiresAt=conn_dict.get("expiresAt"),
|
||
)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting user connection by ID: {str(e)}")
|
||
return None
|
||
|
||
def addUserConnection(
|
||
self,
|
||
userId: str,
|
||
authority: AuthAuthority,
|
||
externalId: str,
|
||
externalUsername: str,
|
||
externalEmail: Optional[str] = None,
|
||
status: ConnectionStatus = ConnectionStatus.PENDING,
|
||
) -> UserConnection:
|
||
"""
|
||
Adds a new connection for a user.
|
||
|
||
Args:
|
||
userId: The ID of the user
|
||
authority: The authentication authority (e.g., MSFT, GOOGLE)
|
||
externalId: The external ID from the authority
|
||
externalUsername: The username from the authority
|
||
externalEmail: Optional email from the authority
|
||
status: The connection status (defaults to PENDING)
|
||
|
||
Returns:
|
||
The created UserConnection object
|
||
"""
|
||
try:
|
||
# Note: User verification is skipped here because:
|
||
# 1. The caller (route) already has an authenticated currentUser
|
||
# 2. Users should always be able to create connections for themselves
|
||
# 3. getUser() uses RBAC filtering which may fail for users without UserInDB view permissions
|
||
|
||
# Create new connection with all required fields
|
||
connection = UserConnection(
|
||
id=str(uuid.uuid4()),
|
||
userId=userId,
|
||
authority=authority,
|
||
externalId=externalId,
|
||
externalUsername=externalUsername,
|
||
externalEmail=externalEmail,
|
||
status=status,
|
||
connectedAt=getUtcTimestamp(),
|
||
lastChecked=getUtcTimestamp(),
|
||
expiresAt=None, # Optional field, set to None by default
|
||
)
|
||
|
||
# Save to connections table
|
||
self.db.recordCreate(UserConnection, connection)
|
||
|
||
return connection
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error adding user connection: {str(e)}")
|
||
raise ValueError(f"Failed to add user connection: {str(e)}")
|
||
|
||
def removeUserConnection(self, connectionId: str) -> None:
|
||
"""Remove a connection to an external service"""
|
||
try:
|
||
# Get connection
|
||
connections = self.db.getRecordset(
|
||
UserConnection, recordFilter={"id": connectionId}
|
||
)
|
||
|
||
if not connections:
|
||
raise ValueError(f"Connection {connectionId} not found")
|
||
|
||
# Delete connection
|
||
self.db.recordDelete(UserConnection, connectionId)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error removing user connection: {str(e)}")
|
||
raise ValueError(f"Failed to remove user connection: {str(e)}")
|
||
|
||
# Mandate methods
|
||
|
||
def getAllMandates(self, pagination: Optional[PaginationParams] = None) -> Union[List[Mandate], PaginatedResult]:
|
||
"""
|
||
Returns all mandates based on user access level.
|
||
Supports optional pagination, sorting, and filtering.
|
||
|
||
Args:
|
||
pagination: Optional pagination parameters. If None, returns all items.
|
||
|
||
Returns:
|
||
If pagination is None: List[Mandate]
|
||
If pagination is provided: PaginatedResult with items and metadata
|
||
"""
|
||
# Use RBAC filtering
|
||
allMandates = getRecordsetWithRBAC(self.db, Mandate, self.currentUser, mandateId=self.mandateId)
|
||
|
||
# Filter out database-specific fields
|
||
filteredMandates = []
|
||
for mandate in allMandates:
|
||
cleanedMandate = dict(mandate)
|
||
filteredMandates.append(cleanedMandate)
|
||
|
||
# If no pagination requested, return all items
|
||
if pagination is None:
|
||
return [Mandate(**mandate) for mandate in filteredMandates]
|
||
|
||
# Apply filtering (if filters provided)
|
||
if pagination.filters:
|
||
filteredMandates = self._applyFilters(filteredMandates, pagination.filters)
|
||
|
||
# Apply sorting (in order of sortFields)
|
||
if pagination.sort:
|
||
filteredMandates = self._applySorting(filteredMandates, pagination.sort)
|
||
|
||
# Count total items after filters
|
||
totalItems = len(filteredMandates)
|
||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||
|
||
# Apply pagination (skip/limit)
|
||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||
endIdx = startIdx + pagination.pageSize
|
||
pagedMandates = filteredMandates[startIdx:endIdx]
|
||
|
||
# Convert to model objects
|
||
items = [Mandate(**mandate) for mandate in pagedMandates]
|
||
|
||
return PaginatedResult(
|
||
items=items,
|
||
totalItems=totalItems,
|
||
totalPages=totalPages
|
||
)
|
||
|
||
def getMandate(self, mandateId: str) -> Optional[Mandate]:
|
||
"""Returns a mandate by ID if user has access."""
|
||
# Use RBAC filtering
|
||
mandates = getRecordsetWithRBAC(self.db,
|
||
Mandate,
|
||
self.currentUser,
|
||
recordFilter={"id": mandateId},
|
||
mandateId=self.mandateId
|
||
)
|
||
|
||
if not mandates:
|
||
return None
|
||
|
||
# Filter out database-specific fields
|
||
filteredMandates = []
|
||
for mandate in mandates:
|
||
cleanedMandate = dict(mandate)
|
||
filteredMandates.append(cleanedMandate)
|
||
if not filteredMandates:
|
||
return None
|
||
|
||
return Mandate(**filteredMandates[0])
|
||
|
||
def _existingMandateNames(self, excludeId: Optional[str] = None) -> List[str]:
|
||
"""Return all mandate.name values currently in the DB (optionally excluding one id)."""
|
||
out: List[str] = []
|
||
for r in self.db.getRecordset(Mandate):
|
||
if excludeId and str(r.get("id")) == str(excludeId):
|
||
continue
|
||
n = r.get("name")
|
||
if n:
|
||
out.append(n)
|
||
return out
|
||
|
||
def _generateUniqueMandateName(self, label: str, excludeId: Optional[str] = None) -> str:
|
||
"""Generate a slug from *label* that is unique across all mandates (Phase 3 helper)."""
|
||
from modules.shared.mandateNameUtils import allocateUniqueMandateSlug, slugifyMandateName
|
||
|
||
base = slugifyMandateName(label or "")
|
||
return allocateUniqueMandateSlug(base, self._existingMandateNames(excludeId=excludeId))
|
||
|
||
def createMandate(self, name: str = None, label: str = None, enabled: bool = True) -> Mandate:
|
||
"""
|
||
Creates a new mandate if user has permission.
|
||
Automatically copies system template roles (admin, user, viewer) to the new mandate.
|
||
|
||
``label`` (Voller Name) is required (non-empty). If ``name`` (Kurzzeichen) is omitted or empty,
|
||
a unique slug is generated from the label; otherwise it is validated and uniqueness-checked.
|
||
"""
|
||
if not self.checkRbacPermission(Mandate, "create"):
|
||
raise PermissionError("No permission to create mandates")
|
||
|
||
from modules.shared.mandateNameUtils import isValidMandateName
|
||
|
||
effLabel = (label or "").strip() if label is not None else ""
|
||
if not effLabel and name:
|
||
effLabel = (name or "").strip()
|
||
if not effLabel:
|
||
raise ValueError("Mandate label (Voller Name) is required")
|
||
|
||
rawName = (name or "").strip() if name else ""
|
||
if not rawName:
|
||
rawName = self._generateUniqueMandateName(effLabel)
|
||
else:
|
||
if not isValidMandateName(rawName):
|
||
raise ValueError(
|
||
"Mandate Kurzzeichen must be 2–32 characters: lowercase a–z, digits, "
|
||
"hyphens only (single-hyphen segments)."
|
||
)
|
||
if rawName in self._existingMandateNames():
|
||
raise ValueError(f"Mandate Kurzzeichen '{rawName}' is already in use")
|
||
|
||
mandateData = Mandate(name=rawName, label=effLabel, enabled=enabled)
|
||
|
||
# Create mandate record
|
||
createdRecord = self.db.recordCreate(Mandate, mandateData)
|
||
if not createdRecord or not createdRecord.get("id"):
|
||
raise ValueError("Failed to create mandate record")
|
||
|
||
mandateId = createdRecord.get("id")
|
||
|
||
# Copy system template roles to new mandate (admin, user, viewer + AccessRules)
|
||
try:
|
||
from modules.interfaces.interfaceBootstrap import copySystemRolesToMandate
|
||
copiedCount = copySystemRolesToMandate(self.db, mandateId)
|
||
logger.info(f"Copied {copiedCount} system roles to new mandate {mandateId}")
|
||
except Exception as e:
|
||
logger.error(f"Error copying system roles to mandate {mandateId}: {e}")
|
||
|
||
return Mandate(**createdRecord)
|
||
|
||
def _provisionMandateForUser(self, userId: str, mandateLabel: str, planKey: str) -> Dict[str, Any]:
|
||
"""
|
||
Atomic provisioning: create Mandate + UserMandate + Subscription + auto-create FeatureInstances.
|
||
Internal method — bypasses RBAC (used during registration when user has no permissions yet).
|
||
|
||
``mandateLabel`` is the display name (Voller Name); a unique slug ``name`` (Kurzzeichen) is derived.
|
||
"""
|
||
from modules.datamodels.datamodelSubscription import MandateSubscription, SubscriptionStatusEnum, BUILTIN_PLANS
|
||
from modules.datamodels.datamodelFeatures import FeatureInstance
|
||
from modules.interfaces.interfaceBootstrap import copySystemRolesToMandate
|
||
from modules.interfaces.interfaceFeatures import getFeatureInterface
|
||
from modules.system.registry import loadFeatureMainModules
|
||
plan = BUILTIN_PLANS.get(planKey)
|
||
if not plan:
|
||
raise ValueError(f"Unknown plan: {planKey}")
|
||
|
||
effLabel = (mandateLabel or "").strip()
|
||
if not effLabel:
|
||
raise ValueError("mandateLabel (Voller Name) is required for provisioning")
|
||
|
||
uniqueName = self._generateUniqueMandateName(effLabel)
|
||
|
||
mandateData = Mandate(
|
||
name=uniqueName,
|
||
label=effLabel,
|
||
enabled=True,
|
||
isSystem=False,
|
||
)
|
||
createdMandate = self.db.recordCreate(Mandate, mandateData)
|
||
if not createdMandate or not createdMandate.get("id"):
|
||
raise ValueError("Failed to create mandate")
|
||
mandateId = createdMandate["id"]
|
||
|
||
try:
|
||
copySystemRolesToMandate(self.db, mandateId)
|
||
|
||
adminRoleId = None
|
||
mandateRoles = self.db.getRecordset(Role, recordFilter={"mandateId": mandateId, "featureInstanceId": None})
|
||
for r in mandateRoles:
|
||
if "admin" in (r.get("roleLabel") or "").lower():
|
||
adminRoleId = r.get("id")
|
||
break
|
||
|
||
if not adminRoleId:
|
||
raise ValueError(f"No admin role found for mandate {mandateId} — cannot assign user without role")
|
||
|
||
from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot
|
||
from modules.interfaces.interfaceDbBilling import getRootInterface as _getBillingRoot
|
||
from datetime import datetime, timezone, timedelta
|
||
|
||
now = datetime.now(timezone.utc)
|
||
nowTs = now.timestamp()
|
||
targetStatus = SubscriptionStatusEnum.TRIALING if plan.trialDays else SubscriptionStatusEnum.ACTIVE
|
||
subscription = MandateSubscription(
|
||
mandateId=mandateId,
|
||
planKey=planKey,
|
||
status=targetStatus,
|
||
startedAt=nowTs,
|
||
currentPeriodStart=nowTs,
|
||
)
|
||
if plan.trialDays:
|
||
trialEnd = now + timedelta(days=plan.trialDays)
|
||
subscription.trialEndsAt = trialEnd.timestamp()
|
||
subscription.currentPeriodEnd = trialEnd.timestamp()
|
||
|
||
subInterface = _getSubRoot()
|
||
subInterface.createSubscription(subscription)
|
||
|
||
try:
|
||
billingRoot = _getBillingRoot()
|
||
billingRoot.getOrCreateSettings(mandateId)
|
||
billingRoot.ensureActivationBudget(mandateId, planKey)
|
||
except Exception as billingEx:
|
||
logger.error(
|
||
"Initial billing setup failed for mandate %s (plan=%s): %s",
|
||
mandateId,
|
||
planKey,
|
||
billingEx,
|
||
)
|
||
|
||
self.createUserMandate(userId, mandateId, roleIds=[adminRoleId], skipCapacityCheck=True)
|
||
|
||
featureInterface = getFeatureInterface(self.db)
|
||
mainModules = loadFeatureMainModules()
|
||
createdInstances = []
|
||
for featureName, module in mainModules.items():
|
||
if not hasattr(module, "getFeatureDefinition"):
|
||
continue
|
||
try:
|
||
featureDef = module.getFeatureDefinition()
|
||
if not featureDef.get("autoCreateInstance", False):
|
||
continue
|
||
featureCode = featureDef.get("code", featureName)
|
||
featureLabel = resolveText(featureDef.get("label", featureName))
|
||
instance = featureInterface.createFeatureInstance(
|
||
featureCode=featureCode,
|
||
mandateId=mandateId,
|
||
label=featureLabel,
|
||
enabled=True,
|
||
copyTemplateRoles=True,
|
||
)
|
||
if instance:
|
||
instanceId = instance.get("id") if isinstance(instance, dict) else instance.id
|
||
createdInstances.append(instanceId)
|
||
instanceRoles = self.db.getRecordset(Role, recordFilter={"featureInstanceId": instanceId})
|
||
adminInstRoleId = None
|
||
for ir in instanceRoles:
|
||
roleLabel = (ir.get("roleLabel") or "").lower()
|
||
if roleLabel.endswith("-admin"):
|
||
adminInstRoleId = ir.get("id")
|
||
break
|
||
if not adminInstRoleId:
|
||
raise ValueError(
|
||
f"No feature-specific admin role (e.g. {featureCode}-admin) for instance {instanceId}. "
|
||
f"Template roles not synced for feature '{featureCode}'."
|
||
)
|
||
self.createFeatureAccess(userId, instanceId, roleIds=[adminInstRoleId])
|
||
except Exception as e:
|
||
logger.error(f"Error auto-creating instance for '{featureName}': {e}")
|
||
|
||
self._syncSubscriptionQuantity(mandateId)
|
||
|
||
logger.info(f"Provisioned mandate {mandateId} (plan={planKey}) for user {userId}, instances={createdInstances}")
|
||
return {
|
||
"mandateId": mandateId,
|
||
"planKey": planKey,
|
||
"featureInstances": createdInstances,
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Provisioning failed for user {userId}, cleaning up mandate {mandateId}: {e}")
|
||
try:
|
||
self.db.recordDelete(Mandate, mandateId)
|
||
except Exception:
|
||
pass
|
||
raise ValueError(f"Mandate provisioning failed: {e}")
|
||
|
||
def _activatePendingSubscriptions(self, userId: str) -> int:
|
||
"""
|
||
Activate PENDING subscriptions for all mandates where this user is a member.
|
||
Called on login — trial period begins NOW, not at registration.
|
||
Uses the subscription interface (poweron_billing) for all subscription operations.
|
||
Returns number of activated subscriptions.
|
||
"""
|
||
from modules.datamodels.datamodelSubscription import (
|
||
SubscriptionStatusEnum, BUILTIN_PLANS,
|
||
)
|
||
from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot
|
||
from datetime import datetime, timezone, timedelta
|
||
|
||
activated = 0
|
||
subInterface = _getSubRoot()
|
||
|
||
userMandates = self.db.getRecordset(
|
||
UserMandate, recordFilter={"userId": userId, "enabled": True}
|
||
)
|
||
|
||
for um in userMandates:
|
||
mandateId = um.get("mandateId")
|
||
allSubs = subInterface.listForMandate(mandateId)
|
||
pendingSubs = [s for s in allSubs if s.get("status") == SubscriptionStatusEnum.PENDING.value]
|
||
|
||
for sub in pendingSubs:
|
||
subId = sub.get("id")
|
||
planKey = sub.get("planKey")
|
||
plan = BUILTIN_PLANS.get(planKey)
|
||
now = datetime.now(timezone.utc)
|
||
|
||
targetStatus = SubscriptionStatusEnum.TRIALING if plan and plan.trialDays else SubscriptionStatusEnum.ACTIVE
|
||
additionalData = {
|
||
"currentPeriodStart": now.timestamp(),
|
||
}
|
||
|
||
if plan and plan.trialDays:
|
||
trialEnd = now + timedelta(days=plan.trialDays)
|
||
additionalData["trialEndsAt"] = trialEnd.timestamp()
|
||
additionalData["currentPeriodEnd"] = trialEnd.timestamp()
|
||
elif plan and plan.billingPeriod:
|
||
from modules.datamodels.datamodelSubscription import BillingPeriodEnum
|
||
if plan.billingPeriod == BillingPeriodEnum.MONTHLY:
|
||
additionalData["currentPeriodEnd"] = (now + timedelta(days=30)).timestamp()
|
||
elif plan.billingPeriod == BillingPeriodEnum.YEARLY:
|
||
additionalData["currentPeriodEnd"] = (now + timedelta(days=365)).timestamp()
|
||
|
||
try:
|
||
subInterface.transitionStatus(
|
||
subId,
|
||
expectedFromStatus=SubscriptionStatusEnum.PENDING,
|
||
toStatus=targetStatus,
|
||
additionalData=additionalData,
|
||
)
|
||
activated += 1
|
||
logger.info(f"Activated subscription {subId} (plan={planKey}) for mandate {mandateId}: {targetStatus.value}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to activate subscription {subId}: {e}")
|
||
|
||
return activated
|
||
|
||
def updateMandate(self, mandateId: str, updateData: Dict[str, Any]) -> Mandate:
|
||
"""
|
||
Updates a mandate if user has access.
|
||
|
||
Field-level rules:
|
||
- ``id`` always immutable.
|
||
- ``isSystem`` only sysadmin.
|
||
- ``name`` (Kurzzeichen) only platform/sysadmin; format and uniqueness are validated.
|
||
- ``label`` (Voller Name) must be non-empty if provided.
|
||
"""
|
||
from modules.shared.mandateNameUtils import isValidMandateName
|
||
|
||
try:
|
||
# First check if user has permission to modify mandates
|
||
if not self.checkRbacPermission(Mandate, "update", mandateId):
|
||
raise PermissionError(f"No permission to update mandate {mandateId}")
|
||
|
||
# Get mandate with access control
|
||
mandate = self.getMandate(mandateId)
|
||
if not mandate:
|
||
raise ValueError(f"Mandate {mandateId} not found")
|
||
|
||
_isSysAdmin = bool(getattr(self.currentUser, "isSysAdmin", False))
|
||
_isPlatformAdmin = bool(getattr(self.currentUser, "isPlatformAdmin", False))
|
||
|
||
_protectedFields = {"id"}
|
||
if not _isSysAdmin:
|
||
_protectedFields.add("isSystem")
|
||
if not (_isSysAdmin or _isPlatformAdmin):
|
||
_protectedFields.add("name")
|
||
_sanitizedData = {k: v for k, v in updateData.items() if k not in _protectedFields}
|
||
|
||
if "name" in _sanitizedData:
|
||
newName = (_sanitizedData["name"] or "").strip()
|
||
if not isValidMandateName(newName):
|
||
raise ValueError(
|
||
"Mandate Kurzzeichen must be 2–32 characters: lowercase a–z, digits, "
|
||
"hyphens only (single-hyphen segments)."
|
||
)
|
||
if newName != mandate.name and newName in self._existingMandateNames(excludeId=mandateId):
|
||
raise ValueError(f"Mandate Kurzzeichen '{newName}' is already in use")
|
||
_sanitizedData["name"] = newName
|
||
|
||
if "label" in _sanitizedData:
|
||
newLabel = (_sanitizedData["label"] or "").strip()
|
||
if not newLabel:
|
||
raise ValueError("Mandate Voller Name (label) must not be empty.")
|
||
_sanitizedData["label"] = newLabel
|
||
|
||
# Update mandate data using model
|
||
updatedData = mandate.model_dump()
|
||
updatedData.update(_sanitizedData)
|
||
updatedMandate = Mandate(**updatedData)
|
||
|
||
# Update mandate record
|
||
self.db.recordModify(Mandate, mandateId, updatedMandate)
|
||
|
||
# Clear cache to ensure fresh data
|
||
|
||
# Get updated mandate
|
||
updatedMandate = self.getMandate(mandateId)
|
||
if not updatedMandate:
|
||
raise ValueError("Failed to retrieve updated mandate")
|
||
|
||
return updatedMandate
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error updating mandate: {str(e)}")
|
||
raise ValueError(f"Failed to update mandate: {str(e)}")
|
||
|
||
def deleteMandate(self, mandateId: str, force: bool = False) -> bool:
|
||
"""
|
||
Delete a mandate with full cascade.
|
||
|
||
Default (force=False): soft-delete — sets enabled=false.
|
||
With force=True: hard-delete — removes all related data.
|
||
System mandates (isSystem=True) cannot be deleted.
|
||
"""
|
||
try:
|
||
mandate = self.getMandate(mandateId)
|
||
if not mandate:
|
||
return False
|
||
|
||
if getattr(mandate, "isSystem", False):
|
||
raise ValueError(f"System mandate '{mandate.name}' cannot be deleted")
|
||
|
||
if not self.checkRbacPermission(Mandate, "delete", mandateId):
|
||
raise PermissionError(f"No permission to delete mandate {mandateId}")
|
||
|
||
if not force:
|
||
from modules.shared.timeUtils import getUtcTimestamp
|
||
self.db.recordModify(Mandate, mandateId, {"enabled": False, "deletedAt": getUtcTimestamp()})
|
||
logger.info(f"Soft-deleted mandate {mandateId} (30-day retention)")
|
||
return True
|
||
|
||
# Hard delete with cascade
|
||
from modules.datamodels.datamodelSubscription import MandateSubscription
|
||
from modules.datamodels.datamodelChat import ChatWorkflow, ChatMessage, ChatLog
|
||
from modules.datamodels.datamodelFiles import FileItem
|
||
from modules.datamodels.datamodelDataSource import DataSource
|
||
from modules.datamodels.datamodelKnowledge import FileContentIndex, ContentChunk
|
||
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
|
||
from modules.datamodels.datamodelBilling import BillingSettings, BillingAccount, BillingTransaction
|
||
from modules.features.neutralization.datamodelFeatureNeutralizer import DataNeutralizerAttributes
|
||
|
||
instances = self.db.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId})
|
||
|
||
# 0-pre. Delete AutoWorkflow data in Greenfield DB (poweron_graphicaleditor)
|
||
self._cascadeDeleteGraphicalEditorData(mandateId, instances)
|
||
|
||
# 0. Delete instance-scoped data for each FeatureInstance
|
||
for inst in instances:
|
||
instId = inst.get("id")
|
||
if not instId:
|
||
continue
|
||
|
||
# 0a. ContentChunk (embeddings) + FileContentIndex (knowledge/RAG)
|
||
fciRecords = self.db.getRecordset(FileContentIndex, recordFilter={"featureInstanceId": instId})
|
||
for rec in fciRecords:
|
||
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileContentIndexId": rec.get("id")})
|
||
for chunk in chunks:
|
||
self.db.recordDelete(ContentChunk, chunk.get("id"))
|
||
self.db.recordDelete(FileContentIndex, rec.get("id"))
|
||
if fciRecords:
|
||
logger.info(f"Cascade: deleted {len(fciRecords)} FileContentIndex records (with chunks) for instance {instId}")
|
||
|
||
# 0b. DataNeutralizerAttributes
|
||
dnaRecords = self.db.getRecordset(DataNeutralizerAttributes, recordFilter={"featureInstanceId": instId})
|
||
for rec in dnaRecords:
|
||
self.db.recordDelete(DataNeutralizerAttributes, rec.get("id"))
|
||
if dnaRecords:
|
||
logger.info(f"Cascade: deleted {len(dnaRecords)} DataNeutralizerAttributes for instance {instId}")
|
||
|
||
# 0c. DataSource
|
||
dsRecords = self.db.getRecordset(DataSource, recordFilter={"featureInstanceId": instId})
|
||
for rec in dsRecords:
|
||
self.db.recordDelete(DataSource, rec.get("id"))
|
||
if dsRecords:
|
||
logger.info(f"Cascade: deleted {len(dsRecords)} DataSource records for instance {instId}")
|
||
|
||
# 0c2. FeatureDataSource
|
||
fdsRecords = self.db.getRecordset(FeatureDataSource, recordFilter={"featureInstanceId": instId})
|
||
for rec in fdsRecords:
|
||
self.db.recordDelete(FeatureDataSource, rec.get("id"))
|
||
if fdsRecords:
|
||
logger.info(f"Cascade: deleted {len(fdsRecords)} FeatureDataSource records for instance {instId}")
|
||
|
||
# 0d. FileItem
|
||
fileRecords = self.db.getRecordset(FileItem, recordFilter={"featureInstanceId": instId})
|
||
for rec in fileRecords:
|
||
self.db.recordDelete(FileItem, rec.get("id"))
|
||
if fileRecords:
|
||
logger.info(f"Cascade: deleted {len(fileRecords)} FileItem records for instance {instId}")
|
||
|
||
# 0e. ChatWorkflow + ChatMessage + ChatLog
|
||
workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"featureInstanceId": instId})
|
||
for wf in workflows:
|
||
wfId = wf.get("id")
|
||
if not wfId:
|
||
continue
|
||
msgs = self.db.getRecordset(ChatMessage, recordFilter={"workflowId": wfId})
|
||
for msg in msgs:
|
||
self.db.recordDelete(ChatMessage, msg.get("id"))
|
||
logs = self.db.getRecordset(ChatLog, recordFilter={"workflowId": wfId})
|
||
for log in logs:
|
||
self.db.recordDelete(ChatLog, log.get("id"))
|
||
self.db.recordDelete(ChatWorkflow, wfId)
|
||
if workflows:
|
||
logger.info(f"Cascade: deleted {len(workflows)} ChatWorkflows (with messages/logs) for instance {instId}")
|
||
|
||
# 1. Delete FeatureAccess + FeatureAccessRole for all instances
|
||
for inst in instances:
|
||
instId = inst.get("id")
|
||
accesses = self.db.getRecordset(FeatureAccess, recordFilter={"featureInstanceId": instId})
|
||
for access in accesses:
|
||
roles = self.db.getRecordset(FeatureAccessRole, recordFilter={"featureAccessId": access.get("id")})
|
||
for role in roles:
|
||
self.db.recordDelete(FeatureAccessRole, role.get("id"))
|
||
self.db.recordDelete(FeatureAccess, access.get("id"))
|
||
self.db.recordDelete(FeatureInstance, instId)
|
||
logger.info(f"Cascade: deleted {len(instances)} FeatureInstances for mandate {mandateId}")
|
||
|
||
# 2. Delete UserMandate + UserMandateRole
|
||
memberships = self.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})
|
||
for um in memberships:
|
||
umRoles = self.db.getRecordset(UserMandateRole, recordFilter={"userMandateId": um.get("id")})
|
||
for umr in umRoles:
|
||
self.db.recordDelete(UserMandateRole, umr.get("id"))
|
||
self.db.recordDelete(UserMandate, um.get("id"))
|
||
logger.info(f"Cascade: deleted {len(memberships)} UserMandates for mandate {mandateId}")
|
||
|
||
# 3. Cancel Stripe subscriptions + delete MandateSubscription records (poweron_billing)
|
||
from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot
|
||
subInterface = _getSubRoot()
|
||
subs = subInterface.listForMandate(mandateId)
|
||
for sub in subs:
|
||
subId = sub.get("id")
|
||
stripeSubId = sub.get("stripeSubscriptionId")
|
||
if stripeSubId:
|
||
try:
|
||
from modules.shared.stripeClient import getStripeClient
|
||
stripe = getStripeClient()
|
||
stripe.Subscription.cancel(stripeSubId)
|
||
logger.info(f"Cancelled Stripe subscription {stripeSubId} for mandate {mandateId}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to cancel Stripe sub {stripeSubId}: {e}")
|
||
subInterface.db.recordDelete(MandateSubscription, subId)
|
||
logger.info(f"Cascade: deleted {len(subs)} subscriptions for mandate {mandateId}")
|
||
|
||
# 3b. Delete Billing data (poweron_billing)
|
||
from modules.interfaces.interfaceDbBilling import getRootInterface as _getBillingRoot
|
||
billingDb = _getBillingRoot().db
|
||
billingAccounts = billingDb.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId})
|
||
for acc in billingAccounts:
|
||
accTxs = billingDb.getRecordset(BillingTransaction, recordFilter={"accountId": acc.get("id")})
|
||
for tx in accTxs:
|
||
billingDb.recordDelete(BillingTransaction, tx.get("id"))
|
||
billingDb.recordDelete(BillingAccount, acc.get("id"))
|
||
billingSettings = billingDb.getRecordset(BillingSettings, recordFilter={"mandateId": mandateId})
|
||
for bs in billingSettings:
|
||
billingDb.recordDelete(BillingSettings, bs.get("id"))
|
||
if billingAccounts or billingSettings:
|
||
logger.info(f"Cascade: deleted billing data for mandate {mandateId}")
|
||
|
||
# 3c. Delete Invitations for this mandate
|
||
from modules.datamodels.datamodelInvitation import Invitation
|
||
invitations = self.db.getRecordset(Invitation, recordFilter={"mandateId": mandateId})
|
||
for inv in invitations:
|
||
self.db.recordDelete(Invitation, inv.get("id"))
|
||
if invitations:
|
||
logger.info(f"Cascade: deleted {len(invitations)} Invitations for mandate {mandateId}")
|
||
|
||
# 4. Delete mandate-level Roles
|
||
from modules.datamodels.datamodelRbac import Role, AccessRule
|
||
roles = self.db.getRecordset(Role, recordFilter={"mandateId": mandateId})
|
||
for role in roles:
|
||
rules = self.db.getRecordset(AccessRule, recordFilter={"roleId": role.get("id")})
|
||
for rule in rules:
|
||
self.db.recordDelete(AccessRule, rule.get("id"))
|
||
self.db.recordDelete(Role, role.get("id"))
|
||
logger.info(f"Cascade: deleted {len(roles)} Roles for mandate {mandateId}")
|
||
|
||
# 5. Delete mandate record
|
||
success = self.db.recordDelete(Mandate, mandateId)
|
||
logger.info(f"Hard-deleted mandate {mandateId}")
|
||
return success
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error deleting mandate: {str(e)}")
|
||
raise ValueError(f"Failed to delete mandate: {str(e)}")
|
||
|
||
def _cascadeDeleteGraphicalEditorData(self, mandateId: str, instances) -> None:
|
||
"""Delete AutoWorkflow + related data in the Greenfield DB for all graphicalEditor instances."""
|
||
try:
|
||
from modules.features.graphicalEditor.datamodelFeatureGraphicalEditor import (
|
||
AutoWorkflow, AutoVersion, AutoRun, AutoStepLog, AutoTask,
|
||
)
|
||
from modules.features.graphicalEditor.interfaceFeatureGraphicalEditor import graphicalEditorDatabase
|
||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||
|
||
geDb = DatabaseConnector(
|
||
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
||
dbDatabase=graphicalEditorDatabase,
|
||
dbUser=APP_CONFIG.get("DB_USER"),
|
||
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET") or APP_CONFIG.get("DB_PASSWORD"),
|
||
dbPort=int(APP_CONFIG.get("DB_PORT", 5432)),
|
||
userId=None,
|
||
)
|
||
|
||
if not geDb._ensureTableExists(AutoWorkflow):
|
||
return
|
||
|
||
geInstances = [
|
||
inst for inst in instances
|
||
if (inst.get("featureCode") if isinstance(inst, dict) else getattr(inst, "featureCode", "")) == "graphicalEditor"
|
||
]
|
||
|
||
totalDeleted = 0
|
||
for inst in geInstances:
|
||
instId = inst.get("id") if isinstance(inst, dict) else getattr(inst, "id", None)
|
||
if not instId:
|
||
continue
|
||
|
||
workflows = geDb.getRecordset(AutoWorkflow, recordFilter={
|
||
"mandateId": mandateId,
|
||
"featureInstanceId": instId,
|
||
}) or []
|
||
|
||
for wf in workflows:
|
||
wfId = wf.get("id")
|
||
if not wfId:
|
||
continue
|
||
|
||
for v in geDb.getRecordset(AutoVersion, recordFilter={"workflowId": wfId}) or []:
|
||
geDb.recordDelete(AutoVersion, v.get("id"))
|
||
|
||
for run in geDb.getRecordset(AutoRun, recordFilter={"workflowId": wfId}) or []:
|
||
runId = run.get("id")
|
||
for sl in geDb.getRecordset(AutoStepLog, recordFilter={"runId": runId}) or []:
|
||
geDb.recordDelete(AutoStepLog, sl.get("id"))
|
||
geDb.recordDelete(AutoRun, runId)
|
||
|
||
for task in geDb.getRecordset(AutoTask, recordFilter={"workflowId": wfId}) or []:
|
||
geDb.recordDelete(AutoTask, task.get("id"))
|
||
|
||
geDb.recordDelete(AutoWorkflow, wfId)
|
||
totalDeleted += 1
|
||
|
||
if totalDeleted:
|
||
logger.info(f"Cascade: deleted {totalDeleted} AutoWorkflow(s) in Greenfield DB for mandate {mandateId}")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to cascade-delete graphical editor data for mandate {mandateId}: {e}")
|
||
|
||
def restoreMandate(self, mandateId: str) -> bool:
|
||
"""Restore a soft-deleted mandate (undo soft-delete within the 30-day retention window)."""
|
||
mandate = self.getMandate(mandateId)
|
||
if not mandate:
|
||
return False
|
||
self.db.recordModify(Mandate, mandateId, {"enabled": True, "deletedAt": None})
|
||
logger.info(f"Restored soft-deleted mandate {mandateId}")
|
||
return True
|
||
|
||
def purgeExpiredMandates(self, retentionDays: int = 30) -> int:
|
||
"""Hard-delete all mandates whose soft-delete timestamp exceeds the retention period."""
|
||
import time
|
||
cutoff = time.time() - (retentionDays * 86400)
|
||
allMandates = self.db.getRecordset(Mandate)
|
||
purged = 0
|
||
for m in allMandates:
|
||
deletedAt = m.get("deletedAt") if isinstance(m, dict) else getattr(m, "deletedAt", None)
|
||
enabled = m.get("enabled") if isinstance(m, dict) else getattr(m, "enabled", True)
|
||
mandateId = m.get("id") if isinstance(m, dict) else getattr(m, "id", None)
|
||
if deletedAt and not enabled and deletedAt < cutoff and mandateId:
|
||
try:
|
||
self.deleteMandate(mandateId, force=True)
|
||
purged += 1
|
||
except Exception as e:
|
||
logger.error(f"Failed to purge expired mandate {mandateId}: {e}")
|
||
if purged:
|
||
logger.info(f"Purged {purged} expired mandate(s) beyond {retentionDays}-day retention")
|
||
return purged
|
||
|
||
# ============================================
|
||
# User-Mandate Membership Methods (Multi-Tenant)
|
||
# ============================================
|
||
|
||
def getUserMandate(self, userId: str, mandateId: str) -> Optional[UserMandate]:
|
||
"""
|
||
Get UserMandate record for a user in a specific mandate.
|
||
|
||
Args:
|
||
userId: User ID
|
||
mandateId: Mandate ID
|
||
|
||
Returns:
|
||
UserMandate object or None
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
UserMandate,
|
||
recordFilter={"userId": userId, "mandateId": mandateId}
|
||
)
|
||
if not records:
|
||
return None
|
||
cleanedRecord = dict(records[0])
|
||
return UserMandate(**cleanedRecord)
|
||
except Exception as e:
|
||
logger.error(f"Error getting UserMandate: {e}")
|
||
return None
|
||
|
||
def getUserMandates(self, userId: str) -> List[UserMandate]:
|
||
"""
|
||
Get all mandates a user is member of.
|
||
|
||
Args:
|
||
userId: User ID
|
||
|
||
Returns:
|
||
List of UserMandate objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
UserMandate,
|
||
recordFilter={"userId": userId, "enabled": True}
|
||
)
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(UserMandate(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting UserMandates: {e}")
|
||
return []
|
||
|
||
def createUserMandate(self, userId: str, mandateId: str, roleIds: List[str] = None, *, skipCapacityCheck: bool = False) -> UserMandate:
|
||
"""
|
||
Create a UserMandate record (add user to mandate).
|
||
Also creates a billing audit account for the user if billing is configured.
|
||
|
||
INVARIANT: A UserMandate MUST have at least one UserMandateRole.
|
||
|
||
Args:
|
||
userId: User ID
|
||
mandateId: Mandate ID
|
||
roleIds: List of role IDs to assign (at least one required)
|
||
skipCapacityCheck: If True, skip subscription capacity check (used during initial provisioning
|
||
when the subscription hasn't been created yet)
|
||
|
||
Returns:
|
||
Created UserMandate object
|
||
"""
|
||
if not roleIds:
|
||
raise ValueError(f"Cannot create UserMandate without roles for user {userId} in mandate {mandateId}")
|
||
|
||
try:
|
||
existing = self.getUserMandate(userId, mandateId)
|
||
if existing:
|
||
raise ValueError(f"User {userId} is already member of mandate {mandateId}")
|
||
|
||
if not skipCapacityCheck:
|
||
self._checkSubscriptionCapacity(mandateId, "users", delta=1)
|
||
|
||
userMandate = UserMandate(
|
||
userId=userId,
|
||
mandateId=mandateId,
|
||
enabled=True
|
||
)
|
||
createdRecord = self.db.recordCreate(UserMandate, userMandate.model_dump())
|
||
if not createdRecord:
|
||
raise ValueError("Database failed to create UserMandate record")
|
||
|
||
userMandateId = createdRecord.get("id")
|
||
for roleId in roleIds:
|
||
userMandateRole = UserMandateRole(
|
||
userMandateId=userMandateId,
|
||
roleId=roleId
|
||
)
|
||
self.db.recordCreate(UserMandateRole, userMandateRole.model_dump())
|
||
|
||
self._ensureUserBillingAccount(userId, mandateId)
|
||
self._syncSubscriptionQuantity(mandateId)
|
||
if not skipCapacityCheck:
|
||
self._adjustAiBudgetForUserChange(mandateId, delta=+1)
|
||
|
||
cleanedRecord = dict(createdRecord)
|
||
return UserMandate(**cleanedRecord)
|
||
except Exception as e:
|
||
if e.__class__.__name__ == "SubscriptionCapacityException":
|
||
raise
|
||
logger.error(f"Error creating UserMandate: {e}")
|
||
raise ValueError(f"Failed to create UserMandate: {e}") from e
|
||
|
||
def _ensureUserBillingAccount(self, userId: str, mandateId: str) -> None:
|
||
"""
|
||
Ensure a user has a billing audit account for the mandate.
|
||
Balance is always on the mandate pool (PREPAY_MANDATE). User accounts are for audit trail only.
|
||
"""
|
||
try:
|
||
from modules.interfaces.interfaceDbBilling import getRootInterface as getBillingRootInterface
|
||
|
||
billingInterface = getBillingRootInterface()
|
||
settings = billingInterface.getSettings(mandateId)
|
||
|
||
if not settings:
|
||
return
|
||
|
||
billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=0.0)
|
||
logger.info(f"Ensured billing audit account for user {userId} in mandate {mandateId}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to create billing account for user {userId} (non-critical): {e}")
|
||
|
||
def _checkSubscriptionCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> None:
|
||
"""Check subscription capacity before creating a resource. Raises on cap violation."""
|
||
try:
|
||
from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface
|
||
from modules.security.rootAccess import getRootUser
|
||
subIf = getSubInterface(getRootUser(), mandateId)
|
||
subIf.assertCapacity(mandateId, resourceType, delta)
|
||
except Exception as e:
|
||
if "SubscriptionCapacityException" in type(e).__name__:
|
||
raise
|
||
logger.debug(f"Subscription capacity check skipped: {e}")
|
||
|
||
def _syncSubscriptionQuantity(self, mandateId: str, *, raiseOnError: bool = False) -> None:
|
||
"""Sync Stripe subscription quantities after a resource mutation.
|
||
|
||
Args:
|
||
raiseOnError: If True, propagate errors (billing-critical paths).
|
||
"""
|
||
try:
|
||
from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface
|
||
from modules.security.rootAccess import getRootUser
|
||
subIf = getSubInterface(getRootUser(), mandateId)
|
||
operative = subIf.getOperativeForMandate(mandateId)
|
||
if not operative:
|
||
if raiseOnError:
|
||
raise ValueError(f"Kein operatives Abonnement für Mandant {mandateId}")
|
||
logger.debug("No operative subscription for mandate %s — quantity sync skipped", mandateId)
|
||
return
|
||
subIf.syncQuantityToStripe(operative["id"], raiseOnError=raiseOnError)
|
||
except Exception as e:
|
||
if raiseOnError:
|
||
raise
|
||
logger.debug(f"Subscription quantity sync skipped: {e}")
|
||
|
||
def _adjustAiBudgetForUserChange(self, mandateId: str, delta: int) -> None:
|
||
"""Pro-rata AI budget credit/debit when a user is added or removed mid-cycle."""
|
||
try:
|
||
from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface
|
||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
|
||
from modules.security.rootAccess import getRootUser
|
||
rootUser = getRootUser()
|
||
subIf = getSubInterface(rootUser, mandateId)
|
||
operative = subIf.getOperativeForMandate(mandateId)
|
||
if not operative:
|
||
return
|
||
planKey = operative.get("planKey", "")
|
||
billingIf = getBillingInterface(rootUser)
|
||
billingIf.adjustAiBudgetForUserChange(mandateId, planKey, delta)
|
||
except Exception as e:
|
||
logger.debug(f"AI budget adjustment skipped: {e}")
|
||
|
||
def deleteUserMandate(self, userId: str, mandateId: str) -> bool:
|
||
"""
|
||
Delete a UserMandate record (remove user from mandate).
|
||
CASCADE will delete UserMandateRole entries.
|
||
Also removes FeatureAccess rows for any feature instances that belong to this mandate
|
||
(FeatureAccessRole rows cascade from FeatureAccess).
|
||
|
||
Args:
|
||
userId: User ID
|
||
mandateId: Mandate ID
|
||
|
||
Returns:
|
||
True if deleted, False if not found
|
||
"""
|
||
try:
|
||
existing = self.getUserMandate(userId, mandateId)
|
||
if not existing:
|
||
return False
|
||
|
||
# Drop feature-instance memberships for instances under this mandate
|
||
instanceRows = self.db.getRecordset(
|
||
FeatureInstance,
|
||
recordFilter={"mandateId": mandateId}
|
||
)
|
||
for row in instanceRows:
|
||
instId = row.get("id")
|
||
if not instId:
|
||
continue
|
||
accessRows = self.db.getRecordset(
|
||
FeatureAccess,
|
||
recordFilter={"userId": userId, "featureInstanceId": instId}
|
||
)
|
||
for acc in accessRows:
|
||
accId = acc.get("id")
|
||
if accId:
|
||
self.db.recordDelete(FeatureAccess, accId)
|
||
|
||
result = self.db.recordDelete(UserMandate, existing.id)
|
||
self._syncSubscriptionQuantity(mandateId)
|
||
self._adjustAiBudgetForUserChange(mandateId, delta=-1)
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error deleting UserMandate: {e}")
|
||
raise ValueError(f"Failed to delete UserMandate: {e}")
|
||
|
||
def getUserMandatesByMandate(self, mandateId: str) -> List[UserMandate]:
|
||
"""
|
||
Get all UserMandate records for a specific mandate.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
|
||
Returns:
|
||
List of UserMandate objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
UserMandate,
|
||
recordFilter={"mandateId": mandateId}
|
||
)
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(UserMandate(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting UserMandates for mandate {mandateId}: {e}")
|
||
return []
|
||
|
||
def getUserMandateRoles(self, userMandateId: str) -> List[UserMandateRole]:
|
||
"""
|
||
Get all UserMandateRole records for a UserMandate.
|
||
|
||
Args:
|
||
userMandateId: UserMandate ID
|
||
|
||
Returns:
|
||
List of UserMandateRole objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
UserMandateRole,
|
||
recordFilter={"userMandateId": userMandateId}
|
||
)
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(UserMandateRole(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting UserMandateRoles: {e}")
|
||
return []
|
||
|
||
def deleteUserMandateRoles(self, userMandateId: str) -> int:
|
||
"""
|
||
Delete all role assignments for a UserMandate.
|
||
|
||
Args:
|
||
userMandateId: UserMandate ID
|
||
|
||
Returns:
|
||
Number of deleted role assignments
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
UserMandateRole,
|
||
recordFilter={"userMandateId": userMandateId}
|
||
)
|
||
deletedCount = 0
|
||
for record in records:
|
||
if self.db.recordDelete(UserMandateRole, record.get("id")):
|
||
deletedCount += 1
|
||
return deletedCount
|
||
except Exception as e:
|
||
logger.error(f"Error deleting UserMandateRoles: {e}")
|
||
return 0
|
||
|
||
def validateRoleForMandate(self, roleId: str, mandateId: str) -> Role:
|
||
"""
|
||
Validate a role exists and belongs to the specified mandate (or is global).
|
||
|
||
Args:
|
||
roleId: Role ID to validate
|
||
mandateId: Mandate ID for context validation
|
||
|
||
Returns:
|
||
Role object if valid
|
||
|
||
Raises:
|
||
ValueError: If role not found or belongs to different mandate
|
||
"""
|
||
role = self.getRole(roleId)
|
||
if not role:
|
||
raise ValueError(f"Role {roleId} not found")
|
||
|
||
# Check mandate scope
|
||
if role.mandateId and str(role.mandateId) != str(mandateId):
|
||
raise ValueError(f"Role {roleId} belongs to a different mandate")
|
||
|
||
# Check feature-instance scope (not allowed at mandate level)
|
||
if role.featureInstanceId:
|
||
raise ValueError(f"Role {roleId} is a feature-instance role and cannot be assigned at mandate level")
|
||
|
||
return role
|
||
|
||
def getRoleIdsForUserMandate(self, userMandateId: str) -> List[str]:
|
||
"""
|
||
Get all role IDs assigned to a UserMandate.
|
||
|
||
Args:
|
||
userMandateId: UserMandate ID
|
||
|
||
Returns:
|
||
List of role IDs
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
UserMandateRole,
|
||
recordFilter={"userMandateId": userMandateId}
|
||
)
|
||
return [r.get("roleId") for r in records if r.get("roleId")]
|
||
except Exception as e:
|
||
logger.error(f"Error getting role IDs for UserMandate: {e}")
|
||
return []
|
||
|
||
def addRoleToUserMandate(self, userMandateId: str, roleId: str) -> UserMandateRole:
|
||
"""
|
||
Add a role to a UserMandate.
|
||
|
||
Args:
|
||
userMandateId: UserMandate ID
|
||
roleId: Role ID to add
|
||
|
||
Returns:
|
||
Created UserMandateRole object
|
||
"""
|
||
try:
|
||
# Check if already exists
|
||
existing = self.db.getRecordset(
|
||
UserMandateRole,
|
||
recordFilter={"userMandateId": userMandateId, "roleId": roleId}
|
||
)
|
||
if existing:
|
||
cleanedRecord = dict(existing[0])
|
||
return UserMandateRole(**cleanedRecord)
|
||
|
||
userMandateRole = UserMandateRole(
|
||
userMandateId=userMandateId,
|
||
roleId=roleId
|
||
)
|
||
createdRecord = self.db.recordCreate(UserMandateRole, userMandateRole.model_dump())
|
||
cleanedRecord = dict(createdRecord)
|
||
return UserMandateRole(**cleanedRecord)
|
||
except Exception as e:
|
||
logger.error(f"Error adding role to UserMandate: {e}")
|
||
raise ValueError(f"Failed to add role: {e}")
|
||
|
||
def removeRoleFromUserMandate(self, userMandateId: str, roleId: str) -> bool:
|
||
"""
|
||
Remove a role from a UserMandate.
|
||
If no roles remain, the UserMandate is deleted (Application-Level Cleanup).
|
||
|
||
Args:
|
||
userMandateId: UserMandate ID
|
||
roleId: Role ID to remove
|
||
|
||
Returns:
|
||
True if removed
|
||
"""
|
||
try:
|
||
# Find and delete the junction record
|
||
records = self.db.getRecordset(
|
||
UserMandateRole,
|
||
recordFilter={"userMandateId": userMandateId, "roleId": roleId}
|
||
)
|
||
if not records:
|
||
return False
|
||
|
||
self.db.recordDelete(UserMandateRole, records[0].get("id"))
|
||
|
||
# Application-Level Cleanup: Delete UserMandate if no roles remain
|
||
remainingRoles = self.db.getRecordset(
|
||
UserMandateRole,
|
||
recordFilter={"userMandateId": userMandateId}
|
||
)
|
||
if not remainingRoles:
|
||
self.db.recordDelete(UserMandate, userMandateId)
|
||
logger.info(f"Deleted empty UserMandate {userMandateId}")
|
||
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error removing role from UserMandate: {e}")
|
||
raise ValueError(f"Failed to remove role: {e}")
|
||
|
||
# ============================================
|
||
# Feature Access Methods (Multi-Tenant)
|
||
# ============================================
|
||
|
||
def getFeatureAccess(self, userId: str, featureInstanceId: str) -> Optional[FeatureAccess]:
|
||
"""
|
||
Get FeatureAccess record for a user to a specific feature instance.
|
||
|
||
Args:
|
||
userId: User ID
|
||
featureInstanceId: FeatureInstance ID
|
||
|
||
Returns:
|
||
FeatureAccess object or None
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
FeatureAccess,
|
||
recordFilter={"userId": userId, "featureInstanceId": featureInstanceId}
|
||
)
|
||
if not records:
|
||
return None
|
||
cleanedRecord = dict(records[0])
|
||
return FeatureAccess(**cleanedRecord)
|
||
except Exception as e:
|
||
logger.error(f"Error getting FeatureAccess: {e}")
|
||
return None
|
||
|
||
def getFeatureAccessesForUser(self, userId: str) -> List[FeatureAccess]:
|
||
"""
|
||
Get all feature accesses for a user.
|
||
|
||
Args:
|
||
userId: User ID
|
||
|
||
Returns:
|
||
List of FeatureAccess objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
FeatureAccess,
|
||
recordFilter={"userId": userId, "enabled": True}
|
||
)
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(FeatureAccess(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting FeatureAccesses: {e}")
|
||
return []
|
||
|
||
def getFeatureAccessesByInstance(self, featureInstanceId: str) -> List[FeatureAccess]:
|
||
"""
|
||
Get all FeatureAccess records for a specific feature instance.
|
||
|
||
Args:
|
||
featureInstanceId: FeatureInstance ID
|
||
|
||
Returns:
|
||
List of FeatureAccess objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
FeatureAccess,
|
||
recordFilter={"featureInstanceId": featureInstanceId}
|
||
)
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(FeatureAccess(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting FeatureAccesses for instance {featureInstanceId}: {e}")
|
||
return []
|
||
|
||
def createFeatureAccess(self, userId: str, featureInstanceId: str, roleIds: List[str] = None) -> FeatureAccess:
|
||
"""
|
||
Create a FeatureAccess record (grant user access to feature instance).
|
||
Also auto-assigns the user to the mandate with the 'user' role if not already a member.
|
||
|
||
INVARIANT: A FeatureAccess MUST have at least one FeatureAccessRole.
|
||
|
||
Args:
|
||
userId: User ID
|
||
featureInstanceId: FeatureInstance ID
|
||
roleIds: List of role IDs to assign (at least one required)
|
||
|
||
Returns:
|
||
Created FeatureAccess object
|
||
"""
|
||
if not roleIds:
|
||
raise ValueError(f"Cannot create FeatureAccess without roles for user {userId} on instance {featureInstanceId}")
|
||
|
||
try:
|
||
existing = self.getFeatureAccess(userId, featureInstanceId)
|
||
if existing:
|
||
raise ValueError(f"User {userId} already has access to feature instance {featureInstanceId}")
|
||
|
||
self._ensureUserMandateMembership(userId, featureInstanceId)
|
||
|
||
featureAccess = FeatureAccess(
|
||
userId=userId,
|
||
featureInstanceId=featureInstanceId,
|
||
enabled=True
|
||
)
|
||
createdRecord = self.db.recordCreate(FeatureAccess, featureAccess.model_dump())
|
||
if not createdRecord:
|
||
raise ValueError("Database failed to create FeatureAccess record")
|
||
|
||
featureAccessId = createdRecord.get("id")
|
||
for roleId in roleIds:
|
||
featureAccessRole = FeatureAccessRole(
|
||
featureAccessId=featureAccessId,
|
||
roleId=roleId
|
||
)
|
||
self.db.recordCreate(FeatureAccessRole, featureAccessRole.model_dump())
|
||
|
||
cleanedRecord = dict(createdRecord)
|
||
return FeatureAccess(**cleanedRecord)
|
||
except Exception as e:
|
||
logger.error(f"Error creating FeatureAccess: {e}")
|
||
raise ValueError(f"Failed to create FeatureAccess: {e}")
|
||
|
||
def _ensureUserMandateMembership(self, userId: str, featureInstanceId: str) -> None:
|
||
"""
|
||
Ensure user is a member of the mandate that owns the feature instance.
|
||
If not already a member, adds them with the 'user' role.
|
||
"""
|
||
try:
|
||
from modules.interfaces.interfaceFeatures import getFeatureInterface
|
||
|
||
featureInterface = getFeatureInterface(self.db)
|
||
instance = featureInterface.getFeatureInstance(featureInstanceId)
|
||
if not instance or not instance.mandateId:
|
||
logger.warning(f"Cannot auto-assign mandate: feature instance {featureInstanceId} not found or has no mandateId")
|
||
return
|
||
|
||
mandateId = str(instance.mandateId)
|
||
|
||
existing = self.getUserMandate(userId, mandateId)
|
||
if existing:
|
||
logger.debug(f"User {userId} already member of mandate {mandateId}")
|
||
return
|
||
|
||
userRoles = self.db.getRecordset(
|
||
Role,
|
||
recordFilter={"roleLabel": "user", "mandateId": mandateId, "featureInstanceId": None}
|
||
)
|
||
userRoleId = userRoles[0].get("id") if userRoles else None
|
||
if not userRoleId:
|
||
raise ValueError(f"No 'user' role found for mandate {mandateId} — cannot assign user without role")
|
||
|
||
self.createUserMandate(userId, mandateId, roleIds=[userRoleId])
|
||
logger.info(f"Auto-assigned user {userId} to mandate {mandateId} with 'user' role (via feature instance {featureInstanceId})")
|
||
|
||
except ValueError as ve:
|
||
if "already member" in str(ve):
|
||
pass
|
||
else:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error auto-assigning user {userId} to mandate: {e}")
|
||
raise
|
||
|
||
def getRoleIdsForFeatureAccess(self, featureAccessId: str) -> List[str]:
|
||
"""
|
||
Get all role IDs assigned to a FeatureAccess.
|
||
|
||
Args:
|
||
featureAccessId: FeatureAccess ID
|
||
|
||
Returns:
|
||
List of role IDs
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
FeatureAccessRole,
|
||
recordFilter={"featureAccessId": featureAccessId}
|
||
)
|
||
return [r.get("roleId") for r in records if r.get("roleId")]
|
||
except Exception as e:
|
||
logger.error(f"Error getting role IDs for FeatureAccess: {e}")
|
||
return []
|
||
|
||
def addRoleToFeatureAccess(self, featureAccessId: str, roleId: str) -> None:
|
||
"""
|
||
Add a role to a FeatureAccess (via junction table).
|
||
Skips if the role is already assigned.
|
||
|
||
Args:
|
||
featureAccessId: FeatureAccess ID
|
||
roleId: Role ID to add
|
||
"""
|
||
try:
|
||
# Check if already exists
|
||
existing = self.db.getRecordset(
|
||
FeatureAccessRole,
|
||
recordFilter={"featureAccessId": featureAccessId, "roleId": roleId}
|
||
)
|
||
if existing:
|
||
return # Already assigned
|
||
|
||
featureAccessRole = FeatureAccessRole(
|
||
featureAccessId=featureAccessId,
|
||
roleId=roleId
|
||
)
|
||
self.db.recordCreate(FeatureAccessRole, featureAccessRole.model_dump())
|
||
logger.debug(f"Added role {roleId} to FeatureAccess {featureAccessId}")
|
||
except Exception as e:
|
||
logger.error(f"Error adding role to FeatureAccess: {e}")
|
||
raise ValueError(f"Failed to add role to FeatureAccess: {e}")
|
||
|
||
def deleteFeatureAccessRoles(self, featureAccessId: str) -> int:
|
||
"""
|
||
Delete all FeatureAccessRole records for a FeatureAccess.
|
||
|
||
Args:
|
||
featureAccessId: FeatureAccess ID
|
||
|
||
Returns:
|
||
Number of records deleted
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(
|
||
FeatureAccessRole,
|
||
recordFilter={"featureAccessId": featureAccessId}
|
||
)
|
||
count = 0
|
||
for record in records:
|
||
recordId = record.get("id")
|
||
if recordId:
|
||
self.db.recordDelete(FeatureAccessRole, recordId)
|
||
count += 1
|
||
return count
|
||
except Exception as e:
|
||
logger.error(f"Error deleting FeatureAccessRoles for {featureAccessId}: {e}")
|
||
return 0
|
||
|
||
# ============================================
|
||
# Invitation Methods
|
||
# ============================================
|
||
|
||
def getInvitation(self, invitationId: str) -> Optional[Invitation]:
|
||
"""
|
||
Get an invitation by ID.
|
||
|
||
Args:
|
||
invitationId: Invitation ID
|
||
|
||
Returns:
|
||
Invitation object if found, None otherwise
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(Invitation, recordFilter={"id": invitationId})
|
||
if records:
|
||
cleanedRecord = dict(records[0])
|
||
return Invitation(**cleanedRecord)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting invitation {invitationId}: {e}")
|
||
return None
|
||
|
||
def getInvitationByToken(self, token: str) -> Optional[Invitation]:
|
||
"""
|
||
Get an invitation by token.
|
||
|
||
Args:
|
||
token: Invitation token
|
||
|
||
Returns:
|
||
Invitation object if found, None otherwise
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(Invitation, recordFilter={"token": token})
|
||
if records:
|
||
cleanedRecord = dict(records[0])
|
||
return Invitation(**cleanedRecord)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting invitation by token: {e}")
|
||
return None
|
||
|
||
def getInvitationsByMandate(self, mandateId: str) -> List[Invitation]:
|
||
"""
|
||
Get all invitations for a mandate.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
|
||
Returns:
|
||
List of Invitation objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(Invitation, recordFilter={"mandateId": mandateId})
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(Invitation(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting invitations for mandate {mandateId}: {e}")
|
||
return []
|
||
|
||
def getInvitationsByCreator(self, creatorId: str) -> List[Invitation]:
|
||
"""
|
||
Get all invitations created by a user.
|
||
|
||
Args:
|
||
creatorId: User ID who created the invitations
|
||
|
||
Returns:
|
||
List of Invitation objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(Invitation, recordFilter={"sysCreatedBy": creatorId})
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(Invitation(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting invitations by creator {creatorId}: {e}")
|
||
return []
|
||
|
||
def getInvitationsByUsedBy(self, usedById: str) -> List[Invitation]:
|
||
"""
|
||
Get all invitations used by a user.
|
||
|
||
Args:
|
||
usedById: User ID who used the invitations
|
||
|
||
Returns:
|
||
List of Invitation objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(Invitation, recordFilter={"usedBy": usedById})
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(Invitation(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting invitations used by {usedById}: {e}")
|
||
return []
|
||
|
||
def getInvitationsByTargetUsername(self, targetUsername: str) -> List[Invitation]:
|
||
"""
|
||
Get all invitations for a target username.
|
||
|
||
Args:
|
||
targetUsername: Target username for the invitations
|
||
|
||
Returns:
|
||
List of Invitation objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(Invitation, recordFilter={"targetUsername": targetUsername})
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(Invitation(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting invitations for target username {targetUsername}: {e}")
|
||
return []
|
||
|
||
def getInvitationsByEmail(self, email: str) -> List[Invitation]:
|
||
"""Get all invitations for a target email address (email-only invitations)."""
|
||
try:
|
||
records = self.db.getRecordset(Invitation, recordFilter={"email": email})
|
||
result = []
|
||
for record in records:
|
||
result.append(Invitation(**dict(record)))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting invitations for email {email}: {e}")
|
||
return []
|
||
|
||
# ============================================
|
||
# Additional Helper Methods
|
||
# ============================================
|
||
|
||
def getAllUsers(self, pagination: Optional[PaginationParams] = None) -> Union[List[User], PaginatedResult]:
|
||
"""
|
||
Get all users (for SysAdmin only).
|
||
|
||
Args:
|
||
pagination: Optional pagination parameters. If None, returns all items.
|
||
|
||
Returns:
|
||
If pagination is None: List[User] (without sensitive fields)
|
||
If pagination is provided: PaginatedResult with items and metadata
|
||
"""
|
||
try:
|
||
result = self.db.getRecordsetPaginated(UserInDB, pagination=pagination)
|
||
|
||
items = []
|
||
for record in result["items"]:
|
||
user = User.model_validate(record)
|
||
if user.roleLabels is None:
|
||
user = user.model_copy(update={"roleLabels": []})
|
||
items.append(user)
|
||
|
||
if pagination is None:
|
||
return items
|
||
|
||
return PaginatedResult(
|
||
items=items,
|
||
totalItems=result["totalItems"],
|
||
totalPages=result["totalPages"]
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error getting all users: {e}")
|
||
if pagination is None:
|
||
return []
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
|
||
def getUserMandateById(self, userMandateId: str) -> Optional[UserMandate]:
|
||
"""
|
||
Get a UserMandate by its ID.
|
||
|
||
Args:
|
||
userMandateId: UserMandate ID
|
||
|
||
Returns:
|
||
UserMandate object if found, None otherwise
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(UserMandate, recordFilter={"id": userMandateId})
|
||
if records:
|
||
cleanedRecord = dict(records[0])
|
||
return UserMandate(**cleanedRecord)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting UserMandate {userMandateId}: {e}")
|
||
return None
|
||
|
||
def getUserMandateRolesByRole(self, roleId: str) -> List[UserMandateRole]:
|
||
"""
|
||
Get all UserMandateRole records for a specific role.
|
||
|
||
Args:
|
||
roleId: Role ID
|
||
|
||
Returns:
|
||
List of UserMandateRole objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(UserMandateRole, recordFilter={"roleId": roleId})
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(UserMandateRole(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting UserMandateRoles for role {roleId}: {e}")
|
||
return []
|
||
|
||
def getFeatureInstance(self, instanceId: str):
|
||
"""
|
||
Get a FeatureInstance by ID.
|
||
|
||
Args:
|
||
instanceId: FeatureInstance ID
|
||
|
||
Returns:
|
||
FeatureInstance object if found, None otherwise
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(FeatureInstance, recordFilter={"id": instanceId})
|
||
if records:
|
||
cleanedRecord = dict(records[0])
|
||
return FeatureInstance(**cleanedRecord)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting FeatureInstance {instanceId}: {e}")
|
||
return None
|
||
|
||
def getFeatureByCode(self, featureCode: str) -> Optional[Feature]:
|
||
"""
|
||
Get a Feature by its code.
|
||
|
||
Args:
|
||
featureCode: Feature code
|
||
|
||
Returns:
|
||
Feature object if found, None otherwise
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(Feature, recordFilter={"code": featureCode})
|
||
if records:
|
||
cleanedRecord = dict(records[0])
|
||
return Feature(**cleanedRecord)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting Feature by code {featureCode}: {e}")
|
||
return None
|
||
|
||
def getFeatureInstancesByMandate(self, mandateId: str, enabledOnly: bool = False) -> List[FeatureInstance]:
|
||
"""
|
||
Get all FeatureInstances for a mandate.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
enabledOnly: If True, only return enabled instances
|
||
|
||
Returns:
|
||
List of FeatureInstance objects
|
||
"""
|
||
try:
|
||
recordFilter = {"mandateId": mandateId}
|
||
if enabledOnly:
|
||
recordFilter["enabled"] = True
|
||
records = self.db.getRecordset(FeatureInstance, recordFilter=recordFilter)
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(FeatureInstance(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting FeatureInstances for mandate {mandateId}: {e}")
|
||
return []
|
||
|
||
# ============================================
|
||
# Notification Methods
|
||
# ============================================
|
||
|
||
def getNotification(self, notificationId: str) -> Optional[UserNotification]:
|
||
"""
|
||
Get a notification by ID.
|
||
|
||
Args:
|
||
notificationId: Notification ID
|
||
|
||
Returns:
|
||
UserNotification object if found, None otherwise
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(UserNotification, recordFilter={"id": notificationId})
|
||
if records:
|
||
cleanedRecord = dict(records[0])
|
||
return UserNotification(**cleanedRecord)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting notification {notificationId}: {e}")
|
||
return None
|
||
|
||
def getNotificationsByUser(
|
||
self,
|
||
userId: str,
|
||
status: Optional[str] = None,
|
||
limit: Optional[int] = None
|
||
) -> List[UserNotification]:
|
||
"""
|
||
Get notifications for a user.
|
||
|
||
Args:
|
||
userId: User ID
|
||
status: Optional status filter (e.g., 'unread')
|
||
limit: Optional limit on number of results
|
||
|
||
Returns:
|
||
List of UserNotification objects
|
||
"""
|
||
try:
|
||
recordFilter = {"userId": userId}
|
||
if status:
|
||
recordFilter["status"] = status
|
||
records = self.db.getRecordset(UserNotification, recordFilter=recordFilter)
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(UserNotification(**cleanedRecord))
|
||
# Sort by sysCreatedAt descending
|
||
result.sort(key=lambda x: x.sysCreatedAt or 0, reverse=True)
|
||
if limit:
|
||
result = result[:limit]
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting notifications for user {userId}: {e}")
|
||
return []
|
||
|
||
# ============================================
|
||
# AccessRule Methods
|
||
# ============================================
|
||
|
||
def getAccessRule(self, ruleId: str) -> Optional[AccessRule]:
|
||
"""
|
||
Get an AccessRule by ID.
|
||
|
||
Args:
|
||
ruleId: AccessRule ID
|
||
|
||
Returns:
|
||
AccessRule object if found, None otherwise
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(AccessRule, recordFilter={"id": ruleId})
|
||
if records:
|
||
cleanedRecord = dict(records[0])
|
||
return AccessRule(**cleanedRecord)
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting AccessRule {ruleId}: {e}")
|
||
return None
|
||
|
||
def getAccessRulesByRole(self, roleId: str) -> List[AccessRule]:
|
||
"""
|
||
Get all AccessRules for a role.
|
||
|
||
Args:
|
||
roleId: Role ID
|
||
|
||
Returns:
|
||
List of AccessRule objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(AccessRule, recordFilter={"roleId": roleId})
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(AccessRule(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting AccessRules for role {roleId}: {e}")
|
||
return []
|
||
|
||
def getRolesByFeatureInstance(self, featureInstanceId: str) -> List[Role]:
|
||
"""
|
||
Get all roles for a feature instance.
|
||
|
||
Args:
|
||
featureInstanceId: FeatureInstance ID
|
||
|
||
Returns:
|
||
List of Role objects
|
||
"""
|
||
try:
|
||
records = self.db.getRecordset(Role, recordFilter={"featureInstanceId": featureInstanceId})
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(Role(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting roles for feature instance {featureInstanceId}: {e}")
|
||
return []
|
||
|
||
def getRolesByFeatureCode(self, featureCode: str, featureInstanceId: Optional[str] = None) -> List[Role]:
|
||
"""
|
||
Get all roles for a feature code, optionally filtered by instance.
|
||
|
||
Args:
|
||
featureCode: Feature code
|
||
featureInstanceId: Optional FeatureInstance ID filter
|
||
|
||
Returns:
|
||
List of Role objects
|
||
"""
|
||
try:
|
||
recordFilter = {"featureCode": featureCode}
|
||
if featureInstanceId:
|
||
recordFilter["featureInstanceId"] = featureInstanceId
|
||
records = self.db.getRecordset(Role, recordFilter=recordFilter)
|
||
result = []
|
||
for record in records:
|
||
cleanedRecord = dict(record)
|
||
result.append(Role(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting roles for feature code {featureCode}: {e}")
|
||
return []
|
||
|
||
# Token methods
|
||
|
||
def saveAccessToken(self, token: Token, replace_existing: bool = True) -> None:
|
||
"""Save an access token for the current user (must NOT have connectionId)"""
|
||
try:
|
||
# Validate that this is NOT a connection token
|
||
if token.connectionId:
|
||
raise ValueError(
|
||
"Access tokens cannot have connectionId - use saveConnectionToken instead"
|
||
)
|
||
|
||
_tp = (
|
||
token.tokenPurpose.value
|
||
if isinstance(token.tokenPurpose, TokenPurpose)
|
||
else token.tokenPurpose
|
||
)
|
||
if _tp != TokenPurpose.AUTH_SESSION.value:
|
||
raise ValueError(
|
||
"saveAccessToken requires tokenPurpose=authSession (gateway session JWT)"
|
||
)
|
||
|
||
# Validate user context
|
||
if not self.currentUser or not self.currentUser.id:
|
||
raise ValueError("No valid user context available for token storage")
|
||
|
||
# Set the user ID and mandate ID
|
||
token.userId = self.currentUser.id
|
||
|
||
# Ensure token has required fields
|
||
if not token.id:
|
||
token.id = str(uuid.uuid4())
|
||
if not token.sysCreatedAt:
|
||
token.sysCreatedAt = getUtcTimestamp()
|
||
|
||
# If replace_existing is True, delete old access tokens for this user and authority first
|
||
if replace_existing:
|
||
try:
|
||
old_tokens = self.db.getRecordset(
|
||
Token,
|
||
recordFilter={
|
||
"userId": self.currentUser.id,
|
||
"authority": token.authority,
|
||
"connectionId": None, # Ensure we only delete access tokens
|
||
"tokenPurpose": TokenPurpose.AUTH_SESSION.value,
|
||
},
|
||
)
|
||
deleted_count = 0
|
||
for old_token in old_tokens:
|
||
if (
|
||
old_token["id"] != token.id
|
||
): # Don't delete the new token if it already exists
|
||
self.db.recordDelete(Token, old_token["id"])
|
||
deleted_count += 1
|
||
|
||
if deleted_count > 0:
|
||
logger.info(
|
||
f"Replaced {deleted_count} old access tokens for user {self.currentUser.id} and authority {token.authority}"
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Failed to delete old access tokens for user {self.currentUser.id} and authority {token.authority}: {str(e)}"
|
||
)
|
||
# Continue with saving the new token even if deletion fails
|
||
|
||
token_dict = token.model_dump()
|
||
token_dict["userId"] = self.currentUser.id
|
||
|
||
# Save to database
|
||
self.db.recordCreate(Token, token_dict)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error saving access token: {str(e)}")
|
||
raise
|
||
|
||
def saveConnectionToken(self, token: Token, replace_existing: bool = True) -> None:
|
||
"""Save a connection token (must have connectionId)"""
|
||
try:
|
||
# Validate that this IS a connection token
|
||
if not token.connectionId:
|
||
raise ValueError(
|
||
"Connection tokens must have connectionId - use saveAccessToken instead"
|
||
)
|
||
|
||
_tp = (
|
||
token.tokenPurpose.value
|
||
if isinstance(token.tokenPurpose, TokenPurpose)
|
||
else token.tokenPurpose
|
||
)
|
||
if _tp != TokenPurpose.DATA_CONNECTION.value:
|
||
raise ValueError(
|
||
"saveConnectionToken requires tokenPurpose=dataConnection (provider OAuth)"
|
||
)
|
||
|
||
# Validate user context
|
||
if not self.currentUser or not self.currentUser.id:
|
||
raise ValueError("No valid user context available for token storage")
|
||
|
||
# Set the user ID for the connection token
|
||
token.userId = self.currentUser.id
|
||
|
||
# Ensure token has required fields
|
||
if not token.id:
|
||
token.id = str(uuid.uuid4())
|
||
if not token.sysCreatedAt:
|
||
token.sysCreatedAt = getUtcTimestamp()
|
||
|
||
# Convert to dict and ensure all fields are properly set
|
||
token_dict = token.model_dump()
|
||
# Ensure userId is set to current user
|
||
token_dict["userId"] = self.currentUser.id
|
||
|
||
# Save to database
|
||
self.db.recordCreate(Token, token_dict)
|
||
|
||
# After successful save, delete old tokens for this connectionId (if requested)
|
||
if replace_existing:
|
||
try:
|
||
old_tokens = self.db.getRecordset(
|
||
Token, recordFilter={"connectionId": token.connectionId}
|
||
)
|
||
deleted_count = 0
|
||
for old_token in old_tokens:
|
||
if old_token["id"] != token.id:
|
||
self.db.recordDelete(Token, old_token["id"])
|
||
deleted_count += 1
|
||
|
||
if deleted_count > 0:
|
||
logger.info(
|
||
f"Replaced {deleted_count} old tokens for connectionId {token.connectionId}"
|
||
)
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Failed to delete old tokens for connectionId {token.connectionId}: {str(e)}"
|
||
)
|
||
# Keep the newly saved token; cleanup can be retried later
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error saving connection token: {str(e)}")
|
||
raise
|
||
|
||
def getConnectionToken(self, connectionId: str) -> Optional[Token]:
|
||
"""Get the latest stored token for a specific connectionId (no refresh)."""
|
||
try:
|
||
# Validate connectionId
|
||
if not connectionId:
|
||
raise ValueError("connectionId is required for getConnectionToken")
|
||
|
||
# Get token for this specific connection
|
||
# Query for specific connection
|
||
tokens = self.db.getRecordset(
|
||
Token, recordFilter={"connectionId": connectionId}
|
||
)
|
||
|
||
if not tokens:
|
||
# Pending connections legitimately have no token yet (PAT not
|
||
# submitted, OAuth callback not completed). Keep at DEBUG to
|
||
# avoid noisy warnings on every connection-list refresh.
|
||
logger.debug(
|
||
f"No connection token found for connectionId: {connectionId}"
|
||
)
|
||
return None
|
||
|
||
# Sort by expiration date and get the latest (most recent expiration)
|
||
tokens.sort(key=lambda x: parseTimestamp(x.get("expiresAt"), default=0), reverse=True)
|
||
latest_token = Token(**tokens[0])
|
||
|
||
# No auto-refresh here. Callers should use a higher-level service to refresh when needed.
|
||
|
||
return latest_token
|
||
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Error getting connection token for connectionId {connectionId}: {str(e)}"
|
||
)
|
||
return None
|
||
|
||
def getTokensByConnectionIdAndAuthority(
|
||
self, connectionId: str, authority: AuthAuthority
|
||
) -> List[Token]:
|
||
"""Get tokens for a connection with specific authority."""
|
||
try:
|
||
tokens = self.db.getRecordset(
|
||
Token, recordFilter={
|
||
"connectionId": connectionId,
|
||
"authority": authority.value if hasattr(authority, 'value') else str(authority)
|
||
}
|
||
)
|
||
result = []
|
||
for token_dict in tokens:
|
||
cleanedRecord = dict(token_dict)
|
||
result.append(Token(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting tokens by connection and authority: {str(e)}")
|
||
return []
|
||
|
||
def getTokensByUserIdNoConnection(
|
||
self, userId: str, authority: AuthAuthority
|
||
) -> List[Token]:
|
||
"""Get tokens for a user without a connection (access tokens)."""
|
||
try:
|
||
tokens = self.db.getRecordset(
|
||
Token, recordFilter={
|
||
"userId": userId,
|
||
"connectionId": None,
|
||
"authority": authority.value if hasattr(authority, 'value') else str(authority)
|
||
}
|
||
)
|
||
result = []
|
||
for token_dict in tokens:
|
||
cleanedRecord = dict(token_dict)
|
||
result.append(Token(**cleanedRecord))
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"Error getting tokens by user and authority: {str(e)}")
|
||
return []
|
||
|
||
def getAllTokens(self, recordFilter: dict = None) -> List[dict]:
|
||
"""Get all tokens with optional filtering (returns raw dicts)."""
|
||
try:
|
||
tokens = self.db.getRecordset(Token, recordFilter=recordFilter or {})
|
||
return tokens
|
||
except Exception as e:
|
||
logger.error(f"Error getting all tokens: {str(e)}")
|
||
return []
|
||
|
||
def findActiveTokenById(
|
||
self,
|
||
tokenId: str,
|
||
userId: str,
|
||
authority: AuthAuthority,
|
||
sessionId: str = None,
|
||
mandateId: str = None,
|
||
tokenPurpose: str = None,
|
||
) -> Optional[Token]:
|
||
"""Find an active access token by its id (jti) with optional session/tenant scoping."""
|
||
try:
|
||
recordFilter = {
|
||
"id": tokenId,
|
||
"userId": userId,
|
||
"authority": authority.value
|
||
if hasattr(authority, "value")
|
||
else str(authority),
|
||
"status": TokenStatus.ACTIVE,
|
||
}
|
||
if sessionId is not None:
|
||
recordFilter["sessionId"] = sessionId
|
||
if mandateId is not None:
|
||
recordFilter["mandateId"] = mandateId
|
||
if tokenPurpose is not None:
|
||
recordFilter["tokenPurpose"] = tokenPurpose
|
||
tokens = self.db.getRecordset(Token, recordFilter=recordFilter)
|
||
if not tokens:
|
||
return None
|
||
return Token(**tokens[0])
|
||
except Exception as e:
|
||
logger.error(f"Error finding active token by id {tokenId}: {str(e)}")
|
||
return None
|
||
|
||
def revokeTokenById(self, tokenId: str, revokedBy: str, reason: str = None) -> bool:
|
||
"""Revoke a single token by id by setting status fields (no delete)."""
|
||
try:
|
||
existing = self.db.getRecordset(Token, recordFilter={"id": tokenId})
|
||
if not existing:
|
||
return False
|
||
token = existing[0]
|
||
if token.get("status") == TokenStatus.REVOKED:
|
||
return True
|
||
tokenUpdate = {
|
||
"status": TokenStatus.REVOKED,
|
||
"revokedAt": getUtcTimestamp(),
|
||
"revokedBy": revokedBy,
|
||
"reason": reason or "revoked",
|
||
}
|
||
self.db.recordModify(Token, tokenId, tokenUpdate)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error revoking token {tokenId}: {str(e)}")
|
||
return False
|
||
|
||
def revokeTokensBySessionId(
|
||
self,
|
||
sessionId: str,
|
||
userId: str,
|
||
authority: AuthAuthority,
|
||
revokedBy: str,
|
||
reason: str = None,
|
||
) -> int:
|
||
"""Revoke all tokens of a session for a user/authority."""
|
||
try:
|
||
tokens = self.db.getRecordset(
|
||
Token,
|
||
recordFilter={
|
||
"userId": userId,
|
||
"authority": authority.value
|
||
if hasattr(authority, "value")
|
||
else str(authority),
|
||
"sessionId": sessionId,
|
||
"status": TokenStatus.ACTIVE,
|
||
},
|
||
)
|
||
count = 0
|
||
for t in tokens:
|
||
self.db.recordModify(
|
||
Token,
|
||
t["id"],
|
||
{
|
||
"status": TokenStatus.REVOKED,
|
||
"revokedAt": getUtcTimestamp(),
|
||
"revokedBy": revokedBy,
|
||
"reason": reason or "session logout",
|
||
},
|
||
)
|
||
count += 1
|
||
return count
|
||
except Exception as e:
|
||
logger.error(f"Error revoking tokens for session {sessionId}: {str(e)}")
|
||
return 0
|
||
|
||
def revokeTokensByUser(
|
||
self,
|
||
userId: str,
|
||
authority: AuthAuthority = None,
|
||
mandateId: str = None,
|
||
revokedBy: str = None,
|
||
reason: str = None,
|
||
) -> int:
|
||
"""Revoke all active tokens for a user, optionally filtered by authority/mandate."""
|
||
try:
|
||
# Fetch all active tokens for user (optionally filtered by authority)
|
||
recordFilter = {
|
||
"userId": userId,
|
||
"status": TokenStatus.ACTIVE,
|
||
}
|
||
if authority is not None:
|
||
recordFilter["authority"] = (
|
||
authority.value if hasattr(authority, "value") else str(authority)
|
||
)
|
||
tokens = self.db.getRecordset(Token, recordFilter=recordFilter)
|
||
count = 0
|
||
for t in tokens:
|
||
self.db.recordModify(
|
||
Token,
|
||
t["id"],
|
||
{
|
||
"status": TokenStatus.REVOKED,
|
||
"revokedAt": getUtcTimestamp(),
|
||
"revokedBy": revokedBy,
|
||
"reason": reason or "admin revoke",
|
||
},
|
||
)
|
||
count += 1
|
||
return count
|
||
except Exception as e:
|
||
logger.error(f"Error revoking tokens for user {userId}: {str(e)}")
|
||
return 0
|
||
|
||
def cleanupExpiredTokens(self) -> int:
|
||
"""Clean up expired tokens for all connections, returns count of cleaned tokens"""
|
||
try:
|
||
current_time = getUtcTimestamp()
|
||
cleaned_count = 0
|
||
|
||
# Get all tokens
|
||
all_tokens = self.db.getRecordset(Token, recordFilter={})
|
||
|
||
for token_data in all_tokens:
|
||
expiresAt = parseTimestamp(token_data.get("expiresAt"))
|
||
if expiresAt and expiresAt < current_time:
|
||
# Token is expired, delete it
|
||
self.db.recordDelete(Token, token_data["id"])
|
||
cleaned_count += 1
|
||
|
||
# Clear cache to ensure fresh data
|
||
if cleaned_count > 0:
|
||
logger.info(f"Cleaned up {cleaned_count} expired tokens")
|
||
|
||
return cleaned_count
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error cleaning up expired tokens: {str(e)}")
|
||
return 0
|
||
|
||
def logout(self) -> None:
|
||
"""Logout current user - clear user context and tokens"""
|
||
try:
|
||
# Clear user context
|
||
self.currentUser = None
|
||
self.userId = None
|
||
self.mandateId = None
|
||
self.rbac = None
|
||
|
||
# Clear database context
|
||
if hasattr(self, "db"):
|
||
self.db.updateContext("")
|
||
|
||
logger.info("User logged out successfully")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error during logout: {str(e)}")
|
||
raise
|
||
|
||
# RBAC CRUD Methods
|
||
|
||
def createAccessRule(self, accessRule: AccessRule) -> AccessRule:
|
||
"""
|
||
Create a new access rule.
|
||
|
||
Args:
|
||
accessRule: AccessRule object to create
|
||
|
||
Returns:
|
||
Created AccessRule object
|
||
"""
|
||
try:
|
||
createdRule = self.db.recordCreate(AccessRule, accessRule)
|
||
logger.info(f"Created access rule with ID {createdRule.get('id')}")
|
||
return AccessRule(**createdRule)
|
||
except Exception as e:
|
||
logger.error(f"Error creating access rule: {str(e)}")
|
||
raise
|
||
|
||
def getAccessRule(self, ruleId: str) -> Optional[AccessRule]:
|
||
"""
|
||
Get an access rule by ID.
|
||
|
||
Args:
|
||
ruleId: Access rule ID
|
||
|
||
Returns:
|
||
AccessRule object if found, None otherwise
|
||
"""
|
||
try:
|
||
rules = self.db.getRecordset(AccessRule, recordFilter={"id": ruleId})
|
||
if rules:
|
||
return AccessRule(**rules[0])
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting access rule {ruleId}: {str(e)}")
|
||
return None
|
||
|
||
def updateAccessRule(self, ruleId: str, accessRule: AccessRule) -> AccessRule:
|
||
"""
|
||
Update an existing access rule.
|
||
|
||
Args:
|
||
ruleId: Access rule ID
|
||
accessRule: Updated AccessRule object
|
||
|
||
Returns:
|
||
Updated AccessRule object
|
||
"""
|
||
try:
|
||
# Exclude id from model_dump - the URL ruleId is authoritative
|
||
updatedRule = self.db.recordModify(AccessRule, ruleId, accessRule.model_dump(exclude={"id"}))
|
||
logger.info(f"Updated access rule with ID {ruleId}")
|
||
return AccessRule(**updatedRule)
|
||
except Exception as e:
|
||
logger.error(f"Error updating access rule {ruleId}: {str(e)}")
|
||
raise
|
||
|
||
def deleteAccessRule(self, ruleId: str) -> bool:
|
||
"""
|
||
Delete an access rule.
|
||
|
||
Args:
|
||
ruleId: Access rule ID
|
||
|
||
Returns:
|
||
True if deleted successfully, False otherwise
|
||
"""
|
||
try:
|
||
self.db.recordDelete(AccessRule, ruleId)
|
||
logger.info(f"Deleted access rule with ID {ruleId}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error deleting access rule {ruleId}: {str(e)}")
|
||
return False
|
||
|
||
def getAccessRules(
|
||
self,
|
||
roleLabel: Optional[str] = None,
|
||
roleId: Optional[str] = None,
|
||
context: Optional[AccessRuleContext] = None,
|
||
item: Optional[str] = None,
|
||
pagination: Optional[PaginationParams] = None
|
||
) -> Union[List[AccessRule], PaginatedResult]:
|
||
"""
|
||
Get access rules with optional filters and pagination.
|
||
|
||
Args:
|
||
roleLabel: Optional role label filter (deprecated, use roleId)
|
||
roleId: Optional role ID filter
|
||
context: Optional context filter
|
||
item: Optional item filter
|
||
pagination: Optional pagination parameters. If None, returns all items.
|
||
|
||
Returns:
|
||
If pagination is None: List[AccessRule]
|
||
If pagination is provided: PaginatedResult with items and metadata
|
||
"""
|
||
try:
|
||
recordFilter = {}
|
||
if roleId:
|
||
recordFilter["roleId"] = roleId
|
||
elif roleLabel:
|
||
recordFilter["roleLabel"] = roleLabel
|
||
if context:
|
||
recordFilter["context"] = context.value
|
||
if item:
|
||
recordFilter["item"] = item
|
||
|
||
# Use RBAC filtering
|
||
rules = getRecordsetWithRBAC(
|
||
self.db,
|
||
AccessRule,
|
||
self.currentUser,
|
||
recordFilter=recordFilter if recordFilter else None,
|
||
mandateId=self.mandateId
|
||
)
|
||
|
||
# Filter out database-specific fields
|
||
filteredRules = []
|
||
for rule in rules:
|
||
cleanedRule = dict(rule)
|
||
filteredRules.append(cleanedRule)
|
||
|
||
# If no pagination requested, return all items
|
||
if pagination is None:
|
||
return [AccessRule(**rule) for rule in filteredRules]
|
||
|
||
# Apply filtering (if filters provided)
|
||
if pagination.filters:
|
||
filteredRules = self._applyFilters(filteredRules, pagination.filters)
|
||
|
||
# Apply sorting (in order of sortFields)
|
||
if pagination.sort:
|
||
filteredRules = self._applySorting(filteredRules, pagination.sort)
|
||
|
||
# Count total items after filters
|
||
totalItems = len(filteredRules)
|
||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||
|
||
# Apply pagination (skip/limit)
|
||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||
endIdx = startIdx + pagination.pageSize
|
||
pagedRules = filteredRules[startIdx:endIdx]
|
||
|
||
# Convert to model objects
|
||
items = [AccessRule(**rule) for rule in pagedRules]
|
||
|
||
return PaginatedResult(
|
||
items=items,
|
||
totalItems=totalItems,
|
||
totalPages=totalPages
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error getting access rules: {str(e)}")
|
||
if pagination is None:
|
||
return []
|
||
else:
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
|
||
def getAccessRulesForRoles(
|
||
self,
|
||
roleLabels: List[str],
|
||
context: AccessRuleContext,
|
||
item: str
|
||
) -> List[AccessRule]:
|
||
"""
|
||
Get access rules for multiple roles, context, and item.
|
||
Returns the most specific matching rules for each role.
|
||
|
||
Args:
|
||
roleLabels: List of role labels
|
||
context: Context type
|
||
item: Item identifier
|
||
|
||
Returns:
|
||
List of AccessRule objects (most specific for each role)
|
||
"""
|
||
try:
|
||
# Pass self.db as dbApp since this interface uses DbApp database
|
||
RbacInstance = RbacClass(self.db, dbApp=self.db)
|
||
allRules = []
|
||
|
||
for roleLabel in roleLabels:
|
||
# Get all rules for this role and context
|
||
roleRules = RbacInstance._getRulesForRole(roleLabel, context)
|
||
|
||
# Find most specific rule for this item
|
||
mostSpecificRule = RbacInstance.findMostSpecificRule(roleRules, item)
|
||
|
||
if mostSpecificRule:
|
||
allRules.append(mostSpecificRule)
|
||
|
||
return allRules
|
||
except Exception as e:
|
||
logger.error(f"Error getting access rules for roles: {str(e)}")
|
||
return []
|
||
|
||
def createRole(self, role: Role) -> Role:
|
||
"""
|
||
Create a new role.
|
||
|
||
Args:
|
||
role: Role object to create
|
||
|
||
Returns:
|
||
Created Role object
|
||
"""
|
||
try:
|
||
# Check if role label already exists
|
||
existingRoles = self.db.getRecordset(Role, recordFilter={"roleLabel": role.roleLabel})
|
||
if existingRoles:
|
||
raise ValueError(f"Role with label '{role.roleLabel}' already exists")
|
||
|
||
createdRole = self.db.recordCreate(Role, role)
|
||
logger.info(f"Created role with ID {createdRole.get('id')} and label {role.roleLabel}")
|
||
return Role(**createdRole)
|
||
except Exception as e:
|
||
logger.error(f"Error creating role: {str(e)}")
|
||
raise
|
||
|
||
def getRole(self, roleId: str) -> Optional[Role]:
|
||
"""
|
||
Get a role by ID.
|
||
|
||
Args:
|
||
roleId: Role ID
|
||
|
||
Returns:
|
||
Role object if found, None otherwise
|
||
"""
|
||
try:
|
||
roles = self.db.getRecordset(Role, recordFilter={"id": roleId})
|
||
if roles:
|
||
return Role(**roles[0])
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting role {roleId}: {str(e)}")
|
||
return None
|
||
|
||
def getRoleByLabel(self, roleLabel: str) -> Optional[Role]:
|
||
"""
|
||
Get a role by label.
|
||
|
||
Args:
|
||
roleLabel: Role label
|
||
|
||
Returns:
|
||
Role object if found, None otherwise
|
||
"""
|
||
try:
|
||
roles = self.db.getRecordset(Role, recordFilter={"roleLabel": roleLabel})
|
||
if roles:
|
||
return Role(**roles[0])
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting role by label {roleLabel}: {str(e)}")
|
||
return None
|
||
|
||
def getRoleByLabelAndScope(
|
||
self,
|
||
roleLabel: str,
|
||
mandateId: Optional[str] = None,
|
||
featureInstanceId: Optional[str] = None,
|
||
featureCode: Optional[str] = None
|
||
) -> Optional[Role]:
|
||
"""
|
||
Get a role by label with scope filtering.
|
||
|
||
Args:
|
||
roleLabel: Role label
|
||
mandateId: Mandate ID (use None for global roles)
|
||
featureInstanceId: Feature instance ID
|
||
featureCode: Feature code
|
||
|
||
Returns:
|
||
Role object if found, None otherwise
|
||
"""
|
||
try:
|
||
recordFilter = {"roleLabel": roleLabel}
|
||
if mandateId is not None:
|
||
recordFilter["mandateId"] = mandateId
|
||
if featureInstanceId is not None:
|
||
recordFilter["featureInstanceId"] = featureInstanceId
|
||
if featureCode is not None:
|
||
recordFilter["featureCode"] = featureCode
|
||
|
||
roles = self.db.getRecordset(Role, recordFilter=recordFilter)
|
||
if roles:
|
||
return Role(**roles[0])
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"Error getting role by label and scope {roleLabel}: {str(e)}")
|
||
return None
|
||
|
||
def getRolesForMandate(self, mandateId: str) -> List[Role]:
|
||
"""
|
||
Get mandate-level roles for a specific mandate (featureInstanceId=NULL).
|
||
These are the roles created by copySystemRolesToMandate during bootstrap.
|
||
"""
|
||
try:
|
||
roles = self.db.getRecordset(
|
||
Role,
|
||
recordFilter={"mandateId": mandateId, "featureInstanceId": None}
|
||
)
|
||
return [Role(**dict(r)) for r in roles]
|
||
except Exception as e:
|
||
logger.error(f"Error getting roles for mandate {mandateId}: {e}")
|
||
return []
|
||
|
||
def getAllRoles(self, pagination: Optional[PaginationParams] = None) -> Union[List[Role], PaginatedResult]:
|
||
"""
|
||
Get all roles with optional pagination, sorting, and filtering.
|
||
|
||
Args:
|
||
pagination: Optional pagination parameters. If None, returns all items.
|
||
|
||
Returns:
|
||
If pagination is None: List[Role]
|
||
If pagination is provided: PaginatedResult with items and metadata
|
||
"""
|
||
try:
|
||
result = self.db.getRecordsetPaginated(Role, pagination=pagination)
|
||
|
||
items = []
|
||
for record in result["items"]:
|
||
cleanedRole = dict(record)
|
||
items.append(Role(**cleanedRole))
|
||
|
||
if pagination is None:
|
||
return items
|
||
|
||
return PaginatedResult(
|
||
items=items,
|
||
totalItems=result["totalItems"],
|
||
totalPages=result["totalPages"]
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error getting all roles: {str(e)}")
|
||
if pagination is None:
|
||
return []
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
|
||
def countRoleAssignments(self) -> Dict[str, int]:
|
||
"""
|
||
Count the number of user assignments per role from UserMandateRole table.
|
||
|
||
Returns:
|
||
Dictionary mapping roleId to count of user assignments
|
||
"""
|
||
try:
|
||
# Get all UserMandateRole records
|
||
assignments = self.db.getRecordset(UserMandateRole)
|
||
|
||
# Count assignments per roleId
|
||
roleCounts: Dict[str, int] = {}
|
||
for assignment in assignments:
|
||
roleId = str(assignment.get("roleId", ""))
|
||
if roleId:
|
||
roleCounts[roleId] = roleCounts.get(roleId, 0) + 1
|
||
|
||
return roleCounts
|
||
except Exception as e:
|
||
logger.error(f"Error counting role assignments: {str(e)}")
|
||
return {}
|
||
|
||
def updateRole(self, roleId: str, role: Role) -> Role:
|
||
"""
|
||
Update an existing role.
|
||
|
||
Args:
|
||
roleId: Role ID
|
||
role: Updated Role object
|
||
|
||
Returns:
|
||
Updated Role object
|
||
"""
|
||
try:
|
||
# Check if role exists
|
||
existingRole = self.getRole(roleId)
|
||
if not existingRole:
|
||
raise ValueError(f"Role with ID {roleId} not found")
|
||
|
||
# If role label is being changed, check for conflicts
|
||
if role.roleLabel != existingRole.roleLabel:
|
||
conflictingRole = self.getRoleByLabel(role.roleLabel)
|
||
if conflictingRole and conflictingRole.id != roleId:
|
||
raise ValueError(f"Role with label '{role.roleLabel}' already exists")
|
||
|
||
_IMMUTABLE_ROLE_FIELDS = {"id", "mandateId", "featureInstanceId", "featureCode", "isSystemRole"}
|
||
updatedRole = self.db.recordModify(Role, roleId, role.model_dump(exclude=_IMMUTABLE_ROLE_FIELDS))
|
||
logger.info(f"Updated role with ID {roleId}")
|
||
return Role(**updatedRole)
|
||
except Exception as e:
|
||
logger.error(f"Error updating role {roleId}: {str(e)}")
|
||
raise
|
||
|
||
def deleteRole(self, roleId: str) -> bool:
|
||
"""
|
||
Delete a role.
|
||
|
||
Args:
|
||
roleId: Role ID
|
||
|
||
Returns:
|
||
True if deleted successfully, False otherwise
|
||
"""
|
||
try:
|
||
# Check if role exists
|
||
role = self.getRole(roleId)
|
||
if not role:
|
||
return False
|
||
|
||
# Prevent deletion of system roles
|
||
if role.isSystemRole:
|
||
raise ValueError(f"Cannot delete system role '{role.roleLabel}'")
|
||
|
||
# Check if role is assigned to any users via UserMandateRole
|
||
roleAssignments = self.db.getRecordset(UserMandateRole, recordFilter={"roleId": roleId})
|
||
if roleAssignments:
|
||
raise ValueError(f"Cannot delete role '{role.roleLabel}' - it is assigned to users")
|
||
|
||
# Check if role is used in any access rules
|
||
accessRules = self.getAccessRules(roleId=roleId)
|
||
if accessRules:
|
||
raise ValueError(f"Cannot delete role '{role.roleLabel}' - it is used in access rules")
|
||
|
||
self.db.recordDelete(Role, roleId)
|
||
logger.info(f"Deleted role with ID {roleId}")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error deleting role {roleId}: {str(e)}")
|
||
raise
|
||
|
||
# -------------------------------------------------------------------------
|
||
# Table List Views (saved display presets: filters, sort, groupByLevels)
|
||
# -------------------------------------------------------------------------
|
||
|
||
def getTableListViews(self, contextKey: str) -> list:
|
||
"""Return all saved views for the current user and contextKey."""
|
||
from modules.datamodels.datamodelPagination import TableListView
|
||
try:
|
||
rows = self.db.getRecordset(
|
||
TableListView,
|
||
recordFilter={"userId": str(self.userId), "contextKey": contextKey},
|
||
)
|
||
result = []
|
||
for row in (rows or []):
|
||
try:
|
||
result.append(TableListView.model_validate(row) if isinstance(row, dict) else row)
|
||
except Exception:
|
||
pass
|
||
return result
|
||
except Exception as e:
|
||
logger.error(f"getTableListViews failed for user={self.userId} context={contextKey}: {e}")
|
||
return []
|
||
|
||
def getTableListView(self, contextKey: str, viewKey: str):
|
||
"""Return one view by viewKey or None if not found."""
|
||
from modules.datamodels.datamodelPagination import TableListView
|
||
try:
|
||
rows = self.db.getRecordset(
|
||
TableListView,
|
||
recordFilter={"userId": str(self.userId), "contextKey": contextKey, "viewKey": viewKey},
|
||
)
|
||
if not rows:
|
||
return None
|
||
row = rows[0]
|
||
return TableListView.model_validate(row) if isinstance(row, dict) else row
|
||
except Exception as e:
|
||
logger.error(f"getTableListView failed for user={self.userId} key={viewKey}: {e}")
|
||
return None
|
||
|
||
def createTableListView(self, contextKey: str, viewKey: str, displayName: str, config: dict):
|
||
"""Create a new view. Raises ValueError if viewKey already exists for this context."""
|
||
from modules.datamodels.datamodelPagination import TableListView
|
||
from modules.shared.timeUtils import getUtcTimestamp
|
||
if self.getTableListView(contextKey=contextKey, viewKey=viewKey) is not None:
|
||
raise ValueError(f"View '{viewKey}' already exists for context '{contextKey}'")
|
||
data = {
|
||
"id": str(uuid.uuid4()),
|
||
"userId": str(self.userId),
|
||
"contextKey": contextKey,
|
||
"viewKey": viewKey,
|
||
"displayName": displayName,
|
||
"config": config,
|
||
"updatedAt": getUtcTimestamp(),
|
||
}
|
||
try:
|
||
self.db.recordCreate(TableListView, data)
|
||
return TableListView.model_validate(data)
|
||
except Exception as e:
|
||
logger.error(f"createTableListView failed: {e}")
|
||
raise
|
||
|
||
def updateTableListView(self, viewId: str, updates: dict):
|
||
"""Update an existing view by its primary key id."""
|
||
from modules.datamodels.datamodelPagination import TableListView
|
||
from modules.shared.timeUtils import getUtcTimestamp
|
||
try:
|
||
updates = {**updates, "updatedAt": getUtcTimestamp()}
|
||
self.db.recordModify(TableListView, viewId, updates)
|
||
rows = self.db.getRecordset(TableListView, recordFilter={"id": viewId})
|
||
if rows:
|
||
row = rows[0]
|
||
return TableListView.model_validate(row) if isinstance(row, dict) else row
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"updateTableListView failed for id={viewId}: {e}")
|
||
raise
|
||
|
||
def deleteTableListView(self, viewId: str) -> bool:
|
||
"""Delete a view by primary key id. Returns True on success."""
|
||
from modules.datamodels.datamodelPagination import TableListView
|
||
try:
|
||
self.db.recordDelete(TableListView, viewId)
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"deleteTableListView failed for id={viewId}: {e}")
|
||
return False
|
||
|
||
|
||
# Public Methods
|
||
|
||
|
||
def getInterface(currentUser: User, mandateId: Optional[str] = None) -> AppObjects:
|
||
"""
|
||
Returns a AppObjects instance for the current user.
|
||
Handles initialization of database and records.
|
||
|
||
Multi-Tenant Design:
|
||
- mandateId wird explizit übergeben (aus Request-Context / X-Mandate-Id Header)
|
||
|
||
Args:
|
||
currentUser: User object
|
||
mandateId: Explicit mandate context (from request header). Required for non-sysadmin.
|
||
|
||
Returns:
|
||
AppObjects instance for the user context
|
||
"""
|
||
if not currentUser:
|
||
raise ValueError("Invalid user context: user is required")
|
||
|
||
effectiveMandateId = mandateId
|
||
|
||
# Create context key (user + mandate combination)
|
||
contextKey = f"{effectiveMandateId}_{currentUser.id}"
|
||
|
||
# Create new instance if not exists
|
||
if contextKey not in _gatewayInterfaces:
|
||
instance = AppObjects(currentUser)
|
||
instance.setUserContext(currentUser, mandateId=effectiveMandateId)
|
||
_gatewayInterfaces[contextKey] = instance
|
||
else:
|
||
# Re-apply user on every resolve: a prior code path (e.g. legacy logout) may have
|
||
# cleared currentUser on this cached singleton; OAuth/login must not see a stale context.
|
||
_gatewayInterfaces[contextKey].setUserContext(currentUser, mandateId=effectiveMandateId)
|
||
|
||
return _gatewayInterfaces[contextKey]
|
||
|
||
|
||
def getRootInterface() -> AppObjects:
|
||
"""
|
||
Returns a AppObjects instance with root privileges.
|
||
This is used for initial setup and user creation.
|
||
|
||
Note: This function uses security.rootAccess internally to avoid circular dependencies.
|
||
Routes can continue using this function, but connectors/interfaces should use
|
||
security.rootAccess.getRootDbAppConnector() or security.rootAccess.getRootUser() directly.
|
||
"""
|
||
global _rootAppObjects
|
||
|
||
if _rootAppObjects is None:
|
||
try:
|
||
# Use security.rootAccess to get root user (avoids circular dependencies)
|
||
from modules.security.rootAccess import getRootUser
|
||
rootUser = getRootUser()
|
||
|
||
# Create root interface with the root user
|
||
_rootAppObjects = AppObjects(rootUser)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting root user: {str(e)}")
|
||
raise ValueError(f"Failed to get root user: {str(e)}")
|
||
|
||
return _rootAppObjects
|