1670 lines
62 KiB
Python
1670 lines
62 KiB
Python
"""
|
|
Interface to the Gateway system.
|
|
Manages users and mandates for authentication.
|
|
"""
|
|
|
|
import logging
|
|
import math
|
|
from typing import Dict, Any, List, Optional, Union
|
|
from passlib.context import CryptContext
|
|
import uuid
|
|
|
|
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
|
|
from modules.interfaces.interfaceBootstrap import initBootstrap
|
|
from modules.security.rbac import RbacClass
|
|
from modules.datamodels.datamodelUam import (
|
|
User,
|
|
Mandate,
|
|
UserInDB,
|
|
UserConnection,
|
|
AuthAuthority,
|
|
UserPrivilege,
|
|
ConnectionStatus,
|
|
)
|
|
from modules.datamodels.datamodelRbac import (
|
|
AccessRule,
|
|
AccessRuleContext,
|
|
)
|
|
from modules.datamodels.datamodelUam import AccessLevel
|
|
from modules.datamodels.datamodelSecurity import Token, AuthEvent, TokenStatus
|
|
from modules.datamodels.datamodelNeutralizer import (
|
|
DataNeutraliserConfig,
|
|
DataNeutralizerAttributes,
|
|
)
|
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Singleton factory for AppObjects instances per context
|
|
_gatewayInterfaces = {}
|
|
|
|
# Root interface instance
|
|
_rootAppObjects = None
|
|
|
|
# Password-Hashing
|
|
pwdContext = CryptContext(schemes=["argon2"], deprecated="auto")
|
|
|
|
|
|
class AppObjects:
|
|
"""
|
|
Interface to the Gateway system.
|
|
Manages users and mandates.
|
|
"""
|
|
|
|
def __init__(self, currentUser: Optional[User] = None):
|
|
"""Initializes the Gateway Interface."""
|
|
# Initialize variables
|
|
self.currentUser = currentUser # Store User object directly
|
|
self.userId = currentUser.id if currentUser else None
|
|
self.mandateId = currentUser.mandateId if currentUser else None
|
|
|
|
# Initialize database
|
|
self._initializeDatabase()
|
|
|
|
# Initialize standard records if needed
|
|
self._initRecords()
|
|
|
|
# Set user context if provided
|
|
if currentUser:
|
|
self.setUserContext(currentUser)
|
|
|
|
def setUserContext(self, currentUser: User):
|
|
"""Sets the user context for the interface."""
|
|
if not currentUser:
|
|
logger.info("Initializing interface without user context")
|
|
return
|
|
|
|
self.currentUser = currentUser # Store User object directly
|
|
self.userId = currentUser.id
|
|
self.mandateId = currentUser.mandateId
|
|
|
|
if not self.userId or not self.mandateId:
|
|
raise ValueError("Invalid user context: id and mandateId are required")
|
|
|
|
# Add language settings
|
|
self.userLanguage = currentUser.language # Default user language
|
|
|
|
# Initialize RBAC interface
|
|
if not currentUser:
|
|
raise ValueError("User context is required for RBAC")
|
|
self.rbac = RbacClass(self.db)
|
|
|
|
# Update database context
|
|
self.db.updateContext(self.userId)
|
|
|
|
def __del__(self):
|
|
"""Cleanup method to close database connection."""
|
|
if hasattr(self, "db") and self.db is not None:
|
|
try:
|
|
self.db.close()
|
|
except Exception as e:
|
|
logger.error(f"Error closing database connection: {e}")
|
|
|
|
def _initializeDatabase(self):
|
|
"""Initializes the database connection directly."""
|
|
try:
|
|
# Get configuration values with defaults
|
|
dbHost = APP_CONFIG.get("DB_APP_HOST", "_no_config_default_data")
|
|
dbDatabase = APP_CONFIG.get("DB_APP_DATABASE", "app")
|
|
dbUser = APP_CONFIG.get("DB_APP_USER")
|
|
dbPassword = APP_CONFIG.get("DB_APP_PASSWORD_SECRET")
|
|
dbPort = int(APP_CONFIG.get("DB_APP_PORT", 5432))
|
|
|
|
# Create database connector directly
|
|
self.db = DatabaseConnector(
|
|
dbHost=dbHost,
|
|
dbDatabase=dbDatabase,
|
|
dbUser=dbUser,
|
|
dbPassword=dbPassword,
|
|
dbPort=dbPort,
|
|
userId=self.userId,
|
|
)
|
|
|
|
# Initialize database system
|
|
self.db.initDbSystem()
|
|
|
|
logger.info(f"Database initialized successfully for user {self.userId}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize database: {str(e)}")
|
|
raise
|
|
|
|
def _initRecords(self):
|
|
"""Initialize standard records if they don't exist."""
|
|
initBootstrap(self.db)
|
|
|
|
|
|
def checkRbacPermission(
|
|
self,
|
|
modelClass: type,
|
|
operation: str,
|
|
recordId: Optional[str] = None
|
|
) -> bool:
|
|
"""
|
|
Check RBAC permission for a specific operation on a table.
|
|
|
|
Args:
|
|
modelClass: Pydantic model class for the table
|
|
operation: Operation to check ('create', 'update', 'delete', 'read')
|
|
recordId: Optional record ID for specific record check
|
|
|
|
Returns:
|
|
Boolean indicating permission
|
|
"""
|
|
if not self.rbac or not self.currentUser:
|
|
return False
|
|
|
|
tableName = modelClass.__name__
|
|
permissions = self.rbac.getUserPermissions(
|
|
self.currentUser,
|
|
AccessRuleContext.DATA,
|
|
tableName
|
|
)
|
|
|
|
if operation == "create":
|
|
return permissions.create != AccessLevel.NONE
|
|
elif operation == "update":
|
|
return permissions.update != AccessLevel.NONE
|
|
elif operation == "delete":
|
|
return permissions.delete != AccessLevel.NONE
|
|
elif operation == "read":
|
|
return permissions.read != AccessLevel.NONE
|
|
else:
|
|
return False
|
|
|
|
def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Apply filter criteria to records.
|
|
|
|
Supports:
|
|
- General search: {"search": "text"} - searches across all text fields
|
|
- Field-specific filters:
|
|
- Simple: {"status": "running"} - equals match
|
|
- With operator: {"status": {"operator": "equals", "value": "running"}}
|
|
- Operators: equals, contains, gt, gte, lt, lte, in, notIn, startsWith, endsWith
|
|
|
|
Args:
|
|
records: List of record dictionaries to filter
|
|
filters: Filter criteria dictionary
|
|
|
|
Returns:
|
|
Filtered list of records
|
|
"""
|
|
if not filters or not records:
|
|
return records
|
|
|
|
filtered = []
|
|
|
|
for record in records:
|
|
matches = True
|
|
|
|
# Handle general search across text fields
|
|
if "search" in filters:
|
|
search_term = str(filters["search"]).lower()
|
|
if search_term:
|
|
# Search in all string fields
|
|
found = False
|
|
for key, value in record.items():
|
|
if isinstance(value, str) and search_term in value.lower():
|
|
found = True
|
|
break
|
|
elif isinstance(value, (int, float)) and search_term in str(value):
|
|
found = True
|
|
break
|
|
if not found:
|
|
matches = False
|
|
|
|
# Handle field-specific filters
|
|
for field_name, filter_value in filters.items():
|
|
if field_name == "search":
|
|
continue # Already handled above
|
|
|
|
if field_name not in record:
|
|
matches = False
|
|
break
|
|
|
|
record_value = record.get(field_name)
|
|
|
|
# Handle simple value (equals operator)
|
|
if not isinstance(filter_value, dict):
|
|
if record_value != filter_value:
|
|
matches = False
|
|
break
|
|
continue
|
|
|
|
# Handle filter with operator
|
|
operator = filter_value.get("operator", "equals")
|
|
filter_val = filter_value.get("value")
|
|
|
|
if operator in ["equals", "eq"]:
|
|
if record_value != filter_val:
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "contains":
|
|
record_str = str(record_value).lower() if record_value is not None else ""
|
|
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
|
if filter_str not in record_str:
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "startsWith":
|
|
record_str = str(record_value).lower() if record_value is not None else ""
|
|
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
|
if not record_str.startswith(filter_str):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "endsWith":
|
|
record_str = str(record_value).lower() if record_value is not None else ""
|
|
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
|
if not record_str.endswith(filter_str):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "gt":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('-inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('-inf')
|
|
if record_num <= filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "gte":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('-inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('-inf')
|
|
if record_num < filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "lt":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('inf')
|
|
if record_num >= filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "lte":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('inf')
|
|
if record_num > filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "in":
|
|
if not isinstance(filter_val, list):
|
|
filter_val = [filter_val]
|
|
if record_value not in filter_val:
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "notIn":
|
|
if not isinstance(filter_val, list):
|
|
filter_val = [filter_val]
|
|
if record_value in filter_val:
|
|
matches = False
|
|
break
|
|
|
|
else:
|
|
# Unknown operator - default to equals
|
|
if record_value != filter_val:
|
|
matches = False
|
|
break
|
|
|
|
if matches:
|
|
filtered.append(record)
|
|
|
|
return filtered
|
|
|
|
def _applySorting(self, records: List[Dict[str, Any]], sortFields: List[Any]) -> List[Dict[str, Any]]:
|
|
"""Apply multi-level sorting to records using stable sort (sorts from least to most significant field)."""
|
|
if not sortFields:
|
|
return records
|
|
|
|
# Start with a copy to avoid modifying original
|
|
sortedRecords = list(records)
|
|
|
|
# Sort from least significant to most significant field (reverse order)
|
|
# Python's sort is stable, so this creates proper multi-level sorting
|
|
for sortField in reversed(sortFields):
|
|
# Handle both dict and object formats
|
|
if isinstance(sortField, dict):
|
|
fieldName = sortField.get("field")
|
|
direction = sortField.get("direction", "asc")
|
|
else:
|
|
fieldName = getattr(sortField, "field", None)
|
|
direction = getattr(sortField, "direction", "asc")
|
|
|
|
if not fieldName:
|
|
continue
|
|
|
|
isDesc = (direction == "desc")
|
|
|
|
def sortKey(record):
|
|
value = record.get(fieldName)
|
|
# Handle None values - place them at the end for both directions
|
|
if value is None:
|
|
# Use a special value that sorts last
|
|
return (1, "") # (is_none_flag, empty_value) - sorts after (0, ...)
|
|
else:
|
|
# Return tuple with type indicator for proper comparison
|
|
if isinstance(value, (int, float)):
|
|
return (0, value)
|
|
elif isinstance(value, str):
|
|
return (0, value)
|
|
elif isinstance(value, bool):
|
|
return (0, value)
|
|
else:
|
|
return (0, str(value))
|
|
|
|
# Sort with reverse parameter for descending
|
|
sortedRecords.sort(key=sortKey, reverse=isDesc)
|
|
|
|
return sortedRecords
|
|
|
|
def getInitialId(self, model_class: type) -> Optional[str]:
|
|
"""Returns the initial ID for a table."""
|
|
return self.db.getInitialId(model_class)
|
|
|
|
def _getDefaultMandateId(self) -> str:
|
|
"""Get the default mandate ID, creating it if necessary."""
|
|
defaultMandateId = self.getInitialId(Mandate)
|
|
if not defaultMandateId:
|
|
# If no default mandate exists, create one
|
|
logger.warning("No default mandate found, creating Root mandate")
|
|
self._initRootMandate()
|
|
defaultMandateId = self.getInitialId(Mandate)
|
|
if not defaultMandateId:
|
|
raise ValueError("Failed to get or create default mandate")
|
|
return defaultMandateId
|
|
|
|
def _getPasswordHash(self, password: str) -> str:
|
|
"""Creates a hash for a password."""
|
|
return pwdContext.hash(password)
|
|
|
|
def _verifyPassword(self, plainPassword: str, hashedPassword: str) -> bool:
|
|
"""Checks if the password matches the hash."""
|
|
return pwdContext.verify(plainPassword, hashedPassword)
|
|
|
|
# User methods
|
|
|
|
def getUsersByMandate(self, mandateId: str, pagination: Optional[PaginationParams] = None) -> Union[List[User], PaginatedResult]:
|
|
"""
|
|
Returns users for a specific mandate if user has access.
|
|
Supports optional pagination, sorting, and filtering.
|
|
For SYSADMIN, returns all users regardless of mandate.
|
|
|
|
Args:
|
|
mandateId: The mandate ID to get users for (ignored for SYSADMIN)
|
|
pagination: Optional pagination parameters. If None, returns all items.
|
|
|
|
Returns:
|
|
If pagination is None: List[User]
|
|
If pagination is provided: PaginatedResult with items and metadata
|
|
"""
|
|
# Use RBAC filtering
|
|
users = self.db.getRecordsetWithRBAC(
|
|
UserInDB,
|
|
self.currentUser,
|
|
recordFilter={"mandateId": mandateId} if mandateId else None
|
|
)
|
|
|
|
# Filter out database-specific fields
|
|
filteredUsers = []
|
|
for user in users:
|
|
cleanedUser = {k: v for k, v in user.items() if not k.startswith("_")}
|
|
filteredUsers.append(cleanedUser)
|
|
|
|
# If no pagination requested, return all items
|
|
if pagination is None:
|
|
return [User(**user) for user in filteredUsers]
|
|
|
|
# Apply filtering (if filters provided)
|
|
if pagination.filters:
|
|
filteredUsers = self._applyFilters(filteredUsers, pagination.filters)
|
|
|
|
# Apply sorting (in order of sortFields)
|
|
if pagination.sort:
|
|
filteredUsers = self._applySorting(filteredUsers, pagination.sort)
|
|
|
|
# Count total items after filters
|
|
totalItems = len(filteredUsers)
|
|
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
|
|
|
# Apply pagination (skip/limit)
|
|
startIdx = (pagination.page - 1) * pagination.pageSize
|
|
endIdx = startIdx + pagination.pageSize
|
|
pagedUsers = filteredUsers[startIdx:endIdx]
|
|
|
|
# Convert to model objects
|
|
items = [User(**user) for user in pagedUsers]
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=totalItems,
|
|
totalPages=totalPages
|
|
)
|
|
|
|
def getUserByUsername(self, username: str) -> Optional[User]:
|
|
"""Returns a user by username."""
|
|
try:
|
|
# Use RBAC filtering
|
|
users = self.db.getRecordsetWithRBAC(
|
|
UserInDB,
|
|
self.currentUser,
|
|
recordFilter={"username": username}
|
|
)
|
|
|
|
if not users:
|
|
logger.info(f"No user found with username {username}")
|
|
return None
|
|
|
|
# Return first matching user (should be unique)
|
|
userDict = users[0]
|
|
# Filter out database-specific fields
|
|
cleanedUser = {k: v for k, v in userDict.items() if not k.startswith("_")}
|
|
return User(**cleanedUser)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting user by username: {str(e)}")
|
|
return None
|
|
|
|
def getUser(self, userId: str) -> Optional[User]:
|
|
"""Returns a user by ID if user has access."""
|
|
try:
|
|
# Get all users
|
|
users = self.db.getRecordset(UserInDB)
|
|
if not users:
|
|
return None
|
|
|
|
# Find user by ID
|
|
for user_dict in users:
|
|
if user_dict.get("id") == userId:
|
|
# User already filtered by RBAC, just clean fields
|
|
cleanedUser = {k: v for k, v in user_dict.items() if not k.startswith("_")}
|
|
return User(**cleanedUser)
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting user by ID: {str(e)}")
|
|
return None
|
|
|
|
def authenticateLocalUser(self, username: str, password: str) -> Optional[User]:
|
|
"""Authenticates a user by username and password using local authentication."""
|
|
# Clear the users table from cache and reload it
|
|
|
|
# Get user by username
|
|
user = self.getUserByUsername(username)
|
|
|
|
if not user:
|
|
raise ValueError("User not found")
|
|
|
|
# Check if the user is enabled
|
|
if not user.enabled:
|
|
raise ValueError("User is disabled")
|
|
|
|
# Verify that the user has local authentication enabled
|
|
if user.authenticationAuthority != AuthAuthority.LOCAL:
|
|
raise ValueError("User does not have local authentication enabled")
|
|
|
|
# Get the full user record with password hash for verification
|
|
userRecord = self.db.getRecordset(UserInDB, recordFilter={"id": user.id})[0]
|
|
if not userRecord.get("hashedPassword"):
|
|
raise ValueError("User has no password set")
|
|
|
|
if not self._verifyPassword(password, userRecord["hashedPassword"]):
|
|
raise ValueError("Invalid password")
|
|
|
|
return user
|
|
|
|
def createUser(
|
|
self,
|
|
username: str,
|
|
password: str = None,
|
|
email: str = None,
|
|
fullName: str = None,
|
|
language: str = "en",
|
|
enabled: bool = True,
|
|
privilege: UserPrivilege = UserPrivilege.USER,
|
|
authenticationAuthority: AuthAuthority = AuthAuthority.LOCAL,
|
|
externalId: str = None,
|
|
externalUsername: str = None,
|
|
externalEmail: str = None,
|
|
) -> User:
|
|
"""Create a new user with optional external connection"""
|
|
try:
|
|
# Ensure username is a string
|
|
username = str(username).strip()
|
|
|
|
# Validate password for local authentication
|
|
if authenticationAuthority == AuthAuthority.LOCAL:
|
|
if not password:
|
|
raise ValueError("Password is required for local authentication")
|
|
if not isinstance(password, str):
|
|
raise ValueError("Password must be a string")
|
|
if not password.strip():
|
|
raise ValueError("Password cannot be empty")
|
|
|
|
# Ensure mandateId is set - use self.mandateId or default mandate
|
|
mandateId = self.mandateId
|
|
if not mandateId:
|
|
mandateId = self._getDefaultMandateId()
|
|
logger.warning(f"Using default mandate ID {mandateId} for new user {username}")
|
|
|
|
# Create user data using UserInDB model
|
|
userData = UserInDB(
|
|
username=username,
|
|
email=email,
|
|
fullName=fullName,
|
|
language=language,
|
|
mandateId=mandateId,
|
|
enabled=enabled,
|
|
privilege=privilege,
|
|
authenticationAuthority=authenticationAuthority,
|
|
hashedPassword=self._getPasswordHash(password) if password else None,
|
|
connections=[],
|
|
)
|
|
|
|
# Create user record
|
|
createdRecord = self.db.recordCreate(UserInDB, userData)
|
|
if not createdRecord or not createdRecord.get("id"):
|
|
raise ValueError("Failed to create user record")
|
|
|
|
# Add external connection if provided
|
|
if externalId and externalUsername:
|
|
self.addUserConnection(
|
|
createdRecord["id"],
|
|
authenticationAuthority,
|
|
externalId,
|
|
externalUsername,
|
|
externalEmail,
|
|
)
|
|
|
|
# Get created user using the returned ID
|
|
createdUser = self.db.getRecordset(
|
|
UserInDB, recordFilter={"id": createdRecord["id"]}
|
|
)
|
|
if not createdUser or len(createdUser) == 0:
|
|
raise ValueError("Failed to retrieve created user")
|
|
|
|
# Clear cache to ensure fresh data (already done above)
|
|
|
|
return User(**createdUser[0])
|
|
|
|
except ValueError as e:
|
|
logger.error(f"Error creating user: {str(e)}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Unexpected error creating user: {str(e)}")
|
|
raise ValueError(f"Failed to create user: {str(e)}")
|
|
|
|
def updateUser(self, userId: str, updateData: Union[Dict[str, Any], User]) -> User:
|
|
"""Update a user's information"""
|
|
try:
|
|
# Get user
|
|
user = self.getUser(userId)
|
|
if not user:
|
|
raise ValueError(f"User {userId} not found")
|
|
|
|
# Convert updateData to dict if it's a User model
|
|
if isinstance(updateData, User):
|
|
updateDict = updateData.model_dump()
|
|
else:
|
|
updateDict = updateData.copy() if isinstance(updateData, dict) else updateData
|
|
|
|
# Remove id field from updateDict if present - we'll use userId from parameter
|
|
updateDict.pop("id", None)
|
|
|
|
# Ensure mandateId is set - if missing or None, use default mandate
|
|
if "mandateId" not in updateDict or not updateDict.get("mandateId"):
|
|
if not user.mandateId:
|
|
# User has no mandateId, set to default
|
|
defaultMandateId = self._getDefaultMandateId()
|
|
updateDict["mandateId"] = defaultMandateId
|
|
logger.warning(f"Setting default mandate ID {defaultMandateId} for user {userId}")
|
|
else:
|
|
# Keep existing mandateId if update doesn't provide one
|
|
updateDict["mandateId"] = user.mandateId
|
|
|
|
# Update user data using model
|
|
updatedData = user.model_dump()
|
|
updatedData.update(updateDict)
|
|
# Ensure ID matches userId parameter
|
|
updatedData["id"] = userId
|
|
# Ensure mandateId is set in final data
|
|
if not updatedData.get("mandateId"):
|
|
updatedData["mandateId"] = self._getDefaultMandateId()
|
|
updatedUser = User(**updatedData)
|
|
|
|
# Update user record
|
|
self.db.recordModify(UserInDB, userId, updatedUser)
|
|
|
|
# Get updated user
|
|
updatedUser = self.getUser(userId)
|
|
if not updatedUser:
|
|
raise ValueError("Failed to retrieve updated user")
|
|
|
|
return updatedUser
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating user: {str(e)}")
|
|
raise ValueError(f"Failed to update user: {str(e)}")
|
|
|
|
def disableUser(self, userId: str) -> User:
|
|
"""Disables a user if current user has permission."""
|
|
return self.updateUser(userId, {"enabled": False})
|
|
|
|
def enableUser(self, userId: str) -> User:
|
|
"""Enables a user if current user has permission."""
|
|
return self.updateUser(userId, {"enabled": True})
|
|
|
|
def _deleteUserReferencedData(self, userId: str) -> None:
|
|
"""Deletes all data associated with a user."""
|
|
try:
|
|
# Delete user auth events
|
|
events = self.db.getRecordset(AuthEvent, recordFilter={"userId": userId})
|
|
for event in events:
|
|
self.db.recordDelete(AuthEvent, event["id"])
|
|
|
|
# Delete user tokens
|
|
tokens = self.db.getRecordset(Token, recordFilter={"userId": userId})
|
|
for token in tokens:
|
|
self.db.recordDelete(Token, token["id"])
|
|
|
|
# Delete user connections
|
|
connections = self.db.getRecordset(
|
|
UserConnection, recordFilter={"userId": userId}
|
|
)
|
|
for conn in connections:
|
|
self.db.recordDelete(UserConnection, conn["id"])
|
|
|
|
logger.info(f"All referenced data for user {userId} has been deleted")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting referenced data for user {userId}: {str(e)}")
|
|
raise
|
|
|
|
def deleteUser(self, userId: str) -> bool:
|
|
"""Deletes a user if current user has permission."""
|
|
try:
|
|
# Get user
|
|
user = self.getUser(userId)
|
|
if not user:
|
|
raise ValueError(f"User {userId} not found")
|
|
|
|
if not self.checkRbacPermission(UserInDB, "update", userId):
|
|
raise PermissionError(f"No permission to delete user {userId}")
|
|
|
|
# Delete all referenced data first
|
|
self._deleteUserReferencedData(userId)
|
|
|
|
# Delete user record
|
|
success = self.db.recordDelete(UserInDB, userId)
|
|
if not success:
|
|
raise ValueError(f"Failed to delete user {userId}")
|
|
|
|
logger.info(f"User {userId} successfully deleted")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting user: {str(e)}")
|
|
raise ValueError(f"Failed to delete user: {str(e)}")
|
|
|
|
def _getInitialUser(self) -> Optional[Dict[str, Any]]:
|
|
"""Get the initial user record directly from database without access control."""
|
|
try:
|
|
initialUserId = self.getInitialId(UserInDB)
|
|
if not initialUserId:
|
|
return None
|
|
|
|
users = self.db.getRecordset(UserInDB, recordFilter={"id": initialUserId})
|
|
return users[0] if users else None
|
|
except Exception as e:
|
|
logger.error(f"Error getting initial user: {str(e)}")
|
|
return None
|
|
|
|
def checkUsernameAvailability(self, checkData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Checks if a username is available for registration."""
|
|
try:
|
|
username = checkData.get("username")
|
|
authenticationAuthority = checkData.get("authenticationAuthority", "local")
|
|
|
|
if not username:
|
|
return {"available": False, "message": "Username is required"}
|
|
|
|
# Get user by username
|
|
user = self.getUserByUsername(username)
|
|
|
|
# Check if user exists (User model instance)
|
|
if user is not None:
|
|
return {"available": False, "message": "Username is already taken"}
|
|
|
|
return {"available": True, "message": "Username is available"}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error checking username availability: {str(e)}")
|
|
return {
|
|
"available": False,
|
|
"message": f"Error checking username availability: {str(e)}",
|
|
}
|
|
|
|
# Connection methods
|
|
|
|
def getUserConnections(self, userId: str) -> List[UserConnection]:
|
|
"""Returns all connections for a user."""
|
|
try:
|
|
# Get connections for this user
|
|
connections = self.db.getRecordset(
|
|
UserConnection, recordFilter={"userId": userId}
|
|
)
|
|
|
|
# Convert to UserConnection objects
|
|
result = []
|
|
for conn_dict in connections:
|
|
try:
|
|
# Create UserConnection object
|
|
connection = UserConnection(
|
|
id=conn_dict["id"],
|
|
userId=conn_dict["userId"],
|
|
authority=conn_dict.get("authority"),
|
|
externalId=conn_dict.get("externalId", ""),
|
|
externalUsername=conn_dict.get("externalUsername", ""),
|
|
externalEmail=conn_dict.get("externalEmail"),
|
|
status=conn_dict.get("status", "pending"),
|
|
connectedAt=conn_dict.get("connectedAt"),
|
|
lastChecked=conn_dict.get("lastChecked"),
|
|
expiresAt=conn_dict.get("expiresAt"),
|
|
)
|
|
result.append(connection)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error converting connection dict to object: {str(e)}"
|
|
)
|
|
continue
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting user connections: {str(e)}")
|
|
return []
|
|
|
|
def addUserConnection(
|
|
self,
|
|
userId: str,
|
|
authority: AuthAuthority,
|
|
externalId: str,
|
|
externalUsername: str,
|
|
externalEmail: Optional[str] = None,
|
|
status: ConnectionStatus = ConnectionStatus.PENDING,
|
|
) -> UserConnection:
|
|
"""
|
|
Adds a new connection for a user.
|
|
|
|
Args:
|
|
userId: The ID of the user
|
|
authority: The authentication authority (e.g., MSFT, GOOGLE)
|
|
externalId: The external ID from the authority
|
|
externalUsername: The username from the authority
|
|
externalEmail: Optional email from the authority
|
|
status: The connection status (defaults to PENDING)
|
|
|
|
Returns:
|
|
The created UserConnection object
|
|
"""
|
|
try:
|
|
# Get the user
|
|
user = self.getUser(userId)
|
|
if not user:
|
|
raise ValueError(f"User not found: {userId}")
|
|
|
|
# Create new connection with all required fields
|
|
connection = UserConnection(
|
|
id=str(uuid.uuid4()),
|
|
userId=userId,
|
|
authority=authority,
|
|
externalId=externalId,
|
|
externalUsername=externalUsername,
|
|
externalEmail=externalEmail,
|
|
status=status,
|
|
connectedAt=getUtcTimestamp(),
|
|
lastChecked=getUtcTimestamp(),
|
|
expiresAt=None, # Optional field, set to None by default
|
|
)
|
|
|
|
# Save to connections table
|
|
self.db.recordCreate(UserConnection, connection)
|
|
|
|
return connection
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding user connection: {str(e)}")
|
|
raise ValueError(f"Failed to add user connection: {str(e)}")
|
|
|
|
def removeUserConnection(self, connectionId: str) -> None:
|
|
"""Remove a connection to an external service"""
|
|
try:
|
|
# Get connection
|
|
connections = self.db.getRecordset(
|
|
UserConnection, recordFilter={"id": connectionId}
|
|
)
|
|
|
|
if not connections:
|
|
raise ValueError(f"Connection {connectionId} not found")
|
|
|
|
# Delete connection
|
|
self.db.recordDelete(UserConnection, connectionId)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error removing user connection: {str(e)}")
|
|
raise ValueError(f"Failed to remove user connection: {str(e)}")
|
|
|
|
# Mandate methods
|
|
|
|
def getAllMandates(self, pagination: Optional[PaginationParams] = None) -> Union[List[Mandate], PaginatedResult]:
|
|
"""
|
|
Returns all mandates based on user access level.
|
|
Supports optional pagination, sorting, and filtering.
|
|
|
|
Args:
|
|
pagination: Optional pagination parameters. If None, returns all items.
|
|
|
|
Returns:
|
|
If pagination is None: List[Mandate]
|
|
If pagination is provided: PaginatedResult with items and metadata
|
|
"""
|
|
# Use RBAC filtering
|
|
allMandates = self.db.getRecordsetWithRBAC(Mandate, self.currentUser)
|
|
|
|
# Filter out database-specific fields
|
|
filteredMandates = []
|
|
for mandate in allMandates:
|
|
cleanedMandate = {k: v for k, v in mandate.items() if not k.startswith("_")}
|
|
filteredMandates.append(cleanedMandate)
|
|
|
|
# If no pagination requested, return all items
|
|
if pagination is None:
|
|
return [Mandate(**mandate) for mandate in filteredMandates]
|
|
|
|
# Apply filtering (if filters provided)
|
|
if pagination.filters:
|
|
filteredMandates = self._applyFilters(filteredMandates, pagination.filters)
|
|
|
|
# Apply sorting (in order of sortFields)
|
|
if pagination.sort:
|
|
filteredMandates = self._applySorting(filteredMandates, pagination.sort)
|
|
|
|
# Count total items after filters
|
|
totalItems = len(filteredMandates)
|
|
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
|
|
|
# Apply pagination (skip/limit)
|
|
startIdx = (pagination.page - 1) * pagination.pageSize
|
|
endIdx = startIdx + pagination.pageSize
|
|
pagedMandates = filteredMandates[startIdx:endIdx]
|
|
|
|
# Convert to model objects
|
|
items = [Mandate(**mandate) for mandate in pagedMandates]
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=totalItems,
|
|
totalPages=totalPages
|
|
)
|
|
|
|
def getMandate(self, mandateId: str) -> Optional[Mandate]:
|
|
"""Returns a mandate by ID if user has access."""
|
|
# Use RBAC filtering
|
|
mandates = self.db.getRecordsetWithRBAC(
|
|
Mandate,
|
|
self.currentUser,
|
|
recordFilter={"id": mandateId}
|
|
)
|
|
|
|
if not mandates:
|
|
return None
|
|
|
|
# Filter out database-specific fields
|
|
filteredMandates = []
|
|
for mandate in mandates:
|
|
cleanedMandate = {k: v for k, v in mandate.items() if not k.startswith("_")}
|
|
filteredMandates.append(cleanedMandate)
|
|
if not filteredMandates:
|
|
return None
|
|
|
|
return Mandate(**filteredMandates[0])
|
|
|
|
def createMandate(self, name: str, language: str = "en") -> Mandate:
|
|
"""Creates a new mandate if user has permission."""
|
|
if not self.checkRbacPermission(Mandate, "create"):
|
|
raise PermissionError("No permission to create mandates")
|
|
|
|
# Create mandate data using model
|
|
mandateData = Mandate(name=name, language=language)
|
|
|
|
# Create mandate record
|
|
createdRecord = self.db.recordCreate(Mandate, mandateData)
|
|
if not createdRecord or not createdRecord.get("id"):
|
|
raise ValueError("Failed to create mandate record")
|
|
|
|
return Mandate(**createdRecord)
|
|
|
|
def updateMandate(self, mandateId: str, updateData: Dict[str, Any]) -> Mandate:
|
|
"""Updates a mandate if user has access."""
|
|
try:
|
|
# First check if user has permission to modify mandates
|
|
if not self.checkRbacPermission(Mandate, "update", mandateId):
|
|
raise PermissionError(f"No permission to update mandate {mandateId}")
|
|
|
|
# Get mandate with access control
|
|
mandate = self.getMandate(mandateId)
|
|
if not mandate:
|
|
raise ValueError(f"Mandate {mandateId} not found")
|
|
|
|
# Update mandate data using model
|
|
updatedData = mandate.model_dump()
|
|
updatedData.update(updateData)
|
|
updatedMandate = Mandate(**updatedData)
|
|
|
|
# Update mandate record
|
|
self.db.recordModify(Mandate, mandateId, updatedMandate)
|
|
|
|
# Clear cache to ensure fresh data
|
|
|
|
# Get updated mandate
|
|
updatedMandate = self.getMandate(mandateId)
|
|
if not updatedMandate:
|
|
raise ValueError("Failed to retrieve updated mandate")
|
|
|
|
return updatedMandate
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating mandate: {str(e)}")
|
|
raise ValueError(f"Failed to update mandate: {str(e)}")
|
|
|
|
def deleteMandate(self, mandateId: str) -> bool:
|
|
"""Deletes a mandate if user has access."""
|
|
try:
|
|
# Check if mandate exists and user has access
|
|
mandate = self.getMandate(mandateId)
|
|
if not mandate:
|
|
return False
|
|
|
|
if not self.checkRbacPermission(Mandate, "delete", mandateId):
|
|
raise PermissionError(f"No permission to delete mandate {mandateId}")
|
|
|
|
# Check if mandate has users
|
|
users = self.getUsersByMandate(mandateId)
|
|
if users:
|
|
raise ValueError(
|
|
f"Cannot delete mandate {mandateId} with existing users"
|
|
)
|
|
|
|
# Delete mandate
|
|
success = self.db.recordDelete(Mandate, mandateId)
|
|
|
|
# Clear cache to ensure fresh data
|
|
|
|
return success
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting mandate: {str(e)}")
|
|
raise ValueError(f"Failed to delete mandate: {str(e)}")
|
|
|
|
# Token methods
|
|
|
|
def saveAccessToken(self, token: Token, replace_existing: bool = True) -> None:
|
|
"""Save an access token for the current user (must NOT have connectionId)"""
|
|
try:
|
|
# Validate that this is NOT a connection token
|
|
if token.connectionId:
|
|
raise ValueError(
|
|
"Access tokens cannot have connectionId - use saveConnectionToken instead"
|
|
)
|
|
|
|
# Validate user context
|
|
if not self.currentUser or not self.currentUser.id:
|
|
raise ValueError("No valid user context available for token storage")
|
|
|
|
# Set the user ID and mandate ID
|
|
token.userId = self.currentUser.id
|
|
|
|
# Ensure token has required fields
|
|
if not token.id:
|
|
token.id = str(uuid.uuid4())
|
|
if not token.createdAt:
|
|
token.createdAt = getUtcTimestamp()
|
|
|
|
# If replace_existing is True, delete old access tokens for this user and authority first
|
|
if replace_existing:
|
|
try:
|
|
old_tokens = self.db.getRecordset(
|
|
Token,
|
|
recordFilter={
|
|
"userId": self.currentUser.id,
|
|
"authority": token.authority,
|
|
"connectionId": None, # Ensure we only delete access tokens
|
|
},
|
|
)
|
|
deleted_count = 0
|
|
for old_token in old_tokens:
|
|
if (
|
|
old_token["id"] != token.id
|
|
): # Don't delete the new token if it already exists
|
|
self.db.recordDelete(Token, old_token["id"])
|
|
deleted_count += 1
|
|
|
|
if deleted_count > 0:
|
|
logger.info(
|
|
f"Replaced {deleted_count} old access tokens for user {self.currentUser.id} and authority {token.authority}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to delete old access tokens for user {self.currentUser.id} and authority {token.authority}: {str(e)}"
|
|
)
|
|
# Continue with saving the new token even if deletion fails
|
|
|
|
# Convert to dict and ensure all fields are properly set
|
|
token_dict = token.model_dump()
|
|
# Ensure userId is set to current user
|
|
# Convert to dict and ensure all fields are properly set
|
|
token_dict = token.model_dump()
|
|
# Ensure userId is set to current user
|
|
token_dict["userId"] = self.currentUser.id
|
|
|
|
# Save to database
|
|
self.db.recordCreate(Token, token_dict)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving access token: {str(e)}")
|
|
raise
|
|
|
|
def saveConnectionToken(self, token: Token, replace_existing: bool = True) -> None:
|
|
"""Save a connection token (must have connectionId)"""
|
|
try:
|
|
# Validate that this IS a connection token
|
|
if not token.connectionId:
|
|
raise ValueError(
|
|
"Connection tokens must have connectionId - use saveAccessToken instead"
|
|
)
|
|
|
|
# Validate user context
|
|
if not self.currentUser or not self.currentUser.id:
|
|
raise ValueError("No valid user context available for token storage")
|
|
|
|
# Set the user ID for the connection token
|
|
token.userId = self.currentUser.id
|
|
|
|
# Ensure token has required fields
|
|
if not token.id:
|
|
token.id = str(uuid.uuid4())
|
|
if not token.createdAt:
|
|
token.createdAt = getUtcTimestamp()
|
|
|
|
# Convert to dict and ensure all fields are properly set
|
|
token_dict = token.model_dump()
|
|
# Ensure userId is set to current user
|
|
token_dict["userId"] = self.currentUser.id
|
|
|
|
# Save to database
|
|
self.db.recordCreate(Token, token_dict)
|
|
|
|
# After successful save, delete old tokens for this connectionId (if requested)
|
|
if replace_existing:
|
|
try:
|
|
old_tokens = self.db.getRecordset(
|
|
Token, recordFilter={"connectionId": token.connectionId}
|
|
)
|
|
deleted_count = 0
|
|
for old_token in old_tokens:
|
|
if old_token["id"] != token.id:
|
|
self.db.recordDelete(Token, old_token["id"])
|
|
deleted_count += 1
|
|
|
|
if deleted_count > 0:
|
|
logger.info(
|
|
f"Replaced {deleted_count} old tokens for connectionId {token.connectionId}"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Failed to delete old tokens for connectionId {token.connectionId}: {str(e)}"
|
|
)
|
|
# Keep the newly saved token; cleanup can be retried later
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error saving connection token: {str(e)}")
|
|
raise
|
|
|
|
def getConnectionToken(self, connectionId: str) -> Optional[Token]:
|
|
"""Get the latest stored token for a specific connectionId (no refresh)."""
|
|
try:
|
|
# Validate connectionId
|
|
if not connectionId:
|
|
raise ValueError("connectionId is required for getConnectionToken")
|
|
|
|
# Get token for this specific connection
|
|
# Query for specific connection
|
|
tokens = self.db.getRecordset(
|
|
Token, recordFilter={"connectionId": connectionId}
|
|
)
|
|
|
|
if not tokens:
|
|
logger.warning(
|
|
f"No connection token found for connectionId: {connectionId}"
|
|
)
|
|
return None
|
|
|
|
# Sort by expiration date and get the latest (most recent expiration)
|
|
tokens.sort(key=lambda x: parseTimestamp(x.get("expiresAt"), default=0), reverse=True)
|
|
latest_token = Token(**tokens[0])
|
|
|
|
# No auto-refresh here. Callers should use a higher-level service to refresh when needed.
|
|
|
|
return latest_token
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error getting connection token for connectionId {connectionId}: {str(e)}"
|
|
)
|
|
return None
|
|
|
|
def findActiveTokenById(
|
|
self,
|
|
tokenId: str,
|
|
userId: str,
|
|
authority: AuthAuthority,
|
|
sessionId: str = None,
|
|
mandateId: str = None,
|
|
) -> Optional[Token]:
|
|
"""Find an active access token by its id (jti) with optional session/tenant scoping."""
|
|
try:
|
|
recordFilter = {
|
|
"id": tokenId,
|
|
"userId": userId,
|
|
"authority": authority.value
|
|
if hasattr(authority, "value")
|
|
else str(authority),
|
|
"status": TokenStatus.ACTIVE,
|
|
}
|
|
if sessionId is not None:
|
|
recordFilter["sessionId"] = sessionId
|
|
if mandateId is not None:
|
|
recordFilter["mandateId"] = mandateId
|
|
tokens = self.db.getRecordset(Token, recordFilter=recordFilter)
|
|
if not tokens:
|
|
return None
|
|
return Token(**tokens[0])
|
|
except Exception as e:
|
|
logger.error(f"Error finding active token by id {tokenId}: {str(e)}")
|
|
return None
|
|
|
|
def revokeTokenById(self, tokenId: str, revokedBy: str, reason: str = None) -> bool:
|
|
"""Revoke a single token by id by setting status fields (no delete)."""
|
|
try:
|
|
existing = self.db.getRecordset(Token, recordFilter={"id": tokenId})
|
|
if not existing:
|
|
return False
|
|
token = existing[0]
|
|
if token.get("status") == TokenStatus.REVOKED:
|
|
return True
|
|
tokenUpdate = {
|
|
"status": TokenStatus.REVOKED,
|
|
"revokedAt": getUtcTimestamp(),
|
|
"revokedBy": revokedBy,
|
|
"reason": reason or "revoked",
|
|
}
|
|
self.db.recordModify(Token, tokenId, tokenUpdate)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error revoking token {tokenId}: {str(e)}")
|
|
return False
|
|
|
|
def revokeTokensBySessionId(
|
|
self,
|
|
sessionId: str,
|
|
userId: str,
|
|
authority: AuthAuthority,
|
|
revokedBy: str,
|
|
reason: str = None,
|
|
) -> int:
|
|
"""Revoke all tokens of a session for a user/authority."""
|
|
try:
|
|
tokens = self.db.getRecordset(
|
|
Token,
|
|
recordFilter={
|
|
"userId": userId,
|
|
"authority": authority.value
|
|
if hasattr(authority, "value")
|
|
else str(authority),
|
|
"sessionId": sessionId,
|
|
"status": TokenStatus.ACTIVE,
|
|
},
|
|
)
|
|
count = 0
|
|
for t in tokens:
|
|
self.db.recordModify(
|
|
Token,
|
|
t["id"],
|
|
{
|
|
"status": TokenStatus.REVOKED,
|
|
"revokedAt": getUtcTimestamp(),
|
|
"revokedBy": revokedBy,
|
|
"reason": reason or "session logout",
|
|
},
|
|
)
|
|
count += 1
|
|
return count
|
|
except Exception as e:
|
|
logger.error(f"Error revoking tokens for session {sessionId}: {str(e)}")
|
|
return 0
|
|
|
|
def revokeTokensByUser(
|
|
self,
|
|
userId: str,
|
|
authority: AuthAuthority = None,
|
|
mandateId: str = None,
|
|
revokedBy: str = None,
|
|
reason: str = None,
|
|
) -> int:
|
|
"""Revoke all active tokens for a user, optionally filtered by authority/mandate."""
|
|
try:
|
|
# Fetch all active tokens for user (optionally filtered by authority)
|
|
recordFilter = {
|
|
"userId": userId,
|
|
"status": TokenStatus.ACTIVE,
|
|
}
|
|
if authority is not None:
|
|
recordFilter["authority"] = (
|
|
authority.value if hasattr(authority, "value") else str(authority)
|
|
)
|
|
tokens = self.db.getRecordset(Token, recordFilter=recordFilter)
|
|
count = 0
|
|
for t in tokens:
|
|
self.db.recordModify(
|
|
Token,
|
|
t["id"],
|
|
{
|
|
"status": TokenStatus.REVOKED,
|
|
"revokedAt": getUtcTimestamp(),
|
|
"revokedBy": revokedBy,
|
|
"reason": reason or "admin revoke",
|
|
},
|
|
)
|
|
count += 1
|
|
return count
|
|
except Exception as e:
|
|
logger.error(f"Error revoking tokens for user {userId}: {str(e)}")
|
|
return 0
|
|
|
|
def cleanupExpiredTokens(self) -> int:
|
|
"""Clean up expired tokens for all connections, returns count of cleaned tokens"""
|
|
try:
|
|
current_time = getUtcTimestamp()
|
|
cleaned_count = 0
|
|
|
|
# Get all tokens
|
|
all_tokens = self.db.getRecordset(Token, recordFilter={})
|
|
|
|
for token_data in all_tokens:
|
|
expiresAt = parseTimestamp(token_data.get("expiresAt"))
|
|
if expiresAt and expiresAt < current_time:
|
|
# Token is expired, delete it
|
|
self.db.recordDelete(Token, token_data["id"])
|
|
cleaned_count += 1
|
|
|
|
# Clear cache to ensure fresh data
|
|
if cleaned_count > 0:
|
|
logger.info(f"Cleaned up {cleaned_count} expired tokens")
|
|
|
|
return cleaned_count
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error cleaning up expired tokens: {str(e)}")
|
|
return 0
|
|
|
|
def logout(self) -> None:
|
|
"""Logout current user - clear user context and tokens"""
|
|
try:
|
|
# Clear user context
|
|
self.currentUser = None
|
|
self.userId = None
|
|
self.mandateId = None
|
|
self.rbac = None
|
|
|
|
# Clear database context
|
|
if hasattr(self, "db"):
|
|
self.db.updateContext("")
|
|
|
|
logger.info("User logged out successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during logout: {str(e)}")
|
|
raise
|
|
|
|
# Neutralization methods
|
|
|
|
def getNeutralizationConfig(self) -> Optional[DataNeutraliserConfig]:
|
|
"""Get the data neutralization configuration for the current user's mandate"""
|
|
try:
|
|
# Use RBAC filtering
|
|
filtered_configs = self.db.getRecordsetWithRBAC(
|
|
DataNeutraliserConfig,
|
|
self.currentUser,
|
|
recordFilter={"mandateId": self.mandateId}
|
|
)
|
|
|
|
if not filtered_configs:
|
|
return None
|
|
|
|
# Filter out database-specific fields
|
|
configDict = filtered_configs[0]
|
|
cleanedConfig = {k: v for k, v in configDict.items() if not k.startswith("_")}
|
|
return DataNeutraliserConfig(**cleanedConfig)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting neutralization config: {str(e)}")
|
|
return None
|
|
|
|
def createOrUpdateNeutralizationConfig(
|
|
self, config_data: Dict[str, Any]
|
|
) -> DataNeutraliserConfig:
|
|
"""Create or update the data neutralization configuration"""
|
|
try:
|
|
# Check if config already exists
|
|
existing_config = self.getNeutralizationConfig()
|
|
|
|
if existing_config:
|
|
# Update existing config
|
|
update_data = existing_config.model_dump()
|
|
update_data.update(config_data)
|
|
update_data["updatedAt"] = getUtcTimestamp()
|
|
|
|
updated_config = DataNeutraliserConfig(**update_data)
|
|
self.db.recordModify(
|
|
DataNeutraliserConfig, existing_config.id, updated_config
|
|
)
|
|
|
|
return updated_config
|
|
else:
|
|
# Create new config
|
|
config_data["mandateId"] = self.mandateId
|
|
config_data["userId"] = self.userId
|
|
|
|
new_config = DataNeutraliserConfig(**config_data)
|
|
created_record = self.db.recordCreate(DataNeutraliserConfig, new_config)
|
|
|
|
return DataNeutraliserConfig(**created_record)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating/updating neutralization config: {str(e)}")
|
|
raise ValueError(f"Failed to create/update neutralization config: {str(e)}")
|
|
|
|
def getNeutralizationAttributes(
|
|
self, file_id: Optional[str] = None
|
|
) -> List[DataNeutralizerAttributes]:
|
|
"""Get neutralization attributes, optionally filtered by file ID"""
|
|
try:
|
|
filter_dict = {"mandateId": self.mandateId}
|
|
if file_id:
|
|
filter_dict["fileId"] = file_id
|
|
|
|
# Use RBAC filtering
|
|
filtered_attributes = self.db.getRecordsetWithRBAC(
|
|
DataNeutralizerAttributes,
|
|
self.currentUser,
|
|
recordFilter=filter_dict
|
|
)
|
|
|
|
# Filter out database-specific fields
|
|
cleaned_attributes = []
|
|
for attr in filtered_attributes:
|
|
cleanedAttr = {k: v for k, v in attr.items() if not k.startswith("_")}
|
|
cleaned_attributes.append(cleanedAttr)
|
|
|
|
return [
|
|
DataNeutralizerAttributes(**attr)
|
|
for attr in cleaned_attributes
|
|
]
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting neutralization attributes: {str(e)}")
|
|
return []
|
|
|
|
def deleteNeutralizationAttributes(self, file_id: str) -> bool:
|
|
"""Delete all neutralization attributes for a specific file"""
|
|
try:
|
|
attributes = self.db.getRecordset(
|
|
DataNeutralizerAttributes,
|
|
recordFilter={"mandateId": self.mandateId, "fileId": file_id},
|
|
)
|
|
|
|
for attribute in attributes:
|
|
self.db.recordDelete(DataNeutralizerAttributes, attribute["id"])
|
|
|
|
logger.info(
|
|
f"Deleted {len(attributes)} neutralization attributes for file {file_id}"
|
|
)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting neutralization attributes: {str(e)}")
|
|
return False
|
|
|
|
# RBAC CRUD Methods
|
|
|
|
def createAccessRule(self, accessRule: AccessRule) -> AccessRule:
|
|
"""
|
|
Create a new access rule.
|
|
|
|
Args:
|
|
accessRule: AccessRule object to create
|
|
|
|
Returns:
|
|
Created AccessRule object
|
|
"""
|
|
try:
|
|
createdRule = self.db.recordCreate(AccessRule, accessRule)
|
|
logger.info(f"Created access rule with ID {createdRule.get('id')}")
|
|
return AccessRule(**createdRule)
|
|
except Exception as e:
|
|
logger.error(f"Error creating access rule: {str(e)}")
|
|
raise
|
|
|
|
def getAccessRule(self, ruleId: str) -> Optional[AccessRule]:
|
|
"""
|
|
Get an access rule by ID.
|
|
|
|
Args:
|
|
ruleId: Access rule ID
|
|
|
|
Returns:
|
|
AccessRule object if found, None otherwise
|
|
"""
|
|
try:
|
|
rules = self.db.getRecordset(AccessRule, recordFilter={"id": ruleId})
|
|
if rules:
|
|
return AccessRule(**rules[0])
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting access rule {ruleId}: {str(e)}")
|
|
return None
|
|
|
|
def updateAccessRule(self, ruleId: str, accessRule: AccessRule) -> AccessRule:
|
|
"""
|
|
Update an existing access rule.
|
|
|
|
Args:
|
|
ruleId: Access rule ID
|
|
accessRule: Updated AccessRule object
|
|
|
|
Returns:
|
|
Updated AccessRule object
|
|
"""
|
|
try:
|
|
updatedRule = self.db.recordUpdate(AccessRule, ruleId, accessRule.model_dump())
|
|
logger.info(f"Updated access rule with ID {ruleId}")
|
|
return AccessRule(**updatedRule)
|
|
except Exception as e:
|
|
logger.error(f"Error updating access rule {ruleId}: {str(e)}")
|
|
raise
|
|
|
|
def deleteAccessRule(self, ruleId: str) -> bool:
|
|
"""
|
|
Delete an access rule.
|
|
|
|
Args:
|
|
ruleId: Access rule ID
|
|
|
|
Returns:
|
|
True if deleted successfully, False otherwise
|
|
"""
|
|
try:
|
|
self.db.recordDelete(AccessRule, ruleId)
|
|
logger.info(f"Deleted access rule with ID {ruleId}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error deleting access rule {ruleId}: {str(e)}")
|
|
return False
|
|
|
|
def getAccessRules(
|
|
self,
|
|
roleLabel: Optional[str] = None,
|
|
context: Optional[AccessRuleContext] = None,
|
|
item: Optional[str] = None
|
|
) -> List[AccessRule]:
|
|
"""
|
|
Get access rules with optional filters.
|
|
|
|
Args:
|
|
roleLabel: Optional role label filter
|
|
context: Optional context filter
|
|
item: Optional item filter
|
|
|
|
Returns:
|
|
List of AccessRule objects
|
|
"""
|
|
try:
|
|
recordFilter = {}
|
|
if roleLabel:
|
|
recordFilter["roleLabel"] = roleLabel
|
|
if context:
|
|
recordFilter["context"] = context.value
|
|
if item:
|
|
recordFilter["item"] = item
|
|
|
|
rules = self.db.getRecordset(AccessRule, recordFilter=recordFilter if recordFilter else None)
|
|
return [AccessRule(**rule) for rule in rules]
|
|
except Exception as e:
|
|
logger.error(f"Error getting access rules: {str(e)}")
|
|
return []
|
|
|
|
def getAccessRulesForRoles(
|
|
self,
|
|
roleLabels: List[str],
|
|
context: AccessRuleContext,
|
|
item: str
|
|
) -> List[AccessRule]:
|
|
"""
|
|
Get access rules for multiple roles, context, and item.
|
|
Returns the most specific matching rules for each role.
|
|
|
|
Args:
|
|
roleLabels: List of role labels
|
|
context: Context type
|
|
item: Item identifier
|
|
|
|
Returns:
|
|
List of AccessRule objects (most specific for each role)
|
|
"""
|
|
try:
|
|
RbacInstance = RbacClass(self.db)
|
|
allRules = []
|
|
|
|
for roleLabel in roleLabels:
|
|
# Get all rules for this role and context
|
|
roleRules = RbacInstance._getRulesForRole(roleLabel, context)
|
|
|
|
# Find most specific rule for this item
|
|
mostSpecificRule = RbacInstance.findMostSpecificRule(roleRules, item)
|
|
|
|
if mostSpecificRule:
|
|
allRules.append(mostSpecificRule)
|
|
|
|
return allRules
|
|
except Exception as e:
|
|
logger.error(f"Error getting access rules for roles: {str(e)}")
|
|
return []
|
|
|
|
|
|
# Public Methods
|
|
|
|
|
|
def getInterface(currentUser: User) -> AppObjects:
|
|
"""
|
|
Returns a AppObjects instance for the current user.
|
|
Handles initialization of database and records.
|
|
"""
|
|
if not currentUser:
|
|
raise ValueError("Invalid user context: user is required")
|
|
|
|
# Create context key
|
|
contextKey = f"{currentUser.mandateId}_{currentUser.id}"
|
|
|
|
# Create new instance if not exists
|
|
if contextKey not in _gatewayInterfaces:
|
|
_gatewayInterfaces[contextKey] = AppObjects(currentUser)
|
|
|
|
return _gatewayInterfaces[contextKey]
|
|
|
|
|
|
def getRootInterface() -> AppObjects:
|
|
"""
|
|
Returns a AppObjects instance with root privileges.
|
|
This is used for initial setup and user creation.
|
|
"""
|
|
global _rootAppObjects
|
|
|
|
if _rootAppObjects is None:
|
|
try:
|
|
# Create a temporary interface without user context to get root user
|
|
tempInterface = AppObjects()
|
|
|
|
# Get the initial user directly
|
|
initialUserId = tempInterface.getInitialId(UserInDB)
|
|
if not initialUserId:
|
|
raise ValueError("No initial user ID found in database")
|
|
|
|
users = tempInterface.db.getRecordset(
|
|
UserInDB, recordFilter={"id": initialUserId}
|
|
)
|
|
if not users:
|
|
raise ValueError("Initial user not found in database")
|
|
|
|
# Convert to User model (use helper compatible with our models)
|
|
user_data = users[0]
|
|
rootUser = User(**user_data)
|
|
|
|
# Create root interface with the root user
|
|
_rootAppObjects = AppObjects(rootUser)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting root user: {str(e)}")
|
|
raise ValueError(f"Failed to get root user: {str(e)}")
|
|
|
|
return _rootAppObjects
|