2308 lines
No EOL
98 KiB
Python
2308 lines
No EOL
98 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
||
# All rights reserved.
|
||
"""
|
||
Interface to Management database and AI Connectors.
|
||
Uses the JSON connector for data access with added language support.
|
||
"""
|
||
|
||
import os
|
||
import logging
|
||
import base64
|
||
import hashlib
|
||
import math
|
||
import mimetypes
|
||
from typing import Dict, Any, List, Optional, Union
|
||
|
||
from modules.connectors.connectorDbPostgre import DatabaseConnector, getCachedConnector
|
||
from modules.shared.dbRegistry import registerDatabase
|
||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC, getRecordsetPaginatedWithRBAC
|
||
from modules.security.rbac import RbacClass
|
||
from modules.datamodels.datamodelRbac import AccessRuleContext
|
||
from modules.datamodels.datamodelUam import AccessLevel
|
||
from modules.datamodels.datamodelFiles import FilePreview, FileItem, FileData
|
||
from modules.datamodels.datamodelFileFolder import FileFolder
|
||
from modules.datamodels.datamodelUtils import Prompt
|
||
from modules.datamodels.datamodelMessaging import (
|
||
MessagingSubscription,
|
||
MessagingSubscriptionRegistration,
|
||
MessagingDelivery,
|
||
MessagingChannel
|
||
)
|
||
from modules.datamodels.datamodelUam import User, Mandate
|
||
from modules.shared.configuration import APP_CONFIG
|
||
from modules.shared.timeUtils import getUtcTimestamp
|
||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
managementDatabase = "poweron_management"
|
||
registerDatabase(managementDatabase)
|
||
|
||
# Singleton factory for Management instances with AI service per context
|
||
_instancesManagement = {}
|
||
|
||
# Custom exceptions for file handling
|
||
class FileError(Exception):
|
||
"""Base class for file handling exceptions."""
|
||
pass
|
||
|
||
class FileNotFoundError(FileError):
|
||
"""Exception raised when a file is not found."""
|
||
pass
|
||
|
||
class FileStorageError(FileError):
|
||
"""Exception raised when there's an error storing a file."""
|
||
pass
|
||
|
||
class FilePermissionError(FileError):
|
||
"""Exception raised when there's a permission issue with a file."""
|
||
pass
|
||
|
||
class FileDeletionError(FileError):
|
||
"""Exception raised when there's an error deleting a file."""
|
||
pass
|
||
|
||
class ComponentObjects:
|
||
"""
|
||
Interface to Management database and AI Connectors.
|
||
Uses the JSON connector for data access with added language support.
|
||
"""
|
||
|
||
def __init__(self):
|
||
"""Initializes the Management Interface."""
|
||
# Initialize variables first
|
||
self.currentUser: Optional[User] = None
|
||
self.userId: Optional[str] = None
|
||
self.rbac: Optional[RbacClass] = None # RBAC interface
|
||
|
||
# Initialize database
|
||
self._initializeDatabase()
|
||
|
||
# Initialize standard records if needed
|
||
self._initRecords()
|
||
|
||
def setUserContext(self, currentUser: User, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
|
||
"""Sets the user context for the interface.
|
||
|
||
Args:
|
||
currentUser: The authenticated user
|
||
mandateId: The mandate ID from RequestContext (X-Mandate-Id header)
|
||
featureInstanceId: The feature instance ID from RequestContext (X-Feature-Instance-Id header)
|
||
"""
|
||
if not currentUser:
|
||
logger.info("Initializing interface without user context")
|
||
return
|
||
|
||
self.currentUser = currentUser # Store User object directly
|
||
self.userId = currentUser.id
|
||
# Use mandateId from parameter (Request-Context), not from user object
|
||
self.mandateId = mandateId
|
||
self.featureInstanceId = featureInstanceId
|
||
|
||
if not self.userId:
|
||
raise ValueError("Invalid user context: id is required")
|
||
|
||
# Add language settings
|
||
self.userLanguage = currentUser.language # Default user language
|
||
|
||
# Initialize RBAC interface
|
||
if not self.currentUser:
|
||
raise ValueError("User context is required for RBAC")
|
||
# Get DbApp connection for RBAC AccessRule queries
|
||
from modules.security.rootAccess import getRootDbAppConnector
|
||
dbApp = getRootDbAppConnector()
|
||
self.rbac = RbacClass(self.db, dbApp=dbApp)
|
||
|
||
# Update database context
|
||
self.db.updateContext(self.userId)
|
||
|
||
def __del__(self):
|
||
"""Cleanup method to close database connection."""
|
||
if hasattr(self, 'db') and self.db is not None:
|
||
try:
|
||
self.db.close()
|
||
except Exception as e:
|
||
logger.error(f"Error closing database connection: {e}")
|
||
|
||
logger.debug(f"User context set: userId={self.userId}")
|
||
|
||
def _initializeDatabase(self):
|
||
"""Initializes the database connection directly."""
|
||
try:
|
||
# Get configuration values with defaults
|
||
dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data")
|
||
dbDatabase = managementDatabase
|
||
dbUser = APP_CONFIG.get("DB_USER")
|
||
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
||
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
||
|
||
self.db = getCachedConnector(
|
||
dbHost=dbHost,
|
||
dbDatabase=dbDatabase,
|
||
dbUser=dbUser,
|
||
dbPassword=dbPassword,
|
||
dbPort=dbPort,
|
||
userId=self.userId if hasattr(self, 'userId') else None
|
||
)
|
||
|
||
logger.info("Database initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize database: {str(e)}")
|
||
raise
|
||
|
||
def _separateObjectFields(self, model_class, data: Dict[str, Any]) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
||
"""Separate simple fields from object fields based on Pydantic model structure."""
|
||
simpleFields = {}
|
||
objectFields = {}
|
||
|
||
# Get field information from the Pydantic model
|
||
modelFields = model_class.model_fields
|
||
|
||
for fieldName, value in data.items():
|
||
# Check if this field should be stored as JSONB in the database
|
||
if fieldName in modelFields:
|
||
fieldInfo = modelFields[fieldName]
|
||
# Pydantic v2 only
|
||
fieldType = fieldInfo.annotation
|
||
|
||
# Check if this is a JSONB field (Dict, List, or complex types)
|
||
# Purely type-based detection - no hardcoded field names
|
||
if (fieldType == dict or
|
||
fieldType == list or
|
||
(hasattr(fieldType, '__origin__') and fieldType.__origin__ in (dict, list))):
|
||
# Store as JSONB - include in simple_fields for database storage
|
||
simpleFields[fieldName] = value
|
||
elif isinstance(value, (str, int, float, bool, type(None))):
|
||
# Simple scalar types
|
||
simpleFields[fieldName] = value
|
||
else:
|
||
# Complex objects that should be filtered out
|
||
objectFields[fieldName] = value
|
||
else:
|
||
if isinstance(value, (str, int, float, bool, type(None))):
|
||
simpleFields[fieldName] = value
|
||
else:
|
||
objectFields[fieldName] = value
|
||
|
||
return simpleFields, objectFields
|
||
|
||
def _initRecords(self):
|
||
"""Initializes standard records in the database if they don't exist."""
|
||
try:
|
||
# Initialize standard prompts
|
||
self._initializeStandardPrompts()
|
||
self._seedUiLanguageSetsIfEmpty()
|
||
|
||
# Add other record initializations here
|
||
|
||
logger.info("Standard records initialized successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize standard records: {str(e)}")
|
||
# Don't raise the error, just log it
|
||
# This allows the interface to be created even if initialization fails
|
||
|
||
def _seedUiLanguageSetsIfEmpty(self) -> None:
|
||
try:
|
||
import json
|
||
from pathlib import Path
|
||
|
||
from modules.datamodels.datamodelUiLanguage import UiLanguageSet
|
||
|
||
existing = self.db.getRecordset(UiLanguageSet)
|
||
if existing:
|
||
return
|
||
seedPath = (
|
||
Path(__file__).resolve().parent.parent
|
||
/ "migration"
|
||
/ "seedData"
|
||
/ "ui_language_seed.json"
|
||
)
|
||
if not seedPath.is_file():
|
||
logger.warning("ui_language_seed.json not found, skipping UI i18n seed")
|
||
return
|
||
payload = json.loads(seedPath.read_text(encoding="utf-8"))
|
||
now = getUtcTimestamp()
|
||
for row in payload:
|
||
entries = row.get("entries")
|
||
if not isinstance(entries, list):
|
||
keys = row.get("keys") or {}
|
||
entries = [{"context": "ui", "key": k, "value": v} for k, v in keys.items()]
|
||
rec = {
|
||
"id": row["id"],
|
||
"label": row["label"],
|
||
"entries": entries,
|
||
"status": row.get("status") or "complete",
|
||
"isDefault": bool(row.get("isDefault", False)),
|
||
"sysCreatedAt": now,
|
||
"sysModifiedBy": None,
|
||
"sysCreatedBy": None,
|
||
"sysModifiedAt": now,
|
||
}
|
||
self.db.recordCreate(UiLanguageSet, rec)
|
||
logger.info("Seeded UiLanguageSet rows from ui_language_seed.json")
|
||
except Exception as e:
|
||
logger.error(f"UI i18n seed failed: {e}")
|
||
|
||
def _initializeStandardPrompts(self):
|
||
"""Initializes standard prompts if they don't exist yet."""
|
||
try:
|
||
# Check if any prompts exist
|
||
existingPrompts = self.db.getRecordset(Prompt)
|
||
if existingPrompts:
|
||
logger.info("Prompts already exist, skipping initialization")
|
||
return
|
||
|
||
# Get the root interface to access the initial mandate ID
|
||
from modules.security.rootAccess import getRootUser
|
||
from modules.interfaces.interfaceDbApp import getInterface
|
||
rootUser = getRootUser()
|
||
rootInterface = getInterface(rootUser)
|
||
|
||
# Get initial mandate ID through the root interface
|
||
mandateId = rootInterface.getInitialId(Mandate)
|
||
if not mandateId:
|
||
logger.error("No initial mandate ID found")
|
||
return
|
||
|
||
# Get root user for initialization
|
||
rootUser = rootInterface.getUserByUsername("admin")
|
||
if not rootUser:
|
||
logger.error("Root user not found for initialization")
|
||
return
|
||
|
||
# Store current user context if it exists
|
||
currentUser = self.currentUser
|
||
|
||
# Set user context to root user for initialization
|
||
self.setUserContext(rootUser)
|
||
|
||
# Define standard prompts
|
||
standardPrompts = [
|
||
Prompt(
|
||
name="Market Research",
|
||
content="Research the current market trends and developments in [TOPIC]. Collect information about leading companies, innovative products or services, and current challenges. Present the results in a structured overview with relevant data and sources.",
|
||
mandateId=mandateId
|
||
),
|
||
Prompt(
|
||
name="Data Analysis",
|
||
content="Analyze the attached dataset on [TOPIC] and identify the most important trends, patterns, and anomalies. Perform statistical calculations to support your findings. Present the results in a clearly structured analysis and draw relevant conclusions.",
|
||
mandateId=mandateId
|
||
),
|
||
Prompt(
|
||
name="Meeting Protocol",
|
||
content="Create a detailed protocol of our meeting on [TOPIC]. Capture all discussed points, decisions made, and agreed measures. Structure the protocol clearly with agenda items, participant list, and clear responsibilities for follow-up actions.",
|
||
mandateId=mandateId
|
||
),
|
||
Prompt(
|
||
name="UI/UX Design",
|
||
content="Develop a UI/UX design concept for [APPLICATION/WEBSITE]. Consider the target audience, main functions, and brand identity. Describe the visual design, navigation, interaction patterns, and information architecture. Explain how the design optimizes user-friendliness and user experience.",
|
||
mandateId=mandateId
|
||
),
|
||
Prompt(
|
||
name="Primzahlen",
|
||
content="Gib mir die ersten 1000 Primzahlen.",
|
||
mandateId=mandateId
|
||
),
|
||
Prompt(
|
||
name="E-Mail",
|
||
content="Bereite mir eine formelle E-Mail an peter.muster@domain.com vor, um meinen Termin von 10 Uhr auf Freitag zu scheiben.",
|
||
mandateId=mandateId
|
||
)
|
||
]
|
||
|
||
# Create prompts
|
||
for prompt in standardPrompts:
|
||
self.db.recordCreate(Prompt, prompt)
|
||
logger.info(f"Created standard prompt: {prompt.name}")
|
||
|
||
# Restore original user context if it existed
|
||
if currentUser:
|
||
self.setUserContext(currentUser)
|
||
else:
|
||
self.currentUser = None
|
||
self.userId = None
|
||
self.db.updateContext("") # Reset database context
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error initializing standard prompts: {str(e)}")
|
||
# Ensure we restore user context even if there's an error
|
||
if 'currentUser' in locals() and currentUser:
|
||
self.setUserContext(currentUser)
|
||
else:
|
||
self.currentUser = None
|
||
self.userId = None
|
||
self.db.updateContext("") # Reset database context
|
||
|
||
|
||
def checkRbacPermission(
|
||
self,
|
||
modelClass: type,
|
||
operation: str,
|
||
recordId: Optional[str] = None
|
||
) -> bool:
|
||
"""
|
||
Check RBAC permission for a specific operation on a table.
|
||
|
||
Args:
|
||
modelClass: Pydantic model class for the table
|
||
operation: Operation to check ('create', 'update', 'delete', 'read')
|
||
recordId: Optional record ID for specific record check
|
||
|
||
Returns:
|
||
Boolean indicating permission
|
||
"""
|
||
if not self.rbac or not self.currentUser:
|
||
return False
|
||
|
||
tableName = modelClass.__name__
|
||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||
objectKey = buildDataObjectKey(tableName)
|
||
permissions = self.rbac.getUserPermissions(
|
||
self.currentUser,
|
||
AccessRuleContext.DATA,
|
||
objectKey,
|
||
mandateId=self.mandateId,
|
||
featureInstanceId=self.featureInstanceId
|
||
)
|
||
|
||
if operation == "create":
|
||
return permissions.create != AccessLevel.NONE
|
||
elif operation == "update":
|
||
return permissions.update != AccessLevel.NONE
|
||
elif operation == "delete":
|
||
return permissions.delete != AccessLevel.NONE
|
||
elif operation == "read":
|
||
return permissions.read != AccessLevel.NONE
|
||
else:
|
||
return False
|
||
|
||
def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
"""
|
||
Apply filter criteria to records.
|
||
|
||
Supports:
|
||
- General search: {"search": "text"} - searches across all text fields
|
||
- Field-specific filters:
|
||
- Simple: {"status": "running"} - equals match
|
||
- With operator: {"status": {"operator": "equals", "value": "running"}}
|
||
- Operators: equals, contains, gt, gte, lt, lte, in, notIn, startsWith, endsWith
|
||
|
||
Args:
|
||
records: List of record dictionaries to filter
|
||
filters: Filter criteria dictionary
|
||
|
||
Returns:
|
||
Filtered list of records
|
||
"""
|
||
if not filters or not records:
|
||
return records
|
||
|
||
filtered = []
|
||
|
||
for record in records:
|
||
matches = True
|
||
|
||
# Handle general search across text fields
|
||
if "search" in filters:
|
||
search_term = str(filters["search"]).lower()
|
||
if search_term:
|
||
# Search in all string fields
|
||
found = False
|
||
for key, value in record.items():
|
||
if isinstance(value, str) and search_term in value.lower():
|
||
found = True
|
||
break
|
||
elif isinstance(value, (int, float)) and search_term in str(value):
|
||
found = True
|
||
break
|
||
if not found:
|
||
matches = False
|
||
|
||
# Handle field-specific filters
|
||
for field_name, filter_value in filters.items():
|
||
if field_name == "search":
|
||
continue # Already handled above
|
||
|
||
if field_name not in record:
|
||
matches = False
|
||
break
|
||
|
||
record_value = record.get(field_name)
|
||
|
||
# Handle simple value (equals operator)
|
||
if not isinstance(filter_value, dict):
|
||
if str(record_value).lower() != str(filter_value).lower():
|
||
matches = False
|
||
break
|
||
continue
|
||
|
||
# Handle filter with operator
|
||
operator = filter_value.get("operator", "equals")
|
||
filter_val = filter_value.get("value")
|
||
|
||
if operator in ["equals", "eq"]:
|
||
if str(record_value).lower() != str(filter_val).lower():
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "contains":
|
||
record_str = str(record_value).lower() if record_value is not None else ""
|
||
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
||
if filter_str not in record_str:
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "startsWith":
|
||
record_str = str(record_value).lower() if record_value is not None else ""
|
||
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
||
if not record_str.startswith(filter_str):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "endsWith":
|
||
record_str = str(record_value).lower() if record_value is not None else ""
|
||
filter_str = str(filter_val).lower() if filter_val is not None else ""
|
||
if not record_str.endswith(filter_str):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "gt":
|
||
try:
|
||
record_num = float(record_value) if record_value is not None else float('-inf')
|
||
filter_num = float(filter_val) if filter_val is not None else float('-inf')
|
||
if record_num <= filter_num:
|
||
matches = False
|
||
break
|
||
except (ValueError, TypeError):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "gte":
|
||
try:
|
||
record_num = float(record_value) if record_value is not None else float('-inf')
|
||
filter_num = float(filter_val) if filter_val is not None else float('-inf')
|
||
if record_num < filter_num:
|
||
matches = False
|
||
break
|
||
except (ValueError, TypeError):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "lt":
|
||
try:
|
||
record_num = float(record_value) if record_value is not None else float('inf')
|
||
filter_num = float(filter_val) if filter_val is not None else float('inf')
|
||
if record_num >= filter_num:
|
||
matches = False
|
||
break
|
||
except (ValueError, TypeError):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "lte":
|
||
try:
|
||
record_num = float(record_value) if record_value is not None else float('inf')
|
||
filter_num = float(filter_val) if filter_val is not None else float('inf')
|
||
if record_num > filter_num:
|
||
matches = False
|
||
break
|
||
except (ValueError, TypeError):
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "in":
|
||
if not isinstance(filter_val, list):
|
||
filter_val = [filter_val]
|
||
if record_value not in filter_val:
|
||
matches = False
|
||
break
|
||
|
||
elif operator == "notIn":
|
||
if not isinstance(filter_val, list):
|
||
filter_val = [filter_val]
|
||
if record_value in filter_val:
|
||
matches = False
|
||
break
|
||
|
||
else:
|
||
# Unknown operator - default to equals
|
||
if record_value != filter_val:
|
||
matches = False
|
||
break
|
||
|
||
if matches:
|
||
filtered.append(record)
|
||
|
||
return filtered
|
||
|
||
def _applySorting(self, records: List[Dict[str, Any]], sortFields: List[Any]) -> List[Dict[str, Any]]:
|
||
"""Apply multi-level sorting to records using stable sort (sorts from least to most significant field)."""
|
||
if not sortFields:
|
||
return records
|
||
|
||
# Start with a copy to avoid modifying original
|
||
sortedRecords = list(records)
|
||
|
||
# Sort from least significant to most significant field (reverse order)
|
||
# Python's sort is stable, so this creates proper multi-level sorting
|
||
for sortField in reversed(sortFields):
|
||
# Handle both dict and object formats
|
||
if isinstance(sortField, dict):
|
||
fieldName = sortField.get("field")
|
||
direction = sortField.get("direction", "asc")
|
||
else:
|
||
fieldName = getattr(sortField, "field", None)
|
||
direction = getattr(sortField, "direction", "asc")
|
||
|
||
if not fieldName:
|
||
continue
|
||
|
||
isDesc = (direction == "desc")
|
||
|
||
def sortKey(record):
|
||
value = record.get(fieldName)
|
||
# Handle None values - place them at the end for both directions
|
||
if value is None:
|
||
# Use a special value that sorts last
|
||
return (1, "") # (is_none_flag, empty_value) - sorts after (0, ...)
|
||
else:
|
||
# Return tuple with type indicator for proper comparison
|
||
if isinstance(value, (int, float)):
|
||
return (0, value)
|
||
elif isinstance(value, str):
|
||
return (0, value)
|
||
elif isinstance(value, bool):
|
||
return (0, value)
|
||
else:
|
||
return (0, str(value))
|
||
|
||
# Sort with reverse parameter for descending
|
||
sortedRecords.sort(key=sortKey, reverse=isDesc)
|
||
|
||
return sortedRecords
|
||
|
||
# Utilities
|
||
|
||
def getInitialId(self, model_class: type) -> Optional[str]:
|
||
"""Returns the initial ID for a table."""
|
||
return self.db.getInitialId(model_class)
|
||
|
||
def _parse_size_string(self, size_str: str) -> Optional[int]:
|
||
"""
|
||
Parse a formatted size string (e.g., "2.13 MB", "1.5 GB") to bytes.
|
||
|
||
Args:
|
||
size_str: Formatted size string like "2.13 MB", "1.5 GB", "500 KB"
|
||
|
||
Returns:
|
||
Size in bytes, or None if parsing fails
|
||
"""
|
||
try:
|
||
size_str = size_str.strip().upper()
|
||
# Remove common separators and spaces
|
||
size_str = size_str.replace(",", "").replace(" ", "")
|
||
|
||
# Extract number and unit - handle both "MB" and "M" formats
|
||
import re
|
||
# Match: number (with optional decimal) followed by optional unit (K/M/G/T with optional B)
|
||
match = re.match(r"^([\d.]+)([KMGT]?B?)$", size_str)
|
||
if not match:
|
||
return None
|
||
|
||
number = float(match.group(1))
|
||
unit = match.group(2) or "B"
|
||
|
||
# Normalize unit (handle "M" as "MB", "K" as "KB", etc.)
|
||
if len(unit) == 1 and unit in "KMGT":
|
||
unit = unit + "B"
|
||
|
||
# Convert to bytes
|
||
multipliers = {
|
||
"B": 1,
|
||
"KB": 1024,
|
||
"MB": 1024 * 1024,
|
||
"GB": 1024 * 1024 * 1024,
|
||
"TB": 1024 * 1024 * 1024 * 1024,
|
||
}
|
||
|
||
multiplier = multipliers.get(unit, 1)
|
||
return int(number * multiplier)
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
|
||
# Prompt methods
|
||
|
||
def _isSysAdmin(self) -> bool:
|
||
"""Check if the current user has the isSysAdmin flag (infrastructure operator)."""
|
||
return bool(getattr(self.currentUser, 'isSysAdmin', False))
|
||
|
||
def _enrichPromptsWithPermissions(self, prompts: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||
"""Enrich prompts with row-level _permissions based on ownership and isSystem flag.
|
||
|
||
- SysAdmin: canUpdate=True, canDelete=True on all prompts
|
||
- Regular user on own prompts: canUpdate=True, canDelete=True
|
||
- Regular user on system prompts: canUpdate=False, canDelete=False (read-only)
|
||
"""
|
||
isSysAdmin = self._isSysAdmin()
|
||
for prompt in prompts:
|
||
isOwner = prompt.get("sysCreatedBy") == self.userId
|
||
prompt["_permissions"] = {
|
||
"canUpdate": isOwner or isSysAdmin,
|
||
"canDelete": isOwner or isSysAdmin
|
||
}
|
||
return prompts
|
||
|
||
def _getPromptsForUser(self) -> List[Dict[str, Any]]:
|
||
"""Returns prompts visible to the current user.
|
||
|
||
Visibility rules:
|
||
- SysAdmin: ALL prompts
|
||
- Regular user: own prompts (sysCreatedBy) + system prompts (isSystem=True)
|
||
"""
|
||
if self._isSysAdmin():
|
||
return self.db.getRecordset(Prompt)
|
||
|
||
# Get own prompts
|
||
ownPrompts = self.db.getRecordset(Prompt, recordFilter={"sysCreatedBy": self.userId})
|
||
|
||
# Get system prompts
|
||
systemPrompts = self.db.getRecordset(Prompt, recordFilter={"isSystem": True})
|
||
|
||
# Merge and deduplicate (a user's own prompt could also be isSystem)
|
||
seen = {}
|
||
for p in ownPrompts:
|
||
seen[p["id"]] = p
|
||
for p in systemPrompts:
|
||
if p["id"] not in seen:
|
||
seen[p["id"]] = p
|
||
|
||
return list(seen.values())
|
||
|
||
def getAllPrompts(self, pagination: Optional[PaginationParams] = None) -> Union[List[Prompt], PaginatedResult]:
|
||
"""
|
||
Returns prompts with visibility rules:
|
||
- SysAdmin: sees ALL prompts, can CRUD all
|
||
- Regular user: sees own prompts + system prompts (isSystem=True), can only CRUD own
|
||
- Row-level _permissions control edit/delete buttons in the UI
|
||
|
||
NOTE: Cannot use db.getRecordsetPaginated() because visibility rules
|
||
(_getPromptsForUser: own + system for regular, all for SysAdmin) and
|
||
per-row _permissions enrichment require loading all records first.
|
||
|
||
Args:
|
||
pagination: Optional pagination parameters. If None, returns all items.
|
||
|
||
Returns:
|
||
If pagination is None: List[Prompt]
|
||
If pagination is provided: PaginatedResult with items and metadata
|
||
"""
|
||
try:
|
||
# Get prompts based on user role (own + system for regular, all for SysAdmin)
|
||
filteredPrompts = self._getPromptsForUser()
|
||
|
||
# Enrich with row-level permissions (_permissions: canUpdate, canDelete)
|
||
filteredPrompts = self._enrichPromptsWithPermissions(filteredPrompts)
|
||
|
||
# If no pagination requested, return all items
|
||
if pagination is None:
|
||
return [Prompt(**prompt) for prompt in filteredPrompts]
|
||
|
||
# Apply filtering (if filters provided)
|
||
if pagination.filters:
|
||
filteredPrompts = self._applyFilters(filteredPrompts, pagination.filters)
|
||
|
||
# Apply sorting (in order of sortFields)
|
||
if pagination.sort:
|
||
filteredPrompts = self._applySorting(filteredPrompts, pagination.sort)
|
||
|
||
# Count total items after filters
|
||
totalItems = len(filteredPrompts)
|
||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||
|
||
# Apply pagination (skip/limit)
|
||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||
endIdx = startIdx + pagination.pageSize
|
||
pagedPrompts = filteredPrompts[startIdx:endIdx]
|
||
|
||
# Convert to model objects (extra='allow' on Prompt preserves system fields)
|
||
items = [Prompt(**prompt) for prompt in pagedPrompts]
|
||
|
||
return PaginatedResult(
|
||
items=items,
|
||
totalItems=totalItems,
|
||
totalPages=totalPages
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting prompts: {str(e)}")
|
||
if pagination is None:
|
||
return []
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
|
||
def getPrompt(self, promptId: str) -> Optional[Prompt]:
|
||
"""Returns a prompt by ID if the user has visibility.
|
||
|
||
Visibility: SysAdmin sees all, regular user sees own + system prompts.
|
||
"""
|
||
filteredPrompts = self.db.getRecordset(Prompt, recordFilter={"id": promptId})
|
||
if not filteredPrompts:
|
||
return None
|
||
|
||
prompt = filteredPrompts[0]
|
||
|
||
# Visibility check for non-SysAdmin: must be owner or system prompt
|
||
if not self._isSysAdmin():
|
||
isOwner = prompt.get("sysCreatedBy") == self.userId
|
||
isSystem = prompt.get("isSystem", False)
|
||
if not isOwner and not isSystem:
|
||
return None
|
||
|
||
return Prompt(**prompt)
|
||
|
||
def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Creates a new prompt if user has permission."""
|
||
if not self.checkRbacPermission(Prompt, "create"):
|
||
raise PermissionError("No permission to create prompts")
|
||
|
||
# Create prompt record
|
||
createdRecord = self.db.recordCreate(Prompt, promptData)
|
||
if not createdRecord or not createdRecord.get("id"):
|
||
raise ValueError("Failed to create prompt record")
|
||
|
||
return createdRecord
|
||
|
||
def updatePrompt(self, promptId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Updates a prompt. Rules:
|
||
- SysAdmin: can update any prompt (including system prompts)
|
||
- Regular user: can only update own prompts (not system prompts)
|
||
"""
|
||
try:
|
||
# Get prompt (visibility-checked)
|
||
prompt = self.getPrompt(promptId)
|
||
if not prompt:
|
||
raise ValueError(f"Prompt {promptId} not found")
|
||
|
||
# Permission check: owner or SysAdmin
|
||
isOwner = (getattr(prompt, 'sysCreatedBy', None) == self.userId)
|
||
if not self._isSysAdmin() and not isOwner:
|
||
raise PermissionError(f"No permission to update prompt {promptId}")
|
||
|
||
# Regular users cannot set isSystem flag
|
||
if not self._isSysAdmin() and 'isSystem' in updateData:
|
||
del updateData['isSystem']
|
||
|
||
# Update prompt record directly with the update data
|
||
self.db.recordModify(Prompt, promptId, updateData)
|
||
|
||
# Clear cache to ensure fresh data
|
||
|
||
# Get updated prompt
|
||
updatedPrompt = self.getPrompt(promptId)
|
||
if not updatedPrompt:
|
||
raise ValueError("Failed to retrieve updated prompt")
|
||
|
||
return updatedPrompt.model_dump()
|
||
|
||
except PermissionError:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error updating prompt: {str(e)}")
|
||
raise ValueError(f"Failed to update prompt: {str(e)}")
|
||
|
||
def deletePrompt(self, promptId: str) -> bool:
|
||
"""Deletes a prompt. Rules:
|
||
- SysAdmin: can delete any prompt (including system prompts)
|
||
- Regular user: can only delete own prompts (not system prompts)
|
||
"""
|
||
# Get prompt (visibility-checked)
|
||
prompt = self.getPrompt(promptId)
|
||
if not prompt:
|
||
return False
|
||
|
||
# Permission check: owner or SysAdmin
|
||
isOwner = (getattr(prompt, 'sysCreatedBy', None) == self.userId)
|
||
if not self._isSysAdmin() and not isOwner:
|
||
raise PermissionError(f"No permission to delete prompt {promptId}")
|
||
|
||
# Delete prompt
|
||
success = self.db.recordDelete(Prompt, promptId)
|
||
|
||
return success
|
||
|
||
# File Utilities
|
||
|
||
def checkForDuplicateFile(self, fileHash: str, fileName: str) -> Optional[FileItem]:
|
||
"""Checks if a file with the same hash AND fileName already exists for the current user
|
||
**within the same scope** (mandateId + featureInstanceId).
|
||
|
||
Duplicate = same user + same fileHash + same fileName + same scope + RBAC-visible.
|
||
Same hash with different name is allowed (intentional copy by user).
|
||
|
||
RBAC parity contract: this method must NEVER return a FileItem that
|
||
``getFile()`` would not return for the current user. Otherwise callers
|
||
(``saveUploadedFile`` / ``createFile``) hand back an id that the very
|
||
next ``updateFile`` / ``getFile`` then rejects with
|
||
``File with ID ... not found`` -- the well-known "ghost duplicate"
|
||
symptom seen when ``interfaceDbComponent`` is initialised without an
|
||
``featureInstanceId`` (e.g. via ``serviceHub``) but a same-hash+name
|
||
file exists in another featureInstance under the same mandate.
|
||
We therefore cross-check the candidate through the RBAC-aware ``getFile``
|
||
before returning it; if RBAC blocks it, we treat it as "no duplicate
|
||
for this scope" and the caller will create a fresh per-scope copy.
|
||
"""
|
||
if not self.userId:
|
||
return None
|
||
|
||
recordFilter: dict = {
|
||
"sysCreatedBy": self.userId,
|
||
"fileHash": fileHash,
|
||
"fileName": fileName,
|
||
}
|
||
if self.featureInstanceId:
|
||
recordFilter["featureInstanceId"] = self.featureInstanceId
|
||
elif self.mandateId:
|
||
recordFilter["mandateId"] = self.mandateId
|
||
|
||
matchingFiles = self.db.getRecordset(
|
||
FileItem,
|
||
recordFilter=recordFilter,
|
||
)
|
||
|
||
if not matchingFiles:
|
||
return None
|
||
|
||
file = matchingFiles[0]
|
||
fileId = file["id"]
|
||
|
||
fileDataExists = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||
if not fileDataExists:
|
||
logger.warning(f"Duplicate FileItem {fileId} found but FileData missing — treating as new file")
|
||
return None
|
||
|
||
rbacVisible = self.getFile(fileId)
|
||
if rbacVisible is None:
|
||
logger.info(
|
||
f"Duplicate FileItem {fileId} ('{fileName}', hash {fileHash[:12]}...) found via "
|
||
f"sysCreatedBy+hash+name match but is not RBAC-visible in current scope "
|
||
f"(mandateId={self.mandateId or '-'}, featureInstanceId={self.featureInstanceId or '-'}). "
|
||
f"Treating as no-duplicate so a fresh per-scope copy gets created."
|
||
)
|
||
return None
|
||
|
||
return rbacVisible
|
||
|
||
# Class-level cache — built once from the ExtractorRegistry
|
||
_extensionToMime: Optional[Dict[str, str]] = None
|
||
_textMimeTypes: Optional[set] = None
|
||
|
||
@classmethod
|
||
def _ensureMimeMaps(cls):
|
||
"""Lazily build extension→MIME and text-MIME-set from the ExtractorRegistry."""
|
||
if cls._extensionToMime is not None:
|
||
return
|
||
try:
|
||
from modules.serviceCenter.services.serviceExtraction.subRegistry import ExtractorRegistry
|
||
registry = ExtractorRegistry()
|
||
cls._extensionToMime = registry.getExtensionToMimeMap()
|
||
|
||
# Collect all MIME types declared by the TextExtractor (and other text-ish extractors)
|
||
textMimes: set = set()
|
||
seen: set = set()
|
||
for ext in registry._map.values():
|
||
eid = id(ext)
|
||
if eid in seen:
|
||
continue
|
||
seen.add(eid)
|
||
mimes = ext.getSupportedMimeTypes()
|
||
if any(m.startswith("text/") for m in mimes):
|
||
textMimes.update(mimes)
|
||
# Always include common text types
|
||
textMimes.update({
|
||
"application/json", "application/xml", "application/javascript",
|
||
"application/sql", "application/x-yaml", "application/x-toml",
|
||
})
|
||
cls._textMimeTypes = textMimes
|
||
except Exception:
|
||
cls._extensionToMime = {}
|
||
cls._textMimeTypes = set()
|
||
|
||
def getMimeType(self, fileName: str) -> str:
|
||
"""Determines the MIME type based on the file extension.
|
||
|
||
Resolution order:
|
||
1. ExtractorRegistry (derived from all registered extractors)
|
||
2. Python stdlib mimetypes.guess_type
|
||
3. Fallback: application/octet-stream
|
||
"""
|
||
self._ensureMimeMaps()
|
||
ext = os.path.splitext(fileName)[1].lower().lstrip('.')
|
||
if ext and ext in self._extensionToMime:
|
||
return self._extensionToMime[ext]
|
||
guessed, _ = mimetypes.guess_type(fileName, strict=False)
|
||
return guessed or "application/octet-stream"
|
||
|
||
def isTextMimeType(self, mimeType: str) -> bool:
|
||
"""Determines if a MIME type represents a text-based format.
|
||
|
||
Derived from the MIME types declared by text-oriented extractors.
|
||
"""
|
||
self._ensureMimeMaps()
|
||
lower = mimeType.lower()
|
||
if lower.startswith("text/"):
|
||
return True
|
||
return lower in self._textMimeTypes
|
||
|
||
# File methods - metadata-based operations
|
||
|
||
def _getFilesByCurrentUser(self, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
||
"""Files are always user-scoped. Returns only files owned by the current user,
|
||
regardless of role (including SysAdmin). This bypasses RBAC intentionally."""
|
||
filterDict = {"sysCreatedBy": self.userId}
|
||
if recordFilter:
|
||
filterDict.update(recordFilter)
|
||
return self.db.getRecordset(FileItem, recordFilter=filterDict)
|
||
|
||
def _isMandatelessFile(self, file: Dict[str, Any]) -> bool:
|
||
"""A file has no mandate context if mandateId is empty, None, or missing."""
|
||
mid = file.get("mandateId")
|
||
return not mid or (isinstance(mid, str) and not mid.strip())
|
||
|
||
def getAllFiles(self, pagination: Optional[PaginationParams] = None, recordFilter: Dict[str, Any] = None) -> Union[List[FileItem], PaginatedResult]:
|
||
"""
|
||
Returns files visible to the current user based on RBAC scope rules.
|
||
|
||
Visibility (via RBAC GROUP / scope-based filtering):
|
||
- Own files (sysCreatedBy = currentUser) always visible
|
||
- Files with scope='global' visible to everyone
|
||
- Files with scope='mandate' visible to members of that mandate
|
||
- Files with scope='featureInstance' visible to users with access to that instance
|
||
|
||
Args:
|
||
pagination: Optional pagination parameters. If None, returns all items.
|
||
|
||
Returns:
|
||
If pagination is None: List[FileItem]
|
||
If pagination is provided: PaginatedResult with items and metadata
|
||
"""
|
||
def _convertFileItems(files):
|
||
fileItems = []
|
||
for file in files:
|
||
try:
|
||
sysCreatedAt = file.get("sysCreatedAt")
|
||
if sysCreatedAt is None or not isinstance(sysCreatedAt, (int, float)) or sysCreatedAt <= 0:
|
||
file["sysCreatedAt"] = getUtcTimestamp()
|
||
else:
|
||
file["sysCreatedAt"] = sysCreatedAt
|
||
|
||
fileName = file.get("fileName")
|
||
if not fileName or fileName == "None":
|
||
continue
|
||
|
||
if file.get("scope") is None:
|
||
file["scope"] = "personal"
|
||
if file.get("neutralize") is None:
|
||
file["neutralize"] = False
|
||
|
||
labelCols = {k: v for k, v in file.items() if k.endswith("Label")}
|
||
fileItem = FileItem(**file)
|
||
itemDict = fileItem.model_dump()
|
||
itemDict.update(labelCols)
|
||
fileItems.append(itemDict)
|
||
except Exception as e:
|
||
logger.warning(f"Skipping invalid file record: {str(e)}")
|
||
continue
|
||
return fileItems
|
||
|
||
if pagination is None:
|
||
allFiles = getRecordsetWithRBAC(
|
||
self.db, FileItem, self.currentUser,
|
||
recordFilter=recordFilter,
|
||
mandateId=self.mandateId,
|
||
featureInstanceId=self.featureInstanceId,
|
||
)
|
||
return _convertFileItems(allFiles)
|
||
|
||
result = getRecordsetPaginatedWithRBAC(
|
||
self.db, FileItem, self.currentUser,
|
||
pagination=pagination,
|
||
recordFilter=recordFilter,
|
||
mandateId=self.mandateId,
|
||
featureInstanceId=self.featureInstanceId,
|
||
)
|
||
|
||
if isinstance(result, PaginatedResult):
|
||
return PaginatedResult(
|
||
items=_convertFileItems(result.items),
|
||
totalItems=result.totalItems,
|
||
totalPages=result.totalPages,
|
||
)
|
||
|
||
return _convertFileItems(result if isinstance(result, list) else [])
|
||
|
||
def getFile(self, fileId: str) -> Optional[FileItem]:
|
||
"""Returns a file by ID if the current user has RBAC access (scope-based)."""
|
||
filteredFiles = getRecordsetWithRBAC(
|
||
self.db, FileItem, self.currentUser,
|
||
recordFilter={"id": fileId},
|
||
mandateId=self.mandateId,
|
||
featureInstanceId=self.featureInstanceId,
|
||
)
|
||
|
||
if not filteredFiles:
|
||
return None
|
||
|
||
file = filteredFiles[0]
|
||
try:
|
||
sysCreatedAt = file.get("sysCreatedAt")
|
||
if sysCreatedAt is None or not isinstance(sysCreatedAt, (int, float)) or sysCreatedAt <= 0:
|
||
file["sysCreatedAt"] = getUtcTimestamp()
|
||
|
||
if file.get("scope") is None:
|
||
file["scope"] = "personal"
|
||
if file.get("neutralize") is None:
|
||
file["neutralize"] = False
|
||
|
||
return FileItem(**file)
|
||
except Exception as e:
|
||
logger.error(f"Error converting file record: {str(e)}")
|
||
return None
|
||
|
||
def _isfileNameUnique(self, fileName: str, excludeFileId: Optional[str] = None) -> bool:
|
||
"""Checks if a fileName is unique for the current user."""
|
||
# Get all files filtered by RBAC (will be filtered by user's access level)
|
||
files = getRecordsetWithRBAC(self.db,
|
||
FileItem,
|
||
self.currentUser,
|
||
mandateId=self.mandateId
|
||
)
|
||
|
||
# Check if fileName exists (excluding the current file if updating)
|
||
for file in files:
|
||
# Skip files without fileName key or with None/empty fileName
|
||
if "fileName" not in file or not file["fileName"]:
|
||
continue
|
||
if file["fileName"] == fileName and (excludeFileId is None or file["id"] != excludeFileId):
|
||
return False
|
||
return True
|
||
|
||
def _generateUniquefileName(self, fileName: str, excludeFileId: Optional[str] = None) -> str:
|
||
"""Generates a unique fileName by adding a number if necessary."""
|
||
if self._isfileNameUnique(fileName, excludeFileId):
|
||
return fileName
|
||
|
||
# Split fileName into name and extension
|
||
name, ext = os.path.splitext(fileName)
|
||
counter = 1
|
||
|
||
# Try fileNames with increasing numbers until we find a unique one
|
||
while True:
|
||
newfileName = f"{name}_{counter}{ext}"
|
||
if self._isfileNameUnique(newfileName, excludeFileId):
|
||
return newfileName
|
||
counter += 1
|
||
|
||
def createFile(self, name: str, mimeType: str, content: bytes, folderId: Optional[str] = None) -> FileItem:
|
||
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content.
|
||
|
||
Duplicate check: if a file with the same user + fileHash + fileName already exists,
|
||
the existing file is returned instead of creating a new one.
|
||
Same hash with different name is allowed (intentional copy by user).
|
||
|
||
Args:
|
||
folderId: Optional parent folder ID. None/empty means the root folder.
|
||
"""
|
||
if not self.checkRbacPermission(FileItem, "create"):
|
||
raise PermissionError("No permission to create files")
|
||
|
||
# Compute file size and hash
|
||
fileSize = len(content)
|
||
fileHash = hashlib.sha256(content).hexdigest()
|
||
|
||
# Duplicate check: same user + same hash + same fileName → return existing
|
||
existingFile = self.checkForDuplicateFile(fileHash, name)
|
||
if existingFile:
|
||
logger.info(f"Duplicate file detected in createFile: '{name}' (hash={fileHash[:12]}...) for user {self.userId} — returning existing file {existingFile.id}")
|
||
return existingFile
|
||
|
||
# Ensure fileName is unique
|
||
uniqueName = self._generateUniquefileName(name)
|
||
|
||
mandateId = self.mandateId or ""
|
||
featureInstanceId = self.featureInstanceId or ""
|
||
|
||
if featureInstanceId:
|
||
scope = "featureInstance"
|
||
elif mandateId:
|
||
scope = "mandate"
|
||
else:
|
||
scope = "personal"
|
||
|
||
# Normalize folderId: treat empty string as "no folder" (= root) – NULL in DB
|
||
normalizedFolderId: Optional[str] = folderId
|
||
if isinstance(normalizedFolderId, str) and not normalizedFolderId.strip():
|
||
normalizedFolderId = None
|
||
|
||
fileItem = FileItem(
|
||
mandateId=mandateId,
|
||
featureInstanceId=featureInstanceId,
|
||
scope=scope,
|
||
fileName=uniqueName,
|
||
mimeType=mimeType,
|
||
fileSize=fileSize,
|
||
fileHash=fileHash,
|
||
folderId=normalizedFolderId,
|
||
)
|
||
|
||
# Store in database
|
||
self.db.recordCreate(FileItem, fileItem)
|
||
|
||
return fileItem
|
||
|
||
def _isFileOwner(self, file) -> bool:
|
||
"""Check if the current user owns the file."""
|
||
createdBy = getattr(file, "sysCreatedBy", None) or (file.get("sysCreatedBy") if isinstance(file, dict) else None)
|
||
return createdBy == self.userId
|
||
|
||
def _requireFileWriteAccess(self, file, fileId: str, operation: str = "update"):
|
||
"""Raise PermissionError if the user cannot mutate this file.
|
||
Owners always can. Non-owners need RBAC ALL level."""
|
||
if self._isFileOwner(file):
|
||
return
|
||
from modules.interfaces.interfaceRbac import buildDataObjectKey
|
||
from modules.datamodels.datamodelRbac import AccessRuleContext
|
||
objectKey = buildDataObjectKey("FileItem")
|
||
permissions = self.rbac.getUserPermissions(
|
||
self.currentUser, AccessRuleContext.DATA, objectKey,
|
||
mandateId=self.mandateId, featureInstanceId=self.featureInstanceId,
|
||
)
|
||
level = getattr(permissions, operation, None)
|
||
if level != AccessLevel.ALL:
|
||
raise PermissionError(
|
||
f"No permission to {operation} file {fileId} (not owner, access level: {level})"
|
||
)
|
||
|
||
def updateFile(self, fileId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Updates file metadata if user has access and is owner (or has ALL)."""
|
||
file = self.getFile(fileId)
|
||
if not file:
|
||
raise FileNotFoundError(f"File with ID {fileId} not found")
|
||
|
||
self._requireFileWriteAccess(file, fileId, "update")
|
||
|
||
# If fileName is being updated, ensure it's unique
|
||
if "fileName" in updateData:
|
||
updateData["fileName"] = self._generateUniquefileName(updateData["fileName"], fileId)
|
||
|
||
# Update file
|
||
success = self.db.recordModify(FileItem, fileId, updateData)
|
||
|
||
|
||
return success
|
||
|
||
def deleteFile(self, fileId: str) -> bool:
|
||
"""Deletes a file if user is owner (or has ALL access)."""
|
||
try:
|
||
file = self.getFile(fileId)
|
||
|
||
if not file:
|
||
raise FileNotFoundError(f"File with ID {fileId} not found")
|
||
|
||
self._requireFileWriteAccess(file, fileId, "delete")
|
||
|
||
# Check for other references to this file (by hash) - user-scoped check
|
||
fileHash = file.fileHash
|
||
if fileHash:
|
||
allReferences = self._getFilesByCurrentUser(recordFilter={"fileHash": fileHash})
|
||
otherReferences = [f for f in allReferences if f["id"] != fileId]
|
||
|
||
# Only delete associated fileData if no other references exist
|
||
if not otherReferences:
|
||
try:
|
||
fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||
if fileDataEntries:
|
||
self.db.recordDelete(FileData, fileId)
|
||
logger.debug(f"FileData for file {fileId} deleted")
|
||
except Exception as e:
|
||
logger.warning(f"Error deleting FileData for file {fileId}: {str(e)}")
|
||
|
||
# Delete the FileItem entry
|
||
success = self.db.recordDelete(FileItem, fileId)
|
||
|
||
# Clear cache to ensure fresh data
|
||
|
||
return success
|
||
|
||
except FileNotFoundError as e:
|
||
raise
|
||
except FilePermissionError as e:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error deleting file {fileId}: {str(e)}")
|
||
raise FileDeletionError(f"Error deleting file: {str(e)}")
|
||
|
||
def deleteFilesBatch(self, fileIds: List[str]) -> Dict[str, Any]:
|
||
"""Delete multiple files in a single SQL batch call.
|
||
Owner can always delete; non-owners need RBAC ALL level."""
|
||
uniqueIds = [str(fid) for fid in dict.fromkeys(fileIds or []) if fid]
|
||
if not uniqueIds:
|
||
return {"deletedFiles": 0}
|
||
|
||
try:
|
||
self.db._ensure_connection()
|
||
with self.db.connection.cursor() as cursor:
|
||
cursor.execute(
|
||
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
|
||
(uniqueIds,),
|
||
)
|
||
rows = cursor.fetchall()
|
||
foundIds = {row["id"] for row in rows}
|
||
missing = sorted(set(uniqueIds) - foundIds)
|
||
if missing:
|
||
raise FileNotFoundError(f"Files not found: {missing}")
|
||
|
||
for row in rows:
|
||
self._requireFileWriteAccess(row, row["id"], "delete")
|
||
|
||
accessibleIds = [row["id"] for row in rows]
|
||
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (accessibleIds,))
|
||
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,))
|
||
deletedFiles = cursor.rowcount
|
||
|
||
self.db.connection.commit()
|
||
return {"deletedFiles": deletedFiles}
|
||
except Exception as e:
|
||
logger.error(f"Error deleting files in batch: {e}")
|
||
self.db.connection.rollback()
|
||
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
|
||
|
||
# ---- Folder methods ----
|
||
|
||
_RESERVED_FOLDER_NAMES = {"(Global)"}
|
||
|
||
def _validateFolderName(self, name: str, parentId: Optional[str], excludeFolderId: Optional[str] = None):
|
||
"""Ensures folder name is not reserved and is unique within parent."""
|
||
if name in self._RESERVED_FOLDER_NAMES:
|
||
raise ValueError(f"Folder name '{name}' is reserved")
|
||
if not name or not name.strip():
|
||
raise ValueError("Folder name cannot be empty")
|
||
existingFolders = self.db.getRecordset(FileFolder, recordFilter={"parentId": parentId or ""})
|
||
for f in existingFolders:
|
||
if f.get("name") == name and f.get("id") != excludeFolderId:
|
||
raise ValueError(f"Folder '{name}' already exists in this directory")
|
||
|
||
def _isDescendantOf(self, folderId: str, ancestorId: str) -> bool:
|
||
"""Checks if folderId is a descendant of ancestorId (circular reference check)."""
|
||
visited = set()
|
||
currentId = folderId
|
||
while currentId:
|
||
if currentId == ancestorId:
|
||
return True
|
||
if currentId in visited:
|
||
break
|
||
visited.add(currentId)
|
||
folders = self.db.getRecordset(FileFolder, recordFilter={"id": currentId})
|
||
if not folders:
|
||
break
|
||
currentId = folders[0].get("parentId")
|
||
return False
|
||
|
||
def _ensureFeatureInstanceFolder(self, featureInstanceId: str, mandateId: str = "") -> Optional[str]:
|
||
"""Return the folder ID for a feature instance, creating it on first use.
|
||
The folder is named after the feature instance label."""
|
||
existing = self.db.getRecordset(
|
||
FileFolder,
|
||
recordFilter={
|
||
"featureInstanceId": featureInstanceId,
|
||
"sysCreatedBy": self.userId or "",
|
||
},
|
||
)
|
||
if existing:
|
||
return existing[0].get("id")
|
||
|
||
# Resolve the instance label for the folder name
|
||
folderName = featureInstanceId[:8]
|
||
try:
|
||
from modules.datamodels.datamodelFeatures import FeatureInstance
|
||
from modules.security.rootAccess import getRootDbAppConnector
|
||
dbApp = getRootDbAppConnector()
|
||
instances = dbApp.getRecordset(FeatureInstance, recordFilter={"id": featureInstanceId})
|
||
if instances:
|
||
folderName = instances[0].get("label") or folderName
|
||
except Exception as e:
|
||
logger.warning(f"Could not resolve feature instance label: {e}")
|
||
|
||
folder = FileFolder(
|
||
name=folderName,
|
||
parentId=None,
|
||
mandateId=mandateId,
|
||
featureInstanceId=featureInstanceId,
|
||
)
|
||
created = self.db.recordCreate(FileFolder, folder)
|
||
return created.get("id") if isinstance(created, dict) else getattr(created, "id", None)
|
||
|
||
def getFolder(self, folderId: str) -> Optional[Dict[str, Any]]:
|
||
"""Returns a folder by ID if it belongs to the current user."""
|
||
folders = self.db.getRecordset(FileFolder, recordFilter={"id": folderId, "sysCreatedBy": self.userId or ""})
|
||
return folders[0] if folders else None
|
||
|
||
def listFolders(self, parentId: Optional[str] = None) -> List[Dict[str, Any]]:
|
||
"""List folders visible to the current user.
|
||
Own folders are always returned. Other users' folders are only
|
||
returned when they contain files visible to the current user.
|
||
Each folder is enriched with ``fileCount``."""
|
||
recordFilter = {}
|
||
if parentId is not None:
|
||
recordFilter["parentId"] = parentId
|
||
folders = self.db.getRecordset(FileFolder, recordFilter=recordFilter if recordFilter else None)
|
||
|
||
if not folders:
|
||
return folders
|
||
|
||
folderIds = [f["id"] for f in folders if f.get("id")]
|
||
fileCounts: Dict[str, int] = {}
|
||
try:
|
||
from modules.interfaces.interfaceRbac import buildFilesScopeWhereClause
|
||
scopeClause = buildFilesScopeWhereClause(
|
||
self.currentUser, "FileItem", self.db,
|
||
self.mandateId, self.featureInstanceId,
|
||
[], [],
|
||
)
|
||
|
||
self.db._ensure_connection()
|
||
with self.db.connection.cursor() as cursor:
|
||
baseQuery = (
|
||
'SELECT "folderId", COUNT(*) AS cnt '
|
||
'FROM "FileItem" '
|
||
'WHERE "folderId" = ANY(%s)'
|
||
)
|
||
queryValues: list = [folderIds]
|
||
|
||
if scopeClause:
|
||
baseQuery += ' AND (' + scopeClause["condition"] + ')'
|
||
queryValues.extend(scopeClause["values"])
|
||
|
||
baseQuery += ' GROUP BY "folderId"'
|
||
cursor.execute(baseQuery, queryValues)
|
||
for row in cursor.fetchall():
|
||
fileCounts[row["folderId"]] = row["cnt"]
|
||
except Exception as e:
|
||
logger.warning(f"Could not count files per folder: {e}")
|
||
|
||
userId = self.userId or ""
|
||
result = []
|
||
for folder in folders:
|
||
fc = fileCounts.get(folder.get("id", ""), 0)
|
||
folder["fileCount"] = fc
|
||
isOwn = folder.get("sysCreatedBy") == userId
|
||
if isOwn or fc > 0:
|
||
result.append(folder)
|
||
|
||
return result
|
||
|
||
def createFolder(self, name: str, parentId: Optional[str] = None) -> Dict[str, Any]:
|
||
"""Create a new folder with unique name validation."""
|
||
self._validateFolderName(name, parentId)
|
||
folder = FileFolder(
|
||
name=name,
|
||
parentId=parentId,
|
||
mandateId=self.mandateId or "",
|
||
featureInstanceId=self.featureInstanceId or "",
|
||
)
|
||
return self.db.recordCreate(FileFolder, folder)
|
||
|
||
def renameFolder(self, folderId: str, newName: str) -> bool:
|
||
"""Rename a folder with unique name validation."""
|
||
folder = self.getFolder(folderId)
|
||
if not folder:
|
||
raise FileNotFoundError(f"Folder {folderId} not found")
|
||
self._validateFolderName(newName, folder.get("parentId"), excludeFolderId=folderId)
|
||
return self.db.recordModify(FileFolder, folderId, {"name": newName})
|
||
|
||
def updateFolder(self, folderId: str, updateData: Dict[str, Any]) -> bool:
|
||
"""
|
||
Update folder metadata (e.g. ``scope``, ``neutralize``). Owner-only,
|
||
same access model as renameFolder/moveFolder. Use ``renameFolder`` for
|
||
``name`` changes (uniqueness validation) and ``moveFolder`` for
|
||
``parentId`` changes (cycle/uniqueness validation).
|
||
"""
|
||
if not updateData:
|
||
return True
|
||
folder = self.getFolder(folderId)
|
||
if not folder:
|
||
raise FileNotFoundError(f"Folder {folderId} not found")
|
||
forbiddenKeys = {"id", "sysCreatedBy", "sysCreatedAt", "sysUpdatedAt"}
|
||
cleaned: Dict[str, Any] = {k: v for k, v in updateData.items() if k not in forbiddenKeys}
|
||
if "name" in cleaned:
|
||
self._validateFolderName(cleaned["name"], folder.get("parentId"), excludeFolderId=folderId)
|
||
return self.db.recordModify(FileFolder, folderId, cleaned)
|
||
|
||
def moveFolder(self, folderId: str, targetParentId: Optional[str] = None) -> bool:
|
||
"""Move a folder to a new parent, with circular reference and unique name checks."""
|
||
folder = self.getFolder(folderId)
|
||
if not folder:
|
||
raise FileNotFoundError(f"Folder {folderId} not found")
|
||
if targetParentId and self._isDescendantOf(targetParentId, folderId):
|
||
raise ValueError("Cannot move folder into its own subtree")
|
||
self._validateFolderName(folder.get("name", ""), targetParentId, excludeFolderId=folderId)
|
||
return self.db.recordModify(FileFolder, folderId, {"parentId": targetParentId})
|
||
|
||
def moveFilesBatch(self, fileIds: List[str], targetFolderId: Optional[str] = None) -> Dict[str, Any]:
|
||
"""Move multiple files with one SQL update.
|
||
Owner can always move; non-owners need RBAC ALL level."""
|
||
uniqueIds = [str(fid) for fid in dict.fromkeys(fileIds or []) if fid]
|
||
if not uniqueIds:
|
||
return {"movedFiles": 0}
|
||
|
||
if targetFolderId:
|
||
targetFolder = self.getFolder(targetFolderId)
|
||
if not targetFolder:
|
||
raise FileNotFoundError(f"Target folder {targetFolderId} not found")
|
||
|
||
try:
|
||
self.db._ensure_connection()
|
||
with self.db.connection.cursor() as cursor:
|
||
cursor.execute(
|
||
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
|
||
(uniqueIds,),
|
||
)
|
||
rows = cursor.fetchall()
|
||
foundIds = {row["id"] for row in rows}
|
||
missing = sorted(set(uniqueIds) - foundIds)
|
||
if missing:
|
||
raise FileNotFoundError(f"Files not found: {missing}")
|
||
|
||
for row in rows:
|
||
self._requireFileWriteAccess(row, row["id"], "update")
|
||
|
||
accessibleIds = [row["id"] for row in rows]
|
||
cursor.execute(
|
||
'UPDATE "FileItem" SET "folderId" = %s, "sysModifiedAt" = %s, "sysModifiedBy" = %s '
|
||
'WHERE "id" = ANY(%s)',
|
||
(targetFolderId, getUtcTimestamp(), self.userId or "", accessibleIds),
|
||
)
|
||
movedFiles = cursor.rowcount
|
||
|
||
self.db.connection.commit()
|
||
return {"movedFiles": movedFiles}
|
||
except Exception as e:
|
||
logger.error(f"Error moving files in batch: {e}")
|
||
self.db.connection.rollback()
|
||
raise FileError(f"Error moving files in batch: {str(e)}")
|
||
|
||
def moveFoldersBatch(self, folderIds: List[str], targetParentId: Optional[str] = None) -> Dict[str, Any]:
|
||
"""Move multiple folders with one SQL update after validation."""
|
||
uniqueIds = [str(fid) for fid in dict.fromkeys(folderIds or []) if fid]
|
||
if not uniqueIds:
|
||
return {"movedFolders": 0}
|
||
|
||
foldersToMove: List[Dict[str, Any]] = []
|
||
for folderId in uniqueIds:
|
||
folder = self.getFolder(folderId)
|
||
if not folder:
|
||
raise FileNotFoundError(f"Folder {folderId} not found")
|
||
if targetParentId and self._isDescendantOf(targetParentId, folderId):
|
||
raise ValueError("Cannot move folder into its own subtree")
|
||
foldersToMove.append(folder)
|
||
|
||
existingInTarget = self.db.getRecordset(
|
||
FileFolder,
|
||
recordFilter={"parentId": targetParentId or "", "sysCreatedBy": self.userId or ""},
|
||
)
|
||
existingNames = {f.get("name"): f.get("id") for f in existingInTarget}
|
||
movingNames: Dict[str, str] = {}
|
||
movingIds = set(uniqueIds)
|
||
|
||
for folder in foldersToMove:
|
||
name = folder.get("name", "")
|
||
folderId = folder.get("id")
|
||
if name in movingNames and movingNames[name] != folderId:
|
||
raise ValueError(f"Folder '{name}' already exists in this move batch")
|
||
movingNames[name] = folderId
|
||
|
||
existingId = existingNames.get(name)
|
||
if existingId and existingId not in movingIds:
|
||
raise ValueError(f"Folder '{name}' already exists in target directory")
|
||
|
||
try:
|
||
self.db._ensure_connection()
|
||
with self.db.connection.cursor() as cursor:
|
||
cursor.execute(
|
||
'UPDATE "FileFolder" SET "parentId" = %s, "sysModifiedAt" = %s, "sysModifiedBy" = %s '
|
||
'WHERE "id" = ANY(%s) AND "sysCreatedBy" = %s',
|
||
(targetParentId, getUtcTimestamp(), self.userId or "", uniqueIds, self.userId or ""),
|
||
)
|
||
movedFolders = cursor.rowcount
|
||
|
||
self.db.connection.commit()
|
||
return {"movedFolders": movedFolders}
|
||
except Exception as e:
|
||
logger.error(f"Error moving folders in batch: {e}")
|
||
self.db.connection.rollback()
|
||
raise FileError(f"Error moving folders in batch: {str(e)}")
|
||
|
||
def deleteFolder(self, folderId: str, recursive: bool = False) -> Dict[str, Any]:
|
||
"""Delete a folder. If recursive, deletes all contents. Returns summary of deletions."""
|
||
folder = self.getFolder(folderId)
|
||
if not folder:
|
||
raise FileNotFoundError(f"Folder {folderId} not found")
|
||
|
||
childFolders = self.db.getRecordset(FileFolder, recordFilter={"parentId": folderId, "sysCreatedBy": self.userId or ""})
|
||
childFiles = self._getFilesByCurrentUser(recordFilter={"folderId": folderId})
|
||
|
||
if not recursive and (childFolders or childFiles):
|
||
raise ValueError(
|
||
f"Folder '{folder.get('name')}' is not empty "
|
||
f"({len(childFiles)} files, {len(childFolders)} subfolders). "
|
||
f"Use recursive=true to delete contents."
|
||
)
|
||
|
||
deletedFiles = 0
|
||
deletedFolders = 0
|
||
|
||
if recursive:
|
||
for subFolder in childFolders:
|
||
subResult = self.deleteFolder(subFolder["id"], recursive=True)
|
||
deletedFiles += subResult.get("deletedFiles", 0)
|
||
deletedFolders += subResult.get("deletedFolders", 0)
|
||
for childFile in childFiles:
|
||
try:
|
||
self.deleteFile(childFile["id"])
|
||
deletedFiles += 1
|
||
except Exception as e:
|
||
logger.warning(f"Failed to delete file {childFile['id']} during folder deletion: {e}")
|
||
|
||
self.db.recordDelete(FileFolder, folderId)
|
||
deletedFolders += 1
|
||
|
||
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
|
||
|
||
def deleteFoldersBatch(self, folderIds: List[str], recursive: bool = True) -> Dict[str, Any]:
|
||
"""Delete multiple folders and their content in batched SQL calls."""
|
||
uniqueIds = [str(fid) for fid in dict.fromkeys(folderIds or []) if fid]
|
||
if not uniqueIds:
|
||
return {"deletedFiles": 0, "deletedFolders": 0}
|
||
|
||
if not recursive:
|
||
deletedFiles = 0
|
||
deletedFolders = 0
|
||
for folderId in uniqueIds:
|
||
result = self.deleteFolder(folderId, recursive=False)
|
||
deletedFiles += result.get("deletedFiles", 0)
|
||
deletedFolders += result.get("deletedFolders", 0)
|
||
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
|
||
|
||
try:
|
||
self.db._ensure_connection()
|
||
with self.db.connection.cursor() as cursor:
|
||
cursor.execute(
|
||
'SELECT "id" FROM "FileFolder" WHERE "id" = ANY(%s) AND "sysCreatedBy" = %s',
|
||
(uniqueIds, self.userId or ""),
|
||
)
|
||
rootAccessibleIds = [row["id"] for row in cursor.fetchall()]
|
||
if len(rootAccessibleIds) != len(uniqueIds):
|
||
missingIds = sorted(set(uniqueIds) - set(rootAccessibleIds))
|
||
raise FileNotFoundError(f"Folders not found or not accessible: {missingIds}")
|
||
|
||
cursor.execute(
|
||
"""
|
||
WITH RECURSIVE folder_tree AS (
|
||
SELECT "id"
|
||
FROM "FileFolder"
|
||
WHERE "id" = ANY(%s) AND "sysCreatedBy" = %s
|
||
UNION ALL
|
||
SELECT child."id"
|
||
FROM "FileFolder" child
|
||
INNER JOIN folder_tree ft ON child."parentId" = ft."id"
|
||
WHERE child."sysCreatedBy" = %s
|
||
)
|
||
SELECT DISTINCT "id" FROM folder_tree
|
||
""",
|
||
(rootAccessibleIds, self.userId or "", self.userId or ""),
|
||
)
|
||
allFolderIds = [row["id"] for row in cursor.fetchall()]
|
||
|
||
cursor.execute(
|
||
'SELECT "id" FROM "FileItem" WHERE "folderId" = ANY(%s) AND "sysCreatedBy" = %s',
|
||
(allFolderIds, self.userId or ""),
|
||
)
|
||
allFileIds = [row["id"] for row in cursor.fetchall()]
|
||
|
||
if allFileIds:
|
||
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (allFileIds,))
|
||
cursor.execute(
|
||
'DELETE FROM "FileItem" WHERE "id" = ANY(%s) AND "sysCreatedBy" = %s',
|
||
(allFileIds, self.userId or ""),
|
||
)
|
||
deletedFiles = cursor.rowcount
|
||
else:
|
||
deletedFiles = 0
|
||
|
||
cursor.execute(
|
||
'DELETE FROM "FileFolder" WHERE "id" = ANY(%s) AND "sysCreatedBy" = %s',
|
||
(allFolderIds, self.userId or ""),
|
||
)
|
||
deletedFolders = cursor.rowcount
|
||
|
||
self.db.connection.commit()
|
||
return {"deletedFiles": deletedFiles, "deletedFolders": deletedFolders}
|
||
except Exception as e:
|
||
logger.error(f"Error deleting folders in batch: {e}")
|
||
self.db.connection.rollback()
|
||
raise FileDeletionError(f"Error deleting folders in batch: {str(e)}")
|
||
|
||
def copyFile(self, sourceFileId: str, targetFolderId: Optional[str] = None, newFileName: Optional[str] = None) -> FileItem:
|
||
"""Create a full duplicate of a file (FileItem + FileData)."""
|
||
sourceFile = self.getFile(sourceFileId)
|
||
if not sourceFile:
|
||
raise FileNotFoundError(f"File {sourceFileId} not found")
|
||
|
||
sourceData = self.getFileData(sourceFileId)
|
||
if sourceData is None:
|
||
raise FileStorageError(f"No data found for file {sourceFileId}")
|
||
|
||
fileName = newFileName or sourceFile.fileName
|
||
copiedFile = self.createFile(fileName, sourceFile.mimeType, sourceData)
|
||
|
||
if targetFolderId:
|
||
self.updateFile(copiedFile.id, {"folderId": targetFolderId})
|
||
elif sourceFile.folderId:
|
||
self.updateFile(copiedFile.id, {"folderId": sourceFile.folderId})
|
||
|
||
self.createFileData(copiedFile.id, sourceData)
|
||
return copiedFile
|
||
|
||
def updateFileData(self, fileId: str, data: bytes) -> bool:
|
||
"""Replace existing file data (delete + create). Updates FileItem metadata."""
|
||
file = self.getFile(fileId)
|
||
if not file:
|
||
raise FileNotFoundError(f"File {fileId} not found")
|
||
|
||
try:
|
||
self.db.recordDelete(FileData, fileId)
|
||
logger.debug(f"Deleted existing FileData for {fileId}")
|
||
except Exception as e:
|
||
logger.debug(f"No existing FileData to delete for {fileId}: {e}")
|
||
|
||
success = self.createFileData(fileId, data)
|
||
if success:
|
||
newSize = len(data)
|
||
newHash = hashlib.sha256(data).hexdigest()
|
||
self.db.recordModify(FileItem, fileId, {"fileSize": newSize, "fileHash": newHash})
|
||
logger.info(f"Updated file data for {fileId} ({newSize} bytes)")
|
||
return success
|
||
|
||
# FileData methods - data operations
|
||
|
||
def createFileData(self, fileId: str, data: bytes) -> bool:
|
||
"""Stores the binary data of a file in the database."""
|
||
try:
|
||
# Check file access
|
||
file = self.getFile(fileId)
|
||
if not file:
|
||
logger.error(f"File with ID {fileId} not found when storing data")
|
||
return False
|
||
|
||
# Determine if this is a text-based format
|
||
mimeType = file.mimeType
|
||
isTextFormat = self.isTextMimeType(mimeType)
|
||
|
||
base64Encoded = False
|
||
fileData = None
|
||
|
||
if isTextFormat:
|
||
# Try to decode as text
|
||
try:
|
||
textContent = data.decode('utf-8')
|
||
fileData = textContent
|
||
base64Encoded = False
|
||
logger.debug(f"Stored file {fileId} as text")
|
||
except UnicodeDecodeError:
|
||
# Fallback to base64 if text decoding fails
|
||
encodedData = base64.b64encode(data).decode('utf-8')
|
||
fileData = encodedData
|
||
base64Encoded = True
|
||
logger.warning(f"Failed to decode text file {fileId}, falling back to base64")
|
||
else:
|
||
# Binary format - always use base64
|
||
encodedData = base64.b64encode(data).decode('utf-8')
|
||
fileData = encodedData
|
||
base64Encoded = True
|
||
logger.debug(f"Stored file {fileId} as base64")
|
||
|
||
# Check if file data already exists (e.g., when createFile returned a duplicate)
|
||
existingData = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||
if existingData:
|
||
logger.debug(f"File data already exists for {fileId} — skipping duplicate storage")
|
||
return True
|
||
|
||
# Create the fileData record with data and encoding flag
|
||
fileDataObj = {
|
||
"id": fileId,
|
||
"data": fileData,
|
||
"base64Encoded": base64Encoded
|
||
}
|
||
|
||
self.db.recordCreate(FileData, fileDataObj)
|
||
|
||
# Clear cache to ensure fresh data
|
||
|
||
logger.debug(f"Successfully stored data for file {fileId} (base64Encoded: {base64Encoded})")
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Error storing data for file {fileId}: {str(e)}")
|
||
return False
|
||
|
||
def getFileData(self, fileId: str) -> Optional[bytes]:
|
||
"""Returns the binary data of a file if user has access."""
|
||
# Check file access
|
||
file = self.getFile(fileId)
|
||
if not file:
|
||
logger.warning(f"No access to file ID {fileId}")
|
||
return None
|
||
|
||
fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||
if not fileDataEntries:
|
||
logger.warning(f"No data found for file ID {fileId}")
|
||
return None
|
||
|
||
fileDataEntry = fileDataEntries[0]
|
||
if "data" not in fileDataEntry:
|
||
logger.warning(f"No data field in file data for ID {fileId}")
|
||
return None
|
||
|
||
data = fileDataEntry["data"]
|
||
base64Encoded = fileDataEntry.get("base64Encoded", False)
|
||
|
||
try:
|
||
if base64Encoded:
|
||
# Decode base64 to bytes
|
||
return base64.b64decode(data)
|
||
else:
|
||
# Check if this is supposed to be a binary file based on mime type
|
||
mimeType = file.mimeType
|
||
isTextFormat = self.isTextMimeType(mimeType)
|
||
|
||
if isTextFormat:
|
||
# This is a text file, encode to bytes as expected
|
||
return data.encode('utf-8')
|
||
else:
|
||
# This is a binary file that was incorrectly stored as text
|
||
# Try to decode it as if it was base64 (common fallback scenario)
|
||
try:
|
||
logger.warning(f"Binary file {fileId} ({mimeType}) was stored as text, attempting base64 decode")
|
||
return base64.b64decode(data)
|
||
except Exception as base64_error:
|
||
logger.error(f"Failed to decode binary file {fileId} as base64: {str(base64_error)}")
|
||
# Last resort: return the data as-is (might be corrupted)
|
||
logger.warning(f"Returning raw data for file {fileId} - file may be corrupted")
|
||
return data.encode('utf-8') if isinstance(data, str) else data
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error processing file data for {fileId}: {str(e)}")
|
||
return None
|
||
|
||
def getFileDataForPublicDocument(self, fileId: str) -> Optional[bytes]:
|
||
"""
|
||
Returns binary data for public documents (e.g. BZO) WITHOUT RBAC filtering.
|
||
Use for official/mandate documents that must be accessible to all users.
|
||
Reads FileData directly from database.
|
||
"""
|
||
try:
|
||
fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
|
||
if not fileDataEntries:
|
||
logger.warning(f"No file data found for public document ID {fileId}")
|
||
return None
|
||
fileDataEntry = fileDataEntries[0]
|
||
if "data" not in fileDataEntry:
|
||
logger.warning(f"No data field in file data for ID {fileId}")
|
||
return None
|
||
data = fileDataEntry["data"]
|
||
base64Encoded = fileDataEntry.get("base64Encoded", False)
|
||
if base64Encoded:
|
||
return base64.b64decode(data)
|
||
# PDF/binary stored as text: try base64 decode (common for binary files)
|
||
try:
|
||
return base64.b64decode(data)
|
||
except Exception:
|
||
return data.encode("utf-8") if isinstance(data, str) else data
|
||
except Exception as e:
|
||
logger.error(f"Error retrieving public document {fileId}: {str(e)}", exc_info=True)
|
||
return None
|
||
|
||
def getFileContent(self, fileId: str) -> Optional[FilePreview]:
|
||
"""Returns the full file content if user has access."""
|
||
try:
|
||
# Get file metadata
|
||
file = self.getFile(fileId)
|
||
if not file:
|
||
logger.warning(f"No access to file ID {fileId}")
|
||
return None
|
||
|
||
# Get file content
|
||
fileContent = self.getFileData(fileId)
|
||
if not fileContent:
|
||
logger.warning(f"No content found for file ID {fileId}")
|
||
return None
|
||
|
||
# Process content based on file type
|
||
isText = False
|
||
content = ""
|
||
encoding = None
|
||
|
||
# Use proper attribute access for FileItem object
|
||
if file.mimeType.startswith("text/"):
|
||
# For text files, return full content
|
||
try:
|
||
content = fileContent.decode('utf-8')
|
||
isText = True
|
||
encoding = 'utf-8'
|
||
except UnicodeDecodeError:
|
||
content = fileContent.decode('latin-1')
|
||
isText = True
|
||
encoding = 'latin-1'
|
||
elif file.mimeType.startswith("image/"):
|
||
# For images, return base64
|
||
content = base64.b64encode(fileContent).decode('utf-8')
|
||
isText = False
|
||
else:
|
||
# For other files, return as base64
|
||
content = base64.b64encode(fileContent).decode('utf-8')
|
||
isText = False
|
||
|
||
return FilePreview(
|
||
content=content,
|
||
mimeType=file.mimeType,
|
||
fileName=file.fileName,
|
||
isText=isText,
|
||
encoding=encoding,
|
||
size=file.fileSize
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting file content: {str(e)}")
|
||
return None
|
||
|
||
def saveUploadedFile(self, fileContent: bytes, fileName: str, folderId: Optional[str] = None) -> tuple[FileItem, str]:
|
||
"""Saves an uploaded file if user has permission.
|
||
|
||
Args:
|
||
folderId: Optional parent folder ID. None means root folder.
|
||
"""
|
||
try:
|
||
# Check file creation permission
|
||
if not self.checkRbacPermission(FileItem, "create"):
|
||
raise PermissionError("No permission to upload files")
|
||
|
||
logger.debug(f"Starting upload process for file: {fileName} (folderId={folderId!r})")
|
||
|
||
if not isinstance(fileContent, bytes):
|
||
logger.error(f"Invalid fileContent type: {type(fileContent)}")
|
||
raise ValueError(f"fileContent must be bytes, got {type(fileContent)}")
|
||
|
||
# Compute file hash to check for duplicates before any DB writes
|
||
fileHash = hashlib.sha256(fileContent).hexdigest()
|
||
|
||
# Duplicate check: same user + same fileHash + same fileName → return existing file
|
||
# Same hash with different name is allowed (intentional copy by user)
|
||
existingFile = self.checkForDuplicateFile(fileHash, fileName)
|
||
if existingFile:
|
||
logger.info(f"Duplicate detected for user {self.userId}: '{fileName}' with hash {fileHash[:12]}... — returning existing file {existingFile.id}")
|
||
return existingFile, "exact_duplicate"
|
||
|
||
# Determine MIME type
|
||
mimeType = self.getMimeType(fileName)
|
||
|
||
# createFile handles its own duplicate check (for calls from other code paths)
|
||
# Here we already checked, so this will create a new file
|
||
logger.debug(f"Saving file metadata to database for file: {fileName}")
|
||
fileItem = self.createFile(
|
||
name=fileName,
|
||
mimeType=mimeType,
|
||
content=fileContent,
|
||
folderId=folderId,
|
||
)
|
||
|
||
# Save binary data
|
||
logger.debug(f"Saving file content to database for file: {fileName}")
|
||
self.createFileData(fileItem.id, fileContent)
|
||
|
||
logger.debug(f"File upload process completed for: {fileName}")
|
||
|
||
# Check if filename was modified (indicating name conflict)
|
||
if fileItem.fileName != fileName:
|
||
return fileItem, "name_conflict"
|
||
else:
|
||
return fileItem, "new_file"
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in saveUploadedFile for {fileName}: {str(e)}", exc_info=True)
|
||
raise FileStorageError(f"Error saving file: {str(e)}")
|
||
|
||
# Messaging Subscription methods
|
||
|
||
def getAllSubscriptions(self, pagination: Optional[PaginationParams] = None) -> Union[List[MessagingSubscription], PaginatedResult]:
|
||
"""
|
||
Returns subscriptions based on user access level.
|
||
Supports optional pagination, sorting, and filtering.
|
||
"""
|
||
try:
|
||
filteredSubscriptions = getRecordsetWithRBAC(self.db,
|
||
MessagingSubscription,
|
||
self.currentUser,
|
||
mandateId=self.mandateId
|
||
)
|
||
|
||
if pagination is None:
|
||
return [MessagingSubscription(**sub) for sub in filteredSubscriptions]
|
||
|
||
if pagination.filters:
|
||
filteredSubscriptions = self._applyFilters(filteredSubscriptions, pagination.filters)
|
||
|
||
if pagination.sort:
|
||
filteredSubscriptions = self._applySorting(filteredSubscriptions, pagination.sort)
|
||
|
||
totalItems = len(filteredSubscriptions)
|
||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||
|
||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||
endIdx = startIdx + pagination.pageSize
|
||
pagedSubscriptions = filteredSubscriptions[startIdx:endIdx]
|
||
|
||
items = [MessagingSubscription(**sub) for sub in pagedSubscriptions]
|
||
|
||
return PaginatedResult(
|
||
items=items,
|
||
totalItems=totalItems,
|
||
totalPages=totalPages
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error getting subscriptions: {str(e)}")
|
||
if pagination is None:
|
||
return []
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
|
||
def getSubscription(self, subscriptionId: str) -> Optional[MessagingSubscription]:
|
||
"""Returns a subscription by subscriptionId if user has access."""
|
||
filteredSubscriptions = getRecordsetWithRBAC(self.db,
|
||
MessagingSubscription,
|
||
self.currentUser,
|
||
recordFilter={"subscriptionId": subscriptionId},
|
||
mandateId=self.mandateId
|
||
)
|
||
return MessagingSubscription(**filteredSubscriptions[0]) if filteredSubscriptions else None
|
||
|
||
def getSubscriptionById(self, id: str) -> Optional[MessagingSubscription]:
|
||
"""Returns a subscription by UUID if user has access."""
|
||
filteredSubscriptions = getRecordsetWithRBAC(self.db,
|
||
MessagingSubscription,
|
||
self.currentUser,
|
||
recordFilter={"id": id},
|
||
mandateId=self.mandateId
|
||
)
|
||
return MessagingSubscription(**filteredSubscriptions[0]) if filteredSubscriptions else None
|
||
|
||
def createSubscription(self, subscriptionData: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Creates a new subscription if user has permission."""
|
||
if not self.checkRbacPermission(MessagingSubscription, "create"):
|
||
raise PermissionError("No permission to create subscriptions")
|
||
|
||
# Validate subscriptionId (only letters and underscores)
|
||
subscriptionId = subscriptionData.get("subscriptionId", "")
|
||
if not subscriptionId:
|
||
raise ValueError("subscriptionId is required")
|
||
# Check that subscriptionId contains only letters and underscores (no numbers)
|
||
if not all(c.isalpha() or c == "_" for c in subscriptionId):
|
||
raise ValueError("subscriptionId must contain only letters and underscores")
|
||
|
||
# Set mandateId and featureInstanceId from context for proper data isolation
|
||
if "mandateId" not in subscriptionData:
|
||
subscriptionData["mandateId"] = self.mandateId
|
||
if "featureInstanceId" not in subscriptionData:
|
||
subscriptionData["featureInstanceId"] = self.featureInstanceId
|
||
|
||
createdRecord = self.db.recordCreate(MessagingSubscription, subscriptionData)
|
||
if not createdRecord or not createdRecord.get("id"):
|
||
raise ValueError("Failed to create subscription record")
|
||
|
||
return createdRecord
|
||
|
||
def updateSubscription(self, subscriptionId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Updates a subscription if user has access."""
|
||
subscription = self.getSubscription(subscriptionId)
|
||
if not subscription:
|
||
raise ValueError(f"Subscription {subscriptionId} not found")
|
||
|
||
self.db.recordModify(MessagingSubscription, subscription.id, updateData)
|
||
|
||
updatedSubscription = self.getSubscription(subscriptionId)
|
||
if not updatedSubscription:
|
||
raise ValueError("Failed to retrieve updated subscription")
|
||
|
||
return updatedSubscription.model_dump()
|
||
|
||
def deleteSubscription(self, subscriptionId: str) -> bool:
|
||
"""Deletes a subscription if user has access."""
|
||
subscription = self.getSubscription(subscriptionId)
|
||
if not subscription:
|
||
return False
|
||
|
||
if not self.checkRbacPermission(MessagingSubscription, "update", subscription.id):
|
||
raise PermissionError(f"No permission to delete subscription {subscriptionId}")
|
||
|
||
return self.db.recordDelete(MessagingSubscription, subscription.id)
|
||
|
||
# Messaging Registration methods
|
||
|
||
def getAllRegistrations(
|
||
self,
|
||
subscriptionId: Optional[str] = None,
|
||
userId: Optional[str] = None,
|
||
pagination: Optional[PaginationParams] = None
|
||
) -> Union[List[MessagingSubscriptionRegistration], PaginatedResult]:
|
||
"""Returns registrations based on user access level."""
|
||
try:
|
||
recordFilter = {}
|
||
if subscriptionId:
|
||
recordFilter["subscriptionId"] = subscriptionId
|
||
if userId:
|
||
recordFilter["userId"] = userId
|
||
|
||
filteredRegistrations = getRecordsetWithRBAC(self.db,
|
||
MessagingSubscriptionRegistration,
|
||
self.currentUser,
|
||
recordFilter=recordFilter if recordFilter else None,
|
||
mandateId=self.mandateId
|
||
)
|
||
|
||
if pagination is None:
|
||
return [MessagingSubscriptionRegistration(**reg) for reg in filteredRegistrations]
|
||
|
||
if pagination.filters:
|
||
filteredRegistrations = self._applyFilters(filteredRegistrations, pagination.filters)
|
||
|
||
if pagination.sort:
|
||
filteredRegistrations = self._applySorting(filteredRegistrations, pagination.sort)
|
||
|
||
totalItems = len(filteredRegistrations)
|
||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||
|
||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||
endIdx = startIdx + pagination.pageSize
|
||
pagedRegistrations = filteredRegistrations[startIdx:endIdx]
|
||
|
||
items = [MessagingSubscriptionRegistration(**reg) for reg in pagedRegistrations]
|
||
|
||
return PaginatedResult(
|
||
items=items,
|
||
totalItems=totalItems,
|
||
totalPages=totalPages
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error getting registrations: {str(e)}")
|
||
if pagination is None:
|
||
return []
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
|
||
def getRegistration(self, registrationId: str) -> Optional[MessagingSubscriptionRegistration]:
|
||
"""Returns a registration by ID if user has access."""
|
||
filteredRegistrations = getRecordsetWithRBAC(self.db,
|
||
MessagingSubscriptionRegistration,
|
||
self.currentUser,
|
||
recordFilter={"id": registrationId},
|
||
mandateId=self.mandateId
|
||
)
|
||
return MessagingSubscriptionRegistration(**filteredRegistrations[0]) if filteredRegistrations else None
|
||
|
||
def createRegistration(self, registrationData: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Creates a new registration if user has permission."""
|
||
if not self.checkRbacPermission(MessagingSubscriptionRegistration, "create"):
|
||
raise PermissionError("No permission to create registrations")
|
||
|
||
# Set userId if not provided
|
||
if "userId" not in registrationData:
|
||
registrationData["userId"] = self.userId
|
||
|
||
# Set mandateId and featureInstanceId from context for proper data isolation
|
||
if "mandateId" not in registrationData:
|
||
registrationData["mandateId"] = self.mandateId
|
||
if "featureInstanceId" not in registrationData:
|
||
registrationData["featureInstanceId"] = self.featureInstanceId
|
||
|
||
createdRecord = self.db.recordCreate(MessagingSubscriptionRegistration, registrationData)
|
||
if not createdRecord or not createdRecord.get("id"):
|
||
raise ValueError("Failed to create registration record")
|
||
|
||
return createdRecord
|
||
|
||
def updateRegistration(self, registrationId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Updates a registration if user has access."""
|
||
registration = self.getRegistration(registrationId)
|
||
if not registration:
|
||
raise ValueError(f"Registration {registrationId} not found")
|
||
|
||
self.db.recordModify(MessagingSubscriptionRegistration, registrationId, updateData)
|
||
|
||
updatedRegistration = self.getRegistration(registrationId)
|
||
if not updatedRegistration:
|
||
raise ValueError("Failed to retrieve updated registration")
|
||
|
||
return updatedRegistration.model_dump()
|
||
|
||
def deleteRegistration(self, registrationId: str) -> bool:
|
||
"""Deletes a registration if user has access."""
|
||
registration = self.getRegistration(registrationId)
|
||
if not registration:
|
||
return False
|
||
|
||
if not self.checkRbacPermission(MessagingSubscriptionRegistration, "update", registrationId):
|
||
raise PermissionError(f"No permission to delete registration {registrationId}")
|
||
|
||
return self.db.recordDelete(MessagingSubscriptionRegistration, registrationId)
|
||
|
||
def subscribeUser(
|
||
self,
|
||
subscriptionId: str,
|
||
userId: str,
|
||
channel: MessagingChannel,
|
||
channelConfig: str
|
||
) -> Dict[str, Any]:
|
||
"""Subscribes a user to a subscription with a specific channel."""
|
||
# Check if subscription exists
|
||
subscription = self.getSubscription(subscriptionId)
|
||
if not subscription:
|
||
raise ValueError(f"Subscription {subscriptionId} not found")
|
||
|
||
# Check if registration already exists
|
||
existingRegistrations = self.getAllRegistrations(subscriptionId=subscriptionId, userId=userId)
|
||
for reg in existingRegistrations:
|
||
if reg.channel == channel:
|
||
# Update existing registration
|
||
return self.updateRegistration(reg.id, {"enabled": True, "channelConfig": channelConfig})
|
||
|
||
# Create new registration
|
||
registrationData = {
|
||
"subscriptionId": subscriptionId,
|
||
"userId": userId,
|
||
"channel": channel.value,
|
||
"channelConfig": channelConfig,
|
||
"enabled": True
|
||
}
|
||
return self.createRegistration(registrationData)
|
||
|
||
def unsubscribeUser(self, subscriptionId: str, userId: str, channel: MessagingChannel) -> bool:
|
||
"""Unsubscribes a user from a subscription for a specific channel."""
|
||
registrations = self.getAllRegistrations(subscriptionId=subscriptionId, userId=userId)
|
||
for reg in registrations:
|
||
if reg.channel == channel:
|
||
return self.deleteRegistration(reg.id)
|
||
return False
|
||
|
||
# Messaging Delivery methods
|
||
|
||
def createDelivery(self, delivery: MessagingDelivery) -> Dict[str, Any]:
|
||
"""Creates a new delivery record."""
|
||
deliveryData = delivery.model_dump() if isinstance(delivery, MessagingDelivery) else delivery
|
||
|
||
# Set mandateId and featureInstanceId from context for proper data isolation
|
||
if "mandateId" not in deliveryData or not deliveryData["mandateId"]:
|
||
deliveryData["mandateId"] = self.mandateId
|
||
if "featureInstanceId" not in deliveryData or not deliveryData["featureInstanceId"]:
|
||
deliveryData["featureInstanceId"] = self.featureInstanceId
|
||
|
||
createdRecord = self.db.recordCreate(MessagingDelivery, deliveryData)
|
||
if not createdRecord or not createdRecord.get("id"):
|
||
raise ValueError("Failed to create delivery record")
|
||
return createdRecord
|
||
|
||
def updateDelivery(self, deliveryId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Updates a delivery record."""
|
||
self.db.recordModify(MessagingDelivery, deliveryId, updateData)
|
||
return updateData
|
||
|
||
def getDeliveries(
|
||
self,
|
||
subscriptionId: Optional[str] = None,
|
||
userId: Optional[str] = None,
|
||
pagination: Optional[PaginationParams] = None
|
||
) -> Union[List[MessagingDelivery], PaginatedResult]:
|
||
"""Returns deliveries based on user access level."""
|
||
try:
|
||
recordFilter = {}
|
||
if subscriptionId:
|
||
recordFilter["subscriptionId"] = subscriptionId
|
||
if userId:
|
||
recordFilter["userId"] = userId
|
||
|
||
filteredDeliveries = getRecordsetWithRBAC(self.db,
|
||
MessagingDelivery,
|
||
self.currentUser,
|
||
recordFilter=recordFilter if recordFilter else None
|
||
)
|
||
|
||
if pagination is None:
|
||
return [MessagingDelivery(**delivery) for delivery in filteredDeliveries]
|
||
|
||
if pagination.filters:
|
||
filteredDeliveries = self._applyFilters(filteredDeliveries, pagination.filters)
|
||
|
||
if pagination.sort:
|
||
filteredDeliveries = self._applySorting(filteredDeliveries, pagination.sort)
|
||
|
||
totalItems = len(filteredDeliveries)
|
||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||
|
||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||
endIdx = startIdx + pagination.pageSize
|
||
pagedDeliveries = filteredDeliveries[startIdx:endIdx]
|
||
|
||
items = [MessagingDelivery(**delivery) for delivery in pagedDeliveries]
|
||
|
||
return PaginatedResult(
|
||
items=items,
|
||
totalItems=totalItems,
|
||
totalPages=totalPages
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error getting deliveries: {str(e)}")
|
||
if pagination is None:
|
||
return []
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
|
||
def getDelivery(self, deliveryId: str) -> Optional[MessagingDelivery]:
|
||
"""Returns a delivery by ID if user has access."""
|
||
filteredDeliveries = getRecordsetWithRBAC(self.db,
|
||
MessagingDelivery,
|
||
self.currentUser,
|
||
recordFilter={"id": deliveryId}
|
||
)
|
||
return MessagingDelivery(**filteredDeliveries[0]) if filteredDeliveries else None
|
||
|
||
|
||
def getInterface(currentUser: Optional[User] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> 'ComponentObjects':
|
||
"""
|
||
Returns a ComponentObjects instance.
|
||
If currentUser is provided, initializes with user context.
|
||
Otherwise, returns an instance with only database access.
|
||
|
||
Args:
|
||
currentUser: The authenticated user
|
||
mandateId: The mandate ID from RequestContext (X-Mandate-Id header). Required.
|
||
featureInstanceId: The feature instance ID from RequestContext (X-Feature-Instance-Id header).
|
||
"""
|
||
effectiveMandateId = str(mandateId) if mandateId else None
|
||
effectiveFeatureInstanceId = str(featureInstanceId) if featureInstanceId else None
|
||
|
||
# Create new instance if not exists
|
||
if "default" not in _instancesManagement:
|
||
_instancesManagement["default"] = ComponentObjects()
|
||
|
||
interface = _instancesManagement["default"]
|
||
|
||
if currentUser:
|
||
interface.setUserContext(currentUser, mandateId=effectiveMandateId, featureInstanceId=effectiveFeatureInstanceId)
|
||
else:
|
||
logger.info("Returning interface without user context")
|
||
|
||
return interface |