2243 lines
No EOL
96 KiB
Python
2243 lines
No EOL
96 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Interface to Management database and AI Connectors.
|
|
Uses the JSON connector for data access with added language support.
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
import base64
|
|
import hashlib
|
|
import math
|
|
import mimetypes
|
|
from typing import Dict, Any, List, Optional, Union
|
|
|
|
from modules.connectors.connectorDbPostgre import DatabaseConnector, _get_cached_connector
|
|
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
|
from modules.security.rbac import RbacClass
|
|
from modules.datamodels.datamodelRbac import AccessRuleContext
|
|
from modules.datamodels.datamodelUam import AccessLevel
|
|
from modules.datamodels.datamodelFiles import FilePreview, FileItem, FileData
|
|
from modules.datamodels.datamodelFileFolder import FileFolder
|
|
from modules.datamodels.datamodelUtils import Prompt
|
|
from modules.datamodels.datamodelVoice import VoiceSettings
|
|
from modules.datamodels.datamodelMessaging import (
|
|
MessagingSubscription,
|
|
MessagingSubscriptionRegistration,
|
|
MessagingDelivery,
|
|
MessagingChannel
|
|
)
|
|
from modules.datamodels.datamodelUam import User, Mandate
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.shared.timeUtils import getUtcTimestamp
|
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Singleton factory for Management instances with AI service per context
|
|
_instancesManagement = {}
|
|
|
|
# Custom exceptions for file handling
|
|
class FileError(Exception):
|
|
"""Base class for file handling exceptions."""
|
|
pass
|
|
|
|
class FileNotFoundError(FileError):
|
|
"""Exception raised when a file is not found."""
|
|
pass
|
|
|
|
class FileStorageError(FileError):
|
|
"""Exception raised when there's an error storing a file."""
|
|
pass
|
|
|
|
class FilePermissionError(FileError):
|
|
"""Exception raised when there's a permission issue with a file."""
|
|
pass
|
|
|
|
class FileDeletionError(FileError):
|
|
"""Exception raised when there's an error deleting a file."""
|
|
pass
|
|
|
|
class ComponentObjects:
|
|
"""
|
|
Interface to Management database and AI Connectors.
|
|
Uses the JSON connector for data access with added language support.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initializes the Management Interface."""
|
|
# Initialize variables first
|
|
self.currentUser: Optional[User] = None
|
|
self.userId: Optional[str] = None
|
|
self.rbac: Optional[RbacClass] = None # RBAC interface
|
|
|
|
# Initialize database
|
|
self._initializeDatabase()
|
|
|
|
# Initialize standard records if needed
|
|
self._initRecords()
|
|
|
|
def setUserContext(self, currentUser: User, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
|
|
"""Sets the user context for the interface.
|
|
|
|
Args:
|
|
currentUser: The authenticated user
|
|
mandateId: The mandate ID from RequestContext (X-Mandate-Id header)
|
|
featureInstanceId: The feature instance ID from RequestContext (X-Feature-Instance-Id header)
|
|
"""
|
|
if not currentUser:
|
|
logger.info("Initializing interface without user context")
|
|
return
|
|
|
|
self.currentUser = currentUser # Store User object directly
|
|
self.userId = currentUser.id
|
|
# Use mandateId from parameter (Request-Context), not from user object
|
|
self.mandateId = mandateId
|
|
self.featureInstanceId = featureInstanceId
|
|
|
|
if not self.userId:
|
|
raise ValueError("Invalid user context: id is required")
|
|
|
|
# Add language settings
|
|
self.userLanguage = currentUser.language # Default user language
|
|
|
|
# Initialize RBAC interface
|
|
if not self.currentUser:
|
|
raise ValueError("User context is required for RBAC")
|
|
# Get DbApp connection for RBAC AccessRule queries
|
|
from modules.security.rootAccess import getRootDbAppConnector
|
|
dbApp = getRootDbAppConnector()
|
|
self.rbac = RbacClass(self.db, dbApp=dbApp)
|
|
|
|
# Update database context
|
|
self.db.updateContext(self.userId)
|
|
|
|
def __del__(self):
|
|
"""Cleanup method to close database connection."""
|
|
if hasattr(self, 'db') and self.db is not None:
|
|
try:
|
|
self.db.close()
|
|
except Exception as e:
|
|
logger.error(f"Error closing database connection: {e}")
|
|
|
|
logger.debug(f"User context set: userId={self.userId}")
|
|
|
|
def _initializeDatabase(self):
|
|
"""Initializes the database connection directly."""
|
|
try:
|
|
# Get configuration values with defaults
|
|
dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data")
|
|
dbDatabase = "poweron_management"
|
|
dbUser = APP_CONFIG.get("DB_USER")
|
|
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
|
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
|
|
|
self.db = _get_cached_connector(
|
|
dbHost=dbHost,
|
|
dbDatabase=dbDatabase,
|
|
dbUser=dbUser,
|
|
dbPassword=dbPassword,
|
|
dbPort=dbPort,
|
|
userId=self.userId if hasattr(self, 'userId') else None
|
|
)
|
|
|
|
logger.info("Database initialized successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize database: {str(e)}")
|
|
raise
|
|
|
|
def _separateObjectFields(self, model_class, data: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
|
"""Separate simple fields from object fields based on Pydantic model structure."""
|
|
simpleFields = {}
|
|
objectFields = {}
|
|
|
|
# Get field information from the Pydantic model
|
|
modelFields = model_class.model_fields
|
|
|
|
for fieldName, value in data.items():
|
|
# Check if this field should be stored as JSONB in the database
|
|
if fieldName in modelFields:
|
|
fieldInfo = modelFields[fieldName]
|
|
# Pydantic v2 only
|
|
fieldType = fieldInfo.annotation
|
|
|
|
# Check if this is a JSONB field (Dict, List, or complex types)
|
|
# Purely type-based detection - no hardcoded field names
|
|
if (fieldType == dict or
|
|
fieldType == list or
|
|
(hasattr(fieldType, '__origin__') and fieldType.__origin__ in (dict, list))):
|
|
# Store as JSONB - include in simple_fields for database storage
|
|
simpleFields[fieldName] = value
|
|
elif isinstance(value, (str, int, float, bool, type(None))):
|
|
# Simple scalar types
|
|
simpleFields[fieldName] = value
|
|
else:
|
|
# Complex objects that should be filtered out
|
|
objectFields[fieldName] = value
|
|
else:
|
|
# Field not in model - treat as scalar if simple, otherwise filter out
|
|
# BUT: always include metadata fields (_createdBy, _createdAt, etc.) as they're handled by connector
|
|
if fieldName.startswith("_"):
|
|
# Metadata fields should be passed through to connector
|
|
simpleFields[fieldName] = value
|
|
elif isinstance(value, (str, int, float, bool, type(None))):
|
|
simpleFields[fieldName] = value
|
|
else:
|
|
objectFields[fieldName] = value
|
|
|
|
return simpleFields, objectFields
|
|
|
|
def _initRecords(self):
|
|
"""Initializes standard records in the database if they don't exist."""
|
|
try:
|
|
# Initialize standard prompts
|
|
self._initializeStandardPrompts()
|
|
|
|
# Add other record initializations here
|
|
|
|
logger.info("Standard records initialized successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize standard records: {str(e)}")
|
|
# Don't raise the error, just log it
|
|
# This allows the interface to be created even if initialization fails
|
|
|
|
def _initializeStandardPrompts(self):
|
|
"""Initializes standard prompts if they don't exist yet."""
|
|
try:
|
|
# Check if any prompts exist
|
|
existingPrompts = self.db.getRecordset(Prompt)
|
|
if existingPrompts:
|
|
logger.info("Prompts already exist, skipping initialization")
|
|
return
|
|
|
|
# Get the root interface to access the initial mandate ID
|
|
from modules.security.rootAccess import getRootUser
|
|
from modules.interfaces.interfaceDbApp import getInterface
|
|
rootUser = getRootUser()
|
|
rootInterface = getInterface(rootUser)
|
|
|
|
# Get initial mandate ID through the root interface
|
|
mandateId = rootInterface.getInitialId(Mandate)
|
|
if not mandateId:
|
|
logger.error("No initial mandate ID found")
|
|
return
|
|
|
|
# Get root user for initialization
|
|
rootUser = rootInterface.getUserByUsername("admin")
|
|
if not rootUser:
|
|
logger.error("Root user not found for initialization")
|
|
return
|
|
|
|
# Store current user context if it exists
|
|
currentUser = self.currentUser
|
|
|
|
# Set user context to root user for initialization
|
|
self.setUserContext(rootUser)
|
|
|
|
# Define standard prompts
|
|
standardPrompts = [
|
|
Prompt(
|
|
name="Market Research",
|
|
content="Research the current market trends and developments in [TOPIC]. Collect information about leading companies, innovative products or services, and current challenges. Present the results in a structured overview with relevant data and sources.",
|
|
mandateId=mandateId
|
|
),
|
|
Prompt(
|
|
name="Data Analysis",
|
|
content="Analyze the attached dataset on [TOPIC] and identify the most important trends, patterns, and anomalies. Perform statistical calculations to support your findings. Present the results in a clearly structured analysis and draw relevant conclusions.",
|
|
mandateId=mandateId
|
|
),
|
|
Prompt(
|
|
name="Meeting Protocol",
|
|
content="Create a detailed protocol of our meeting on [TOPIC]. Capture all discussed points, decisions made, and agreed measures. Structure the protocol clearly with agenda items, participant list, and clear responsibilities for follow-up actions.",
|
|
mandateId=mandateId
|
|
),
|
|
Prompt(
|
|
name="UI/UX Design",
|
|
content="Develop a UI/UX design concept for [APPLICATION/WEBSITE]. Consider the target audience, main functions, and brand identity. Describe the visual design, navigation, interaction patterns, and information architecture. Explain how the design optimizes user-friendliness and user experience.",
|
|
mandateId=mandateId
|
|
),
|
|
Prompt(
|
|
name="Primzahlen",
|
|
content="Gib mir die ersten 1000 Primzahlen.",
|
|
mandateId=mandateId
|
|
),
|
|
Prompt(
|
|
name="E-Mail",
|
|
content="Bereite mir eine formelle E-Mail an peter.muster@domain.com vor, um meinen Termin von 10 Uhr auf Freitag zu scheiben.",
|
|
mandateId=mandateId
|
|
)
|
|
]
|
|
|
|
# Create prompts
|
|
for prompt in standardPrompts:
|
|
self.db.recordCreate(Prompt, prompt)
|
|
logger.info(f"Created standard prompt: {prompt.name}")
|
|
|
|
# Restore original user context if it existed
|
|
if currentUser:
|
|
self.setUserContext(currentUser)
|
|
else:
|
|
self.currentUser = None
|
|
self.userId = None
|
|
self.db.updateContext("") # Reset database context
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error initializing standard prompts: {str(e)}")
|
|
# Ensure we restore user context even if there's an error
|
|
if 'currentUser' in locals() and currentUser:
|
|
self.setUserContext(currentUser)
|
|
else:
|
|
self.currentUser = None
|
|
self.userId = None
|
|
self.db.updateContext("") # Reset database context
|
|
|
|
|
|
def checkRbacPermission(
|
|
self,
|
|
modelClass: type,
|
|
operation: str,
|
|
recordId: Optional[str] = None
|
|
) -> bool:
|
|
"""
|
|
Check RBAC permission for a specific operation on a table.
|
|
|
|
Args:
|
|
modelClass: Pydantic model class for the table
|
|
operation: Operation to check ('create', 'update', 'delete', 'read')
|
|
recordId: Optional record ID for specific record check
|
|
|
|
Returns:
|
|
Boolean indicating permission
|
|
"""
|
|
if not self.rbac or not self.currentUser:
|
|
return False
|
|
|
|
tableName = modelClass.__name__
|
|
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
|
objectKey = buildDataObjectKey(tableName)
|
|
permissions = self.rbac.getUserPermissions(
|
|
self.currentUser,
|
|
AccessRuleContext.DATA,
|
|
objectKey,
|
|
mandateId=self.mandateId,
|
|
featureInstanceId=self.featureInstanceId
|
|
)
|
|
|
|
if operation == "create":
|
|
return permissions.create != AccessLevel.NONE
|
|
elif operation == "update":
|
|
return permissions.update != AccessLevel.NONE
|
|
elif operation == "delete":
|
|
return permissions.delete != AccessLevel.NONE
|
|
elif operation == "read":
|
|
return permissions.read != AccessLevel.NONE
|
|
else:
|
|
return False
|
|
|
|
def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
|
|
"""
|
|
Apply filter criteria to records.
|
|
|
|
Supports:
|
|
- General search: {"search": "text"} - searches across all text fields
|
|
- Field-specific filters:
|
|
- Simple: {"status": "running"} - equals match
|
|
- With operator: {"status": {"operator": "equals", "value": "running"}}
|
|
- Operators: equals, contains, gt, gte, lt, lte, in, notIn, startsWith, endsWith
|
|
|
|
Args:
|
|
records: List of record dictionaries to filter
|
|
filters: Filter criteria dictionary
|
|
|
|
Returns:
|
|
Filtered list of records
|
|
"""
|
|
if not filters or not records:
|
|
return records
|
|
|
|
filtered = []
|
|
|
|
for record in records:
|
|
matches = True
|
|
|
|
# Handle general search across text fields
|
|
if "search" in filters:
|
|
search_term = str(filters["search"]).lower()
|
|
if search_term:
|
|
# Search in all string fields
|
|
found = False
|
|
for key, value in record.items():
|
|
if isinstance(value, str) and search_term in value.lower():
|
|
found = True
|
|
break
|
|
elif isinstance(value, (int, float)) and search_term in str(value):
|
|
found = True
|
|
break
|
|
if not found:
|
|
matches = False
|
|
|
|
# Handle field-specific filters
|
|
for field_name, filter_value in filters.items():
|
|
if field_name == "search":
|
|
continue # Already handled above
|
|
|
|
if field_name not in record:
|
|
matches = False
|
|
break
|
|
|
|
record_value = record.get(field_name)
|
|
|
|
# Handle simple value (equals operator)
|
|
if not isinstance(filter_value, dict):
|
|
if str(record_value).lower() != str(filter_value).lower():
|
|
matches = False
|
|
break
|
|
continue
|
|
|
|
# Handle filter with operator
|
|
operator = filter_value.get("operator", "equals")
|
|
filter_val = filter_value.get("value")
|
|
|
|
if operator in ["equals", "eq"]:
|
|
if str(record_value).lower() != str(filter_val).lower():
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "contains":
|
|
record_str = str(record_value).lower() if record_value is not None else ""
|
|
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
|
if filter_str not in record_str:
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "startsWith":
|
|
record_str = str(record_value).lower() if record_value is not None else ""
|
|
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
|
if not record_str.startswith(filter_str):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "endsWith":
|
|
record_str = str(record_value).lower() if record_value is not None else ""
|
|
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
|
if not record_str.endswith(filter_str):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "gt":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('-inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('-inf')
|
|
if record_num <= filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "gte":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('-inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('-inf')
|
|
if record_num < filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "lt":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('inf')
|
|
if record_num >= filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "lte":
|
|
try:
|
|
record_num = float(record_value) if record_value is not None else float('inf')
|
|
filter_num = float(filter_val) if filter_val is not None else float('inf')
|
|
if record_num > filter_num:
|
|
matches = False
|
|
break
|
|
except (ValueError, TypeError):
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "in":
|
|
if not isinstance(filter_val, list):
|
|
filter_val = [filter_val]
|
|
if record_value not in filter_val:
|
|
matches = False
|
|
break
|
|
|
|
elif operator == "notIn":
|
|
if not isinstance(filter_val, list):
|
|
filter_val = [filter_val]
|
|
if record_value in filter_val:
|
|
matches = False
|
|
break
|
|
|
|
else:
|
|
# Unknown operator - default to equals
|
|
if record_value != filter_val:
|
|
matches = False
|
|
break
|
|
|
|
if matches:
|
|
filtered.append(record)
|
|
|
|
return filtered
|
|
|
|
def _applySorting(self, records: List[Dict[str, Any]], sortFields: List[Any]) -> List[Dict[str, Any]]:
|
|
"""Apply multi-level sorting to records using stable sort (sorts from least to most significant field)."""
|
|
if not sortFields:
|
|
return records
|
|
|
|
# Start with a copy to avoid modifying original
|
|
sortedRecords = list(records)
|
|
|
|
# Sort from least significant to most significant field (reverse order)
|
|
# Python's sort is stable, so this creates proper multi-level sorting
|
|
for sortField in reversed(sortFields):
|
|
# Handle both dict and object formats
|
|
if isinstance(sortField, dict):
|
|
fieldName = sortField.get("field")
|
|
direction = sortField.get("direction", "asc")
|
|
else:
|
|
fieldName = getattr(sortField, "field", None)
|
|
direction = getattr(sortField, "direction", "asc")
|
|
|
|
if not fieldName:
|
|
continue
|
|
|
|
isDesc = (direction == "desc")
|
|
|
|
def sortKey(record):
|
|
value = record.get(fieldName)
|
|
# Handle None values - place them at the end for both directions
|
|
if value is None:
|
|
# Use a special value that sorts last
|
|
return (1, "") # (is_none_flag, empty_value) - sorts after (0, ...)
|
|
else:
|
|
# Return tuple with type indicator for proper comparison
|
|
if isinstance(value, (int, float)):
|
|
return (0, value)
|
|
elif isinstance(value, str):
|
|
return (0, value)
|
|
elif isinstance(value, bool):
|
|
return (0, value)
|
|
else:
|
|
return (0, str(value))
|
|
|
|
# Sort with reverse parameter for descending
|
|
sortedRecords.sort(key=sortKey, reverse=isDesc)
|
|
|
|
return sortedRecords
|
|
|
|
# Utilities
|
|
|
|
def getInitialId(self, model_class: type) -> Optional[str]:
|
|
"""Returns the initial ID for a table."""
|
|
return self.db.getInitialId(model_class)
|
|
|
|
def _parse_size_string(self, size_str: str) -> Optional[int]:
|
|
"""
|
|
Parse a formatted size string (e.g., "2.13 MB", "1.5 GB") to bytes.
|
|
|
|
Args:
|
|
size_str: Formatted size string like "2.13 MB", "1.5 GB", "500 KB"
|
|
|
|
Returns:
|
|
Size in bytes, or None if parsing fails
|
|
"""
|
|
try:
|
|
size_str = size_str.strip().upper()
|
|
# Remove common separators and spaces
|
|
size_str = size_str.replace(",", "").replace(" ", "")
|
|
|
|
# Extract number and unit - handle both "MB" and "M" formats
|
|
import re
|
|
# Match: number (with optional decimal) followed by optional unit (K/M/G/T with optional B)
|
|
match = re.match(r"^([\d.]+)([KMGT]?B?)$", size_str)
|
|
if not match:
|
|
return None
|
|
|
|
number = float(match.group(1))
|
|
unit = match.group(2) or "B"
|
|
|
|
# Normalize unit (handle "M" as "MB", "K" as "KB", etc.)
|
|
if len(unit) == 1 and unit in "KMGT":
|
|
unit = unit + "B"
|
|
|
|
# Convert to bytes
|
|
multipliers = {
|
|
"B": 1,
|
|
"KB": 1024,
|
|
"MB": 1024 * 1024,
|
|
"GB": 1024 * 1024 * 1024,
|
|
"TB": 1024 * 1024 * 1024 * 1024,
|
|
}
|
|
|
|
multiplier = multipliers.get(unit, 1)
|
|
return int(number * multiplier)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
|
|
# Prompt methods
|
|
|
|
def _isSysAdmin(self) -> bool:
|
|
"""Check if the current user has sysadmin role (or isSysAdmin flag as fallback)."""
|
|
from modules.auth.authentication import _hasSysAdminRole
|
|
userId = getattr(self.currentUser, 'id', None)
|
|
if userId and _hasSysAdminRole(str(userId)):
|
|
return True
|
|
return hasattr(self.currentUser, 'isSysAdmin') and self.currentUser.isSysAdmin
|
|
|
|
def _enrichPromptsWithPermissions(self, prompts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Enrich prompts with row-level _permissions based on ownership and isSystem flag.
|
|
|
|
- SysAdmin: canUpdate=True, canDelete=True on all prompts
|
|
- Regular user on own prompts: canUpdate=True, canDelete=True
|
|
- Regular user on system prompts: canUpdate=False, canDelete=False (read-only)
|
|
"""
|
|
isSysAdmin = self._isSysAdmin()
|
|
for prompt in prompts:
|
|
isOwner = prompt.get("_createdBy") == self.userId
|
|
prompt["_permissions"] = {
|
|
"canUpdate": isOwner or isSysAdmin,
|
|
"canDelete": isOwner or isSysAdmin
|
|
}
|
|
return prompts
|
|
|
|
def _getPromptsForUser(self) -> List[Dict[str, Any]]:
|
|
"""Returns prompts visible to the current user.
|
|
|
|
Visibility rules:
|
|
- SysAdmin: ALL prompts
|
|
- Regular user: own prompts (_createdBy) + system prompts (isSystem=True)
|
|
"""
|
|
if self._isSysAdmin():
|
|
return self.db.getRecordset(Prompt)
|
|
|
|
# Get own prompts
|
|
ownPrompts = self.db.getRecordset(Prompt, recordFilter={"_createdBy": self.userId})
|
|
|
|
# Get system prompts
|
|
systemPrompts = self.db.getRecordset(Prompt, recordFilter={"isSystem": True})
|
|
|
|
# Merge and deduplicate (a user's own prompt could also be isSystem)
|
|
seen = {}
|
|
for p in ownPrompts:
|
|
seen[p["id"]] = p
|
|
for p in systemPrompts:
|
|
if p["id"] not in seen:
|
|
seen[p["id"]] = p
|
|
|
|
return list(seen.values())
|
|
|
|
def getAllPrompts(self, pagination: Optional[PaginationParams] = None) -> Union[List[Prompt], PaginatedResult]:
|
|
"""
|
|
Returns prompts with visibility rules:
|
|
- SysAdmin: sees ALL prompts, can CRUD all
|
|
- Regular user: sees own prompts + system prompts (isSystem=True), can only CRUD own
|
|
- Row-level _permissions control edit/delete buttons in the UI
|
|
|
|
NOTE: Cannot use db.getRecordsetPaginated() because visibility rules
|
|
(_getPromptsForUser: own + system for regular, all for SysAdmin) and
|
|
per-row _permissions enrichment require loading all records first.
|
|
|
|
Args:
|
|
pagination: Optional pagination parameters. If None, returns all items.
|
|
|
|
Returns:
|
|
If pagination is None: List[Prompt]
|
|
If pagination is provided: PaginatedResult with items and metadata
|
|
"""
|
|
try:
|
|
# Get prompts based on user role (own + system for regular, all for SysAdmin)
|
|
filteredPrompts = self._getPromptsForUser()
|
|
|
|
# Enrich with row-level permissions (_permissions: canUpdate, canDelete)
|
|
filteredPrompts = self._enrichPromptsWithPermissions(filteredPrompts)
|
|
|
|
# If no pagination requested, return all items
|
|
if pagination is None:
|
|
return [Prompt(**prompt) for prompt in filteredPrompts]
|
|
|
|
# Apply filtering (if filters provided)
|
|
if pagination.filters:
|
|
filteredPrompts = self._applyFilters(filteredPrompts, pagination.filters)
|
|
|
|
# Apply sorting (in order of sortFields)
|
|
if pagination.sort:
|
|
filteredPrompts = self._applySorting(filteredPrompts, pagination.sort)
|
|
|
|
# Count total items after filters
|
|
totalItems = len(filteredPrompts)
|
|
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
|
|
|
# Apply pagination (skip/limit)
|
|
startIdx = (pagination.page - 1) * pagination.pageSize
|
|
endIdx = startIdx + pagination.pageSize
|
|
pagedPrompts = filteredPrompts[startIdx:endIdx]
|
|
|
|
# Convert to model objects (extra='allow' on Prompt preserves system fields)
|
|
items = [Prompt(**prompt) for prompt in pagedPrompts]
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=totalItems,
|
|
totalPages=totalPages
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting prompts: {str(e)}")
|
|
if pagination is None:
|
|
return []
|
|
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
|
|
|
def getPrompt(self, promptId: str) -> Optional[Prompt]:
|
|
"""Returns a prompt by ID if the user has visibility.
|
|
|
|
Visibility: SysAdmin sees all, regular user sees own + system prompts.
|
|
"""
|
|
filteredPrompts = self.db.getRecordset(Prompt, recordFilter={"id": promptId})
|
|
if not filteredPrompts:
|
|
return None
|
|
|
|
prompt = filteredPrompts[0]
|
|
|
|
# Visibility check for non-SysAdmin: must be owner or system prompt
|
|
if not self._isSysAdmin():
|
|
isOwner = prompt.get("_createdBy") == self.userId
|
|
isSystem = prompt.get("isSystem", False)
|
|
if not isOwner and not isSystem:
|
|
return None
|
|
|
|
return Prompt(**prompt)
|
|
|
|
def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Creates a new prompt if user has permission."""
|
|
if not self.checkRbacPermission(Prompt, "create"):
|
|
raise PermissionError("No permission to create prompts")
|
|
|
|
# Create prompt record
|
|
createdRecord = self.db.recordCreate(Prompt, promptData)
|
|
if not createdRecord or not createdRecord.get("id"):
|
|
raise ValueError("Failed to create prompt record")
|
|
|
|
return createdRecord
|
|
|
|
def updatePrompt(self, promptId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Updates a prompt. Rules:
|
|
- SysAdmin: can update any prompt (including system prompts)
|
|
- Regular user: can only update own prompts (not system prompts)
|
|
"""
|
|
try:
|
|
# Get prompt (visibility-checked)
|
|
prompt = self.getPrompt(promptId)
|
|
if not prompt:
|
|
raise ValueError(f"Prompt {promptId} not found")
|
|
|
|
# Permission check: owner or SysAdmin
|
|
isOwner = (getattr(prompt, '_createdBy', None) == self.userId)
|
|
if not self._isSysAdmin() and not isOwner:
|
|
raise PermissionError(f"No permission to update prompt {promptId}")
|
|
|
|
# Regular users cannot set isSystem flag
|
|
if not self._isSysAdmin() and 'isSystem' in updateData:
|
|
del updateData['isSystem']
|
|
|
|
# Update prompt record directly with the update data
|
|
self.db.recordModify(Prompt, promptId, updateData)
|
|
|
|
# Clear cache to ensure fresh data
|
|
|
|
# Get updated prompt
|
|
updatedPrompt = self.getPrompt(promptId)
|
|
if not updatedPrompt:
|
|
raise ValueError("Failed to retrieve updated prompt")
|
|
|
|
return updatedPrompt.model_dump()
|
|
|
|
except PermissionError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error updating prompt: {str(e)}")
|
|
raise ValueError(f"Failed to update prompt: {str(e)}")
|
|
|
|
def deletePrompt(self, promptId: str) -> bool:
|
|
"""Deletes a prompt. Rules:
|
|
- SysAdmin: can delete any prompt (including system prompts)
|
|
- Regular user: can only delete own prompts (not system prompts)
|
|
"""
|
|
# Get prompt (visibility-checked)
|
|
prompt = self.getPrompt(promptId)
|
|
if not prompt:
|
|
return False
|
|
|
|
# Permission check: owner or SysAdmin
|
|
isOwner = (getattr(prompt, '_createdBy', None) == self.userId)
|
|
if not self._isSysAdmin() and not isOwner:
|
|
raise PermissionError(f"No permission to delete prompt {promptId}")
|
|
|
|
# Delete prompt
|
|
success = self.db.recordDelete(Prompt, promptId)
|
|
|
|
return success
|
|
|
|
# File Utilities
|
|
|
|
def checkForDuplicateFile(self, fileHash: str, fileName: str) -> Optional[FileItem]:
|
|
"""Checks if a file with the same hash AND fileName already exists for the current user.
|
|
|
|
Duplicate = same user (_createdBy) + same fileHash + same fileName.
|
|
Same hash with different name is allowed (intentional copy by user).
|
|
Uses direct DB query (not RBAC) because files are isolated per user.
|
|
"""
|
|
if not self.userId:
|
|
return None
|
|
|
|
# Direct DB query: find files with matching hash + name + user
|
|
matchingFiles = self.db.getRecordset(
|
|
FileItem,
|
|
recordFilter={
|
|
"_createdBy": self.userId,
|
|
"fileHash": fileHash,
|
|
"fileName": fileName
|
|
}
|
|
)
|
|
|
|
if not matchingFiles:
|
|
return None
|
|
|
|
# Return first match
|
|
file = matchingFiles[0]
|
|
return FileItem(
|
|
id=file["id"],
|
|
mandateId=file.get("mandateId", ""),
|
|
featureInstanceId=file.get("featureInstanceId", ""),
|
|
fileName=file["fileName"],
|
|
mimeType=file["mimeType"],
|
|
fileHash=file["fileHash"],
|
|
fileSize=file["fileSize"],
|
|
creationDate=file["creationDate"]
|
|
)
|
|
|
|
def getMimeType(self, fileName: str) -> str:
|
|
"""Determines the MIME type based on the file extension."""
|
|
ext = os.path.splitext(fileName)[1].lower()[1:]
|
|
extensionToMime = {
|
|
"pdf": "application/pdf",
|
|
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
"doc": "application/msword",
|
|
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
|
"xls": "application/vnd.ms-excel",
|
|
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
|
"ppt": "application/vnd.ms-powerpoint",
|
|
"csv": "text/csv",
|
|
"txt": "text/plain",
|
|
"json": "application/json",
|
|
"xml": "application/xml",
|
|
"html": "text/html",
|
|
"htm": "text/html",
|
|
"jpg": "image/jpeg",
|
|
"jpeg": "image/jpeg",
|
|
"png": "image/png",
|
|
"gif": "image/gif",
|
|
"webp": "image/webp",
|
|
"svg": "image/svg+xml",
|
|
"py": "text/x-python",
|
|
"js": "application/javascript",
|
|
"css": "text/css",
|
|
"eml": "message/rfc822",
|
|
"msg": "application/vnd.ms-outlook",
|
|
}
|
|
return extensionToMime.get(ext.lower(), "application/octet-stream")
|
|
|
|
def isTextMimeType(self, mimeType: str) -> bool:
|
|
"""Determines if a MIME type represents a text-based format."""
|
|
textMimeTypes = {
|
|
'text/plain',
|
|
'text/html',
|
|
'text/css',
|
|
'text/javascript',
|
|
'text/x-python',
|
|
'text/csv',
|
|
'text/xml',
|
|
'application/json',
|
|
'application/xml',
|
|
'application/javascript',
|
|
'application/x-python',
|
|
'application/x-httpd-php',
|
|
'application/x-sh',
|
|
'application/x-shellscript',
|
|
'application/x-yaml',
|
|
'application/x-toml',
|
|
'application/x-markdown',
|
|
'application/x-latex',
|
|
'application/x-tex',
|
|
'application/x-rst',
|
|
'application/x-asciidoc',
|
|
'application/x-markdown',
|
|
'application/x-httpd-php',
|
|
'application/x-httpd-php-source',
|
|
'application/x-httpd-php3',
|
|
'application/x-httpd-php4',
|
|
'application/x-httpd-php5',
|
|
'application/x-httpd-php7',
|
|
'application/x-httpd-php8',
|
|
'application/x-httpd-php-source',
|
|
'application/x-httpd-php3-source',
|
|
'application/x-httpd-php4-source',
|
|
'application/x-httpd-php5-source',
|
|
'application/x-httpd-php7-source',
|
|
'application/x-httpd-php8-source'
|
|
}
|
|
return mimeType.lower() in textMimeTypes
|
|
|
|
# File methods - metadata-based operations
|
|
|
|
def _getFilesByCurrentUser(self, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
|
"""Files are always user-scoped. Returns only files owned by the current user,
|
|
regardless of role (including SysAdmin). This bypasses RBAC intentionally."""
|
|
filterDict = {"_createdBy": self.userId}
|
|
if recordFilter:
|
|
filterDict.update(recordFilter)
|
|
return self.db.getRecordset(FileItem, recordFilter=filterDict)
|
|
|
|
def getAllFiles(self, pagination: Optional[PaginationParams] = None) -> Union[List[FileItem], PaginatedResult]:
|
|
"""
|
|
Returns files owned by the current user (user-scoped, not RBAC-based).
|
|
Every user (including SysAdmin) only sees their own files.
|
|
Supports optional pagination, sorting, and filtering via database-level queries.
|
|
|
|
Args:
|
|
pagination: Optional pagination parameters. If None, returns all items.
|
|
|
|
Returns:
|
|
If pagination is None: List[FileItem]
|
|
If pagination is provided: PaginatedResult with items and metadata
|
|
"""
|
|
# User-scoping filter: every user only sees their own files (bypasses RBAC SysAdmin override)
|
|
recordFilter = {"_createdBy": self.userId}
|
|
|
|
def _convertFileItems(files):
|
|
fileItems = []
|
|
for file in files:
|
|
try:
|
|
creationDate = file.get("creationDate")
|
|
if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0:
|
|
file["creationDate"] = getUtcTimestamp()
|
|
|
|
fileName = file.get("fileName")
|
|
if not fileName or fileName == "None":
|
|
continue
|
|
|
|
fileItem = FileItem(**file)
|
|
fileItems.append(fileItem)
|
|
except Exception as e:
|
|
logger.warning(f"Skipping invalid file record: {str(e)}")
|
|
continue
|
|
return fileItems
|
|
|
|
if pagination is None:
|
|
allFiles = self._getFilesByCurrentUser()
|
|
return _convertFileItems(allFiles)
|
|
|
|
# Database-level pagination: filtering, sorting, and LIMIT/OFFSET happen in SQL
|
|
result = self.db.getRecordsetPaginated(
|
|
FileItem,
|
|
pagination=pagination,
|
|
recordFilter=recordFilter
|
|
)
|
|
|
|
items = _convertFileItems(result["items"])
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=result["totalItems"],
|
|
totalPages=result["totalPages"]
|
|
)
|
|
|
|
def getFile(self, fileId: str) -> Optional[FileItem]:
|
|
"""Returns a file by ID if it belongs to the current user (user-scoped)."""
|
|
# Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override)
|
|
filteredFiles = self._getFilesByCurrentUser(recordFilter={"id": fileId})
|
|
|
|
if not filteredFiles:
|
|
return None
|
|
|
|
file = filteredFiles[0]
|
|
try:
|
|
# Get creation date from record or use current time
|
|
creationDate = file.get("creationDate")
|
|
if not creationDate:
|
|
creationDate = getUtcTimestamp()
|
|
|
|
return FileItem(
|
|
id=file.get("id"),
|
|
mandateId=file.get("mandateId"),
|
|
fileName=file.get("fileName"),
|
|
mimeType=file.get("mimeType"),
|
|
workflowId=file.get("workflowId"),
|
|
fileHash=file.get("fileHash"),
|
|
fileSize=file.get("fileSize"),
|
|
creationDate=creationDate
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error converting file record: {str(e)}")
|
|
return None
|
|
|
|
def _isfileNameUnique(self, fileName: str, excludeFileId: Optional[str] = None) -> bool:
|
|
"""Checks if a fileName is unique for the current user."""
|
|
# Get all files filtered by RBAC (will be filtered by user's access level)
|
|
files = getRecordsetWithRBAC(self.db,
|
|
FileItem,
|
|
self.currentUser,
|
|
mandateId=self.mandateId
|
|
)
|
|
|
|
# Check if fileName exists (excluding the current file if updating)
|
|
for file in files:
|
|
# Skip files without fileName key or with None/empty fileName
|
|
if "fileName" not in file or not file["fileName"]:
|
|
continue
|
|
if file["fileName"] == fileName and (excludeFileId is None or file["id"] != excludeFileId):
|
|
return False
|
|
return True
|
|
|
|
def _generateUniquefileName(self, fileName: str, excludeFileId: Optional[str] = None) -> str:
|
|
"""Generates a unique fileName by adding a number if necessary."""
|
|
if self._isfileNameUnique(fileName, excludeFileId):
|
|
return fileName
|
|
|
|
# Split fileName into name and extension
|
|
name, ext = os.path.splitext(fileName)
|
|
counter = 1
|
|
|
|
# Try fileNames with increasing numbers until we find a unique one
|
|
while True:
|
|
newfileName = f"{name}_{counter}{ext}"
|
|
if self._isfileNameUnique(newfileName, excludeFileId):
|
|
return newfileName
|
|
counter += 1
|
|
|
|
def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem:
|
|
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content.
|
|
|
|
Duplicate check: if a file with the same user + fileHash + fileName already exists,
|
|
the existing file is returned instead of creating a new one.
|
|
Same hash with different name is allowed (intentional copy by user).
|
|
"""
|
|
if not self.checkRbacPermission(FileItem, "create"):
|
|
raise PermissionError("No permission to create files")
|
|
|
|
# Compute file size and hash
|
|
fileSize = len(content)
|
|
fileHash = hashlib.sha256(content).hexdigest()
|
|
|
|
# Duplicate check: same user + same hash + same fileName → return existing
|
|
existingFile = self.checkForDuplicateFile(fileHash, name)
|
|
if existingFile:
|
|
logger.info(f"Duplicate file detected in createFile: '{name}' (hash={fileHash[:12]}...) for user {self.userId} — returning existing file {existingFile.id}")
|
|
return existingFile
|
|
|
|
# Ensure fileName is unique
|
|
uniqueName = self._generateUniquefileName(name)
|
|
|
|
# Use mandateId and featureInstanceId from context for proper data isolation
|
|
# Convert None to empty string to satisfy Pydantic validation
|
|
mandateId = self.mandateId or ""
|
|
featureInstanceId = self.featureInstanceId or ""
|
|
|
|
# Create FileItem instance
|
|
fileItem = FileItem(
|
|
mandateId=mandateId,
|
|
featureInstanceId=featureInstanceId,
|
|
fileName=uniqueName,
|
|
mimeType=mimeType,
|
|
fileSize=fileSize,
|
|
fileHash=fileHash
|
|
)
|
|
|
|
# Store in database
|
|
self.db.recordCreate(FileItem, fileItem)
|
|
|
|
return fileItem
|
|
|
|
def updateFile(self, fileId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Updates file metadata if user has access."""
|
|
# Check if the file exists and user has access
|
|
file = self.getFile(fileId)
|
|
if not file:
|
|
raise FileNotFoundError(f"File with ID {fileId} not found")
|
|
|
|
if not self.checkRbacPermission(FileItem, "update", fileId):
|
|
raise PermissionError(f"No permission to update file {fileId}")
|
|
|
|
# If fileName is being updated, ensure it's unique
|
|
if "fileName" in updateData:
|
|
updateData["fileName"] = self._generateUniquefileName(updateData["fileName"], fileId)
|
|
|
|
# Update file
|
|
success = self.db.recordModify(FileItem, fileId, updateData)
|
|
|
|
|
|
return success
|
|
|
|
def deleteFile(self, fileId: str) -> bool:
|
|
"""Deletes a file if user has access."""
|
|
try:
|
|
# Check if the file exists and user has access
|
|
file = self.getFile(fileId)
|
|
|
|
if not file:
|
|
raise FileNotFoundError(f"File with ID {fileId} not found")
|
|
|
|
if not self.checkRbacPermission(FileItem, "update", fileId):
|
|
raise PermissionError(f"No permission to delete file {fileId}")
|
|
|
|
# Check for other references to this file (by hash) - user-scoped check
|
|
fileHash = file.fileHash
|
|
if fileHash:
|
|
allReferences = self._getFilesByCurrentUser(recordFilter={"fileHash": fileHash})
|
|
otherReferences = [f for f in allReferences if f["id"] != fileId]
|
|
|
|
# Only delete associated fileData if no other references exist
|
|
if not otherReferences:
|
|
try:
|
|
fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
|
if fileDataEntries:
|
|
self.db.recordDelete(FileData, fileId)
|
|
logger.debug(f"FileData for file {fileId} deleted")
|
|
except Exception as e:
|
|
logger.warning(f"Error deleting FileData for file {fileId}: {str(e)}")
|
|
|
|
# Delete the FileItem entry
|
|
success = self.db.recordDelete(FileItem, fileId)
|
|
|
|
# Clear cache to ensure fresh data
|
|
|
|
return success
|
|
|
|
except FileNotFoundError as e:
|
|
raise
|
|
except FilePermissionError as e:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error deleting file {fileId}: {str(e)}")
|
|
raise FileDeletionError(f"Error deleting file: {str(e)}")
|
|
|
|
def deleteFilesBatch(self, fileIds: List[str]) -> Dict[str, Any]:
|
|
"""Delete multiple files in a single SQL batch call."""
|
|
uniqueIds = [str(fid) for fid in dict.fromkeys(fileIds or []) if fid]
|
|
if not uniqueIds:
|
|
return {"deletedFiles": 0}
|
|
|
|
try:
|
|
self.db._ensure_connection()
|
|
with self.db.connection.cursor() as cursor:
|
|
cursor.execute(
|
|
'SELECT "id" FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
|
(uniqueIds, self.userId or ""),
|
|
)
|
|
accessibleIds = [row["id"] for row in cursor.fetchall()]
|
|
|
|
if len(accessibleIds) != len(uniqueIds):
|
|
missingIds = sorted(set(uniqueIds) - set(accessibleIds))
|
|
raise FileNotFoundError(f"Files not found or not accessible: {missingIds}")
|
|
|
|
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (accessibleIds,))
|
|
cursor.execute(
|
|
'DELETE FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
|
(accessibleIds, self.userId or ""),
|
|
)
|
|
deletedFiles = cursor.rowcount
|
|
|
|
self.db.connection.commit()
|
|
return {"deletedFiles": deletedFiles}
|
|
except Exception as e:
|
|
logger.error(f"Error deleting files in batch: {e}")
|
|
self.db.connection.rollback()
|
|
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
|
|
|
|
# ---- Folder methods ----
|
|
|
|
_RESERVED_FOLDER_NAMES = {"(Global)"}
|
|
|
|
def _validateFolderName(self, name: str, parentId: Optional[str], excludeFolderId: Optional[str] = None):
|
|
"""Ensures folder name is not reserved and is unique within parent."""
|
|
if name in self._RESERVED_FOLDER_NAMES:
|
|
raise ValueError(f"Folder name '{name}' is reserved")
|
|
if not name or not name.strip():
|
|
raise ValueError("Folder name cannot be empty")
|
|
existingFolders = self.db.getRecordset(FileFolder, recordFilter={"parentId": parentId or ""})
|
|
for f in existingFolders:
|
|
if f.get("name") == name and f.get("id") != excludeFolderId:
|
|
raise ValueError(f"Folder '{name}' already exists in this directory")
|
|
|
|
def _isDescendantOf(self, folderId: str, ancestorId: str) -> bool:
|
|
"""Checks if folderId is a descendant of ancestorId (circular reference check)."""
|
|
visited = set()
|
|
currentId = folderId
|
|
while currentId:
|
|
if currentId == ancestorId:
|
|
return True
|
|
if currentId in visited:
|
|
break
|
|
visited.add(currentId)
|
|
folders = self.db.getRecordset(FileFolder, recordFilter={"id": currentId})
|
|
if not folders:
|
|
break
|
|
currentId = folders[0].get("parentId")
|
|
return False
|
|
|
|
def getFolder(self, folderId: str) -> Optional[Dict[str, Any]]:
|
|
"""Returns a folder by ID if it belongs to the current user."""
|
|
folders = self.db.getRecordset(FileFolder, recordFilter={"id": folderId, "_createdBy": self.userId or ""})
|
|
return folders[0] if folders else None
|
|
|
|
def listFolders(self, parentId: Optional[str] = None) -> List[Dict[str, Any]]:
|
|
"""List folders for current user, optionally filtered by parentId."""
|
|
recordFilter = {"_createdBy": self.userId or ""}
|
|
if parentId is not None:
|
|
recordFilter["parentId"] = parentId
|
|
return self.db.getRecordset(FileFolder, recordFilter=recordFilter)
|
|
|
|
def createFolder(self, name: str, parentId: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Create a new folder with unique name validation."""
|
|
self._validateFolderName(name, parentId)
|
|
folder = FileFolder(
|
|
name=name,
|
|
parentId=parentId,
|
|
mandateId=self.mandateId or "",
|
|
featureInstanceId=self.featureInstanceId or "",
|
|
)
|
|
return self.db.recordCreate(FileFolder, folder)
|
|
|
|
def renameFolder(self, folderId: str, newName: str) -> bool:
|
|
"""Rename a folder with unique name validation."""
|
|
folder = self.getFolder(folderId)
|
|
if not folder:
|
|
raise FileNotFoundError(f"Folder {folderId} not found")
|
|
self._validateFolderName(newName, folder.get("parentId"), excludeFolderId=folderId)
|
|
return self.db.recordModify(FileFolder, folderId, {"name": newName})
|
|
|
|
def moveFolder(self, folderId: str, targetParentId: Optional[str] = None) -> bool:
|
|
"""Move a folder to a new parent, with circular reference and unique name checks."""
|
|
folder = self.getFolder(folderId)
|
|
if not folder:
|
|
raise FileNotFoundError(f"Folder {folderId} not found")
|
|
if targetParentId and self._isDescendantOf(targetParentId, folderId):
|
|
raise ValueError("Cannot move folder into its own subtree")
|
|
self._validateFolderName(folder.get("name", ""), targetParentId, excludeFolderId=folderId)
|
|
return self.db.recordModify(FileFolder, folderId, {"parentId": targetParentId})
|
|
|
|
def moveFilesBatch(self, fileIds: List[str], targetFolderId: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Move multiple files with one SQL update."""
|
|
uniqueIds = [str(fid) for fid in dict.fromkeys(fileIds or []) if fid]
|
|
if not uniqueIds:
|
|
return {"movedFiles": 0}
|
|
|
|
if targetFolderId:
|
|
targetFolder = self.getFolder(targetFolderId)
|
|
if not targetFolder:
|
|
raise FileNotFoundError(f"Target folder {targetFolderId} not found")
|
|
|
|
try:
|
|
self.db._ensure_connection()
|
|
with self.db.connection.cursor() as cursor:
|
|
cursor.execute(
|
|
'SELECT "id" FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
|
(uniqueIds, self.userId or ""),
|
|
)
|
|
accessibleIds = [row["id"] for row in cursor.fetchall()]
|
|
if len(accessibleIds) != len(uniqueIds):
|
|
missingIds = sorted(set(uniqueIds) - set(accessibleIds))
|
|
raise FileNotFoundError(f"Files not found or not accessible: {missingIds}")
|
|
|
|
cursor.execute(
|
|
'UPDATE "FileItem" SET "folderId" = %s, "_modifiedAt" = %s, "_modifiedBy" = %s '
|
|
'WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
|
(targetFolderId, getUtcTimestamp(), self.userId or "", accessibleIds, self.userId or ""),
|
|
)
|
|
movedFiles = cursor.rowcount
|
|
|
|
self.db.connection.commit()
|
|
return {"movedFiles": movedFiles}
|
|
except Exception as e:
|
|
logger.error(f"Error moving files in batch: {e}")
|
|
self.db.connection.rollback()
|
|
raise FileError(f"Error moving files in batch: {str(e)}")
|
|
|
|
def moveFoldersBatch(self, folderIds: List[str], targetParentId: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Move multiple folders with one SQL update after validation."""
|
|
uniqueIds = [str(fid) for fid in dict.fromkeys(folderIds or []) if fid]
|
|
if not uniqueIds:
|
|
return {"movedFolders": 0}
|
|
|
|
foldersToMove: List[Dict[str, Any]] = []
|
|
for folderId in uniqueIds:
|
|
folder = self.getFolder(folderId)
|
|
if not folder:
|
|
raise FileNotFoundError(f"Folder {folderId} not found")
|
|
if targetParentId and self._isDescendantOf(targetParentId, folderId):
|
|
raise ValueError("Cannot move folder into its own subtree")
|
|
foldersToMove.append(folder)
|
|
|
|
existingInTarget = self.db.getRecordset(
|
|
FileFolder,
|
|
recordFilter={"parentId": targetParentId or "", "_createdBy": self.userId or ""},
|
|
)
|
|
existingNames = {f.get("name"): f.get("id") for f in existingInTarget}
|
|
movingNames: Dict[str, str] = {}
|
|
movingIds = set(uniqueIds)
|
|
|
|
for folder in foldersToMove:
|
|
name = folder.get("name", "")
|
|
folderId = folder.get("id")
|
|
if name in movingNames and movingNames[name] != folderId:
|
|
raise ValueError(f"Folder '{name}' already exists in this move batch")
|
|
movingNames[name] = folderId
|
|
|
|
existingId = existingNames.get(name)
|
|
if existingId and existingId not in movingIds:
|
|
raise ValueError(f"Folder '{name}' already exists in target directory")
|
|
|
|
try:
|
|
self.db._ensure_connection()
|
|
with self.db.connection.cursor() as cursor:
|
|
cursor.execute(
|
|
'UPDATE "FileFolder" SET "parentId" = %s, "_modifiedAt" = %s, "_modifiedBy" = %s '
|
|
'WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
|
(targetParentId, getUtcTimestamp(), self.userId or "", uniqueIds, self.userId or ""),
|
|
)
|
|
movedFolders = cursor.rowcount
|
|
|
|
self.db.connection.commit()
|
|
return {"movedFolders": movedFolders}
|
|
except Exception as e:
|
|
logger.error(f"Error moving folders in batch: {e}")
|
|
self.db.connection.rollback()
|
|
raise FileError(f"Error moving folders in batch: {str(e)}")
|
|
|
|
def deleteFolder(self, folderId: str, recursive: bool = False) -> Dict[str, Any]:
|
|
"""Delete a folder. If recursive, deletes all contents. Returns summary of deletions."""
|
|
folder = self.getFolder(folderId)
|
|
if not folder:
|
|
raise FileNotFoundError(f"Folder {folderId} not found")
|
|
|
|
childFolders = self.db.getRecordset(FileFolder, recordFilter={"parentId": folderId, "_createdBy": self.userId or ""})
|
|
childFiles = self._getFilesByCurrentUser(recordFilter={"folderId": folderId})
|
|
|
|
if not recursive and (childFolders or childFiles):
|
|
raise ValueError(
|
|
f"Folder '{folder.get('name')}' is not empty "
|
|
f"({len(childFiles)} files, {len(childFolders)} subfolders). "
|
|
f"Use recursive=true to delete contents."
|
|
)
|
|
|
|
deletedFiles = 0
|
|
deletedFolders = 0
|
|
|
|
if recursive:
|
|
for subFolder in childFolders:
|
|
subResult = self.deleteFolder(subFolder["id"], recursive=True)
|
|
deletedFiles += subResult.get("deletedFiles", 0)
|
|
deletedFolders += subResult.get("deletedFolders", 0)
|
|
for childFile in childFiles:
|
|
try:
|
|
self.deleteFile(childFile["id"])
|
|
deletedFiles += 1
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete file {childFile['id']} during folder deletion: {e}")
|
|
|
|
self.db.recordDelete(FileFolder, folderId)
|
|
deletedFolders += 1
|
|
|
|
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
|
|
|
|
def deleteFoldersBatch(self, folderIds: List[str], recursive: bool = True) -> Dict[str, Any]:
|
|
"""Delete multiple folders and their content in batched SQL calls."""
|
|
uniqueIds = [str(fid) for fid in dict.fromkeys(folderIds or []) if fid]
|
|
if not uniqueIds:
|
|
return {"deletedFiles": 0, "deletedFolders": 0}
|
|
|
|
if not recursive:
|
|
deletedFiles = 0
|
|
deletedFolders = 0
|
|
for folderId in uniqueIds:
|
|
result = self.deleteFolder(folderId, recursive=False)
|
|
deletedFiles += result.get("deletedFiles", 0)
|
|
deletedFolders += result.get("deletedFolders", 0)
|
|
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
|
|
|
|
try:
|
|
self.db._ensure_connection()
|
|
with self.db.connection.cursor() as cursor:
|
|
cursor.execute(
|
|
'SELECT "id" FROM "FileFolder" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
|
(uniqueIds, self.userId or ""),
|
|
)
|
|
rootAccessibleIds = [row["id"] for row in cursor.fetchall()]
|
|
if len(rootAccessibleIds) != len(uniqueIds):
|
|
missingIds = sorted(set(uniqueIds) - set(rootAccessibleIds))
|
|
raise FileNotFoundError(f"Folders not found or not accessible: {missingIds}")
|
|
|
|
cursor.execute(
|
|
"""
|
|
WITH RECURSIVE folder_tree AS (
|
|
SELECT "id"
|
|
FROM "FileFolder"
|
|
WHERE "id" = ANY(%s) AND "_createdBy" = %s
|
|
UNION ALL
|
|
SELECT child."id"
|
|
FROM "FileFolder" child
|
|
INNER JOIN folder_tree ft ON child."parentId" = ft."id"
|
|
WHERE child."_createdBy" = %s
|
|
)
|
|
SELECT DISTINCT "id" FROM folder_tree
|
|
""",
|
|
(rootAccessibleIds, self.userId or "", self.userId or ""),
|
|
)
|
|
allFolderIds = [row["id"] for row in cursor.fetchall()]
|
|
|
|
cursor.execute(
|
|
'SELECT "id" FROM "FileItem" WHERE "folderId" = ANY(%s) AND "_createdBy" = %s',
|
|
(allFolderIds, self.userId or ""),
|
|
)
|
|
allFileIds = [row["id"] for row in cursor.fetchall()]
|
|
|
|
if allFileIds:
|
|
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (allFileIds,))
|
|
cursor.execute(
|
|
'DELETE FROM "FileItem" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
|
(allFileIds, self.userId or ""),
|
|
)
|
|
deletedFiles = cursor.rowcount
|
|
else:
|
|
deletedFiles = 0
|
|
|
|
cursor.execute(
|
|
'DELETE FROM "FileFolder" WHERE "id" = ANY(%s) AND "_createdBy" = %s',
|
|
(allFolderIds, self.userId or ""),
|
|
)
|
|
deletedFolders = cursor.rowcount
|
|
|
|
self.db.connection.commit()
|
|
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
|
|
except Exception as e:
|
|
logger.error(f"Error deleting folders in batch: {e}")
|
|
self.db.connection.rollback()
|
|
raise FileDeletionError(f"Error deleting folders in batch: {str(e)}")
|
|
|
|
def copyFile(self, sourceFileId: str, targetFolderId: Optional[str] = None, newFileName: Optional[str] = None) -> FileItem:
|
|
"""Create a full duplicate of a file (FileItem + FileData)."""
|
|
sourceFile = self.getFile(sourceFileId)
|
|
if not sourceFile:
|
|
raise FileNotFoundError(f"File {sourceFileId} not found")
|
|
|
|
sourceData = self.getFileData(sourceFileId)
|
|
if sourceData is None:
|
|
raise FileStorageError(f"No data found for file {sourceFileId}")
|
|
|
|
fileName = newFileName or sourceFile.fileName
|
|
copiedFile = self.createFile(fileName, sourceFile.mimeType, sourceData)
|
|
|
|
if targetFolderId:
|
|
self.updateFile(copiedFile.id, {"folderId": targetFolderId})
|
|
elif sourceFile.folderId:
|
|
self.updateFile(copiedFile.id, {"folderId": sourceFile.folderId})
|
|
|
|
self.createFileData(copiedFile.id, sourceData)
|
|
return copiedFile
|
|
|
|
def updateFileData(self, fileId: str, data: bytes) -> bool:
|
|
"""Replace existing file data (delete + create). Updates FileItem metadata."""
|
|
file = self.getFile(fileId)
|
|
if not file:
|
|
raise FileNotFoundError(f"File {fileId} not found")
|
|
|
|
try:
|
|
self.db.recordDelete(FileData, fileId)
|
|
logger.debug(f"Deleted existing FileData for {fileId}")
|
|
except Exception as e:
|
|
logger.debug(f"No existing FileData to delete for {fileId}: {e}")
|
|
|
|
success = self.createFileData(fileId, data)
|
|
if success:
|
|
newSize = len(data)
|
|
newHash = hashlib.sha256(data).hexdigest()
|
|
self.db.recordModify(FileItem, fileId, {"fileSize": newSize, "fileHash": newHash})
|
|
logger.info(f"Updated file data for {fileId} ({newSize} bytes)")
|
|
return success
|
|
|
|
# FileData methods - data operations
|
|
|
|
def createFileData(self, fileId: str, data: bytes) -> bool:
|
|
"""Stores the binary data of a file in the database."""
|
|
try:
|
|
# Check file access
|
|
file = self.getFile(fileId)
|
|
if not file:
|
|
logger.error(f"File with ID {fileId} not found when storing data")
|
|
return False
|
|
|
|
# Determine if this is a text-based format
|
|
mimeType = file.mimeType
|
|
isTextFormat = self.isTextMimeType(mimeType)
|
|
|
|
base64Encoded = False
|
|
fileData = None
|
|
|
|
if isTextFormat:
|
|
# Try to decode as text
|
|
try:
|
|
textContent = data.decode('utf-8')
|
|
fileData = textContent
|
|
base64Encoded = False
|
|
logger.debug(f"Stored file {fileId} as text")
|
|
except UnicodeDecodeError:
|
|
# Fallback to base64 if text decoding fails
|
|
encodedData = base64.b64encode(data).decode('utf-8')
|
|
fileData = encodedData
|
|
base64Encoded = True
|
|
logger.warning(f"Failed to decode text file {fileId}, falling back to base64")
|
|
else:
|
|
# Binary format - always use base64
|
|
encodedData = base64.b64encode(data).decode('utf-8')
|
|
fileData = encodedData
|
|
base64Encoded = True
|
|
logger.debug(f"Stored file {fileId} as base64")
|
|
|
|
# Check if file data already exists (e.g., when createFile returned a duplicate)
|
|
existingData = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
|
if existingData:
|
|
logger.debug(f"File data already exists for {fileId} — skipping duplicate storage")
|
|
return True
|
|
|
|
# Create the fileData record with data and encoding flag
|
|
fileDataObj = {
|
|
"id": fileId,
|
|
"data": fileData,
|
|
"base64Encoded": base64Encoded
|
|
}
|
|
|
|
self.db.recordCreate(FileData, fileDataObj)
|
|
|
|
# Clear cache to ensure fresh data
|
|
|
|
logger.debug(f"Successfully stored data for file {fileId} (base64Encoded: {base64Encoded})")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error storing data for file {fileId}: {str(e)}")
|
|
return False
|
|
|
|
def getFileData(self, fileId: str) -> Optional[bytes]:
|
|
"""Returns the binary data of a file if user has access."""
|
|
# Check file access
|
|
file = self.getFile(fileId)
|
|
if not file:
|
|
logger.warning(f"No access to file ID {fileId}")
|
|
return None
|
|
|
|
fileDataEntries = getRecordsetWithRBAC(self.db, FileData, self.currentUser, recordFilter={"id": fileId}, mandateId=self.mandateId)
|
|
if not fileDataEntries:
|
|
logger.warning(f"No data found for file ID {fileId}")
|
|
return None
|
|
|
|
fileDataEntry = fileDataEntries[0]
|
|
if "data" not in fileDataEntry:
|
|
logger.warning(f"No data field in file data for ID {fileId}")
|
|
return None
|
|
|
|
data = fileDataEntry["data"]
|
|
base64Encoded = fileDataEntry.get("base64Encoded", False)
|
|
|
|
try:
|
|
if base64Encoded:
|
|
# Decode base64 to bytes
|
|
return base64.b64decode(data)
|
|
else:
|
|
# Check if this is supposed to be a binary file based on mime type
|
|
mimeType = file.mimeType
|
|
isTextFormat = self.isTextMimeType(mimeType)
|
|
|
|
if isTextFormat:
|
|
# This is a text file, encode to bytes as expected
|
|
return data.encode('utf-8')
|
|
else:
|
|
# This is a binary file that was incorrectly stored as text
|
|
# Try to decode it as if it was base64 (common fallback scenario)
|
|
try:
|
|
logger.warning(f"Binary file {fileId} ({mimeType}) was stored as text, attempting base64 decode")
|
|
return base64.b64decode(data)
|
|
except Exception as base64_error:
|
|
logger.error(f"Failed to decode binary file {fileId} as base64: {str(base64_error)}")
|
|
# Last resort: return the data as-is (might be corrupted)
|
|
logger.warning(f"Returning raw data for file {fileId} - file may be corrupted")
|
|
return data.encode('utf-8') if isinstance(data, str) else data
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing file data for {fileId}: {str(e)}")
|
|
return None
|
|
|
|
def getFileDataForPublicDocument(self, fileId: str) -> Optional[bytes]:
|
|
"""
|
|
Returns binary data for public documents (e.g. BZO) WITHOUT RBAC filtering.
|
|
Use for official/mandate documents that must be accessible to all users.
|
|
Reads FileData directly from database.
|
|
"""
|
|
try:
|
|
fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
|
if not fileDataEntries:
|
|
logger.warning(f"No file data found for public document ID {fileId}")
|
|
return None
|
|
fileDataEntry = fileDataEntries[0]
|
|
if "data" not in fileDataEntry:
|
|
logger.warning(f"No data field in file data for ID {fileId}")
|
|
return None
|
|
data = fileDataEntry["data"]
|
|
base64Encoded = fileDataEntry.get("base64Encoded", False)
|
|
if base64Encoded:
|
|
return base64.b64decode(data)
|
|
# PDF/binary stored as text: try base64 decode (common for binary files)
|
|
try:
|
|
return base64.b64decode(data)
|
|
except Exception:
|
|
return data.encode("utf-8") if isinstance(data, str) else data
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving public document {fileId}: {str(e)}", exc_info=True)
|
|
return None
|
|
|
|
def getFileContent(self, fileId: str) -> Optional[FilePreview]:
|
|
"""Returns the full file content if user has access."""
|
|
try:
|
|
# Get file metadata
|
|
file = self.getFile(fileId)
|
|
if not file:
|
|
logger.warning(f"No access to file ID {fileId}")
|
|
return None
|
|
|
|
# Get file content
|
|
fileContent = self.getFileData(fileId)
|
|
if not fileContent:
|
|
logger.warning(f"No content found for file ID {fileId}")
|
|
return None
|
|
|
|
# Process content based on file type
|
|
isText = False
|
|
content = ""
|
|
encoding = None
|
|
|
|
# Use proper attribute access for FileItem object
|
|
if file.mimeType.startswith("text/"):
|
|
# For text files, return full content
|
|
try:
|
|
content = fileContent.decode('utf-8')
|
|
isText = True
|
|
encoding = 'utf-8'
|
|
except UnicodeDecodeError:
|
|
content = fileContent.decode('latin-1')
|
|
isText = True
|
|
encoding = 'latin-1'
|
|
elif file.mimeType.startswith("image/"):
|
|
# For images, return base64
|
|
content = base64.b64encode(fileContent).decode('utf-8')
|
|
isText = False
|
|
else:
|
|
# For other files, return as base64
|
|
content = base64.b64encode(fileContent).decode('utf-8')
|
|
isText = False
|
|
|
|
return FilePreview(
|
|
content=content,
|
|
mimeType=file.mimeType,
|
|
fileName=file.fileName,
|
|
isText=isText,
|
|
encoding=encoding,
|
|
size=file.fileSize
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting file content: {str(e)}")
|
|
return None
|
|
|
|
def saveUploadedFile(self, fileContent: bytes, fileName: str) -> tuple[FileItem, str]:
|
|
"""Saves an uploaded file if user has permission."""
|
|
try:
|
|
# Check file creation permission
|
|
if not self.checkRbacPermission(FileItem, "create"):
|
|
raise PermissionError("No permission to upload files")
|
|
|
|
logger.debug(f"Starting upload process for file: {fileName}")
|
|
|
|
if not isinstance(fileContent, bytes):
|
|
logger.error(f"Invalid fileContent type: {type(fileContent)}")
|
|
raise ValueError(f"fileContent must be bytes, got {type(fileContent)}")
|
|
|
|
# Compute file hash to check for duplicates before any DB writes
|
|
fileHash = hashlib.sha256(fileContent).hexdigest()
|
|
|
|
# Duplicate check: same user + same fileHash + same fileName → return existing file
|
|
# Same hash with different name is allowed (intentional copy by user)
|
|
existingFile = self.checkForDuplicateFile(fileHash, fileName)
|
|
if existingFile:
|
|
logger.info(f"Duplicate detected for user {self.userId}: '{fileName}' with hash {fileHash[:12]}... — returning existing file {existingFile.id}")
|
|
return existingFile, "exact_duplicate"
|
|
|
|
# Determine MIME type
|
|
mimeType = self.getMimeType(fileName)
|
|
|
|
# createFile handles its own duplicate check (for calls from other code paths)
|
|
# Here we already checked, so this will create a new file
|
|
logger.debug(f"Saving file metadata to database for file: {fileName}")
|
|
fileItem = self.createFile(
|
|
name=fileName,
|
|
mimeType=mimeType,
|
|
content=fileContent
|
|
)
|
|
|
|
# Save binary data
|
|
logger.debug(f"Saving file content to database for file: {fileName}")
|
|
self.createFileData(fileItem.id, fileContent)
|
|
|
|
logger.debug(f"File upload process completed for: {fileName}")
|
|
|
|
# Check if filename was modified (indicating name conflict)
|
|
if fileItem.fileName != fileName:
|
|
return fileItem, "name_conflict"
|
|
else:
|
|
return fileItem, "new_file"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in saveUploadedFile for {fileName}: {str(e)}", exc_info=True)
|
|
raise FileStorageError(f"Error saving file: {str(e)}")
|
|
|
|
# VoiceSettings methods
|
|
|
|
def getVoiceSettings(self, userId: Optional[str] = None) -> Optional[VoiceSettings]:
|
|
"""Returns voice settings for a user if user has access."""
|
|
try:
|
|
targetUserId = userId or self.userId
|
|
if not targetUserId:
|
|
logger.error("No user ID provided for voice settings")
|
|
return None
|
|
|
|
recordFilter: Dict[str, Any] = {"userId": targetUserId}
|
|
if self.featureInstanceId:
|
|
recordFilter["featureInstanceId"] = self.featureInstanceId
|
|
|
|
# Get voice settings for the user (scoped to current feature instance if available), filtered by RBAC
|
|
filteredSettings = getRecordsetWithRBAC(self.db,
|
|
VoiceSettings,
|
|
self.currentUser,
|
|
recordFilter=recordFilter,
|
|
mandateId=self.mandateId
|
|
)
|
|
|
|
if not filteredSettings:
|
|
logger.warning(f"No access to voice settings for user {targetUserId}")
|
|
return None
|
|
|
|
# Ensure timestamps are set for validation
|
|
settings_data = filteredSettings[0]
|
|
if not settings_data.get("creationDate"):
|
|
settings_data["creationDate"] = getUtcTimestamp()
|
|
if not settings_data.get("lastModified"):
|
|
settings_data["lastModified"] = getUtcTimestamp()
|
|
|
|
return VoiceSettings(**settings_data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting voice settings: {str(e)}")
|
|
return None
|
|
|
|
def createVoiceSettings(self, settingsData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Creates voice settings for a user if user has permission."""
|
|
try:
|
|
if not self.checkRbacPermission(VoiceSettings, "update"):
|
|
raise PermissionError("No permission to create voice settings")
|
|
|
|
# Ensure userId is set
|
|
if "userId" not in settingsData:
|
|
settingsData["userId"] = self.userId
|
|
|
|
# Ensure mandateId and featureInstanceId are set from context
|
|
if "mandateId" not in settingsData:
|
|
settingsData["mandateId"] = self.mandateId
|
|
if "featureInstanceId" not in settingsData:
|
|
settingsData["featureInstanceId"] = self.featureInstanceId
|
|
|
|
# Check if settings already exist for this user
|
|
existingSettings = self.getVoiceSettings(settingsData["userId"])
|
|
if existingSettings:
|
|
raise ValueError(f"Voice settings already exist for user {settingsData['userId']}")
|
|
|
|
# Create voice settings record
|
|
createdRecord = self.db.recordCreate(VoiceSettings, settingsData)
|
|
if not createdRecord or not createdRecord.get("id"):
|
|
raise ValueError("Failed to create voice settings record")
|
|
|
|
logger.info(f"Created voice settings for user {settingsData['userId']}")
|
|
return createdRecord
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating voice settings: {str(e)}")
|
|
raise
|
|
|
|
def updateVoiceSettings(self, userId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Updates voice settings for a user if user has access."""
|
|
try:
|
|
# Get existing settings
|
|
existingSettings = self.getVoiceSettings(userId)
|
|
if not existingSettings:
|
|
raise ValueError(f"Voice settings not found for user {userId}")
|
|
|
|
# Update lastModified timestamp
|
|
updateData["lastModified"] = getUtcTimestamp()
|
|
|
|
# Update voice settings record
|
|
success = self.db.recordModify(VoiceSettings, existingSettings.id, updateData)
|
|
if not success:
|
|
raise ValueError("Failed to update voice settings record")
|
|
|
|
# Get updated settings
|
|
updatedSettings = self.getVoiceSettings(userId)
|
|
if not updatedSettings:
|
|
raise ValueError("Failed to retrieve updated voice settings")
|
|
|
|
logger.info(f"Updated voice settings for user {userId}")
|
|
return updatedSettings.model_dump()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error updating voice settings: {str(e)}")
|
|
raise
|
|
|
|
def deleteVoiceSettings(self, userId: str) -> bool:
|
|
"""Deletes voice settings for a user if user has access."""
|
|
try:
|
|
# Get existing settings
|
|
existingSettings = self.getVoiceSettings(userId)
|
|
if not existingSettings:
|
|
logger.warning(f"Voice settings not found for user {userId}")
|
|
return False
|
|
|
|
# Delete voice settings
|
|
success = self.db.recordDelete(VoiceSettings, existingSettings.id)
|
|
if success:
|
|
logger.info(f"Deleted voice settings for user {userId}")
|
|
else:
|
|
logger.error(f"Failed to delete voice settings for user {userId}")
|
|
|
|
return success
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting voice settings: {str(e)}")
|
|
return False
|
|
|
|
def getOrCreateVoiceSettings(self, userId: Optional[str] = None) -> VoiceSettings:
|
|
"""Gets existing voice settings or creates default ones for a user."""
|
|
try:
|
|
targetUserId = userId or self.userId
|
|
if not targetUserId:
|
|
raise ValueError("No user ID provided for voice settings")
|
|
|
|
# Try to get existing settings
|
|
existingSettings = self.getVoiceSettings(targetUserId)
|
|
if existingSettings:
|
|
return existingSettings
|
|
|
|
# Create default settings
|
|
defaultSettings = {
|
|
"userId": targetUserId,
|
|
"mandateId": self.mandateId,
|
|
"sttLanguage": "de-DE",
|
|
"ttsLanguage": "de-DE",
|
|
"ttsVoice": "de-DE-KatjaNeural",
|
|
"translationEnabled": True,
|
|
"targetLanguage": "en-US"
|
|
}
|
|
|
|
createdRecord = self.createVoiceSettings(defaultSettings)
|
|
return VoiceSettings(**createdRecord)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting or creating voice settings: {str(e)}")
|
|
raise
|
|
|
|
# Messaging Subscription methods
|
|
|
|
def getAllSubscriptions(self, pagination: Optional[PaginationParams] = None) -> Union[List[MessagingSubscription], PaginatedResult]:
|
|
"""
|
|
Returns subscriptions based on user access level.
|
|
Supports optional pagination, sorting, and filtering.
|
|
"""
|
|
try:
|
|
filteredSubscriptions = getRecordsetWithRBAC(self.db,
|
|
MessagingSubscription,
|
|
self.currentUser,
|
|
mandateId=self.mandateId
|
|
)
|
|
|
|
if pagination is None:
|
|
return [MessagingSubscription(**sub) for sub in filteredSubscriptions]
|
|
|
|
if pagination.filters:
|
|
filteredSubscriptions = self._applyFilters(filteredSubscriptions, pagination.filters)
|
|
|
|
if pagination.sort:
|
|
filteredSubscriptions = self._applySorting(filteredSubscriptions, pagination.sort)
|
|
|
|
totalItems = len(filteredSubscriptions)
|
|
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
|
|
|
startIdx = (pagination.page - 1) * pagination.pageSize
|
|
endIdx = startIdx + pagination.pageSize
|
|
pagedSubscriptions = filteredSubscriptions[startIdx:endIdx]
|
|
|
|
items = [MessagingSubscription(**sub) for sub in pagedSubscriptions]
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=totalItems,
|
|
totalPages=totalPages
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting subscriptions: {str(e)}")
|
|
if pagination is None:
|
|
return []
|
|
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
|
|
|
def getSubscription(self, subscriptionId: str) -> Optional[MessagingSubscription]:
|
|
"""Returns a subscription by subscriptionId if user has access."""
|
|
filteredSubscriptions = getRecordsetWithRBAC(self.db,
|
|
MessagingSubscription,
|
|
self.currentUser,
|
|
recordFilter={"subscriptionId": subscriptionId},
|
|
mandateId=self.mandateId
|
|
)
|
|
return MessagingSubscription(**filteredSubscriptions[0]) if filteredSubscriptions else None
|
|
|
|
def getSubscriptionById(self, id: str) -> Optional[MessagingSubscription]:
|
|
"""Returns a subscription by UUID if user has access."""
|
|
filteredSubscriptions = getRecordsetWithRBAC(self.db,
|
|
MessagingSubscription,
|
|
self.currentUser,
|
|
recordFilter={"id": id},
|
|
mandateId=self.mandateId
|
|
)
|
|
return MessagingSubscription(**filteredSubscriptions[0]) if filteredSubscriptions else None
|
|
|
|
def createSubscription(self, subscriptionData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Creates a new subscription if user has permission."""
|
|
if not self.checkRbacPermission(MessagingSubscription, "create"):
|
|
raise PermissionError("No permission to create subscriptions")
|
|
|
|
# Validate subscriptionId (only letters and underscores)
|
|
subscriptionId = subscriptionData.get("subscriptionId", "")
|
|
if not subscriptionId:
|
|
raise ValueError("subscriptionId is required")
|
|
# Check that subscriptionId contains only letters and underscores (no numbers)
|
|
if not all(c.isalpha() or c == "_" for c in subscriptionId):
|
|
raise ValueError("subscriptionId must contain only letters and underscores")
|
|
|
|
# Set mandateId and featureInstanceId from context for proper data isolation
|
|
if "mandateId" not in subscriptionData:
|
|
subscriptionData["mandateId"] = self.mandateId
|
|
if "featureInstanceId" not in subscriptionData:
|
|
subscriptionData["featureInstanceId"] = self.featureInstanceId
|
|
|
|
createdRecord = self.db.recordCreate(MessagingSubscription, subscriptionData)
|
|
if not createdRecord or not createdRecord.get("id"):
|
|
raise ValueError("Failed to create subscription record")
|
|
|
|
return createdRecord
|
|
|
|
def updateSubscription(self, subscriptionId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Updates a subscription if user has access."""
|
|
subscription = self.getSubscription(subscriptionId)
|
|
if not subscription:
|
|
raise ValueError(f"Subscription {subscriptionId} not found")
|
|
|
|
self.db.recordModify(MessagingSubscription, subscription.id, updateData)
|
|
|
|
updatedSubscription = self.getSubscription(subscriptionId)
|
|
if not updatedSubscription:
|
|
raise ValueError("Failed to retrieve updated subscription")
|
|
|
|
return updatedSubscription.model_dump()
|
|
|
|
def deleteSubscription(self, subscriptionId: str) -> bool:
|
|
"""Deletes a subscription if user has access."""
|
|
subscription = self.getSubscription(subscriptionId)
|
|
if not subscription:
|
|
return False
|
|
|
|
if not self.checkRbacPermission(MessagingSubscription, "update", subscription.id):
|
|
raise PermissionError(f"No permission to delete subscription {subscriptionId}")
|
|
|
|
return self.db.recordDelete(MessagingSubscription, subscription.id)
|
|
|
|
# Messaging Registration methods
|
|
|
|
def getAllRegistrations(
|
|
self,
|
|
subscriptionId: Optional[str] = None,
|
|
userId: Optional[str] = None,
|
|
pagination: Optional[PaginationParams] = None
|
|
) -> Union[List[MessagingSubscriptionRegistration], PaginatedResult]:
|
|
"""Returns registrations based on user access level."""
|
|
try:
|
|
recordFilter = {}
|
|
if subscriptionId:
|
|
recordFilter["subscriptionId"] = subscriptionId
|
|
if userId:
|
|
recordFilter["userId"] = userId
|
|
|
|
filteredRegistrations = getRecordsetWithRBAC(self.db,
|
|
MessagingSubscriptionRegistration,
|
|
self.currentUser,
|
|
recordFilter=recordFilter if recordFilter else None,
|
|
mandateId=self.mandateId
|
|
)
|
|
|
|
if pagination is None:
|
|
return [MessagingSubscriptionRegistration(**reg) for reg in filteredRegistrations]
|
|
|
|
if pagination.filters:
|
|
filteredRegistrations = self._applyFilters(filteredRegistrations, pagination.filters)
|
|
|
|
if pagination.sort:
|
|
filteredRegistrations = self._applySorting(filteredRegistrations, pagination.sort)
|
|
|
|
totalItems = len(filteredRegistrations)
|
|
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
|
|
|
startIdx = (pagination.page - 1) * pagination.pageSize
|
|
endIdx = startIdx + pagination.pageSize
|
|
pagedRegistrations = filteredRegistrations[startIdx:endIdx]
|
|
|
|
items = [MessagingSubscriptionRegistration(**reg) for reg in pagedRegistrations]
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=totalItems,
|
|
totalPages=totalPages
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting registrations: {str(e)}")
|
|
if pagination is None:
|
|
return []
|
|
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
|
|
|
def getRegistration(self, registrationId: str) -> Optional[MessagingSubscriptionRegistration]:
|
|
"""Returns a registration by ID if user has access."""
|
|
filteredRegistrations = getRecordsetWithRBAC(self.db,
|
|
MessagingSubscriptionRegistration,
|
|
self.currentUser,
|
|
recordFilter={"id": registrationId},
|
|
mandateId=self.mandateId
|
|
)
|
|
return MessagingSubscriptionRegistration(**filteredRegistrations[0]) if filteredRegistrations else None
|
|
|
|
def createRegistration(self, registrationData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Creates a new registration if user has permission."""
|
|
if not self.checkRbacPermission(MessagingSubscriptionRegistration, "create"):
|
|
raise PermissionError("No permission to create registrations")
|
|
|
|
# Set userId if not provided
|
|
if "userId" not in registrationData:
|
|
registrationData["userId"] = self.userId
|
|
|
|
# Set mandateId and featureInstanceId from context for proper data isolation
|
|
if "mandateId" not in registrationData:
|
|
registrationData["mandateId"] = self.mandateId
|
|
if "featureInstanceId" not in registrationData:
|
|
registrationData["featureInstanceId"] = self.featureInstanceId
|
|
|
|
createdRecord = self.db.recordCreate(MessagingSubscriptionRegistration, registrationData)
|
|
if not createdRecord or not createdRecord.get("id"):
|
|
raise ValueError("Failed to create registration record")
|
|
|
|
return createdRecord
|
|
|
|
def updateRegistration(self, registrationId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Updates a registration if user has access."""
|
|
registration = self.getRegistration(registrationId)
|
|
if not registration:
|
|
raise ValueError(f"Registration {registrationId} not found")
|
|
|
|
self.db.recordModify(MessagingSubscriptionRegistration, registrationId, updateData)
|
|
|
|
updatedRegistration = self.getRegistration(registrationId)
|
|
if not updatedRegistration:
|
|
raise ValueError("Failed to retrieve updated registration")
|
|
|
|
return updatedRegistration.model_dump()
|
|
|
|
def deleteRegistration(self, registrationId: str) -> bool:
|
|
"""Deletes a registration if user has access."""
|
|
registration = self.getRegistration(registrationId)
|
|
if not registration:
|
|
return False
|
|
|
|
if not self.checkRbacPermission(MessagingSubscriptionRegistration, "update", registrationId):
|
|
raise PermissionError(f"No permission to delete registration {registrationId}")
|
|
|
|
return self.db.recordDelete(MessagingSubscriptionRegistration, registrationId)
|
|
|
|
def subscribeUser(
|
|
self,
|
|
subscriptionId: str,
|
|
userId: str,
|
|
channel: MessagingChannel,
|
|
channelConfig: str
|
|
) -> Dict[str, Any]:
|
|
"""Subscribes a user to a subscription with a specific channel."""
|
|
# Check if subscription exists
|
|
subscription = self.getSubscription(subscriptionId)
|
|
if not subscription:
|
|
raise ValueError(f"Subscription {subscriptionId} not found")
|
|
|
|
# Check if registration already exists
|
|
existingRegistrations = self.getAllRegistrations(subscriptionId=subscriptionId, userId=userId)
|
|
for reg in existingRegistrations:
|
|
if reg.channel == channel:
|
|
# Update existing registration
|
|
return self.updateRegistration(reg.id, {"enabled": True, "channelConfig": channelConfig})
|
|
|
|
# Create new registration
|
|
registrationData = {
|
|
"subscriptionId": subscriptionId,
|
|
"userId": userId,
|
|
"channel": channel.value,
|
|
"channelConfig": channelConfig,
|
|
"enabled": True
|
|
}
|
|
return self.createRegistration(registrationData)
|
|
|
|
def unsubscribeUser(self, subscriptionId: str, userId: str, channel: MessagingChannel) -> bool:
|
|
"""Unsubscribes a user from a subscription for a specific channel."""
|
|
registrations = self.getAllRegistrations(subscriptionId=subscriptionId, userId=userId)
|
|
for reg in registrations:
|
|
if reg.channel == channel:
|
|
return self.deleteRegistration(reg.id)
|
|
return False
|
|
|
|
# Messaging Delivery methods
|
|
|
|
def createDelivery(self, delivery: MessagingDelivery) -> Dict[str, Any]:
|
|
"""Creates a new delivery record."""
|
|
deliveryData = delivery.model_dump() if isinstance(delivery, MessagingDelivery) else delivery
|
|
|
|
# Set mandateId and featureInstanceId from context for proper data isolation
|
|
if "mandateId" not in deliveryData or not deliveryData["mandateId"]:
|
|
deliveryData["mandateId"] = self.mandateId
|
|
if "featureInstanceId" not in deliveryData or not deliveryData["featureInstanceId"]:
|
|
deliveryData["featureInstanceId"] = self.featureInstanceId
|
|
|
|
createdRecord = self.db.recordCreate(MessagingDelivery, deliveryData)
|
|
if not createdRecord or not createdRecord.get("id"):
|
|
raise ValueError("Failed to create delivery record")
|
|
return createdRecord
|
|
|
|
def updateDelivery(self, deliveryId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Updates a delivery record."""
|
|
self.db.recordModify(MessagingDelivery, deliveryId, updateData)
|
|
return updateData
|
|
|
|
def getDeliveries(
|
|
self,
|
|
subscriptionId: Optional[str] = None,
|
|
userId: Optional[str] = None,
|
|
pagination: Optional[PaginationParams] = None
|
|
) -> Union[List[MessagingDelivery], PaginatedResult]:
|
|
"""Returns deliveries based on user access level."""
|
|
try:
|
|
recordFilter = {}
|
|
if subscriptionId:
|
|
recordFilter["subscriptionId"] = subscriptionId
|
|
if userId:
|
|
recordFilter["userId"] = userId
|
|
|
|
filteredDeliveries = getRecordsetWithRBAC(self.db,
|
|
MessagingDelivery,
|
|
self.currentUser,
|
|
recordFilter=recordFilter if recordFilter else None
|
|
)
|
|
|
|
if pagination is None:
|
|
return [MessagingDelivery(**delivery) for delivery in filteredDeliveries]
|
|
|
|
if pagination.filters:
|
|
filteredDeliveries = self._applyFilters(filteredDeliveries, pagination.filters)
|
|
|
|
if pagination.sort:
|
|
filteredDeliveries = self._applySorting(filteredDeliveries, pagination.sort)
|
|
|
|
totalItems = len(filteredDeliveries)
|
|
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
|
|
|
startIdx = (pagination.page - 1) * pagination.pageSize
|
|
endIdx = startIdx + pagination.pageSize
|
|
pagedDeliveries = filteredDeliveries[startIdx:endIdx]
|
|
|
|
items = [MessagingDelivery(**delivery) for delivery in pagedDeliveries]
|
|
|
|
return PaginatedResult(
|
|
items=items,
|
|
totalItems=totalItems,
|
|
totalPages=totalPages
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting deliveries: {str(e)}")
|
|
if pagination is None:
|
|
return []
|
|
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
|
|
|
def getDelivery(self, deliveryId: str) -> Optional[MessagingDelivery]:
|
|
"""Returns a delivery by ID if user has access."""
|
|
filteredDeliveries = getRecordsetWithRBAC(self.db,
|
|
MessagingDelivery,
|
|
self.currentUser,
|
|
recordFilter={"id": deliveryId}
|
|
)
|
|
return MessagingDelivery(**filteredDeliveries[0]) if filteredDeliveries else None
|
|
|
|
|
|
def getInterface(currentUser: Optional[User] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> 'ComponentObjects':
|
|
"""
|
|
Returns a ComponentObjects instance.
|
|
If currentUser is provided, initializes with user context.
|
|
Otherwise, returns an instance with only database access.
|
|
|
|
Args:
|
|
currentUser: The authenticated user
|
|
mandateId: The mandate ID from RequestContext (X-Mandate-Id header). Required.
|
|
featureInstanceId: The feature instance ID from RequestContext (X-Feature-Instance-Id header).
|
|
"""
|
|
effectiveMandateId = str(mandateId) if mandateId else None
|
|
effectiveFeatureInstanceId = str(featureInstanceId) if featureInstanceId else None
|
|
|
|
# Create new instance if not exists
|
|
if "default" not in _instancesManagement:
|
|
_instancesManagement["default"] = ComponentObjects()
|
|
|
|
interface = _instancesManagement["default"]
|
|
|
|
if currentUser:
|
|
interface.setUserContext(currentUser, mandateId=effectiveMandateId, featureInstanceId=effectiveFeatureInstanceId)
|
|
else:
|
|
logger.info("Returning interface without user context")
|
|
|
|
return interface |