diff --git a/app.py b/app.py index 9ace64b5..23a8cb5c 100644 --- a/app.py +++ b/app.py @@ -437,3 +437,6 @@ app.include_router(automationRouter) from modules.routes.routeAdminAutomationEvents import router as adminAutomationEventsRouter app.include_router(adminAutomationEventsRouter) +from modules.routes.routeRbac import router as rbacRouter +app.include_router(rbacRouter) + diff --git a/import_map_analysis.md b/import_map_analysis.md new file mode 100644 index 00000000..4074d1a7 --- /dev/null +++ b/import_map_analysis.md @@ -0,0 +1,247 @@ +# Import Map Analysis: interfaces ↔ connectors ↔ security + +## Overview +This document maps all imports between `modules/interfaces`, `modules/connectors`, and `modules/security` to identify structural issues, circular dependencies, and architectural concerns. + +**Architectural Principle:** +- ✅ Connectors (infrastructure) can import from Security (infrastructure) +- ✅ Interfaces (business logic) can import from Security (infrastructure) +- ✅ Interfaces (business logic) can import from Connectors (infrastructure) +- ❌ Connectors should NOT import from Interfaces (business logic) + +--- + +## Import Dependencies Map + +### **CONNECTORS → SECURITY** + +#### `connectorDbPostgre.py` +- **Imports from security:** + - `from modules.security.rbac import RbacClass` (line 13) + - **Usage:** + - **Runtime instantiation:** `RbacClass(self)` in `getRecordsetWithRBAC()` (line 1073) + - Creates `RbacClass` instance to get user permissions + - **Status:** ✅ **ARCHITECTURALLY CORRECT** - Connectors can import from security module + +--- + +### **SECURITY → CONNECTORS** + +#### `security/rbac.py` (moved from `interfaces/interfaceRbac.py`) +- **Imports from connectors:** + - `from modules.connectors.connectorDbPostgre import DatabaseConnector` (line 11, inside TYPE_CHECKING) + - **Usage:** Type hint only (`db: "DatabaseConnector"`) + - **Status:** ✅ Fixed with TYPE_CHECKING to avoid circular import + - **Architecture:** ✅ Correct - Security module can import from connectors (infrastructure layer) + +### **INTERFACES → CONNECTORS** + +#### `interfaceBootstrap.py` +- **Imports from connectors:** + - `from modules.connectors.connectorDbPostgre import DatabaseConnector` (line 9) + - **Usage:** Function parameter types (`initBootstrap(db: DatabaseConnector)`) + +#### `interfaceDbAppObjects.py` +- **Imports from connectors:** + - `from modules.connectors.connectorDbPostgre import DatabaseConnector` (line 12) + - **Usage:** Class initialization (`self.db: DatabaseConnector`) +- **Imports from security:** + - `from modules.security.rbac import RbacClass` (line 17) + - **Usage:** RBAC permission checking + - **Architecture:** ✅ Correct - Interfaces can import from security (infrastructure layer) + +#### `interfaceDbChatObjects.py` +- **Imports from connectors:** + - `from modules.connectors.connectorDbPostgre import DatabaseConnector` (line 29) + - **Usage:** Class initialization + +#### `interfaceDbComponentObjects.py` +- **Imports from connectors:** + - `from modules.connectors.connectorDbPostgre import DatabaseConnector` (line 13) + - **Usage:** Class initialization + +#### `interfaceVoiceObjects.py` +- **Imports from connectors:** + - `from modules.connectors.connectorVoiceGoogle import ConnectorGoogleSpeech` (line 10) + - **Usage:** Class initialization + +--- + +## Circular Dependency Analysis + +### **CIRCULAR DEPENDENCY #1: RESOLVED** ✅ +``` +connectorDbPostgre.py (line 13) + └─> imports RbacClass from security.rbac + └─> Uses: RbacClass(self) at runtime (line 1073) + +security/rbac.py (line 11, inside TYPE_CHECKING) + └─> imports DatabaseConnector (type hint only) +``` + +**Status:** ✅ **RESOLVED** by moving RBAC to security module + `TYPE_CHECKING` + +**Architectural Fix:** +- Moved `interfaceRbac.py` → `security/rbac.py` +- Connectors can import from security (infrastructure layer) +- Interfaces can import from security (business logic layer) +- No architectural violation: security is shared infrastructure + +**Solution Applied:** +```python +# security/rbac.py +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from modules.connectors.connectorDbPostgre import DatabaseConnector + +class RbacClass: + def __init__(self, db: "DatabaseConnector"): # String annotation + self.db = db # Uses db at runtime, but import is deferred +``` + +**Why This Works:** +1. At **import time**: `connectorDbPostgre` imports `RbacClass` ✅ +2. `RbacClass` tries to import `DatabaseConnector` but it's inside `TYPE_CHECKING`, so **no actual import occurs** ✅ +3. At **runtime**: When `getRecordsetWithRBAC()` calls `RbacClass(self)`, `DatabaseConnector` is already fully loaded ✅ +4. Runtime circular reference is safe because Python objects can reference each other once loaded + +--- + +## Architecture Analysis + +### **Current Structure** + +``` +┌─────────────────────────────────────────────────────────────┐ +│ CONNECTORS │ +│ (Database, External Services) │ +│ │ +│ connectorDbPostgre.py │ +│ └─> Uses: RbacClass (runtime instantiation) ⚠️ │ +│ │ +│ connectorVoiceGoogle.py │ +│ connectorTicketsClickup.py │ +│ connectorTicketsJira.py │ +└─────────────────────────────────────────────────────────────┘ + ▲ + │ imports + │ +┌─────────────────────────────────────────────────────────────┐ +│ INTERFACES │ +│ (Business Logic, Data Access Layer) │ +│ │ +│ security/rbac.py (moved from interfaces) │ +│ └─> Uses: DatabaseConnector (type hint only) ✅ │ +│ └─> Can be imported by both connectors and interfaces │ +│ │ +│ interfaceBootstrap.py │ +│ └─> Uses: DatabaseConnector │ +│ │ +│ interfaceDbAppObjects.py │ +│ └─> Uses: DatabaseConnector │ +│ └─> Uses: security.rbac.RbacClass │ +│ └─> Uses: interfaceBootstrap.initBootstrap │ +│ │ +│ interfaceDbChatObjects.py │ +│ └─> Uses: DatabaseConnector │ +│ │ +│ interfaceDbComponentObjects.py │ +│ └─> Uses: DatabaseConnector │ +│ │ +│ interfaceVoiceObjects.py │ +│ └─> Uses: connectorVoiceGoogle.ConnectorGoogleSpeech │ +└─────────────────────────────────────────────────────────────┘ +``` + +--- + +## Potential Issues & Recommendations + +### ✅ **RESOLVED ISSUES** + +1. **Circular Import: security.rbac ↔ connectorDbPostgre** + - **Status:** ✅ Resolved by moving to security module + TYPE_CHECKING + - **Impact:** None - Proper architectural layering maintained + +### ⚠️ **POTENTIAL ISSUES** + +1. **Tight Coupling: Interfaces depend on specific connectors** + - **Issue:** `interfaceDbAppObjects.py` directly imports `DatabaseConnector` + - **Impact:** Makes it harder to swap database implementations + - **Recommendation:** Consider dependency injection or abstract base class + +2. **Connector importing from Security (connectorDbPostgre → security.rbac)** ✅ + - **Status:** ✅ **RESOLVED** - Moved RBAC to security module + - **Current Usage:** Runtime instantiation in `getRecordsetWithRBAC()` (line 1073) + - **Code:** + ```python + RbacInstance = RbacClass(self) + permissions = RbacInstance.getUserPermissions(...) + ``` + - **Architecture:** ✅ Correct - Connectors can import from security (infrastructure layer) + - **Rationale:** Security is shared infrastructure, not business logic + +3. **Multiple interfaces importing same connector** + - **Files importing DatabaseConnector:** + - `interfaceBootstrap.py` + - `interfaceDbAppObjects.py` + - `interfaceDbChatObjects.py` + - `interfaceDbComponentObjects.py` + - **Impact:** Medium - creates coupling + - **Recommendation:** Consider a shared database interface abstraction + +--- + +## Recommendations + +### **1. Move RBAC Logic Out of Connector** +**Current:** `connectorDbPostgre.getRecordsetWithRBAC()` instantiates `RbacClass(self)` at runtime +**Recommendation:** +- ~~Move `getRecordsetWithRBAC()` to `interfaceRbac.py` or `interfaceDbAppObjects.py`~~ ✅ **RESOLVED** - RBAC moved to security module +- Connector should only handle raw database operations +- Interface layer handles RBAC filtering + +### **2. Use Dependency Injection** +**Current:** Interfaces directly import `DatabaseConnector` +**Recommendation:** +- Create abstract base class `DatabaseConnectorBase` +- Interfaces depend on abstraction, not concrete implementation +- Allows easier testing and swapping implementations + +### **3. Consider Layered Architecture** +``` +┌─────────────────────────────────────┐ +│ Interfaces (Business Logic) │ +│ - Uses connectors via abstraction │ +└─────────────────────────────────────┘ + ▲ + │ +┌─────────────────────────────────────┐ +│ Connectors (Infrastructure) │ +│ - No knowledge of interfaces │ +└─────────────────────────────────────┘ +``` + +### **4. Use TYPE_CHECKING for All Type-Only Imports** +**Current:** `security/rbac.py` uses TYPE_CHECKING (moved from interfaces) +**Recommendation:** Use TYPE_CHECKING for all type-only imports between layers + +--- + +## Summary + +### **Current State:** +- ✅ 1 circular dependency **RESOLVED** (moved to security module) +- ✅ Architectural violation **FIXED** (RBAC moved to security) +- ⚠️ Multiple tight couplings to `DatabaseConnector` (acceptable for now) + +### **Architectural Health:** +- **Overall:** 🟢 **Good** - Proper layering maintained +- **Architecture:** ✅ Connectors → Security (infrastructure) ✅ Interfaces → Security (infrastructure) +- **Risk Level:** Low - Clean separation of concerns + +### **Completed Actions:** +1. ✅ **DONE:** Fixed circular import with TYPE_CHECKING +2. ✅ **DONE:** Moved RBAC to security module (proper architectural layering) +3. 🔄 **OPTIONAL:** Introduce abstraction layer for database connector (future improvement) diff --git a/modules/connectors/connectorDbJson.py b/modules/connectors/connectorDbJson.py deleted file mode 100644 index 0b44e6df..00000000 --- a/modules/connectors/connectorDbJson.py +++ /dev/null @@ -1,678 +0,0 @@ -import json -import os -from typing import List, Dict, Any, Optional, TypedDict -import logging -import uuid -from pydantic import BaseModel -import threading -import time - -from modules.shared.timeUtils import getUtcTimestamp - -logger = logging.getLogger(__name__) - -class TableCache(TypedDict): - """Type definition for table cache entries""" - recordIds: List[str] - -class DatabaseConnector: - """ - A connector for JSON-based data storage. - Provides generic database operations without user/mandate filtering. - Stores tables as folders and records as individual files. - """ - def __init__(self, dbHost: str, dbDatabase: str, dbUser: str = None, dbPassword: str = None, userId: str = None): - # Store the input parameters - self.dbHost = dbHost - self.dbDatabase = dbDatabase - self.dbUser = dbUser - self.dbPassword = dbPassword - - # Set userId (default to empty string if None) - self.userId = userId if userId is not None else "" - - # Initialize database system - self.initDbSystem() - - # Set up database folder path - self.dbFolder = os.path.join(self.dbHost, self.dbDatabase) - - # Cache for loaded data - self._tablesCache: Dict[str, List[Dict[str, Any]]] = {} - self._tableMetadataCache: Dict[str, TableCache] = {} # Cache for table metadata (record IDs, etc.) - - # File locks with timeout protection - self._file_locks = {} - self._lock_manager = threading.Lock() - self._lock_timeouts = {} # Track when locks were acquired - - # Initialize system table - self._systemTableName = "_system" - self._initializeSystemTable() - - logger.debug(f"Context: userId={self.userId}") - - def initDbSystem(self): - """Initialize the database system - creates necessary directories and structure.""" - try: - # Ensure the database directory exists - self.dbFolder = os.path.join(self.dbHost, self.dbDatabase) - os.makedirs(self.dbFolder, exist_ok=True) - logger.info(f"Database system initialized: {self.dbFolder}") - except Exception as e: - logger.error(f"Error initializing database system: {e}") - raise - - def _initializeSystemTable(self): - """Initializes the system table if it doesn't exist yet.""" - systemTablePath = self._getTablePath(self._systemTableName) - if not os.path.exists(systemTablePath): - emptySystemTable = {} - self._saveSystemTable(emptySystemTable) - logger.info(f"System table initialized in {systemTablePath}") - else: - # Load existing system table to ensure it's available - self._loadSystemTable() - logger.debug(f"Existing system table loaded from {systemTablePath}") - - def _loadSystemTable(self) -> Dict[str, str]: - """Loads the system table with the initial IDs.""" - # Check if system table is in cache - if f"_{self._systemTableName}" in self._tablesCache: - return self._tablesCache[f"_{self._systemTableName}"] - - systemTablePath = self._getTablePath(self._systemTableName) - try: - if os.path.exists(systemTablePath): - with open(systemTablePath, 'r', encoding='utf-8') as f: - data = json.load(f) - # Store in cache with special prefix to avoid collision with regular tables - self._tablesCache[f"_{self._systemTableName}"] = data - return data - else: - self._tablesCache[f"_{self._systemTableName}"] = {} - return {} - except Exception as e: - logger.error(f"Error loading the system table: {e}") - self._tablesCache[f"_{self._systemTableName}"] = {} - return {} - - def _saveSystemTable(self, data: Dict[str, str]) -> bool: - """Saves the system table with the initial IDs.""" - systemTablePath = self._getTablePath(self._systemTableName) - try: - with open(systemTablePath, 'w', encoding='utf-8') as f: - json.dump(data, f, indent=2, ensure_ascii=False) - # Update cache - self._tablesCache[f"_{self._systemTableName}"] = data - return True - except Exception as e: - logger.error(f"Error saving the system table: {e}") - return False - - def _getTablePath(self, table: str) -> str: - """Returns the full path to a table folder""" - return os.path.join(self.dbFolder, table) - - def _getRecordPath(self, table: str, recordId: str) -> str: - """Returns the full path to a record file""" - return os.path.join(self._getTablePath(table), f"{recordId}.json") - - def _get_file_lock(self, filepath: str, timeout_seconds: int = 30): - """Get file lock with timeout protection""" - with self._lock_manager: - if filepath not in self._file_locks: - self._file_locks[filepath] = threading.Lock() - - lock = self._file_locks[filepath] - - # Check if lock is stale (held too long) - if filepath in self._lock_timeouts: - lock_age = time.time() - self._lock_timeouts[filepath] - if lock_age > timeout_seconds: - logger.warning(f"Stale lock detected for {filepath}, age: {lock_age}s") - # Force release stale lock - try: - lock.release() - except: - pass - # Create new lock - self._file_locks[filepath] = threading.Lock() - lock = self._file_locks[filepath] - - return lock - - def _get_table_lock(self, table: str, timeout_seconds: int = 30): - """Get table-level lock for metadata operations""" - table_lock_key = f"table_{table}" - return self._get_file_lock(table_lock_key, timeout_seconds) - - def _ensureTableDirectory(self, table: str) -> bool: - """Ensures the table directory exists.""" - if table == self._systemTableName: - return True - - tablePath = self._getTablePath(table) - try: - os.makedirs(tablePath, exist_ok=True) - return True - except Exception as e: - logger.error(f"Error creating table directory {tablePath}: {e}") - return False - - def _loadTableMetadata(self, table: str) -> Dict[str, Any]: - """Loads table metadata (list of record IDs) without loading actual records. - NOTE: This method is safe to call without additional locking. - """ - if table in self._tableMetadataCache: - return self._tableMetadataCache[table] - - # Ensure table directory exists - if not self._ensureTableDirectory(table): - return {"recordIds": []} - - tablePath = self._getTablePath(table) - metadata = {"recordIds": []} - - try: - if os.path.exists(tablePath): - for fileName in os.listdir(tablePath): - if fileName.endswith('.json') and fileName != '_metadata.json': - recordId = fileName[:-5] # Remove .json extension - metadata["recordIds"].append(recordId) - - metadata["recordIds"].sort() - self._tableMetadataCache[table] = metadata - except Exception as e: - logger.error(f"Error loading table metadata for {table}: {e}") - - return metadata - - def _loadRecord(self, table: str, recordId: str) -> Optional[Dict[str, Any]]: - """Loads a single record from the table.""" - recordPath = self._getRecordPath(table, recordId) - try: - if os.path.exists(recordPath): - with open(recordPath, 'r', encoding='utf-8') as f: - record = json.load(f) - return record - except Exception as e: - logger.error(f"Error loading record {recordId} from table {table}: {e}") - return None - - def _saveRecord(self, table: str, recordId: str, record: Dict[str, Any]) -> bool: - """Saves a single record to the table with atomic metadata operations.""" - recordPath = self._getRecordPath(table, recordId) - record_lock = self._get_file_lock(recordPath) - table_lock = self._get_table_lock(table) - - try: - # Acquire both locks with timeout - record lock first, then table lock - if not record_lock.acquire(timeout=30): - raise TimeoutError(f"Could not acquire record lock for {recordPath} within 30 seconds") - - if not table_lock.acquire(timeout=30): - record_lock.release() - raise TimeoutError(f"Could not acquire table lock for {table} within 30 seconds") - - # Record lock acquisition time - self._lock_timeouts[recordPath] = time.time() - self._lock_timeouts[f"table_{table}"] = time.time() - - # Ensure table directory exists - if not self._ensureTableDirectory(table): - raise ValueError(f"Error creating table directory for {table}") - - # Ensure recordId is a string - recordId = str(recordId) - - # CRITICAL: Ensure record ID matches the file name - if "id" in record and str(record["id"]) != recordId: - logger.error(f"Record ID mismatch: file name ID ({recordId}) does not match record ID ({record['id']})") - raise ValueError(f"Record ID mismatch: file name ID ({recordId}) does not match record ID ({record['id']})") - - # Add metadata - currentTime = getUtcTimestamp() - if "_createdAt" not in record: - record["_createdAt"] = currentTime - record["_createdBy"] = self.userId - record["_modifiedAt"] = currentTime - record["_modifiedBy"] = self.userId - - # Save the record file using atomic write - tempPath = recordPath + '.tmp' - - # Ensure directory exists - os.makedirs(os.path.dirname(recordPath), exist_ok=True) - - # Write to temporary file first - with open(tempPath, 'w', encoding='utf-8') as f: - json.dump(record, f, indent=2, ensure_ascii=False) - - # Verify the temporary file can be read back (validation) - try: - with open(tempPath, 'r', encoding='utf-8') as f: - json.load(f) # This will fail if file is corrupted - except Exception as e: - logger.error(f"Validation failed for record {recordId}: {e}") - # Clean up temp file - if os.path.exists(tempPath): - os.remove(tempPath) - raise ValueError(f"Record validation failed: {e}") - - # Atomic move from temp to final location - os.replace(tempPath, recordPath) - - # ATOMIC: Update metadata while holding both locks - metadata = self._loadTableMetadata(table) - if recordId not in metadata["recordIds"]: - metadata["recordIds"].append(recordId) - metadata["recordIds"].sort() - self._saveTableMetadata(table, metadata) - - # Update cache if it exists (also protected by table lock) - if table in self._tablesCache: - # Find and update existing record or append new one - found = False - for i, existing_record in enumerate(self._tablesCache[table]): - if str(existing_record.get("id")) == recordId: - self._tablesCache[table][i] = record - found = True - break - if not found: - self._tablesCache[table].append(record) - - return True - - except Exception as e: - logger.error(f"Error saving record {recordId} to table {table}: {e}") - # Clean up temp file if it exists - tempPath = self._getRecordPath(table, recordId) + '.tmp' - if os.path.exists(tempPath): - try: - os.remove(tempPath) - except: - pass - return False - - finally: - # ALWAYS release both locks, even on error - try: - if table_lock.locked(): - table_lock.release() - if f"table_{table}" in self._lock_timeouts: - del self._lock_timeouts[f"table_{table}"] - except Exception as release_error: - logger.error(f"Error releasing table lock for {table}: {release_error}") - - try: - if record_lock.locked(): - record_lock.release() - if recordPath in self._lock_timeouts: - del self._lock_timeouts[recordPath] - except Exception as release_error: - logger.error(f"Error releasing record lock for {recordPath}: {release_error}") - - def _loadTable(self, table: str) -> List[Dict[str, Any]]: - """Loads all records from a table folder.""" - # If the table is the system table, load it directly - if table == self._systemTableName: - return self._loadSystemTable() - - # If the table is already in the cache, use the cache - if table in self._tablesCache: - return self._tablesCache[table] - - # Load metadata first - metadata = self._loadTableMetadata(table) - records = [] - - # Load each record - for recordId in metadata["recordIds"]: - # Skip metadata file - if recordId == "_metadata": - continue - record = self._loadRecord(table, recordId) - if record: - records.append(record) - - self._tablesCache[table] = records - return records - - def _saveTable(self, table: str, data: List[Dict[str, Any]]) -> bool: - """Saves all records to a table folder""" - # The system table is handled specially - if table == self._systemTableName: - return self._saveSystemTable(data) - - tablePath = self._getTablePath(table) - try: - # Ensure table directory exists - os.makedirs(tablePath, exist_ok=True) - - # Save each record as a separate file - for record in data: - if "id" not in record: - logger.error(f"Record missing ID in table {table}") - continue - - recordPath = self._getRecordPath(table, record["id"]) - with open(recordPath, 'w', encoding='utf-8') as f: - json.dump(record, f, indent=2, ensure_ascii=False) - - # Update the cache - self._tablesCache[table] = data - logger.debug(f"Successfully saved table {table}") - return True - except Exception as e: - logger.error(f"Error saving table {table}: {str(e)}") - logger.error(f"Error type: {type(e).__name__}") - logger.error(f"Error details: {e.__dict__ if hasattr(e, '__dict__') else 'No details available'}") - return False - - def _applyRecordFilter(self, records: List[Dict[str, Any]], recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]: - """Applies a record filter to the records""" - if not recordFilter: - return records - - filteredRecords = [] - - for record in records: - match = True - - for field, value in recordFilter.items(): - # Check if the field exists - if field not in record: - match = False - break - - # Convert both values to strings for comparison - recordValue = str(record[field]) - filterValue = str(value) - - # Direct string comparison - if recordValue != filterValue: - match = False - break - - if match: - filteredRecords.append(record) - - return filteredRecords - - def _registerInitialId(self, table: str, initialId: str) -> bool: - """Registers the initial ID for a table.""" - try: - systemData = self._loadSystemTable() - - if table not in systemData: - systemData[table] = initialId - success = self._saveSystemTable(systemData) - if success: - logger.info(f"Initial ID {initialId} for table {table} registered") - return success - return True # If already present, this is not an error - except Exception as e: - logger.error(f"Error registering the initial ID for table {table}: {e}") - return False - - def _removeInitialId(self, table: str) -> bool: - """Removes the initial ID for a table from the system table.""" - try: - systemData = self._loadSystemTable() - - if table in systemData: - del systemData[table] - success = self._saveSystemTable(systemData) - if success: - logger.info(f"Initial ID for table {table} removed from system table") - return success - return True # If not present, this is not an error - except Exception as e: - logger.error(f"Error removing initial ID for table {table}: {e}") - return False - - - - def _saveTableMetadata(self, table: str, metadata: Dict[str, Any]) -> bool: - """Saves table metadata to a metadata file. - NOTE: This method assumes the caller already holds the table lock. - """ - try: - # Create metadata file path - metadataPath = os.path.join(self._getTablePath(table), "_metadata.json") - - # Save metadata (caller should already hold table lock) - with open(metadataPath, 'w', encoding='utf-8') as f: - json.dump(metadata, f, indent=2, ensure_ascii=False) - - # Update cache - self._tableMetadataCache[table] = metadata - - return True - - except Exception as e: - logger.error(f"Error saving metadata for table {table}: {e}") - return False - - def updateContext(self, userId: str) -> None: - """Updates the context of the database connector.""" - if userId is None: - raise ValueError("userId must be provided") - - self.userId = userId - logger.info(f"Updated database context: userId={self.userId}") - - # Clear cache to ensure fresh data with new context - self._tablesCache = {} - self._tableMetadataCache = {} - - def clearTableCache(self, table: str) -> None: - """Clears cache for a specific table to ensure fresh data.""" - if table in self._tablesCache: - del self._tablesCache[table] - logger.debug(f"Cleared cache for table: {table}") - - if table in self._tableMetadataCache: - del self._tableMetadataCache[table] - logger.debug(f"Cleared metadata cache for table: {table}") - - # Public API - - def getTables(self) -> List[str]: - """Returns a list of all available tables.""" - tables = [] - - try: - for item in os.listdir(self.dbFolder): - itemPath = os.path.join(self.dbFolder, item) - if os.path.isdir(itemPath) and not item.startswith('_'): - tables.append(item) - except Exception as e: - logger.error(f"Error reading the database directory: {e}") - - return tables - - def getFields(self, table: str) -> List[str]: - """Returns a list of all fields in a table.""" - data = self._loadTable(table) - - if not data: - return [] - - fields = list(data[0].keys()) if data else [] - - return fields - - def getSchema(self, table: str, language: str = None) -> Dict[str, Dict[str, Any]]: - """Returns a schema object for a table with data types and labels.""" - data = self._loadTable(table) - - schema = {} - - if not data: - return schema - - firstRecord = data[0] - - for field, value in firstRecord.items(): - dataType = type(value).__name__ - label = field - - schema[field] = { - "type": dataType, - "label": label - } - - return schema - - def getRecordset(self, table: str, fieldFilter: List[str] = None, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]: - """Returns a list of records from a table, filtered by criteria.""" - # If we have specific record IDs in the filter, only load those records - if recordFilter and "id" in recordFilter: - recordId = recordFilter["id"] - record = self._loadRecord(table, recordId) - if record: - records = [record] - else: - return [] - else: - # Load all records if no specific ID filter - records = self._loadTable(table) - - # Apply recordFilter if available - if recordFilter: - records = self._applyRecordFilter(records, recordFilter) - - # If fieldFilter is available, reduce the fields - if fieldFilter and isinstance(fieldFilter, list): - result = [] - for record in records: - filteredRecord = {} - for field in fieldFilter: - if field in record: - filteredRecord[field] = record[field] - result.append(filteredRecord) - return result - - return records - - def recordCreate(self, table: str, record: Dict[str, Any]) -> Dict[str, Any]: - """Creates a new record in a table.""" - # Ensure record has an ID - if "id" not in record: - record["id"] = str(uuid.uuid4()) - - # If record is a Pydantic model, convert to dict - if isinstance(record, BaseModel): - record = record.model_dump() - - # Save record - self._saveRecord(table, record["id"], record) - return record - - def recordModify(self, table: str, recordId: str, record: Dict[str, Any]) -> Dict[str, Any]: - """Modifies an existing record in a table.""" - # Load existing record - existingRecord = self._loadRecord(table, recordId) - if not existingRecord: - raise ValueError(f"Record {recordId} not found in table {table}") - - # If record is a Pydantic model, convert to dict - if isinstance(record, BaseModel): - record = record.model_dump() - - # CRITICAL: Ensure we never modify the ID - if "id" in record and str(record["id"]) != recordId: - logger.error(f"Attempted to modify record ID from {recordId} to {record['id']}") - raise ValueError("Cannot modify record ID - it must match the file name") - - # Update existing record with new data - existingRecord.update(record) - - # Save updated record - self._saveRecord(table, recordId, existingRecord) - return existingRecord - - def recordDelete(self, table: str, recordId: str) -> bool: - """Deletes a record from the table with atomic metadata operations.""" - recordPath = self._getRecordPath(table, recordId) - record_lock = self._get_file_lock(recordPath) - table_lock = self._get_table_lock(table) - - try: - # Acquire both locks with timeout - record lock first, then table lock - if not record_lock.acquire(timeout=30): - raise TimeoutError(f"Could not acquire record lock for {recordPath} within 30 seconds") - - if not table_lock.acquire(timeout=30): - record_lock.release() - raise TimeoutError(f"Could not acquire table lock for {table} within 30 seconds") - - # Record lock acquisition time - self._lock_timeouts[recordPath] = time.time() - self._lock_timeouts[f"table_{table}"] = time.time() - - # Load metadata - metadata = self._loadTableMetadata(table) - - if recordId not in metadata["recordIds"]: - return False - - # Check if it's an initial record - initialId = self.getInitialId(table) - if initialId is not None and initialId == recordId: - self._removeInitialId(table) - logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table") - - # Delete the record file - if os.path.exists(recordPath): - os.remove(recordPath) - - # ATOMIC: Update metadata while holding both locks - metadata["recordIds"].remove(recordId) - self._saveTableMetadata(table, metadata) - - # Update table cache if it exists (also protected by table lock) - if table in self._tablesCache: - self._tablesCache[table] = [r for r in self._tablesCache[table] if r.get("id") != recordId] - - return True - else: - return False - - except Exception as e: - logger.error(f"Error deleting record {recordId} from table {table}: {e}") - return False - - finally: - # ALWAYS release both locks, even on error - try: - if table_lock.locked(): - table_lock.release() - if f"table_{table}" in self._lock_timeouts: - del self._lock_timeouts[f"table_{table}"] - except Exception as release_error: - logger.error(f"Error releasing table lock for {table}: {release_error}") - - try: - if record_lock.locked(): - record_lock.release() - if recordPath in self._lock_timeouts: - del self._lock_timeouts[recordPath] - except Exception as release_error: - logger.error(f"Error releasing record lock for {recordPath}: {release_error}") - - def getInitialId(self, table_or_model) -> Optional[str]: - """Returns the initial ID for a table.""" - # Handle both string table names (legacy) and model classes (new) - if isinstance(table_or_model, str): - table = table_or_model - else: - table = table_or_model.__name__ - - systemData = self._loadSystemTable() - initialId = systemData.get(table) - logger.debug(f"Initial ID for table '{table}': {initialId}") - return initialId - \ No newline at end of file diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py index c9206c8d..828fa703 100644 --- a/modules/connectors/connectorDbPostgre.py +++ b/modules/connectors/connectorDbPostgre.py @@ -1,13 +1,16 @@ import psycopg2 import psycopg2.extras import logging -from typing import List, Dict, Any, Optional, Union, get_origin, get_args +from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type import uuid from pydantic import BaseModel, Field import threading from modules.shared.timeUtils import getUtcTimestamp from modules.shared.configuration import APP_CONFIG +from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions +from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext +from modules.security.rbac import RbacClass logger = logging.getLogger(__name__) @@ -1039,6 +1042,208 @@ class DatabaseConnector: initialId = systemData.get(table) return initialId + def getRecordsetWithRBAC( + self, + modelClass: Type[BaseModel], + currentUser: User, + recordFilter: Dict[str, Any] = None, + orderBy: str = None, + limit: int = None, + ) -> List[Dict[str, Any]]: + """ + Get records with RBAC filtering applied at database level. + + Args: + modelClass: Pydantic model class for the table + currentUser: User object with roleLabels + recordFilter: Additional record filters + orderBy: Field to order by (defaults to "id") + limit: Maximum number of records to return + + Returns: + List of filtered records + """ + table = modelClass.__name__ + + try: + if not self._ensureTableExists(modelClass): + return [] + + # Get RBAC permissions for this table + RbacInstance = RbacClass(self) + permissions = RbacInstance.getUserPermissions( + currentUser, + AccessRuleContext.DATA, + table + ) + + # Check view permission first + if not permissions.view: + logger.debug(f"User {currentUser.id} has no view permission for table {table}") + return [] + + # Build WHERE clause with RBAC filtering + whereConditions = [] + whereValues = [] + + # Add RBAC WHERE clause based on read permission + rbacWhereClause = self.buildRbacWhereClause(permissions, currentUser, table) + if rbacWhereClause: + whereConditions.append(rbacWhereClause["condition"]) + whereValues.extend(rbacWhereClause["values"]) + + # Add additional record filters + if recordFilter: + for field, value in recordFilter.items(): + whereConditions.append(f'"{field}" = %s') + whereValues.append(value) + + # Build the query + whereClause = "" + if whereConditions: + whereClause = " WHERE " + " AND ".join(whereConditions) + + orderByClause = f' ORDER BY "{orderBy}"' if orderBy else ' ORDER BY "id"' + limitClause = f" LIMIT {limit}" if limit else "" + + query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}' + + with self.connection.cursor() as cursor: + cursor.execute(query, whereValues) + records = [dict(row) for row in cursor.fetchall()] + + # Handle JSONB fields and ensure numeric types are correct + fields = _get_model_fields(modelClass) + for record in records: + for fieldName, fieldType in fields.items(): + # Ensure numeric fields are properly typed + if fieldType in ("DOUBLE PRECISION", "INTEGER") and fieldName in record: + value = record[fieldName] + if value is not None: + try: + if fieldType == "DOUBLE PRECISION": + record[fieldName] = float(value) + elif fieldType == "INTEGER": + record[fieldName] = int(value) + except (ValueError, TypeError): + logger.warning( + f"Could not convert {fieldName} to {fieldType} for record {record.get('id', 'unknown')}: {value}" + ) + elif fieldType == "JSONB" and fieldName in record: + if record[fieldName] is None: + if fieldName in ["logs", "messages", "tasks", "expectedDocumentFormats", "resultDocuments"]: + record[fieldName] = [] + elif fieldName in ["execParameters", "stats"]: + record[fieldName] = {} + else: + record[fieldName] = None + else: + import json + try: + if isinstance(record[fieldName], str): + record[fieldName] = json.loads(record[fieldName]) + elif isinstance(record[fieldName], (dict, list)): + pass + else: + record[fieldName] = json.loads(str(record[fieldName])) + except (json.JSONDecodeError, TypeError, ValueError): + logger.warning( + f"Could not parse JSONB field {fieldName}, keeping as string: {record[fieldName]}" + ) + + return records + except Exception as e: + logger.error(f"Error loading records with RBAC from table {table}: {e}") + return [] + + def buildRbacWhereClause( + self, + permissions: UserPermissions, + currentUser: User, + table: str + ) -> Optional[Dict[str, Any]]: + """ + Build RBAC WHERE clause based on permissions and access level. + + Args: + permissions: UserPermissions object + currentUser: User object + table: Table name + + Returns: + Dictionary with "condition" and "values" keys, or None if no filtering needed + """ + if not permissions or not hasattr(permissions, "read"): + return None + + readLevel = permissions.read + + # No access - return empty result condition + if readLevel == AccessLevel.NONE: + return {"condition": "1 = 0", "values": []} + + # All records - no filtering needed + if readLevel == AccessLevel.ALL: + return None + + # My records - filter by _createdBy or userId field + if readLevel == AccessLevel.MY: + # Try common field names for creator + userIdField = None + if table == "UserInDB": + userIdField = "id" + elif table == "UserConnection": + userIdField = "userId" + else: + userIdField = "_createdBy" + + return { + "condition": f'"{userIdField}" = %s', + "values": [currentUser.id] + } + + # Group records - filter by mandateId + if readLevel == AccessLevel.GROUP: + if not currentUser.mandateId: + logger.warning(f"User {currentUser.id} has no mandateId for GROUP access") + return {"condition": "1 = 0", "values": []} + + # For UserInDB, filter by mandateId directly + if table == "UserInDB": + return { + "condition": '"mandateId" = %s', + "values": [currentUser.mandateId] + } + # For UserConnection, need to join with UserInDB or filter by mandateId in user + elif table == "UserConnection": + # Get all user IDs in the same mandate using direct SQL query + try: + with self.connection.cursor() as cursor: + cursor.execute( + 'SELECT "id" FROM "UserInDB" WHERE "mandateId" = %s', + (currentUser.mandateId,) + ) + users = cursor.fetchall() + userIds = [u["id"] for u in users] + if not userIds: + return {"condition": "1 = 0", "values": []} + placeholders = ",".join(["%s"] * len(userIds)) + return { + "condition": f'"userId" IN ({placeholders})', + "values": userIds + } + except Exception as e: + logger.error(f"Error building GROUP filter for UserConnection: {e}") + return {"condition": "1 = 0", "values": []} + # For other tables, filter by mandateId + else: + return { + "condition": '"mandateId" = %s', + "values": [currentUser.mandateId] + } + + return None + def close(self): """Close the database connection.""" if ( diff --git a/modules/datamodels/datamodelRbac.py b/modules/datamodels/datamodelRbac.py new file mode 100644 index 00000000..c2ba90d8 --- /dev/null +++ b/modules/datamodels/datamodelRbac.py @@ -0,0 +1,102 @@ +"""RBAC models: AccessRule, AccessRuleContext.""" + +import uuid +from typing import Optional +from enum import Enum +from pydantic import BaseModel, Field +from modules.shared.attributeUtils import registerModelLabels +from modules.datamodels.datamodelUam import AccessLevel + + +class AccessRuleContext(str, Enum): + """Context type enumeration""" + DATA = "DATA" # Database tables and fields + UI = "UI" # UI elements and features + RESOURCE = "RESOURCE" # System resources (AI models, actions, etc.) + + +class AccessRule(BaseModel): + """Data model for access control rules""" + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Unique ID of the access rule", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False} + ) + roleLabel: str = Field( + description="Role label this rule applies to", + json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True, "frontend_options": "user.role"} + ) + context: AccessRuleContext = Field( + description="Context type: DATA (database), UI (interface), RESOURCE (system resources)", + json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True, "frontend_options": [ + {"value": "DATA", "label": {"en": "Data", "fr": "Données"}}, + {"value": "UI", "label": {"en": "UI", "fr": "Interface"}}, + {"value": "RESOURCE", "label": {"en": "Resource", "fr": "Ressource"}} + ]} + ) + item: Optional[str] = Field( + None, + description="Item identifier (null = all items in context). Format: DATA: '' or '
.', UI: cascading string (e.g., 'playground.voice.settings'), RESOURCE: cascading string (e.g., 'ai.model.anthropic')", + json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False} + ) + view: bool = Field( + False, + description="View permission: if true, item is visible/enabled. Only objects with view=true are shown.", + json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": True} + ) + read: Optional[AccessLevel] = Field( + None, + description="Read permission level (only for DATA context)", + json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [ + {"value": "a", "label": {"en": "All Records", "fr": "Tous les enregistrements"}}, + {"value": "m", "label": {"en": "My Records", "fr": "Mes enregistrements"}}, + {"value": "g", "label": {"en": "Group Records", "fr": "Enregistrements du groupe"}}, + {"value": "n", "label": {"en": "No Access", "fr": "Aucun accès"}} + ]} + ) + create: Optional[AccessLevel] = Field( + None, + description="Create permission level (only for DATA context)", + json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [ + {"value": "a", "label": {"en": "All Records", "fr": "Tous les enregistrements"}}, + {"value": "m", "label": {"en": "My Records", "fr": "Mes enregistrements"}}, + {"value": "g", "label": {"en": "Group Records", "fr": "Enregistrements du groupe"}}, + {"value": "n", "label": {"en": "No Access", "fr": "Aucun accès"}} + ]} + ) + update: Optional[AccessLevel] = Field( + None, + description="Update permission level (only for DATA context)", + json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [ + {"value": "a", "label": {"en": "All Records", "fr": "Tous les enregistrements"}}, + {"value": "m", "label": {"en": "My Records", "fr": "Mes enregistrements"}}, + {"value": "g", "label": {"en": "Group Records", "fr": "Enregistrements du groupe"}}, + {"value": "n", "label": {"en": "No Access", "fr": "Aucun accès"}} + ]} + ) + delete: Optional[AccessLevel] = Field( + None, + description="Delete permission level (only for DATA context)", + json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [ + {"value": "a", "label": {"en": "All Records", "fr": "Tous les enregistrements"}}, + {"value": "m", "label": {"en": "My Records", "fr": "Mes enregistrements"}}, + {"value": "g", "label": {"en": "Group Records", "fr": "Enregistrements du groupe"}}, + {"value": "n", "label": {"en": "No Access", "fr": "Aucun accès"}} + ]} + ) + +registerModelLabels( + "AccessRule", + {"en": "Access Rule", "fr": "Règle d'accès"}, + { + "id": {"en": "ID", "fr": "ID"}, + "roleLabel": {"en": "Role Label", "fr": "Label du rôle"}, + "context": {"en": "Context", "fr": "Contexte"}, + "item": {"en": "Item", "fr": "Élément"}, + "view": {"en": "View", "fr": "Vue"}, + "read": {"en": "Read", "fr": "Lecture"}, + "create": {"en": "Create", "fr": "Créer"}, + "update": {"en": "Update", "fr": "Mettre à jour"}, + "delete": {"en": "Delete", "fr": "Supprimer"}, + }, +) diff --git a/modules/datamodels/datamodelUam.py b/modules/datamodels/datamodelUam.py index 4a9c10aa..4c9e0a84 100644 --- a/modules/datamodels/datamodelUam.py +++ b/modules/datamodels/datamodelUam.py @@ -1,7 +1,7 @@ """UAM models: User, Mandate, UserConnection.""" import uuid -from typing import Optional +from typing import Optional, List from enum import Enum from pydantic import BaseModel, Field, EmailStr from modules.shared.attributeUtils import registerModelLabels @@ -13,7 +13,7 @@ class AuthAuthority(str, Enum): GOOGLE = "google" MSFT = "msft" -class UserPrivilege(str, Enum): +class UserPrivilege(str, Enum): # TODO: TO remove, one new RBAC System is in place! SYSADMIN = "sysadmin" ADMIN = "admin" USER = "user" @@ -24,6 +24,36 @@ class ConnectionStatus(str, Enum): REVOKED = "revoked" PENDING = "pending" +class AccessLevel(str, Enum): + """Access level enumeration for RBAC""" + ALL = "a" # All records + MY = "m" # My records (created by me) + GROUP = "g" # Group records (group context is the mandate) + NONE = "n" # No access + +class UserPermissions(BaseModel): + """User permissions model for RBAC""" + view: bool = Field( + default=False, + description="View permission: if true, item is visible/enabled" + ) + read: AccessLevel = Field( + default=AccessLevel.NONE, + description="Read permission level" + ) + create: AccessLevel = Field( + default=AccessLevel.NONE, + description="Create permission level" + ) + update: AccessLevel = Field( + default=AccessLevel.NONE, + description="Update permission level" + ) + delete: AccessLevel = Field( + default=AccessLevel.NONE, + description="Delete permission level" + ) + class Mandate(BaseModel): id: str = Field( default_factory=lambda: str(uuid.uuid4()), @@ -122,11 +152,12 @@ class User(BaseModel): {"value": "it", "label": {"en": "Italiano", "fr": "Italien"}}, ]}) enabled: bool = Field(default=True, description="Indicates whether the user is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}) - privilege: UserPrivilege = Field(default=UserPrivilege.USER, description="Permission level", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True, "frontend_options": [ - {"value": "user", "label": {"en": "User", "fr": "Utilisateur"}}, - {"value": "admin", "label": {"en": "Admin", "fr": "Administrateur"}}, - {"value": "sysadmin", "label": {"en": "SysAdmin", "fr": "Administrateur système"}}, - ]}) + privilege: UserPrivilege = Field(default=UserPrivilege.USER, description="Permission level (DEPRECATED: use roleLabels instead)", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": "user.role"}) + roleLabels: List[str] = Field( + default_factory=list, + description="List of role labels assigned to this user. All roles are opening roles (union) - if one role enables something, it is enabled.", + json_schema_extra={"frontend_type": "multiselect", "frontend_readonly": False, "frontend_required": True, "frontend_options": "user.role"} + ) authenticationAuthority: AuthAuthority = Field(default=AuthAuthority.LOCAL, description="Primary authentication authority", json_schema_extra={"frontend_type": "select", "frontend_readonly": True, "frontend_required": False, "frontend_options": [ {"value": "local", "label": {"en": "Local", "fr": "Local"}}, {"value": "google", "label": {"en": "Google", "fr": "Google"}}, @@ -144,6 +175,7 @@ registerModelLabels( "language": {"en": "Language", "fr": "Langue"}, "enabled": {"en": "Enabled", "fr": "Activé"}, "privilege": {"en": "Privilege", "fr": "Privilège"}, + "roleLabels": {"en": "Role Labels", "fr": "Labels de rôle"}, "authenticationAuthority": {"en": "Auth Authority", "fr": "Autorité d'authentification"}, "mandateId": {"en": "Mandate ID", "fr": "ID de mandat"}, }, diff --git a/modules/features/automation/mainAutomation.py b/modules/features/automation/mainAutomation.py index c0534229..768ca2e0 100644 --- a/modules/features/automation/mainAutomation.py +++ b/modules/features/automation/mainAutomation.py @@ -163,9 +163,11 @@ async def syncAutomationEvents(chatInterface, eventUser) -> Dict[str, Any]: Returns: Dictionary with sync results (synced count and event IDs) """ - # Get all automation definitions (for current mandate) - allAutomations = chatInterface.db.getRecordset(AutomationDefinition) - filtered = chatInterface._uam(AutomationDefinition, allAutomations) + # Get all automation definitions filtered by RBAC (for current mandate) + filtered = chatInterface.db.getRecordsetWithRBAC( + AutomationDefinition, + eventUser + ) registeredEvents = {} diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py new file mode 100644 index 00000000..5c4a90a1 --- /dev/null +++ b/modules/interfaces/interfaceBootstrap.py @@ -0,0 +1,548 @@ +""" +Centralized bootstrap interface for system initialization. +Contains all bootstrap logic including mandate, users, and RBAC rules. +""" + +import logging +from typing import Optional +from passlib.context import CryptContext +from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.shared.configuration import APP_CONFIG +from modules.datamodels.datamodelUam import ( + Mandate, + UserInDB, + UserPrivilege, + AuthAuthority, +) +from modules.datamodels.datamodelRbac import ( + AccessRule, + AccessRuleContext, +) +from modules.datamodels.datamodelUam import AccessLevel + +logger = logging.getLogger(__name__) + +# Password-Hashing +pwdContext = CryptContext(schemes=["argon2"], deprecated="auto") + + +def initBootstrap(db: DatabaseConnector) -> None: + """ + Main bootstrap entry point - initializes all system components. + + Args: + db: Database connector instance + """ + logger.info("Starting system bootstrap") + + # Initialize root mandate + mandateId = initRootMandate(db) + + # Initialize admin user + adminUserId = initAdminUser(db, mandateId) + + # Initialize event user + eventUserId = initEventUser(db, mandateId) + + # Initialize RBAC rules + initRbacRules(db) + + # Assign initial user roles + if adminUserId and eventUserId: + assignInitialUserRoles(db, adminUserId, eventUserId) + + logger.info("System bootstrap completed") + + +def initRootMandate(db: DatabaseConnector) -> Optional[str]: + """ + Creates the Root mandate if it doesn't exist. + + Args: + db: Database connector instance + + Returns: + Mandate ID if created or found, None otherwise + """ + existingMandates = db.getRecordset(Mandate) + if existingMandates: + mandateId = existingMandates[0].get("id") + logger.info(f"Root mandate already exists with ID {mandateId}") + return mandateId + + logger.info("Creating Root mandate") + rootMandate = Mandate(name="Root", language="en", enabled=True) + createdMandate = db.recordCreate(Mandate, rootMandate) + mandateId = createdMandate.get("id") + logger.info(f"Root mandate created with ID {mandateId}") + return mandateId + + +def initAdminUser(db: DatabaseConnector, mandateId: Optional[str]) -> Optional[str]: + """ + Creates the Admin user if it doesn't exist. + + Args: + db: Database connector instance + mandateId: Root mandate ID + + Returns: + User ID if created or found, None otherwise + """ + existingUsers = db.getRecordset(UserInDB, recordFilter={"username": "admin"}) + if existingUsers: + userId = existingUsers[0].get("id") + logger.info(f"Admin user already exists with ID {userId}") + return userId + + logger.info("Creating Admin user") + adminUser = UserInDB( + mandateId=mandateId, + username="admin", + email="admin@example.com", + fullName="Administrator", + enabled=True, + language="en", + privilege=UserPrivilege.SYSADMIN, + roleLabels=["sysadmin"], + authenticationAuthority=AuthAuthority.LOCAL, + hashedPassword=_getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET")), + connections=[], + ) + createdUser = db.recordCreate(UserInDB, adminUser) + userId = createdUser.get("id") + logger.info(f"Admin user created with ID {userId}") + return userId + + +def initEventUser(db: DatabaseConnector, mandateId: Optional[str]) -> Optional[str]: + """ + Creates the Event user if it doesn't exist. + + Args: + db: Database connector instance + mandateId: Root mandate ID + + Returns: + User ID if created or found, None otherwise + """ + existingUsers = db.getRecordset(UserInDB, recordFilter={"username": "event"}) + if existingUsers: + userId = existingUsers[0].get("id") + logger.info(f"Event user already exists with ID {userId}") + return userId + + logger.info("Creating Event user") + eventUser = UserInDB( + mandateId=mandateId, + username="event", + email="event@example.com", + fullName="Event", + enabled=True, + language="en", + privilege=UserPrivilege.SYSADMIN, + roleLabels=["sysadmin"], + authenticationAuthority=AuthAuthority.LOCAL, + hashedPassword=_getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET")), + connections=[], + ) + createdUser = db.recordCreate(UserInDB, eventUser) + userId = createdUser.get("id") + logger.info(f"Event user created with ID {userId}") + return userId + + +def initRbacRules(db: DatabaseConnector) -> None: + """ + Initialize RBAC rules if they don't exist. + Converts all UAM logic from interface*Access.py modules to RBAC rules. + + Args: + db: Database connector instance + """ + existingRules = db.getRecordset(AccessRule) + if existingRules: + logger.info(f"RBAC rules already exist ({len(existingRules)} rules)") + return + + logger.info("Initializing RBAC rules") + + # Create default role rules + createDefaultRoleRules(db) + + # Create table-specific rules (converted from UAM logic) + createTableSpecificRules(db) + + logger.info("RBAC rules initialization completed") + + +def createDefaultRoleRules(db: DatabaseConnector) -> None: + """ + Create default role rules for generic access (item = null). + + Args: + db: Database connector instance + """ + defaultRules = [ + # SysAdmin Role - Full access to all + AccessRule( + roleLabel="sysadmin", + context=AccessRuleContext.DATA, + item=None, + view=True, + read=AccessLevel.ALL, + create=AccessLevel.ALL, + update=AccessLevel.ALL, + delete=AccessLevel.ALL, + ), + # Admin Role - Group-level access + AccessRule( + roleLabel="admin", + context=AccessRuleContext.DATA, + item=None, + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.NONE, + ), + # User Role - My records only + AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item=None, + view=True, + read=AccessLevel.MY, + create=AccessLevel.MY, + update=AccessLevel.MY, + delete=AccessLevel.MY, + ), + # Viewer Role - Read-only group access + AccessRule( + roleLabel="viewer", + context=AccessRuleContext.DATA, + item=None, + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + ), + ] + + for rule in defaultRules: + db.recordCreate(AccessRule, rule) + + logger.info(f"Created {len(defaultRules)} default role rules") + + +def createTableSpecificRules(db: DatabaseConnector) -> None: + """ + Create table-specific rules converted from UAM logic. + These rules override generic rules for specific tables. + + Args: + db: Database connector instance + """ + tableRules = [] + + # Mandate table - Only sysadmin can access + tableRules.append(AccessRule( + roleLabel="sysadmin", + context=AccessRuleContext.DATA, + item="Mandate", + view=True, + read=AccessLevel.ALL, + create=AccessLevel.ALL, + update=AccessLevel.ALL, + delete=AccessLevel.ALL, + )) + tableRules.append(AccessRule( + roleLabel="admin", + context=AccessRuleContext.DATA, + item="Mandate", + view=False, + read=AccessLevel.NONE, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + tableRules.append(AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="Mandate", + view=False, + read=AccessLevel.NONE, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + tableRules.append(AccessRule( + roleLabel="viewer", + context=AccessRuleContext.DATA, + item="Mandate", + view=False, + read=AccessLevel.NONE, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # UserInDB table + tableRules.append(AccessRule( + roleLabel="sysadmin", + context=AccessRuleContext.DATA, + item="UserInDB", + view=True, + read=AccessLevel.ALL, + create=AccessLevel.ALL, + update=AccessLevel.ALL, + delete=AccessLevel.ALL, + )) + tableRules.append(AccessRule( + roleLabel="admin", + context=AccessRuleContext.DATA, + item="UserInDB", + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.GROUP, + )) + tableRules.append(AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="UserInDB", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.MY, + delete=AccessLevel.NONE, + )) + tableRules.append(AccessRule( + roleLabel="viewer", + context=AccessRuleContext.DATA, + item="UserInDB", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # UserConnection table + tableRules.append(AccessRule( + roleLabel="sysadmin", + context=AccessRuleContext.DATA, + item="UserConnection", + view=True, + read=AccessLevel.ALL, + create=AccessLevel.ALL, + update=AccessLevel.ALL, + delete=AccessLevel.ALL, + )) + tableRules.append(AccessRule( + roleLabel="admin", + context=AccessRuleContext.DATA, + item="UserConnection", + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.GROUP, + )) + tableRules.append(AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="UserConnection", + view=True, + read=AccessLevel.MY, + create=AccessLevel.MY, + update=AccessLevel.MY, + delete=AccessLevel.MY, + )) + tableRules.append(AccessRule( + roleLabel="viewer", + context=AccessRuleContext.DATA, + item="UserConnection", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # DataNeutraliserConfig table + tableRules.append(AccessRule( + roleLabel="sysadmin", + context=AccessRuleContext.DATA, + item="DataNeutraliserConfig", + view=True, + read=AccessLevel.ALL, + create=AccessLevel.ALL, + update=AccessLevel.ALL, + delete=AccessLevel.ALL, + )) + tableRules.append(AccessRule( + roleLabel="admin", + context=AccessRuleContext.DATA, + item="DataNeutraliserConfig", + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.GROUP, + )) + tableRules.append(AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="DataNeutraliserConfig", + view=True, + read=AccessLevel.MY, + create=AccessLevel.MY, + update=AccessLevel.MY, + delete=AccessLevel.MY, + )) + tableRules.append(AccessRule( + roleLabel="viewer", + context=AccessRuleContext.DATA, + item="DataNeutraliserConfig", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # DataNeutralizerAttributes table + tableRules.append(AccessRule( + roleLabel="sysadmin", + context=AccessRuleContext.DATA, + item="DataNeutralizerAttributes", + view=True, + read=AccessLevel.ALL, + create=AccessLevel.ALL, + update=AccessLevel.ALL, + delete=AccessLevel.ALL, + )) + tableRules.append(AccessRule( + roleLabel="admin", + context=AccessRuleContext.DATA, + item="DataNeutralizerAttributes", + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.GROUP, + )) + tableRules.append(AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="DataNeutralizerAttributes", + view=True, + read=AccessLevel.MY, + create=AccessLevel.MY, + update=AccessLevel.MY, + delete=AccessLevel.MY, + )) + tableRules.append(AccessRule( + roleLabel="viewer", + context=AccessRuleContext.DATA, + item="DataNeutralizerAttributes", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # AuthEvent table + tableRules.append(AccessRule( + roleLabel="sysadmin", + context=AccessRuleContext.DATA, + item="AuthEvent", + view=True, + read=AccessLevel.ALL, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.ALL, + )) + tableRules.append(AccessRule( + roleLabel="admin", + context=AccessRuleContext.DATA, + item="AuthEvent", + view=True, + read=AccessLevel.ALL, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.ALL, + )) + tableRules.append(AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="AuthEvent", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + tableRules.append(AccessRule( + roleLabel="viewer", + context=AccessRuleContext.DATA, + item="AuthEvent", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE, + )) + + # Create all table-specific rules + for rule in tableRules: + db.recordCreate(AccessRule, rule) + + logger.info(f"Created {len(tableRules)} table-specific rules") + + +def assignInitialUserRoles(db: DatabaseConnector, adminUserId: str, eventUserId: str) -> None: + """ + Assign initial roles to admin and event users. + + Args: + db: Database connector instance + adminUserId: Admin user ID + eventUserId: Event user ID + """ + # Update admin user with sysadmin role + adminUser = db.getRecordset(UserInDB, recordFilter={"id": adminUserId}) + if adminUser: + adminUserData = adminUser[0] + if "sysadmin" not in adminUserData.get("roleLabels", []): + adminUserData["roleLabels"] = adminUserData.get("roleLabels", []) + ["sysadmin"] + db.recordUpdate(UserInDB, adminUserId, adminUserData) + logger.info(f"Assigned sysadmin role to admin user {adminUserId}") + + # Update event user with sysadmin role + eventUser = db.getRecordset(UserInDB, recordFilter={"id": eventUserId}) + if eventUser: + eventUserData = eventUser[0] + if "sysadmin" not in eventUserData.get("roleLabels", []): + eventUserData["roleLabels"] = eventUserData.get("roleLabels", []) + ["sysadmin"] + db.recordUpdate(UserInDB, eventUserId, eventUserData) + logger.info(f"Assigned sysadmin role to event user {eventUserId}") + + +def _getPasswordHash(password: Optional[str]) -> Optional[str]: + """ + Hash a password using Argon2. + + Args: + password: Plain text password + + Returns: + Hashed password or None if password is None + """ + if password is None: + return None + return pwdContext.hash(password) diff --git a/modules/interfaces/interfaceDbAppAccess.py b/modules/interfaces/interfaceDbAppAccess.py deleted file mode 100644 index 1bb9126c..00000000 --- a/modules/interfaces/interfaceDbAppAccess.py +++ /dev/null @@ -1,254 +0,0 @@ -""" -Access control for the Application. -""" - -import logging -from typing import Dict, Any, List, Optional -from modules.datamodels.datamodelUam import UserPrivilege, User, UserInDB, Mandate -from modules.datamodels.datamodelSecurity import AuthEvent - -# Configure logger -logger = logging.getLogger(__name__) - -class AppAccess: - """ - Access control class for Application interface. - Handles user access management and permission checks. - """ - - def __init__(self, currentUser: User, db): - """Initialize with user context.""" - self.currentUser = currentUser - self.userId = currentUser.id - self.mandateId = currentUser.mandateId - self.privilege = currentUser.privilege - - if not self.mandateId or not self.userId: - raise ValueError("Invalid user context: mandateId and userId are required") - - self.db = db - - def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Unified user access management function that filters data based on user privileges - and adds access control attributes. - - Args: - model_class: Pydantic model class for the table - recordset: Recordset to filter based on access rules - - Returns: - Filtered recordset with access control attributes - """ - filtered_records = [] - table_name = model_class.__name__ - - # Only SYSADMIN can see mandates - if table_name == "Mandate": - if self.privilege == UserPrivilege.SYSADMIN: - filtered_records = recordset - else: - filtered_records = [] - # Special handling for users table - elif table_name == "UserInDB": - if self.privilege == UserPrivilege.SYSADMIN: - # SysAdmin sees all users - filtered_records = recordset - elif self.privilege == UserPrivilege.ADMIN: - # Admin sees all users in their mandate - filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId] - else: - # Regular users only see themselves - filtered_records = [r for r in recordset if r.get("id") == self.userId] - # Special handling for connections table - elif table_name == "UserConnection": - if self.privilege == UserPrivilege.SYSADMIN: - # SysAdmin sees all connections - filtered_records = recordset - elif self.privilege == UserPrivilege.ADMIN: - # Admin sees connections for users in their mandate - users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId}) - user_ids: List[str] = [str(u["id"]) for u in users] - filtered_records = [r for r in recordset if r.get("userId") in user_ids] - else: - # Regular users only see their own connections - filtered_records = [r for r in recordset if r.get("userId") == self.userId] - # Special handling for data neutralization config table - elif table_name == "DataNeutraliserConfig": - if self.privilege == UserPrivilege.SYSADMIN: - # SysAdmin sees all configs - filtered_records = recordset - elif self.privilege == UserPrivilege.ADMIN: - # Admin sees configs in their mandate - filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId] - else: - # Regular users only see their own configs - filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId and r.get("userId") == self.userId] - # Special handling for data neutralizer attributes table - elif table_name == "DataNeutralizerAttributes": - if self.privilege == UserPrivilege.SYSADMIN: - # SysAdmin sees all attributes - filtered_records = recordset - elif self.privilege == UserPrivilege.ADMIN: - # Admin sees attributes in their mandate - filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId] - else: - # Regular users only see their own attributes - filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId and r.get("userId") == self.userId] - # System admins see all other records - elif self.privilege == UserPrivilege.SYSADMIN: - filtered_records = recordset - # For other records, admins see records in their mandate - elif self.privilege == UserPrivilege.ADMIN: - filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId] - # Regular users only see records they own within their mandate - else: - filtered_records = [r for r in recordset - if r.get("mandateId","-") == self.mandateId and r.get("createdBy") == self.userId] - - # Add access control attributes to each record - for record in filtered_records: - record_id = record.get("id") - - # Set access control flags based on user permissions - if table_name == "Mandate": - record["_hideView"] = False # SYSADMIN can view - record["_hideEdit"] = not self.canModify(Mandate, record_id) - record["_hideDelete"] = not self.canModify(Mandate, record_id) - elif table_name == "UserInDB": - record["_hideView"] = False # Everyone can view users they have access to - # SysAdmin can edit/delete any user - if self.privilege == UserPrivilege.SYSADMIN: - record["_hideEdit"] = False - record["_hideDelete"] = False - # Admin can edit/delete users in their mandate - elif self.privilege == UserPrivilege.ADMIN: - record["_hideEdit"] = record.get("mandateId","-") != self.mandateId - record["_hideDelete"] = record.get("mandateId","-") != self.mandateId - # Regular users can only edit themselves - else: - record["_hideEdit"] = record.get("id") != self.userId - record["_hideDelete"] = True # Regular users cannot delete users - elif table_name == "UserConnection": - # Everyone can view connections they have access to - record["_hideView"] = False - # SysAdmin can edit/delete any connection - if self.privilege == UserPrivilege.SYSADMIN: - record["_hideEdit"] = False - record["_hideDelete"] = False - # Admin can edit/delete connections for users in their mandate - elif self.privilege == UserPrivilege.ADMIN: - users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId}) - user_ids: List[str] = [str(u["id"]) for u in users] - record["_hideEdit"] = record.get("userId") not in user_ids - record["_hideDelete"] = record.get("userId") not in user_ids - # Regular users can only edit/delete their own connections - else: - record["_hideEdit"] = record.get("userId") != self.userId - record["_hideDelete"] = record.get("userId") != self.userId - - elif table_name == "DataNeutraliserConfig": - # Everyone can view configs they have access to - record["_hideView"] = False - # SysAdmin can edit/delete any config - if self.privilege == UserPrivilege.SYSADMIN: - record["_hideEdit"] = False - record["_hideDelete"] = False - # Admin can edit/delete configs in their mandate - elif self.privilege == UserPrivilege.ADMIN: - record["_hideEdit"] = record.get("mandateId","-") != self.mandateId - record["_hideDelete"] = record.get("mandateId","-") != self.mandateId - # Regular users can only edit/delete their own configs - else: - record["_hideEdit"] = record.get("userId") != self.userId - record["_hideDelete"] = record.get("userId") != self.userId - elif table_name == "DataNeutralizerAttributes": - # Everyone can view attributes they have access to - record["_hideView"] = False - # SysAdmin can edit/delete any attributes - if self.privilege == UserPrivilege.SYSADMIN: - record["_hideEdit"] = False - record["_hideDelete"] = False - # Admin can edit/delete attributes in their mandate - elif self.privilege == UserPrivilege.ADMIN: - record["_hideEdit"] = record.get("mandateId","-") != self.mandateId - record["_hideDelete"] = record.get("mandateId","-") != self.mandateId - # Regular users can only edit/delete their own attributes - else: - record["_hideEdit"] = record.get("userId") != self.userId - record["_hideDelete"] = record.get("userId") != self.userId - - elif table_name == "AuthEvent": - # Only show auth events for the current user or if admin - if self.privilege in [UserPrivilege.SYSADMIN, UserPrivilege.ADMIN]: - record["_hideView"] = False - else: - record["_hideView"] = record.get("userId") != self.userId - record["_hideEdit"] = True # Auth events can't be edited - record["_hideDelete"] = not self.canModify(AuthEvent, record_id) - else: - # Default access control for other tables - record["_hideView"] = False - record["_hideEdit"] = not self.canModify(model_class, record_id) - record["_hideDelete"] = not self.canModify(model_class, record_id) - - return filtered_records - - def canModify(self, model_class: type, recordId: Optional[str] = None) -> bool: - """ - Checks if the current user can modify (create/update/delete) records in a table. - - Args: - model_class: Pydantic model class for the table - recordId: Optional record ID for specific record check - - Returns: - Boolean indicating permission - """ - table_name = model_class.__name__ - - # For mandates, only SYSADMIN can modify - if table_name == "Mandate": - return self.privilege == UserPrivilege.SYSADMIN - - # System admins can modify anything else - if self.privilege == UserPrivilege.SYSADMIN: - return True - - # Check specific record permissions - if recordId is not None: - # Get the record to check ownership - records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": str(recordId)}) - if not records: - return False - - record = records[0] - - # Special handling for connections - if table_name == "UserConnection": - # Admin can modify connections for users in their mandate - if self.privilege == UserPrivilege.ADMIN: - users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId}) - user_ids: List[str] = [str(u["id"]) for u in users] - return record.get("userId") in user_ids - # Users can only modify their own connections - return record.get("userId") == self.userId - - # Admins can modify anything in their mandate - if self.privilege == UserPrivilege.ADMIN and record.get("mandateId","-") == self.mandateId: - return True - - # Users can only modify their own records - if (record.get("mandateId","-") == self.mandateId and - record.get("createdBy") == self.userId): - return True - - return False - else: - # For general table modify permission (e.g., create) - # Admins can create anything in their mandate - if self.privilege == UserPrivilege.ADMIN: - return True - - # Regular users can create most entities - return True diff --git a/modules/interfaces/interfaceDbAppObjects.py b/modules/interfaces/interfaceDbAppObjects.py index 91d7bda4..900f7328 100644 --- a/modules/interfaces/interfaceDbAppObjects.py +++ b/modules/interfaces/interfaceDbAppObjects.py @@ -12,7 +12,8 @@ import uuid from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp -from modules.interfaces.interfaceDbAppAccess import AppAccess +from modules.interfaces.interfaceBootstrap import initBootstrap +from modules.security.rbac import RbacClass from modules.datamodels.datamodelUam import ( User, Mandate, @@ -22,6 +23,11 @@ from modules.datamodels.datamodelUam import ( UserPrivilege, ConnectionStatus, ) +from modules.datamodels.datamodelRbac import ( + AccessRule, + AccessRuleContext, +) +from modules.datamodels.datamodelUam import AccessLevel from modules.datamodels.datamodelSecurity import Token, AuthEvent, TokenStatus from modules.datamodels.datamodelNeutralizer import ( DataNeutraliserConfig, @@ -53,7 +59,6 @@ class AppObjects: self.currentUser = currentUser # Store User object directly self.userId = currentUser.id if currentUser else None self.mandateId = currentUser.mandateId if currentUser else None - self.access = None # Will be set when user context is provided # Initialize database self._initializeDatabase() @@ -81,10 +86,10 @@ class AppObjects: # Add language settings self.userLanguage = currentUser.language # Default user language - # Initialize access control with user context - self.access = AppAccess( - self.currentUser, self.db - ) # Convert to dict only when needed + # Initialize RBAC interface + if not currentUser: + raise ValueError("User context is required for RBAC") + self.rbac = RbacClass(self.db) # Update database context self.db.updateContext(self.userId) @@ -127,113 +132,46 @@ class AppObjects: def _initRecords(self): """Initialize standard records if they don't exist.""" - self._initRootMandate() - self._initAdminUser() - self._initEventUser() + initBootstrap(self.db) - def _initRootMandate(self): - """Creates the Root mandate if it doesn't exist.""" - existingMandateId = self.getInitialId(Mandate) - mandates = self.db.getRecordset(Mandate) - if existingMandateId is None or not mandates: - logger.info("Creating Root mandate") - rootMandate = Mandate(name="Root", language="en", enabled=True) - createdMandate = self.db.recordCreate(Mandate, rootMandate) - logger.info(f"Root mandate created with ID {createdMandate['id']}") - # Update mandate context - self.mandateId = createdMandate["id"] - - def _initAdminUser(self): - """Creates the Admin user if it doesn't exist.""" - existingUserId = self.getInitialId(UserInDB) - users = self.db.getRecordset(UserInDB) - if existingUserId is None or not users: - logger.info("Creating Admin user") - adminUser = UserInDB( - mandateId=self.getInitialId(Mandate), - username="admin", - email="admin@example.com", - fullName="Administrator", - enabled=True, - language="en", - privilege=UserPrivilege.SYSADMIN, - authenticationAuthority="local", # Using lowercase value directly - hashedPassword=self._getPasswordHash( - APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET") - ), - connections=[], - ) - createdUser = self.db.recordCreate(UserInDB, adminUser) - logger.info(f"Admin user created with ID {createdUser['id']}") - - # Update user context - self.currentUser = createdUser - self.userId = createdUser.get("id") - - def _initEventUser(self): - """Creates the Event user if it doesn't exist.""" - # Check if event user already exists - existingUsers = self.db.getRecordset( - UserInDB, recordFilter={"username": "event"} - ) - if not existingUsers: - logger.info("Creating Event user") - eventUser = UserInDB( - mandateId=self.getInitialId(Mandate), - username="event", - email="event@example.com", - fullName="Event", - enabled=True, - language="en", - privilege=UserPrivilege.SYSADMIN, - authenticationAuthority="local", # Using lowercase value directly - hashedPassword=self._getPasswordHash( - APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET") - ), - connections=[], - ) - createdUser = self.db.recordCreate(UserInDB, eventUser) - logger.info(f"Event user created with ID {createdUser['id']}") - - def _uam( - self, model_class: type, recordset: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: + def checkRbacPermission( + self, + modelClass: type, + operation: str, + recordId: Optional[str] = None + ) -> bool: """ - Unified user access management function that filters data based on user privileges - and adds access control attributes. + Check RBAC permission for a specific operation on a table. Args: - model_class: Pydantic model class for the table - recordset: Recordset to filter based on access rules - - Returns: - Filtered recordset with access control attributes - """ - # First apply access control - filteredRecords = self.access.uam(model_class, recordset) - - # Then filter out database-specific fields - cleanedRecords = [] - for record in filteredRecords: - # Create a new dict with only non-database fields - cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")} - cleanedRecords.append(cleanedRecord) - - return cleanedRecords - - def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool: - """ - Checks if the current user can modify (create/update/delete) records in a table. - - Args: - model_class: Pydantic model class for the table + modelClass: Pydantic model class for the table + operation: Operation to check ('create', 'update', 'delete', 'read') recordId: Optional record ID for specific record check Returns: Boolean indicating permission """ - return self.access.canModify(model_class, recordId) + if not self.rbac or not self.currentUser: + return False + + tableName = modelClass.__name__ + permissions = self.rbac.getUserPermissions( + self.currentUser, + AccessRuleContext.DATA, + tableName + ) + + if operation == "create": + return permissions.create != AccessLevel.NONE + elif operation == "update": + return permissions.update != AccessLevel.NONE + elif operation == "delete": + return permissions.delete != AccessLevel.NONE + elif operation == "read": + return permissions.read != AccessLevel.NONE + else: + return False def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]: """ @@ -480,13 +418,18 @@ class AppObjects: If pagination is None: List[User] If pagination is provided: PaginatedResult with items and metadata """ - # For SYSADMIN, get all users regardless of mandate - # For others, filter by mandate - if self.currentUser and self.currentUser.privilege == UserPrivilege.SYSADMIN: - users = self.db.getRecordset(UserInDB) - else: - users = self.db.getRecordset(UserInDB, recordFilter={"mandateId": mandateId}) - filteredUsers = self._uam(UserInDB, users) + # Use RBAC filtering + users = self.db.getRecordsetWithRBAC( + UserInDB, + self.currentUser, + recordFilter={"mandateId": mandateId} if mandateId else None + ) + + # Filter out database-specific fields + filteredUsers = [] + for user in users: + cleanedUser = {k: v for k, v in user.items() if not k.startswith("_")} + filteredUsers.append(cleanedUser) # If no pagination requested, return all items if pagination is None: @@ -521,18 +464,22 @@ class AppObjects: def getUserByUsername(self, username: str) -> Optional[User]: """Returns a user by username.""" try: - # Get users table - users = self.db.getRecordset(UserInDB) + # Use RBAC filtering + users = self.db.getRecordsetWithRBAC( + UserInDB, + self.currentUser, + recordFilter={"username": username} + ) + if not users: + logger.info(f"No user found with username {username}") return None - # Find user by username - for user_dict in users: - if user_dict.get("username") == username: - return User(**user_dict) - - logger.info(f"No user found with username {username}") - return None + # Return first matching user (should be unique) + userDict = users[0] + # Filter out database-specific fields + cleanedUser = {k: v for k, v in userDict.items() if not k.startswith("_")} + return User(**cleanedUser) except Exception as e: logger.error(f"Error getting user by username: {str(e)}") @@ -549,11 +496,9 @@ class AppObjects: # Find user by ID for user_dict in users: if user_dict.get("id") == userId: - # Apply access control - filteredUsers = self._uam(UserInDB, [user_dict]) - if filteredUsers: - return User(**filteredUsers[0]) - return None + # User already filtered by RBAC, just clean fields + cleanedUser = {k: v for k, v in user_dict.items() if not k.startswith("_")} + return User(**cleanedUser) return None @@ -764,7 +709,7 @@ class AppObjects: if not user: raise ValueError(f"User {userId} not found") - if not self._canModify(UserInDB, userId): + if not self.checkRbacPermission(UserInDB, "update", userId): raise PermissionError(f"No permission to delete user {userId}") # Delete all referenced data first @@ -943,8 +888,14 @@ class AppObjects: If pagination is None: List[Mandate] If pagination is provided: PaginatedResult with items and metadata """ - allMandates = self.db.getRecordset(Mandate) - filteredMandates = self._uam(Mandate, allMandates) + # Use RBAC filtering + allMandates = self.db.getRecordsetWithRBAC(Mandate, self.currentUser) + + # Filter out database-specific fields + filteredMandates = [] + for mandate in allMandates: + cleanedMandate = {k: v for k, v in mandate.items() if not k.startswith("_")} + filteredMandates.append(cleanedMandate) # If no pagination requested, return all items if pagination is None: @@ -978,11 +929,21 @@ class AppObjects: def getMandate(self, mandateId: str) -> Optional[Mandate]: """Returns a mandate by ID if user has access.""" - mandates = self.db.getRecordset(Mandate, recordFilter={"id": mandateId}) + # Use RBAC filtering + mandates = self.db.getRecordsetWithRBAC( + Mandate, + self.currentUser, + recordFilter={"id": mandateId} + ) + if not mandates: return None - - filteredMandates = self._uam(Mandate, mandates) + + # Filter out database-specific fields + filteredMandates = [] + for mandate in mandates: + cleanedMandate = {k: v for k, v in mandate.items() if not k.startswith("_")} + filteredMandates.append(cleanedMandate) if not filteredMandates: return None @@ -990,7 +951,7 @@ class AppObjects: def createMandate(self, name: str, language: str = "en") -> Mandate: """Creates a new mandate if user has permission.""" - if not self._canModify(Mandate): + if not self.checkRbacPermission(Mandate, "create"): raise PermissionError("No permission to create mandates") # Create mandate data using model @@ -1007,7 +968,7 @@ class AppObjects: """Updates a mandate if user has access.""" try: # First check if user has permission to modify mandates - if not self._canModify(Mandate, mandateId): + if not self.checkRbacPermission(Mandate, "update", mandateId): raise PermissionError(f"No permission to update mandate {mandateId}") # Get mandate with access control @@ -1044,7 +1005,7 @@ class AppObjects: if not mandate: return False - if not self._canModify(Mandate, mandateId): + if not self.checkRbacPermission(Mandate, "delete", mandateId): raise PermissionError(f"No permission to delete mandate {mandateId}") # Check if mandate has users @@ -1384,7 +1345,7 @@ class AppObjects: self.currentUser = None self.userId = None self.mandateId = None - self.access = None + self.rbac = None # Clear database context if hasattr(self, "db"): @@ -1401,18 +1362,20 @@ class AppObjects: def getNeutralizationConfig(self) -> Optional[DataNeutraliserConfig]: """Get the data neutralization configuration for the current user's mandate""" try: - configs = self.db.getRecordset( - DataNeutraliserConfig, recordFilter={"mandateId": self.mandateId} + # Use RBAC filtering + filtered_configs = self.db.getRecordsetWithRBAC( + DataNeutraliserConfig, + self.currentUser, + recordFilter={"mandateId": self.mandateId} ) - if not configs: - return None - - # Apply access control - filtered_configs = self._uam(DataNeutraliserConfig, configs) + if not filtered_configs: return None - return DataNeutraliserConfig(**filtered_configs[0]) + # Filter out database-specific fields + configDict = filtered_configs[0] + cleanedConfig = {k: v for k, v in configDict.items() if not k.startswith("_")} + return DataNeutraliserConfig(**cleanedConfig) except Exception as e: logger.error(f"Error getting neutralization config: {str(e)}") @@ -1461,14 +1424,22 @@ class AppObjects: if file_id: filter_dict["fileId"] = file_id - attributes = self.db.getRecordset( - DataNeutralizerAttributes, recordFilter=filter_dict + # Use RBAC filtering + filtered_attributes = self.db.getRecordsetWithRBAC( + DataNeutralizerAttributes, + self.currentUser, + recordFilter=filter_dict ) - filtered_attributes = self._uam(DataNeutralizerAttributes, attributes) + # Filter out database-specific fields + cleaned_attributes = [] + for attr in filtered_attributes: + cleanedAttr = {k: v for k, v in attr.items() if not k.startswith("_")} + cleaned_attributes.append(cleanedAttr) + return [ DataNeutralizerAttributes(**attr) - for attr in filtered_attributes + for attr in cleaned_attributes ] except Exception as e: @@ -1495,6 +1466,151 @@ class AppObjects: logger.error(f"Error deleting neutralization attributes: {str(e)}") return False + # RBAC CRUD Methods + + def createAccessRule(self, accessRule: AccessRule) -> AccessRule: + """ + Create a new access rule. + + Args: + accessRule: AccessRule object to create + + Returns: + Created AccessRule object + """ + try: + createdRule = self.db.recordCreate(AccessRule, accessRule) + logger.info(f"Created access rule with ID {createdRule.get('id')}") + return AccessRule(**createdRule) + except Exception as e: + logger.error(f"Error creating access rule: {str(e)}") + raise + + def getAccessRule(self, ruleId: str) -> Optional[AccessRule]: + """ + Get an access rule by ID. + + Args: + ruleId: Access rule ID + + Returns: + AccessRule object if found, None otherwise + """ + try: + rules = self.db.getRecordset(AccessRule, recordFilter={"id": ruleId}) + if rules: + return AccessRule(**rules[0]) + return None + except Exception as e: + logger.error(f"Error getting access rule {ruleId}: {str(e)}") + return None + + def updateAccessRule(self, ruleId: str, accessRule: AccessRule) -> AccessRule: + """ + Update an existing access rule. + + Args: + ruleId: Access rule ID + accessRule: Updated AccessRule object + + Returns: + Updated AccessRule object + """ + try: + updatedRule = self.db.recordUpdate(AccessRule, ruleId, accessRule.model_dump()) + logger.info(f"Updated access rule with ID {ruleId}") + return AccessRule(**updatedRule) + except Exception as e: + logger.error(f"Error updating access rule {ruleId}: {str(e)}") + raise + + def deleteAccessRule(self, ruleId: str) -> bool: + """ + Delete an access rule. + + Args: + ruleId: Access rule ID + + Returns: + True if deleted successfully, False otherwise + """ + try: + self.db.recordDelete(AccessRule, ruleId) + logger.info(f"Deleted access rule with ID {ruleId}") + return True + except Exception as e: + logger.error(f"Error deleting access rule {ruleId}: {str(e)}") + return False + + def getAccessRules( + self, + roleLabel: Optional[str] = None, + context: Optional[AccessRuleContext] = None, + item: Optional[str] = None + ) -> List[AccessRule]: + """ + Get access rules with optional filters. + + Args: + roleLabel: Optional role label filter + context: Optional context filter + item: Optional item filter + + Returns: + List of AccessRule objects + """ + try: + recordFilter = {} + if roleLabel: + recordFilter["roleLabel"] = roleLabel + if context: + recordFilter["context"] = context.value + if item: + recordFilter["item"] = item + + rules = self.db.getRecordset(AccessRule, recordFilter=recordFilter if recordFilter else None) + return [AccessRule(**rule) for rule in rules] + except Exception as e: + logger.error(f"Error getting access rules: {str(e)}") + return [] + + def getAccessRulesForRoles( + self, + roleLabels: List[str], + context: AccessRuleContext, + item: str + ) -> List[AccessRule]: + """ + Get access rules for multiple roles, context, and item. + Returns the most specific matching rules for each role. + + Args: + roleLabels: List of role labels + context: Context type + item: Item identifier + + Returns: + List of AccessRule objects (most specific for each role) + """ + try: + RbacInstance = RbacClass(self.db) + allRules = [] + + for roleLabel in roleLabels: + # Get all rules for this role and context + roleRules = RbacInstance._getRulesForRole(roleLabel, context) + + # Find most specific rule for this item + mostSpecificRule = RbacInstance.findMostSpecificRule(roleRules, item) + + if mostSpecificRule: + allRules.append(mostSpecificRule) + + return allRules + except Exception as e: + logger.error(f"Error getting access rules for roles: {str(e)}") + return [] + # Public Methods diff --git a/modules/interfaces/interfaceDbChatAccess.py b/modules/interfaces/interfaceDbChatAccess.py deleted file mode 100644 index 37e96d84..00000000 --- a/modules/interfaces/interfaceDbChatAccess.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -Access control module for Chat interface. -Handles user access management and permission checks. -""" - -from typing import Dict, Any, List, Optional -from modules.datamodels.datamodelUam import User, UserPrivilege -from modules.datamodels.datamodelChat import ChatWorkflow, AutomationDefinition - -class ChatAccess: - """ - Access control class for Chat interface. - Handles user access management and permission checks. - """ - - def __init__(self, currentUser: User, db): - """Initialize with user context.""" - self.currentUser = currentUser - self.mandateId = currentUser.mandateId - self.userId = currentUser.id - - if not self.mandateId or not self.userId: - raise ValueError("Invalid user context: mandateId and userId are required") - - self.db = db - - def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Unified user access management function that filters data based on user privileges - and adds access control attributes. - - Args: - model_class: Pydantic model class for the table - recordset: Recordset to filter based on access rules - - Returns: - Filtered recordset with access control attributes - """ - userPrivilege = self.currentUser.privilege - table_name = model_class.__name__ - filtered_records = [] - - # Apply filtering based on privilege - if table_name == "AutomationDefinition": - # Filter automations based on user privilege - if userPrivilege == UserPrivilege.SYSADMIN: - # System admins see all automations - filtered_records = recordset - elif userPrivilege == UserPrivilege.ADMIN: - # Admins see all automations in their mandate - filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId] - else: - # Regular users see only their own automations - filtered_records = [ - r for r in recordset - if r.get("mandateId","-") == self.mandateId and r.get("_createdBy") == self.userId - ] - elif userPrivilege == UserPrivilege.SYSADMIN: - filtered_records = recordset # System admins see all records - elif userPrivilege == UserPrivilege.ADMIN: - # Admins see records in their mandate - filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId] - else: # Regular users - # Users see only their records for other tables - filtered_records = [r for r in recordset - if r.get("mandateId","-") == self.mandateId and r.get("_createdBy") == self.userId] - - # Add access control attributes to each record - for record in filtered_records: - record_id = record.get("id") - - # Set access control flags based on user permissions - if table_name == "ChatWorkflow": - record["_hideView"] = False # Everyone can view - record["_hideEdit"] = not self.canModify(ChatWorkflow, record_id) - record["_hideDelete"] = not self.canModify(ChatWorkflow, record_id) - elif table_name == "ChatMessage": - record["_hideView"] = False # Everyone can view - record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId")) - record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId")) - elif table_name == "ChatLog": - record["_hideView"] = False # Everyone can view - record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId")) - record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId")) - elif table_name == "AutomationDefinition": - record["_hideView"] = False # Everyone can view - record["_hideEdit"] = not self.canModify(AutomationDefinition, record_id) - record["_hideDelete"] = not self.canModify(AutomationDefinition, record_id) - else: - # Default access control for other tables - record["_hideView"] = False - record["_hideEdit"] = not self.canModify(model_class, record_id) - record["_hideDelete"] = not self.canModify(model_class, record_id) - - return filtered_records - - def canModify(self, model_class: type, recordId: Optional[str] = None) -> bool: - """ - Checks if the current user can modify (create/update/delete) records in a table. - - Args: - model_class: Pydantic model class for the table - recordId: Optional record ID for specific record check - - Returns: - Boolean indicating permission - """ - userPrivilege = self.currentUser.privilege - - # System admins can modify anything - if userPrivilege == UserPrivilege.SYSADMIN: - return True - - # For regular users and admins, check specific cases - if recordId is not None: - # Get the record to check ownership - records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": recordId}) - if not records: - return False - - record = records[0] - - # Admins can modify anything in their mandate, if mandate is specified for a record - if userPrivilege == UserPrivilege.ADMIN and record.get("mandateId","-") == self.mandateId: - return True - - # Regular users can only modify their own records - if (record.get("mandateId","-") == self.mandateId and - record.get("_createdBy") == self.userId): - return True - - return False - else: - # For general modification permission (e.g., create) - # Admins can create anything in their mandate - if userPrivilege == UserPrivilege.ADMIN: - return True - - # Regular users can create in most tables - return True \ No newline at end of file diff --git a/modules/interfaces/interfaceDbChatObjects.py b/modules/interfaces/interfaceDbChatObjects.py index de4abc7e..6093eb78 100644 --- a/modules/interfaces/interfaceDbChatObjects.py +++ b/modules/interfaces/interfaceDbChatObjects.py @@ -10,7 +10,9 @@ from typing import Dict, Any, List, Optional, Union import asyncio -from modules.interfaces.interfaceDbChatAccess import ChatAccess +from modules.security.rbac import RbacClass +from modules.datamodels.datamodelRbac import AccessRuleContext +from modules.datamodels.datamodelUam import AccessLevel from modules.datamodels.datamodelChat import ( ChatDocument, @@ -179,7 +181,7 @@ class ChatObjects: self.currentUser = currentUser # Store User object directly self.userId = currentUser.id if currentUser else None self.mandateId = currentUser.mandateId if currentUser else None - self.access = None # Will be set when user context is provided + self.rbac = None # RBAC interface # Initialize services self._initializeServices() @@ -263,8 +265,10 @@ class ChatObjects: # Add language settings self.userLanguage = currentUser.language # Default user language - # Initialize access control with user context - self.access = ChatAccess(self.currentUser, self.db) # Convert to dict only when needed + # Initialize RBAC interface + if not self.currentUser: + raise ValueError("User context is required for RBAC") + self.rbac = RbacClass(self.db) # Update database context self.db.updateContext(self.userId) @@ -310,35 +314,44 @@ class ChatObjects: """Initializes standard records in the database if they don't exist.""" pass - def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Delegate to access control module.""" - # First apply access control - filteredRecords = self.access.uam(model_class, recordset) - - # For AutomationDefinition, keep _createdBy and mandateId for enrichment purposes - # Other fields starting with _ are filtered out as they're database-specific - if model_class.__name__ == "AutomationDefinition": - # Keep _createdBy and mandateId for enrichment, filter out other _ fields - cleanedRecords = [] - for record in filteredRecords: - cleanedRecord = {} - for k, v in record.items(): - # Keep _createdBy and mandateId, filter out other _ fields - if k == "_createdBy" or k == "mandateId" or not k.startswith('_'): - cleanedRecord[k] = v - cleanedRecords.append(cleanedRecord) - return cleanedRecords - else: - # For other models, filter out all database-specific fields - cleanedRecords = [] - for record in filteredRecords: - cleanedRecord = {k: v for k, v in record.items() if not k.startswith('_')} - cleanedRecords.append(cleanedRecord) - return cleanedRecords + + def checkRbacPermission( + self, + modelClass: type, + operation: str, + recordId: Optional[str] = None + ) -> bool: + """ + Check RBAC permission for a specific operation on a table. - def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool: - """Delegate to access control module.""" - return self.access.canModify(model_class, recordId) + Args: + modelClass: Pydantic model class for the table + operation: Operation to check ('create', 'update', 'delete', 'read') + recordId: Optional record ID for specific record check + + Returns: + Boolean indicating permission + """ + if not self.rbac or not self.currentUser: + return False + + tableName = modelClass.__name__ + permissions = self.rbac.getUserPermissions( + self.currentUser, + AccessRuleContext.DATA, + tableName + ) + + if operation == "create": + return permissions.create != AccessLevel.NONE + elif operation == "update": + return permissions.update != AccessLevel.NONE + elif operation == "delete": + return permissions.delete != AccessLevel.NONE + elif operation == "read": + return permissions.read != AccessLevel.NONE + else: + return False def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]: """ @@ -567,8 +580,11 @@ class ChatObjects: If pagination is None: List[Dict[str, Any]] If pagination is provided: PaginatedResult with items and metadata """ - allWorkflows = self.db.getRecordset(ChatWorkflow) - filteredWorkflows = self._uam(ChatWorkflow, allWorkflows) + # Use RBAC filtering + filteredWorkflows = self.db.getRecordsetWithRBAC( + ChatWorkflow, + self.currentUser + ) # If no pagination requested, return all items (no sorting - frontend handles it) if pagination is None: @@ -599,15 +615,17 @@ class ChatObjects: def getWorkflow(self, workflowId: str) -> Optional[ChatWorkflow]: """Returns a workflow by ID if user has access.""" - workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) + # Use RBAC filtering + workflows = self.db.getRecordsetWithRBAC( + ChatWorkflow, + self.currentUser, + recordFilter={"id": workflowId} + ) + if not workflows: return None - filteredWorkflows = self._uam(ChatWorkflow, workflows) - if not filteredWorkflows: - return None - - workflow = filteredWorkflows[0] + workflow = workflows[0] try: # Load related data from normalized tables logs = self.getLogs(workflowId) @@ -637,7 +655,7 @@ class ChatObjects: def createWorkflow(self, workflowData: Dict[str, Any]) -> ChatWorkflow: """Creates a new workflow if user has permission.""" - if not self._canModify(ChatWorkflow): + if not self.checkRbacPermission(ChatWorkflow, "create"): raise PermissionError("No permission to create workflows") # Set timestamp if not present @@ -682,7 +700,7 @@ class ChatObjects: if not workflow: return None - if not self._canModify(ChatWorkflow, workflowId): + if not self.checkRbacPermission(ChatWorkflow, "update", workflowId): raise PermissionError(f"No permission to update workflow {workflowId}") # Use generic field separation based on ChatWorkflow model @@ -728,7 +746,7 @@ class ChatObjects: if not workflow: return False - if not self._canModify(ChatWorkflow, workflowId): + if not self.checkRbacPermission(ChatWorkflow, "delete", workflowId): raise PermissionError(f"No permission to delete workflow {workflowId}") # CASCADE DELETE: Delete all related data first @@ -787,18 +805,18 @@ class ChatObjects: If pagination is provided: PaginatedResult with items and metadata """ # Check workflow access first (without calling getWorkflow to avoid circular reference) - workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) + # Use RBAC filtering + workflows = self.db.getRecordsetWithRBAC( + ChatWorkflow, + self.currentUser, + recordFilter={"id": workflowId} + ) + if not workflows: if pagination is None: return [] return PaginatedResult(items=[], totalItems=0, totalPages=0) - filteredWorkflows = self._uam(ChatWorkflow, workflows) - if not filteredWorkflows: - if pagination is None: - return [] - return PaginatedResult(items=[], totalItems=0, totalPages=0) - # Get messages for this workflow from normalized table messages = self.db.getRecordset(ChatMessage, recordFilter={"workflowId": workflowId}) @@ -938,7 +956,7 @@ class ChatObjects: if not workflow: raise PermissionError(f"No access to workflow {workflowId}") - if not self._canModify(ChatWorkflow, workflowId): + if not self.checkRbacPermission(ChatWorkflow, "update", workflowId): raise PermissionError(f"No permission to modify workflow {workflowId}") # Validate that ID is not None @@ -1054,7 +1072,7 @@ class ChatObjects: if not workflow: raise PermissionError(f"No access to workflow {workflowId}") - if not self._canModify(ChatWorkflow, workflowId): + if not self.checkRbacPermission(ChatWorkflow, "update", workflowId): raise PermissionError(f"No permission to modify workflow {workflowId}") logger.info(f"Creating new message with ID {messageId} for workflow {workflowId}") @@ -1072,7 +1090,7 @@ class ChatObjects: if not workflow: raise PermissionError(f"No access to workflow {workflowId}") - if not self._canModify(ChatWorkflow, workflowId): + if not self.checkRbacPermission(ChatWorkflow, "update", workflowId): raise PermissionError(f"No permission to modify workflow {workflowId}") # Use generic field separation based on ChatMessage model @@ -1132,7 +1150,7 @@ class ChatObjects: logger.warning(f"No access to workflow {workflowId}") return False - if not self._canModify(ChatWorkflow, workflowId): + if not self.checkRbacPermission(ChatWorkflow, "update", workflowId): raise PermissionError(f"No permission to modify workflow {workflowId}") # Check if the message exists @@ -1173,7 +1191,7 @@ class ChatObjects: logger.warning(f"No access to workflow {workflowId}") return False - if not self._canModify(ChatWorkflow, workflowId): + if not self.checkRbacPermission(ChatWorkflow, "update", workflowId): raise PermissionError(f"No permission to modify workflow {workflowId}") @@ -1257,18 +1275,18 @@ class ChatObjects: If pagination is provided: PaginatedResult with items and metadata """ # Check workflow access first (without calling getWorkflow to avoid circular reference) - workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) + # Use RBAC filtering + workflows = self.db.getRecordsetWithRBAC( + ChatWorkflow, + self.currentUser, + recordFilter={"id": workflowId} + ) + if not workflows: if pagination is None: return [] return PaginatedResult(items=[], totalItems=0, totalPages=0) - filteredWorkflows = self._uam(ChatWorkflow, workflows) - if not filteredWorkflows: - if pagination is None: - return [] - return PaginatedResult(items=[], totalItems=0, totalPages=0) - # Get logs for this workflow from normalized table logs = self.db.getRecordset(ChatLog, recordFilter={"workflowId": workflowId}) @@ -1335,7 +1353,7 @@ class ChatObjects: logger.warning(f"No access to workflow {workflowId}") return None - if not self._canModify(ChatWorkflow, workflowId): + if not self.checkRbacPermission(ChatWorkflow, "update", workflowId): logger.warning(f"No permission to modify workflow {workflowId}") return None @@ -1378,14 +1396,16 @@ class ChatObjects: def getStats(self, workflowId: str) -> List[ChatStat]: """Returns list of statistics for a workflow if user has access.""" # Check workflow access first (without calling getWorkflow to avoid circular reference) - workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) + # Use RBAC filtering + workflows = self.db.getRecordsetWithRBAC( + ChatWorkflow, + self.currentUser, + recordFilter={"id": workflowId} + ) + if not workflows: return [] - filteredWorkflows = self._uam(ChatWorkflow, workflows) - if not filteredWorkflows: - return [] - # Get stats for this workflow from normalized table stats = self.db.getRecordset(ChatStat, recordFilter={"workflowId": workflowId}) @@ -1423,13 +1443,15 @@ class ChatObjects: Uses timestamp-based selective data transfer for efficient polling. """ # Check workflow access first - workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId}) + # Use RBAC filtering + workflows = self.db.getRecordsetWithRBAC( + ChatWorkflow, + self.currentUser, + recordFilter={"id": workflowId} + ) + if not workflows: return {"items": []} - - filteredWorkflows = self._uam(ChatWorkflow, workflows) - if not filteredWorkflows: - return {"items": []} # Get all data types and filter in Python (PostgreSQL connector doesn't support $gt operators) items = [] @@ -1585,8 +1607,11 @@ class ChatObjects: Supports optional pagination, sorting, and filtering. Computes status field for each automation. """ - allAutomations = self.db.getRecordset(AutomationDefinition) - filteredAutomations = self._uam(AutomationDefinition, allAutomations) + # Use RBAC filtering + filteredAutomations = self.db.getRecordsetWithRBAC( + AutomationDefinition, + self.currentUser + ) # Compute status for each automation and normalize executionLogs for automation in filteredAutomations: @@ -1628,8 +1653,12 @@ class ChatObjects: def getAutomationDefinition(self, automationId: str) -> Optional[Dict[str, Any]]: """Returns an automation definition by ID if user has access, with computed status.""" try: - automations = self.db.getRecordset(AutomationDefinition, recordFilter={"id": automationId}) - filtered = self._uam(AutomationDefinition, automations) + # Use RBAC filtering + filtered = self.db.getRecordsetWithRBAC( + AutomationDefinition, + self.currentUser, + recordFilter={"id": automationId} + ) if not filtered: return None @@ -1695,7 +1724,7 @@ class ChatObjects: if not existing: raise PermissionError(f"No access to automation {automationId}") - if not self._canModify(AutomationDefinition, automationId): + if not self.checkRbacPermission(AutomationDefinition, "update", automationId): raise PermissionError(f"No permission to modify automation {automationId}") # Use generic field separation @@ -1726,7 +1755,7 @@ class ChatObjects: if not existing: raise PermissionError(f"No access to automation {automationId}") - if not self._canModify(AutomationDefinition, automationId): + if not self.checkRbacPermission(AutomationDefinition, "delete", automationId): raise PermissionError(f"No permission to delete automation {automationId}") # Remove event if exists diff --git a/modules/interfaces/interfaceDbComponentAccess.py b/modules/interfaces/interfaceDbComponentAccess.py deleted file mode 100644 index 36c3cfff..00000000 --- a/modules/interfaces/interfaceDbComponentAccess.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -Access control module for Management interface. -Handles user access management and permission checks. -""" - -import logging -from typing import Dict, Any, List, Optional -from modules.datamodels.datamodelUam import User -from modules.datamodels.datamodelUtils import Prompt -from modules.datamodels.datamodelFiles import FileItem -from modules.datamodels.datamodelChat import ChatWorkflow - -# Configure logger -logger = logging.getLogger(__name__) - -class ComponentAccess: - """ - Access control class for Management interface. - Handles user access management and permission checks. - """ - - def __init__(self, currentUser: User, db): - """Initialize with user context.""" - self.currentUser = currentUser - self.userId = currentUser.id - self.mandateId = currentUser.mandateId - self.privilege = currentUser.privilege - self.db = db - - def getInitialUserid(self): - return "----" - # return self.db.getInitialUserId() --> to get from AdminDB ! - - def canModifyAttribute(self, table: str, attribute: str) -> bool: - """ - Checks if the current user can modify a specific attribute in a table. - - Args: - table: Name of the table - attribute: Name of the attribute - - Returns: - Boolean indicating permission - """ - userPrivilege = self.privilege - - # Special case for mandateId in prompts table - if table == "prompts" and attribute == "mandateId": - return userPrivilege == "sysadmin" - - return True - - def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Unified user access management function that filters data based on user privileges - and adds access control attributes. - - Args: - model_class: Pydantic model class for the table - recordset: Recordset to filter based on access rules - - Returns: - Filtered recordset with access control attributes - """ - userPrivilege = self.privilege - table_name = model_class.__name__ - - filtered_records = [] - - initialid = self.getInitialUserid() - - # Apply filtering based on privilege - if userPrivilege == "sysadmin": - filtered_records = recordset # System admins see all records - elif userPrivilege == "admin": - # Admins see records in their mandate - filtered_records = [r for r in recordset if r.get("mandateId") == self.mandateId] - else: # Regular users - # For prompts, users can see all prompts from their mandate - if table_name == "Prompt": - filtered_records = [r for r in recordset if r.get("mandateId") == self.mandateId] - elif table_name == "UserInDB": - # For users table, users can only see their own record - filtered_records = [r for r in recordset if r.get("id") == self.userId] - elif table_name == "VoiceSettings": - # For voice settings, users can only see their own settings - filtered_records = [r for r in recordset if r.get("userId") == self.userId] - else: - # Users see only their records for other tables - filtered_records = [ - r for r in recordset - if r.get("mandateId") == self.mandateId and r.get("_createdBy") == self.userId - ] - - # Add access control attributes to each record - for record in filtered_records: - record_id = record.get("id") - - # Set access control flags based on user permissions - if table_name == "Prompt": - record["_hideView"] = False # Everyone can view - record["_hideEdit"] = not self.canModify(Prompt, record_id) - record["_hideDelete"] = not self.canModify(Prompt, record_id) - - # Add attribute-level permissions for mandateId - if "mandateId" in record: - record["_hideEdit_mandateId"] = not self.canModifyAttribute(Prompt, "mandateId") - elif table_name == "FileItem": - record["_hideView"] = False # Everyone can view - record["_hideEdit"] = not self.canModify(FileItem, record_id) - record["_hideDelete"] = not self.canModify(FileItem, record_id) - record["_hideDownload"] = not self.canModify(FileItem, record_id) - elif table_name == "ChatWorkflow": - record["_hideView"] = False # Everyone can view - record["_hideEdit"] = not self.canModify(ChatWorkflow, record_id) - record["_hideDelete"] = not self.canModify(ChatWorkflow, record_id) - elif table_name == "ChatMessage": - record["_hideView"] = False # Everyone can view - record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId")) - record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId")) - elif table_name == "ChatLog": - record["_hideView"] = False # Everyone can view - record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId")) - record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId")) - elif table_name == "UserInDB": - # For users table, users can only modify their own connections - record["_hideView"] = False - record["_hideEdit"] = record_id != self.userId - record["_hideDelete"] = record_id != self.userId - # Add connection-specific permissions - if "connections" in record: - for conn in record["connections"]: - conn["_hideEdit"] = record_id != self.userId - conn["_hideDelete"] = record_id != self.userId - elif table_name == "VoiceSettings": - # For voice settings, users can only access their own settings - record["_hideView"] = False - record["_hideEdit"] = record.get("userId") != self.userId - record["_hideDelete"] = record.get("userId") != self.userId - else: - # Default access control for other tables - record["_hideView"] = False - record["_hideEdit"] = not self.canModify(model_class, record_id) - record["_hideDelete"] = not self.canModify(model_class, record_id) - - return filtered_records - - def canModify(self, model_class: type, recordId: Optional[int] = None) -> bool: - """ - Checks if the current user can modify (create/update/delete) records in a table. - - Args: - model_class: Pydantic model class for the table - recordId: Optional record ID for specific record check - - Returns: - Boolean indicating permission - """ - userPrivilege = self.privilege - - # System admins can modify anything - if userPrivilege == "sysadmin": - return True - - # For regular users and admins, check specific cases - if recordId is not None: - # Get the record to check ownership - records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": recordId}) - if not records: - return False - - record = records[0] - - # Special case for users table - users can modify their own connections - if model_class.__name__ == "UserInDB": - if record.get("id") == self.userId: - return True - return False - - # Special case for voice settings - users can modify their own settings - if model_class.__name__ == "VoiceSettings": - if record.get("userId") == self.userId: - return True - return False - - # Admins can modify anything in their mandate, if mandate is specified for a record - if userPrivilege == "admin" and record.get("mandateId","-") == self.mandateId: - return True - - # Regular users can only modify their own records - if (record.get("mandateId","-") == self.mandateId and - record.get("_createdBy") == self.userId): - return True - - return False - else: - # For general modification permission (e.g., create) - # Admins can create anything in their mandate - if userPrivilege == "admin": - return True - - # Regular users can create in most tables - return True \ No newline at end of file diff --git a/modules/interfaces/interfaceDbComponentObjects.py b/modules/interfaces/interfaceDbComponentObjects.py index 225f8ad5..0e1be949 100644 --- a/modules/interfaces/interfaceDbComponentObjects.py +++ b/modules/interfaces/interfaceDbComponentObjects.py @@ -11,7 +11,9 @@ import math from typing import Dict, Any, List, Optional, Union from modules.connectors.connectorDbPostgre import DatabaseConnector -from modules.interfaces.interfaceDbComponentAccess import ComponentAccess +from modules.security.rbac import RbacClass +from modules.datamodels.datamodelRbac import AccessRuleContext +from modules.datamodels.datamodelUam import AccessLevel from modules.datamodels.datamodelFiles import FilePreview, FileItem, FileData from modules.datamodels.datamodelUtils import Prompt from modules.datamodels.datamodelVoice import VoiceSettings @@ -57,7 +59,7 @@ class ComponentObjects: # Initialize variables first self.currentUser: Optional[User] = None self.userId: Optional[str] = None - self.access: Optional[ComponentAccess] = None # Will be set when user context is provided + self.rbac: Optional[RbacClass] = None # RBAC interface # Initialize database self._initializeDatabase() @@ -80,8 +82,10 @@ class ComponentObjects: # Add language settings self.userLanguage = currentUser.language # Default user language - # Initialize access control with user context - self.access = ComponentAccess(self.currentUser, self.db) + # Initialize RBAC interface + if not self.currentUser: + raise ValueError("User context is required for RBAC") + self.rbac = RbacClass(self.db) # Update database context self.db.updateContext(self.userId) @@ -214,7 +218,6 @@ class ComponentObjects: else: self.currentUser = None self.userId = None - self.access = None self.db.updateContext("") # Reset database context except Exception as e: @@ -225,26 +228,46 @@ class ComponentObjects: else: self.currentUser = None self.userId = None - self.access = None self.db.updateContext("") # Reset database context - def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Delegate to access control module.""" - # First apply access control - filteredRecords = self.access.uam(model_class, recordset) - - # Then filter out database-specific fields - cleanedRecords = [] - for record in filteredRecords: - # Create a new dict with only non-database fields - cleanedRecord = {k: v for k, v in record.items() if not k.startswith('_')} - cleanedRecords.append(cleanedRecord) - - return cleanedRecords + + def checkRbacPermission( + self, + modelClass: type, + operation: str, + recordId: Optional[str] = None + ) -> bool: + """ + Check RBAC permission for a specific operation on a table. - def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool: - """Delegate to access control module.""" - return self.access.canModify(model_class, recordId) + Args: + modelClass: Pydantic model class for the table + operation: Operation to check ('create', 'update', 'delete', 'read') + recordId: Optional record ID for specific record check + + Returns: + Boolean indicating permission + """ + if not self.rbac or not self.currentUser: + return False + + tableName = modelClass.__name__ + permissions = self.rbac.getUserPermissions( + self.currentUser, + AccessRuleContext.DATA, + tableName + ) + + if operation == "create": + return permissions.create != AccessLevel.NONE + elif operation == "update": + return permissions.update != AccessLevel.NONE + elif operation == "delete": + return permissions.delete != AccessLevel.NONE + elif operation == "read": + return permissions.read != AccessLevel.NONE + else: + return False def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]: """ @@ -474,8 +497,11 @@ class ComponentObjects: If pagination is provided: PaginatedResult with items and metadata """ try: - allPrompts = self.db.getRecordset(Prompt) - filteredPrompts = self._uam(Prompt, allPrompts) + # Use RBAC filtering + filteredPrompts = self.db.getRecordsetWithRBAC( + Prompt, + self.currentUser + ) # If no pagination requested, return all items if pagination is None: @@ -515,16 +541,18 @@ class ComponentObjects: def getPrompt(self, promptId: str) -> Optional[Prompt]: """Returns a prompt by ID if user has access.""" - prompts = self.db.getRecordset(Prompt, recordFilter={"id": promptId}) - if not prompts: - return None + # Use RBAC filtering + filteredPrompts = self.db.getRecordsetWithRBAC( + Prompt, + self.currentUser, + recordFilter={"id": promptId} + ) - filteredPrompts = self._uam(Prompt, prompts) return Prompt(**filteredPrompts[0]) if filteredPrompts else None def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]: """Creates a new prompt if user has permission.""" - if not self._canModify(Prompt): + if not self.checkRbacPermission(Prompt, "create"): raise PermissionError("No permission to create prompts") # Create prompt record @@ -565,7 +593,7 @@ class ComponentObjects: if not prompt: return False - if not self._canModify(Prompt, promptId): + if not self.checkRbacPermission(Prompt, "update", promptId): raise PermissionError(f"No permission to delete prompt {promptId}") # Delete prompt @@ -580,13 +608,12 @@ class ComponentObjects: """Checks if a file with the same hash already exists for the current user and mandate. If fileName is provided, also checks for exact name+hash match. Only returns files the current user has access to.""" - # First get all files with the hash - allFilesWithHash = self.db.getRecordset(FileItem, recordFilter={ - "fileHash": fileHash - }) - - # Filter by user access using UAM - accessibleFiles = self._uam(FileItem, allFilesWithHash) + # Get files with the hash, filtered by RBAC + accessibleFiles = self.db.getRecordsetWithRBAC( + FileItem, + self.currentUser, + recordFilter={"fileHash": fileHash} + ) if not accessibleFiles: return None @@ -711,8 +738,11 @@ class ComponentObjects: If pagination is None: List[FileItem] If pagination is provided: PaginatedResult with items and metadata """ - allFiles = self.db.getRecordset(FileItem) - filteredFiles = self._uam(FileItem, allFiles) + # Use RBAC filtering + filteredFiles = self.db.getRecordsetWithRBAC( + FileItem, + self.currentUser + ) # Convert database records to FileItem instances (for both paginated and non-paginated) def convertFileItems(files): @@ -775,11 +805,13 @@ class ComponentObjects: def getFile(self, fileId: str) -> Optional[FileItem]: """Returns a file by ID if user has access.""" - files = self.db.getRecordset(FileItem, recordFilter={"id": fileId}) - if not files: - return None - - filteredFiles = self._uam(FileItem, files) + # Use RBAC filtering + filteredFiles = self.db.getRecordsetWithRBAC( + FileItem, + self.currentUser, + recordFilter={"id": fileId} + ) + if not filteredFiles: return None @@ -838,7 +870,7 @@ class ComponentObjects: def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem: """Creates a new file entry if user has permission. Computes fileHash and fileSize from content.""" - if not self._canModify(FileItem): + if not self.checkRbacPermission(FileItem, "create"): raise PermissionError("No permission to create files") # Ensure fileName is unique @@ -873,7 +905,7 @@ class ComponentObjects: if not file: raise FileNotFoundError(f"File with ID {fileId} not found") - if not self._canModify(FileItem, fileId): + if not self.checkRbacPermission(FileItem, "update", fileId): raise PermissionError(f"No permission to update file {fileId}") # If fileName is being updated, ensure it's unique @@ -895,7 +927,7 @@ class ComponentObjects: if not file: raise FileNotFoundError(f"File with ID {fileId} not found") - if not self._canModify(FileItem, fileId): + if not self.checkRbacPermission(FileItem, "update", fileId): raise PermissionError(f"No permission to delete file {fileId}") # Check for other references to this file (by hash) @@ -1090,7 +1122,7 @@ class ComponentObjects: """Saves an uploaded file if user has permission.""" try: # Check file creation permission - if not self._canModify(FileItem): + if not self.checkRbacPermission(FileItem, "create"): raise PermissionError("No permission to upload files") logger.debug(f"Starting upload process for file: {fileName}") @@ -1151,14 +1183,13 @@ class ComponentObjects: logger.error("No user ID provided for voice settings") return None - # Get voice settings for the user - settings = self.db.getRecordset(VoiceSettings, recordFilter={"userId": targetUserId}) - if not settings: - logger.debug(f"No voice settings found for user {targetUserId}") - return None + # Get voice settings for the user, filtered by RBAC + filteredSettings = self.db.getRecordsetWithRBAC( + VoiceSettings, + self.currentUser, + recordFilter={"userId": targetUserId} + ) - # Apply access control - filteredSettings = self._uam(VoiceSettings, settings) if not filteredSettings: logger.warning(f"No access to voice settings for user {targetUserId}") return None @@ -1179,7 +1210,7 @@ class ComponentObjects: def createVoiceSettings(self, settingsData: Dict[str, Any]) -> Dict[str, Any]: """Creates voice settings for a user if user has permission.""" try: - if not self._canModify(VoiceSettings): + if not self.checkRbacPermission(VoiceSettings, "update"): raise PermissionError("No permission to create voice settings") # Ensure userId is set diff --git a/modules/migration/__init__.py b/modules/migration/__init__.py new file mode 100644 index 00000000..49056d7c --- /dev/null +++ b/modules/migration/__init__.py @@ -0,0 +1 @@ +"""Migration modules for database schema and data migrations.""" diff --git a/modules/migration/migrateUamToRbac.py b/modules/migration/migrateUamToRbac.py new file mode 100644 index 00000000..688bf8e7 --- /dev/null +++ b/modules/migration/migrateUamToRbac.py @@ -0,0 +1,212 @@ +""" +Migration script to convert UAM (User Access Management) to RBAC (Role-Based Access Control). + +This script: +1. Creates AccessRule table if it doesn't exist +2. Adds roleLabels column to User table if it doesn't exist +3. Converts User.privilege to User.roleLabels +4. Creates initial RBAC rules based on bootstrap logic +""" + +import logging +from typing import List, Dict, Any +from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.shared.configuration import APP_CONFIG +from modules.datamodels.datamodelUam import UserInDB, UserPrivilege +from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext +from modules.datamodels.datamodelUam import AccessLevel +from modules.interfaces.interfaceBootstrap import initRbacRules + +logger = logging.getLogger(__name__) + + +def migrateUamToRbac(db: DatabaseConnector, dryRun: bool = False) -> Dict[str, Any]: + """ + Migrate from UAM to RBAC system. + + Args: + db: Database connector instance + dryRun: If True, only report what would be done without making changes + + Returns: + Dictionary with migration results + """ + results = { + "schemaChanges": [], + "dataMigrations": [], + "rulesCreated": 0, + "usersUpdated": 0, + "errors": [] + } + + try: + # Step 1: Ensure AccessRule table exists + logger.info("Step 1: Ensuring AccessRule table exists") + if not dryRun: + db._ensureTableExists(AccessRule) + results["schemaChanges"].append("AccessRule table ensured") + else: + results["schemaChanges"].append("Would ensure AccessRule table") + + # Step 2: Add roleLabels column to UserInDB table if it doesn't exist + logger.info("Step 2: Adding roleLabels column to UserInDB table") + if not dryRun: + try: + with db.connection.cursor() as cursor: + # Check if column exists + cursor.execute(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'UserInDB' AND column_name = 'roleLabels' + """) + columnExists = cursor.fetchone() is not None + + if not columnExists: + cursor.execute('ALTER TABLE "UserInDB" ADD COLUMN "roleLabels" JSONB DEFAULT \'[]\'::jsonb') + db.connection.commit() + results["schemaChanges"].append("Added roleLabels column to UserInDB") + logger.info("Added roleLabels column to UserInDB table") + else: + results["schemaChanges"].append("roleLabels column already exists") + logger.info("roleLabels column already exists in UserInDB table") + except Exception as e: + logger.error(f"Error adding roleLabels column: {e}") + results["errors"].append(f"Error adding roleLabels column: {e}") + db.connection.rollback() + else: + results["schemaChanges"].append("Would add roleLabels column to UserInDB") + + # Step 3: Convert User.privilege to User.roleLabels + logger.info("Step 3: Converting User.privilege to User.roleLabels") + if not dryRun: + try: + users = db.getRecordset(UserInDB) + updatedCount = 0 + + for user in users: + privilege = user.get("privilege") + roleLabels = user.get("roleLabels", []) + + # Skip if already has roleLabels + if roleLabels and isinstance(roleLabels, list) and len(roleLabels) > 0: + logger.debug(f"User {user.get('id')} already has roleLabels: {roleLabels}") + continue + + # Convert privilege to roleLabels + if privilege == UserPrivilege.SYSADMIN.value: + newRoleLabels = ["sysadmin"] + elif privilege == UserPrivilege.ADMIN.value: + newRoleLabels = ["admin"] + elif privilege == UserPrivilege.USER.value: + newRoleLabels = ["user"] + else: + # Default to user if privilege is unknown + newRoleLabels = ["user"] + logger.warning(f"Unknown privilege '{privilege}' for user {user.get('id')}, defaulting to 'user'") + + # Update user + user["roleLabels"] = newRoleLabels + db.recordModify(UserInDB, user["id"], user) + updatedCount += 1 + logger.info(f"Updated user {user.get('id')} ({user.get('username')}): {privilege} -> {newRoleLabels}") + + results["usersUpdated"] = updatedCount + logger.info(f"Updated {updatedCount} users with roleLabels") + except Exception as e: + logger.error(f"Error converting user privileges: {e}") + results["errors"].append(f"Error converting user privileges: {e}") + else: + # Dry run: count users that would be updated + users = db.getRecordset(UserInDB) + wouldUpdate = 0 + for user in users: + roleLabels = user.get("roleLabels", []) + if not roleLabels or not isinstance(roleLabels, list) or len(roleLabels) == 0: + wouldUpdate += 1 + results["usersUpdated"] = wouldUpdate + logger.info(f"Would update {wouldUpdate} users with roleLabels") + + # Step 4: Create RBAC rules if they don't exist + logger.info("Step 4: Creating RBAC rules") + if not dryRun: + try: + existingRules = db.getRecordset(AccessRule) + if existingRules: + results["rulesCreated"] = len(existingRules) + results["dataMigrations"].append(f"RBAC rules already exist ({len(existingRules)} rules)") + logger.info(f"RBAC rules already exist ({len(existingRules)} rules)") + else: + # Initialize RBAC rules using bootstrap logic + initRbacRules(db) + newRules = db.getRecordset(AccessRule) + results["rulesCreated"] = len(newRules) + results["dataMigrations"].append(f"Created {len(newRules)} RBAC rules") + logger.info(f"Created {len(newRules)} RBAC rules") + except Exception as e: + logger.error(f"Error creating RBAC rules: {e}") + results["errors"].append(f"Error creating RBAC rules: {e}") + else: + existingRules = db.getRecordset(AccessRule) + if existingRules: + results["rulesCreated"] = len(existingRules) + results["dataMigrations"].append(f"RBAC rules already exist ({len(existingRules)} rules)") + else: + results["dataMigrations"].append("Would create RBAC rules") + + logger.info("Migration completed successfully") + return results + + except Exception as e: + logger.error(f"Migration failed: {e}") + results["errors"].append(f"Migration failed: {e}") + return results + + +def validateMigration(db: DatabaseConnector) -> Dict[str, Any]: + """ + Validate that migration was successful. + + Args: + db: Database connector instance + + Returns: + Dictionary with validation results + """ + validation = { + "valid": True, + "issues": [] + } + + try: + # Check that AccessRule table exists + try: + rules = db.getRecordset(AccessRule) + if not rules: + validation["valid"] = False + validation["issues"].append("AccessRule table exists but has no rules") + except Exception as e: + validation["valid"] = False + validation["issues"].append(f"AccessRule table does not exist or is not accessible: {e}") + + # Check that all users have roleLabels + users = db.getRecordset(UserInDB) + usersWithoutRoles = [] + for user in users: + roleLabels = user.get("roleLabels", []) + if not roleLabels or not isinstance(roleLabels, list) or len(roleLabels) == 0: + usersWithoutRoles.append({ + "id": user.get("id"), + "username": user.get("username"), + "privilege": user.get("privilege") + }) + + if usersWithoutRoles: + validation["valid"] = False + validation["issues"].append(f"{len(usersWithoutRoles)} users without roleLabels: {[u['username'] for u in usersWithoutRoles]}") + + return validation + + except Exception as e: + validation["valid"] = False + validation["issues"].append(f"Validation error: {e}") + return validation diff --git a/modules/routes/routeDataFiles.py b/modules/routes/routeDataFiles.py index 7c0f60c0..5cdfcfc5 100644 --- a/modules/routes/routeDataFiles.py +++ b/modules/routes/routeDataFiles.py @@ -229,8 +229,8 @@ async def update_file( detail=f"File with ID {fileId} not found" ) - # Check if user has access to the file using the interface's permission system - if not managementInterface._canModify("files", fileId): + # Check if user has access to the file using RBAC + if not managementInterface.checkRbacPermission(FileItem, "update", fileId): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to update this file" diff --git a/modules/routes/routeRbac.py b/modules/routes/routeRbac.py new file mode 100644 index 00000000..95184779 --- /dev/null +++ b/modules/routes/routeRbac.py @@ -0,0 +1,161 @@ +""" +RBAC routes for the backend API. +Implements endpoints for role-based access control permissions. +""" + +from fastapi import APIRouter, HTTPException, Depends, Query, Request +from typing import Optional +import logging + +from modules.security.auth import getCurrentUser, limiter +from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel +from modules.datamodels.datamodelRbac import AccessRuleContext +from modules.interfaces.interfaceDbAppObjects import getInterface + +# Configure logger +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/api/rbac", + tags=["RBAC"], + responses={404: {"description": "Not found"}} + ) + + +@router.get("/permissions", response_model=UserPermissions) +@limiter.limit("60/minute") +async def getPermissions( + request: Request, + context: str = Query(..., description="Context type: DATA, UI, or RESOURCE"), + item: Optional[str] = Query(None, description="Item identifier (table name, UI path, or resource path)"), + currentUser: User = Depends(getCurrentUser) + ) -> UserPermissions: + """ + Get RBAC permissions for the current user for a specific context and item. + + Query Parameters: + - context: Context type (DATA, UI, or RESOURCE) + - item: Optional item identifier. For DATA: table name (e.g., "UserInDB"), + For UI: cascading string (e.g., "playground.voice.settings"), + For RESOURCE: cascading string (e.g., "ai.model.anthropic") + + Returns: + - UserPermissions object with view, read, create, update, delete permissions + + Examples: + - GET /api/rbac/permissions?context=DATA&item=UserInDB + - GET /api/rbac/permissions?context=UI&item=playground.voice.settings + - GET /api/rbac/permissions?context=RESOURCE&item=ai.model.anthropic + """ + try: + # Validate context + try: + accessContext = AccessRuleContext(context.upper()) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Invalid context '{context}'. Must be one of: DATA, UI, RESOURCE" + ) + + # Get interface and RBAC permissions + interface = getInterface(currentUser) + if not interface.rbac: + raise HTTPException( + status_code=500, + detail="RBAC interface not available" + ) + + # Get permissions + permissions = interface.rbac.getUserPermissions( + currentUser, + accessContext, + item or "" + ) + + return permissions + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting RBAC permissions: {str(e)}") + raise HTTPException( + status_code=500, + detail=f"Failed to get permissions: {str(e)}" + ) + + +@router.get("/rules", response_model=list) +@limiter.limit("30/minute") +async def getAccessRules( + request: Request, + roleLabel: Optional[str] = Query(None, description="Filter by role label"), + context: Optional[str] = Query(None, description="Filter by context (DATA, UI, RESOURCE)"), + item: Optional[str] = Query(None, description="Filter by item identifier"), + currentUser: User = Depends(getCurrentUser) + ) -> list: + """ + Get access rules with optional filters. + Only returns rules that the current user has permission to view. + + Query Parameters: + - roleLabel: Optional role label filter + - context: Optional context filter (DATA, UI, RESOURCE) + - item: Optional item filter + + Returns: + - List of AccessRule objects + """ + try: + # Get interface + interface = getInterface(currentUser) + + # Check if user has permission to view access rules + # For now, only sysadmin can view rules + if not interface.rbac: + raise HTTPException( + status_code=500, + detail="RBAC interface not available" + ) + + # Check permission - only sysadmin can view rules + permissions = interface.rbac.getUserPermissions( + currentUser, + AccessRuleContext.DATA, + "AccessRule" + ) + + if not permissions.view or permissions.read == AccessLevel.NONE: + raise HTTPException( + status_code=403, + detail="No permission to view access rules" + ) + + # Parse context if provided + accessContext = None + if context: + try: + accessContext = AccessRuleContext(context.upper()) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Invalid context '{context}'. Must be one of: DATA, UI, RESOURCE" + ) + + # Get rules + rules = interface.getAccessRules( + roleLabel=roleLabel, + context=accessContext, + item=item + ) + + # Convert to dict for JSON serialization + return [rule.model_dump() for rule in rules] + + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting access rules: {str(e)}") + raise HTTPException( + status_code=500, + detail=f"Failed to get access rules: {str(e)}" + ) diff --git a/modules/routes/routeWorkflows.py b/modules/routes/routeWorkflows.py index ea52a067..080e8077 100644 --- a/modules/routes/routeWorkflows.py +++ b/modules/routes/routeWorkflows.py @@ -180,8 +180,8 @@ async def update_workflow( workflow_data = workflows[0] - # Check if user has permission to update using the interface's permission system - if not workflowInterface._canModify("workflows", workflowId): + # Check if user has permission to update using RBAC + if not workflowInterface.checkRbacPermission(ChatWorkflow, "update", workflowId): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to update this workflow" @@ -437,8 +437,8 @@ async def delete_workflow( workflow_data = workflows[0] - # Check if user has permission to delete using the interface's permission system - if not interfaceDbChat._canModify("workflows", workflowId): + # Check if user has permission to delete using RBAC + if not interfaceDbChat.checkRbacPermission(ChatWorkflow, "delete", workflowId): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You don't have permission to delete this workflow" diff --git a/modules/security/rbac.py b/modules/security/rbac.py new file mode 100644 index 00000000..ca2050de --- /dev/null +++ b/modules/security/rbac.py @@ -0,0 +1,194 @@ +""" +RBAC interface: Core RBAC logic and permission resolution. +Moved from interfaces to security module to maintain proper architectural layering. +Connectors can import from security, but not from interfaces. +""" + +import logging +from typing import List, Optional, Dict, Any, TYPE_CHECKING +from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext +from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel + +if TYPE_CHECKING: + from modules.connectors.connectorDbPostgre import DatabaseConnector + +logger = logging.getLogger(__name__) + + +class RbacClass: + """ + RBAC interface for permission resolution and rule validation. + """ + + def __init__(self, db: "DatabaseConnector"): + """Initialize RBAC interface with database connector.""" + self.db = db + + def getUserPermissions(self, user: User, context: AccessRuleContext, item: str) -> UserPermissions: + """ + Get combined permissions for a user across all their roles. + + Args: + user: User object with roleLabels + context: Access rule context (DATA, UI, RESOURCE) + item: Item identifier (table name, UI path, resource path) + + Returns: + UserPermissions object with combined permissions + """ + permissions = UserPermissions( + view=False, + read=AccessLevel.NONE, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE + ) + + if not user.roleLabels: + logger.warning(f"User {user.id} has no roleLabels assigned") + return permissions + + # Step 1: For each role, find the most specific matching rule (most specific wins within role) + rolePermissions = {} + for roleLabel in user.roleLabels: + # Get all rules for this role and context + allRules = self._getRulesForRole(roleLabel, context) + + # Find most specific rule for this item (longest matching prefix) + mostSpecificRule = self.findMostSpecificRule(allRules, item) + + if mostSpecificRule: + rolePermissions[roleLabel] = mostSpecificRule + + # Step 2: Combine permissions across roles using opening (union) logic + for roleLabel, rule in rolePermissions.items(): + # View: union logic - if ANY role has view=true, then view=true + if rule.view: + permissions.view = True + + if context == AccessRuleContext.DATA: + # For DATA context, use most permissive access level across roles + if rule.read and self._isMorePermissive(rule.read, permissions.read): + permissions.read = rule.read + if rule.create and self._isMorePermissive(rule.create, permissions.create): + permissions.create = rule.create + if rule.update and self._isMorePermissive(rule.update, permissions.update): + permissions.update = rule.update + if rule.delete and self._isMorePermissive(rule.delete, permissions.delete): + permissions.delete = rule.delete + + return permissions + + def findMostSpecificRule(self, rules: List[AccessRule], item: str) -> Optional[AccessRule]: + """ + Find the most specific rule for an item (longest matching prefix wins). + + Args: + rules: List of access rules to search + item: Item identifier to match + + Returns: + Most specific matching rule, or None if no match + """ + if not item: + # If no item specified, return generic rule (item = null) + genericRules = [r for r in rules if r.item is None] + return genericRules[0] if genericRules else None + + # Find longest matching prefix + itemParts = item.split(".") + bestMatch = None + bestMatchLength = -1 + + for rule in rules: + if rule.item is None: + # Generic rule - use as fallback if no specific match found + if bestMatch is None: + bestMatch = rule + elif rule.item == item: + # Exact match - most specific + return rule + elif item.startswith(rule.item + "."): + # Prefix match - check if it's longer than current best + matchLength = len(rule.item.split(".")) + if matchLength > bestMatchLength: + bestMatch = rule + bestMatchLength = matchLength + + return bestMatch + + def validateAccessRule(self, rule: AccessRule) -> bool: + """ + Validate that CUD permissions are allowed by read permission level (only for DATA context). + + Args: + rule: AccessRule to validate + + Returns: + True if rule is valid, False otherwise + """ + if rule.context != AccessRuleContext.DATA: + # For UI and RESOURCE contexts, only view is relevant + return True + + if rule.read is None: + return False # DATA context requires read permission + + readLevel = AccessLevel(rule.read) + + # CUD operations are only allowed if read permission exists + for operation in [rule.create, rule.update, rule.delete]: + if operation is None or operation == AccessLevel.NONE.value: + continue # No access is always valid + if readLevel == AccessLevel.NONE: + return False # No CUD allowed if no read access + if readLevel == AccessLevel.MY and operation not in [AccessLevel.NONE.value, AccessLevel.MY.value]: + return False + if readLevel == AccessLevel.GROUP and operation not in [AccessLevel.NONE.value, AccessLevel.MY.value, AccessLevel.GROUP.value]: + return False + + return True + + def _isMorePermissive(self, level1: AccessLevel, level2: AccessLevel) -> bool: + """ + Check if level1 is more permissive than level2. + + Args: + level1: First access level + level2: Second access level + + Returns: + True if level1 is more permissive than level2 + """ + hierarchy = { + AccessLevel.NONE: 0, + AccessLevel.MY: 1, + AccessLevel.GROUP: 2, + AccessLevel.ALL: 3 + } + return hierarchy.get(level1, 0) > hierarchy.get(level2, 0) + + def _getRulesForRole(self, roleLabel: str, context: AccessRuleContext) -> List[AccessRule]: + """ + Get all access rules for a specific role and context. + + Args: + roleLabel: Role label to get rules for + context: Context type + + Returns: + List of AccessRule objects + """ + try: + rules = self.db.getRecordset( + AccessRule, + recordFilter={ + "roleLabel": roleLabel, + "context": context.value + } + ) + # Convert dict records to AccessRule objects + return [AccessRule(**record) for record in rules] + except Exception as e: + logger.error(f"Error getting rules for role {roleLabel} and context {context.value}: {e}") + return [] diff --git a/modules/shared/rbacHelpers.py b/modules/shared/rbacHelpers.py new file mode 100644 index 00000000..843a588a --- /dev/null +++ b/modules/shared/rbacHelpers.py @@ -0,0 +1,178 @@ +""" +RBAC helper functions for resource access control. +Provides convenient functions for checking permissions in feature modules. +""" + +import logging +from typing import Optional +from modules.datamodels.datamodelUam import User, AccessLevel +from modules.datamodels.datamodelRbac import AccessRuleContext +from modules.security.rbac import RbacClass +from modules.connectors.connectorDbPostgre import DatabaseConnector + +logger = logging.getLogger(__name__) + + +def checkResourceAccess( + RbacInstance: RbacClass, + currentUser: User, + resourcePath: str +) -> bool: + """ + Check if user has access to a resource. + + Args: + RbacInstance: RbacClass instance + currentUser: Current user object + resourcePath: Resource path (e.g., "ai.model.anthropic", "ai.action.jira") + + Returns: + True if user has view permission for the resource, False otherwise + """ + try: + permissions = RbacInstance.getUserPermissions( + currentUser, + AccessRuleContext.RESOURCE, + resourcePath + ) + return permissions.view + except Exception as e: + logger.error(f"Error checking resource access for {resourcePath}: {e}") + return False + + +def checkUiAccess( + RbacInstance: RbacClass, + currentUser: User, + uiPath: str +) -> bool: + """ + Check if user has access to a UI element. + + Args: + RbacInstance: RbacClass instance + currentUser: Current user object + uiPath: UI path (e.g., "playground.voice.settings", "chatbot.search") + + Returns: + True if user has view permission for the UI element, False otherwise + """ + try: + permissions = RbacInstance.getUserPermissions( + currentUser, + AccessRuleContext.UI, + uiPath + ) + return permissions.view + except Exception as e: + logger.error(f"Error checking UI access for {uiPath}: {e}") + return False + + +def checkDataAccess( + RbacInstance: RbacClass, + currentUser: User, + tableName: str, + operation: str = "read" +) -> bool: + """ + Check if user has access to a data table for a specific operation. + + Args: + RbacInstance: RbacClass instance + currentUser: Current user object + tableName: Table name (e.g., "UserInDB", "Mandate") + operation: Operation to check ("read", "create", "update", "delete") + + Returns: + True if user has permission for the operation, False otherwise + """ + try: + permissions = RbacInstance.getUserPermissions( + currentUser, + AccessRuleContext.DATA, + tableName + ) + + if operation == "read": + return permissions.read != AccessLevel.NONE + elif operation == "create": + return permissions.create != AccessLevel.NONE + elif operation == "update": + return permissions.update != AccessLevel.NONE + elif operation == "delete": + return permissions.delete != AccessLevel.NONE + else: + logger.warning(f"Unknown operation: {operation}") + return False + except Exception as e: + logger.error(f"Error checking data access for {tableName}: {e}") + return False + + +def getResourcePermissions( + RbacInstance: RbacClass, + currentUser: User, + resourcePath: str +) -> dict: + """ + Get full permissions for a resource. + + Args: + RbacInstance: RbacClass instance + currentUser: Current user object + resourcePath: Resource path (e.g., "ai.model.anthropic") + + Returns: + Dictionary with permission information + """ + try: + permissions = RbacInstance.getUserPermissions( + currentUser, + AccessRuleContext.RESOURCE, + resourcePath + ) + return { + "view": permissions.view, + "hasAccess": permissions.view + } + except Exception as e: + logger.error(f"Error getting resource permissions for {resourcePath}: {e}") + return { + "view": False, + "hasAccess": False + } + + +def getUiPermissions( + RbacInstance: RbacClass, + currentUser: User, + uiPath: str +) -> dict: + """ + Get full permissions for a UI element. + + Args: + RbacInstance: RbacClass instance + currentUser: Current user object + uiPath: UI path (e.g., "playground.voice.settings") + + Returns: + Dictionary with permission information + """ + try: + permissions = RbacInstance.getUserPermissions( + currentUser, + AccessRuleContext.UI, + uiPath + ) + return { + "view": permissions.view, + "hasAccess": permissions.view + } + except Exception as e: + logger.error(f"Error getting UI permissions for {uiPath}: {e}") + return { + "view": False, + "hasAccess": False + } diff --git a/pytest.ini b/pytest.ini index ae59338f..ad1e22f2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,7 +3,7 @@ testpaths = tests pythonpath = . python_files = test_*.py python_classes = Test* -python_functions = test_* +python_functions = test* log_file = logs/test_logs.log log_file_level = INFO log_file_format = %(asctime)s %(levelname)s %(message)s diff --git a/tests/integration/rbac/README.md b/tests/integration/rbac/README.md new file mode 100644 index 00000000..0c866c1d --- /dev/null +++ b/tests/integration/rbac/README.md @@ -0,0 +1,42 @@ +# RBAC Integration Tests + +Integration tests for the Role-Based Access Control (RBAC) system. + +## Test Files + +### `test_rbac_database.py` +Tests RBAC database filtering: +- WHERE clause building for ALL access level +- WHERE clause building for MY access level +- WHERE clause building for GROUP access level +- WHERE clause building for NONE access level +- Special handling for UserInDB table +- Special handling for UserConnection table + +### `test_rbac_migration.py` +Tests UAM to RBAC migration: +- User privilege to roleLabels conversion +- Skipping users with existing roleLabels +- Dry run mode +- Migration validation +- Validation failure scenarios + +## Running Tests + +```bash +# Run all RBAC integration tests +pytest tests/integration/rbac/ + +# Run specific test file +pytest tests/integration/rbac/test_rbac_database.py + +# Run with verbose output +pytest tests/integration/rbac/ -v +``` + +## Test Coverage + +- Database query filtering with RBAC +- SQL WHERE clause generation +- Migration script functionality +- Data validation after migration diff --git a/tests/integration/rbac/__init__.py b/tests/integration/rbac/__init__.py new file mode 100644 index 00000000..32a3a0b9 --- /dev/null +++ b/tests/integration/rbac/__init__.py @@ -0,0 +1 @@ +"""Integration tests for RBAC system.""" diff --git a/tests/integration/rbac/test_rbac_database.py b/tests/integration/rbac/test_rbac_database.py new file mode 100644 index 00000000..34a51c30 --- /dev/null +++ b/tests/integration/rbac/test_rbac_database.py @@ -0,0 +1,209 @@ +""" +Integration tests for RBAC database filtering. +Tests that database queries correctly filter records based on RBAC rules. +Uses real database connection for integration testing. +""" + +import pytest +from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions +from modules.shared.configuration import APP_CONFIG + + +@pytest.fixture(scope="class") +def db(): + """Create real database connector for integration tests.""" + dbHost = APP_CONFIG.get("DB_HOST", "localhost") + dbDatabase = APP_CONFIG.get("DB_DATABASE", "poweron_test") + dbUser = APP_CONFIG.get("DB_USER", "postgres") + dbPassword = APP_CONFIG.get("DB_PASSWORD", "") + dbPort = APP_CONFIG.get("DB_PORT", 5432) + + db = DatabaseConnector( + dbHost=dbHost, + dbDatabase=dbDatabase, + dbUser=dbUser, + dbPassword=dbPassword, + dbPort=dbPort + ) + yield db + db.close() + + +class TestRbacDatabaseFiltering: + """Test RBAC database filtering.""" + + def testBuildRbacWhereClauseAllAccess(self, db): + """Test WHERE clause building for ALL access level.""" + + permissions = UserPermissions( + view=True, + read=AccessLevel.ALL, + create=AccessLevel.ALL, + update=AccessLevel.ALL, + delete=AccessLevel.ALL + ) + + user = User( + id="test_user_all", + username="testuser", + roleLabels=["sysadmin"], + mandateId="test_mandate_all" + ) + + whereClause = db.buildRbacWhereClause(permissions, user, "SomeTable") + + # ALL access should return None (no filtering) + assert whereClause is None + + def testBuildRbacWhereClauseMyAccess(self, db): + """Test WHERE clause building for MY access level.""" + + permissions = UserPermissions( + view=True, + read=AccessLevel.MY, + create=AccessLevel.MY, + update=AccessLevel.MY, + delete=AccessLevel.MY + ) + + user = User( + id="test_user_my", + username="testuser", + roleLabels=["user"], + mandateId="test_mandate_my" + ) + + whereClause = db.buildRbacWhereClause(permissions, user, "SomeTable") + + assert whereClause is not None + assert whereClause["condition"] == '"_createdBy" = %s' + assert whereClause["values"] == ["test_user_my"] + + def testBuildRbacWhereClauseGroupAccess(self, db): + """Test WHERE clause building for GROUP access level.""" + + permissions = UserPermissions( + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.GROUP + ) + + user = User( + id="test_user_group", + username="testuser", + roleLabels=["admin"], + mandateId="test_mandate_group" + ) + + whereClause = db.buildRbacWhereClause(permissions, user, "SomeTable") + + assert whereClause is not None + assert whereClause["condition"] == '"mandateId" = %s' + assert whereClause["values"] == ["test_mandate_group"] + + def testBuildRbacWhereClauseNoAccess(self, db): + """Test WHERE clause building for NONE access level.""" + + permissions = UserPermissions( + view=True, + read=AccessLevel.NONE, + create=AccessLevel.NONE, + update=AccessLevel.NONE, + delete=AccessLevel.NONE + ) + + user = User( + id="test_user_none", + username="testuser", + roleLabels=["viewer"], + mandateId="test_mandate_none" + ) + + whereClause = db.buildRbacWhereClause(permissions, user, "SomeTable") + + assert whereClause is not None + assert whereClause["condition"] == "1 = 0" # Always false + assert whereClause["values"] == [] + + def testBuildRbacWhereClauseUserInDBTable(self, db): + """Test WHERE clause building for UserInDB table with MY access.""" + + permissions = UserPermissions( + view=True, + read=AccessLevel.MY, + create=AccessLevel.MY, + update=AccessLevel.MY, + delete=AccessLevel.MY + ) + + user = User( + id="test_user_in_db", + username="testuser", + roleLabels=["user"], + mandateId="test_mandate_in_db" + ) + + whereClause = db.buildRbacWhereClause(permissions, user, "UserInDB") + + # UserInDB with MY access should filter by id field + assert whereClause is not None + assert whereClause["condition"] == '"id" = %s' + assert whereClause["values"] == ["test_user_in_db"] + + def testBuildRbacWhereClauseUserConnectionTable(self, db): + """Test WHERE clause building for UserConnection table with GROUP access.""" + # Create test users in the same mandate for GROUP access testing + from modules.datamodels.datamodelUam import UserInDB + testMandateId = "test_mandate_group" + + # Create test users + user1 = UserInDB( + id="test_user1", + username="testuser1", + mandateId=testMandateId + ) + user2 = UserInDB( + id="test_user2", + username="testuser2", + mandateId=testMandateId + ) + + try: + user1Data = user1.model_dump() + user1Data["id"] = user1.id + user2Data = user2.model_dump() + user2Data["id"] = user2.id + db.recordCreate(UserInDB, user1Data) + db.recordCreate(UserInDB, user2Data) + + permissions = UserPermissions( + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.GROUP + ) + + user = User( + id="test_user1", + username="testuser1", + roleLabels=["admin"], + mandateId=testMandateId + ) + + whereClause = db.buildRbacWhereClause(permissions, user, "UserConnection") + + assert whereClause is not None + assert "userId" in whereClause["condition"] + assert "IN" in whereClause["condition"] + assert len(whereClause["values"]) >= 2 + finally: + # Cleanup test users + try: + db.recordDelete(UserInDB, "test_user1") + db.recordDelete(UserInDB, "test_user2") + except: + pass diff --git a/tests/integration/rbac/test_rbac_migration.py b/tests/integration/rbac/test_rbac_migration.py new file mode 100644 index 00000000..86f3eb6d --- /dev/null +++ b/tests/integration/rbac/test_rbac_migration.py @@ -0,0 +1,282 @@ +""" +Integration tests for UAM to RBAC migration. +Tests that migration correctly converts user privileges to roleLabels. +Uses real database connection for integration testing. +""" + +import pytest +from modules.migration.migrateUamToRbac import migrateUamToRbac, validateMigration +from modules.datamodels.datamodelUam import UserInDB, UserPrivilege +from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.shared.configuration import APP_CONFIG + + +@pytest.fixture(scope="class") +def db(): + """Create real database connector for integration tests.""" + dbHost = APP_CONFIG.get("DB_HOST", "localhost") + dbDatabase = APP_CONFIG.get("DB_DATABASE", "poweron_test") + dbUser = APP_CONFIG.get("DB_USER", "postgres") + dbPassword = APP_CONFIG.get("DB_PASSWORD", "") + dbPort = APP_CONFIG.get("DB_PORT", 5432) + + db = DatabaseConnector( + dbHost=dbHost, + dbDatabase=dbDatabase, + dbUser=dbUser, + dbPassword=dbPassword, + dbPort=dbPort + ) + yield db + db.close() + + +class TestRbacMigration: + """Test RBAC migration from UAM.""" + + def testMigrateUserPrivilegeToRoleLabels(self, db): + """Test that user privileges are correctly converted to roleLabels.""" + # Create test users with privileges but no roleLabels + testUsers = [ + UserInDB( + id="migrate_test_user1", + username="migrate_admin", + privilege=UserPrivilege.SYSADMIN.value + ), + UserInDB( + id="migrate_test_user2", + username="migrate_admin2", + privilege=UserPrivilege.ADMIN.value + ), + UserInDB( + id="migrate_test_user3", + username="migrate_user1", + privilege=UserPrivilege.USER.value + ) + ] + + try: + # Create test users in database + for user in testUsers: + userData = user.model_dump() + # Ensure roleLabels is None/empty for migration test + userData["roleLabels"] = [] + userData["id"] = user.id + db.recordCreate(UserInDB, userData) + + # Run migration + results = migrateUamToRbac(db, dryRun=False) + + # Check that users were updated + assert results["usersUpdated"] == 3 + + # Verify users were actually updated in database + users1 = db.getRecordset(UserInDB, recordFilter={"id": "migrate_test_user1"}) + users2 = db.getRecordset(UserInDB, recordFilter={"id": "migrate_test_user2"}) + users3 = db.getRecordset(UserInDB, recordFilter={"id": "migrate_test_user3"}) + user1 = users1[0] if users1 else None + user2 = users2[0] if users2 else None + user3 = users3[0] if users3 else None + + assert user1 is not None + assert "sysadmin" in user1.get("roleLabels", []) + + assert user2 is not None + assert "admin" in user2.get("roleLabels", []) + + assert user3 is not None + assert "user" in user3.get("roleLabels", []) + finally: + # Cleanup test users + for user in testUsers: + try: + db.recordDelete(UserInDB, user.id) + except: + pass + + def testMigrationSkipsUsersWithExistingRoleLabels(self, db): + """Test that migration skips users who already have roleLabels.""" + # Create test users: one with roleLabels, one without + user1 = UserInDB( + id="skip_test_user1", + username="skip_admin", + privilege=UserPrivilege.SYSADMIN.value, + roleLabels=["sysadmin"] # Already migrated + ) + user2 = UserInDB( + id="skip_test_user2", + username="skip_user1", + privilege=UserPrivilege.USER.value, + roleLabels=[] # Needs migration + ) + + try: + # Create test users in database + user1Data = user1.model_dump() + user1Data["id"] = user1.id + user2Data = user2.model_dump() + user2Data["id"] = user2.id + db.recordCreate(UserInDB, user1Data) + db.recordCreate(UserInDB, user2Data) + + # Run migration + results = migrateUamToRbac(db, dryRun=False) + + # Only one user should be updated (user2) + assert results["usersUpdated"] == 1 + + # Verify user1 still has original roleLabels + users1 = db.getRecordset(UserInDB, recordFilter={"id": "skip_test_user1"}) + updatedUser1 = users1[0] if users1 else None + assert updatedUser1 is not None + assert "sysadmin" in updatedUser1.get("roleLabels", []) + + # Verify user2 was updated + users2 = db.getRecordset(UserInDB, recordFilter={"id": "skip_test_user2"}) + updatedUser2 = users2[0] if users2 else None + assert updatedUser2 is not None + assert "user" in updatedUser2.get("roleLabels", []) + finally: + # Cleanup test users + try: + db.recordDelete(UserInDB, "skip_test_user1") + db.recordDelete(UserInDB, "skip_test_user2") + except: + pass + + def testDryRunMode(self, db): + """Test that dry run mode doesn't make changes.""" + # Create test user without roleLabels + testUser = UserInDB( + id="dryrun_test_user1", + username="dryrun_admin", + privilege=UserPrivilege.SYSADMIN.value, + roleLabels=[] # Needs migration + ) + + try: + # Create test user in database + userData = testUser.model_dump() + userData["id"] = testUser.id + db.recordCreate(UserInDB, userData) + + # Get original state + originalUsers = db.getRecordset(UserInDB, recordFilter={"id": "dryrun_test_user1"}) + originalUser = originalUsers[0] if originalUsers else None + assert originalUser is not None + originalRoleLabels = originalUser.get("roleLabels", []) + + # Run migration in dry run mode + results = migrateUamToRbac(db, dryRun=True) + + # Should report what would be done + assert results["usersUpdated"] == 1 + + # Verify user was NOT actually updated + unchangedUsers = db.getRecordset(UserInDB, recordFilter={"id": "dryrun_test_user1"}) + unchangedUser = unchangedUsers[0] if unchangedUsers else None + assert unchangedUser is not None + assert unchangedUser.get("roleLabels", []) == originalRoleLabels + finally: + # Cleanup test user + try: + db.recordDelete(UserInDB, "dryrun_test_user1") + except: + pass + + def testValidateMigrationSuccess(self, db): + """Test validation passes when migration is successful.""" + # Create test users with roleLabels (already migrated) + testUsers = [ + UserInDB( + id="validate_test_user1", + username="validate_admin", + privilege=UserPrivilege.SYSADMIN.value, + roleLabels=["sysadmin"] + ), + UserInDB( + id="validate_test_user2", + username="validate_admin2", + privilege=UserPrivilege.ADMIN.value, + roleLabels=["admin"] + ) + ] + + try: + # Create test users in database + for user in testUsers: + userData = user.model_dump() + userData["id"] = user.id + db.recordCreate(UserInDB, userData) + + # Ensure AccessRule table exists (migration should have created it) + from modules.datamodels.datamodelRbac import AccessRule + db._ensureTableExists(AccessRule) + + # Run validation + validation = validateMigration(db) + + assert validation["valid"] == True + assert len(validation["issues"]) == 0 + finally: + # Cleanup test users + for user in testUsers: + try: + db.recordDelete(UserInDB, user.id) + except: + pass + + def testValidateMigrationFailsWithoutRoleLabels(self, db): + """Test validation fails when users don't have roleLabels.""" + # Create test users: one with roleLabels, one without, one with empty roleLabels + testUsers = [ + UserInDB( + id="validate_fail_user1", + username="validate_fail_admin", + privilege=UserPrivilege.SYSADMIN.value, + roleLabels=["sysadmin"] # Has roleLabels + ), + UserInDB( + id="validate_fail_user2", + username="validate_fail_user", + privilege=UserPrivilege.USER.value, + roleLabels=[] # Empty roleLabels + ), + UserInDB( + id="validate_fail_user3", + username="validate_fail_user2", + privilege=UserPrivilege.USER.value + # Missing roleLabels field (will be None) + ) + ] + + try: + # Create test users in database + for user in testUsers: + userData = user.model_dump() + userData["id"] = user.id + # For user3, explicitly set roleLabels to None or remove it + if user.id == "validate_fail_user3": + if "roleLabels" in userData: + del userData["roleLabels"] + db.recordCreate(UserInDB, userData) + + # Ensure AccessRule table exists + from modules.datamodels.datamodelRbac import AccessRule + db._ensureTableExists(AccessRule) + + # Run validation + validation = validateMigration(db) + + assert validation["valid"] == False + assert len(validation["issues"]) > 0 + # Check that validation found users without roleLabels + issuesStr = " ".join(validation["issues"]) + assert "users without roleLabels" in issuesStr or "without roleLabels" in issuesStr + finally: + # Cleanup test users + for user in testUsers: + try: + db.recordDelete(UserInDB, user.id) + except: + pass diff --git a/tests/unit/rbac/README.md b/tests/unit/rbac/README.md new file mode 100644 index 00000000..3666ef2a --- /dev/null +++ b/tests/unit/rbac/README.md @@ -0,0 +1,47 @@ +# RBAC Unit Tests + +Unit tests for the Role-Based Access Control (RBAC) system. + +## Test Files + +### `test_rbac_permissions.py` +Tests RBAC permission resolution logic: +- Single role with generic rules +- Rule specificity (most specific wins) +- Multiple roles with union logic +- View permission overrides +- No roles scenario +- Finding most specific rules +- Opening rights validation +- UI and RESOURCE context handling + +### `test_rbac_bootstrap.py` +Tests RBAC bootstrap initialization: +- Root mandate creation +- Admin user creation with sysadmin role +- Event user creation with sysadmin role +- Default role rules creation +- Table-specific rules creation +- Rule initialization skipping when rules exist + +## Running Tests + +```bash +# Run all RBAC unit tests +pytest tests/unit/rbac/ + +# Run specific test file +pytest tests/unit/rbac/test_rbac_permissions.py + +# Run with verbose output +pytest tests/unit/rbac/ -v +``` + +## Test Coverage + +- Permission resolution algorithms +- Rule specificity logic +- Multiple role combination (union logic) +- Access rule validation +- Bootstrap initialization +- Default rule creation diff --git a/tests/unit/rbac/__init__.py b/tests/unit/rbac/__init__.py new file mode 100644 index 00000000..5d55b3ca --- /dev/null +++ b/tests/unit/rbac/__init__.py @@ -0,0 +1 @@ +"""Unit tests for RBAC system.""" diff --git a/tests/unit/rbac/test_rbac_bootstrap.py b/tests/unit/rbac/test_rbac_bootstrap.py new file mode 100644 index 00000000..573a4fd1 --- /dev/null +++ b/tests/unit/rbac/test_rbac_bootstrap.py @@ -0,0 +1,162 @@ +""" +Unit tests for RBAC bootstrap initialization. +Tests that bootstrap creates correct rules and initial data. +""" + +import pytest +from unittest.mock import Mock, MagicMock, patch +from modules.interfaces.interfaceBootstrap import ( + initBootstrap, + initRootMandate, + initAdminUser, + initEventUser, + initRbacRules, + createDefaultRoleRules, + createTableSpecificRules +) +from modules.datamodels.datamodelUam import UserInDB, Mandate, UserPrivilege, AuthAuthority +from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext +from modules.datamodels.datamodelUam import AccessLevel + + +class TestRbacBootstrap: + """Test RBAC bootstrap initialization.""" + + def testInitRootMandateCreatesIfNotExists(self): + """Test that initRootMandate creates mandate if it doesn't exist.""" + db = Mock() + db.getRecordset = Mock(return_value=[]) # No existing mandates + db.recordCreate = Mock(return_value={"id": "mandate1", "name": "Root"}) + + mandateId = initRootMandate(db) + + assert mandateId == "mandate1" + db.recordCreate.assert_called_once() + callArgs = db.recordCreate.call_args + assert isinstance(callArgs[0][1], Mandate) + assert callArgs[0][1].name == "Root" + + def testInitRootMandateReturnsExisting(self): + """Test that initRootMandate returns existing mandate ID.""" + db = Mock() + db.getRecordset = Mock(return_value=[{"id": "existing_mandate"}]) + + mandateId = initRootMandate(db) + + assert mandateId == "existing_mandate" + db.recordCreate.assert_not_called() + + def testInitAdminUserCreatesWithSysadminRole(self): + """Test that initAdminUser creates user with sysadmin role.""" + db = Mock() + db.getRecordset = Mock(return_value=[]) # No existing users + db.recordCreate = Mock(return_value={"id": "admin1", "username": "admin"}) + + with patch('modules.interfaces.interfaceBootstrap._getPasswordHash', return_value="hashed"): + userId = initAdminUser(db, "mandate1") + + assert userId == "admin1" + db.recordCreate.assert_called_once() + callArgs = db.recordCreate.call_args + user = callArgs[0][1] + assert isinstance(user, UserInDB) + assert user.username == "admin" + assert "sysadmin" in user.roleLabels + assert user.privilege == UserPrivilege.SYSADMIN + + def testInitEventUserCreatesWithSysadminRole(self): + """Test that initEventUser creates user with sysadmin role.""" + db = Mock() + db.getRecordset = Mock(return_value=[]) # No existing users + db.recordCreate = Mock(return_value={"id": "event1", "username": "event"}) + + with patch('modules.interfaces.interfaceBootstrap._getPasswordHash', return_value="hashed"): + userId = initEventUser(db, "mandate1") + + assert userId == "event1" + db.recordCreate.assert_called_once() + callArgs = db.recordCreate.call_args + user = callArgs[0][1] + assert isinstance(user, UserInDB) + assert user.username == "event" + assert "sysadmin" in user.roleLabels + + def testCreateDefaultRoleRules(self): + """Test that createDefaultRoleRules creates correct default rules.""" + db = Mock() + db.recordCreate = Mock() + + createDefaultRoleRules(db) + + # Should create 4 default rules (sysadmin, admin, user, viewer) + assert db.recordCreate.call_count == 4 + + # Check sysadmin rule + sysadminCall = [call for call in db.recordCreate.call_args_list + if call[0][1].roleLabel == "sysadmin"][0] + sysadminRule = sysadminCall[0][1] + assert sysadminRule.context == AccessRuleContext.DATA + assert sysadminRule.item is None + assert sysadminRule.view == True + assert sysadminRule.read == AccessLevel.ALL + assert sysadminRule.create == AccessLevel.ALL + + # Check user rule + userCall = [call for call in db.recordCreate.call_args_list + if call[0][1].roleLabel == "user"][0] + userRule = userCall[0][1] + assert userRule.read == AccessLevel.MY + assert userRule.create == AccessLevel.MY + + def testCreateTableSpecificRules(self): + """Test that createTableSpecificRules creates table-specific rules.""" + db = Mock() + db.recordCreate = Mock() + + createTableSpecificRules(db) + + # Should create multiple rules for different tables + assert db.recordCreate.call_count > 0 + + # Check that Mandate table rules are created + mandateCalls = [call for call in db.recordCreate.call_args_list + if call[0][1].item == "Mandate"] + assert len(mandateCalls) > 0 + + # Check sysadmin rule for Mandate + sysadminMandateCall = [call for call in mandateCalls + if call[0][1].roleLabel == "sysadmin"][0] + sysadminRule = sysadminMandateCall[0][1] + assert sysadminRule.view == True + assert sysadminRule.read == AccessLevel.ALL + + # Check that other roles have view=False for Mandate + otherMandateCalls = [call for call in mandateCalls + if call[0][1].roleLabel != "sysadmin"] + for call in otherMandateCalls: + rule = call[0][1] + assert rule.view == False + + def testInitRbacRulesSkipsIfExists(self): + """Test that initRbacRules skips creation if rules already exist.""" + db = Mock() + db.getRecordset = Mock(return_value=[{"id": "rule1"}]) # Rules exist + + initRbacRules(db) + + # Should not create new rules + db.recordCreate.assert_not_called() + + def testInitRbacRulesCreatesIfNotExists(self): + """Test that initRbacRules creates rules if they don't exist.""" + db = Mock() + db.getRecordset = Mock(side_effect=[ + [], # No existing rules + [] # After creating default rules + ]) + db.recordCreate = Mock() + + initRbacRules(db) + + # Should create rules + assert db.recordCreate.call_count > 0 diff --git a/tests/unit/rbac/test_rbac_permissions.py b/tests/unit/rbac/test_rbac_permissions.py new file mode 100644 index 00000000..d180f5b8 --- /dev/null +++ b/tests/unit/rbac/test_rbac_permissions.py @@ -0,0 +1,403 @@ +""" +Unit tests for RBAC permission resolution. +Tests rule specificity, multiple roles, and permission combination logic. +""" + +import pytest +from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions +from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext +from modules.security.rbac import RbacClass +from modules.connectors.connectorDbPostgre import DatabaseConnector +from unittest.mock import Mock, MagicMock + + +class TestRbacPermissionResolution: + """Test RBAC permission resolution logic.""" + + def testSingleRoleGenericRule(self): + """Test permission resolution with a single role and generic rule.""" + # Mock database connector + db = Mock(spec=DatabaseConnector) + + # Create RBAC interface + rbac = RbacClass(db) + + # Create user with single role + user = User( + id="user1", + username="testuser", + roleLabels=["user"], + mandateId="mandate1" + ) + + # Mock rules for "user" role + def mockGetRulesForRole(roleLabel, context): + if roleLabel == "user" and context == AccessRuleContext.DATA: + return [ + AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item=None, # Generic rule + view=True, + read=AccessLevel.MY, + create=AccessLevel.MY, + update=AccessLevel.MY, + delete=AccessLevel.MY + ) + ] + return [] + + rbac._getRulesForRole = mockGetRulesForRole + + # Get permissions for generic table + permissions = rbac.getUserPermissions( + user, + AccessRuleContext.DATA, + "SomeTable" + ) + + assert permissions.view == True + assert permissions.read == AccessLevel.MY + assert permissions.create == AccessLevel.MY + assert permissions.update == AccessLevel.MY + assert permissions.delete == AccessLevel.MY + + def testRuleSpecificityMostSpecificWins(self): + """Test that most specific rule wins within a single role.""" + db = Mock(spec=DatabaseConnector) + rbac = RbacClass(db) + + user = User( + id="user1", + username="testuser", + roleLabels=["user"], + mandateId="mandate1" + ) + + def mockGetRulesForRole(roleLabel, context): + if roleLabel == "user" and context == AccessRuleContext.DATA: + return [ + AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item=None, # Generic rule + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.GROUP + ), + AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="UserInDB", # Specific rule + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.MY, + delete=AccessLevel.NONE + ) + ] + return [] + + rbac._getRulesForRole = mockGetRulesForRole + + # Get permissions for UserInDB table - should use specific rule + permissions = rbac.getUserPermissions( + user, + AccessRuleContext.DATA, + "UserInDB" + ) + + # Most specific rule should win + assert permissions.read == AccessLevel.MY + assert permissions.create == AccessLevel.NONE + assert permissions.update == AccessLevel.MY + assert permissions.delete == AccessLevel.NONE + + def testMultipleRolesUnionLogic(self): + """Test that multiple roles use union (opening) logic.""" + db = Mock(spec=DatabaseConnector) + rbac = RbacClass(db) + + # User with multiple roles + user = User( + id="user1", + username="testuser", + roleLabels=["user", "viewer"], + mandateId="mandate1" + ) + + def mockGetRulesForRole(roleLabel, context): + if context == AccessRuleContext.UI: + if roleLabel == "user": + return [ + AccessRule( + roleLabel="user", + context=AccessRuleContext.UI, + item="playground", + view=False # User role hides playground + ) + ] + elif roleLabel == "viewer": + return [ + AccessRule( + roleLabel="viewer", + context=AccessRuleContext.UI, + item="playground", + view=True # Viewer role shows playground + ) + ] + return [] + + rbac._getRulesForRole = mockGetRulesForRole + + # Get permissions - union logic should make playground visible + permissions = rbac.getUserPermissions( + user, + AccessRuleContext.UI, + "playground" + ) + + # Union logic: if ANY role has view=true, then view=true + assert permissions.view == True + + def testViewFalseOverridesGeneric(self): + """Test that specific view=false overrides generic view=true.""" + db = Mock(spec=DatabaseConnector) + rbac = RbacClass(db) + + user = User( + id="user1", + username="testuser", + roleLabels=["user"], + mandateId="mandate1" + ) + + def mockGetRulesForRole(roleLabel, context): + if roleLabel == "user" and context == AccessRuleContext.UI: + return [ + AccessRule( + roleLabel="user", + context=AccessRuleContext.UI, + item=None, # Generic: view all UI + view=True + ), + AccessRule( + roleLabel="user", + context=AccessRuleContext.UI, + item="playground.voice.settings", # Specific: hide this + view=False + ) + ] + return [] + + rbac._getRulesForRole = mockGetRulesForRole + + # Get permissions for specific UI element + permissions = rbac.getUserPermissions( + user, + AccessRuleContext.UI, + "playground.voice.settings" + ) + + # Specific rule should override generic + assert permissions.view == False + + def testNoRolesReturnsNoAccess(self): + """Test that user with no roles gets no access.""" + db = Mock(spec=DatabaseConnector) + rbac = RbacClass(db) + + user = User( + id="user1", + username="testuser", + roleLabels=[], # No roles + mandateId="mandate1" + ) + + permissions = rbac.getUserPermissions( + user, + AccessRuleContext.DATA, + "SomeTable" + ) + + assert permissions.view == False + assert permissions.read == AccessLevel.NONE + assert permissions.create == AccessLevel.NONE + assert permissions.update == AccessLevel.NONE + assert permissions.delete == AccessLevel.NONE + + def testFindMostSpecificRule(self): + """Test findMostSpecificRule method.""" + db = Mock(spec=DatabaseConnector) + rbac = RbacClass(db) + + rules = [ + AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item=None, # Generic + view=True, + read=AccessLevel.GROUP + ), + AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="UserInDB", # Table-level + view=True, + read=AccessLevel.MY + ), + AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="UserInDB.email", # Field-level - most specific + view=True, + read=AccessLevel.NONE + ) + ] + + # Test exact match + rule = rbac.findMostSpecificRule(rules, "UserInDB.email") + assert rule is not None + assert rule.item == "UserInDB.email" + assert rule.read == AccessLevel.NONE + + # Test table-level match + rule = rbac.findMostSpecificRule(rules, "UserInDB") + assert rule is not None + assert rule.item == "UserInDB" + assert rule.read == AccessLevel.MY + + # Test generic fallback + rule = rbac.findMostSpecificRule(rules, "OtherTable") + assert rule is not None + assert rule.item is None + assert rule.read == AccessLevel.GROUP + + def testValidateAccessRuleOpeningRights(self): + """Test that CUD permissions respect read permission level.""" + db = Mock(spec=DatabaseConnector) + rbac = RbacClass(db) + + # Valid: Read=MY, Create=MY (allowed) + rule1 = AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="UserInDB", + view=True, + read=AccessLevel.MY, + create=AccessLevel.MY, + update=AccessLevel.MY, + delete=AccessLevel.MY + ) + assert rbac.validateAccessRule(rule1) == True + + # Invalid: Read=MY, Create=GROUP (not allowed - GROUP > MY) + rule2 = AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="UserInDB", + view=True, + read=AccessLevel.MY, + create=AccessLevel.GROUP, # Not allowed + update=AccessLevel.MY, + delete=AccessLevel.MY + ) + assert rbac.validateAccessRule(rule2) == False + + # Valid: Read=GROUP, Create=GROUP (allowed) + rule3 = AccessRule( + roleLabel="admin", + context=AccessRuleContext.DATA, + item="UserInDB", + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.GROUP + ) + assert rbac.validateAccessRule(rule3) == True + + # Invalid: Read=NONE, Create=MY (not allowed - no read access) + rule4 = AccessRule( + roleLabel="user", + context=AccessRuleContext.DATA, + item="UserInDB", + view=True, + read=AccessLevel.NONE, + create=AccessLevel.MY, # Not allowed without read + update=AccessLevel.MY, + delete=AccessLevel.MY + ) + assert rbac.validateAccessRule(rule4) == False + + def testUiContextOnlyViewMatters(self): + """Test that UI context only checks view permission.""" + db = Mock(spec=DatabaseConnector) + rbac = RbacClass(db) + + user = User( + id="user1", + username="testuser", + roleLabels=["user"], + mandateId="mandate1" + ) + + def mockGetRulesForRole(roleLabel, context): + if roleLabel == "user" and context == AccessRuleContext.UI: + return [ + AccessRule( + roleLabel="user", + context=AccessRuleContext.UI, + item="playground", + view=True + # No read/create/update/delete for UI context + ) + ] + return [] + + rbac._getRulesForRole = mockGetRulesForRole + + permissions = rbac.getUserPermissions( + user, + AccessRuleContext.UI, + "playground" + ) + + assert permissions.view == True + # Other permissions don't matter for UI context + + def testResourceContextOnlyViewMatters(self): + """Test that RESOURCE context only checks view permission.""" + db = Mock(spec=DatabaseConnector) + rbac = RbacClass(db) + + user = User( + id="user1", + username="testuser", + roleLabels=["user"], + mandateId="mandate1" + ) + + def mockGetRulesForRole(roleLabel, context): + if roleLabel == "user" and context == AccessRuleContext.RESOURCE: + return [ + AccessRule( + roleLabel="user", + context=AccessRuleContext.RESOURCE, + item="ai.model.anthropic", + view=True + ) + ] + return [] + + rbac._getRulesForRole = mockGetRulesForRole + + permissions = rbac.getUserPermissions( + user, + AccessRuleContext.RESOURCE, + "ai.model.anthropic" + ) + + assert permissions.view == True