From 13b7c4fdbe6fe11c40fda7c627bbdbaf81bd1b25 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Sun, 7 Dec 2025 08:48:49 +0100
Subject: [PATCH 1/6] hot fixes: sharepoint folders and stats
---
modules/routes/routeDataAutomation.py | 3 +-
.../services/serviceChat/mainServiceChat.py | 3 +-
.../mainServiceSharepoint.py | 410 +++-
modules/workflows/methods/methodAi.py | 8 +-
modules/workflows/methods/methodContext.py | 4 +-
modules/workflows/methods/methodOutlook.py | 8 +-
modules/workflows/methods/methodSharepoint.py | 1732 ++++++++---------
.../processing/core/actionExecutor.py | 29 +
8 files changed, 1229 insertions(+), 968 deletions(-)
diff --git a/modules/routes/routeDataAutomation.py b/modules/routes/routeDataAutomation.py
index 903d0d53..ee13915c 100644
--- a/modules/routes/routeDataAutomation.py
+++ b/modules/routes/routeDataAutomation.py
@@ -15,6 +15,7 @@ from modules.security.auth import getCurrentUser, limiter
from modules.datamodels.datamodelChat import AutomationDefinition, ChatWorkflow
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
from modules.shared.attributeUtils import getModelAttributeDefinitions
+from modules.features.automation import executeAutomation
# Configure logger
logger = logging.getLogger(__name__)
@@ -217,7 +218,7 @@ async def execute_automation(
"""Execute an automation immediately (test mode)"""
try:
chatInterface = getChatInterface(currentUser)
- workflow = await chatInterface.executeAutomation(automationId)
+ workflow = await executeAutomation(automationId, chatInterface)
return workflow
except HTTPException:
raise
diff --git a/modules/services/serviceChat/mainServiceChat.py b/modules/services/serviceChat/mainServiceChat.py
index cb05279f..7848cb29 100644
--- a/modules/services/serviceChat/mainServiceChat.py
+++ b/modules/services/serviceChat/mainServiceChat.py
@@ -1013,7 +1013,8 @@ class ChatService:
return self._progressLogger
def createProgressLogger(self) -> ProgressLogger:
- return ProgressLogger(self.services)
+ """Get or create the progress logger instance (singleton)"""
+ return self._getProgressLogger()
def progressLogStart(self, operationId: str, serviceName: str, actionName: str, context: str = "", parentOperationId: Optional[str] = None):
"""Wrapper for ProgressLogger.startOperation
diff --git a/modules/services/serviceSharepoint/mainServiceSharepoint.py b/modules/services/serviceSharepoint/mainServiceSharepoint.py
index e7f24648..6c6c266e 100644
--- a/modules/services/serviceSharepoint/mainServiceSharepoint.py
+++ b/modules/services/serviceSharepoint/mainServiceSharepoint.py
@@ -287,7 +287,12 @@ class SharepointService:
try:
# Clean the path
cleanPath = folderPath.lstrip('/')
- endpoint = f"sites/{siteId}/drive/root:/{cleanPath}"
+
+ # If path is empty, get root directly
+ if not cleanPath:
+ endpoint = f"sites/{siteId}/drive/root"
+ else:
+ endpoint = f"sites/{siteId}/drive/root:/{cleanPath}"
result = await self._makeGraphApiCall(endpoint)
@@ -499,4 +504,407 @@ class SharepointService:
except Exception as e:
logger.error(f"Error downloading file by path: {str(e)}")
return None
+
+ async def _getItemById(self, siteId: str, driveId: str, itemId: str) -> Optional[Dict[str, Any]]:
+ """Verify that an item exists by getting it by ID.
+
+ Args:
+ siteId: SharePoint site ID
+ driveId: Drive ID (document library)
+ itemId: Item ID to verify
+
+ Returns:
+ Item dictionary if found, None otherwise
+ """
+ try:
+ endpoint = f"sites/{siteId}/drives/{driveId}/items/{itemId}"
+ result = await self._makeGraphApiCall(endpoint)
+
+ if "error" in result:
+ logger.warning(f"Item {itemId} not found: {result['error']}")
+ return None
+
+ return result
+
+ except Exception as e:
+ logger.warning(f"Error verifying item {itemId}: {str(e)}")
+ return None
+
+ async def _findDriveForItem(self, siteId: str, itemId: str) -> Optional[str]:
+ """Find which drive contains a specific item by trying to get it from all drives.
+
+ Args:
+ siteId: SharePoint site ID
+ itemId: Item ID to find
+
+ Returns:
+ Drive ID if found, None otherwise
+ """
+ try:
+ # Get all drives for the site
+ endpoint = f"sites/{siteId}/drives"
+ drivesResult = await self._makeGraphApiCall(endpoint)
+
+ if "error" in drivesResult:
+ logger.warning(f"Could not get drives for site {siteId}: {drivesResult['error']}")
+ return None
+
+ drives = drivesResult.get("value", [])
+ if not drives:
+ logger.warning(f"No drives found for site {siteId}")
+ return None
+
+ # Try to find the item in each drive
+ for drive in drives:
+ driveId = drive.get("id")
+ if not driveId:
+ continue
+
+ itemInfo = await self._getItemById(siteId, driveId, itemId)
+ if itemInfo:
+ logger.info(f"Found item {itemId} in drive {drive.get('name', driveId)}")
+ return driveId
+
+ logger.warning(f"Item {itemId} not found in any drive for site {siteId}")
+ return None
+
+ except Exception as e:
+ logger.warning(f"Error finding drive for item {itemId}: {str(e)}")
+ return None
+
+ async def getFolderUsageAnalytics(self, siteId: str, driveId: str, itemId: str, startDateTime: Optional[str] = None, endDateTime: Optional[str] = None, interval: str = "day") -> Dict[str, Any]:
+ """Get usage analytics for a folder or file.
+
+ Args:
+ siteId: SharePoint site ID
+ driveId: Drive ID (document library)
+ itemId: Folder or file item ID
+ startDateTime: Start date/time in ISO format (e.g., "2025-11-01T00:00:00Z"). If None, uses 30 days ago.
+ endDateTime: End date/time in ISO format (e.g., "2025-11-30T23:59:59Z"). If None, uses current time.
+ interval: Time interval for grouping activities. Options: "day", "week", "month". Default: "day"
+
+ Returns:
+ Dictionary containing analytics data with activities grouped by interval.
+ If analytics are not available (404), returns empty analytics structure instead of error.
+ """
+ try:
+ from datetime import datetime, timedelta, timezone
+
+ # Set default time range if not provided (last 30 days)
+ if not endDateTime:
+ endDateTime = datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z')
+ if not startDateTime:
+ startDate = datetime.now(timezone.utc) - timedelta(days=30)
+ startDateTime = startDate.isoformat().replace('+00:00', 'Z')
+
+ # Build endpoint with query parameters
+ endpoint = f"sites/{siteId}/drives/{driveId}/items/{itemId}/getActivitiesByInterval"
+ endpoint += f"?startDateTime={startDateTime}&endDateTime={endDateTime}&interval={interval}"
+
+ result = await self._makeGraphApiCall(endpoint)
+
+ if "error" in result:
+ errorMsg = result.get('error', '')
+ # Check if it's a 404 error
+ if isinstance(errorMsg, str) and '404' in errorMsg:
+ # Verify if the item exists - first try with current driveId
+ itemInfo = await self._getItemById(siteId, driveId, itemId)
+
+ # If not found, try to find the correct drive for this item
+ if not itemInfo:
+ logger.info(f"Item {itemId} not found in drive {driveId}, searching for correct drive")
+ correctDriveId = await self._findDriveForItem(siteId, itemId)
+ if correctDriveId and correctDriveId != driveId:
+ logger.info(f"Found item in different drive {correctDriveId}, retrying analytics call")
+ # Retry with correct drive
+ endpoint = f"sites/{siteId}/drives/{correctDriveId}/items/{itemId}/getActivitiesByInterval"
+ endpoint += f"?startDateTime={startDateTime}&endDateTime={endDateTime}&interval={interval}"
+ result = await self._makeGraphApiCall(endpoint)
+
+ if "error" not in result:
+ logger.info(f"Successfully retrieved analytics using correct drive {correctDriveId}")
+ return result
+ # If still error, continue with original error handling
+ itemInfo = await self._getItemById(siteId, correctDriveId, itemId)
+
+ if itemInfo:
+ # Item exists but analytics are not available - return empty analytics
+ logger.warning(f"Usage analytics not available for item {itemId} (item exists but has no activity data or analytics not supported)")
+ return {
+ "value": [],
+ "note": "No analytics data available for this item. The item exists but may not have activity data or analytics may not be supported for this item type."
+ }
+ else:
+ # Item doesn't exist
+ logger.error(f"Item {itemId} not found when trying to get usage analytics")
+ return result
+ else:
+ # Other error
+ logger.error(f"Error getting usage analytics: {result['error']}")
+ return result
+
+ logger.info(f"Retrieved usage analytics for item {itemId} with interval {interval}")
+ return result
+
+ except Exception as e:
+ logger.error(f"Error getting folder usage analytics: {str(e)}")
+ return {"error": f"Error getting folder usage analytics: {str(e)}"}
+
+ async def getDriveId(self, siteId: str, driveName: Optional[str] = None) -> Optional[str]:
+ """Get drive ID for a site. If driveName is provided, finds the specific drive, otherwise returns the default drive.
+
+ Args:
+ siteId: SharePoint site ID
+ driveName: Optional drive name (document library name). If None, returns default drive.
+
+ Returns:
+ Drive ID string or None if not found
+ """
+ try:
+ endpoint = f"sites/{siteId}/drives"
+ result = await self._makeGraphApiCall(endpoint)
+
+ if "error" in result:
+ logger.error(f"Error getting drives: {result['error']}")
+ return None
+
+ drives = result.get("value", [])
+
+ if not driveName:
+ # Return default drive (usually the first one or the one named "Documents")
+ for drive in drives:
+ if drive.get("name") == "Documents" or drive.get("name") == "Shared Documents":
+ logger.info(f"Found default drive: {drive.get('name')} (ID: {drive.get('id')})")
+ return drive.get("id")
+ # If no Documents drive found, return first drive
+ if drives:
+ logger.info(f"Using first drive: {drives[0].get('name')} (ID: {drives[0].get('id')})")
+ return drives[0].get("id")
+ return None
+
+ # Find specific drive by name
+ for drive in drives:
+ if drive.get("name", "").lower() == driveName.lower():
+ logger.info(f"Found drive '{driveName}': {drive.get('id')}")
+ return drive.get("id")
+
+ logger.warning(f"Drive '{driveName}' not found")
+ return None
+
+ except Exception as e:
+ logger.error(f"Error getting drive ID: {str(e)}")
+ return None
+
+ def extractSiteFromStandardPath(self, pathQuery: str) -> Optional[Dict[str, str]]:
+ """
+ Extract site name from Microsoft-standard server-relative path:
+ /sites/company-share/Freigegebene Dokumente/...
+
+ Returns dict with keys: siteName, innerPath (no leading slash) on success, else None.
+ """
+ try:
+ if not pathQuery or not pathQuery.startswith('/sites/'):
+ return None
+
+ # Remove leading /sites/ prefix
+ remainder = pathQuery[7:] # len('/sites/') = 7
+
+ # Split on first '/' to get site name
+ if '/' not in remainder:
+ # Only site name, no inner path
+ return {"siteName": remainder, "innerPath": ""}
+
+ siteName, inner = remainder.split('/', 1)
+ siteName = siteName.strip()
+ innerPath = inner.strip()
+
+ if not siteName:
+ return None
+
+ return {"siteName": siteName, "innerPath": innerPath}
+ except Exception as e:
+ logger.error(f"Error extracting site from standard path '{pathQuery}': {str(e)}")
+ return None
+
+ async def getSiteByStandardPath(self, sitePath: str, allSites: Optional[List[Dict[str, Any]]] = None) -> Optional[Dict[str, Any]]:
+ """
+ Get SharePoint site directly by Microsoft-standard path (/sites/SiteName)
+ without loading all sites. Uses hostname from first available site.
+
+ Parameters:
+ sitePath (str): Site path like 'company-share' (without /sites/ prefix)
+ allSites (Optional[List[Dict]]): Pre-discovered sites list (optional, for optimization)
+
+ Returns:
+ Optional[Dict[str, Any]]: Site information if found, None otherwise
+ """
+ try:
+ # Get hostname from first available site (minimal load - only 1 site)
+ if allSites and len(allSites) > 0:
+ from urllib.parse import urlparse
+ webUrl = allSites[0].get("webUrl", "")
+ hostname = urlparse(webUrl).hostname if webUrl else None
+ else:
+ # Discover minimal sites to get hostname
+ minimalSites = await self.discoverSites()
+ if not minimalSites:
+ logger.warning("No sites available to extract hostname")
+ return None
+ from urllib.parse import urlparse
+ hostname = urlparse(minimalSites[0].get("webUrl", "")).hostname
+
+ if not hostname:
+ logger.warning("Could not extract hostname from site")
+ return None
+
+ logger.info(f"Extracted hostname '{hostname}' from first site, now getting site by path: {sitePath}")
+
+ # Get site directly using hostname + path
+ endpoint = f"sites/{hostname}:/sites/{sitePath}"
+ result = await self._makeGraphApiCall(endpoint)
+
+ if "error" in result:
+ logger.warning(f"Could not get site directly by path '{sitePath}': {result['error']}")
+ return None
+
+ siteInfo = {
+ "id": result.get("id"),
+ "displayName": result.get("displayName"),
+ "name": result.get("name"),
+ "webUrl": result.get("webUrl"),
+ "description": result.get("description"),
+ "createdDateTime": result.get("createdDateTime"),
+ "lastModifiedDateTime": result.get("lastModifiedDateTime")
+ }
+
+ logger.info(f"Successfully got site by standard path: {siteInfo['displayName']} (ID: {siteInfo['id']})")
+ return siteInfo
+
+ except Exception as e:
+ logger.error(f"Error getting site by standard path '{sitePath}': {str(e)}")
+ return None
+
+ def filterSitesByHint(self, sites: List[Dict[str, Any]], siteHint: str) -> List[Dict[str, Any]]:
+ """Filter discovered sites by a human-entered site hint (case-insensitive substring)."""
+ try:
+ if not siteHint:
+ return sites
+ hint = siteHint.strip().lower()
+ filtered: List[Dict[str, Any]] = []
+ for site in sites:
+ name = (site.get("displayName") or "").lower()
+ webUrl = (site.get("webUrl") or "").lower()
+ if hint in name or hint in webUrl:
+ filtered.append(site)
+ return filtered if filtered else sites
+ except Exception as e:
+ logger.error(f"Error filtering sites by hint '{siteHint}': {str(e)}")
+ return sites
+
+ async def resolveSitesFromPathQuery(self, pathQuery: str, allSites: Optional[List[Dict[str, Any]]] = None) -> List[Dict[str, Any]]:
+ """
+ Resolve sites from pathQuery. Handles both Microsoft-standard paths (/sites/SiteName/...)
+ and regular paths. Returns list of matching sites.
+
+ Parameters:
+ pathQuery (str): Path query string (e.g., /sites/SiteName/FolderPath)
+ allSites (Optional[List[Dict]]): Pre-discovered sites list (optional, for optimization)
+
+ Returns:
+ List[Dict[str, Any]]: List of matching sites
+ """
+ try:
+ # If pathQuery starts with Microsoft-standard /sites/, try to get site directly
+ if pathQuery.startswith('/sites/'):
+ parsedPath = self.extractSiteFromStandardPath(pathQuery)
+ if parsedPath:
+ siteName = parsedPath.get("siteName")
+ directSite = await self.getSiteByStandardPath(siteName, allSites)
+ if directSite:
+ logger.info(f"Got site directly by standard path - no need to discover all sites")
+ return [directSite]
+ else:
+ logger.warning(f"Could not get site directly, falling back to site discovery")
+
+ # If we didn't get the site directly, use discovery and filtering
+ if not allSites:
+ allSites = await self.discoverSites()
+ if not allSites:
+ logger.warning("No SharePoint sites found or accessible")
+ return []
+
+ # If pathQuery starts with Microsoft-standard /sites/, extract site name and filter
+ if pathQuery.startswith('/sites/'):
+ parsedPath = self.extractSiteFromStandardPath(pathQuery)
+ if parsedPath:
+ siteName = parsedPath.get("siteName")
+ sites = self.filterSitesByHint(allSites, siteName)
+ if not sites:
+ logger.warning(f"No SharePoint site found matching '{siteName}'")
+ return []
+ logger.info(f"Filtered to site(s) matching '{siteName}': {[s['displayName'] for s in sites]}")
+ return sites
+ else:
+ return allSites
+ else:
+ return allSites
+
+ except Exception as e:
+ logger.error(f"Error resolving sites from pathQuery '{pathQuery}': {str(e)}")
+ return []
+
+ def validatePathQuery(self, pathQuery: str) -> tuple[bool, Optional[str]]:
+ """
+ Validate pathQuery format. Returns (isValid, errorMessage).
+
+ Parameters:
+ pathQuery (str): Path query to validate
+
+ Returns:
+ tuple[bool, Optional[str]]: (True, None) if valid, (False, errorMessage) if invalid
+ """
+ try:
+ if not pathQuery or pathQuery.strip() == "" or pathQuery.strip() == "*":
+ return False, "pathQuery cannot be empty or '*'"
+
+ if not pathQuery.startswith('/'):
+ return False, "pathQuery must start with '/' and include site name with Microsoft-standard syntax /sites//... e.g. /sites/company-share/Freigegebene Dokumente/Work"
+
+ # Check if pathQuery contains search terms (words without proper path structure)
+ validPathPrefixes = ['/sites/', '/Documents', '/documents', '/Shared Documents', '/shared documents']
+ if not any(pathQuery.startswith(prefix) for prefix in validPathPrefixes):
+ return False, f"Invalid pathQuery '{pathQuery}'. This appears to be search terms, not a valid SharePoint path. Use findDocumentPath action first to search for folders, then use the returned folder path as pathQuery."
+
+ return True, None
+ except Exception as e:
+ logger.error(f"Error validating pathQuery '{pathQuery}': {str(e)}")
+ return False, f"Error validating pathQuery: {str(e)}"
+
+ def detectFolderType(self, item: Dict[str, Any]) -> bool:
+ """
+ Detect if an item is a folder using improved detection logic.
+
+ Parameters:
+ item (Dict[str, Any]): Item from SharePoint API response
+
+ Returns:
+ bool: True if item is a folder, False otherwise
+ """
+ try:
+ # Use improved folder detection logic
+ if 'folder' in item:
+ return True
+
+ # Try to detect by URL pattern or other indicators
+ webUrl = item.get('webUrl', '')
+ name = item.get('name', '')
+
+ # Check if URL has no file extension and looks like a folder path
+ if '.' not in name and ('/' in webUrl or '\\' in webUrl):
+ return True
+
+ return False
+ except Exception as e:
+ logger.error(f"Error detecting folder type: {str(e)}")
+ return False
diff --git a/modules/workflows/methods/methodAi.py b/modules/workflows/methods/methodAi.py
index eee848f7..ba6bb9b3 100644
--- a/modules/workflows/methods/methodAi.py
+++ b/modules/workflows/methods/methodAi.py
@@ -49,11 +49,13 @@ class MethodAi(MethodBase):
operationId = f"ai_process_{workflowId}_{int(time.time())}"
# Start progress tracking
+ parentOperationId = parameters.get('parentOperationId')
self.services.chat.progressLogStart(
operationId,
"Generate",
"AI Processing",
- f"Format: {parameters.get('resultType', 'txt')}"
+ f"Format: {parameters.get('resultType', 'txt')}",
+ parentOperationId=parentOperationId
)
aiPrompt = parameters.get("aiPrompt")
@@ -256,11 +258,13 @@ class MethodAi(MethodBase):
operationId = f"web_research_{workflowId}_{int(time.time())}"
# Start progress tracking
+ parentOperationId = parameters.get('parentOperationId')
self.services.chat.progressLogStart(
operationId,
"Web Research",
"Searching and Crawling",
- "Extracting URLs and Content"
+ "Extracting URLs and Content",
+ parentOperationId=parentOperationId
)
# Call webcrawl service - service handles all AI intention analysis and processing
diff --git a/modules/workflows/methods/methodContext.py b/modules/workflows/methods/methodContext.py
index 8bd16f9b..20485612 100644
--- a/modules/workflows/methods/methodContext.py
+++ b/modules/workflows/methods/methodContext.py
@@ -250,11 +250,13 @@ class MethodContext(MethodBase):
return ActionResult.isFailure(error=f"Invalid documentList type: {type(documentListParam)}")
# Start progress tracking
+ parentOperationId = parameters.get('parentOperationId')
self.services.chat.progressLogStart(
operationId,
"Extracting content from documents",
"Content Extraction",
- f"Documents: {len(documentList.references)}"
+ f"Documents: {len(documentList.references)}",
+ parentOperationId=parentOperationId
)
# Get ChatDocuments from documentList
diff --git a/modules/workflows/methods/methodOutlook.py b/modules/workflows/methods/methodOutlook.py
index 033b5283..16030fcc 100644
--- a/modules/workflows/methods/methodOutlook.py
+++ b/modules/workflows/methods/methodOutlook.py
@@ -334,11 +334,13 @@ class MethodOutlook(MethodBase):
operationId = f"outlook_read_{workflowId}_{int(time.time())}"
# Start progress tracking
+ parentOperationId = parameters.get('parentOperationId')
self.services.chat.progressLogStart(
operationId,
"Read Emails",
"Outlook Email Reading",
- f"Folder: {parameters.get('folder', 'Inbox')}"
+ f"Folder: {parameters.get('folder', 'Inbox')}",
+ parentOperationId=parentOperationId
)
connectionReference = parameters.get("connectionReference")
@@ -1546,11 +1548,13 @@ Return JSON:
operationId = f"outlook_send_{workflowId}_{int(time.time())}"
# Start progress tracking
+ parentOperationId = parameters.get('parentOperationId')
self.services.chat.progressLogStart(
operationId,
"Send Draft Email",
"Outlook Email Sending",
- f"Processing {len(parameters.get('documentList', []))} draft(s)"
+ f"Processing {len(parameters.get('documentList', []))} draft(s)",
+ parentOperationId=parentOperationId
)
connectionReference = parameters.get("connectionReference")
diff --git a/modules/workflows/methods/methodSharepoint.py b/modules/workflows/methods/methodSharepoint.py
index da3db26b..d5109251 100644
--- a/modules/workflows/methods/methodSharepoint.py
+++ b/modules/workflows/methods/methodSharepoint.py
@@ -7,7 +7,7 @@ import logging
import re
import json
from typing import Dict, Any, List, Optional
-from datetime import datetime, UTC
+from datetime import datetime, UTC, timedelta, timezone
import urllib
import aiohttp
import asyncio
@@ -122,103 +122,26 @@ class MethodSharepoint(MethodBase):
logger.error(f"Error extracting hostname from webUrl '{webUrl}': {str(e)}")
return None
- async def _getSiteByStandardPath(self, sitePath: str) -> Optional[Dict[str, Any]]:
- """
- Get SharePoint site directly by Microsoft-standard path (/sites/SiteName)
- without loading all sites. Uses hostname from first available site.
-
- Parameters:
- sitePath (str): Site path like 'company-share' (without /sites/ prefix)
-
- Returns:
- Optional[Dict[str, Any]]: Site information if found, None otherwise
- """
- try:
- # Get hostname from first available site (minimal load - only 1 site)
- minimalSites = await self._discoverSharePointSites(limit=1)
- if not minimalSites:
- logger.warning("No sites available to extract hostname")
- return None
-
- hostname = self._extractHostnameFromWebUrl(minimalSites[0].get("webUrl"))
- if not hostname:
- logger.warning("Could not extract hostname from site")
- return None
-
- logger.info(f"Extracted hostname '{hostname}' from first site, now getting site by path: {sitePath}")
-
- # Get site directly using hostname + path
- endpoint = f"sites/{hostname}:/sites/{sitePath}"
- result = await self._makeGraphApiCall(endpoint)
-
- if "error" in result:
- logger.warning(f"Could not get site directly by path '{sitePath}': {result['error']}")
- return None
-
- siteInfo = {
- "id": result.get("id"),
- "displayName": result.get("displayName"),
- "name": result.get("name"),
- "webUrl": result.get("webUrl"),
- "description": result.get("description"),
- "createdDateTime": result.get("createdDateTime"),
- "lastModifiedDateTime": result.get("lastModifiedDateTime")
- }
-
- logger.info(f"Successfully got site by standard path: {siteInfo['displayName']} (ID: {siteInfo['id']})")
- return siteInfo
-
- except Exception as e:
- logger.error(f"Error getting site by standard path '{sitePath}': {str(e)}")
- return None
-
- def _filterSitesByHint(self, sites: List[Dict[str, Any]], siteHint: str) -> List[Dict[str, Any]]:
- """Filter discovered sites by a human-entered site hint (case-insensitive substring)."""
- try:
- if not siteHint:
- return sites
- hint = siteHint.strip().lower()
- filtered: List[Dict[str, Any]] = []
- for site in sites:
- name = (site.get("displayName") or "").lower()
- webUrl = (site.get("webUrl") or "").lower()
- if hint in name or hint in webUrl:
- filtered.append(site)
- return filtered if filtered else sites
- except Exception as e:
- logger.error(f"Error filtering sites by hint '{siteHint}': {str(e)}")
- return sites
-
def _extractSiteFromStandardPath(self, pathQuery: str) -> Optional[Dict[str, str]]:
"""
- Extract site name from Microsoft-standard server-relative path:
- /sites/company-share/Freigegebene Dokumente/...
-
- Returns dict with keys: siteName, innerPath (no leading slash) on success, else None.
+ Extract site name from Microsoft-standard server-relative path.
+ Delegates to SharePoint service.
"""
- try:
- if not pathQuery or not pathQuery.startswith('/sites/'):
- return None
-
- # Remove leading /sites/ prefix
- remainder = pathQuery[7:] # len('/sites/') = 7
-
- # Split on first '/' to get site name
- if '/' not in remainder:
- # Only site name, no inner path
- return {"siteName": remainder, "innerPath": ""}
-
- siteName, inner = remainder.split('/', 1)
- siteName = siteName.strip()
- innerPath = inner.strip()
-
- if not siteName:
- return None
-
- return {"siteName": siteName, "innerPath": innerPath}
- except Exception as e:
- logger.error(f"Error extracting site from standard path '{pathQuery}': {str(e)}")
- return None
+ return self.services.sharepoint.extractSiteFromStandardPath(pathQuery)
+
+ async def _getSiteByStandardPath(self, sitePath: str) -> Optional[Dict[str, Any]]:
+ """
+ Get SharePoint site directly by Microsoft-standard path.
+ Delegates to SharePoint service.
+ """
+ return await self.services.sharepoint.getSiteByStandardPath(sitePath)
+
+ def _filterSitesByHint(self, sites: List[Dict[str, Any]], siteHint: str) -> List[Dict[str, Any]]:
+ """
+ Filter discovered sites by a human-entered site hint.
+ Delegates to SharePoint service.
+ """
+ return self.services.sharepoint.filterSitesByHint(sites, siteHint)
def _parseSearchQuery(self, searchQuery: str) -> tuple[str, str, str, dict]:
"""
@@ -624,6 +547,170 @@ class MethodSharepoint(MethodBase):
except Exception as e:
logger.error(f"Error getting site ID: {str(e)}")
return ""
+
+ async def _parseDocumentListForFoundDocuments(self, documentList: Any) -> tuple[Optional[List[Dict[str, Any]]], Optional[List[Dict[str, Any]]], Optional[str]]:
+ """
+ Parse documentList to extract foundDocuments and site information.
+
+ Parameters:
+ documentList: Document list (can be list, DocumentReferenceList, or string)
+
+ Returns:
+ tuple: (foundDocuments, sites, errorMessage)
+ - foundDocuments: List of found documents from findDocumentPath result
+ - sites: List of site dictionaries with id, displayName, webUrl
+ - errorMessage: Error message if parsing failed, None otherwise
+ """
+ try:
+ if isinstance(documentList, str):
+ documentList = [documentList]
+
+ # Resolve documentList to get actual documents
+ from modules.datamodels.datamodelDocref import DocumentReferenceList
+ if isinstance(documentList, DocumentReferenceList):
+ docRefList = documentList
+ elif isinstance(documentList, list):
+ docRefList = DocumentReferenceList.from_string_list(documentList)
+ else:
+ docRefList = DocumentReferenceList(references=[])
+
+ chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docRefList)
+ if not chatDocuments:
+ return None, None, "No documents found for the provided document list"
+
+ firstDocument = chatDocuments[0]
+ fileData = self.services.chat.getFileData(firstDocument.fileId)
+ if not fileData:
+ return None, None, None # No fileData, but not an error (might be regular file)
+
+ try:
+ resultData = json.loads(fileData)
+ foundDocuments = resultData.get("foundDocuments", [])
+
+ # If no foundDocuments, check if it's a listDocuments result (has listResults)
+ if not foundDocuments and "listResults" in resultData:
+ logger.info(f"documentList contains listResults from listDocuments, converting to foundDocuments format")
+ listResults = resultData.get("listResults", [])
+ foundDocuments = []
+ siteIdFromList = None
+ siteNameFromList = None
+
+ for listResult in listResults:
+ siteResults = listResult.get("siteResults", [])
+ for siteResult in siteResults:
+ items = siteResult.get("items", [])
+ # Extract site info from first item if available
+ if items and not siteIdFromList:
+ siteNameFromList = items[0].get("siteName")
+
+ for item in items:
+ # Convert listDocuments item format to foundDocuments format
+ if item.get("type") == "file":
+ foundDoc = {
+ "id": item.get("id"),
+ "name": item.get("name"),
+ "type": "file",
+ "siteName": item.get("siteName"),
+ "siteId": None, # Will be determined from site discovery
+ "webUrl": item.get("webUrl"),
+ "fullPath": item.get("webUrl", ""),
+ "parentPath": item.get("parentPath", "")
+ }
+ foundDocuments.append(foundDoc)
+
+ # Discover sites to get siteId if we have siteName
+ if foundDocuments and siteNameFromList and not siteIdFromList:
+ logger.info(f"Discovering sites to find siteId for '{siteNameFromList}'")
+ allSites = await self._discoverSharePointSites()
+ matchingSites = self._filterSitesByHint(allSites, siteNameFromList)
+ if matchingSites:
+ siteIdFromList = matchingSites[0].get("id")
+ # Update all foundDocuments with siteId
+ for doc in foundDocuments:
+ doc["siteId"] = siteIdFromList
+ logger.info(f"Found siteId '{siteIdFromList}' for site '{siteNameFromList}'")
+
+ logger.info(f"Converted {len(foundDocuments)} files from listResults format")
+
+ if not foundDocuments:
+ return None, None, None # No foundDocuments, but not an error
+
+ # Extract site information from foundDocuments
+ firstDoc = foundDocuments[0]
+ siteName = firstDoc.get("siteName")
+ siteId = firstDoc.get("siteId")
+
+ # If siteId is missing (from listDocuments conversion), discover sites to find it
+ if siteName and not siteId:
+ logger.info(f"Site ID missing, discovering sites to find siteId for '{siteName}'")
+ allSites = await self._discoverSharePointSites()
+ matchingSites = self._filterSitesByHint(allSites, siteName)
+ if matchingSites:
+ siteId = matchingSites[0].get("id")
+ logger.info(f"Found siteId '{siteId}' for site '{siteName}'")
+
+ sites = None
+ if siteName and siteId:
+ sites = [{
+ "id": siteId,
+ "displayName": siteName,
+ "webUrl": firstDoc.get("webUrl", "")
+ }]
+ logger.info(f"Using specific site from documentList: {siteName} (ID: {siteId})")
+ elif siteName:
+ # Try to get site by name
+ allSites = await self._discoverSharePointSites()
+ matchingSites = self._filterSitesByHint(allSites, siteName)
+ if matchingSites:
+ sites = [{
+ "id": matchingSites[0].get("id"),
+ "displayName": siteName,
+ "webUrl": matchingSites[0].get("webUrl", "")
+ }]
+ logger.info(f"Found site by name: {siteName} (ID: {sites[0]['id']})")
+ else:
+ return None, None, f"Site '{siteName}' not found. Cannot determine target site."
+ else:
+ return None, None, "Site information missing from documentList. Cannot determine target site."
+
+ return foundDocuments, sites, None
+
+ except json.JSONDecodeError as e:
+ return None, None, f"Invalid JSON in documentList: {str(e)}"
+ except Exception as e:
+ return None, None, f"Error processing documentList: {str(e)}"
+
+ except Exception as e:
+ logger.error(f"Error parsing documentList: {str(e)}")
+ return None, None, f"Error parsing documentList: {str(e)}"
+
+ async def _resolveSitesFromPathQuery(self, pathQuery: str) -> tuple[List[Dict[str, Any]], Optional[str]]:
+ """
+ Resolve sites from pathQuery using SharePoint service helper methods.
+
+ Parameters:
+ pathQuery (str): Path query string
+
+ Returns:
+ tuple: (sites, errorMessage)
+ - sites: List of site dictionaries
+ - errorMessage: Error message if resolution failed, None otherwise
+ """
+ try:
+ # Validate pathQuery format
+ isValid, errorMsg = self.services.sharepoint.validatePathQuery(pathQuery)
+ if not isValid:
+ return [], errorMsg
+
+ # Resolve sites using service helper
+ sites = await self.services.sharepoint.resolveSitesFromPathQuery(pathQuery)
+ if not sites:
+ return [], "No SharePoint sites found or accessible"
+
+ return sites, None
+ except Exception as e:
+ logger.error(f"Error resolving sites from pathQuery '{pathQuery}': {str(e)}")
+ return [], f"Error resolving sites from pathQuery: {str(e)}"
@action
@@ -638,23 +725,44 @@ class MethodSharepoint(MethodBase):
- connectionReference (str, required): Microsoft connection label.
- site (str, optional): Site hint.
- searchQuery (str, required): Search terms or path.
- - maxResults (int, optional): Maximum items to return. Default: 100.
+ - maxResults (int, optional): Maximum items to return. Default: 1000.
"""
+ import time
+ operationId = None
try:
+ # Init progress logger
+ workflowId = self.services.workflow.id if self.services.workflow else f"no-workflow-{int(time.time())}"
+ operationId = f"sharepoint_find_{workflowId}_{int(time.time())}"
+
+ # Start progress tracking
+ parentOperationId = parameters.get('parentOperationId')
+ self.services.chat.progressLogStart(
+ operationId,
+ "Find Document Path",
+ "SharePoint Search",
+ f"Query: {parameters.get('searchQuery', '*')}",
+ parentOperationId=parentOperationId
+ )
+
connectionReference = parameters.get("connectionReference")
site = parameters.get("site")
searchQuery = parameters.get("searchQuery", "*")
- maxResults = parameters.get("maxResults", 100)
+ maxResults = parameters.get("maxResults", 1000)
if not connectionReference:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
return ActionResult.isFailure(error="Connection reference is required")
# Parse searchQuery to extract path, search terms, search type, and options
pathQuery, fileQuery, searchType, searchOptions = self._parseSearchQuery(searchQuery)
logger.debug(f"Parsed searchQuery '{searchQuery}' -> pathQuery='{pathQuery}', fileQuery='{fileQuery}', searchType='{searchType}'")
+ self.services.chat.progressLogUpdate(operationId, 0.2, "Getting Microsoft connection")
connection = self._getMicrosoftConnection(connectionReference)
if not connection:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
return ActionResult.isFailure(error="No valid Microsoft connection found for the provided connection reference")
# Extract site name from pathQuery if it contains Microsoft-standard path (/sites/SiteName/...)
@@ -683,25 +791,34 @@ class MethodSharepoint(MethodBase):
siteHintToUse = site or siteFromPath or searchOptions.get("site_hint")
# Discover SharePoint sites - use targeted approach when site hint is available
+ self.services.chat.progressLogUpdate(operationId, 0.3, "Discovering SharePoint sites")
if siteHintToUse:
# When site hint is available, discover all sites first, then filter
allSites = await self._discoverSharePointSites()
if not allSites:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
return ActionResult.isFailure(error="No SharePoint sites found or accessible")
sites = self._filterSitesByHint(allSites, siteHintToUse)
logger.info(f"Filtered sites by site hint '{siteHintToUse}' -> {len(sites)} sites")
if not sites:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
return ActionResult.isFailure(error=f"No SharePoint sites found matching '{siteHintToUse}'")
else:
# No site hint - discover all sites
sites = await self._discoverSharePointSites()
if not sites:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
return ActionResult.isFailure(error="No SharePoint sites found or accessible")
# Resolve path query into search paths
searchPaths = self._resolvePathQuery(pathQuery)
+ self.services.chat.progressLogUpdate(operationId, 0.5, f"Searching across {len(sites)} site(s)")
+
try:
# Search across all discovered sites
foundDocuments = []
@@ -763,17 +880,7 @@ class MethodSharepoint(MethodBase):
resource = item
# Use the same detection logic as our test
- isFolder = False
- if 'folder' in resource:
- isFolder = True
- else:
- # Try to detect by URL pattern or other indicators
- webUrl = resource.get('webUrl', '')
- name = resource.get('name', '')
-
- # Check if URL has no file extension and looks like a folder path
- if '.' not in name and ('/' in webUrl or '\\' in webUrl):
- isFolder = True
+ isFolder = self.services.sharepoint.detectFolderType(resource)
if isFolder:
folderItems.append(item)
@@ -823,17 +930,7 @@ class MethodSharepoint(MethodBase):
logger.warning(f"Error extracting site info from URL {webUrl}: {e}")
# Use improved folder detection logic
- isFolder = False
- if 'folder' in item:
- isFolder = True
- else:
- # Try to detect by URL pattern or other indicators
- name = item.get('name', '')
-
- # Check if URL has no file extension and looks like a folder path
- if '.' not in name and ('/' in webUrl or '\\' in webUrl):
- isFolder = True
-
+ isFolder = self.services.sharepoint.detectFolderType(item)
itemType = "folder" if isFolder else "file"
itemPath = item.get("parentReference", {}).get("path", "")
logger.debug(f"Processing {itemType}: '{itemName}' at path: '{itemPath}'")
@@ -986,17 +1083,7 @@ class MethodSharepoint(MethodBase):
itemName = item.get("name", "")
# Use improved folder detection logic
- isFolder = False
- if 'folder' in item:
- isFolder = True
- else:
- # Try to detect by URL pattern or other indicators
- webUrl = item.get('webUrl', '')
- name = item.get('name', '')
-
- # Check if URL has no file extension and looks like a folder path
- if '.' not in name and ('/' in webUrl or '\\' in webUrl):
- isFolder = True
+ isFolder = self.services.sharepoint.detectFolderType(item)
itemType = "folder" if isFolder else "file"
itemPath = item.get("parentReference", {}).get("path", "")
@@ -1056,6 +1143,8 @@ class MethodSharepoint(MethodBase):
foundDocuments = foundDocuments[:maxResults]
logger.info(f"Limited results to {maxResults} items")
+ self.services.chat.progressLogUpdate(operationId, 0.9, f"Found {len(foundDocuments)} document(s)")
+
resultData = {
"searchQuery": searchQuery,
"totalResults": len(foundDocuments),
@@ -1066,6 +1155,8 @@ class MethodSharepoint(MethodBase):
except Exception as e:
logger.error(f"Error searching SharePoint: {str(e)}")
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
return ActionResult.isFailure(error=str(e))
# Use default JSON format for output
@@ -1080,6 +1171,7 @@ class MethodSharepoint(MethodBase):
"hasResults": len(foundDocuments) > 0
}
+ self.services.chat.progressLogFinish(operationId, True)
return ActionResult(
success=True,
documents=[
@@ -1094,6 +1186,11 @@ class MethodSharepoint(MethodBase):
except Exception as e:
logger.error(f"Error finding document path: {str(e)}")
+ if operationId:
+ try:
+ self.services.chat.progressLogFinish(operationId, False)
+ except:
+ pass
return ActionResult.isFailure(error=str(e))
@action
@@ -1101,7 +1198,7 @@ class MethodSharepoint(MethodBase):
"""
GENERAL:
- Purpose: Read documents from SharePoint and extract content/metadata.
- - Input requirements: connectionReference (required); optional documentList, pathObject, or pathQuery; includeMetadata.
+ - Input requirements: connectionReference (required); documentList or pathQuery (required); includeMetadata (optional).
- Output format: Standardized ActionDocument format (documentName, documentData, mimeType).
- Binary files (PDFs, etc.) are Base64-encoded in documentData.
- Text files are stored as plain text in documentData.
@@ -1109,9 +1206,8 @@ class MethodSharepoint(MethodBase):
Parameters:
- connectionReference (str, required): Microsoft connection label.
- - pathObject (str, optional): Reference to a previous path result (from findDocumentPath).
- - documentList (list, optional): Document list reference(s) to read (backward compatibility).
- - pathQuery (str, optional): Path query if no pathObject (backward compatibility).
+ - documentList (list, optional): Document list reference(s) containing findDocumentPath result.
+ - pathQuery (str, optional): Direct path query if no documentList (e.g., /sites/SiteName/FolderPath).
- includeMetadata (bool, optional): Include metadata. Default: True.
Returns:
@@ -1128,19 +1224,18 @@ class MethodSharepoint(MethodBase):
operationId = f"sharepoint_read_{workflowId}_{int(time.time())}"
# Start progress tracking
+ parentOperationId = parameters.get('parentOperationId')
self.services.chat.progressLogStart(
operationId,
"Read Documents",
"SharePoint Document Reading",
- f"Path: {parameters.get('pathQuery', parameters.get('pathObject', '*'))}"
+ "Processing document list",
+ parentOperationId=parentOperationId
)
documentList = parameters.get("documentList")
- if isinstance(documentList, str):
- documentList = [documentList]
- connectionReference = parameters.get("connectionReference")
pathQuery = parameters.get("pathQuery", "*")
- pathObject = parameters.get("pathObject")
+ connectionReference = parameters.get("connectionReference")
includeMetadata = parameters.get("includeMetadata", True)
# Validate connection reference
@@ -1149,7 +1244,13 @@ class MethodSharepoint(MethodBase):
self.services.chat.progressLogFinish(operationId, False)
return ActionResult.isFailure(error="Connection reference is required")
- # Get connection first - needed for both pathObject and documentList approaches
+ # Require either documentList or pathQuery
+ if not documentList and (not pathQuery or pathQuery.strip() == "" or pathQuery.strip() == "*"):
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Either documentList or pathQuery is required")
+
+ # Get connection first
self.services.chat.progressLogUpdate(operationId, 0.2, "Getting Microsoft connection")
connection = self._getMicrosoftConnection(connectionReference)
if not connection:
@@ -1157,132 +1258,27 @@ class MethodSharepoint(MethodBase):
self.services.chat.progressLogFinish(operationId, False)
return ActionResult.isFailure(error="No valid Microsoft connection found for the provided connection reference")
- # If pathObject is provided, extract SharePoint file IDs and read them directly
- # pathObject contains the result from findDocumentPath with foundDocuments array
+ # Parse documentList to extract foundDocuments and site information
sharePointFileIds = None
sites = None
- if pathObject:
- if pathQuery and pathQuery != "*":
- logger.debug(f"Both pathObject and pathQuery provided - using pathObject (pathQuery '{pathQuery}' will be ignored)")
- try:
- # Resolve the reference label to get the actual document list
- from modules.datamodels.datamodelDocref import DocumentReferenceList
- pathObjectDocuments = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([pathObject]))
- if not pathObjectDocuments or len(pathObjectDocuments) == 0:
+
+ if documentList:
+ foundDocuments, sites, errorMsg = await self._parseDocumentListForFoundDocuments(documentList)
+ if errorMsg:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error=errorMsg)
+
+ if foundDocuments:
+ # Extract SharePoint file IDs from foundDocuments
+ sharePointFileIds = [doc.get("id") for doc in foundDocuments if doc.get("type") == "file"]
+ if not sharePointFileIds:
if operationId:
self.services.chat.progressLogFinish(operationId, False)
- return ActionResult.isFailure(error=f"No document list found for reference: {pathObject}")
-
- # Get the first document's content (which should be the JSON from findDocumentPath)
- firstDocument = pathObjectDocuments[0]
- fileData = self.services.chat.getFileData(firstDocument.fileId)
- if not fileData:
- return ActionResult.isFailure(error=f"No file data found for document: {pathObject}")
-
- # Parse the JSON content
- resultData = json.loads(fileData)
- foundDocuments = resultData.get("foundDocuments", [])
-
- # If no foundDocuments, check if it's a listDocuments result (has listResults)
- if not foundDocuments and "listResults" in resultData:
- logger.info(f"pathObject contains listResults from listDocuments, converting to foundDocuments format")
- listResults = resultData.get("listResults", [])
- foundDocuments = []
- siteIdFromList = None
- siteNameFromList = None
-
- for listResult in listResults:
- siteResults = listResult.get("siteResults", [])
- for siteResult in siteResults:
- items = siteResult.get("items", [])
- # Extract site info from first item if available
- if items and not siteIdFromList:
- # Try to get site info from the siteResult structure
- # We need to discover sites to get the siteId
- siteNameFromList = items[0].get("siteName")
-
- for item in items:
- # Convert listDocuments item format to foundDocuments format
- if item.get("type") == "file":
- foundDoc = {
- "id": item.get("id"),
- "name": item.get("name"),
- "type": "file",
- "siteName": item.get("siteName"),
- "siteId": None, # Will be determined from site discovery
- "webUrl": item.get("webUrl"),
- "fullPath": item.get("webUrl", ""),
- "parentPath": item.get("parentPath", "")
- }
- foundDocuments.append(foundDoc)
-
- # Discover sites to get siteId if we have siteName
- if foundDocuments and siteNameFromList and not siteIdFromList:
- logger.info(f"Discovering sites to find siteId for '{siteNameFromList}'")
- allSites = await self._discoverSharePointSites()
- matchingSites = self._filterSitesByHint(allSites, siteNameFromList)
- if matchingSites:
- siteIdFromList = matchingSites[0].get("id")
- # Update all foundDocuments with siteId
- for doc in foundDocuments:
- doc["siteId"] = siteIdFromList
- logger.info(f"Found siteId '{siteIdFromList}' for site '{siteNameFromList}'")
-
- logger.info(f"Converted {len(foundDocuments)} files from listResults format")
-
- if foundDocuments:
- # Extract SharePoint file IDs from foundDocuments
- sharePointFileIds = [doc.get("id") for doc in foundDocuments if doc.get("type") == "file"]
- if not sharePointFileIds:
- return ActionResult.isFailure(error=f"No files found in pathObject '{pathObject}'")
- logger.info(f"Extracted {len(sharePointFileIds)} SharePoint file IDs from pathObject '{pathObject}'")
-
- # Extract site information from foundDocuments
- if foundDocuments:
- firstDoc = foundDocuments[0]
- siteName = firstDoc.get("siteName")
- siteId = firstDoc.get("siteId")
-
- # If siteId is missing (from listDocuments conversion), discover sites to find it
- if siteName and not siteId:
- logger.info(f"Site ID missing, discovering sites to find siteId for '{siteName}'")
- allSites = await self._discoverSharePointSites()
- matchingSites = self._filterSitesByHint(allSites, siteName)
- if matchingSites:
- siteId = matchingSites[0].get("id")
- logger.info(f"Found siteId '{siteId}' for site '{siteName}'")
-
- if siteName and siteId:
- sites = [{
- "id": siteId,
- "displayName": siteName,
- "webUrl": firstDoc.get("webUrl", "")
- }]
- logger.info(f"Using specific site from pathObject: {siteName} (ID: {siteId})")
- elif siteName:
- # Try to get site by name
- allSites = await self._discoverSharePointSites()
- matchingSites = self._filterSitesByHint(allSites, siteName)
- if matchingSites:
- sites = [{
- "id": matchingSites[0].get("id"),
- "displayName": siteName,
- "webUrl": matchingSites[0].get("webUrl", "")
- }]
- logger.info(f"Found site by name: {siteName} (ID: {sites[0]['id']})")
- else:
- return ActionResult.isFailure(error=f"Site '{siteName}' not found. Cannot determine target site for read operation.")
- else:
- return ActionResult.isFailure(error="Site information missing from pathObject. Cannot determine target site for read operation.")
- else:
- return ActionResult.isFailure(error=f"No documents found in pathObject '{pathObject}'")
-
- except json.JSONDecodeError as e:
- return ActionResult.isFailure(error=f"Invalid JSON in pathObject: {str(e)}")
- except Exception as e:
- return ActionResult.isFailure(error=f"Error resolving pathObject reference: {str(e)}")
+ return ActionResult.isFailure(error="No files found in documentList from findDocumentPath result")
+ logger.info(f"Extracted {len(sharePointFileIds)} SharePoint file IDs from documentList")
- # If we have SharePoint file IDs from pathObject, read them directly
+ # If we have SharePoint file IDs from documentList (findDocumentPath result), read them directly
if sharePointFileIds and sites:
# Read SharePoint files directly using their IDs
readResults = []
@@ -1338,7 +1334,7 @@ class MethodSharepoint(MethodBase):
if not readResults:
self.services.chat.progressLogFinish(operationId, False)
- return ActionResult.isFailure(error="No files could be read from pathObject")
+ return ActionResult.isFailure(error="No files could be read from documentList")
# Convert read results to ActionDocument objects
# IMPORTANT: For binary files (PDFs), store Base64-encoded content directly in documentData
@@ -1442,232 +1438,24 @@ class MethodSharepoint(MethodBase):
self.services.chat.progressLogFinish(operationId, True)
return ActionResult.isSuccess(documents=actionDocuments)
- # Fallback: Use documentList parameter (for backward compatibility)
- # Validate documentList
- if not documentList:
- return ActionResult.isFailure(error="Document list reference is required. Either provide documentList parameter or use pathObject that contains files.")
+ # If no sites from documentList, try pathQuery fallback
+ if not sites and pathQuery and pathQuery.strip() != "" and pathQuery.strip() != "*":
+ sites, errorMsg = await self._resolveSitesFromPathQuery(pathQuery)
+ if errorMsg:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error=errorMsg)
- # Get documents from reference - ensure documentList is a list, not a string
- # documentList is already normalized above
- from modules.datamodels.datamodelDocref import DocumentReferenceList
- # Convert to DocumentReferenceList if needed
- if isinstance(documentList, DocumentReferenceList):
- docRefList = documentList
- elif isinstance(documentList, list):
- docRefList = DocumentReferenceList.from_string_list(documentList)
- elif isinstance(documentList, str):
- docRefList = DocumentReferenceList.from_string_list([documentList])
- else:
- docRefList = DocumentReferenceList(references=[])
- chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docRefList)
-
- if not chatDocuments:
- return ActionResult.isFailure(error="No documents found for the provided reference")
-
- # Determine sites to use - strict validation: pathObject → pathQuery → ERROR
+ # If still no sites, return error
if not sites:
- # Step 2: If no pathObject, check pathQuery
- if pathQuery and pathQuery.strip() != "" and pathQuery.strip() != "*":
- # Validate pathQuery format
- if not pathQuery.startswith('/'):
- return ActionResult.isFailure(error="pathQuery must start with '/' and include site name with Microsoft-standard syntax /sites//... e.g. /sites/company-share/Freigegebene Dokumente/Work")
-
- # Check if pathQuery contains search terms (words without proper path structure)
- validPathPrefixes = ['/sites/', '/Documents', '/documents', '/Shared Documents', '/shared documents']
- if not any(pathQuery.startswith(prefix) for prefix in validPathPrefixes):
- return ActionResult.isFailure(error=f"Invalid pathQuery '{pathQuery}'. This appears to be search terms, not a valid SharePoint path. Use findDocumentPath action first to search for folders, then use the returned folder path as pathQuery.")
-
- # If pathQuery starts with Microsoft-standard /sites/, try to get site directly
- directSite = None
- if pathQuery.startswith('/sites/'):
- parsedPath = self._extractSiteFromStandardPath(pathQuery)
- if parsedPath:
- siteName = parsedPath.get("siteName")
- # Try to get site directly by path (optimization - no need to load all 60 sites)
- directSite = await self._getSiteByStandardPath(siteName)
- if directSite:
- logger.info(f"Got site directly by standard path - no need to discover all sites")
- sites = [directSite]
- else:
- logger.warning(f"Could not get site directly, falling back to site discovery")
-
- # If we didn't get the site directly, use discovery and filtering
- if not directSite:
- # For pathQuery, we need to discover sites to find the specific one
- allSites = await self._discoverSharePointSites()
- if not allSites:
- return ActionResult.isFailure(error="No SharePoint sites found or accessible")
-
- # If pathQuery starts with Microsoft-standard /sites/, extract site name and filter
- if pathQuery.startswith('/sites/'):
- parsedPath = self._extractSiteFromStandardPath(pathQuery)
- if parsedPath:
- siteName = parsedPath.get("siteName")
- # Filter sites by name (case-insensitive substring match)
- sites = self._filterSitesByHint(allSites, siteName)
- if not sites:
- return ActionResult.isFailure(error=f"No SharePoint site found matching '{siteName}'")
- logger.info(f"Filtered to site(s) matching '{siteName}': {[s['displayName'] for s in sites]}")
- else:
- sites = allSites
- else:
- sites = allSites
- else:
- # Step 3: Both pathObject and pathQuery failed - ERROR, NO FALLBACK
- return ActionResult.isFailure(error="No valid read path provided. Either provide pathObject (from findDocumentPath) or a valid pathQuery with specific site information.")
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Either documentList must contain findDocumentPath result with file information, or pathQuery must be provided. Use findDocumentPath first to get file paths, or provide pathQuery directly.")
- if not sites:
- return ActionResult.isFailure(error="No valid target site determined for read operation")
-
- # Resolve path query into search paths
- searchPaths = self._resolvePathQuery(pathQuery)
-
- # Process each chat document across all sites
- readResults = []
-
- for i, chatDocument in enumerate(chatDocuments):
- try:
- fileId = chatDocument.fileId
- fileName = chatDocument.fileName
-
- # Search for this file across all sites
- fileFound = False
-
- for site in sites:
- siteId = site["id"]
- siteName = site["displayName"]
- siteUrl = site["webUrl"]
-
- # Try to find the file by name in this site
- searchQuery = fileName.replace("'", "''") # Escape single quotes for OData
- endpoint = f"sites/{siteId}/drive/root/search(q='{searchQuery}')"
-
- searchResult = await self._makeGraphApiCall(endpoint)
-
- if "error" in searchResult:
- continue
-
- items = searchResult.get("value", [])
- for item in items:
- if item.get("name") == fileName:
- # Found the file, get its details
- fileId = item.get("id")
- fileEndpoint = f"sites/{siteId}/drive/items/{fileId}"
-
- # Get file metadata
- fileInfoResult = await self._makeGraphApiCall(fileEndpoint)
-
- if "error" in fileInfoResult:
- continue
-
- # Build result with metadata
- resultItem = {
- "fileId": fileId,
- "fileName": fileName,
- "sharepointFileId": fileId,
- "siteName": siteName,
- "siteUrl": siteUrl,
- "size": fileInfoResult.get("size", 0),
- "createdDateTime": fileInfoResult.get("createdDateTime"),
- "lastModifiedDateTime": fileInfoResult.get("lastModifiedDateTime"),
- "webUrl": fileInfoResult.get("webUrl")
- }
-
- # Add metadata if requested
- if includeMetadata:
- resultItem["metadata"] = {
- "mimeType": fileInfoResult.get("file", {}).get("mimeType"),
- "downloadUrl": fileInfoResult.get("@microsoft.graph.downloadUrl"),
- "createdBy": fileInfoResult.get("createdBy", {}),
- "lastModifiedBy": fileInfoResult.get("lastModifiedBy", {}),
- "parentReference": fileInfoResult.get("parentReference", {})
- }
-
- # Get file content if it's a readable format
- mimeType = fileInfoResult.get("file", {}).get("mimeType", "")
- if mimeType.startswith("text/") or mimeType in [
- "application/json", "application/xml", "application/javascript"
- ]:
- # Download the file content
- contentEndpoint = f"sites/{siteId}/drive/items/{fileId}/content"
-
- # For content download, we need to handle binary data
- try:
- async with aiohttp.ClientSession() as session:
- headers = {"Authorization": f"Bearer {self.services.sharepoint._target.accessToken}"}
- async with session.get(f"https://graph.microsoft.com/v1.0/{contentEndpoint}", headers=headers) as response:
- if response.status == 200:
- content = await response.text()
- resultItem["content"] = content
- else:
- resultItem["content"] = f"Could not download content: HTTP {response.status}"
- except Exception as e:
- resultItem["content"] = f"Error downloading content: {str(e)}"
- else:
- resultItem["content"] = f"Binary file type ({mimeType}) - content not retrieved"
-
- readResults.append(resultItem)
- fileFound = True
- break
-
- if fileFound:
- break
-
- if not fileFound:
- readResults.append({
- "fileId": fileId,
- "fileName": fileName,
- "error": "File not found in any accessible SharePoint site",
- "content": None
- })
-
- except Exception as e:
- logger.error(f"Error reading document {chatDocument.fileName}: {str(e)}")
- readResults.append({
- "fileId": chatDocument.fileId,
- "fileName": chatDocument.fileName,
- "error": str(e),
- "content": None
- })
-
- resultData = {
- "connectionReference": connectionReference,
- "pathQuery": pathQuery,
- "documentList": documentList,
- "includeMetadata": includeMetadata,
- "sitesSearched": len(sites),
- "readResults": readResults,
- "connection": {
- "id": connection["id"],
- "authority": "microsoft",
- "reference": connectionReference
- },
- "timestamp": self.services.utils.timestampGetUtc()
- }
-
- # Use default JSON format for output
- outputExtension = ".json" # Default
- outputMimeType = "application/json" # Default
-
- validationMetadata = {
- "actionType": "sharepoint.readDocuments",
- "connectionReference": connectionReference,
- "documentCount": len(readResults),
- "includeMetadata": includeMetadata,
- "sitesSearched": len(sites)
- }
-
- return ActionResult(
- success=True,
- documents=[
- ActionDocument(
- documentName=f"sharepoint_documents_{self._format_timestamp_for_filename()}{outputExtension}",
- documentData=json.dumps(resultData, indent=2),
- mimeType=outputMimeType,
- validationMetadata=validationMetadata
- )
- ]
- )
+ # This should never be reached if logic above is correct
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Unexpected error: could not process documentList or pathQuery")
except Exception as e:
logger.error(f"Error reading SharePoint documents: {str(e)}")
if operationId:
@@ -1685,286 +1473,120 @@ class MethodSharepoint(MethodBase):
"""
GENERAL:
- Purpose: Upload documents to SharePoint. Only to choose this action with a connectionReference
- - Input requirements: connectionReference (required); documentList (required); optional pathObject or pathQuery.
+ - Input requirements: connectionReference (required); documentList (required); pathQuery (optional).
- Output format: JSON with upload status and file info.
Parameters:
- connectionReference (str, required): Microsoft connection label.
- - pathObject (str, optional): Reference to a previous path result.
- - pathQuery (str, optional): Upload target path if no pathObject.
- documentList (list, required): Document reference(s) to upload. File names are taken from the documents.
+ - pathQuery (str, optional): Direct upload target path if documentList doesn't contain findDocumentPath result (e.g., /sites/SiteName/FolderPath).
"""
+ import time
+ operationId = None
try:
+ # Init progress logger
+ workflowId = self.services.workflow.id if self.services.workflow else f"no-workflow-{int(time.time())}"
+ operationId = f"sharepoint_upload_{workflowId}_{int(time.time())}"
+
+ # Start progress tracking
+ parentOperationId = parameters.get('parentOperationId')
+ self.services.chat.progressLogStart(
+ operationId,
+ "Upload Document",
+ "SharePoint Upload",
+ "Processing document list",
+ parentOperationId=parentOperationId
+ )
+
connectionReference = parameters.get("connectionReference")
- pathQuery = parameters.get("pathQuery")
documentList = parameters.get("documentList")
+ pathQuery = parameters.get("pathQuery")
if isinstance(documentList, str):
documentList = [documentList]
- pathObject = parameters.get("pathObject")
- uploadPath = pathQuery
- logger.debug(f"Using pathQuery: {pathQuery}")
+ if not connectionReference:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Connection reference is required")
- if not connectionReference or not documentList:
- return ActionResult.isFailure(error="Connection reference and document list are required")
+ if not documentList:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Document list is required")
- # If pathObject is provided, extract folder IDs from it
- if pathObject:
- try:
- # Resolve the reference label to get the actual document list
- from modules.datamodels.datamodelDocref import DocumentReferenceList
- documentList = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([pathObject]))
- if not documentList or len(documentList) == 0:
- return ActionResult.isFailure(error=f"No document list found for reference: {pathObject}")
-
- # Get the first document's content (which should be the JSON)
- firstDocument = documentList[0]
- fileData = self.services.chat.getFileData(firstDocument.fileId)
- if not fileData:
- return ActionResult.isFailure(error=f"No file data found for document: {pathObject}")
-
- # Parse the JSON content
- resultData = json.loads(fileData)
-
- # Debug: Log the structure of the result document
- logger.info(f"Result document keys: {list(resultData.keys())}")
-
- # Handle different result document formats
- foundDocuments = []
-
- # Check if it's a direct SharePoint result (has foundDocuments)
- if "foundDocuments" in resultData:
- foundDocuments = resultData.get("foundDocuments", [])
- logger.info(f"Found {len(foundDocuments)} documents in foundDocuments array")
- # Check if it's an AI validation result (has result string with validationReport)
- elif "result" in resultData and "validationReport" in resultData["result"]:
- try:
- # Parse the nested JSON in the result field
- nestedResult = json.loads(resultData["result"])
- validationReport = nestedResult.get("validationReport", {})
- documentDetails = validationReport.get("documentDetails", {})
-
- if documentDetails:
- # Convert the single document details to the expected format
- doc = {
- "id": documentDetails.get("id"),
- "name": documentDetails.get("name"),
- "type": documentDetails.get("type", "").lower(), # Convert "Folder" to "folder"
- "siteName": documentDetails.get("siteName"),
- "siteId": documentDetails.get("siteId"),
- "fullPath": documentDetails.get("fullPath"),
- "webUrl": documentDetails.get("webUrl", ""),
- "parentPath": documentDetails.get("parentPath", "")
- }
- foundDocuments = [doc]
- logger.info(f"Extracted 1 document from validation report")
- except json.JSONDecodeError as e:
- logger.error(f"Failed to parse nested JSON in result field: {e}")
- return ActionResult.isFailure(error=f"Invalid nested JSON in pathObject: {str(e)}")
-
- # Debug: Log what we found in the result document
- logger.info(f"Result document contains {len(foundDocuments)} documents")
- for i, doc in enumerate(foundDocuments):
- logger.info(f" Document {i+1}: name='{doc.get('name')}', type='{doc.get('type')}', id='{doc.get('id')}'")
-
- # Extract folder information from the result
- folders = []
- for doc in foundDocuments:
- if doc.get("type") == "folder":
- folders.append(doc)
-
- logger.info(f"Found {len(folders)} folders in result document")
-
- if folders:
- # Use the first folder found - prefer folder ID for direct API calls
- firstFolder = folders[0]
- if firstFolder.get("id"):
- # Use folder ID directly for most reliable API calls
- uploadPath = firstFolder.get("id")
- logger.info(f"Using folder ID from pathObject: {uploadPath}")
- elif firstFolder.get("fullPath"):
- # Extract the correct path portion from fullPath by removing site name
- fullPath = firstFolder.get("fullPath")
- # fullPath format: \\SiteName\\Library\\Folder\\SubFolder
- # We need to remove the first two parts (\\SiteName\\) to get the actual folder path
- pathParts = fullPath.lstrip('\\').split('\\')
- if len(pathParts) > 1:
- # Remove the first part (site name) and reconstruct the path
- actualPath = '\\'.join(pathParts[1:])
- uploadPath = actualPath
- logger.info(f"Extracted path from fullPath: {uploadPath}")
- else:
- uploadPath = fullPath
- logger.info(f"Using full path from pathObject (no site name to remove): {uploadPath}")
- else:
- return ActionResult.isFailure(error="No valid folder information found in pathObject")
- else:
- return ActionResult.isFailure(error="No folders found in pathObject")
-
- except json.JSONDecodeError as e:
- return ActionResult.isFailure(error=f"Invalid JSON in pathObject: {str(e)}")
- except Exception as e:
- return ActionResult.isFailure(error=f"Error resolving pathObject reference: {str(e)}")
+ # Parse documentList to extract folder path and site information
+ uploadPath, sites, filesToUpload, errorMsg = await self._parseDocumentListForFolder(documentList)
+ if errorMsg:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error=errorMsg)
- # Get Microsoft connection
- connection = self._getMicrosoftConnection(connectionReference)
- if not connection:
- return ActionResult.isFailure(error="No valid Microsoft connection found for the provided connection reference")
+ # If no folder path found from documentList, use pathQuery if provided
+ if not uploadPath and pathQuery and pathQuery.strip() != "" and pathQuery.strip() != "*":
+ uploadPath = pathQuery
+ logger.info(f"Using pathQuery for upload path: {uploadPath}")
+ # Resolve sites from pathQuery
+ sites, errorMsg = await self._resolveSitesFromPathQuery(pathQuery)
+ if errorMsg:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error=errorMsg)
- # Get documents from reference - ensure documentList is a list, not a string
- if isinstance(documentList, str):
- documentList = [documentList] # Convert string to list
- from modules.datamodels.datamodelDocref import DocumentReferenceList
- # Convert to DocumentReferenceList if needed
- if isinstance(documentList, DocumentReferenceList):
- docRefList = documentList
- elif isinstance(documentList, list):
- docRefList = DocumentReferenceList.from_string_list(documentList)
- elif isinstance(documentList, str):
- docRefList = DocumentReferenceList.from_string_list([documentList])
- else:
- docRefList = DocumentReferenceList(references=[])
- chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docRefList)
- if not chatDocuments:
- return ActionResult.isFailure(error="No documents found for the provided reference")
-
- # Determine sites to use based on whether pathObject was provided
- sites = None
- if pathObject:
- # When pathObject is provided, we should have specific site information
- # Extract site information from the pathObject result
- try:
- # Get the site information from the first folder in pathObject
- if 'foundDocuments' in locals() and foundDocuments:
- firstFolder = foundDocuments[0]
- siteName = firstFolder.get("siteName")
- siteId = firstFolder.get("siteId")
-
- if siteName and siteId:
- # Use the specific site from pathObject instead of discovering all sites
- sites = [{
- "id": siteId,
- "displayName": siteName,
- "webUrl": firstFolder.get("webUrl", "")
- }]
- logger.info(f"Using specific site from pathObject: {siteName} (ID: {siteId})")
- else:
- # Site info missing from pathObject - this is an error, not a fallback
- return ActionResult.isFailure(error="Site information missing from pathObject. Cannot determine target site for upload.")
- else:
- # No documents found in pathObject - this is an error
- return ActionResult.isFailure(error="No valid folder information found in pathObject. Cannot determine target site for upload.")
- except Exception as e:
- # Error processing pathObject - this is an error, not a fallback
- return ActionResult.isFailure(error=f"Error processing pathObject: {str(e)}. Cannot determine target site for upload.")
- else:
- # No pathObject provided - check if pathQuery is valid
- if not uploadPath or uploadPath.strip() == "" or uploadPath.strip() == "*":
- return ActionResult.isFailure(error="No valid upload path provided. Either provide pathObject (from findDocumentPath) or a valid pathQuery with specific site information.")
-
- # Validate pathQuery format
- if not uploadPath.startswith('/'):
- return ActionResult.isFailure(error="pathQuery must start with '/' and include site name with Microsoft-standard syntax /sites//... e.g. /sites/company-share/Freigegebene Dokumente/Work")
-
- # Check if uploadPath contains search terms (words without proper path structure)
- validPathPrefixes = ['/sites/', '/Documents', '/documents', '/Shared Documents', '/shared documents']
- if not any(uploadPath.startswith(prefix) for prefix in validPathPrefixes):
- return ActionResult.isFailure(error=f"Invalid pathQuery '{uploadPath}'. This appears to be search terms, not a valid SharePoint path. Use findDocumentPath action first to search for folders, then use the returned folder path as pathQuery.")
-
- # If uploadPath starts with Microsoft-standard /sites/, try to get site directly
- directSite = None
- if uploadPath.startswith('/sites/'):
- parsedPath = self._extractSiteFromStandardPath(uploadPath)
- if parsedPath:
- siteName = parsedPath.get("siteName")
- # Try to get site directly by path (optimization - no need to load all 60 sites)
- directSite = await self._getSiteByStandardPath(siteName)
- if directSite:
- logger.info(f"Got site directly by standard path - no need to discover all sites")
- sites = [directSite]
- else:
- logger.warning(f"Could not get site directly, falling back to site discovery")
-
- # If we didn't get the site directly, use discovery and filtering
- if not directSite:
- # For pathQuery, we need to discover sites to find the specific one
- allSites = await self._discoverSharePointSites()
- if not allSites:
- return ActionResult.isFailure(error="No SharePoint sites found or accessible")
-
- # If uploadPath starts with Microsoft-standard /sites/, extract site name and filter
- if uploadPath.startswith('/sites/'):
- parsedPath = self._extractSiteFromStandardPath(uploadPath)
- if parsedPath:
- siteName = parsedPath.get("siteName")
- # Filter sites by name (case-insensitive substring match)
- sites = self._filterSitesByHint(allSites, siteName)
- if not sites:
- return ActionResult.isFailure(error=f"No SharePoint site found matching '{siteName}'")
- logger.info(f"Filtered to site(s) matching '{siteName}': {[s['displayName'] for s in sites]}")
- else:
- sites = allSites
- else:
- sites = allSites
+ # Validate required parameters
+ if not uploadPath:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Either documentList must contain findDocumentPath result with folder information, or pathQuery must be provided. Use findDocumentPath first to get upload folder, or provide pathQuery directly.")
if not sites:
- return ActionResult.isFailure(error="No valid target site determined for upload")
-
- # Process upload paths based on whether pathObject was provided
- uploadSiteScope = None
- if not pathObject:
- # Parse the validated pathQuery to extract site and path information
- parsed = self._extractSiteFromStandardPath(uploadPath)
-
- if not parsed:
- return ActionResult.isFailure(error="Invalid uploadPath. Use Microsoft-standard /sites//")
-
- # Find matching site (already filtered above, but ensure we have the right one)
- candidateSites = self._filterSitesByHint(sites, parsed["siteName"]) # substring match
- # Choose exact displayName match if available
- exact = [s for s in candidateSites if (s.get("displayName") or "").strip().lower() == parsed["siteName"].strip().lower()]
- selectedSite = exact[0] if exact else (candidateSites[0] if candidateSites else None)
- if not selectedSite:
- return ActionResult.isFailure(error=f"SharePoint site '{parsed['siteName']}' not found or not accessible")
-
- uploadSiteScope = selectedSite
- # Use the inner path portion as the actual upload target path
- # Remove document library name from path (same logic as listDocuments)
- innerPath = parsed.get('innerPath', '').lstrip('/')
- pathSegments = [s for s in innerPath.split('/') if s.strip()]
- if len(pathSegments) > 1:
- # Path has multiple segments - first might be a library name
- # Try without first segment (assuming it's a library name)
- innerPath = '/'.join(pathSegments[1:])
- logger.info(f"Removed first path segment (potential library name), path changed from '{parsed['innerPath']}' to '{innerPath}'")
- elif len(pathSegments) == 1:
- # Only one segment - if it's a common library-like name, use empty path (root)
- firstSegmentLower = pathSegments[0].lower()
- libraryIndicators = ['document', 'dokument', 'shared', 'freigegeben', 'library', 'bibliothek']
- if any(indicator in firstSegmentLower for indicator in libraryIndicators):
- innerPath = ''
- logger.info(f"First segment '{pathSegments[0]}' appears to be a library name, using root")
-
- uploadPaths = [f"/{innerPath}" if innerPath else "/"]
- sites = [selectedSite]
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Site information missing. Cannot determine target site for upload.")
+
+ if not filesToUpload:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="No files to upload found in documentList.")
+
+ # Get connection
+ self.services.chat.progressLogUpdate(operationId, 0.3, "Getting Microsoft connection")
+ connection = self._getMicrosoftConnection(connectionReference)
+ if not connection:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="No valid Microsoft connection found for the provided connection reference")
+
+ # Process upload paths
+ uploadPaths = []
+ if uploadPath.startswith('01PPXICCB') or uploadPath.startswith('01'):
+ # It's a folder ID - use it directly
+ uploadPaths = [uploadPath]
+ logger.info(f"Using folder ID directly for upload: {uploadPath}")
else:
- # When using pathObject, check if uploadPath is a folder ID or a path
- if uploadPath.startswith('01PPXICCB') or uploadPath.startswith('01'):
- # It's a folder ID - use it directly
- uploadPaths = [uploadPath]
- logger.info(f"Using folder ID directly for upload: {uploadPath}")
- else:
- # It's a path - resolve it normally
- uploadPaths = self._resolvePathQuery(uploadPath)
+ # It's a path - resolve it normally
+ uploadPaths = self._resolvePathQuery(uploadPath)
# Process each document upload
uploadResults = []
# Extract file names from documents
- fileNames = [doc.fileName for doc in chatDocuments]
+ fileNames = [doc.fileName for doc in filesToUpload]
logger.info(f"Using file names from documentList: {fileNames}")
- for i, (chatDocument, fileName) in enumerate(zip(chatDocuments, fileNames)):
+ self.services.chat.progressLogUpdate(operationId, 0.5, f"Uploading {len(filesToUpload)} document(s)")
+
+ # Process upload paths
+
+ # Process each document upload
+ uploadResults = []
+
+ # Extract file names from documents
+ fileNames = [doc.fileName for doc in filesToUpload]
+ logger.info(f"Using file names from documentList: {fileNames}")
+
+ self.services.chat.progressLogUpdate(operationId, 0.5, f"Uploading {len(filesToUpload)} document(s)")
+
+ for i, (chatDocument, fileName) in enumerate(zip(filesToUpload, fileNames)):
try:
fileId = chatDocument.fileId
fileData = self.services.chat.getFileData(fileId)
@@ -2056,11 +1678,14 @@ class MethodSharepoint(MethodBase):
"error": str(e),
"uploadStatus": "failed"
})
+
+ # Update progress for each file
+ self.services.chat.progressLogUpdate(operationId, 0.5 + (i * 0.4 / len(filesToUpload)), f"Uploaded {i + 1}/{len(filesToUpload)} file(s)")
# Create result data
resultData = {
"connectionReference": connectionReference,
- "pathQuery": uploadPath,
+ "uploadPath": uploadPath,
"documentList": documentList,
"fileNames": fileNames,
"sitesAvailable": len(sites),
@@ -2087,6 +1712,10 @@ class MethodSharepoint(MethodBase):
"failedUploads": len([r for r in uploadResults if r.get("uploadStatus") == "failed"])
}
+ successfulUploads = len([r for r in uploadResults if r.get("uploadStatus") == "success"])
+ self.services.chat.progressLogUpdate(operationId, 0.9, f"Uploaded {successfulUploads}/{len(uploadResults)} file(s)")
+ self.services.chat.progressLogFinish(operationId, successfulUploads > 0)
+
return ActionResult(
success=True,
documents=[
@@ -2101,6 +1730,11 @@ class MethodSharepoint(MethodBase):
except Exception as e:
logger.error(f"Error uploading to SharePoint: {str(e)}")
+ if operationId:
+ try:
+ self.services.chat.progressLogFinish(operationId, False)
+ except:
+ pass
return ActionResult(
success=False,
error=str(e)
@@ -2111,226 +1745,94 @@ class MethodSharepoint(MethodBase):
"""
GENERAL:
- Purpose: List documents and folders in SharePoint paths across sites.
- - Input requirements: connectionReference (required); optional pathObject or pathQuery; includeSubfolders.
+ - Input requirements: connectionReference (required); documentList (required); includeSubfolders (optional).
- Output format: JSON with folder items and metadata.
Parameters:
- connectionReference (str, required): Microsoft connection label.
- - pathObject (str, optional): Reference to a previous path result.
- - pathQuery (str, optional): Path query if no pathObject.
+ - documentList (list, required): Document list reference(s) containing findDocumentPath result.
- includeSubfolders (bool, optional): Include one level of subfolders. Default: False.
"""
+ import time
+ operationId = None
try:
+ # Init progress logger
+ workflowId = self.services.workflow.id if self.services.workflow else f"no-workflow-{int(time.time())}"
+ operationId = f"sharepoint_list_{workflowId}_{int(time.time())}"
+
+ # Start progress tracking
+ parentOperationId = parameters.get('parentOperationId')
+ self.services.chat.progressLogStart(
+ operationId,
+ "List Documents",
+ "SharePoint Listing",
+ "Processing document list",
+ parentOperationId=parentOperationId
+ )
+
connectionReference = parameters.get("connectionReference")
- pathObject = parameters.get("pathObject")
- pathQuery = parameters.get("pathQuery")
+ documentList = parameters.get("documentList")
+ pathQuery = parameters.get("pathQuery", "*")
+ if isinstance(documentList, str):
+ documentList = [documentList]
includeSubfolders = parameters.get("includeSubfolders", False) # Default to False for better UX
- listQuery = pathQuery
- logger.info(f"Using pathQuery: {pathQuery}")
-
if not connectionReference:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
return ActionResult.isFailure(error="Connection reference is required")
- # If pathObject is provided, resolve the reference and extract folder IDs from it
- # Note: pathObject takes precedence over pathQuery when both are provided
- if pathObject:
- if pathQuery and pathQuery != "*":
- logger.debug(f"Both pathObject and pathQuery provided - using pathObject (pathQuery '{pathQuery}' will be ignored)")
- try:
- # Resolve the reference label to get the actual document list
- from modules.datamodels.datamodelDocref import DocumentReferenceList
- documentList = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([pathObject]))
- if not documentList or len(documentList) == 0:
- return ActionResult.isFailure(error=f"No document list found for reference: {pathObject}")
-
- # Get the first document's content (which should be the JSON)
- firstDocument = documentList[0]
- logger.info(f"Document fileId: {firstDocument.fileId}, fileName: {firstDocument.fileName}")
- fileData = self.services.chat.getFileData(firstDocument.fileId)
- if not fileData:
- return ActionResult.isFailure(error=f"No file data found for document: {pathObject} (fileId: {firstDocument.fileId})")
- logger.info(f"File data length: {len(fileData) if fileData else 0}")
-
- # Parse the JSON content
- resultData = json.loads(fileData)
-
- # Debug: Log the structure of the result document
- logger.info(f"Result document keys: {list(resultData.keys())}")
-
- # Handle different result document formats
- foundDocuments = []
-
- # Check if it's a direct SharePoint result (has foundDocuments)
- if "foundDocuments" in resultData:
- foundDocuments = resultData.get("foundDocuments", [])
- logger.info(f"Found {len(foundDocuments)} documents in foundDocuments array")
- # Check if it's an AI validation result (has result string with validationReport)
- elif "result" in resultData and "validationReport" in resultData["result"]:
- try:
- # Parse the nested JSON in the result field
- nestedResult = json.loads(resultData["result"])
- validationReport = nestedResult.get("validationReport", {})
- documentDetails = validationReport.get("documentDetails", {})
-
- if documentDetails:
- # Convert the single document details to the expected format
- doc = {
- "id": documentDetails.get("id"),
- "name": documentDetails.get("name"),
- "type": documentDetails.get("type", "").lower(), # Convert "Folder" to "folder"
- "siteName": documentDetails.get("siteName"),
- "siteId": documentDetails.get("siteId"),
- "fullPath": documentDetails.get("fullPath"),
- "webUrl": documentDetails.get("webUrl", ""),
- "parentPath": documentDetails.get("parentPath", "")
- }
- foundDocuments = [doc]
- logger.info(f"Extracted 1 document from validation report")
- except ValueError as e:
- logger.error(f"Failed to parse nested JSON in result field: {e}")
- return ActionResult.isFailure(error=f"Invalid nested JSON in pathObject: {str(e)}")
-
- # Debug: Log what we found in the result document
- logger.info(f"Result document contains {len(foundDocuments)} documents")
- for i, doc in enumerate(foundDocuments):
- logger.info(f" Document {i+1}: name='{doc.get('name')}', type='{doc.get('type')}', id='{doc.get('id')}'")
-
- # Extract folder information from the result
- folders = []
- for doc in foundDocuments:
- if doc.get("type") == "folder":
- folders.append(doc)
-
- logger.info(f"Found {len(folders)} folders in result document")
-
- if folders:
- # Use the first folder found - prefer folder ID for direct API calls
- firstFolder = folders[0]
- if firstFolder.get("id"):
- # Use folder ID directly for most reliable API calls
- listQuery = firstFolder.get("id")
- logger.info(f"Using folder ID from pathObject: {listQuery}")
- elif firstFolder.get("fullPath"):
- # Extract the correct path portion from fullPath by removing site name
- fullPath = firstFolder.get("fullPath")
- # fullPath format: \\SiteName\\Library\\Folder\\SubFolder
- # We need to remove the first two parts (\\SiteName\\) to get the actual folder path
- pathParts = fullPath.lstrip('\\').split('\\')
- if len(pathParts) > 1:
- # Remove the first part (site name) and reconstruct the path
- actualPath = '\\'.join(pathParts[1:])
- listQuery = actualPath
- logger.info(f"Extracted path from fullPath: {listQuery}")
- else:
- listQuery = fullPath
- logger.info(f"Using full path from pathObject (no site name to remove): {listQuery}")
- else:
- return ActionResult.isFailure(error="No valid folder information found in pathObject")
- else:
- return ActionResult.isFailure(error="No folders found in pathObject")
-
- except ValueError as e:
- return ActionResult.isFailure(error=f"Invalid JSON in pathObject: {str(e)}")
- except Exception as e:
- return ActionResult.isFailure(error=f"Error resolving pathObject reference: {str(e)}")
+ # Require either documentList or pathQuery
+ if not documentList and (not pathQuery or pathQuery.strip() == "" or pathQuery.strip() == "*"):
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Either documentList or pathQuery is required")
- # Get Microsoft connection
+ # Parse documentList to extract folder path and site information
+ listQuery, sites, _, errorMsg = await self._parseDocumentListForFolder(documentList)
+ if errorMsg:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error=errorMsg)
+
+ # If no folder path found from documentList, use pathQuery if provided
+ if not listQuery and pathQuery and pathQuery.strip() != "" and pathQuery.strip() != "*":
+ listQuery = pathQuery
+ logger.info(f"Using pathQuery for list query: {listQuery}")
+ # Resolve sites from pathQuery
+ sites, errorMsg = await self._resolveSitesFromPathQuery(pathQuery)
+ if errorMsg:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error=errorMsg)
+
+ # Validate required parameters
+ if not listQuery:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Either documentList must contain findDocumentPath result with folder information, or pathQuery must be provided. Use findDocumentPath first to get folder path, or provide pathQuery directly.")
+
+ if not sites:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Site information missing. Cannot determine target site for list operation.")
+
+ # Get connection
+ self.services.chat.progressLogUpdate(operationId, 0.2, "Getting Microsoft connection")
connection = self._getMicrosoftConnection(connectionReference)
if not connection:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
return ActionResult.isFailure(error="No valid Microsoft connection found for the provided connection reference")
logger.info(f"Starting SharePoint listDocuments for listQuery: {listQuery}")
logger.debug(f"Connection ID: {connection['id']}")
+ self.services.chat.progressLogUpdate(operationId, 0.3, "Processing folder path")
+
# Parse listQuery to extract path, search terms, search type, and options
pathQuery, fileQuery, searchType, searchOptions = self._parseSearchQuery(listQuery)
- # Determine sites to use - strict validation: pathObject → pathQuery → ERROR
- sites = None
-
- # Step 1: Check pathObject first
- if pathObject:
- # When pathObject is provided, we should have specific site information
- # Extract site information from the pathObject result
- try:
- # Get the site information from the first folder in pathObject
- if 'foundDocuments' in locals() and foundDocuments:
- firstFolder = foundDocuments[0]
- siteName = firstFolder.get("siteName")
- siteId = firstFolder.get("siteId")
-
- if siteName and siteId:
- # Use the specific site from pathObject instead of discovering all sites
- sites = [{
- "id": siteId,
- "displayName": siteName,
- "webUrl": firstFolder.get("webUrl", "")
- }]
- logger.info(f"Using specific site from pathObject: {siteName} (ID: {siteId})")
- else:
- # Site info missing from pathObject - this is an error
- return ActionResult.isFailure(error="Site information missing from pathObject. Cannot determine target site for list operation.")
- else:
- # No documents found in pathObject - this is an error
- return ActionResult.isFailure(error="No valid folder information found in pathObject. Cannot determine target site for list operation.")
- except Exception as e:
- # Error processing pathObject - this is an error
- return ActionResult.isFailure(error=f"Error processing pathObject: {str(e)}. Cannot determine target site for list operation.")
-
- # Step 2: If no pathObject, check pathQuery
- elif pathQuery and pathQuery.strip() != "" and pathQuery.strip() != "*":
- # Validate pathQuery format
- if not pathQuery.startswith('/'):
- return ActionResult.isFailure(error="pathQuery must start with '/' and include site name with Microsoft-standard syntax /sites//... e.g. /sites/company-share/Freigegebene Dokumente/Work")
-
- # Check if pathQuery contains search terms (words without proper path structure)
- validPathPrefixes = ['/sites/', '/Documents', '/documents', '/Shared Documents', '/shared documents']
- if not any(pathQuery.startswith(prefix) for prefix in validPathPrefixes):
- return ActionResult.isFailure(error=f"Invalid pathQuery '{pathQuery}'. This appears to be search terms, not a valid SharePoint path. Use findDocumentPath action first to search for folders, then use the returned folder path as pathQuery.")
-
- # If pathQuery starts with Microsoft-standard /sites/, try to get site directly
- directSite = None
- if pathQuery.startswith('/sites/'):
- parsedPath = self._extractSiteFromStandardPath(pathQuery)
- if parsedPath:
- siteName = parsedPath.get("siteName")
- # Try to get site directly by path (optimization - no need to load all 60 sites)
- directSite = await self._getSiteByStandardPath(siteName)
- if directSite:
- logger.info(f"Got site directly by standard path - no need to discover all sites")
- sites = [directSite]
- else:
- logger.warning(f"Could not get site directly, falling back to site discovery")
-
- # If we didn't get the site directly, use discovery and filtering
- if not directSite:
- # For pathQuery, we need to discover sites to find the specific one
- allSites = await self._discoverSharePointSites()
- if not allSites:
- return ActionResult.isFailure(error="No SharePoint sites found or accessible")
-
- # If pathQuery starts with Microsoft-standard /sites/, extract site name and filter
- if pathQuery.startswith('/sites/'):
- parsedPath = self._extractSiteFromStandardPath(pathQuery)
- if parsedPath:
- siteName = parsedPath.get("siteName")
- # Filter sites by name (case-insensitive substring match)
- sites = self._filterSitesByHint(allSites, siteName)
- if not sites:
- return ActionResult.isFailure(error=f"No SharePoint site found matching '{siteName}'")
- logger.info(f"Filtered to site(s) matching '{siteName}': {[s['displayName'] for s in sites]}")
- else:
- sites = allSites
- else:
- sites = allSites
- else:
- # Step 3: Both pathObject and pathQuery failed - ERROR, NO FALLBACK
- return ActionResult.isFailure(error="No valid list path provided. Either provide pathObject (from findDocumentPath) or a valid pathQuery with specific site information.")
-
- if not sites:
- return ActionResult.isFailure(error="No valid target site determined for list operation")
-
# Check if listQuery is a folder ID (starts with 01PPXICCB...)
if listQuery.startswith('01PPXICCB') or listQuery.startswith('01'):
# Direct folder ID - use it directly
@@ -2375,6 +1877,8 @@ class MethodSharepoint(MethodBase):
# Process each folder path across all sites
listResults = []
+ self.services.chat.progressLogUpdate(operationId, 0.5, f"Listing {len(folderPaths)} folder(s) across {len(sites)} site(s)")
+
for folderPath in folderPaths:
try:
folderResults = []
@@ -2413,17 +1917,7 @@ class MethodSharepoint(MethodBase):
for item in items:
# Use improved folder detection logic
- isFolder = False
- if 'folder' in item:
- isFolder = True
- else:
- # Try to detect by URL pattern or other indicators
- webUrl = item.get('webUrl', '')
- name = item.get('name', '')
-
- # Check if URL has no file extension and looks like a folder path
- if '.' not in name and ('/' in webUrl or '\\' in webUrl):
- isFolder = True
+ isFolder = self.services.sharepoint.detectFolderType(item)
itemInfo = {
"id": item.get("id"),
@@ -2473,17 +1967,7 @@ class MethodSharepoint(MethodBase):
for subfolderItem in subfolderItems:
# Use improved folder detection logic for subfolder items
- subfolderIsFolder = False
- if 'folder' in subfolderItem:
- subfolderIsFolder = True
- else:
- # Try to detect by URL pattern or other indicators
- subfolderWebUrl = subfolderItem.get('webUrl', '')
- subfolderName = subfolderItem.get('name', '')
-
- # Check if URL has no file extension and looks like a folder path
- if '.' not in subfolderName and ('/' in subfolderWebUrl or '\\' in subfolderWebUrl):
- subfolderIsFolder = True
+ subfolderIsFolder = self.services.sharepoint.detectFolderType(subfolderItem)
# Only add files and direct subfolders, NO RECURSION
subfolderItemInfo = {
@@ -2535,6 +2019,9 @@ class MethodSharepoint(MethodBase):
"siteResults": []
})
+ totalItems = sum(len(result.get("siteResults", [])) for result in listResults)
+ self.services.chat.progressLogUpdate(operationId, 0.9, f"Found {totalItems} item(s)")
+
# Create result data
resultData = {
"pathQuery": listQuery,
@@ -2554,9 +2041,10 @@ class MethodSharepoint(MethodBase):
"includeSubfolders": includeSubfolders,
"sitesSearched": len(sites),
"folderCount": len(listResults),
- "totalItems": sum(len(result.get("siteResults", [])) for result in listResults)
+ "totalItems": totalItems
}
+ self.services.chat.progressLogFinish(operationId, True)
return ActionResult(
success=True,
documents=[
@@ -2571,7 +2059,331 @@ class MethodSharepoint(MethodBase):
except Exception as e:
logger.error(f"Error listing SharePoint documents: {str(e)}")
+ if operationId:
+ try:
+ self.services.chat.progressLogFinish(operationId, False)
+ except:
+ pass
return ActionResult(
success=False,
error=str(e)
- )
\ No newline at end of file
+ )
+
+ @action
+ async def analyzeFolderUsage(self, parameters: Dict[str, Any]) -> ActionResult:
+ """
+ GENERAL:
+ - Purpose: Analyze usage intensity of folders and files in SharePoint.
+ - Input requirements: connectionReference (required); documentList (required); optional startDateTime, endDateTime, interval.
+ - Output format: JSON with usage analytics grouped by time intervals.
+
+ Parameters:
+ - connectionReference (str, required): Microsoft connection label.
+ - documentList (list, required): Document list reference(s) containing findDocumentPath result.
+ - startDateTime (str, optional): Start date/time in ISO format (e.g., "2025-11-01T00:00:00Z"). Default: 30 days ago.
+ - endDateTime (str, optional): End date/time in ISO format (e.g., "2025-11-30T23:59:59Z"). Default: current time.
+ - interval (str, optional): Time interval for grouping activities. Options: "day", "week", "month". Default: "day".
+ """
+ import time
+ operationId = None
+ try:
+ # Init progress logger
+ workflowId = self.services.workflow.id if self.services.workflow else f"no-workflow-{int(time.time())}"
+ operationId = f"sharepoint_usage_{workflowId}_{int(time.time())}"
+
+ # Start progress tracking
+ parentOperationId = parameters.get('parentOperationId')
+ self.services.chat.progressLogStart(
+ operationId,
+ "Analyze Folder Usage",
+ "SharePoint Analytics",
+ "Processing document list",
+ parentOperationId=parentOperationId
+ )
+
+ connectionReference = parameters.get("connectionReference")
+ documentList = parameters.get("documentList")
+ pathQuery = parameters.get("pathQuery")
+ if isinstance(documentList, str):
+ documentList = [documentList]
+ startDateTime = parameters.get("startDateTime")
+ endDateTime = parameters.get("endDateTime")
+ interval = parameters.get("interval", "day")
+
+ if not connectionReference:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Connection reference is required")
+
+ # Require either documentList or pathQuery
+ if not documentList and (not pathQuery or pathQuery.strip() == "" or pathQuery.strip() == "*"):
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Either documentList or pathQuery is required")
+
+ # Resolve folder/item information from documentList or pathQuery
+ siteId = None
+ driveId = None
+ itemId = None
+ folderPath = None
+ folderName = None
+
+ if documentList:
+ foundDocuments, sites, errorMsg = await self._parseDocumentListForFoundDocuments(documentList)
+ if errorMsg:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error=errorMsg)
+
+ if not foundDocuments:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="No documents found in documentList")
+
+ # Get siteId from first document (all should be from same site)
+ firstItem = foundDocuments[0]
+ siteId = firstItem.get("siteId")
+ if not siteId:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Site ID missing from documentList")
+
+ # Get drive ID (needed for analytics)
+ driveId = await self.services.sharepoint.getDriveId(siteId)
+ if not driveId:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Could not determine drive ID for the site")
+
+ # If no items from documentList, try pathQuery fallback
+ if not foundDocuments and pathQuery and pathQuery.strip() != "" and pathQuery.strip() != "*":
+ sites, errorMsg = await self._resolveSitesFromPathQuery(pathQuery)
+ if errorMsg:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error=errorMsg)
+
+ if sites:
+ siteId = sites[0].get("id")
+ # Parse pathQuery to find the folder/item
+ pathQueryParsed, fileQuery, searchType, searchOptions = self._parseSearchQuery(pathQuery)
+
+ # Extract folder path from pathQuery
+ folderPath = '/'
+ if pathQueryParsed and pathQueryParsed.startswith('/sites/'):
+ parsedPath = self._extractSiteFromStandardPath(pathQueryParsed)
+ if parsedPath:
+ innerPath = parsedPath.get("innerPath", "")
+ folderPath = '/' + innerPath if innerPath else '/'
+ elif pathQueryParsed:
+ folderPath = pathQueryParsed
+
+ # Get drive ID
+ driveId = await self.services.sharepoint.getDriveId(siteId)
+ if not driveId:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Could not determine drive ID for the site")
+
+ # Get folder/item by path
+ folderInfo = await self.services.sharepoint.getFolderByPath(siteId, folderPath.lstrip('/'))
+ if not folderInfo:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error=f"Folder or file not found at path: {folderPath}")
+
+ # Add pathQuery item to foundDocuments for processing
+ foundDocuments = [{
+ "id": folderInfo.get("id"),
+ "name": folderInfo.get("name", ""),
+ "type": "folder" if folderInfo.get("folder") else "file",
+ "siteId": siteId,
+ "fullPath": folderPath,
+ "webUrl": folderInfo.get("webUrl", "")
+ }]
+
+ if not siteId or not driveId:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Either documentList must contain findDocumentPath result with folder information, or pathQuery must be provided. Use findDocumentPath first to get folder path, or provide pathQuery directly.")
+
+ self.services.chat.progressLogUpdate(operationId, 0.2, "Getting Microsoft connection")
+ # Get Microsoft connection
+ connection = self._getMicrosoftConnection(connectionReference)
+ if not connection:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="No valid Microsoft connection found for the provided connection reference")
+
+ # Set access token
+ if not self.services.sharepoint.setAccessTokenFromConnection(connection):
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="Failed to set SharePoint access token")
+
+ # Process all items from documentList or pathQuery
+ # IMPORTANT: Only analyze FOLDERS, not files (action is "analyzeFolderUsage")
+ itemsToAnalyze = []
+ if foundDocuments:
+ for item in foundDocuments:
+ itemId = item.get("id")
+ itemType = item.get("type", "").lower()
+
+ # Only process folders, skip files and site-level items
+ if itemId and itemType == "folder":
+ itemsToAnalyze.append({
+ "id": itemId,
+ "name": item.get("name", ""),
+ "type": itemType,
+ "path": item.get("fullPath", ""),
+ "webUrl": item.get("webUrl", "")
+ })
+
+ if not itemsToAnalyze:
+ if operationId:
+ self.services.chat.progressLogFinish(operationId, False)
+ return ActionResult.isFailure(error="No valid folders found in documentList to analyze. Note: This action only analyzes folders, not files.")
+
+ self.services.chat.progressLogUpdate(operationId, 0.4, f"Analyzing {len(itemsToAnalyze)} folder(s)")
+
+ # Analyze each item
+ allAnalytics = []
+ totalActivities = 0
+ uniqueUsers = set()
+ activityTypes = {}
+
+ # Compute actual date range values (getFolderUsageAnalytics will set defaults if None)
+ # We need to compute them here to store in output, since getFolderUsageAnalytics modifies them
+ actualStartDateTime = startDateTime
+ actualEndDateTime = endDateTime
+ if not actualEndDateTime:
+ actualEndDateTime = datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z')
+ if not actualStartDateTime:
+ startDate = datetime.now(timezone.utc) - timedelta(days=30)
+ actualStartDateTime = startDate.isoformat().replace('+00:00', 'Z')
+
+ for idx, item in enumerate(itemsToAnalyze):
+ progress = 0.4 + (idx / len(itemsToAnalyze)) * 0.5
+ self.services.chat.progressLogUpdate(operationId, progress, f"Analyzing folder {item['name']} ({idx+1}/{len(itemsToAnalyze)})")
+
+ # Get usage analytics for this folder
+ analyticsResult = await self.services.sharepoint.getFolderUsageAnalytics(
+ siteId=siteId,
+ driveId=driveId,
+ itemId=item["id"],
+ startDateTime=startDateTime,
+ endDateTime=endDateTime,
+ interval=interval
+ )
+
+ if "error" in analyticsResult:
+ logger.warning(f"Failed to get analytics for item {item['name']} ({item['id']}): {analyticsResult['error']}")
+ # Continue with other items even if one fails
+ itemAnalytics = {
+ "itemId": item["id"],
+ "itemName": item["name"],
+ "itemType": item["type"],
+ "itemPath": item["path"],
+ "error": analyticsResult.get("error", "Unknown error")
+ }
+ else:
+ # Process analytics for this item
+ itemActivities = 0
+ itemUsers = set()
+ itemActivityTypes = {}
+
+ if "value" in analyticsResult:
+ for intervalData in analyticsResult["value"]:
+ activities = intervalData.get("activities", [])
+ for activity in activities:
+ itemActivities += 1
+ totalActivities += 1
+
+ action = activity.get("action", {})
+ actionType = action.get("verb", "unknown")
+ itemActivityTypes[actionType] = itemActivityTypes.get(actionType, 0) + 1
+ activityTypes[actionType] = activityTypes.get(actionType, 0) + 1
+
+ actor = activity.get("actor", {})
+ userPrincipalName = actor.get("userPrincipalName", "")
+ if userPrincipalName:
+ itemUsers.add(userPrincipalName)
+ uniqueUsers.add(userPrincipalName)
+
+ itemAnalytics = {
+ "itemId": item["id"],
+ "itemName": item["name"],
+ "itemType": item["type"],
+ "itemPath": item["path"],
+ "webUrl": item["webUrl"],
+ "analytics": analyticsResult,
+ "summary": {
+ "totalActivities": itemActivities,
+ "uniqueUsers": len(itemUsers),
+ "activityTypes": itemActivityTypes
+ }
+ }
+
+ # Include note if analytics are not available
+ if "note" in analyticsResult:
+ itemAnalytics["note"] = analyticsResult["note"]
+
+ allAnalytics.append(itemAnalytics)
+
+ self.services.chat.progressLogUpdate(operationId, 0.9, "Processing analytics data")
+
+ # Process and format analytics data
+ resultData = {
+ "siteId": siteId,
+ "driveId": driveId,
+ "startDateTime": actualStartDateTime, # Store computed date range (not None)
+ "endDateTime": actualEndDateTime, # Store computed date range (not None)
+ "interval": interval,
+ "itemsAnalyzed": len(itemsToAnalyze),
+ "foldersAnalyzed": len([item for item in allAnalytics if item.get("itemType") == "folder"]),
+ "items": allAnalytics,
+ "summary": {
+ "totalActivities": totalActivities,
+ "uniqueUsers": len(uniqueUsers),
+ "activityTypes": activityTypes
+ },
+ "note": f"Analyzed {len(itemsToAnalyze)} folder(s) from {actualStartDateTime} to {actualEndDateTime}. " +
+ f"Found {totalActivities} total activities across {len(uniqueUsers)} unique user(s)." +
+ (f" Note: {len([item for item in allAnalytics if 'error' in item])} folder(s) had errors or no analytics data available." if any('error' in item for item in allAnalytics) else ""),
+ "timestamp": self.services.utils.timestampGetUtc()
+ }
+
+ self.services.chat.progressLogUpdate(operationId, 0.95, f"Found {totalActivities} total activities across {len(itemsToAnalyze)} folder(s)")
+
+ validationMetadata = {
+ "actionType": "sharepoint.analyzeFolderUsage",
+ "itemsAnalyzed": len(itemsToAnalyze),
+ "interval": interval,
+ "totalActivities": totalActivities,
+ "uniqueUsers": len(uniqueUsers)
+ }
+
+ self.services.chat.progressLogFinish(operationId, True)
+ return ActionResult(
+ success=True,
+ documents=[
+ ActionDocument(
+ documentName=f"sharepoint_usage_analysis_{self._format_timestamp_for_filename()}.json",
+ documentData=json.dumps(resultData, indent=2),
+ mimeType="application/json",
+ validationMetadata=validationMetadata
+ )
+ ]
+ )
+
+ except Exception as e:
+ logger.error(f"Error analyzing folder usage: {str(e)}")
+ if operationId:
+ try:
+ self.services.chat.progressLogFinish(operationId, False)
+ except:
+ pass
+ return ActionResult(
+ success=False,
+ error=str(e)
+ )
\ No newline at end of file
diff --git a/modules/workflows/processing/core/actionExecutor.py b/modules/workflows/processing/core/actionExecutor.py
index f9af58e7..f183c0e4 100644
--- a/modules/workflows/processing/core/actionExecutor.py
+++ b/modules/workflows/processing/core/actionExecutor.py
@@ -82,6 +82,35 @@ class ActionExecutor:
enhancedParameters['expectedDocumentFormats'] = action.expectedDocumentFormats
logger.info(f"Expected formats: {action.expectedDocumentFormats}")
+ # Get current task execution operationId to pass as parent to action methods
+ # This MUST be the "Service Workflow Execution" operation ID (taskExec_*)
+ parentOperationId = None
+ try:
+ progressLogger = self.services.chat.createProgressLogger()
+ activeOperations = progressLogger.getActiveOperations()
+ logger.debug(f"Looking for parent operation ID. Active operations: {list(activeOperations.keys())}")
+
+ # Look for task execution operation (starts with "taskExec_")
+ # This is the "Service Workflow Execution" level that should be parent of ALL actions
+ for opId in activeOperations.keys():
+ if opId.startswith("taskExec_"):
+ parentOperationId = opId
+ logger.info(f"Found parent operation ID: {parentOperationId} for action {action.execMethod}.{action.execAction}")
+ break
+
+ if not parentOperationId:
+ logger.warning(f"No taskExec_ operation found in active operations. Active operations: {list(activeOperations.keys())}")
+ except Exception as e:
+ logger.error(f"Error getting parent operation ID: {str(e)}")
+
+ # Add parentOperationId to parameters so action methods can use it
+ # This is critical for UI dashboard hierarchical display
+ if parentOperationId:
+ enhancedParameters['parentOperationId'] = parentOperationId
+ logger.info(f"Passing parentOperationId '{parentOperationId}' to action {action.execMethod}.{action.execAction}")
+ else:
+ logger.warning(f"WARNING: No parentOperationId found for action {action.execMethod}.{action.execAction}. Action logs will appear at root level!")
+
# Check workflow status before executing the action
checkWorkflowStopped(self.services)
From 54246745a98ec6c81e5cec9783ed674a651e6d78 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Sun, 7 Dec 2025 13:48:39 +0100
Subject: [PATCH 2/6] refactored uam to rbac
---
app.py | 3 +
import_map_analysis.md | 247 +++++++
modules/connectors/connectorDbJson.py | 678 ------------------
modules/connectors/connectorDbPostgre.py | 207 +++++-
modules/datamodels/datamodelRbac.py | 102 +++
modules/datamodels/datamodelUam.py | 46 +-
modules/features/automation/mainAutomation.py | 8 +-
modules/interfaces/interfaceBootstrap.py | 548 ++++++++++++++
modules/interfaces/interfaceDbAppAccess.py | 254 -------
modules/interfaces/interfaceDbAppObjects.py | 408 +++++++----
modules/interfaces/interfaceDbChatAccess.py | 140 ----
modules/interfaces/interfaceDbChatObjects.py | 187 +++--
.../interfaces/interfaceDbComponentAccess.py | 203 ------
.../interfaces/interfaceDbComponentObjects.py | 143 ++--
modules/migration/__init__.py | 1 +
modules/migration/migrateUamToRbac.py | 212 ++++++
modules/routes/routeDataFiles.py | 4 +-
modules/routes/routeRbac.py | 161 +++++
modules/routes/routeWorkflows.py | 8 +-
modules/security/rbac.py | 194 +++++
modules/shared/rbacHelpers.py | 178 +++++
pytest.ini | 2 +-
tests/integration/rbac/README.md | 42 ++
tests/integration/rbac/__init__.py | 1 +
tests/integration/rbac/test_rbac_database.py | 209 ++++++
tests/integration/rbac/test_rbac_migration.py | 282 ++++++++
tests/unit/rbac/README.md | 47 ++
tests/unit/rbac/__init__.py | 1 +
tests/unit/rbac/test_rbac_bootstrap.py | 162 +++++
tests/unit/rbac/test_rbac_permissions.py | 403 +++++++++++
30 files changed, 3507 insertions(+), 1574 deletions(-)
create mode 100644 import_map_analysis.md
delete mode 100644 modules/connectors/connectorDbJson.py
create mode 100644 modules/datamodels/datamodelRbac.py
create mode 100644 modules/interfaces/interfaceBootstrap.py
delete mode 100644 modules/interfaces/interfaceDbAppAccess.py
delete mode 100644 modules/interfaces/interfaceDbChatAccess.py
delete mode 100644 modules/interfaces/interfaceDbComponentAccess.py
create mode 100644 modules/migration/__init__.py
create mode 100644 modules/migration/migrateUamToRbac.py
create mode 100644 modules/routes/routeRbac.py
create mode 100644 modules/security/rbac.py
create mode 100644 modules/shared/rbacHelpers.py
create mode 100644 tests/integration/rbac/README.md
create mode 100644 tests/integration/rbac/__init__.py
create mode 100644 tests/integration/rbac/test_rbac_database.py
create mode 100644 tests/integration/rbac/test_rbac_migration.py
create mode 100644 tests/unit/rbac/README.md
create mode 100644 tests/unit/rbac/__init__.py
create mode 100644 tests/unit/rbac/test_rbac_bootstrap.py
create mode 100644 tests/unit/rbac/test_rbac_permissions.py
diff --git a/app.py b/app.py
index 9ace64b5..23a8cb5c 100644
--- a/app.py
+++ b/app.py
@@ -437,3 +437,6 @@ app.include_router(automationRouter)
from modules.routes.routeAdminAutomationEvents import router as adminAutomationEventsRouter
app.include_router(adminAutomationEventsRouter)
+from modules.routes.routeRbac import router as rbacRouter
+app.include_router(rbacRouter)
+
diff --git a/import_map_analysis.md b/import_map_analysis.md
new file mode 100644
index 00000000..4074d1a7
--- /dev/null
+++ b/import_map_analysis.md
@@ -0,0 +1,247 @@
+# Import Map Analysis: interfaces ↔ connectors ↔ security
+
+## Overview
+This document maps all imports between `modules/interfaces`, `modules/connectors`, and `modules/security` to identify structural issues, circular dependencies, and architectural concerns.
+
+**Architectural Principle:**
+- ✅ Connectors (infrastructure) can import from Security (infrastructure)
+- ✅ Interfaces (business logic) can import from Security (infrastructure)
+- ✅ Interfaces (business logic) can import from Connectors (infrastructure)
+- ❌ Connectors should NOT import from Interfaces (business logic)
+
+---
+
+## Import Dependencies Map
+
+### **CONNECTORS → SECURITY**
+
+#### `connectorDbPostgre.py`
+- **Imports from security:**
+ - `from modules.security.rbac import RbacClass` (line 13)
+ - **Usage:**
+ - **Runtime instantiation:** `RbacClass(self)` in `getRecordsetWithRBAC()` (line 1073)
+ - Creates `RbacClass` instance to get user permissions
+ - **Status:** ✅ **ARCHITECTURALLY CORRECT** - Connectors can import from security module
+
+---
+
+### **SECURITY → CONNECTORS**
+
+#### `security/rbac.py` (moved from `interfaces/interfaceRbac.py`)
+- **Imports from connectors:**
+ - `from modules.connectors.connectorDbPostgre import DatabaseConnector` (line 11, inside TYPE_CHECKING)
+ - **Usage:** Type hint only (`db: "DatabaseConnector"`)
+ - **Status:** ✅ Fixed with TYPE_CHECKING to avoid circular import
+ - **Architecture:** ✅ Correct - Security module can import from connectors (infrastructure layer)
+
+### **INTERFACES → CONNECTORS**
+
+#### `interfaceBootstrap.py`
+- **Imports from connectors:**
+ - `from modules.connectors.connectorDbPostgre import DatabaseConnector` (line 9)
+ - **Usage:** Function parameter types (`initBootstrap(db: DatabaseConnector)`)
+
+#### `interfaceDbAppObjects.py`
+- **Imports from connectors:**
+ - `from modules.connectors.connectorDbPostgre import DatabaseConnector` (line 12)
+ - **Usage:** Class initialization (`self.db: DatabaseConnector`)
+- **Imports from security:**
+ - `from modules.security.rbac import RbacClass` (line 17)
+ - **Usage:** RBAC permission checking
+ - **Architecture:** ✅ Correct - Interfaces can import from security (infrastructure layer)
+
+#### `interfaceDbChatObjects.py`
+- **Imports from connectors:**
+ - `from modules.connectors.connectorDbPostgre import DatabaseConnector` (line 29)
+ - **Usage:** Class initialization
+
+#### `interfaceDbComponentObjects.py`
+- **Imports from connectors:**
+ - `from modules.connectors.connectorDbPostgre import DatabaseConnector` (line 13)
+ - **Usage:** Class initialization
+
+#### `interfaceVoiceObjects.py`
+- **Imports from connectors:**
+ - `from modules.connectors.connectorVoiceGoogle import ConnectorGoogleSpeech` (line 10)
+ - **Usage:** Class initialization
+
+---
+
+## Circular Dependency Analysis
+
+### **CIRCULAR DEPENDENCY #1: RESOLVED** ✅
+```
+connectorDbPostgre.py (line 13)
+ └─> imports RbacClass from security.rbac
+ └─> Uses: RbacClass(self) at runtime (line 1073)
+
+security/rbac.py (line 11, inside TYPE_CHECKING)
+ └─> imports DatabaseConnector (type hint only)
+```
+
+**Status:** ✅ **RESOLVED** by moving RBAC to security module + `TYPE_CHECKING`
+
+**Architectural Fix:**
+- Moved `interfaceRbac.py` → `security/rbac.py`
+- Connectors can import from security (infrastructure layer)
+- Interfaces can import from security (business logic layer)
+- No architectural violation: security is shared infrastructure
+
+**Solution Applied:**
+```python
+# security/rbac.py
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from modules.connectors.connectorDbPostgre import DatabaseConnector
+
+class RbacClass:
+ def __init__(self, db: "DatabaseConnector"): # String annotation
+ self.db = db # Uses db at runtime, but import is deferred
+```
+
+**Why This Works:**
+1. At **import time**: `connectorDbPostgre` imports `RbacClass` ✅
+2. `RbacClass` tries to import `DatabaseConnector` but it's inside `TYPE_CHECKING`, so **no actual import occurs** ✅
+3. At **runtime**: When `getRecordsetWithRBAC()` calls `RbacClass(self)`, `DatabaseConnector` is already fully loaded ✅
+4. Runtime circular reference is safe because Python objects can reference each other once loaded
+
+---
+
+## Architecture Analysis
+
+### **Current Structure**
+
+```
+┌─────────────────────────────────────────────────────────────┐
+│ CONNECTORS │
+│ (Database, External Services) │
+│ │
+│ connectorDbPostgre.py │
+│ └─> Uses: RbacClass (runtime instantiation) ⚠️ │
+│ │
+│ connectorVoiceGoogle.py │
+│ connectorTicketsClickup.py │
+│ connectorTicketsJira.py │
+└─────────────────────────────────────────────────────────────┘
+ ▲
+ │ imports
+ │
+┌─────────────────────────────────────────────────────────────┐
+│ INTERFACES │
+│ (Business Logic, Data Access Layer) │
+│ │
+│ security/rbac.py (moved from interfaces) │
+│ └─> Uses: DatabaseConnector (type hint only) ✅ │
+│ └─> Can be imported by both connectors and interfaces │
+│ │
+│ interfaceBootstrap.py │
+│ └─> Uses: DatabaseConnector │
+│ │
+│ interfaceDbAppObjects.py │
+│ └─> Uses: DatabaseConnector │
+│ └─> Uses: security.rbac.RbacClass │
+│ └─> Uses: interfaceBootstrap.initBootstrap │
+│ │
+│ interfaceDbChatObjects.py │
+│ └─> Uses: DatabaseConnector │
+│ │
+│ interfaceDbComponentObjects.py │
+│ └─> Uses: DatabaseConnector │
+│ │
+│ interfaceVoiceObjects.py │
+│ └─> Uses: connectorVoiceGoogle.ConnectorGoogleSpeech │
+└─────────────────────────────────────────────────────────────┘
+```
+
+---
+
+## Potential Issues & Recommendations
+
+### ✅ **RESOLVED ISSUES**
+
+1. **Circular Import: security.rbac ↔ connectorDbPostgre**
+ - **Status:** ✅ Resolved by moving to security module + TYPE_CHECKING
+ - **Impact:** None - Proper architectural layering maintained
+
+### ⚠️ **POTENTIAL ISSUES**
+
+1. **Tight Coupling: Interfaces depend on specific connectors**
+ - **Issue:** `interfaceDbAppObjects.py` directly imports `DatabaseConnector`
+ - **Impact:** Makes it harder to swap database implementations
+ - **Recommendation:** Consider dependency injection or abstract base class
+
+2. **Connector importing from Security (connectorDbPostgre → security.rbac)** ✅
+ - **Status:** ✅ **RESOLVED** - Moved RBAC to security module
+ - **Current Usage:** Runtime instantiation in `getRecordsetWithRBAC()` (line 1073)
+ - **Code:**
+ ```python
+ RbacInstance = RbacClass(self)
+ permissions = RbacInstance.getUserPermissions(...)
+ ```
+ - **Architecture:** ✅ Correct - Connectors can import from security (infrastructure layer)
+ - **Rationale:** Security is shared infrastructure, not business logic
+
+3. **Multiple interfaces importing same connector**
+ - **Files importing DatabaseConnector:**
+ - `interfaceBootstrap.py`
+ - `interfaceDbAppObjects.py`
+ - `interfaceDbChatObjects.py`
+ - `interfaceDbComponentObjects.py`
+ - **Impact:** Medium - creates coupling
+ - **Recommendation:** Consider a shared database interface abstraction
+
+---
+
+## Recommendations
+
+### **1. Move RBAC Logic Out of Connector**
+**Current:** `connectorDbPostgre.getRecordsetWithRBAC()` instantiates `RbacClass(self)` at runtime
+**Recommendation:**
+- ~~Move `getRecordsetWithRBAC()` to `interfaceRbac.py` or `interfaceDbAppObjects.py`~~ ✅ **RESOLVED** - RBAC moved to security module
+- Connector should only handle raw database operations
+- Interface layer handles RBAC filtering
+
+### **2. Use Dependency Injection**
+**Current:** Interfaces directly import `DatabaseConnector`
+**Recommendation:**
+- Create abstract base class `DatabaseConnectorBase`
+- Interfaces depend on abstraction, not concrete implementation
+- Allows easier testing and swapping implementations
+
+### **3. Consider Layered Architecture**
+```
+┌─────────────────────────────────────┐
+│ Interfaces (Business Logic) │
+│ - Uses connectors via abstraction │
+└─────────────────────────────────────┘
+ ▲
+ │
+┌─────────────────────────────────────┐
+│ Connectors (Infrastructure) │
+│ - No knowledge of interfaces │
+└─────────────────────────────────────┘
+```
+
+### **4. Use TYPE_CHECKING for All Type-Only Imports**
+**Current:** `security/rbac.py` uses TYPE_CHECKING (moved from interfaces)
+**Recommendation:** Use TYPE_CHECKING for all type-only imports between layers
+
+---
+
+## Summary
+
+### **Current State:**
+- ✅ 1 circular dependency **RESOLVED** (moved to security module)
+- ✅ Architectural violation **FIXED** (RBAC moved to security)
+- ⚠️ Multiple tight couplings to `DatabaseConnector` (acceptable for now)
+
+### **Architectural Health:**
+- **Overall:** 🟢 **Good** - Proper layering maintained
+- **Architecture:** ✅ Connectors → Security (infrastructure) ✅ Interfaces → Security (infrastructure)
+- **Risk Level:** Low - Clean separation of concerns
+
+### **Completed Actions:**
+1. ✅ **DONE:** Fixed circular import with TYPE_CHECKING
+2. ✅ **DONE:** Moved RBAC to security module (proper architectural layering)
+3. 🔄 **OPTIONAL:** Introduce abstraction layer for database connector (future improvement)
diff --git a/modules/connectors/connectorDbJson.py b/modules/connectors/connectorDbJson.py
deleted file mode 100644
index 0b44e6df..00000000
--- a/modules/connectors/connectorDbJson.py
+++ /dev/null
@@ -1,678 +0,0 @@
-import json
-import os
-from typing import List, Dict, Any, Optional, TypedDict
-import logging
-import uuid
-from pydantic import BaseModel
-import threading
-import time
-
-from modules.shared.timeUtils import getUtcTimestamp
-
-logger = logging.getLogger(__name__)
-
-class TableCache(TypedDict):
- """Type definition for table cache entries"""
- recordIds: List[str]
-
-class DatabaseConnector:
- """
- A connector for JSON-based data storage.
- Provides generic database operations without user/mandate filtering.
- Stores tables as folders and records as individual files.
- """
- def __init__(self, dbHost: str, dbDatabase: str, dbUser: str = None, dbPassword: str = None, userId: str = None):
- # Store the input parameters
- self.dbHost = dbHost
- self.dbDatabase = dbDatabase
- self.dbUser = dbUser
- self.dbPassword = dbPassword
-
- # Set userId (default to empty string if None)
- self.userId = userId if userId is not None else ""
-
- # Initialize database system
- self.initDbSystem()
-
- # Set up database folder path
- self.dbFolder = os.path.join(self.dbHost, self.dbDatabase)
-
- # Cache for loaded data
- self._tablesCache: Dict[str, List[Dict[str, Any]]] = {}
- self._tableMetadataCache: Dict[str, TableCache] = {} # Cache for table metadata (record IDs, etc.)
-
- # File locks with timeout protection
- self._file_locks = {}
- self._lock_manager = threading.Lock()
- self._lock_timeouts = {} # Track when locks were acquired
-
- # Initialize system table
- self._systemTableName = "_system"
- self._initializeSystemTable()
-
- logger.debug(f"Context: userId={self.userId}")
-
- def initDbSystem(self):
- """Initialize the database system - creates necessary directories and structure."""
- try:
- # Ensure the database directory exists
- self.dbFolder = os.path.join(self.dbHost, self.dbDatabase)
- os.makedirs(self.dbFolder, exist_ok=True)
- logger.info(f"Database system initialized: {self.dbFolder}")
- except Exception as e:
- logger.error(f"Error initializing database system: {e}")
- raise
-
- def _initializeSystemTable(self):
- """Initializes the system table if it doesn't exist yet."""
- systemTablePath = self._getTablePath(self._systemTableName)
- if not os.path.exists(systemTablePath):
- emptySystemTable = {}
- self._saveSystemTable(emptySystemTable)
- logger.info(f"System table initialized in {systemTablePath}")
- else:
- # Load existing system table to ensure it's available
- self._loadSystemTable()
- logger.debug(f"Existing system table loaded from {systemTablePath}")
-
- def _loadSystemTable(self) -> Dict[str, str]:
- """Loads the system table with the initial IDs."""
- # Check if system table is in cache
- if f"_{self._systemTableName}" in self._tablesCache:
- return self._tablesCache[f"_{self._systemTableName}"]
-
- systemTablePath = self._getTablePath(self._systemTableName)
- try:
- if os.path.exists(systemTablePath):
- with open(systemTablePath, 'r', encoding='utf-8') as f:
- data = json.load(f)
- # Store in cache with special prefix to avoid collision with regular tables
- self._tablesCache[f"_{self._systemTableName}"] = data
- return data
- else:
- self._tablesCache[f"_{self._systemTableName}"] = {}
- return {}
- except Exception as e:
- logger.error(f"Error loading the system table: {e}")
- self._tablesCache[f"_{self._systemTableName}"] = {}
- return {}
-
- def _saveSystemTable(self, data: Dict[str, str]) -> bool:
- """Saves the system table with the initial IDs."""
- systemTablePath = self._getTablePath(self._systemTableName)
- try:
- with open(systemTablePath, 'w', encoding='utf-8') as f:
- json.dump(data, f, indent=2, ensure_ascii=False)
- # Update cache
- self._tablesCache[f"_{self._systemTableName}"] = data
- return True
- except Exception as e:
- logger.error(f"Error saving the system table: {e}")
- return False
-
- def _getTablePath(self, table: str) -> str:
- """Returns the full path to a table folder"""
- return os.path.join(self.dbFolder, table)
-
- def _getRecordPath(self, table: str, recordId: str) -> str:
- """Returns the full path to a record file"""
- return os.path.join(self._getTablePath(table), f"{recordId}.json")
-
- def _get_file_lock(self, filepath: str, timeout_seconds: int = 30):
- """Get file lock with timeout protection"""
- with self._lock_manager:
- if filepath not in self._file_locks:
- self._file_locks[filepath] = threading.Lock()
-
- lock = self._file_locks[filepath]
-
- # Check if lock is stale (held too long)
- if filepath in self._lock_timeouts:
- lock_age = time.time() - self._lock_timeouts[filepath]
- if lock_age > timeout_seconds:
- logger.warning(f"Stale lock detected for {filepath}, age: {lock_age}s")
- # Force release stale lock
- try:
- lock.release()
- except:
- pass
- # Create new lock
- self._file_locks[filepath] = threading.Lock()
- lock = self._file_locks[filepath]
-
- return lock
-
- def _get_table_lock(self, table: str, timeout_seconds: int = 30):
- """Get table-level lock for metadata operations"""
- table_lock_key = f"table_{table}"
- return self._get_file_lock(table_lock_key, timeout_seconds)
-
- def _ensureTableDirectory(self, table: str) -> bool:
- """Ensures the table directory exists."""
- if table == self._systemTableName:
- return True
-
- tablePath = self._getTablePath(table)
- try:
- os.makedirs(tablePath, exist_ok=True)
- return True
- except Exception as e:
- logger.error(f"Error creating table directory {tablePath}: {e}")
- return False
-
- def _loadTableMetadata(self, table: str) -> Dict[str, Any]:
- """Loads table metadata (list of record IDs) without loading actual records.
- NOTE: This method is safe to call without additional locking.
- """
- if table in self._tableMetadataCache:
- return self._tableMetadataCache[table]
-
- # Ensure table directory exists
- if not self._ensureTableDirectory(table):
- return {"recordIds": []}
-
- tablePath = self._getTablePath(table)
- metadata = {"recordIds": []}
-
- try:
- if os.path.exists(tablePath):
- for fileName in os.listdir(tablePath):
- if fileName.endswith('.json') and fileName != '_metadata.json':
- recordId = fileName[:-5] # Remove .json extension
- metadata["recordIds"].append(recordId)
-
- metadata["recordIds"].sort()
- self._tableMetadataCache[table] = metadata
- except Exception as e:
- logger.error(f"Error loading table metadata for {table}: {e}")
-
- return metadata
-
- def _loadRecord(self, table: str, recordId: str) -> Optional[Dict[str, Any]]:
- """Loads a single record from the table."""
- recordPath = self._getRecordPath(table, recordId)
- try:
- if os.path.exists(recordPath):
- with open(recordPath, 'r', encoding='utf-8') as f:
- record = json.load(f)
- return record
- except Exception as e:
- logger.error(f"Error loading record {recordId} from table {table}: {e}")
- return None
-
- def _saveRecord(self, table: str, recordId: str, record: Dict[str, Any]) -> bool:
- """Saves a single record to the table with atomic metadata operations."""
- recordPath = self._getRecordPath(table, recordId)
- record_lock = self._get_file_lock(recordPath)
- table_lock = self._get_table_lock(table)
-
- try:
- # Acquire both locks with timeout - record lock first, then table lock
- if not record_lock.acquire(timeout=30):
- raise TimeoutError(f"Could not acquire record lock for {recordPath} within 30 seconds")
-
- if not table_lock.acquire(timeout=30):
- record_lock.release()
- raise TimeoutError(f"Could not acquire table lock for {table} within 30 seconds")
-
- # Record lock acquisition time
- self._lock_timeouts[recordPath] = time.time()
- self._lock_timeouts[f"table_{table}"] = time.time()
-
- # Ensure table directory exists
- if not self._ensureTableDirectory(table):
- raise ValueError(f"Error creating table directory for {table}")
-
- # Ensure recordId is a string
- recordId = str(recordId)
-
- # CRITICAL: Ensure record ID matches the file name
- if "id" in record and str(record["id"]) != recordId:
- logger.error(f"Record ID mismatch: file name ID ({recordId}) does not match record ID ({record['id']})")
- raise ValueError(f"Record ID mismatch: file name ID ({recordId}) does not match record ID ({record['id']})")
-
- # Add metadata
- currentTime = getUtcTimestamp()
- if "_createdAt" not in record:
- record["_createdAt"] = currentTime
- record["_createdBy"] = self.userId
- record["_modifiedAt"] = currentTime
- record["_modifiedBy"] = self.userId
-
- # Save the record file using atomic write
- tempPath = recordPath + '.tmp'
-
- # Ensure directory exists
- os.makedirs(os.path.dirname(recordPath), exist_ok=True)
-
- # Write to temporary file first
- with open(tempPath, 'w', encoding='utf-8') as f:
- json.dump(record, f, indent=2, ensure_ascii=False)
-
- # Verify the temporary file can be read back (validation)
- try:
- with open(tempPath, 'r', encoding='utf-8') as f:
- json.load(f) # This will fail if file is corrupted
- except Exception as e:
- logger.error(f"Validation failed for record {recordId}: {e}")
- # Clean up temp file
- if os.path.exists(tempPath):
- os.remove(tempPath)
- raise ValueError(f"Record validation failed: {e}")
-
- # Atomic move from temp to final location
- os.replace(tempPath, recordPath)
-
- # ATOMIC: Update metadata while holding both locks
- metadata = self._loadTableMetadata(table)
- if recordId not in metadata["recordIds"]:
- metadata["recordIds"].append(recordId)
- metadata["recordIds"].sort()
- self._saveTableMetadata(table, metadata)
-
- # Update cache if it exists (also protected by table lock)
- if table in self._tablesCache:
- # Find and update existing record or append new one
- found = False
- for i, existing_record in enumerate(self._tablesCache[table]):
- if str(existing_record.get("id")) == recordId:
- self._tablesCache[table][i] = record
- found = True
- break
- if not found:
- self._tablesCache[table].append(record)
-
- return True
-
- except Exception as e:
- logger.error(f"Error saving record {recordId} to table {table}: {e}")
- # Clean up temp file if it exists
- tempPath = self._getRecordPath(table, recordId) + '.tmp'
- if os.path.exists(tempPath):
- try:
- os.remove(tempPath)
- except:
- pass
- return False
-
- finally:
- # ALWAYS release both locks, even on error
- try:
- if table_lock.locked():
- table_lock.release()
- if f"table_{table}" in self._lock_timeouts:
- del self._lock_timeouts[f"table_{table}"]
- except Exception as release_error:
- logger.error(f"Error releasing table lock for {table}: {release_error}")
-
- try:
- if record_lock.locked():
- record_lock.release()
- if recordPath in self._lock_timeouts:
- del self._lock_timeouts[recordPath]
- except Exception as release_error:
- logger.error(f"Error releasing record lock for {recordPath}: {release_error}")
-
- def _loadTable(self, table: str) -> List[Dict[str, Any]]:
- """Loads all records from a table folder."""
- # If the table is the system table, load it directly
- if table == self._systemTableName:
- return self._loadSystemTable()
-
- # If the table is already in the cache, use the cache
- if table in self._tablesCache:
- return self._tablesCache[table]
-
- # Load metadata first
- metadata = self._loadTableMetadata(table)
- records = []
-
- # Load each record
- for recordId in metadata["recordIds"]:
- # Skip metadata file
- if recordId == "_metadata":
- continue
- record = self._loadRecord(table, recordId)
- if record:
- records.append(record)
-
- self._tablesCache[table] = records
- return records
-
- def _saveTable(self, table: str, data: List[Dict[str, Any]]) -> bool:
- """Saves all records to a table folder"""
- # The system table is handled specially
- if table == self._systemTableName:
- return self._saveSystemTable(data)
-
- tablePath = self._getTablePath(table)
- try:
- # Ensure table directory exists
- os.makedirs(tablePath, exist_ok=True)
-
- # Save each record as a separate file
- for record in data:
- if "id" not in record:
- logger.error(f"Record missing ID in table {table}")
- continue
-
- recordPath = self._getRecordPath(table, record["id"])
- with open(recordPath, 'w', encoding='utf-8') as f:
- json.dump(record, f, indent=2, ensure_ascii=False)
-
- # Update the cache
- self._tablesCache[table] = data
- logger.debug(f"Successfully saved table {table}")
- return True
- except Exception as e:
- logger.error(f"Error saving table {table}: {str(e)}")
- logger.error(f"Error type: {type(e).__name__}")
- logger.error(f"Error details: {e.__dict__ if hasattr(e, '__dict__') else 'No details available'}")
- return False
-
- def _applyRecordFilter(self, records: List[Dict[str, Any]], recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
- """Applies a record filter to the records"""
- if not recordFilter:
- return records
-
- filteredRecords = []
-
- for record in records:
- match = True
-
- for field, value in recordFilter.items():
- # Check if the field exists
- if field not in record:
- match = False
- break
-
- # Convert both values to strings for comparison
- recordValue = str(record[field])
- filterValue = str(value)
-
- # Direct string comparison
- if recordValue != filterValue:
- match = False
- break
-
- if match:
- filteredRecords.append(record)
-
- return filteredRecords
-
- def _registerInitialId(self, table: str, initialId: str) -> bool:
- """Registers the initial ID for a table."""
- try:
- systemData = self._loadSystemTable()
-
- if table not in systemData:
- systemData[table] = initialId
- success = self._saveSystemTable(systemData)
- if success:
- logger.info(f"Initial ID {initialId} for table {table} registered")
- return success
- return True # If already present, this is not an error
- except Exception as e:
- logger.error(f"Error registering the initial ID for table {table}: {e}")
- return False
-
- def _removeInitialId(self, table: str) -> bool:
- """Removes the initial ID for a table from the system table."""
- try:
- systemData = self._loadSystemTable()
-
- if table in systemData:
- del systemData[table]
- success = self._saveSystemTable(systemData)
- if success:
- logger.info(f"Initial ID for table {table} removed from system table")
- return success
- return True # If not present, this is not an error
- except Exception as e:
- logger.error(f"Error removing initial ID for table {table}: {e}")
- return False
-
-
-
- def _saveTableMetadata(self, table: str, metadata: Dict[str, Any]) -> bool:
- """Saves table metadata to a metadata file.
- NOTE: This method assumes the caller already holds the table lock.
- """
- try:
- # Create metadata file path
- metadataPath = os.path.join(self._getTablePath(table), "_metadata.json")
-
- # Save metadata (caller should already hold table lock)
- with open(metadataPath, 'w', encoding='utf-8') as f:
- json.dump(metadata, f, indent=2, ensure_ascii=False)
-
- # Update cache
- self._tableMetadataCache[table] = metadata
-
- return True
-
- except Exception as e:
- logger.error(f"Error saving metadata for table {table}: {e}")
- return False
-
- def updateContext(self, userId: str) -> None:
- """Updates the context of the database connector."""
- if userId is None:
- raise ValueError("userId must be provided")
-
- self.userId = userId
- logger.info(f"Updated database context: userId={self.userId}")
-
- # Clear cache to ensure fresh data with new context
- self._tablesCache = {}
- self._tableMetadataCache = {}
-
- def clearTableCache(self, table: str) -> None:
- """Clears cache for a specific table to ensure fresh data."""
- if table in self._tablesCache:
- del self._tablesCache[table]
- logger.debug(f"Cleared cache for table: {table}")
-
- if table in self._tableMetadataCache:
- del self._tableMetadataCache[table]
- logger.debug(f"Cleared metadata cache for table: {table}")
-
- # Public API
-
- def getTables(self) -> List[str]:
- """Returns a list of all available tables."""
- tables = []
-
- try:
- for item in os.listdir(self.dbFolder):
- itemPath = os.path.join(self.dbFolder, item)
- if os.path.isdir(itemPath) and not item.startswith('_'):
- tables.append(item)
- except Exception as e:
- logger.error(f"Error reading the database directory: {e}")
-
- return tables
-
- def getFields(self, table: str) -> List[str]:
- """Returns a list of all fields in a table."""
- data = self._loadTable(table)
-
- if not data:
- return []
-
- fields = list(data[0].keys()) if data else []
-
- return fields
-
- def getSchema(self, table: str, language: str = None) -> Dict[str, Dict[str, Any]]:
- """Returns a schema object for a table with data types and labels."""
- data = self._loadTable(table)
-
- schema = {}
-
- if not data:
- return schema
-
- firstRecord = data[0]
-
- for field, value in firstRecord.items():
- dataType = type(value).__name__
- label = field
-
- schema[field] = {
- "type": dataType,
- "label": label
- }
-
- return schema
-
- def getRecordset(self, table: str, fieldFilter: List[str] = None, recordFilter: Dict[str, Any] = None) -> List[Dict[str, Any]]:
- """Returns a list of records from a table, filtered by criteria."""
- # If we have specific record IDs in the filter, only load those records
- if recordFilter and "id" in recordFilter:
- recordId = recordFilter["id"]
- record = self._loadRecord(table, recordId)
- if record:
- records = [record]
- else:
- return []
- else:
- # Load all records if no specific ID filter
- records = self._loadTable(table)
-
- # Apply recordFilter if available
- if recordFilter:
- records = self._applyRecordFilter(records, recordFilter)
-
- # If fieldFilter is available, reduce the fields
- if fieldFilter and isinstance(fieldFilter, list):
- result = []
- for record in records:
- filteredRecord = {}
- for field in fieldFilter:
- if field in record:
- filteredRecord[field] = record[field]
- result.append(filteredRecord)
- return result
-
- return records
-
- def recordCreate(self, table: str, record: Dict[str, Any]) -> Dict[str, Any]:
- """Creates a new record in a table."""
- # Ensure record has an ID
- if "id" not in record:
- record["id"] = str(uuid.uuid4())
-
- # If record is a Pydantic model, convert to dict
- if isinstance(record, BaseModel):
- record = record.model_dump()
-
- # Save record
- self._saveRecord(table, record["id"], record)
- return record
-
- def recordModify(self, table: str, recordId: str, record: Dict[str, Any]) -> Dict[str, Any]:
- """Modifies an existing record in a table."""
- # Load existing record
- existingRecord = self._loadRecord(table, recordId)
- if not existingRecord:
- raise ValueError(f"Record {recordId} not found in table {table}")
-
- # If record is a Pydantic model, convert to dict
- if isinstance(record, BaseModel):
- record = record.model_dump()
-
- # CRITICAL: Ensure we never modify the ID
- if "id" in record and str(record["id"]) != recordId:
- logger.error(f"Attempted to modify record ID from {recordId} to {record['id']}")
- raise ValueError("Cannot modify record ID - it must match the file name")
-
- # Update existing record with new data
- existingRecord.update(record)
-
- # Save updated record
- self._saveRecord(table, recordId, existingRecord)
- return existingRecord
-
- def recordDelete(self, table: str, recordId: str) -> bool:
- """Deletes a record from the table with atomic metadata operations."""
- recordPath = self._getRecordPath(table, recordId)
- record_lock = self._get_file_lock(recordPath)
- table_lock = self._get_table_lock(table)
-
- try:
- # Acquire both locks with timeout - record lock first, then table lock
- if not record_lock.acquire(timeout=30):
- raise TimeoutError(f"Could not acquire record lock for {recordPath} within 30 seconds")
-
- if not table_lock.acquire(timeout=30):
- record_lock.release()
- raise TimeoutError(f"Could not acquire table lock for {table} within 30 seconds")
-
- # Record lock acquisition time
- self._lock_timeouts[recordPath] = time.time()
- self._lock_timeouts[f"table_{table}"] = time.time()
-
- # Load metadata
- metadata = self._loadTableMetadata(table)
-
- if recordId not in metadata["recordIds"]:
- return False
-
- # Check if it's an initial record
- initialId = self.getInitialId(table)
- if initialId is not None and initialId == recordId:
- self._removeInitialId(table)
- logger.info(f"Initial ID {recordId} for table {table} has been removed from the system table")
-
- # Delete the record file
- if os.path.exists(recordPath):
- os.remove(recordPath)
-
- # ATOMIC: Update metadata while holding both locks
- metadata["recordIds"].remove(recordId)
- self._saveTableMetadata(table, metadata)
-
- # Update table cache if it exists (also protected by table lock)
- if table in self._tablesCache:
- self._tablesCache[table] = [r for r in self._tablesCache[table] if r.get("id") != recordId]
-
- return True
- else:
- return False
-
- except Exception as e:
- logger.error(f"Error deleting record {recordId} from table {table}: {e}")
- return False
-
- finally:
- # ALWAYS release both locks, even on error
- try:
- if table_lock.locked():
- table_lock.release()
- if f"table_{table}" in self._lock_timeouts:
- del self._lock_timeouts[f"table_{table}"]
- except Exception as release_error:
- logger.error(f"Error releasing table lock for {table}: {release_error}")
-
- try:
- if record_lock.locked():
- record_lock.release()
- if recordPath in self._lock_timeouts:
- del self._lock_timeouts[recordPath]
- except Exception as release_error:
- logger.error(f"Error releasing record lock for {recordPath}: {release_error}")
-
- def getInitialId(self, table_or_model) -> Optional[str]:
- """Returns the initial ID for a table."""
- # Handle both string table names (legacy) and model classes (new)
- if isinstance(table_or_model, str):
- table = table_or_model
- else:
- table = table_or_model.__name__
-
- systemData = self._loadSystemTable()
- initialId = systemData.get(table)
- logger.debug(f"Initial ID for table '{table}': {initialId}")
- return initialId
-
\ No newline at end of file
diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py
index c9206c8d..828fa703 100644
--- a/modules/connectors/connectorDbPostgre.py
+++ b/modules/connectors/connectorDbPostgre.py
@@ -1,13 +1,16 @@
import psycopg2
import psycopg2.extras
import logging
-from typing import List, Dict, Any, Optional, Union, get_origin, get_args
+from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type
import uuid
from pydantic import BaseModel, Field
import threading
from modules.shared.timeUtils import getUtcTimestamp
from modules.shared.configuration import APP_CONFIG
+from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions
+from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
+from modules.security.rbac import RbacClass
logger = logging.getLogger(__name__)
@@ -1039,6 +1042,208 @@ class DatabaseConnector:
initialId = systemData.get(table)
return initialId
+ def getRecordsetWithRBAC(
+ self,
+ modelClass: Type[BaseModel],
+ currentUser: User,
+ recordFilter: Dict[str, Any] = None,
+ orderBy: str = None,
+ limit: int = None,
+ ) -> List[Dict[str, Any]]:
+ """
+ Get records with RBAC filtering applied at database level.
+
+ Args:
+ modelClass: Pydantic model class for the table
+ currentUser: User object with roleLabels
+ recordFilter: Additional record filters
+ orderBy: Field to order by (defaults to "id")
+ limit: Maximum number of records to return
+
+ Returns:
+ List of filtered records
+ """
+ table = modelClass.__name__
+
+ try:
+ if not self._ensureTableExists(modelClass):
+ return []
+
+ # Get RBAC permissions for this table
+ RbacInstance = RbacClass(self)
+ permissions = RbacInstance.getUserPermissions(
+ currentUser,
+ AccessRuleContext.DATA,
+ table
+ )
+
+ # Check view permission first
+ if not permissions.view:
+ logger.debug(f"User {currentUser.id} has no view permission for table {table}")
+ return []
+
+ # Build WHERE clause with RBAC filtering
+ whereConditions = []
+ whereValues = []
+
+ # Add RBAC WHERE clause based on read permission
+ rbacWhereClause = self.buildRbacWhereClause(permissions, currentUser, table)
+ if rbacWhereClause:
+ whereConditions.append(rbacWhereClause["condition"])
+ whereValues.extend(rbacWhereClause["values"])
+
+ # Add additional record filters
+ if recordFilter:
+ for field, value in recordFilter.items():
+ whereConditions.append(f'"{field}" = %s')
+ whereValues.append(value)
+
+ # Build the query
+ whereClause = ""
+ if whereConditions:
+ whereClause = " WHERE " + " AND ".join(whereConditions)
+
+ orderByClause = f' ORDER BY "{orderBy}"' if orderBy else ' ORDER BY "id"'
+ limitClause = f" LIMIT {limit}" if limit else ""
+
+ query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
+
+ with self.connection.cursor() as cursor:
+ cursor.execute(query, whereValues)
+ records = [dict(row) for row in cursor.fetchall()]
+
+ # Handle JSONB fields and ensure numeric types are correct
+ fields = _get_model_fields(modelClass)
+ for record in records:
+ for fieldName, fieldType in fields.items():
+ # Ensure numeric fields are properly typed
+ if fieldType in ("DOUBLE PRECISION", "INTEGER") and fieldName in record:
+ value = record[fieldName]
+ if value is not None:
+ try:
+ if fieldType == "DOUBLE PRECISION":
+ record[fieldName] = float(value)
+ elif fieldType == "INTEGER":
+ record[fieldName] = int(value)
+ except (ValueError, TypeError):
+ logger.warning(
+ f"Could not convert {fieldName} to {fieldType} for record {record.get('id', 'unknown')}: {value}"
+ )
+ elif fieldType == "JSONB" and fieldName in record:
+ if record[fieldName] is None:
+ if fieldName in ["logs", "messages", "tasks", "expectedDocumentFormats", "resultDocuments"]:
+ record[fieldName] = []
+ elif fieldName in ["execParameters", "stats"]:
+ record[fieldName] = {}
+ else:
+ record[fieldName] = None
+ else:
+ import json
+ try:
+ if isinstance(record[fieldName], str):
+ record[fieldName] = json.loads(record[fieldName])
+ elif isinstance(record[fieldName], (dict, list)):
+ pass
+ else:
+ record[fieldName] = json.loads(str(record[fieldName]))
+ except (json.JSONDecodeError, TypeError, ValueError):
+ logger.warning(
+ f"Could not parse JSONB field {fieldName}, keeping as string: {record[fieldName]}"
+ )
+
+ return records
+ except Exception as e:
+ logger.error(f"Error loading records with RBAC from table {table}: {e}")
+ return []
+
+ def buildRbacWhereClause(
+ self,
+ permissions: UserPermissions,
+ currentUser: User,
+ table: str
+ ) -> Optional[Dict[str, Any]]:
+ """
+ Build RBAC WHERE clause based on permissions and access level.
+
+ Args:
+ permissions: UserPermissions object
+ currentUser: User object
+ table: Table name
+
+ Returns:
+ Dictionary with "condition" and "values" keys, or None if no filtering needed
+ """
+ if not permissions or not hasattr(permissions, "read"):
+ return None
+
+ readLevel = permissions.read
+
+ # No access - return empty result condition
+ if readLevel == AccessLevel.NONE:
+ return {"condition": "1 = 0", "values": []}
+
+ # All records - no filtering needed
+ if readLevel == AccessLevel.ALL:
+ return None
+
+ # My records - filter by _createdBy or userId field
+ if readLevel == AccessLevel.MY:
+ # Try common field names for creator
+ userIdField = None
+ if table == "UserInDB":
+ userIdField = "id"
+ elif table == "UserConnection":
+ userIdField = "userId"
+ else:
+ userIdField = "_createdBy"
+
+ return {
+ "condition": f'"{userIdField}" = %s',
+ "values": [currentUser.id]
+ }
+
+ # Group records - filter by mandateId
+ if readLevel == AccessLevel.GROUP:
+ if not currentUser.mandateId:
+ logger.warning(f"User {currentUser.id} has no mandateId for GROUP access")
+ return {"condition": "1 = 0", "values": []}
+
+ # For UserInDB, filter by mandateId directly
+ if table == "UserInDB":
+ return {
+ "condition": '"mandateId" = %s',
+ "values": [currentUser.mandateId]
+ }
+ # For UserConnection, need to join with UserInDB or filter by mandateId in user
+ elif table == "UserConnection":
+ # Get all user IDs in the same mandate using direct SQL query
+ try:
+ with self.connection.cursor() as cursor:
+ cursor.execute(
+ 'SELECT "id" FROM "UserInDB" WHERE "mandateId" = %s',
+ (currentUser.mandateId,)
+ )
+ users = cursor.fetchall()
+ userIds = [u["id"] for u in users]
+ if not userIds:
+ return {"condition": "1 = 0", "values": []}
+ placeholders = ",".join(["%s"] * len(userIds))
+ return {
+ "condition": f'"userId" IN ({placeholders})',
+ "values": userIds
+ }
+ except Exception as e:
+ logger.error(f"Error building GROUP filter for UserConnection: {e}")
+ return {"condition": "1 = 0", "values": []}
+ # For other tables, filter by mandateId
+ else:
+ return {
+ "condition": '"mandateId" = %s',
+ "values": [currentUser.mandateId]
+ }
+
+ return None
+
def close(self):
"""Close the database connection."""
if (
diff --git a/modules/datamodels/datamodelRbac.py b/modules/datamodels/datamodelRbac.py
new file mode 100644
index 00000000..c2ba90d8
--- /dev/null
+++ b/modules/datamodels/datamodelRbac.py
@@ -0,0 +1,102 @@
+"""RBAC models: AccessRule, AccessRuleContext."""
+
+import uuid
+from typing import Optional
+from enum import Enum
+from pydantic import BaseModel, Field
+from modules.shared.attributeUtils import registerModelLabels
+from modules.datamodels.datamodelUam import AccessLevel
+
+
+class AccessRuleContext(str, Enum):
+ """Context type enumeration"""
+ DATA = "DATA" # Database tables and fields
+ UI = "UI" # UI elements and features
+ RESOURCE = "RESOURCE" # System resources (AI models, actions, etc.)
+
+
+class AccessRule(BaseModel):
+ """Data model for access control rules"""
+ id: str = Field(
+ default_factory=lambda: str(uuid.uuid4()),
+ description="Unique ID of the access rule",
+ json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}
+ )
+ roleLabel: str = Field(
+ description="Role label this rule applies to",
+ json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True, "frontend_options": "user.role"}
+ )
+ context: AccessRuleContext = Field(
+ description="Context type: DATA (database), UI (interface), RESOURCE (system resources)",
+ json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True, "frontend_options": [
+ {"value": "DATA", "label": {"en": "Data", "fr": "Données"}},
+ {"value": "UI", "label": {"en": "UI", "fr": "Interface"}},
+ {"value": "RESOURCE", "label": {"en": "Resource", "fr": "Ressource"}}
+ ]}
+ )
+ item: Optional[str] = Field(
+ None,
+ description="Item identifier (null = all items in context). Format: DATA: '' or '.', UI: cascading string (e.g., 'playground.voice.settings'), RESOURCE: cascading string (e.g., 'ai.model.anthropic')",
+ json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False}
+ )
+ view: bool = Field(
+ False,
+ description="View permission: if true, item is visible/enabled. Only objects with view=true are shown.",
+ json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": True}
+ )
+ read: Optional[AccessLevel] = Field(
+ None,
+ description="Read permission level (only for DATA context)",
+ json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [
+ {"value": "a", "label": {"en": "All Records", "fr": "Tous les enregistrements"}},
+ {"value": "m", "label": {"en": "My Records", "fr": "Mes enregistrements"}},
+ {"value": "g", "label": {"en": "Group Records", "fr": "Enregistrements du groupe"}},
+ {"value": "n", "label": {"en": "No Access", "fr": "Aucun accès"}}
+ ]}
+ )
+ create: Optional[AccessLevel] = Field(
+ None,
+ description="Create permission level (only for DATA context)",
+ json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [
+ {"value": "a", "label": {"en": "All Records", "fr": "Tous les enregistrements"}},
+ {"value": "m", "label": {"en": "My Records", "fr": "Mes enregistrements"}},
+ {"value": "g", "label": {"en": "Group Records", "fr": "Enregistrements du groupe"}},
+ {"value": "n", "label": {"en": "No Access", "fr": "Aucun accès"}}
+ ]}
+ )
+ update: Optional[AccessLevel] = Field(
+ None,
+ description="Update permission level (only for DATA context)",
+ json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [
+ {"value": "a", "label": {"en": "All Records", "fr": "Tous les enregistrements"}},
+ {"value": "m", "label": {"en": "My Records", "fr": "Mes enregistrements"}},
+ {"value": "g", "label": {"en": "Group Records", "fr": "Enregistrements du groupe"}},
+ {"value": "n", "label": {"en": "No Access", "fr": "Aucun accès"}}
+ ]}
+ )
+ delete: Optional[AccessLevel] = Field(
+ None,
+ description="Delete permission level (only for DATA context)",
+ json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [
+ {"value": "a", "label": {"en": "All Records", "fr": "Tous les enregistrements"}},
+ {"value": "m", "label": {"en": "My Records", "fr": "Mes enregistrements"}},
+ {"value": "g", "label": {"en": "Group Records", "fr": "Enregistrements du groupe"}},
+ {"value": "n", "label": {"en": "No Access", "fr": "Aucun accès"}}
+ ]}
+ )
+
+registerModelLabels(
+ "AccessRule",
+ {"en": "Access Rule", "fr": "Règle d'accès"},
+ {
+ "id": {"en": "ID", "fr": "ID"},
+ "roleLabel": {"en": "Role Label", "fr": "Label du rôle"},
+ "context": {"en": "Context", "fr": "Contexte"},
+ "item": {"en": "Item", "fr": "Élément"},
+ "view": {"en": "View", "fr": "Vue"},
+ "read": {"en": "Read", "fr": "Lecture"},
+ "create": {"en": "Create", "fr": "Créer"},
+ "update": {"en": "Update", "fr": "Mettre à jour"},
+ "delete": {"en": "Delete", "fr": "Supprimer"},
+ },
+)
diff --git a/modules/datamodels/datamodelUam.py b/modules/datamodels/datamodelUam.py
index 4a9c10aa..4c9e0a84 100644
--- a/modules/datamodels/datamodelUam.py
+++ b/modules/datamodels/datamodelUam.py
@@ -1,7 +1,7 @@
"""UAM models: User, Mandate, UserConnection."""
import uuid
-from typing import Optional
+from typing import Optional, List
from enum import Enum
from pydantic import BaseModel, Field, EmailStr
from modules.shared.attributeUtils import registerModelLabels
@@ -13,7 +13,7 @@ class AuthAuthority(str, Enum):
GOOGLE = "google"
MSFT = "msft"
-class UserPrivilege(str, Enum):
+class UserPrivilege(str, Enum): # TODO: TO remove, one new RBAC System is in place!
SYSADMIN = "sysadmin"
ADMIN = "admin"
USER = "user"
@@ -24,6 +24,36 @@ class ConnectionStatus(str, Enum):
REVOKED = "revoked"
PENDING = "pending"
+class AccessLevel(str, Enum):
+ """Access level enumeration for RBAC"""
+ ALL = "a" # All records
+ MY = "m" # My records (created by me)
+ GROUP = "g" # Group records (group context is the mandate)
+ NONE = "n" # No access
+
+class UserPermissions(BaseModel):
+ """User permissions model for RBAC"""
+ view: bool = Field(
+ default=False,
+ description="View permission: if true, item is visible/enabled"
+ )
+ read: AccessLevel = Field(
+ default=AccessLevel.NONE,
+ description="Read permission level"
+ )
+ create: AccessLevel = Field(
+ default=AccessLevel.NONE,
+ description="Create permission level"
+ )
+ update: AccessLevel = Field(
+ default=AccessLevel.NONE,
+ description="Update permission level"
+ )
+ delete: AccessLevel = Field(
+ default=AccessLevel.NONE,
+ description="Delete permission level"
+ )
+
class Mandate(BaseModel):
id: str = Field(
default_factory=lambda: str(uuid.uuid4()),
@@ -122,11 +152,12 @@ class User(BaseModel):
{"value": "it", "label": {"en": "Italiano", "fr": "Italien"}},
]})
enabled: bool = Field(default=True, description="Indicates whether the user is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False})
- privilege: UserPrivilege = Field(default=UserPrivilege.USER, description="Permission level", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True, "frontend_options": [
- {"value": "user", "label": {"en": "User", "fr": "Utilisateur"}},
- {"value": "admin", "label": {"en": "Admin", "fr": "Administrateur"}},
- {"value": "sysadmin", "label": {"en": "SysAdmin", "fr": "Administrateur système"}},
- ]})
+ privilege: UserPrivilege = Field(default=UserPrivilege.USER, description="Permission level (DEPRECATED: use roleLabels instead)", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": "user.role"})
+ roleLabels: List[str] = Field(
+ default_factory=list,
+ description="List of role labels assigned to this user. All roles are opening roles (union) - if one role enables something, it is enabled.",
+ json_schema_extra={"frontend_type": "multiselect", "frontend_readonly": False, "frontend_required": True, "frontend_options": "user.role"}
+ )
authenticationAuthority: AuthAuthority = Field(default=AuthAuthority.LOCAL, description="Primary authentication authority", json_schema_extra={"frontend_type": "select", "frontend_readonly": True, "frontend_required": False, "frontend_options": [
{"value": "local", "label": {"en": "Local", "fr": "Local"}},
{"value": "google", "label": {"en": "Google", "fr": "Google"}},
@@ -144,6 +175,7 @@ registerModelLabels(
"language": {"en": "Language", "fr": "Langue"},
"enabled": {"en": "Enabled", "fr": "Activé"},
"privilege": {"en": "Privilege", "fr": "Privilège"},
+ "roleLabels": {"en": "Role Labels", "fr": "Labels de rôle"},
"authenticationAuthority": {"en": "Auth Authority", "fr": "Autorité d'authentification"},
"mandateId": {"en": "Mandate ID", "fr": "ID de mandat"},
},
diff --git a/modules/features/automation/mainAutomation.py b/modules/features/automation/mainAutomation.py
index c0534229..768ca2e0 100644
--- a/modules/features/automation/mainAutomation.py
+++ b/modules/features/automation/mainAutomation.py
@@ -163,9 +163,11 @@ async def syncAutomationEvents(chatInterface, eventUser) -> Dict[str, Any]:
Returns:
Dictionary with sync results (synced count and event IDs)
"""
- # Get all automation definitions (for current mandate)
- allAutomations = chatInterface.db.getRecordset(AutomationDefinition)
- filtered = chatInterface._uam(AutomationDefinition, allAutomations)
+ # Get all automation definitions filtered by RBAC (for current mandate)
+ filtered = chatInterface.db.getRecordsetWithRBAC(
+ AutomationDefinition,
+ eventUser
+ )
registeredEvents = {}
diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py
new file mode 100644
index 00000000..5c4a90a1
--- /dev/null
+++ b/modules/interfaces/interfaceBootstrap.py
@@ -0,0 +1,548 @@
+"""
+Centralized bootstrap interface for system initialization.
+Contains all bootstrap logic including mandate, users, and RBAC rules.
+"""
+
+import logging
+from typing import Optional
+from passlib.context import CryptContext
+from modules.connectors.connectorDbPostgre import DatabaseConnector
+from modules.shared.configuration import APP_CONFIG
+from modules.datamodels.datamodelUam import (
+ Mandate,
+ UserInDB,
+ UserPrivilege,
+ AuthAuthority,
+)
+from modules.datamodels.datamodelRbac import (
+ AccessRule,
+ AccessRuleContext,
+)
+from modules.datamodels.datamodelUam import AccessLevel
+
+logger = logging.getLogger(__name__)
+
+# Password-Hashing
+pwdContext = CryptContext(schemes=["argon2"], deprecated="auto")
+
+
+def initBootstrap(db: DatabaseConnector) -> None:
+ """
+ Main bootstrap entry point - initializes all system components.
+
+ Args:
+ db: Database connector instance
+ """
+ logger.info("Starting system bootstrap")
+
+ # Initialize root mandate
+ mandateId = initRootMandate(db)
+
+ # Initialize admin user
+ adminUserId = initAdminUser(db, mandateId)
+
+ # Initialize event user
+ eventUserId = initEventUser(db, mandateId)
+
+ # Initialize RBAC rules
+ initRbacRules(db)
+
+ # Assign initial user roles
+ if adminUserId and eventUserId:
+ assignInitialUserRoles(db, adminUserId, eventUserId)
+
+ logger.info("System bootstrap completed")
+
+
+def initRootMandate(db: DatabaseConnector) -> Optional[str]:
+ """
+ Creates the Root mandate if it doesn't exist.
+
+ Args:
+ db: Database connector instance
+
+ Returns:
+ Mandate ID if created or found, None otherwise
+ """
+ existingMandates = db.getRecordset(Mandate)
+ if existingMandates:
+ mandateId = existingMandates[0].get("id")
+ logger.info(f"Root mandate already exists with ID {mandateId}")
+ return mandateId
+
+ logger.info("Creating Root mandate")
+ rootMandate = Mandate(name="Root", language="en", enabled=True)
+ createdMandate = db.recordCreate(Mandate, rootMandate)
+ mandateId = createdMandate.get("id")
+ logger.info(f"Root mandate created with ID {mandateId}")
+ return mandateId
+
+
+def initAdminUser(db: DatabaseConnector, mandateId: Optional[str]) -> Optional[str]:
+ """
+ Creates the Admin user if it doesn't exist.
+
+ Args:
+ db: Database connector instance
+ mandateId: Root mandate ID
+
+ Returns:
+ User ID if created or found, None otherwise
+ """
+ existingUsers = db.getRecordset(UserInDB, recordFilter={"username": "admin"})
+ if existingUsers:
+ userId = existingUsers[0].get("id")
+ logger.info(f"Admin user already exists with ID {userId}")
+ return userId
+
+ logger.info("Creating Admin user")
+ adminUser = UserInDB(
+ mandateId=mandateId,
+ username="admin",
+ email="admin@example.com",
+ fullName="Administrator",
+ enabled=True,
+ language="en",
+ privilege=UserPrivilege.SYSADMIN,
+ roleLabels=["sysadmin"],
+ authenticationAuthority=AuthAuthority.LOCAL,
+ hashedPassword=_getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET")),
+ connections=[],
+ )
+ createdUser = db.recordCreate(UserInDB, adminUser)
+ userId = createdUser.get("id")
+ logger.info(f"Admin user created with ID {userId}")
+ return userId
+
+
+def initEventUser(db: DatabaseConnector, mandateId: Optional[str]) -> Optional[str]:
+ """
+ Creates the Event user if it doesn't exist.
+
+ Args:
+ db: Database connector instance
+ mandateId: Root mandate ID
+
+ Returns:
+ User ID if created or found, None otherwise
+ """
+ existingUsers = db.getRecordset(UserInDB, recordFilter={"username": "event"})
+ if existingUsers:
+ userId = existingUsers[0].get("id")
+ logger.info(f"Event user already exists with ID {userId}")
+ return userId
+
+ logger.info("Creating Event user")
+ eventUser = UserInDB(
+ mandateId=mandateId,
+ username="event",
+ email="event@example.com",
+ fullName="Event",
+ enabled=True,
+ language="en",
+ privilege=UserPrivilege.SYSADMIN,
+ roleLabels=["sysadmin"],
+ authenticationAuthority=AuthAuthority.LOCAL,
+ hashedPassword=_getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET")),
+ connections=[],
+ )
+ createdUser = db.recordCreate(UserInDB, eventUser)
+ userId = createdUser.get("id")
+ logger.info(f"Event user created with ID {userId}")
+ return userId
+
+
+def initRbacRules(db: DatabaseConnector) -> None:
+ """
+ Initialize RBAC rules if they don't exist.
+ Converts all UAM logic from interface*Access.py modules to RBAC rules.
+
+ Args:
+ db: Database connector instance
+ """
+ existingRules = db.getRecordset(AccessRule)
+ if existingRules:
+ logger.info(f"RBAC rules already exist ({len(existingRules)} rules)")
+ return
+
+ logger.info("Initializing RBAC rules")
+
+ # Create default role rules
+ createDefaultRoleRules(db)
+
+ # Create table-specific rules (converted from UAM logic)
+ createTableSpecificRules(db)
+
+ logger.info("RBAC rules initialization completed")
+
+
+def createDefaultRoleRules(db: DatabaseConnector) -> None:
+ """
+ Create default role rules for generic access (item = null).
+
+ Args:
+ db: Database connector instance
+ """
+ defaultRules = [
+ # SysAdmin Role - Full access to all
+ AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.DATA,
+ item=None,
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL,
+ ),
+ # Admin Role - Group-level access
+ AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.DATA,
+ item=None,
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.NONE,
+ ),
+ # User Role - My records only
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item=None,
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY,
+ ),
+ # Viewer Role - Read-only group access
+ AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.DATA,
+ item=None,
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ),
+ ]
+
+ for rule in defaultRules:
+ db.recordCreate(AccessRule, rule)
+
+ logger.info(f"Created {len(defaultRules)} default role rules")
+
+
+def createTableSpecificRules(db: DatabaseConnector) -> None:
+ """
+ Create table-specific rules converted from UAM logic.
+ These rules override generic rules for specific tables.
+
+ Args:
+ db: Database connector instance
+ """
+ tableRules = []
+
+ # Mandate table - Only sysadmin can access
+ tableRules.append(AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.DATA,
+ item="Mandate",
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.DATA,
+ item="Mandate",
+ view=False,
+ read=AccessLevel.NONE,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="Mandate",
+ view=False,
+ read=AccessLevel.NONE,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.DATA,
+ item="Mandate",
+ view=False,
+ read=AccessLevel.NONE,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+
+ # UserInDB table
+ tableRules.append(AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.DATA,
+ item="UserInDB",
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.DATA,
+ item="UserInDB",
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="UserInDB",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.MY,
+ delete=AccessLevel.NONE,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.DATA,
+ item="UserInDB",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+
+ # UserConnection table
+ tableRules.append(AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.DATA,
+ item="UserConnection",
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.DATA,
+ item="UserConnection",
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="UserConnection",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.DATA,
+ item="UserConnection",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+
+ # DataNeutraliserConfig table
+ tableRules.append(AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.DATA,
+ item="DataNeutraliserConfig",
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.DATA,
+ item="DataNeutraliserConfig",
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="DataNeutraliserConfig",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.DATA,
+ item="DataNeutraliserConfig",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+
+ # DataNeutralizerAttributes table
+ tableRules.append(AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.DATA,
+ item="DataNeutralizerAttributes",
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.DATA,
+ item="DataNeutralizerAttributes",
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="DataNeutralizerAttributes",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.DATA,
+ item="DataNeutralizerAttributes",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+
+ # AuthEvent table
+ tableRules.append(AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.DATA,
+ item="AuthEvent",
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.ALL,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.DATA,
+ item="AuthEvent",
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.ALL,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="AuthEvent",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.DATA,
+ item="AuthEvent",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+
+ # Create all table-specific rules
+ for rule in tableRules:
+ db.recordCreate(AccessRule, rule)
+
+ logger.info(f"Created {len(tableRules)} table-specific rules")
+
+
+def assignInitialUserRoles(db: DatabaseConnector, adminUserId: str, eventUserId: str) -> None:
+ """
+ Assign initial roles to admin and event users.
+
+ Args:
+ db: Database connector instance
+ adminUserId: Admin user ID
+ eventUserId: Event user ID
+ """
+ # Update admin user with sysadmin role
+ adminUser = db.getRecordset(UserInDB, recordFilter={"id": adminUserId})
+ if adminUser:
+ adminUserData = adminUser[0]
+ if "sysadmin" not in adminUserData.get("roleLabels", []):
+ adminUserData["roleLabels"] = adminUserData.get("roleLabels", []) + ["sysadmin"]
+ db.recordUpdate(UserInDB, adminUserId, adminUserData)
+ logger.info(f"Assigned sysadmin role to admin user {adminUserId}")
+
+ # Update event user with sysadmin role
+ eventUser = db.getRecordset(UserInDB, recordFilter={"id": eventUserId})
+ if eventUser:
+ eventUserData = eventUser[0]
+ if "sysadmin" not in eventUserData.get("roleLabels", []):
+ eventUserData["roleLabels"] = eventUserData.get("roleLabels", []) + ["sysadmin"]
+ db.recordUpdate(UserInDB, eventUserId, eventUserData)
+ logger.info(f"Assigned sysadmin role to event user {eventUserId}")
+
+
+def _getPasswordHash(password: Optional[str]) -> Optional[str]:
+ """
+ Hash a password using Argon2.
+
+ Args:
+ password: Plain text password
+
+ Returns:
+ Hashed password or None if password is None
+ """
+ if password is None:
+ return None
+ return pwdContext.hash(password)
diff --git a/modules/interfaces/interfaceDbAppAccess.py b/modules/interfaces/interfaceDbAppAccess.py
deleted file mode 100644
index 1bb9126c..00000000
--- a/modules/interfaces/interfaceDbAppAccess.py
+++ /dev/null
@@ -1,254 +0,0 @@
-"""
-Access control for the Application.
-"""
-
-import logging
-from typing import Dict, Any, List, Optional
-from modules.datamodels.datamodelUam import UserPrivilege, User, UserInDB, Mandate
-from modules.datamodels.datamodelSecurity import AuthEvent
-
-# Configure logger
-logger = logging.getLogger(__name__)
-
-class AppAccess:
- """
- Access control class for Application interface.
- Handles user access management and permission checks.
- """
-
- def __init__(self, currentUser: User, db):
- """Initialize with user context."""
- self.currentUser = currentUser
- self.userId = currentUser.id
- self.mandateId = currentUser.mandateId
- self.privilege = currentUser.privilege
-
- if not self.mandateId or not self.userId:
- raise ValueError("Invalid user context: mandateId and userId are required")
-
- self.db = db
-
- def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """
- Unified user access management function that filters data based on user privileges
- and adds access control attributes.
-
- Args:
- model_class: Pydantic model class for the table
- recordset: Recordset to filter based on access rules
-
- Returns:
- Filtered recordset with access control attributes
- """
- filtered_records = []
- table_name = model_class.__name__
-
- # Only SYSADMIN can see mandates
- if table_name == "Mandate":
- if self.privilege == UserPrivilege.SYSADMIN:
- filtered_records = recordset
- else:
- filtered_records = []
- # Special handling for users table
- elif table_name == "UserInDB":
- if self.privilege == UserPrivilege.SYSADMIN:
- # SysAdmin sees all users
- filtered_records = recordset
- elif self.privilege == UserPrivilege.ADMIN:
- # Admin sees all users in their mandate
- filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId]
- else:
- # Regular users only see themselves
- filtered_records = [r for r in recordset if r.get("id") == self.userId]
- # Special handling for connections table
- elif table_name == "UserConnection":
- if self.privilege == UserPrivilege.SYSADMIN:
- # SysAdmin sees all connections
- filtered_records = recordset
- elif self.privilege == UserPrivilege.ADMIN:
- # Admin sees connections for users in their mandate
- users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId})
- user_ids: List[str] = [str(u["id"]) for u in users]
- filtered_records = [r for r in recordset if r.get("userId") in user_ids]
- else:
- # Regular users only see their own connections
- filtered_records = [r for r in recordset if r.get("userId") == self.userId]
- # Special handling for data neutralization config table
- elif table_name == "DataNeutraliserConfig":
- if self.privilege == UserPrivilege.SYSADMIN:
- # SysAdmin sees all configs
- filtered_records = recordset
- elif self.privilege == UserPrivilege.ADMIN:
- # Admin sees configs in their mandate
- filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId]
- else:
- # Regular users only see their own configs
- filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId and r.get("userId") == self.userId]
- # Special handling for data neutralizer attributes table
- elif table_name == "DataNeutralizerAttributes":
- if self.privilege == UserPrivilege.SYSADMIN:
- # SysAdmin sees all attributes
- filtered_records = recordset
- elif self.privilege == UserPrivilege.ADMIN:
- # Admin sees attributes in their mandate
- filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId]
- else:
- # Regular users only see their own attributes
- filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId and r.get("userId") == self.userId]
- # System admins see all other records
- elif self.privilege == UserPrivilege.SYSADMIN:
- filtered_records = recordset
- # For other records, admins see records in their mandate
- elif self.privilege == UserPrivilege.ADMIN:
- filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId]
- # Regular users only see records they own within their mandate
- else:
- filtered_records = [r for r in recordset
- if r.get("mandateId","-") == self.mandateId and r.get("createdBy") == self.userId]
-
- # Add access control attributes to each record
- for record in filtered_records:
- record_id = record.get("id")
-
- # Set access control flags based on user permissions
- if table_name == "Mandate":
- record["_hideView"] = False # SYSADMIN can view
- record["_hideEdit"] = not self.canModify(Mandate, record_id)
- record["_hideDelete"] = not self.canModify(Mandate, record_id)
- elif table_name == "UserInDB":
- record["_hideView"] = False # Everyone can view users they have access to
- # SysAdmin can edit/delete any user
- if self.privilege == UserPrivilege.SYSADMIN:
- record["_hideEdit"] = False
- record["_hideDelete"] = False
- # Admin can edit/delete users in their mandate
- elif self.privilege == UserPrivilege.ADMIN:
- record["_hideEdit"] = record.get("mandateId","-") != self.mandateId
- record["_hideDelete"] = record.get("mandateId","-") != self.mandateId
- # Regular users can only edit themselves
- else:
- record["_hideEdit"] = record.get("id") != self.userId
- record["_hideDelete"] = True # Regular users cannot delete users
- elif table_name == "UserConnection":
- # Everyone can view connections they have access to
- record["_hideView"] = False
- # SysAdmin can edit/delete any connection
- if self.privilege == UserPrivilege.SYSADMIN:
- record["_hideEdit"] = False
- record["_hideDelete"] = False
- # Admin can edit/delete connections for users in their mandate
- elif self.privilege == UserPrivilege.ADMIN:
- users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId})
- user_ids: List[str] = [str(u["id"]) for u in users]
- record["_hideEdit"] = record.get("userId") not in user_ids
- record["_hideDelete"] = record.get("userId") not in user_ids
- # Regular users can only edit/delete their own connections
- else:
- record["_hideEdit"] = record.get("userId") != self.userId
- record["_hideDelete"] = record.get("userId") != self.userId
-
- elif table_name == "DataNeutraliserConfig":
- # Everyone can view configs they have access to
- record["_hideView"] = False
- # SysAdmin can edit/delete any config
- if self.privilege == UserPrivilege.SYSADMIN:
- record["_hideEdit"] = False
- record["_hideDelete"] = False
- # Admin can edit/delete configs in their mandate
- elif self.privilege == UserPrivilege.ADMIN:
- record["_hideEdit"] = record.get("mandateId","-") != self.mandateId
- record["_hideDelete"] = record.get("mandateId","-") != self.mandateId
- # Regular users can only edit/delete their own configs
- else:
- record["_hideEdit"] = record.get("userId") != self.userId
- record["_hideDelete"] = record.get("userId") != self.userId
- elif table_name == "DataNeutralizerAttributes":
- # Everyone can view attributes they have access to
- record["_hideView"] = False
- # SysAdmin can edit/delete any attributes
- if self.privilege == UserPrivilege.SYSADMIN:
- record["_hideEdit"] = False
- record["_hideDelete"] = False
- # Admin can edit/delete attributes in their mandate
- elif self.privilege == UserPrivilege.ADMIN:
- record["_hideEdit"] = record.get("mandateId","-") != self.mandateId
- record["_hideDelete"] = record.get("mandateId","-") != self.mandateId
- # Regular users can only edit/delete their own attributes
- else:
- record["_hideEdit"] = record.get("userId") != self.userId
- record["_hideDelete"] = record.get("userId") != self.userId
-
- elif table_name == "AuthEvent":
- # Only show auth events for the current user or if admin
- if self.privilege in [UserPrivilege.SYSADMIN, UserPrivilege.ADMIN]:
- record["_hideView"] = False
- else:
- record["_hideView"] = record.get("userId") != self.userId
- record["_hideEdit"] = True # Auth events can't be edited
- record["_hideDelete"] = not self.canModify(AuthEvent, record_id)
- else:
- # Default access control for other tables
- record["_hideView"] = False
- record["_hideEdit"] = not self.canModify(model_class, record_id)
- record["_hideDelete"] = not self.canModify(model_class, record_id)
-
- return filtered_records
-
- def canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
- """
- Checks if the current user can modify (create/update/delete) records in a table.
-
- Args:
- model_class: Pydantic model class for the table
- recordId: Optional record ID for specific record check
-
- Returns:
- Boolean indicating permission
- """
- table_name = model_class.__name__
-
- # For mandates, only SYSADMIN can modify
- if table_name == "Mandate":
- return self.privilege == UserPrivilege.SYSADMIN
-
- # System admins can modify anything else
- if self.privilege == UserPrivilege.SYSADMIN:
- return True
-
- # Check specific record permissions
- if recordId is not None:
- # Get the record to check ownership
- records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": str(recordId)})
- if not records:
- return False
-
- record = records[0]
-
- # Special handling for connections
- if table_name == "UserConnection":
- # Admin can modify connections for users in their mandate
- if self.privilege == UserPrivilege.ADMIN:
- users: List[Dict[str, Any]] = self.db.getRecordset(UserInDB, recordFilter={"mandateId": self.mandateId})
- user_ids: List[str] = [str(u["id"]) for u in users]
- return record.get("userId") in user_ids
- # Users can only modify their own connections
- return record.get("userId") == self.userId
-
- # Admins can modify anything in their mandate
- if self.privilege == UserPrivilege.ADMIN and record.get("mandateId","-") == self.mandateId:
- return True
-
- # Users can only modify their own records
- if (record.get("mandateId","-") == self.mandateId and
- record.get("createdBy") == self.userId):
- return True
-
- return False
- else:
- # For general table modify permission (e.g., create)
- # Admins can create anything in their mandate
- if self.privilege == UserPrivilege.ADMIN:
- return True
-
- # Regular users can create most entities
- return True
diff --git a/modules/interfaces/interfaceDbAppObjects.py b/modules/interfaces/interfaceDbAppObjects.py
index 91d7bda4..900f7328 100644
--- a/modules/interfaces/interfaceDbAppObjects.py
+++ b/modules/interfaces/interfaceDbAppObjects.py
@@ -12,7 +12,8 @@ import uuid
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
from modules.shared.timeUtils import getUtcTimestamp, parseTimestamp
-from modules.interfaces.interfaceDbAppAccess import AppAccess
+from modules.interfaces.interfaceBootstrap import initBootstrap
+from modules.security.rbac import RbacClass
from modules.datamodels.datamodelUam import (
User,
Mandate,
@@ -22,6 +23,11 @@ from modules.datamodels.datamodelUam import (
UserPrivilege,
ConnectionStatus,
)
+from modules.datamodels.datamodelRbac import (
+ AccessRule,
+ AccessRuleContext,
+)
+from modules.datamodels.datamodelUam import AccessLevel
from modules.datamodels.datamodelSecurity import Token, AuthEvent, TokenStatus
from modules.datamodels.datamodelNeutralizer import (
DataNeutraliserConfig,
@@ -53,7 +59,6 @@ class AppObjects:
self.currentUser = currentUser # Store User object directly
self.userId = currentUser.id if currentUser else None
self.mandateId = currentUser.mandateId if currentUser else None
- self.access = None # Will be set when user context is provided
# Initialize database
self._initializeDatabase()
@@ -81,10 +86,10 @@ class AppObjects:
# Add language settings
self.userLanguage = currentUser.language # Default user language
- # Initialize access control with user context
- self.access = AppAccess(
- self.currentUser, self.db
- ) # Convert to dict only when needed
+ # Initialize RBAC interface
+ if not currentUser:
+ raise ValueError("User context is required for RBAC")
+ self.rbac = RbacClass(self.db)
# Update database context
self.db.updateContext(self.userId)
@@ -127,113 +132,46 @@ class AppObjects:
def _initRecords(self):
"""Initialize standard records if they don't exist."""
- self._initRootMandate()
- self._initAdminUser()
- self._initEventUser()
+ initBootstrap(self.db)
- def _initRootMandate(self):
- """Creates the Root mandate if it doesn't exist."""
- existingMandateId = self.getInitialId(Mandate)
- mandates = self.db.getRecordset(Mandate)
- if existingMandateId is None or not mandates:
- logger.info("Creating Root mandate")
- rootMandate = Mandate(name="Root", language="en", enabled=True)
- createdMandate = self.db.recordCreate(Mandate, rootMandate)
- logger.info(f"Root mandate created with ID {createdMandate['id']}")
- # Update mandate context
- self.mandateId = createdMandate["id"]
-
- def _initAdminUser(self):
- """Creates the Admin user if it doesn't exist."""
- existingUserId = self.getInitialId(UserInDB)
- users = self.db.getRecordset(UserInDB)
- if existingUserId is None or not users:
- logger.info("Creating Admin user")
- adminUser = UserInDB(
- mandateId=self.getInitialId(Mandate),
- username="admin",
- email="admin@example.com",
- fullName="Administrator",
- enabled=True,
- language="en",
- privilege=UserPrivilege.SYSADMIN,
- authenticationAuthority="local", # Using lowercase value directly
- hashedPassword=self._getPasswordHash(
- APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET")
- ),
- connections=[],
- )
- createdUser = self.db.recordCreate(UserInDB, adminUser)
- logger.info(f"Admin user created with ID {createdUser['id']}")
-
- # Update user context
- self.currentUser = createdUser
- self.userId = createdUser.get("id")
-
- def _initEventUser(self):
- """Creates the Event user if it doesn't exist."""
- # Check if event user already exists
- existingUsers = self.db.getRecordset(
- UserInDB, recordFilter={"username": "event"}
- )
- if not existingUsers:
- logger.info("Creating Event user")
- eventUser = UserInDB(
- mandateId=self.getInitialId(Mandate),
- username="event",
- email="event@example.com",
- fullName="Event",
- enabled=True,
- language="en",
- privilege=UserPrivilege.SYSADMIN,
- authenticationAuthority="local", # Using lowercase value directly
- hashedPassword=self._getPasswordHash(
- APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET")
- ),
- connections=[],
- )
- createdUser = self.db.recordCreate(UserInDB, eventUser)
- logger.info(f"Event user created with ID {createdUser['id']}")
-
- def _uam(
- self, model_class: type, recordset: List[Dict[str, Any]]
- ) -> List[Dict[str, Any]]:
+ def checkRbacPermission(
+ self,
+ modelClass: type,
+ operation: str,
+ recordId: Optional[str] = None
+ ) -> bool:
"""
- Unified user access management function that filters data based on user privileges
- and adds access control attributes.
+ Check RBAC permission for a specific operation on a table.
Args:
- model_class: Pydantic model class for the table
- recordset: Recordset to filter based on access rules
-
- Returns:
- Filtered recordset with access control attributes
- """
- # First apply access control
- filteredRecords = self.access.uam(model_class, recordset)
-
- # Then filter out database-specific fields
- cleanedRecords = []
- for record in filteredRecords:
- # Create a new dict with only non-database fields
- cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_")}
- cleanedRecords.append(cleanedRecord)
-
- return cleanedRecords
-
- def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
- """
- Checks if the current user can modify (create/update/delete) records in a table.
-
- Args:
- model_class: Pydantic model class for the table
+ modelClass: Pydantic model class for the table
+ operation: Operation to check ('create', 'update', 'delete', 'read')
recordId: Optional record ID for specific record check
Returns:
Boolean indicating permission
"""
- return self.access.canModify(model_class, recordId)
+ if not self.rbac or not self.currentUser:
+ return False
+
+ tableName = modelClass.__name__
+ permissions = self.rbac.getUserPermissions(
+ self.currentUser,
+ AccessRuleContext.DATA,
+ tableName
+ )
+
+ if operation == "create":
+ return permissions.create != AccessLevel.NONE
+ elif operation == "update":
+ return permissions.update != AccessLevel.NONE
+ elif operation == "delete":
+ return permissions.delete != AccessLevel.NONE
+ elif operation == "read":
+ return permissions.read != AccessLevel.NONE
+ else:
+ return False
def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
@@ -480,13 +418,18 @@ class AppObjects:
If pagination is None: List[User]
If pagination is provided: PaginatedResult with items and metadata
"""
- # For SYSADMIN, get all users regardless of mandate
- # For others, filter by mandate
- if self.currentUser and self.currentUser.privilege == UserPrivilege.SYSADMIN:
- users = self.db.getRecordset(UserInDB)
- else:
- users = self.db.getRecordset(UserInDB, recordFilter={"mandateId": mandateId})
- filteredUsers = self._uam(UserInDB, users)
+ # Use RBAC filtering
+ users = self.db.getRecordsetWithRBAC(
+ UserInDB,
+ self.currentUser,
+ recordFilter={"mandateId": mandateId} if mandateId else None
+ )
+
+ # Filter out database-specific fields
+ filteredUsers = []
+ for user in users:
+ cleanedUser = {k: v for k, v in user.items() if not k.startswith("_")}
+ filteredUsers.append(cleanedUser)
# If no pagination requested, return all items
if pagination is None:
@@ -521,18 +464,22 @@ class AppObjects:
def getUserByUsername(self, username: str) -> Optional[User]:
"""Returns a user by username."""
try:
- # Get users table
- users = self.db.getRecordset(UserInDB)
+ # Use RBAC filtering
+ users = self.db.getRecordsetWithRBAC(
+ UserInDB,
+ self.currentUser,
+ recordFilter={"username": username}
+ )
+
if not users:
+ logger.info(f"No user found with username {username}")
return None
- # Find user by username
- for user_dict in users:
- if user_dict.get("username") == username:
- return User(**user_dict)
-
- logger.info(f"No user found with username {username}")
- return None
+ # Return first matching user (should be unique)
+ userDict = users[0]
+ # Filter out database-specific fields
+ cleanedUser = {k: v for k, v in userDict.items() if not k.startswith("_")}
+ return User(**cleanedUser)
except Exception as e:
logger.error(f"Error getting user by username: {str(e)}")
@@ -549,11 +496,9 @@ class AppObjects:
# Find user by ID
for user_dict in users:
if user_dict.get("id") == userId:
- # Apply access control
- filteredUsers = self._uam(UserInDB, [user_dict])
- if filteredUsers:
- return User(**filteredUsers[0])
- return None
+ # User already filtered by RBAC, just clean fields
+ cleanedUser = {k: v for k, v in user_dict.items() if not k.startswith("_")}
+ return User(**cleanedUser)
return None
@@ -764,7 +709,7 @@ class AppObjects:
if not user:
raise ValueError(f"User {userId} not found")
- if not self._canModify(UserInDB, userId):
+ if not self.checkRbacPermission(UserInDB, "update", userId):
raise PermissionError(f"No permission to delete user {userId}")
# Delete all referenced data first
@@ -943,8 +888,14 @@ class AppObjects:
If pagination is None: List[Mandate]
If pagination is provided: PaginatedResult with items and metadata
"""
- allMandates = self.db.getRecordset(Mandate)
- filteredMandates = self._uam(Mandate, allMandates)
+ # Use RBAC filtering
+ allMandates = self.db.getRecordsetWithRBAC(Mandate, self.currentUser)
+
+ # Filter out database-specific fields
+ filteredMandates = []
+ for mandate in allMandates:
+ cleanedMandate = {k: v for k, v in mandate.items() if not k.startswith("_")}
+ filteredMandates.append(cleanedMandate)
# If no pagination requested, return all items
if pagination is None:
@@ -978,11 +929,21 @@ class AppObjects:
def getMandate(self, mandateId: str) -> Optional[Mandate]:
"""Returns a mandate by ID if user has access."""
- mandates = self.db.getRecordset(Mandate, recordFilter={"id": mandateId})
+ # Use RBAC filtering
+ mandates = self.db.getRecordsetWithRBAC(
+ Mandate,
+ self.currentUser,
+ recordFilter={"id": mandateId}
+ )
+
if not mandates:
return None
-
- filteredMandates = self._uam(Mandate, mandates)
+
+ # Filter out database-specific fields
+ filteredMandates = []
+ for mandate in mandates:
+ cleanedMandate = {k: v for k, v in mandate.items() if not k.startswith("_")}
+ filteredMandates.append(cleanedMandate)
if not filteredMandates:
return None
@@ -990,7 +951,7 @@ class AppObjects:
def createMandate(self, name: str, language: str = "en") -> Mandate:
"""Creates a new mandate if user has permission."""
- if not self._canModify(Mandate):
+ if not self.checkRbacPermission(Mandate, "create"):
raise PermissionError("No permission to create mandates")
# Create mandate data using model
@@ -1007,7 +968,7 @@ class AppObjects:
"""Updates a mandate if user has access."""
try:
# First check if user has permission to modify mandates
- if not self._canModify(Mandate, mandateId):
+ if not self.checkRbacPermission(Mandate, "update", mandateId):
raise PermissionError(f"No permission to update mandate {mandateId}")
# Get mandate with access control
@@ -1044,7 +1005,7 @@ class AppObjects:
if not mandate:
return False
- if not self._canModify(Mandate, mandateId):
+ if not self.checkRbacPermission(Mandate, "delete", mandateId):
raise PermissionError(f"No permission to delete mandate {mandateId}")
# Check if mandate has users
@@ -1384,7 +1345,7 @@ class AppObjects:
self.currentUser = None
self.userId = None
self.mandateId = None
- self.access = None
+ self.rbac = None
# Clear database context
if hasattr(self, "db"):
@@ -1401,18 +1362,20 @@ class AppObjects:
def getNeutralizationConfig(self) -> Optional[DataNeutraliserConfig]:
"""Get the data neutralization configuration for the current user's mandate"""
try:
- configs = self.db.getRecordset(
- DataNeutraliserConfig, recordFilter={"mandateId": self.mandateId}
+ # Use RBAC filtering
+ filtered_configs = self.db.getRecordsetWithRBAC(
+ DataNeutraliserConfig,
+ self.currentUser,
+ recordFilter={"mandateId": self.mandateId}
)
- if not configs:
- return None
-
- # Apply access control
- filtered_configs = self._uam(DataNeutraliserConfig, configs)
+
if not filtered_configs:
return None
- return DataNeutraliserConfig(**filtered_configs[0])
+ # Filter out database-specific fields
+ configDict = filtered_configs[0]
+ cleanedConfig = {k: v for k, v in configDict.items() if not k.startswith("_")}
+ return DataNeutraliserConfig(**cleanedConfig)
except Exception as e:
logger.error(f"Error getting neutralization config: {str(e)}")
@@ -1461,14 +1424,22 @@ class AppObjects:
if file_id:
filter_dict["fileId"] = file_id
- attributes = self.db.getRecordset(
- DataNeutralizerAttributes, recordFilter=filter_dict
+ # Use RBAC filtering
+ filtered_attributes = self.db.getRecordsetWithRBAC(
+ DataNeutralizerAttributes,
+ self.currentUser,
+ recordFilter=filter_dict
)
- filtered_attributes = self._uam(DataNeutralizerAttributes, attributes)
+ # Filter out database-specific fields
+ cleaned_attributes = []
+ for attr in filtered_attributes:
+ cleanedAttr = {k: v for k, v in attr.items() if not k.startswith("_")}
+ cleaned_attributes.append(cleanedAttr)
+
return [
DataNeutralizerAttributes(**attr)
- for attr in filtered_attributes
+ for attr in cleaned_attributes
]
except Exception as e:
@@ -1495,6 +1466,151 @@ class AppObjects:
logger.error(f"Error deleting neutralization attributes: {str(e)}")
return False
+ # RBAC CRUD Methods
+
+ def createAccessRule(self, accessRule: AccessRule) -> AccessRule:
+ """
+ Create a new access rule.
+
+ Args:
+ accessRule: AccessRule object to create
+
+ Returns:
+ Created AccessRule object
+ """
+ try:
+ createdRule = self.db.recordCreate(AccessRule, accessRule)
+ logger.info(f"Created access rule with ID {createdRule.get('id')}")
+ return AccessRule(**createdRule)
+ except Exception as e:
+ logger.error(f"Error creating access rule: {str(e)}")
+ raise
+
+ def getAccessRule(self, ruleId: str) -> Optional[AccessRule]:
+ """
+ Get an access rule by ID.
+
+ Args:
+ ruleId: Access rule ID
+
+ Returns:
+ AccessRule object if found, None otherwise
+ """
+ try:
+ rules = self.db.getRecordset(AccessRule, recordFilter={"id": ruleId})
+ if rules:
+ return AccessRule(**rules[0])
+ return None
+ except Exception as e:
+ logger.error(f"Error getting access rule {ruleId}: {str(e)}")
+ return None
+
+ def updateAccessRule(self, ruleId: str, accessRule: AccessRule) -> AccessRule:
+ """
+ Update an existing access rule.
+
+ Args:
+ ruleId: Access rule ID
+ accessRule: Updated AccessRule object
+
+ Returns:
+ Updated AccessRule object
+ """
+ try:
+ updatedRule = self.db.recordUpdate(AccessRule, ruleId, accessRule.model_dump())
+ logger.info(f"Updated access rule with ID {ruleId}")
+ return AccessRule(**updatedRule)
+ except Exception as e:
+ logger.error(f"Error updating access rule {ruleId}: {str(e)}")
+ raise
+
+ def deleteAccessRule(self, ruleId: str) -> bool:
+ """
+ Delete an access rule.
+
+ Args:
+ ruleId: Access rule ID
+
+ Returns:
+ True if deleted successfully, False otherwise
+ """
+ try:
+ self.db.recordDelete(AccessRule, ruleId)
+ logger.info(f"Deleted access rule with ID {ruleId}")
+ return True
+ except Exception as e:
+ logger.error(f"Error deleting access rule {ruleId}: {str(e)}")
+ return False
+
+ def getAccessRules(
+ self,
+ roleLabel: Optional[str] = None,
+ context: Optional[AccessRuleContext] = None,
+ item: Optional[str] = None
+ ) -> List[AccessRule]:
+ """
+ Get access rules with optional filters.
+
+ Args:
+ roleLabel: Optional role label filter
+ context: Optional context filter
+ item: Optional item filter
+
+ Returns:
+ List of AccessRule objects
+ """
+ try:
+ recordFilter = {}
+ if roleLabel:
+ recordFilter["roleLabel"] = roleLabel
+ if context:
+ recordFilter["context"] = context.value
+ if item:
+ recordFilter["item"] = item
+
+ rules = self.db.getRecordset(AccessRule, recordFilter=recordFilter if recordFilter else None)
+ return [AccessRule(**rule) for rule in rules]
+ except Exception as e:
+ logger.error(f"Error getting access rules: {str(e)}")
+ return []
+
+ def getAccessRulesForRoles(
+ self,
+ roleLabels: List[str],
+ context: AccessRuleContext,
+ item: str
+ ) -> List[AccessRule]:
+ """
+ Get access rules for multiple roles, context, and item.
+ Returns the most specific matching rules for each role.
+
+ Args:
+ roleLabels: List of role labels
+ context: Context type
+ item: Item identifier
+
+ Returns:
+ List of AccessRule objects (most specific for each role)
+ """
+ try:
+ RbacInstance = RbacClass(self.db)
+ allRules = []
+
+ for roleLabel in roleLabels:
+ # Get all rules for this role and context
+ roleRules = RbacInstance._getRulesForRole(roleLabel, context)
+
+ # Find most specific rule for this item
+ mostSpecificRule = RbacInstance.findMostSpecificRule(roleRules, item)
+
+ if mostSpecificRule:
+ allRules.append(mostSpecificRule)
+
+ return allRules
+ except Exception as e:
+ logger.error(f"Error getting access rules for roles: {str(e)}")
+ return []
+
# Public Methods
diff --git a/modules/interfaces/interfaceDbChatAccess.py b/modules/interfaces/interfaceDbChatAccess.py
deleted file mode 100644
index 37e96d84..00000000
--- a/modules/interfaces/interfaceDbChatAccess.py
+++ /dev/null
@@ -1,140 +0,0 @@
-"""
-Access control module for Chat interface.
-Handles user access management and permission checks.
-"""
-
-from typing import Dict, Any, List, Optional
-from modules.datamodels.datamodelUam import User, UserPrivilege
-from modules.datamodels.datamodelChat import ChatWorkflow, AutomationDefinition
-
-class ChatAccess:
- """
- Access control class for Chat interface.
- Handles user access management and permission checks.
- """
-
- def __init__(self, currentUser: User, db):
- """Initialize with user context."""
- self.currentUser = currentUser
- self.mandateId = currentUser.mandateId
- self.userId = currentUser.id
-
- if not self.mandateId or not self.userId:
- raise ValueError("Invalid user context: mandateId and userId are required")
-
- self.db = db
-
- def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """
- Unified user access management function that filters data based on user privileges
- and adds access control attributes.
-
- Args:
- model_class: Pydantic model class for the table
- recordset: Recordset to filter based on access rules
-
- Returns:
- Filtered recordset with access control attributes
- """
- userPrivilege = self.currentUser.privilege
- table_name = model_class.__name__
- filtered_records = []
-
- # Apply filtering based on privilege
- if table_name == "AutomationDefinition":
- # Filter automations based on user privilege
- if userPrivilege == UserPrivilege.SYSADMIN:
- # System admins see all automations
- filtered_records = recordset
- elif userPrivilege == UserPrivilege.ADMIN:
- # Admins see all automations in their mandate
- filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId]
- else:
- # Regular users see only their own automations
- filtered_records = [
- r for r in recordset
- if r.get("mandateId","-") == self.mandateId and r.get("_createdBy") == self.userId
- ]
- elif userPrivilege == UserPrivilege.SYSADMIN:
- filtered_records = recordset # System admins see all records
- elif userPrivilege == UserPrivilege.ADMIN:
- # Admins see records in their mandate
- filtered_records = [r for r in recordset if r.get("mandateId","-") == self.mandateId]
- else: # Regular users
- # Users see only their records for other tables
- filtered_records = [r for r in recordset
- if r.get("mandateId","-") == self.mandateId and r.get("_createdBy") == self.userId]
-
- # Add access control attributes to each record
- for record in filtered_records:
- record_id = record.get("id")
-
- # Set access control flags based on user permissions
- if table_name == "ChatWorkflow":
- record["_hideView"] = False # Everyone can view
- record["_hideEdit"] = not self.canModify(ChatWorkflow, record_id)
- record["_hideDelete"] = not self.canModify(ChatWorkflow, record_id)
- elif table_name == "ChatMessage":
- record["_hideView"] = False # Everyone can view
- record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
- record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
- elif table_name == "ChatLog":
- record["_hideView"] = False # Everyone can view
- record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
- record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
- elif table_name == "AutomationDefinition":
- record["_hideView"] = False # Everyone can view
- record["_hideEdit"] = not self.canModify(AutomationDefinition, record_id)
- record["_hideDelete"] = not self.canModify(AutomationDefinition, record_id)
- else:
- # Default access control for other tables
- record["_hideView"] = False
- record["_hideEdit"] = not self.canModify(model_class, record_id)
- record["_hideDelete"] = not self.canModify(model_class, record_id)
-
- return filtered_records
-
- def canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
- """
- Checks if the current user can modify (create/update/delete) records in a table.
-
- Args:
- model_class: Pydantic model class for the table
- recordId: Optional record ID for specific record check
-
- Returns:
- Boolean indicating permission
- """
- userPrivilege = self.currentUser.privilege
-
- # System admins can modify anything
- if userPrivilege == UserPrivilege.SYSADMIN:
- return True
-
- # For regular users and admins, check specific cases
- if recordId is not None:
- # Get the record to check ownership
- records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": recordId})
- if not records:
- return False
-
- record = records[0]
-
- # Admins can modify anything in their mandate, if mandate is specified for a record
- if userPrivilege == UserPrivilege.ADMIN and record.get("mandateId","-") == self.mandateId:
- return True
-
- # Regular users can only modify their own records
- if (record.get("mandateId","-") == self.mandateId and
- record.get("_createdBy") == self.userId):
- return True
-
- return False
- else:
- # For general modification permission (e.g., create)
- # Admins can create anything in their mandate
- if userPrivilege == UserPrivilege.ADMIN:
- return True
-
- # Regular users can create in most tables
- return True
\ No newline at end of file
diff --git a/modules/interfaces/interfaceDbChatObjects.py b/modules/interfaces/interfaceDbChatObjects.py
index de4abc7e..6093eb78 100644
--- a/modules/interfaces/interfaceDbChatObjects.py
+++ b/modules/interfaces/interfaceDbChatObjects.py
@@ -10,7 +10,9 @@ from typing import Dict, Any, List, Optional, Union
import asyncio
-from modules.interfaces.interfaceDbChatAccess import ChatAccess
+from modules.security.rbac import RbacClass
+from modules.datamodels.datamodelRbac import AccessRuleContext
+from modules.datamodels.datamodelUam import AccessLevel
from modules.datamodels.datamodelChat import (
ChatDocument,
@@ -179,7 +181,7 @@ class ChatObjects:
self.currentUser = currentUser # Store User object directly
self.userId = currentUser.id if currentUser else None
self.mandateId = currentUser.mandateId if currentUser else None
- self.access = None # Will be set when user context is provided
+ self.rbac = None # RBAC interface
# Initialize services
self._initializeServices()
@@ -263,8 +265,10 @@ class ChatObjects:
# Add language settings
self.userLanguage = currentUser.language # Default user language
- # Initialize access control with user context
- self.access = ChatAccess(self.currentUser, self.db) # Convert to dict only when needed
+ # Initialize RBAC interface
+ if not self.currentUser:
+ raise ValueError("User context is required for RBAC")
+ self.rbac = RbacClass(self.db)
# Update database context
self.db.updateContext(self.userId)
@@ -310,35 +314,44 @@ class ChatObjects:
"""Initializes standard records in the database if they don't exist."""
pass
- def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """Delegate to access control module."""
- # First apply access control
- filteredRecords = self.access.uam(model_class, recordset)
-
- # For AutomationDefinition, keep _createdBy and mandateId for enrichment purposes
- # Other fields starting with _ are filtered out as they're database-specific
- if model_class.__name__ == "AutomationDefinition":
- # Keep _createdBy and mandateId for enrichment, filter out other _ fields
- cleanedRecords = []
- for record in filteredRecords:
- cleanedRecord = {}
- for k, v in record.items():
- # Keep _createdBy and mandateId, filter out other _ fields
- if k == "_createdBy" or k == "mandateId" or not k.startswith('_'):
- cleanedRecord[k] = v
- cleanedRecords.append(cleanedRecord)
- return cleanedRecords
- else:
- # For other models, filter out all database-specific fields
- cleanedRecords = []
- for record in filteredRecords:
- cleanedRecord = {k: v for k, v in record.items() if not k.startswith('_')}
- cleanedRecords.append(cleanedRecord)
- return cleanedRecords
+
+ def checkRbacPermission(
+ self,
+ modelClass: type,
+ operation: str,
+ recordId: Optional[str] = None
+ ) -> bool:
+ """
+ Check RBAC permission for a specific operation on a table.
- def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
- """Delegate to access control module."""
- return self.access.canModify(model_class, recordId)
+ Args:
+ modelClass: Pydantic model class for the table
+ operation: Operation to check ('create', 'update', 'delete', 'read')
+ recordId: Optional record ID for specific record check
+
+ Returns:
+ Boolean indicating permission
+ """
+ if not self.rbac or not self.currentUser:
+ return False
+
+ tableName = modelClass.__name__
+ permissions = self.rbac.getUserPermissions(
+ self.currentUser,
+ AccessRuleContext.DATA,
+ tableName
+ )
+
+ if operation == "create":
+ return permissions.create != AccessLevel.NONE
+ elif operation == "update":
+ return permissions.update != AccessLevel.NONE
+ elif operation == "delete":
+ return permissions.delete != AccessLevel.NONE
+ elif operation == "read":
+ return permissions.read != AccessLevel.NONE
+ else:
+ return False
def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
@@ -567,8 +580,11 @@ class ChatObjects:
If pagination is None: List[Dict[str, Any]]
If pagination is provided: PaginatedResult with items and metadata
"""
- allWorkflows = self.db.getRecordset(ChatWorkflow)
- filteredWorkflows = self._uam(ChatWorkflow, allWorkflows)
+ # Use RBAC filtering
+ filteredWorkflows = self.db.getRecordsetWithRBAC(
+ ChatWorkflow,
+ self.currentUser
+ )
# If no pagination requested, return all items (no sorting - frontend handles it)
if pagination is None:
@@ -599,15 +615,17 @@ class ChatObjects:
def getWorkflow(self, workflowId: str) -> Optional[ChatWorkflow]:
"""Returns a workflow by ID if user has access."""
- workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
+ # Use RBAC filtering
+ workflows = self.db.getRecordsetWithRBAC(
+ ChatWorkflow,
+ self.currentUser,
+ recordFilter={"id": workflowId}
+ )
+
if not workflows:
return None
- filteredWorkflows = self._uam(ChatWorkflow, workflows)
- if not filteredWorkflows:
- return None
-
- workflow = filteredWorkflows[0]
+ workflow = workflows[0]
try:
# Load related data from normalized tables
logs = self.getLogs(workflowId)
@@ -637,7 +655,7 @@ class ChatObjects:
def createWorkflow(self, workflowData: Dict[str, Any]) -> ChatWorkflow:
"""Creates a new workflow if user has permission."""
- if not self._canModify(ChatWorkflow):
+ if not self.checkRbacPermission(ChatWorkflow, "create"):
raise PermissionError("No permission to create workflows")
# Set timestamp if not present
@@ -682,7 +700,7 @@ class ChatObjects:
if not workflow:
return None
- if not self._canModify(ChatWorkflow, workflowId):
+ if not self.checkRbacPermission(ChatWorkflow, "update", workflowId):
raise PermissionError(f"No permission to update workflow {workflowId}")
# Use generic field separation based on ChatWorkflow model
@@ -728,7 +746,7 @@ class ChatObjects:
if not workflow:
return False
- if not self._canModify(ChatWorkflow, workflowId):
+ if not self.checkRbacPermission(ChatWorkflow, "delete", workflowId):
raise PermissionError(f"No permission to delete workflow {workflowId}")
# CASCADE DELETE: Delete all related data first
@@ -787,18 +805,18 @@ class ChatObjects:
If pagination is provided: PaginatedResult with items and metadata
"""
# Check workflow access first (without calling getWorkflow to avoid circular reference)
- workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
+ # Use RBAC filtering
+ workflows = self.db.getRecordsetWithRBAC(
+ ChatWorkflow,
+ self.currentUser,
+ recordFilter={"id": workflowId}
+ )
+
if not workflows:
if pagination is None:
return []
return PaginatedResult(items=[], totalItems=0, totalPages=0)
- filteredWorkflows = self._uam(ChatWorkflow, workflows)
- if not filteredWorkflows:
- if pagination is None:
- return []
- return PaginatedResult(items=[], totalItems=0, totalPages=0)
-
# Get messages for this workflow from normalized table
messages = self.db.getRecordset(ChatMessage, recordFilter={"workflowId": workflowId})
@@ -938,7 +956,7 @@ class ChatObjects:
if not workflow:
raise PermissionError(f"No access to workflow {workflowId}")
- if not self._canModify(ChatWorkflow, workflowId):
+ if not self.checkRbacPermission(ChatWorkflow, "update", workflowId):
raise PermissionError(f"No permission to modify workflow {workflowId}")
# Validate that ID is not None
@@ -1054,7 +1072,7 @@ class ChatObjects:
if not workflow:
raise PermissionError(f"No access to workflow {workflowId}")
- if not self._canModify(ChatWorkflow, workflowId):
+ if not self.checkRbacPermission(ChatWorkflow, "update", workflowId):
raise PermissionError(f"No permission to modify workflow {workflowId}")
logger.info(f"Creating new message with ID {messageId} for workflow {workflowId}")
@@ -1072,7 +1090,7 @@ class ChatObjects:
if not workflow:
raise PermissionError(f"No access to workflow {workflowId}")
- if not self._canModify(ChatWorkflow, workflowId):
+ if not self.checkRbacPermission(ChatWorkflow, "update", workflowId):
raise PermissionError(f"No permission to modify workflow {workflowId}")
# Use generic field separation based on ChatMessage model
@@ -1132,7 +1150,7 @@ class ChatObjects:
logger.warning(f"No access to workflow {workflowId}")
return False
- if not self._canModify(ChatWorkflow, workflowId):
+ if not self.checkRbacPermission(ChatWorkflow, "update", workflowId):
raise PermissionError(f"No permission to modify workflow {workflowId}")
# Check if the message exists
@@ -1173,7 +1191,7 @@ class ChatObjects:
logger.warning(f"No access to workflow {workflowId}")
return False
- if not self._canModify(ChatWorkflow, workflowId):
+ if not self.checkRbacPermission(ChatWorkflow, "update", workflowId):
raise PermissionError(f"No permission to modify workflow {workflowId}")
@@ -1257,18 +1275,18 @@ class ChatObjects:
If pagination is provided: PaginatedResult with items and metadata
"""
# Check workflow access first (without calling getWorkflow to avoid circular reference)
- workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
+ # Use RBAC filtering
+ workflows = self.db.getRecordsetWithRBAC(
+ ChatWorkflow,
+ self.currentUser,
+ recordFilter={"id": workflowId}
+ )
+
if not workflows:
if pagination is None:
return []
return PaginatedResult(items=[], totalItems=0, totalPages=0)
- filteredWorkflows = self._uam(ChatWorkflow, workflows)
- if not filteredWorkflows:
- if pagination is None:
- return []
- return PaginatedResult(items=[], totalItems=0, totalPages=0)
-
# Get logs for this workflow from normalized table
logs = self.db.getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
@@ -1335,7 +1353,7 @@ class ChatObjects:
logger.warning(f"No access to workflow {workflowId}")
return None
- if not self._canModify(ChatWorkflow, workflowId):
+ if not self.checkRbacPermission(ChatWorkflow, "update", workflowId):
logger.warning(f"No permission to modify workflow {workflowId}")
return None
@@ -1378,14 +1396,16 @@ class ChatObjects:
def getStats(self, workflowId: str) -> List[ChatStat]:
"""Returns list of statistics for a workflow if user has access."""
# Check workflow access first (without calling getWorkflow to avoid circular reference)
- workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
+ # Use RBAC filtering
+ workflows = self.db.getRecordsetWithRBAC(
+ ChatWorkflow,
+ self.currentUser,
+ recordFilter={"id": workflowId}
+ )
+
if not workflows:
return []
- filteredWorkflows = self._uam(ChatWorkflow, workflows)
- if not filteredWorkflows:
- return []
-
# Get stats for this workflow from normalized table
stats = self.db.getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
@@ -1423,13 +1443,15 @@ class ChatObjects:
Uses timestamp-based selective data transfer for efficient polling.
"""
# Check workflow access first
- workflows = self.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
+ # Use RBAC filtering
+ workflows = self.db.getRecordsetWithRBAC(
+ ChatWorkflow,
+ self.currentUser,
+ recordFilter={"id": workflowId}
+ )
+
if not workflows:
return {"items": []}
-
- filteredWorkflows = self._uam(ChatWorkflow, workflows)
- if not filteredWorkflows:
- return {"items": []}
# Get all data types and filter in Python (PostgreSQL connector doesn't support $gt operators)
items = []
@@ -1585,8 +1607,11 @@ class ChatObjects:
Supports optional pagination, sorting, and filtering.
Computes status field for each automation.
"""
- allAutomations = self.db.getRecordset(AutomationDefinition)
- filteredAutomations = self._uam(AutomationDefinition, allAutomations)
+ # Use RBAC filtering
+ filteredAutomations = self.db.getRecordsetWithRBAC(
+ AutomationDefinition,
+ self.currentUser
+ )
# Compute status for each automation and normalize executionLogs
for automation in filteredAutomations:
@@ -1628,8 +1653,12 @@ class ChatObjects:
def getAutomationDefinition(self, automationId: str) -> Optional[Dict[str, Any]]:
"""Returns an automation definition by ID if user has access, with computed status."""
try:
- automations = self.db.getRecordset(AutomationDefinition, recordFilter={"id": automationId})
- filtered = self._uam(AutomationDefinition, automations)
+ # Use RBAC filtering
+ filtered = self.db.getRecordsetWithRBAC(
+ AutomationDefinition,
+ self.currentUser,
+ recordFilter={"id": automationId}
+ )
if not filtered:
return None
@@ -1695,7 +1724,7 @@ class ChatObjects:
if not existing:
raise PermissionError(f"No access to automation {automationId}")
- if not self._canModify(AutomationDefinition, automationId):
+ if not self.checkRbacPermission(AutomationDefinition, "update", automationId):
raise PermissionError(f"No permission to modify automation {automationId}")
# Use generic field separation
@@ -1726,7 +1755,7 @@ class ChatObjects:
if not existing:
raise PermissionError(f"No access to automation {automationId}")
- if not self._canModify(AutomationDefinition, automationId):
+ if not self.checkRbacPermission(AutomationDefinition, "delete", automationId):
raise PermissionError(f"No permission to delete automation {automationId}")
# Remove event if exists
diff --git a/modules/interfaces/interfaceDbComponentAccess.py b/modules/interfaces/interfaceDbComponentAccess.py
deleted file mode 100644
index 36c3cfff..00000000
--- a/modules/interfaces/interfaceDbComponentAccess.py
+++ /dev/null
@@ -1,203 +0,0 @@
-"""
-Access control module for Management interface.
-Handles user access management and permission checks.
-"""
-
-import logging
-from typing import Dict, Any, List, Optional
-from modules.datamodels.datamodelUam import User
-from modules.datamodels.datamodelUtils import Prompt
-from modules.datamodels.datamodelFiles import FileItem
-from modules.datamodels.datamodelChat import ChatWorkflow
-
-# Configure logger
-logger = logging.getLogger(__name__)
-
-class ComponentAccess:
- """
- Access control class for Management interface.
- Handles user access management and permission checks.
- """
-
- def __init__(self, currentUser: User, db):
- """Initialize with user context."""
- self.currentUser = currentUser
- self.userId = currentUser.id
- self.mandateId = currentUser.mandateId
- self.privilege = currentUser.privilege
- self.db = db
-
- def getInitialUserid(self):
- return "----"
- # return self.db.getInitialUserId() --> to get from AdminDB !
-
- def canModifyAttribute(self, table: str, attribute: str) -> bool:
- """
- Checks if the current user can modify a specific attribute in a table.
-
- Args:
- table: Name of the table
- attribute: Name of the attribute
-
- Returns:
- Boolean indicating permission
- """
- userPrivilege = self.privilege
-
- # Special case for mandateId in prompts table
- if table == "prompts" and attribute == "mandateId":
- return userPrivilege == "sysadmin"
-
- return True
-
- def uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """
- Unified user access management function that filters data based on user privileges
- and adds access control attributes.
-
- Args:
- model_class: Pydantic model class for the table
- recordset: Recordset to filter based on access rules
-
- Returns:
- Filtered recordset with access control attributes
- """
- userPrivilege = self.privilege
- table_name = model_class.__name__
-
- filtered_records = []
-
- initialid = self.getInitialUserid()
-
- # Apply filtering based on privilege
- if userPrivilege == "sysadmin":
- filtered_records = recordset # System admins see all records
- elif userPrivilege == "admin":
- # Admins see records in their mandate
- filtered_records = [r for r in recordset if r.get("mandateId") == self.mandateId]
- else: # Regular users
- # For prompts, users can see all prompts from their mandate
- if table_name == "Prompt":
- filtered_records = [r for r in recordset if r.get("mandateId") == self.mandateId]
- elif table_name == "UserInDB":
- # For users table, users can only see their own record
- filtered_records = [r for r in recordset if r.get("id") == self.userId]
- elif table_name == "VoiceSettings":
- # For voice settings, users can only see their own settings
- filtered_records = [r for r in recordset if r.get("userId") == self.userId]
- else:
- # Users see only their records for other tables
- filtered_records = [
- r for r in recordset
- if r.get("mandateId") == self.mandateId and r.get("_createdBy") == self.userId
- ]
-
- # Add access control attributes to each record
- for record in filtered_records:
- record_id = record.get("id")
-
- # Set access control flags based on user permissions
- if table_name == "Prompt":
- record["_hideView"] = False # Everyone can view
- record["_hideEdit"] = not self.canModify(Prompt, record_id)
- record["_hideDelete"] = not self.canModify(Prompt, record_id)
-
- # Add attribute-level permissions for mandateId
- if "mandateId" in record:
- record["_hideEdit_mandateId"] = not self.canModifyAttribute(Prompt, "mandateId")
- elif table_name == "FileItem":
- record["_hideView"] = False # Everyone can view
- record["_hideEdit"] = not self.canModify(FileItem, record_id)
- record["_hideDelete"] = not self.canModify(FileItem, record_id)
- record["_hideDownload"] = not self.canModify(FileItem, record_id)
- elif table_name == "ChatWorkflow":
- record["_hideView"] = False # Everyone can view
- record["_hideEdit"] = not self.canModify(ChatWorkflow, record_id)
- record["_hideDelete"] = not self.canModify(ChatWorkflow, record_id)
- elif table_name == "ChatMessage":
- record["_hideView"] = False # Everyone can view
- record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
- record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
- elif table_name == "ChatLog":
- record["_hideView"] = False # Everyone can view
- record["_hideEdit"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
- record["_hideDelete"] = not self.canModify(ChatWorkflow, record.get("workflowId"))
- elif table_name == "UserInDB":
- # For users table, users can only modify their own connections
- record["_hideView"] = False
- record["_hideEdit"] = record_id != self.userId
- record["_hideDelete"] = record_id != self.userId
- # Add connection-specific permissions
- if "connections" in record:
- for conn in record["connections"]:
- conn["_hideEdit"] = record_id != self.userId
- conn["_hideDelete"] = record_id != self.userId
- elif table_name == "VoiceSettings":
- # For voice settings, users can only access their own settings
- record["_hideView"] = False
- record["_hideEdit"] = record.get("userId") != self.userId
- record["_hideDelete"] = record.get("userId") != self.userId
- else:
- # Default access control for other tables
- record["_hideView"] = False
- record["_hideEdit"] = not self.canModify(model_class, record_id)
- record["_hideDelete"] = not self.canModify(model_class, record_id)
-
- return filtered_records
-
- def canModify(self, model_class: type, recordId: Optional[int] = None) -> bool:
- """
- Checks if the current user can modify (create/update/delete) records in a table.
-
- Args:
- model_class: Pydantic model class for the table
- recordId: Optional record ID for specific record check
-
- Returns:
- Boolean indicating permission
- """
- userPrivilege = self.privilege
-
- # System admins can modify anything
- if userPrivilege == "sysadmin":
- return True
-
- # For regular users and admins, check specific cases
- if recordId is not None:
- # Get the record to check ownership
- records: List[Dict[str, Any]] = self.db.getRecordset(model_class, recordFilter={"id": recordId})
- if not records:
- return False
-
- record = records[0]
-
- # Special case for users table - users can modify their own connections
- if model_class.__name__ == "UserInDB":
- if record.get("id") == self.userId:
- return True
- return False
-
- # Special case for voice settings - users can modify their own settings
- if model_class.__name__ == "VoiceSettings":
- if record.get("userId") == self.userId:
- return True
- return False
-
- # Admins can modify anything in their mandate, if mandate is specified for a record
- if userPrivilege == "admin" and record.get("mandateId","-") == self.mandateId:
- return True
-
- # Regular users can only modify their own records
- if (record.get("mandateId","-") == self.mandateId and
- record.get("_createdBy") == self.userId):
- return True
-
- return False
- else:
- # For general modification permission (e.g., create)
- # Admins can create anything in their mandate
- if userPrivilege == "admin":
- return True
-
- # Regular users can create in most tables
- return True
\ No newline at end of file
diff --git a/modules/interfaces/interfaceDbComponentObjects.py b/modules/interfaces/interfaceDbComponentObjects.py
index 225f8ad5..0e1be949 100644
--- a/modules/interfaces/interfaceDbComponentObjects.py
+++ b/modules/interfaces/interfaceDbComponentObjects.py
@@ -11,7 +11,9 @@ import math
from typing import Dict, Any, List, Optional, Union
from modules.connectors.connectorDbPostgre import DatabaseConnector
-from modules.interfaces.interfaceDbComponentAccess import ComponentAccess
+from modules.security.rbac import RbacClass
+from modules.datamodels.datamodelRbac import AccessRuleContext
+from modules.datamodels.datamodelUam import AccessLevel
from modules.datamodels.datamodelFiles import FilePreview, FileItem, FileData
from modules.datamodels.datamodelUtils import Prompt
from modules.datamodels.datamodelVoice import VoiceSettings
@@ -57,7 +59,7 @@ class ComponentObjects:
# Initialize variables first
self.currentUser: Optional[User] = None
self.userId: Optional[str] = None
- self.access: Optional[ComponentAccess] = None # Will be set when user context is provided
+ self.rbac: Optional[RbacClass] = None # RBAC interface
# Initialize database
self._initializeDatabase()
@@ -80,8 +82,10 @@ class ComponentObjects:
# Add language settings
self.userLanguage = currentUser.language # Default user language
- # Initialize access control with user context
- self.access = ComponentAccess(self.currentUser, self.db)
+ # Initialize RBAC interface
+ if not self.currentUser:
+ raise ValueError("User context is required for RBAC")
+ self.rbac = RbacClass(self.db)
# Update database context
self.db.updateContext(self.userId)
@@ -214,7 +218,6 @@ class ComponentObjects:
else:
self.currentUser = None
self.userId = None
- self.access = None
self.db.updateContext("") # Reset database context
except Exception as e:
@@ -225,26 +228,46 @@ class ComponentObjects:
else:
self.currentUser = None
self.userId = None
- self.access = None
self.db.updateContext("") # Reset database context
- def _uam(self, model_class: type, recordset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """Delegate to access control module."""
- # First apply access control
- filteredRecords = self.access.uam(model_class, recordset)
-
- # Then filter out database-specific fields
- cleanedRecords = []
- for record in filteredRecords:
- # Create a new dict with only non-database fields
- cleanedRecord = {k: v for k, v in record.items() if not k.startswith('_')}
- cleanedRecords.append(cleanedRecord)
-
- return cleanedRecords
+
+ def checkRbacPermission(
+ self,
+ modelClass: type,
+ operation: str,
+ recordId: Optional[str] = None
+ ) -> bool:
+ """
+ Check RBAC permission for a specific operation on a table.
- def _canModify(self, model_class: type, recordId: Optional[str] = None) -> bool:
- """Delegate to access control module."""
- return self.access.canModify(model_class, recordId)
+ Args:
+ modelClass: Pydantic model class for the table
+ operation: Operation to check ('create', 'update', 'delete', 'read')
+ recordId: Optional record ID for specific record check
+
+ Returns:
+ Boolean indicating permission
+ """
+ if not self.rbac or not self.currentUser:
+ return False
+
+ tableName = modelClass.__name__
+ permissions = self.rbac.getUserPermissions(
+ self.currentUser,
+ AccessRuleContext.DATA,
+ tableName
+ )
+
+ if operation == "create":
+ return permissions.create != AccessLevel.NONE
+ elif operation == "update":
+ return permissions.update != AccessLevel.NONE
+ elif operation == "delete":
+ return permissions.delete != AccessLevel.NONE
+ elif operation == "read":
+ return permissions.read != AccessLevel.NONE
+ else:
+ return False
def _applyFilters(self, records: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
@@ -474,8 +497,11 @@ class ComponentObjects:
If pagination is provided: PaginatedResult with items and metadata
"""
try:
- allPrompts = self.db.getRecordset(Prompt)
- filteredPrompts = self._uam(Prompt, allPrompts)
+ # Use RBAC filtering
+ filteredPrompts = self.db.getRecordsetWithRBAC(
+ Prompt,
+ self.currentUser
+ )
# If no pagination requested, return all items
if pagination is None:
@@ -515,16 +541,18 @@ class ComponentObjects:
def getPrompt(self, promptId: str) -> Optional[Prompt]:
"""Returns a prompt by ID if user has access."""
- prompts = self.db.getRecordset(Prompt, recordFilter={"id": promptId})
- if not prompts:
- return None
+ # Use RBAC filtering
+ filteredPrompts = self.db.getRecordsetWithRBAC(
+ Prompt,
+ self.currentUser,
+ recordFilter={"id": promptId}
+ )
- filteredPrompts = self._uam(Prompt, prompts)
return Prompt(**filteredPrompts[0]) if filteredPrompts else None
def createPrompt(self, promptData: Dict[str, Any]) -> Dict[str, Any]:
"""Creates a new prompt if user has permission."""
- if not self._canModify(Prompt):
+ if not self.checkRbacPermission(Prompt, "create"):
raise PermissionError("No permission to create prompts")
# Create prompt record
@@ -565,7 +593,7 @@ class ComponentObjects:
if not prompt:
return False
- if not self._canModify(Prompt, promptId):
+ if not self.checkRbacPermission(Prompt, "update", promptId):
raise PermissionError(f"No permission to delete prompt {promptId}")
# Delete prompt
@@ -580,13 +608,12 @@ class ComponentObjects:
"""Checks if a file with the same hash already exists for the current user and mandate.
If fileName is provided, also checks for exact name+hash match.
Only returns files the current user has access to."""
- # First get all files with the hash
- allFilesWithHash = self.db.getRecordset(FileItem, recordFilter={
- "fileHash": fileHash
- })
-
- # Filter by user access using UAM
- accessibleFiles = self._uam(FileItem, allFilesWithHash)
+ # Get files with the hash, filtered by RBAC
+ accessibleFiles = self.db.getRecordsetWithRBAC(
+ FileItem,
+ self.currentUser,
+ recordFilter={"fileHash": fileHash}
+ )
if not accessibleFiles:
return None
@@ -711,8 +738,11 @@ class ComponentObjects:
If pagination is None: List[FileItem]
If pagination is provided: PaginatedResult with items and metadata
"""
- allFiles = self.db.getRecordset(FileItem)
- filteredFiles = self._uam(FileItem, allFiles)
+ # Use RBAC filtering
+ filteredFiles = self.db.getRecordsetWithRBAC(
+ FileItem,
+ self.currentUser
+ )
# Convert database records to FileItem instances (for both paginated and non-paginated)
def convertFileItems(files):
@@ -775,11 +805,13 @@ class ComponentObjects:
def getFile(self, fileId: str) -> Optional[FileItem]:
"""Returns a file by ID if user has access."""
- files = self.db.getRecordset(FileItem, recordFilter={"id": fileId})
- if not files:
- return None
-
- filteredFiles = self._uam(FileItem, files)
+ # Use RBAC filtering
+ filteredFiles = self.db.getRecordsetWithRBAC(
+ FileItem,
+ self.currentUser,
+ recordFilter={"id": fileId}
+ )
+
if not filteredFiles:
return None
@@ -838,7 +870,7 @@ class ComponentObjects:
def createFile(self, name: str, mimeType: str, content: bytes) -> FileItem:
"""Creates a new file entry if user has permission. Computes fileHash and fileSize from content."""
- if not self._canModify(FileItem):
+ if not self.checkRbacPermission(FileItem, "create"):
raise PermissionError("No permission to create files")
# Ensure fileName is unique
@@ -873,7 +905,7 @@ class ComponentObjects:
if not file:
raise FileNotFoundError(f"File with ID {fileId} not found")
- if not self._canModify(FileItem, fileId):
+ if not self.checkRbacPermission(FileItem, "update", fileId):
raise PermissionError(f"No permission to update file {fileId}")
# If fileName is being updated, ensure it's unique
@@ -895,7 +927,7 @@ class ComponentObjects:
if not file:
raise FileNotFoundError(f"File with ID {fileId} not found")
- if not self._canModify(FileItem, fileId):
+ if not self.checkRbacPermission(FileItem, "update", fileId):
raise PermissionError(f"No permission to delete file {fileId}")
# Check for other references to this file (by hash)
@@ -1090,7 +1122,7 @@ class ComponentObjects:
"""Saves an uploaded file if user has permission."""
try:
# Check file creation permission
- if not self._canModify(FileItem):
+ if not self.checkRbacPermission(FileItem, "create"):
raise PermissionError("No permission to upload files")
logger.debug(f"Starting upload process for file: {fileName}")
@@ -1151,14 +1183,13 @@ class ComponentObjects:
logger.error("No user ID provided for voice settings")
return None
- # Get voice settings for the user
- settings = self.db.getRecordset(VoiceSettings, recordFilter={"userId": targetUserId})
- if not settings:
- logger.debug(f"No voice settings found for user {targetUserId}")
- return None
+ # Get voice settings for the user, filtered by RBAC
+ filteredSettings = self.db.getRecordsetWithRBAC(
+ VoiceSettings,
+ self.currentUser,
+ recordFilter={"userId": targetUserId}
+ )
- # Apply access control
- filteredSettings = self._uam(VoiceSettings, settings)
if not filteredSettings:
logger.warning(f"No access to voice settings for user {targetUserId}")
return None
@@ -1179,7 +1210,7 @@ class ComponentObjects:
def createVoiceSettings(self, settingsData: Dict[str, Any]) -> Dict[str, Any]:
"""Creates voice settings for a user if user has permission."""
try:
- if not self._canModify(VoiceSettings):
+ if not self.checkRbacPermission(VoiceSettings, "update"):
raise PermissionError("No permission to create voice settings")
# Ensure userId is set
diff --git a/modules/migration/__init__.py b/modules/migration/__init__.py
new file mode 100644
index 00000000..49056d7c
--- /dev/null
+++ b/modules/migration/__init__.py
@@ -0,0 +1 @@
+"""Migration modules for database schema and data migrations."""
diff --git a/modules/migration/migrateUamToRbac.py b/modules/migration/migrateUamToRbac.py
new file mode 100644
index 00000000..688bf8e7
--- /dev/null
+++ b/modules/migration/migrateUamToRbac.py
@@ -0,0 +1,212 @@
+"""
+Migration script to convert UAM (User Access Management) to RBAC (Role-Based Access Control).
+
+This script:
+1. Creates AccessRule table if it doesn't exist
+2. Adds roleLabels column to User table if it doesn't exist
+3. Converts User.privilege to User.roleLabels
+4. Creates initial RBAC rules based on bootstrap logic
+"""
+
+import logging
+from typing import List, Dict, Any
+from modules.connectors.connectorDbPostgre import DatabaseConnector
+from modules.shared.configuration import APP_CONFIG
+from modules.datamodels.datamodelUam import UserInDB, UserPrivilege
+from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
+from modules.datamodels.datamodelUam import AccessLevel
+from modules.interfaces.interfaceBootstrap import initRbacRules
+
+logger = logging.getLogger(__name__)
+
+
+def migrateUamToRbac(db: DatabaseConnector, dryRun: bool = False) -> Dict[str, Any]:
+ """
+ Migrate from UAM to RBAC system.
+
+ Args:
+ db: Database connector instance
+ dryRun: If True, only report what would be done without making changes
+
+ Returns:
+ Dictionary with migration results
+ """
+ results = {
+ "schemaChanges": [],
+ "dataMigrations": [],
+ "rulesCreated": 0,
+ "usersUpdated": 0,
+ "errors": []
+ }
+
+ try:
+ # Step 1: Ensure AccessRule table exists
+ logger.info("Step 1: Ensuring AccessRule table exists")
+ if not dryRun:
+ db._ensureTableExists(AccessRule)
+ results["schemaChanges"].append("AccessRule table ensured")
+ else:
+ results["schemaChanges"].append("Would ensure AccessRule table")
+
+ # Step 2: Add roleLabels column to UserInDB table if it doesn't exist
+ logger.info("Step 2: Adding roleLabels column to UserInDB table")
+ if not dryRun:
+ try:
+ with db.connection.cursor() as cursor:
+ # Check if column exists
+ cursor.execute("""
+ SELECT column_name
+ FROM information_schema.columns
+ WHERE table_name = 'UserInDB' AND column_name = 'roleLabels'
+ """)
+ columnExists = cursor.fetchone() is not None
+
+ if not columnExists:
+ cursor.execute('ALTER TABLE "UserInDB" ADD COLUMN "roleLabels" JSONB DEFAULT \'[]\'::jsonb')
+ db.connection.commit()
+ results["schemaChanges"].append("Added roleLabels column to UserInDB")
+ logger.info("Added roleLabels column to UserInDB table")
+ else:
+ results["schemaChanges"].append("roleLabels column already exists")
+ logger.info("roleLabels column already exists in UserInDB table")
+ except Exception as e:
+ logger.error(f"Error adding roleLabels column: {e}")
+ results["errors"].append(f"Error adding roleLabels column: {e}")
+ db.connection.rollback()
+ else:
+ results["schemaChanges"].append("Would add roleLabels column to UserInDB")
+
+ # Step 3: Convert User.privilege to User.roleLabels
+ logger.info("Step 3: Converting User.privilege to User.roleLabels")
+ if not dryRun:
+ try:
+ users = db.getRecordset(UserInDB)
+ updatedCount = 0
+
+ for user in users:
+ privilege = user.get("privilege")
+ roleLabels = user.get("roleLabels", [])
+
+ # Skip if already has roleLabels
+ if roleLabels and isinstance(roleLabels, list) and len(roleLabels) > 0:
+ logger.debug(f"User {user.get('id')} already has roleLabels: {roleLabels}")
+ continue
+
+ # Convert privilege to roleLabels
+ if privilege == UserPrivilege.SYSADMIN.value:
+ newRoleLabels = ["sysadmin"]
+ elif privilege == UserPrivilege.ADMIN.value:
+ newRoleLabels = ["admin"]
+ elif privilege == UserPrivilege.USER.value:
+ newRoleLabels = ["user"]
+ else:
+ # Default to user if privilege is unknown
+ newRoleLabels = ["user"]
+ logger.warning(f"Unknown privilege '{privilege}' for user {user.get('id')}, defaulting to 'user'")
+
+ # Update user
+ user["roleLabels"] = newRoleLabels
+ db.recordModify(UserInDB, user["id"], user)
+ updatedCount += 1
+ logger.info(f"Updated user {user.get('id')} ({user.get('username')}): {privilege} -> {newRoleLabels}")
+
+ results["usersUpdated"] = updatedCount
+ logger.info(f"Updated {updatedCount} users with roleLabels")
+ except Exception as e:
+ logger.error(f"Error converting user privileges: {e}")
+ results["errors"].append(f"Error converting user privileges: {e}")
+ else:
+ # Dry run: count users that would be updated
+ users = db.getRecordset(UserInDB)
+ wouldUpdate = 0
+ for user in users:
+ roleLabels = user.get("roleLabels", [])
+ if not roleLabels or not isinstance(roleLabels, list) or len(roleLabels) == 0:
+ wouldUpdate += 1
+ results["usersUpdated"] = wouldUpdate
+ logger.info(f"Would update {wouldUpdate} users with roleLabels")
+
+ # Step 4: Create RBAC rules if they don't exist
+ logger.info("Step 4: Creating RBAC rules")
+ if not dryRun:
+ try:
+ existingRules = db.getRecordset(AccessRule)
+ if existingRules:
+ results["rulesCreated"] = len(existingRules)
+ results["dataMigrations"].append(f"RBAC rules already exist ({len(existingRules)} rules)")
+ logger.info(f"RBAC rules already exist ({len(existingRules)} rules)")
+ else:
+ # Initialize RBAC rules using bootstrap logic
+ initRbacRules(db)
+ newRules = db.getRecordset(AccessRule)
+ results["rulesCreated"] = len(newRules)
+ results["dataMigrations"].append(f"Created {len(newRules)} RBAC rules")
+ logger.info(f"Created {len(newRules)} RBAC rules")
+ except Exception as e:
+ logger.error(f"Error creating RBAC rules: {e}")
+ results["errors"].append(f"Error creating RBAC rules: {e}")
+ else:
+ existingRules = db.getRecordset(AccessRule)
+ if existingRules:
+ results["rulesCreated"] = len(existingRules)
+ results["dataMigrations"].append(f"RBAC rules already exist ({len(existingRules)} rules)")
+ else:
+ results["dataMigrations"].append("Would create RBAC rules")
+
+ logger.info("Migration completed successfully")
+ return results
+
+ except Exception as e:
+ logger.error(f"Migration failed: {e}")
+ results["errors"].append(f"Migration failed: {e}")
+ return results
+
+
+def validateMigration(db: DatabaseConnector) -> Dict[str, Any]:
+ """
+ Validate that migration was successful.
+
+ Args:
+ db: Database connector instance
+
+ Returns:
+ Dictionary with validation results
+ """
+ validation = {
+ "valid": True,
+ "issues": []
+ }
+
+ try:
+ # Check that AccessRule table exists
+ try:
+ rules = db.getRecordset(AccessRule)
+ if not rules:
+ validation["valid"] = False
+ validation["issues"].append("AccessRule table exists but has no rules")
+ except Exception as e:
+ validation["valid"] = False
+ validation["issues"].append(f"AccessRule table does not exist or is not accessible: {e}")
+
+ # Check that all users have roleLabels
+ users = db.getRecordset(UserInDB)
+ usersWithoutRoles = []
+ for user in users:
+ roleLabels = user.get("roleLabels", [])
+ if not roleLabels or not isinstance(roleLabels, list) or len(roleLabels) == 0:
+ usersWithoutRoles.append({
+ "id": user.get("id"),
+ "username": user.get("username"),
+ "privilege": user.get("privilege")
+ })
+
+ if usersWithoutRoles:
+ validation["valid"] = False
+ validation["issues"].append(f"{len(usersWithoutRoles)} users without roleLabels: {[u['username'] for u in usersWithoutRoles]}")
+
+ return validation
+
+ except Exception as e:
+ validation["valid"] = False
+ validation["issues"].append(f"Validation error: {e}")
+ return validation
diff --git a/modules/routes/routeDataFiles.py b/modules/routes/routeDataFiles.py
index 7c0f60c0..5cdfcfc5 100644
--- a/modules/routes/routeDataFiles.py
+++ b/modules/routes/routeDataFiles.py
@@ -229,8 +229,8 @@ async def update_file(
detail=f"File with ID {fileId} not found"
)
- # Check if user has access to the file using the interface's permission system
- if not managementInterface._canModify("files", fileId):
+ # Check if user has access to the file using RBAC
+ if not managementInterface.checkRbacPermission(FileItem, "update", fileId):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authorized to update this file"
diff --git a/modules/routes/routeRbac.py b/modules/routes/routeRbac.py
new file mode 100644
index 00000000..95184779
--- /dev/null
+++ b/modules/routes/routeRbac.py
@@ -0,0 +1,161 @@
+"""
+RBAC routes for the backend API.
+Implements endpoints for role-based access control permissions.
+"""
+
+from fastapi import APIRouter, HTTPException, Depends, Query, Request
+from typing import Optional
+import logging
+
+from modules.security.auth import getCurrentUser, limiter
+from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel
+from modules.datamodels.datamodelRbac import AccessRuleContext
+from modules.interfaces.interfaceDbAppObjects import getInterface
+
+# Configure logger
+logger = logging.getLogger(__name__)
+
+router = APIRouter(
+ prefix="/api/rbac",
+ tags=["RBAC"],
+ responses={404: {"description": "Not found"}}
+ )
+
+
+@router.get("/permissions", response_model=UserPermissions)
+@limiter.limit("60/minute")
+async def getPermissions(
+ request: Request,
+ context: str = Query(..., description="Context type: DATA, UI, or RESOURCE"),
+ item: Optional[str] = Query(None, description="Item identifier (table name, UI path, or resource path)"),
+ currentUser: User = Depends(getCurrentUser)
+ ) -> UserPermissions:
+ """
+ Get RBAC permissions for the current user for a specific context and item.
+
+ Query Parameters:
+ - context: Context type (DATA, UI, or RESOURCE)
+ - item: Optional item identifier. For DATA: table name (e.g., "UserInDB"),
+ For UI: cascading string (e.g., "playground.voice.settings"),
+ For RESOURCE: cascading string (e.g., "ai.model.anthropic")
+
+ Returns:
+ - UserPermissions object with view, read, create, update, delete permissions
+
+ Examples:
+ - GET /api/rbac/permissions?context=DATA&item=UserInDB
+ - GET /api/rbac/permissions?context=UI&item=playground.voice.settings
+ - GET /api/rbac/permissions?context=RESOURCE&item=ai.model.anthropic
+ """
+ try:
+ # Validate context
+ try:
+ accessContext = AccessRuleContext(context.upper())
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid context '{context}'. Must be one of: DATA, UI, RESOURCE"
+ )
+
+ # Get interface and RBAC permissions
+ interface = getInterface(currentUser)
+ if not interface.rbac:
+ raise HTTPException(
+ status_code=500,
+ detail="RBAC interface not available"
+ )
+
+ # Get permissions
+ permissions = interface.rbac.getUserPermissions(
+ currentUser,
+ accessContext,
+ item or ""
+ )
+
+ return permissions
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting RBAC permissions: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to get permissions: {str(e)}"
+ )
+
+
+@router.get("/rules", response_model=list)
+@limiter.limit("30/minute")
+async def getAccessRules(
+ request: Request,
+ roleLabel: Optional[str] = Query(None, description="Filter by role label"),
+ context: Optional[str] = Query(None, description="Filter by context (DATA, UI, RESOURCE)"),
+ item: Optional[str] = Query(None, description="Filter by item identifier"),
+ currentUser: User = Depends(getCurrentUser)
+ ) -> list:
+ """
+ Get access rules with optional filters.
+ Only returns rules that the current user has permission to view.
+
+ Query Parameters:
+ - roleLabel: Optional role label filter
+ - context: Optional context filter (DATA, UI, RESOURCE)
+ - item: Optional item filter
+
+ Returns:
+ - List of AccessRule objects
+ """
+ try:
+ # Get interface
+ interface = getInterface(currentUser)
+
+ # Check if user has permission to view access rules
+ # For now, only sysadmin can view rules
+ if not interface.rbac:
+ raise HTTPException(
+ status_code=500,
+ detail="RBAC interface not available"
+ )
+
+ # Check permission - only sysadmin can view rules
+ permissions = interface.rbac.getUserPermissions(
+ currentUser,
+ AccessRuleContext.DATA,
+ "AccessRule"
+ )
+
+ if not permissions.view or permissions.read == AccessLevel.NONE:
+ raise HTTPException(
+ status_code=403,
+ detail="No permission to view access rules"
+ )
+
+ # Parse context if provided
+ accessContext = None
+ if context:
+ try:
+ accessContext = AccessRuleContext(context.upper())
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid context '{context}'. Must be one of: DATA, UI, RESOURCE"
+ )
+
+ # Get rules
+ rules = interface.getAccessRules(
+ roleLabel=roleLabel,
+ context=accessContext,
+ item=item
+ )
+
+ # Convert to dict for JSON serialization
+ return [rule.model_dump() for rule in rules]
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting access rules: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to get access rules: {str(e)}"
+ )
diff --git a/modules/routes/routeWorkflows.py b/modules/routes/routeWorkflows.py
index ea52a067..080e8077 100644
--- a/modules/routes/routeWorkflows.py
+++ b/modules/routes/routeWorkflows.py
@@ -180,8 +180,8 @@ async def update_workflow(
workflow_data = workflows[0]
- # Check if user has permission to update using the interface's permission system
- if not workflowInterface._canModify("workflows", workflowId):
+ # Check if user has permission to update using RBAC
+ if not workflowInterface.checkRbacPermission(ChatWorkflow, "update", workflowId):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to update this workflow"
@@ -437,8 +437,8 @@ async def delete_workflow(
workflow_data = workflows[0]
- # Check if user has permission to delete using the interface's permission system
- if not interfaceDbChat._canModify("workflows", workflowId):
+ # Check if user has permission to delete using RBAC
+ if not interfaceDbChat.checkRbacPermission(ChatWorkflow, "delete", workflowId):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You don't have permission to delete this workflow"
diff --git a/modules/security/rbac.py b/modules/security/rbac.py
new file mode 100644
index 00000000..ca2050de
--- /dev/null
+++ b/modules/security/rbac.py
@@ -0,0 +1,194 @@
+"""
+RBAC interface: Core RBAC logic and permission resolution.
+Moved from interfaces to security module to maintain proper architectural layering.
+Connectors can import from security, but not from interfaces.
+"""
+
+import logging
+from typing import List, Optional, Dict, Any, TYPE_CHECKING
+from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
+from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel
+
+if TYPE_CHECKING:
+ from modules.connectors.connectorDbPostgre import DatabaseConnector
+
+logger = logging.getLogger(__name__)
+
+
+class RbacClass:
+ """
+ RBAC interface for permission resolution and rule validation.
+ """
+
+ def __init__(self, db: "DatabaseConnector"):
+ """Initialize RBAC interface with database connector."""
+ self.db = db
+
+ def getUserPermissions(self, user: User, context: AccessRuleContext, item: str) -> UserPermissions:
+ """
+ Get combined permissions for a user across all their roles.
+
+ Args:
+ user: User object with roleLabels
+ context: Access rule context (DATA, UI, RESOURCE)
+ item: Item identifier (table name, UI path, resource path)
+
+ Returns:
+ UserPermissions object with combined permissions
+ """
+ permissions = UserPermissions(
+ view=False,
+ read=AccessLevel.NONE,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE
+ )
+
+ if not user.roleLabels:
+ logger.warning(f"User {user.id} has no roleLabels assigned")
+ return permissions
+
+ # Step 1: For each role, find the most specific matching rule (most specific wins within role)
+ rolePermissions = {}
+ for roleLabel in user.roleLabels:
+ # Get all rules for this role and context
+ allRules = self._getRulesForRole(roleLabel, context)
+
+ # Find most specific rule for this item (longest matching prefix)
+ mostSpecificRule = self.findMostSpecificRule(allRules, item)
+
+ if mostSpecificRule:
+ rolePermissions[roleLabel] = mostSpecificRule
+
+ # Step 2: Combine permissions across roles using opening (union) logic
+ for roleLabel, rule in rolePermissions.items():
+ # View: union logic - if ANY role has view=true, then view=true
+ if rule.view:
+ permissions.view = True
+
+ if context == AccessRuleContext.DATA:
+ # For DATA context, use most permissive access level across roles
+ if rule.read and self._isMorePermissive(rule.read, permissions.read):
+ permissions.read = rule.read
+ if rule.create and self._isMorePermissive(rule.create, permissions.create):
+ permissions.create = rule.create
+ if rule.update and self._isMorePermissive(rule.update, permissions.update):
+ permissions.update = rule.update
+ if rule.delete and self._isMorePermissive(rule.delete, permissions.delete):
+ permissions.delete = rule.delete
+
+ return permissions
+
+ def findMostSpecificRule(self, rules: List[AccessRule], item: str) -> Optional[AccessRule]:
+ """
+ Find the most specific rule for an item (longest matching prefix wins).
+
+ Args:
+ rules: List of access rules to search
+ item: Item identifier to match
+
+ Returns:
+ Most specific matching rule, or None if no match
+ """
+ if not item:
+ # If no item specified, return generic rule (item = null)
+ genericRules = [r for r in rules if r.item is None]
+ return genericRules[0] if genericRules else None
+
+ # Find longest matching prefix
+ itemParts = item.split(".")
+ bestMatch = None
+ bestMatchLength = -1
+
+ for rule in rules:
+ if rule.item is None:
+ # Generic rule - use as fallback if no specific match found
+ if bestMatch is None:
+ bestMatch = rule
+ elif rule.item == item:
+ # Exact match - most specific
+ return rule
+ elif item.startswith(rule.item + "."):
+ # Prefix match - check if it's longer than current best
+ matchLength = len(rule.item.split("."))
+ if matchLength > bestMatchLength:
+ bestMatch = rule
+ bestMatchLength = matchLength
+
+ return bestMatch
+
+ def validateAccessRule(self, rule: AccessRule) -> bool:
+ """
+ Validate that CUD permissions are allowed by read permission level (only for DATA context).
+
+ Args:
+ rule: AccessRule to validate
+
+ Returns:
+ True if rule is valid, False otherwise
+ """
+ if rule.context != AccessRuleContext.DATA:
+ # For UI and RESOURCE contexts, only view is relevant
+ return True
+
+ if rule.read is None:
+ return False # DATA context requires read permission
+
+ readLevel = AccessLevel(rule.read)
+
+ # CUD operations are only allowed if read permission exists
+ for operation in [rule.create, rule.update, rule.delete]:
+ if operation is None or operation == AccessLevel.NONE.value:
+ continue # No access is always valid
+ if readLevel == AccessLevel.NONE:
+ return False # No CUD allowed if no read access
+ if readLevel == AccessLevel.MY and operation not in [AccessLevel.NONE.value, AccessLevel.MY.value]:
+ return False
+ if readLevel == AccessLevel.GROUP and operation not in [AccessLevel.NONE.value, AccessLevel.MY.value, AccessLevel.GROUP.value]:
+ return False
+
+ return True
+
+ def _isMorePermissive(self, level1: AccessLevel, level2: AccessLevel) -> bool:
+ """
+ Check if level1 is more permissive than level2.
+
+ Args:
+ level1: First access level
+ level2: Second access level
+
+ Returns:
+ True if level1 is more permissive than level2
+ """
+ hierarchy = {
+ AccessLevel.NONE: 0,
+ AccessLevel.MY: 1,
+ AccessLevel.GROUP: 2,
+ AccessLevel.ALL: 3
+ }
+ return hierarchy.get(level1, 0) > hierarchy.get(level2, 0)
+
+ def _getRulesForRole(self, roleLabel: str, context: AccessRuleContext) -> List[AccessRule]:
+ """
+ Get all access rules for a specific role and context.
+
+ Args:
+ roleLabel: Role label to get rules for
+ context: Context type
+
+ Returns:
+ List of AccessRule objects
+ """
+ try:
+ rules = self.db.getRecordset(
+ AccessRule,
+ recordFilter={
+ "roleLabel": roleLabel,
+ "context": context.value
+ }
+ )
+ # Convert dict records to AccessRule objects
+ return [AccessRule(**record) for record in rules]
+ except Exception as e:
+ logger.error(f"Error getting rules for role {roleLabel} and context {context.value}: {e}")
+ return []
diff --git a/modules/shared/rbacHelpers.py b/modules/shared/rbacHelpers.py
new file mode 100644
index 00000000..843a588a
--- /dev/null
+++ b/modules/shared/rbacHelpers.py
@@ -0,0 +1,178 @@
+"""
+RBAC helper functions for resource access control.
+Provides convenient functions for checking permissions in feature modules.
+"""
+
+import logging
+from typing import Optional
+from modules.datamodels.datamodelUam import User, AccessLevel
+from modules.datamodels.datamodelRbac import AccessRuleContext
+from modules.security.rbac import RbacClass
+from modules.connectors.connectorDbPostgre import DatabaseConnector
+
+logger = logging.getLogger(__name__)
+
+
+def checkResourceAccess(
+ RbacInstance: RbacClass,
+ currentUser: User,
+ resourcePath: str
+) -> bool:
+ """
+ Check if user has access to a resource.
+
+ Args:
+ RbacInstance: RbacClass instance
+ currentUser: Current user object
+ resourcePath: Resource path (e.g., "ai.model.anthropic", "ai.action.jira")
+
+ Returns:
+ True if user has view permission for the resource, False otherwise
+ """
+ try:
+ permissions = RbacInstance.getUserPermissions(
+ currentUser,
+ AccessRuleContext.RESOURCE,
+ resourcePath
+ )
+ return permissions.view
+ except Exception as e:
+ logger.error(f"Error checking resource access for {resourcePath}: {e}")
+ return False
+
+
+def checkUiAccess(
+ RbacInstance: RbacClass,
+ currentUser: User,
+ uiPath: str
+) -> bool:
+ """
+ Check if user has access to a UI element.
+
+ Args:
+ RbacInstance: RbacClass instance
+ currentUser: Current user object
+ uiPath: UI path (e.g., "playground.voice.settings", "chatbot.search")
+
+ Returns:
+ True if user has view permission for the UI element, False otherwise
+ """
+ try:
+ permissions = RbacInstance.getUserPermissions(
+ currentUser,
+ AccessRuleContext.UI,
+ uiPath
+ )
+ return permissions.view
+ except Exception as e:
+ logger.error(f"Error checking UI access for {uiPath}: {e}")
+ return False
+
+
+def checkDataAccess(
+ RbacInstance: RbacClass,
+ currentUser: User,
+ tableName: str,
+ operation: str = "read"
+) -> bool:
+ """
+ Check if user has access to a data table for a specific operation.
+
+ Args:
+ RbacInstance: RbacClass instance
+ currentUser: Current user object
+ tableName: Table name (e.g., "UserInDB", "Mandate")
+ operation: Operation to check ("read", "create", "update", "delete")
+
+ Returns:
+ True if user has permission for the operation, False otherwise
+ """
+ try:
+ permissions = RbacInstance.getUserPermissions(
+ currentUser,
+ AccessRuleContext.DATA,
+ tableName
+ )
+
+ if operation == "read":
+ return permissions.read != AccessLevel.NONE
+ elif operation == "create":
+ return permissions.create != AccessLevel.NONE
+ elif operation == "update":
+ return permissions.update != AccessLevel.NONE
+ elif operation == "delete":
+ return permissions.delete != AccessLevel.NONE
+ else:
+ logger.warning(f"Unknown operation: {operation}")
+ return False
+ except Exception as e:
+ logger.error(f"Error checking data access for {tableName}: {e}")
+ return False
+
+
+def getResourcePermissions(
+ RbacInstance: RbacClass,
+ currentUser: User,
+ resourcePath: str
+) -> dict:
+ """
+ Get full permissions for a resource.
+
+ Args:
+ RbacInstance: RbacClass instance
+ currentUser: Current user object
+ resourcePath: Resource path (e.g., "ai.model.anthropic")
+
+ Returns:
+ Dictionary with permission information
+ """
+ try:
+ permissions = RbacInstance.getUserPermissions(
+ currentUser,
+ AccessRuleContext.RESOURCE,
+ resourcePath
+ )
+ return {
+ "view": permissions.view,
+ "hasAccess": permissions.view
+ }
+ except Exception as e:
+ logger.error(f"Error getting resource permissions for {resourcePath}: {e}")
+ return {
+ "view": False,
+ "hasAccess": False
+ }
+
+
+def getUiPermissions(
+ RbacInstance: RbacClass,
+ currentUser: User,
+ uiPath: str
+) -> dict:
+ """
+ Get full permissions for a UI element.
+
+ Args:
+ RbacInstance: RbacClass instance
+ currentUser: Current user object
+ uiPath: UI path (e.g., "playground.voice.settings")
+
+ Returns:
+ Dictionary with permission information
+ """
+ try:
+ permissions = RbacInstance.getUserPermissions(
+ currentUser,
+ AccessRuleContext.UI,
+ uiPath
+ )
+ return {
+ "view": permissions.view,
+ "hasAccess": permissions.view
+ }
+ except Exception as e:
+ logger.error(f"Error getting UI permissions for {uiPath}: {e}")
+ return {
+ "view": False,
+ "hasAccess": False
+ }
diff --git a/pytest.ini b/pytest.ini
index ae59338f..ad1e22f2 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -3,7 +3,7 @@ testpaths = tests
pythonpath = .
python_files = test_*.py
python_classes = Test*
-python_functions = test_*
+python_functions = test*
log_file = logs/test_logs.log
log_file_level = INFO
log_file_format = %(asctime)s %(levelname)s %(message)s
diff --git a/tests/integration/rbac/README.md b/tests/integration/rbac/README.md
new file mode 100644
index 00000000..0c866c1d
--- /dev/null
+++ b/tests/integration/rbac/README.md
@@ -0,0 +1,42 @@
+# RBAC Integration Tests
+
+Integration tests for the Role-Based Access Control (RBAC) system.
+
+## Test Files
+
+### `test_rbac_database.py`
+Tests RBAC database filtering:
+- WHERE clause building for ALL access level
+- WHERE clause building for MY access level
+- WHERE clause building for GROUP access level
+- WHERE clause building for NONE access level
+- Special handling for UserInDB table
+- Special handling for UserConnection table
+
+### `test_rbac_migration.py`
+Tests UAM to RBAC migration:
+- User privilege to roleLabels conversion
+- Skipping users with existing roleLabels
+- Dry run mode
+- Migration validation
+- Validation failure scenarios
+
+## Running Tests
+
+```bash
+# Run all RBAC integration tests
+pytest tests/integration/rbac/
+
+# Run specific test file
+pytest tests/integration/rbac/test_rbac_database.py
+
+# Run with verbose output
+pytest tests/integration/rbac/ -v
+```
+
+## Test Coverage
+
+- Database query filtering with RBAC
+- SQL WHERE clause generation
+- Migration script functionality
+- Data validation after migration
diff --git a/tests/integration/rbac/__init__.py b/tests/integration/rbac/__init__.py
new file mode 100644
index 00000000..32a3a0b9
--- /dev/null
+++ b/tests/integration/rbac/__init__.py
@@ -0,0 +1 @@
+"""Integration tests for RBAC system."""
diff --git a/tests/integration/rbac/test_rbac_database.py b/tests/integration/rbac/test_rbac_database.py
new file mode 100644
index 00000000..34a51c30
--- /dev/null
+++ b/tests/integration/rbac/test_rbac_database.py
@@ -0,0 +1,209 @@
+"""
+Integration tests for RBAC database filtering.
+Tests that database queries correctly filter records based on RBAC rules.
+Uses real database connection for integration testing.
+"""
+
+import pytest
+from modules.connectors.connectorDbPostgre import DatabaseConnector
+from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions
+from modules.shared.configuration import APP_CONFIG
+
+
+@pytest.fixture(scope="class")
+def db():
+ """Create real database connector for integration tests."""
+ dbHost = APP_CONFIG.get("DB_HOST", "localhost")
+ dbDatabase = APP_CONFIG.get("DB_DATABASE", "poweron_test")
+ dbUser = APP_CONFIG.get("DB_USER", "postgres")
+ dbPassword = APP_CONFIG.get("DB_PASSWORD", "")
+ dbPort = APP_CONFIG.get("DB_PORT", 5432)
+
+ db = DatabaseConnector(
+ dbHost=dbHost,
+ dbDatabase=dbDatabase,
+ dbUser=dbUser,
+ dbPassword=dbPassword,
+ dbPort=dbPort
+ )
+ yield db
+ db.close()
+
+
+class TestRbacDatabaseFiltering:
+ """Test RBAC database filtering."""
+
+ def testBuildRbacWhereClauseAllAccess(self, db):
+ """Test WHERE clause building for ALL access level."""
+
+ permissions = UserPermissions(
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL
+ )
+
+ user = User(
+ id="test_user_all",
+ username="testuser",
+ roleLabels=["sysadmin"],
+ mandateId="test_mandate_all"
+ )
+
+ whereClause = db.buildRbacWhereClause(permissions, user, "SomeTable")
+
+ # ALL access should return None (no filtering)
+ assert whereClause is None
+
+ def testBuildRbacWhereClauseMyAccess(self, db):
+ """Test WHERE clause building for MY access level."""
+
+ permissions = UserPermissions(
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY
+ )
+
+ user = User(
+ id="test_user_my",
+ username="testuser",
+ roleLabels=["user"],
+ mandateId="test_mandate_my"
+ )
+
+ whereClause = db.buildRbacWhereClause(permissions, user, "SomeTable")
+
+ assert whereClause is not None
+ assert whereClause["condition"] == '"_createdBy" = %s'
+ assert whereClause["values"] == ["test_user_my"]
+
+ def testBuildRbacWhereClauseGroupAccess(self, db):
+ """Test WHERE clause building for GROUP access level."""
+
+ permissions = UserPermissions(
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP
+ )
+
+ user = User(
+ id="test_user_group",
+ username="testuser",
+ roleLabels=["admin"],
+ mandateId="test_mandate_group"
+ )
+
+ whereClause = db.buildRbacWhereClause(permissions, user, "SomeTable")
+
+ assert whereClause is not None
+ assert whereClause["condition"] == '"mandateId" = %s'
+ assert whereClause["values"] == ["test_mandate_group"]
+
+ def testBuildRbacWhereClauseNoAccess(self, db):
+ """Test WHERE clause building for NONE access level."""
+
+ permissions = UserPermissions(
+ view=True,
+ read=AccessLevel.NONE,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE
+ )
+
+ user = User(
+ id="test_user_none",
+ username="testuser",
+ roleLabels=["viewer"],
+ mandateId="test_mandate_none"
+ )
+
+ whereClause = db.buildRbacWhereClause(permissions, user, "SomeTable")
+
+ assert whereClause is not None
+ assert whereClause["condition"] == "1 = 0" # Always false
+ assert whereClause["values"] == []
+
+ def testBuildRbacWhereClauseUserInDBTable(self, db):
+ """Test WHERE clause building for UserInDB table with MY access."""
+
+ permissions = UserPermissions(
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY
+ )
+
+ user = User(
+ id="test_user_in_db",
+ username="testuser",
+ roleLabels=["user"],
+ mandateId="test_mandate_in_db"
+ )
+
+ whereClause = db.buildRbacWhereClause(permissions, user, "UserInDB")
+
+ # UserInDB with MY access should filter by id field
+ assert whereClause is not None
+ assert whereClause["condition"] == '"id" = %s'
+ assert whereClause["values"] == ["test_user_in_db"]
+
+ def testBuildRbacWhereClauseUserConnectionTable(self, db):
+ """Test WHERE clause building for UserConnection table with GROUP access."""
+ # Create test users in the same mandate for GROUP access testing
+ from modules.datamodels.datamodelUam import UserInDB
+ testMandateId = "test_mandate_group"
+
+ # Create test users
+ user1 = UserInDB(
+ id="test_user1",
+ username="testuser1",
+ mandateId=testMandateId
+ )
+ user2 = UserInDB(
+ id="test_user2",
+ username="testuser2",
+ mandateId=testMandateId
+ )
+
+ try:
+ user1Data = user1.model_dump()
+ user1Data["id"] = user1.id
+ user2Data = user2.model_dump()
+ user2Data["id"] = user2.id
+ db.recordCreate(UserInDB, user1Data)
+ db.recordCreate(UserInDB, user2Data)
+
+ permissions = UserPermissions(
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP
+ )
+
+ user = User(
+ id="test_user1",
+ username="testuser1",
+ roleLabels=["admin"],
+ mandateId=testMandateId
+ )
+
+ whereClause = db.buildRbacWhereClause(permissions, user, "UserConnection")
+
+ assert whereClause is not None
+ assert "userId" in whereClause["condition"]
+ assert "IN" in whereClause["condition"]
+ assert len(whereClause["values"]) >= 2
+ finally:
+ # Cleanup test users
+ try:
+ db.recordDelete(UserInDB, "test_user1")
+ db.recordDelete(UserInDB, "test_user2")
+ except:
+ pass
diff --git a/tests/integration/rbac/test_rbac_migration.py b/tests/integration/rbac/test_rbac_migration.py
new file mode 100644
index 00000000..86f3eb6d
--- /dev/null
+++ b/tests/integration/rbac/test_rbac_migration.py
@@ -0,0 +1,282 @@
+"""
+Integration tests for UAM to RBAC migration.
+Tests that migration correctly converts user privileges to roleLabels.
+Uses real database connection for integration testing.
+"""
+
+import pytest
+from modules.migration.migrateUamToRbac import migrateUamToRbac, validateMigration
+from modules.datamodels.datamodelUam import UserInDB, UserPrivilege
+from modules.connectors.connectorDbPostgre import DatabaseConnector
+from modules.shared.configuration import APP_CONFIG
+
+
+@pytest.fixture(scope="class")
+def db():
+ """Create real database connector for integration tests."""
+ dbHost = APP_CONFIG.get("DB_HOST", "localhost")
+ dbDatabase = APP_CONFIG.get("DB_DATABASE", "poweron_test")
+ dbUser = APP_CONFIG.get("DB_USER", "postgres")
+ dbPassword = APP_CONFIG.get("DB_PASSWORD", "")
+ dbPort = APP_CONFIG.get("DB_PORT", 5432)
+
+ db = DatabaseConnector(
+ dbHost=dbHost,
+ dbDatabase=dbDatabase,
+ dbUser=dbUser,
+ dbPassword=dbPassword,
+ dbPort=dbPort
+ )
+ yield db
+ db.close()
+
+
+class TestRbacMigration:
+ """Test RBAC migration from UAM."""
+
+ def testMigrateUserPrivilegeToRoleLabels(self, db):
+ """Test that user privileges are correctly converted to roleLabels."""
+ # Create test users with privileges but no roleLabels
+ testUsers = [
+ UserInDB(
+ id="migrate_test_user1",
+ username="migrate_admin",
+ privilege=UserPrivilege.SYSADMIN.value
+ ),
+ UserInDB(
+ id="migrate_test_user2",
+ username="migrate_admin2",
+ privilege=UserPrivilege.ADMIN.value
+ ),
+ UserInDB(
+ id="migrate_test_user3",
+ username="migrate_user1",
+ privilege=UserPrivilege.USER.value
+ )
+ ]
+
+ try:
+ # Create test users in database
+ for user in testUsers:
+ userData = user.model_dump()
+ # Ensure roleLabels is None/empty for migration test
+ userData["roleLabels"] = []
+ userData["id"] = user.id
+ db.recordCreate(UserInDB, userData)
+
+ # Run migration
+ results = migrateUamToRbac(db, dryRun=False)
+
+ # Check that users were updated
+ assert results["usersUpdated"] == 3
+
+ # Verify users were actually updated in database
+ users1 = db.getRecordset(UserInDB, recordFilter={"id": "migrate_test_user1"})
+ users2 = db.getRecordset(UserInDB, recordFilter={"id": "migrate_test_user2"})
+ users3 = db.getRecordset(UserInDB, recordFilter={"id": "migrate_test_user3"})
+ user1 = users1[0] if users1 else None
+ user2 = users2[0] if users2 else None
+ user3 = users3[0] if users3 else None
+
+ assert user1 is not None
+ assert "sysadmin" in user1.get("roleLabels", [])
+
+ assert user2 is not None
+ assert "admin" in user2.get("roleLabels", [])
+
+ assert user3 is not None
+ assert "user" in user3.get("roleLabels", [])
+ finally:
+ # Cleanup test users
+ for user in testUsers:
+ try:
+ db.recordDelete(UserInDB, user.id)
+ except:
+ pass
+
+ def testMigrationSkipsUsersWithExistingRoleLabels(self, db):
+ """Test that migration skips users who already have roleLabels."""
+ # Create test users: one with roleLabels, one without
+ user1 = UserInDB(
+ id="skip_test_user1",
+ username="skip_admin",
+ privilege=UserPrivilege.SYSADMIN.value,
+ roleLabels=["sysadmin"] # Already migrated
+ )
+ user2 = UserInDB(
+ id="skip_test_user2",
+ username="skip_user1",
+ privilege=UserPrivilege.USER.value,
+ roleLabels=[] # Needs migration
+ )
+
+ try:
+ # Create test users in database
+ user1Data = user1.model_dump()
+ user1Data["id"] = user1.id
+ user2Data = user2.model_dump()
+ user2Data["id"] = user2.id
+ db.recordCreate(UserInDB, user1Data)
+ db.recordCreate(UserInDB, user2Data)
+
+ # Run migration
+ results = migrateUamToRbac(db, dryRun=False)
+
+ # Only one user should be updated (user2)
+ assert results["usersUpdated"] == 1
+
+ # Verify user1 still has original roleLabels
+ users1 = db.getRecordset(UserInDB, recordFilter={"id": "skip_test_user1"})
+ updatedUser1 = users1[0] if users1 else None
+ assert updatedUser1 is not None
+ assert "sysadmin" in updatedUser1.get("roleLabels", [])
+
+ # Verify user2 was updated
+ users2 = db.getRecordset(UserInDB, recordFilter={"id": "skip_test_user2"})
+ updatedUser2 = users2[0] if users2 else None
+ assert updatedUser2 is not None
+ assert "user" in updatedUser2.get("roleLabels", [])
+ finally:
+ # Cleanup test users
+ try:
+ db.recordDelete(UserInDB, "skip_test_user1")
+ db.recordDelete(UserInDB, "skip_test_user2")
+ except:
+ pass
+
+ def testDryRunMode(self, db):
+ """Test that dry run mode doesn't make changes."""
+ # Create test user without roleLabels
+ testUser = UserInDB(
+ id="dryrun_test_user1",
+ username="dryrun_admin",
+ privilege=UserPrivilege.SYSADMIN.value,
+ roleLabels=[] # Needs migration
+ )
+
+ try:
+ # Create test user in database
+ userData = testUser.model_dump()
+ userData["id"] = testUser.id
+ db.recordCreate(UserInDB, userData)
+
+ # Get original state
+ originalUsers = db.getRecordset(UserInDB, recordFilter={"id": "dryrun_test_user1"})
+ originalUser = originalUsers[0] if originalUsers else None
+ assert originalUser is not None
+ originalRoleLabels = originalUser.get("roleLabels", [])
+
+ # Run migration in dry run mode
+ results = migrateUamToRbac(db, dryRun=True)
+
+ # Should report what would be done
+ assert results["usersUpdated"] == 1
+
+ # Verify user was NOT actually updated
+ unchangedUsers = db.getRecordset(UserInDB, recordFilter={"id": "dryrun_test_user1"})
+ unchangedUser = unchangedUsers[0] if unchangedUsers else None
+ assert unchangedUser is not None
+ assert unchangedUser.get("roleLabels", []) == originalRoleLabels
+ finally:
+ # Cleanup test user
+ try:
+ db.recordDelete(UserInDB, "dryrun_test_user1")
+ except:
+ pass
+
+ def testValidateMigrationSuccess(self, db):
+ """Test validation passes when migration is successful."""
+ # Create test users with roleLabels (already migrated)
+ testUsers = [
+ UserInDB(
+ id="validate_test_user1",
+ username="validate_admin",
+ privilege=UserPrivilege.SYSADMIN.value,
+ roleLabels=["sysadmin"]
+ ),
+ UserInDB(
+ id="validate_test_user2",
+ username="validate_admin2",
+ privilege=UserPrivilege.ADMIN.value,
+ roleLabels=["admin"]
+ )
+ ]
+
+ try:
+ # Create test users in database
+ for user in testUsers:
+ userData = user.model_dump()
+ userData["id"] = user.id
+ db.recordCreate(UserInDB, userData)
+
+ # Ensure AccessRule table exists (migration should have created it)
+ from modules.datamodels.datamodelRbac import AccessRule
+ db._ensureTableExists(AccessRule)
+
+ # Run validation
+ validation = validateMigration(db)
+
+ assert validation["valid"] == True
+ assert len(validation["issues"]) == 0
+ finally:
+ # Cleanup test users
+ for user in testUsers:
+ try:
+ db.recordDelete(UserInDB, user.id)
+ except:
+ pass
+
+ def testValidateMigrationFailsWithoutRoleLabels(self, db):
+ """Test validation fails when users don't have roleLabels."""
+ # Create test users: one with roleLabels, one without, one with empty roleLabels
+ testUsers = [
+ UserInDB(
+ id="validate_fail_user1",
+ username="validate_fail_admin",
+ privilege=UserPrivilege.SYSADMIN.value,
+ roleLabels=["sysadmin"] # Has roleLabels
+ ),
+ UserInDB(
+ id="validate_fail_user2",
+ username="validate_fail_user",
+ privilege=UserPrivilege.USER.value,
+ roleLabels=[] # Empty roleLabels
+ ),
+ UserInDB(
+ id="validate_fail_user3",
+ username="validate_fail_user2",
+ privilege=UserPrivilege.USER.value
+ # Missing roleLabels field (will be None)
+ )
+ ]
+
+ try:
+ # Create test users in database
+ for user in testUsers:
+ userData = user.model_dump()
+ userData["id"] = user.id
+ # For user3, explicitly set roleLabels to None or remove it
+ if user.id == "validate_fail_user3":
+ if "roleLabels" in userData:
+ del userData["roleLabels"]
+ db.recordCreate(UserInDB, userData)
+
+ # Ensure AccessRule table exists
+ from modules.datamodels.datamodelRbac import AccessRule
+ db._ensureTableExists(AccessRule)
+
+ # Run validation
+ validation = validateMigration(db)
+
+ assert validation["valid"] == False
+ assert len(validation["issues"]) > 0
+ # Check that validation found users without roleLabels
+ issuesStr = " ".join(validation["issues"])
+ assert "users without roleLabels" in issuesStr or "without roleLabels" in issuesStr
+ finally:
+ # Cleanup test users
+ for user in testUsers:
+ try:
+ db.recordDelete(UserInDB, user.id)
+ except:
+ pass
diff --git a/tests/unit/rbac/README.md b/tests/unit/rbac/README.md
new file mode 100644
index 00000000..3666ef2a
--- /dev/null
+++ b/tests/unit/rbac/README.md
@@ -0,0 +1,47 @@
+# RBAC Unit Tests
+
+Unit tests for the Role-Based Access Control (RBAC) system.
+
+## Test Files
+
+### `test_rbac_permissions.py`
+Tests RBAC permission resolution logic:
+- Single role with generic rules
+- Rule specificity (most specific wins)
+- Multiple roles with union logic
+- View permission overrides
+- No roles scenario
+- Finding most specific rules
+- Opening rights validation
+- UI and RESOURCE context handling
+
+### `test_rbac_bootstrap.py`
+Tests RBAC bootstrap initialization:
+- Root mandate creation
+- Admin user creation with sysadmin role
+- Event user creation with sysadmin role
+- Default role rules creation
+- Table-specific rules creation
+- Rule initialization skipping when rules exist
+
+## Running Tests
+
+```bash
+# Run all RBAC unit tests
+pytest tests/unit/rbac/
+
+# Run specific test file
+pytest tests/unit/rbac/test_rbac_permissions.py
+
+# Run with verbose output
+pytest tests/unit/rbac/ -v
+```
+
+## Test Coverage
+
+- Permission resolution algorithms
+- Rule specificity logic
+- Multiple role combination (union logic)
+- Access rule validation
+- Bootstrap initialization
+- Default rule creation
diff --git a/tests/unit/rbac/__init__.py b/tests/unit/rbac/__init__.py
new file mode 100644
index 00000000..5d55b3ca
--- /dev/null
+++ b/tests/unit/rbac/__init__.py
@@ -0,0 +1 @@
+"""Unit tests for RBAC system."""
diff --git a/tests/unit/rbac/test_rbac_bootstrap.py b/tests/unit/rbac/test_rbac_bootstrap.py
new file mode 100644
index 00000000..573a4fd1
--- /dev/null
+++ b/tests/unit/rbac/test_rbac_bootstrap.py
@@ -0,0 +1,162 @@
+"""
+Unit tests for RBAC bootstrap initialization.
+Tests that bootstrap creates correct rules and initial data.
+"""
+
+import pytest
+from unittest.mock import Mock, MagicMock, patch
+from modules.interfaces.interfaceBootstrap import (
+ initBootstrap,
+ initRootMandate,
+ initAdminUser,
+ initEventUser,
+ initRbacRules,
+ createDefaultRoleRules,
+ createTableSpecificRules
+)
+from modules.datamodels.datamodelUam import UserInDB, Mandate, UserPrivilege, AuthAuthority
+from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
+from modules.datamodels.datamodelUam import AccessLevel
+
+
+class TestRbacBootstrap:
+ """Test RBAC bootstrap initialization."""
+
+ def testInitRootMandateCreatesIfNotExists(self):
+ """Test that initRootMandate creates mandate if it doesn't exist."""
+ db = Mock()
+ db.getRecordset = Mock(return_value=[]) # No existing mandates
+ db.recordCreate = Mock(return_value={"id": "mandate1", "name": "Root"})
+
+ mandateId = initRootMandate(db)
+
+ assert mandateId == "mandate1"
+ db.recordCreate.assert_called_once()
+ callArgs = db.recordCreate.call_args
+ assert isinstance(callArgs[0][1], Mandate)
+ assert callArgs[0][1].name == "Root"
+
+ def testInitRootMandateReturnsExisting(self):
+ """Test that initRootMandate returns existing mandate ID."""
+ db = Mock()
+ db.getRecordset = Mock(return_value=[{"id": "existing_mandate"}])
+
+ mandateId = initRootMandate(db)
+
+ assert mandateId == "existing_mandate"
+ db.recordCreate.assert_not_called()
+
+ def testInitAdminUserCreatesWithSysadminRole(self):
+ """Test that initAdminUser creates user with sysadmin role."""
+ db = Mock()
+ db.getRecordset = Mock(return_value=[]) # No existing users
+ db.recordCreate = Mock(return_value={"id": "admin1", "username": "admin"})
+
+ with patch('modules.interfaces.interfaceBootstrap._getPasswordHash', return_value="hashed"):
+ userId = initAdminUser(db, "mandate1")
+
+ assert userId == "admin1"
+ db.recordCreate.assert_called_once()
+ callArgs = db.recordCreate.call_args
+ user = callArgs[0][1]
+ assert isinstance(user, UserInDB)
+ assert user.username == "admin"
+ assert "sysadmin" in user.roleLabels
+ assert user.privilege == UserPrivilege.SYSADMIN
+
+ def testInitEventUserCreatesWithSysadminRole(self):
+ """Test that initEventUser creates user with sysadmin role."""
+ db = Mock()
+ db.getRecordset = Mock(return_value=[]) # No existing users
+ db.recordCreate = Mock(return_value={"id": "event1", "username": "event"})
+
+ with patch('modules.interfaces.interfaceBootstrap._getPasswordHash', return_value="hashed"):
+ userId = initEventUser(db, "mandate1")
+
+ assert userId == "event1"
+ db.recordCreate.assert_called_once()
+ callArgs = db.recordCreate.call_args
+ user = callArgs[0][1]
+ assert isinstance(user, UserInDB)
+ assert user.username == "event"
+ assert "sysadmin" in user.roleLabels
+
+ def testCreateDefaultRoleRules(self):
+ """Test that createDefaultRoleRules creates correct default rules."""
+ db = Mock()
+ db.recordCreate = Mock()
+
+ createDefaultRoleRules(db)
+
+ # Should create 4 default rules (sysadmin, admin, user, viewer)
+ assert db.recordCreate.call_count == 4
+
+ # Check sysadmin rule
+ sysadminCall = [call for call in db.recordCreate.call_args_list
+ if call[0][1].roleLabel == "sysadmin"][0]
+ sysadminRule = sysadminCall[0][1]
+ assert sysadminRule.context == AccessRuleContext.DATA
+ assert sysadminRule.item is None
+ assert sysadminRule.view == True
+ assert sysadminRule.read == AccessLevel.ALL
+ assert sysadminRule.create == AccessLevel.ALL
+
+ # Check user rule
+ userCall = [call for call in db.recordCreate.call_args_list
+ if call[0][1].roleLabel == "user"][0]
+ userRule = userCall[0][1]
+ assert userRule.read == AccessLevel.MY
+ assert userRule.create == AccessLevel.MY
+
+ def testCreateTableSpecificRules(self):
+ """Test that createTableSpecificRules creates table-specific rules."""
+ db = Mock()
+ db.recordCreate = Mock()
+
+ createTableSpecificRules(db)
+
+ # Should create multiple rules for different tables
+ assert db.recordCreate.call_count > 0
+
+ # Check that Mandate table rules are created
+ mandateCalls = [call for call in db.recordCreate.call_args_list
+ if call[0][1].item == "Mandate"]
+ assert len(mandateCalls) > 0
+
+ # Check sysadmin rule for Mandate
+ sysadminMandateCall = [call for call in mandateCalls
+ if call[0][1].roleLabel == "sysadmin"][0]
+ sysadminRule = sysadminMandateCall[0][1]
+ assert sysadminRule.view == True
+ assert sysadminRule.read == AccessLevel.ALL
+
+ # Check that other roles have view=False for Mandate
+ otherMandateCalls = [call for call in mandateCalls
+ if call[0][1].roleLabel != "sysadmin"]
+ for call in otherMandateCalls:
+ rule = call[0][1]
+ assert rule.view == False
+
+ def testInitRbacRulesSkipsIfExists(self):
+ """Test that initRbacRules skips creation if rules already exist."""
+ db = Mock()
+ db.getRecordset = Mock(return_value=[{"id": "rule1"}]) # Rules exist
+
+ initRbacRules(db)
+
+ # Should not create new rules
+ db.recordCreate.assert_not_called()
+
+ def testInitRbacRulesCreatesIfNotExists(self):
+ """Test that initRbacRules creates rules if they don't exist."""
+ db = Mock()
+ db.getRecordset = Mock(side_effect=[
+ [], # No existing rules
+ [] # After creating default rules
+ ])
+ db.recordCreate = Mock()
+
+ initRbacRules(db)
+
+ # Should create rules
+ assert db.recordCreate.call_count > 0
diff --git a/tests/unit/rbac/test_rbac_permissions.py b/tests/unit/rbac/test_rbac_permissions.py
new file mode 100644
index 00000000..d180f5b8
--- /dev/null
+++ b/tests/unit/rbac/test_rbac_permissions.py
@@ -0,0 +1,403 @@
+"""
+Unit tests for RBAC permission resolution.
+Tests rule specificity, multiple roles, and permission combination logic.
+"""
+
+import pytest
+from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions
+from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
+from modules.security.rbac import RbacClass
+from modules.connectors.connectorDbPostgre import DatabaseConnector
+from unittest.mock import Mock, MagicMock
+
+
+class TestRbacPermissionResolution:
+ """Test RBAC permission resolution logic."""
+
+ def testSingleRoleGenericRule(self):
+ """Test permission resolution with a single role and generic rule."""
+ # Mock database connector
+ db = Mock(spec=DatabaseConnector)
+
+ # Create RBAC interface
+ rbac = RbacClass(db)
+
+ # Create user with single role
+ user = User(
+ id="user1",
+ username="testuser",
+ roleLabels=["user"],
+ mandateId="mandate1"
+ )
+
+ # Mock rules for "user" role
+ def mockGetRulesForRole(roleLabel, context):
+ if roleLabel == "user" and context == AccessRuleContext.DATA:
+ return [
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item=None, # Generic rule
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY
+ )
+ ]
+ return []
+
+ rbac._getRulesForRole = mockGetRulesForRole
+
+ # Get permissions for generic table
+ permissions = rbac.getUserPermissions(
+ user,
+ AccessRuleContext.DATA,
+ "SomeTable"
+ )
+
+ assert permissions.view == True
+ assert permissions.read == AccessLevel.MY
+ assert permissions.create == AccessLevel.MY
+ assert permissions.update == AccessLevel.MY
+ assert permissions.delete == AccessLevel.MY
+
+ def testRuleSpecificityMostSpecificWins(self):
+ """Test that most specific rule wins within a single role."""
+ db = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db)
+
+ user = User(
+ id="user1",
+ username="testuser",
+ roleLabels=["user"],
+ mandateId="mandate1"
+ )
+
+ def mockGetRulesForRole(roleLabel, context):
+ if roleLabel == "user" and context == AccessRuleContext.DATA:
+ return [
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item=None, # Generic rule
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP
+ ),
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="UserInDB", # Specific rule
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.MY,
+ delete=AccessLevel.NONE
+ )
+ ]
+ return []
+
+ rbac._getRulesForRole = mockGetRulesForRole
+
+ # Get permissions for UserInDB table - should use specific rule
+ permissions = rbac.getUserPermissions(
+ user,
+ AccessRuleContext.DATA,
+ "UserInDB"
+ )
+
+ # Most specific rule should win
+ assert permissions.read == AccessLevel.MY
+ assert permissions.create == AccessLevel.NONE
+ assert permissions.update == AccessLevel.MY
+ assert permissions.delete == AccessLevel.NONE
+
+ def testMultipleRolesUnionLogic(self):
+ """Test that multiple roles use union (opening) logic."""
+ db = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db)
+
+ # User with multiple roles
+ user = User(
+ id="user1",
+ username="testuser",
+ roleLabels=["user", "viewer"],
+ mandateId="mandate1"
+ )
+
+ def mockGetRulesForRole(roleLabel, context):
+ if context == AccessRuleContext.UI:
+ if roleLabel == "user":
+ return [
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.UI,
+ item="playground",
+ view=False # User role hides playground
+ )
+ ]
+ elif roleLabel == "viewer":
+ return [
+ AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.UI,
+ item="playground",
+ view=True # Viewer role shows playground
+ )
+ ]
+ return []
+
+ rbac._getRulesForRole = mockGetRulesForRole
+
+ # Get permissions - union logic should make playground visible
+ permissions = rbac.getUserPermissions(
+ user,
+ AccessRuleContext.UI,
+ "playground"
+ )
+
+ # Union logic: if ANY role has view=true, then view=true
+ assert permissions.view == True
+
+ def testViewFalseOverridesGeneric(self):
+ """Test that specific view=false overrides generic view=true."""
+ db = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db)
+
+ user = User(
+ id="user1",
+ username="testuser",
+ roleLabels=["user"],
+ mandateId="mandate1"
+ )
+
+ def mockGetRulesForRole(roleLabel, context):
+ if roleLabel == "user" and context == AccessRuleContext.UI:
+ return [
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.UI,
+ item=None, # Generic: view all UI
+ view=True
+ ),
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.UI,
+ item="playground.voice.settings", # Specific: hide this
+ view=False
+ )
+ ]
+ return []
+
+ rbac._getRulesForRole = mockGetRulesForRole
+
+ # Get permissions for specific UI element
+ permissions = rbac.getUserPermissions(
+ user,
+ AccessRuleContext.UI,
+ "playground.voice.settings"
+ )
+
+ # Specific rule should override generic
+ assert permissions.view == False
+
+ def testNoRolesReturnsNoAccess(self):
+ """Test that user with no roles gets no access."""
+ db = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db)
+
+ user = User(
+ id="user1",
+ username="testuser",
+ roleLabels=[], # No roles
+ mandateId="mandate1"
+ )
+
+ permissions = rbac.getUserPermissions(
+ user,
+ AccessRuleContext.DATA,
+ "SomeTable"
+ )
+
+ assert permissions.view == False
+ assert permissions.read == AccessLevel.NONE
+ assert permissions.create == AccessLevel.NONE
+ assert permissions.update == AccessLevel.NONE
+ assert permissions.delete == AccessLevel.NONE
+
+ def testFindMostSpecificRule(self):
+ """Test findMostSpecificRule method."""
+ db = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db)
+
+ rules = [
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item=None, # Generic
+ view=True,
+ read=AccessLevel.GROUP
+ ),
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="UserInDB", # Table-level
+ view=True,
+ read=AccessLevel.MY
+ ),
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="UserInDB.email", # Field-level - most specific
+ view=True,
+ read=AccessLevel.NONE
+ )
+ ]
+
+ # Test exact match
+ rule = rbac.findMostSpecificRule(rules, "UserInDB.email")
+ assert rule is not None
+ assert rule.item == "UserInDB.email"
+ assert rule.read == AccessLevel.NONE
+
+ # Test table-level match
+ rule = rbac.findMostSpecificRule(rules, "UserInDB")
+ assert rule is not None
+ assert rule.item == "UserInDB"
+ assert rule.read == AccessLevel.MY
+
+ # Test generic fallback
+ rule = rbac.findMostSpecificRule(rules, "OtherTable")
+ assert rule is not None
+ assert rule.item is None
+ assert rule.read == AccessLevel.GROUP
+
+ def testValidateAccessRuleOpeningRights(self):
+ """Test that CUD permissions respect read permission level."""
+ db = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db)
+
+ # Valid: Read=MY, Create=MY (allowed)
+ rule1 = AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="UserInDB",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY
+ )
+ assert rbac.validateAccessRule(rule1) == True
+
+ # Invalid: Read=MY, Create=GROUP (not allowed - GROUP > MY)
+ rule2 = AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="UserInDB",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.GROUP, # Not allowed
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY
+ )
+ assert rbac.validateAccessRule(rule2) == False
+
+ # Valid: Read=GROUP, Create=GROUP (allowed)
+ rule3 = AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.DATA,
+ item="UserInDB",
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP
+ )
+ assert rbac.validateAccessRule(rule3) == True
+
+ # Invalid: Read=NONE, Create=MY (not allowed - no read access)
+ rule4 = AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="UserInDB",
+ view=True,
+ read=AccessLevel.NONE,
+ create=AccessLevel.MY, # Not allowed without read
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY
+ )
+ assert rbac.validateAccessRule(rule4) == False
+
+ def testUiContextOnlyViewMatters(self):
+ """Test that UI context only checks view permission."""
+ db = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db)
+
+ user = User(
+ id="user1",
+ username="testuser",
+ roleLabels=["user"],
+ mandateId="mandate1"
+ )
+
+ def mockGetRulesForRole(roleLabel, context):
+ if roleLabel == "user" and context == AccessRuleContext.UI:
+ return [
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.UI,
+ item="playground",
+ view=True
+ # No read/create/update/delete for UI context
+ )
+ ]
+ return []
+
+ rbac._getRulesForRole = mockGetRulesForRole
+
+ permissions = rbac.getUserPermissions(
+ user,
+ AccessRuleContext.UI,
+ "playground"
+ )
+
+ assert permissions.view == True
+ # Other permissions don't matter for UI context
+
+ def testResourceContextOnlyViewMatters(self):
+ """Test that RESOURCE context only checks view permission."""
+ db = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db)
+
+ user = User(
+ id="user1",
+ username="testuser",
+ roleLabels=["user"],
+ mandateId="mandate1"
+ )
+
+ def mockGetRulesForRole(roleLabel, context):
+ if roleLabel == "user" and context == AccessRuleContext.RESOURCE:
+ return [
+ AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.RESOURCE,
+ item="ai.model.anthropic",
+ view=True
+ )
+ ]
+ return []
+
+ rbac._getRulesForRole = mockGetRulesForRole
+
+ permissions = rbac.getUserPermissions(
+ user,
+ AccessRuleContext.RESOURCE,
+ "ai.model.anthropic"
+ )
+
+ assert permissions.view == True
From 6e6cf7012b4d74b2057ae2852c911a62915688b7 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Sun, 7 Dec 2025 22:00:55 +0100
Subject: [PATCH 3/6] rbac module testing done
---
docs/rbac_getrecordset_review.md | 135 +++++++++
modules/datamodels/datamodelUam.py | 7 -
modules/interfaces/interfaceBootstrap.py | 3 -
modules/interfaces/interfaceDbAppObjects.py | 36 ++-
modules/interfaces/interfaceDbChatObjects.py | 28 +-
.../interfaces/interfaceDbComponentObjects.py | 23 +-
modules/migration/__init__.py | 1 -
modules/migration/migrateUamToRbac.py | 212 -------------
modules/routes/routeAdminAutomationEvents.py | 8 +-
modules/routes/routeDataUsers.py | 6 +-
modules/routes/routeSecurityAdmin.py | 11 +-
modules/routes/routeSecurityLocal.py | 7 +-
modules/routes/routeWorkflows.py | 8 +-
tests/integration/rbac/test_rbac_migration.py | 282 ------------------
tests/unit/rbac/test_rbac_bootstrap.py | 3 +-
15 files changed, 209 insertions(+), 561 deletions(-)
create mode 100644 docs/rbac_getrecordset_review.md
delete mode 100644 modules/migration/__init__.py
delete mode 100644 modules/migration/migrateUamToRbac.py
delete mode 100644 tests/integration/rbac/test_rbac_migration.py
diff --git a/docs/rbac_getrecordset_review.md b/docs/rbac_getrecordset_review.md
new file mode 100644
index 00000000..d2c06524
--- /dev/null
+++ b/docs/rbac_getrecordset_review.md
@@ -0,0 +1,135 @@
+# RBAC getRecordset() Review
+
+## Overview
+Review of all `getRecordset()` calls in `interfaceDbChatObjects.py` and `interfaceDbComponentObjects.py` to determine which should be converted to `getRecordsetWithRBAC()`.
+
+## Analysis Criteria
+- **Convert to RBAC**: User-facing data that should respect access control
+- **Keep as-is**: Internal/technical operations that don't need RBAC filtering
+
+---
+
+## interfaceDbChatObjects.py
+
+### Summary: **14 calls found - ALL should be converted to `getRecordsetWithRBAC()`**
+
+All calls access user-facing data (ChatMessage, ChatDocument, ChatStat, ChatLog) and should respect RBAC even when:
+- Used in cascade delete operations (after parent access is verified)
+- Used to fetch child records (after parent access is verified)
+- Used for existence checks
+
+**Rationale**: RBAC should be applied at every data access point to ensure consistent security and prevent potential bypass scenarios.
+
+### Detailed List:
+
+1. **Line 760** - `deleteWorkflow()` - Cascade delete ChatStat
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"messageId": messageId})`
+ - **Reason**: Deleting related data should respect RBAC
+
+2. **Line 765** - `deleteWorkflow()` - Cascade delete ChatDocument
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})`
+ - **Reason**: Deleting related data should respect RBAC
+
+3. **Line 773** - `deleteWorkflow()` - Cascade delete ChatStat (workflow level)
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"workflowId": workflowId})`
+ - **Reason**: Deleting related data should respect RBAC
+
+4. **Line 778** - `deleteWorkflow()` - Cascade delete ChatLog
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})`
+ - **Reason**: Deleting related data should respect RBAC
+
+5. **Line 821** - `getMessages()` - Fetch messages for workflow
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId})`
+ - **Reason**: Child records should still respect RBAC even if parent access is verified
+
+6. **Line 1062** - `updateMessage()` - Check if message exists
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatMessage, self.currentUser, recordFilter={"id": messageId})`
+ - **Reason**: Existence checks should respect RBAC
+
+7. **Line 1167** - `deleteMessage()` - Cascade delete ChatStat
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"messageId": messageId})`
+ - **Reason**: Deleting related data should respect RBAC
+
+8. **Line 1172** - `deleteMessage()` - Cascade delete ChatDocument
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})`
+ - **Reason**: Deleting related data should respect RBAC
+
+9. **Line 1199** - `deleteFileFromMessage()` - Get documents for message
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})`
+ - **Reason**: Accessing related data should respect RBAC
+
+10. **Line 1242** - `getDocuments()` - Get documents for message
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})`
+ - **Reason**: Public method accessing user data should respect RBAC
+
+11. **Line 1291** - `getLogs()` - Fetch logs for workflow
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})`
+ - **Reason**: Child records should still respect RBAC even if parent access is verified
+
+12. **Line 1410** - `getStats()` - Fetch stats for workflow
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"workflowId": workflowId})`
+ - **Reason**: Child records should still respect RBAC even if parent access is verified
+
+13. **Line 1460** - `getUnifiedChatData()` - Fetch messages for workflow
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId})`
+ - **Reason**: Child records should still respect RBAC even if parent access is verified
+
+14. **Line 1501** - `getUnifiedChatData()` - Fetch logs for workflow
+ - **Action**: Convert to `getRecordsetWithRBAC(ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})`
+ - **Reason**: Child records should still respect RBAC even if parent access is verified
+
+---
+
+## interfaceDbComponentObjects.py
+
+### Summary: **3 calls found - 1 keep as-is, 2 should be converted**
+
+### Detailed List:
+
+1. **Line 149** - `_initializeStandardPrompts()` - Check if prompts exist
+ - **Action**: **KEEP AS-IS** ✅
+ - **Reason**: This is initialization code that runs during bootstrap. It checks if any prompts exist to avoid re-initialization. Since this runs with root user context and is a system-level check, RBAC is not needed here.
+
+2. **Line 947** - `deleteFile()` - Get FileData for deletion
+ - **Action**: **CONVERT** to `getRecordsetWithRBAC(FileData, self.currentUser, recordFilter={"id": fileId})`
+ - **Reason**: FileData stores binary data associated with FileItem. While it's a technical table, we should still respect RBAC for consistency and security. The file access was already checked via `getFile()`, but FileData access should also be RBAC-filtered.
+
+3. **Line 1032** - `getFileData()` - Get FileData for reading
+ - **Action**: **CONVERT** to `getRecordsetWithRBAC(FileData, self.currentUser, recordFilter={"id": fileId})`
+ - **Reason**: FileData access should respect RBAC. The file access was already checked via `getFile()`, but FileData access should also be RBAC-filtered for consistency.
+
+**Note on FileData**: FileData is a technical table storing binary file content. However, for consistency and security, RBAC should still be applied. If FileData doesn't have RBAC rules defined, the RBAC filter will effectively be a no-op (allowing access), but the pattern is consistent.
+
+---
+
+## Implementation Priority
+
+### High Priority (User-facing data access)
+- All `interfaceDbChatObjects.py` calls (14 calls)
+- `interfaceDbComponentObjects.py` FileData calls (2 calls)
+
+### Low Priority (System initialization)
+- `interfaceDbComponentObjects.py` Prompt initialization check (1 call) - Keep as-is
+
+---
+
+## Next Steps
+
+1. Convert all 14 calls in `interfaceDbChatObjects.py` to `getRecordsetWithRBAC()`
+2. Convert 2 FileData calls in `interfaceDbComponentObjects.py` to `getRecordsetWithRBAC()`
+3. Keep 1 Prompt initialization check as-is
+4. Test all changes to ensure RBAC filtering works correctly
+5. Verify cascade delete operations still work correctly with RBAC
+
+---
+
+## Testing Checklist
+
+After conversion, verify:
+- [ ] Workflow deletion still works (cascade deletes)
+- [ ] Message deletion still works (cascade deletes)
+- [ ] File deletion still works (FileData cleanup)
+- [ ] File reading still works (FileData access)
+- [ ] Child record access (messages, logs, stats, documents) respects RBAC
+- [ ] Users can only access data they have permission for
+- [ ] No performance degradation from RBAC filtering
diff --git a/modules/datamodels/datamodelUam.py b/modules/datamodels/datamodelUam.py
index 4c9e0a84..49e62beb 100644
--- a/modules/datamodels/datamodelUam.py
+++ b/modules/datamodels/datamodelUam.py
@@ -13,11 +13,6 @@ class AuthAuthority(str, Enum):
GOOGLE = "google"
MSFT = "msft"
-class UserPrivilege(str, Enum): # TODO: TO remove, one new RBAC System is in place!
- SYSADMIN = "sysadmin"
- ADMIN = "admin"
- USER = "user"
-
class ConnectionStatus(str, Enum):
ACTIVE = "active"
EXPIRED = "expired"
@@ -152,7 +147,6 @@ class User(BaseModel):
{"value": "it", "label": {"en": "Italiano", "fr": "Italien"}},
]})
enabled: bool = Field(default=True, description="Indicates whether the user is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False})
- privilege: UserPrivilege = Field(default=UserPrivilege.USER, description="Permission level (DEPRECATED: use roleLabels instead)", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": "user.role"})
roleLabels: List[str] = Field(
default_factory=list,
description="List of role labels assigned to this user. All roles are opening roles (union) - if one role enables something, it is enabled.",
@@ -174,7 +168,6 @@ registerModelLabels(
"fullName": {"en": "Full Name", "fr": "Nom complet"},
"language": {"en": "Language", "fr": "Langue"},
"enabled": {"en": "Enabled", "fr": "Activé"},
- "privilege": {"en": "Privilege", "fr": "Privilège"},
"roleLabels": {"en": "Role Labels", "fr": "Labels de rôle"},
"authenticationAuthority": {"en": "Auth Authority", "fr": "Autorité d'authentification"},
"mandateId": {"en": "Mandate ID", "fr": "ID de mandat"},
diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py
index 5c4a90a1..55d94c3c 100644
--- a/modules/interfaces/interfaceBootstrap.py
+++ b/modules/interfaces/interfaceBootstrap.py
@@ -11,7 +11,6 @@ from modules.shared.configuration import APP_CONFIG
from modules.datamodels.datamodelUam import (
Mandate,
UserInDB,
- UserPrivilege,
AuthAuthority,
)
from modules.datamodels.datamodelRbac import (
@@ -103,7 +102,6 @@ def initAdminUser(db: DatabaseConnector, mandateId: Optional[str]) -> Optional[s
fullName="Administrator",
enabled=True,
language="en",
- privilege=UserPrivilege.SYSADMIN,
roleLabels=["sysadmin"],
authenticationAuthority=AuthAuthority.LOCAL,
hashedPassword=_getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_ADMIN_SECRET")),
@@ -140,7 +138,6 @@ def initEventUser(db: DatabaseConnector, mandateId: Optional[str]) -> Optional[s
fullName="Event",
enabled=True,
language="en",
- privilege=UserPrivilege.SYSADMIN,
roleLabels=["sysadmin"],
authenticationAuthority=AuthAuthority.LOCAL,
hashedPassword=_getPasswordHash(APP_CONFIG.get("APP_INIT_PASS_EVENT_SECRET")),
diff --git a/modules/interfaces/interfaceDbAppObjects.py b/modules/interfaces/interfaceDbAppObjects.py
index 900f7328..cf582fa2 100644
--- a/modules/interfaces/interfaceDbAppObjects.py
+++ b/modules/interfaces/interfaceDbAppObjects.py
@@ -20,7 +20,6 @@ from modules.datamodels.datamodelUam import (
UserInDB,
UserConnection,
AuthAuthority,
- UserPrivilege,
ConnectionStatus,
)
from modules.datamodels.datamodelRbac import (
@@ -488,19 +487,20 @@ class AppObjects:
def getUser(self, userId: str) -> Optional[User]:
"""Returns a user by ID if user has access."""
try:
- # Get all users
- users = self.db.getRecordset(UserInDB)
+ # Get users filtered by RBAC
+ users = self.db.getRecordsetWithRBAC(
+ UserInDB,
+ self.currentUser,
+ recordFilter={"id": userId}
+ )
+
if not users:
return None
- # Find user by ID
- for user_dict in users:
- if user_dict.get("id") == userId:
- # User already filtered by RBAC, just clean fields
- cleanedUser = {k: v for k, v in user_dict.items() if not k.startswith("_")}
- return User(**cleanedUser)
-
- return None
+ # User already filtered by RBAC, just clean fields
+ user_dict = users[0]
+ cleanedUser = {k: v for k, v in user_dict.items() if not k.startswith("_")}
+ return User(**cleanedUser)
except Exception as e:
logger.error(f"Error getting user by ID: {str(e)}")
@@ -542,7 +542,7 @@ class AppObjects:
fullName: str = None,
language: str = "en",
enabled: bool = True,
- privilege: UserPrivilege = UserPrivilege.USER,
+ roleLabels: List[str] = None,
authenticationAuthority: AuthAuthority = AuthAuthority.LOCAL,
externalId: str = None,
externalUsername: str = None,
@@ -568,6 +568,10 @@ class AppObjects:
mandateId = self._getDefaultMandateId()
logger.warning(f"Using default mandate ID {mandateId} for new user {username}")
+ # Default roleLabels to ["user"] if not provided
+ if roleLabels is None or not roleLabels:
+ roleLabels = ["user"]
+
# Create user data using UserInDB model
userData = UserInDB(
username=username,
@@ -576,7 +580,7 @@ class AppObjects:
language=language,
mandateId=mandateId,
enabled=enabled,
- privilege=privilege,
+ roleLabels=roleLabels,
authenticationAuthority=authenticationAuthority,
hashedPassword=self._getPasswordHash(password) if password else None,
connections=[],
@@ -734,7 +738,11 @@ class AppObjects:
if not initialUserId:
return None
- users = self.db.getRecordset(UserInDB, recordFilter={"id": initialUserId})
+ users = self.db.getRecordsetWithRBAC(
+ UserInDB,
+ self.currentUser,
+ recordFilter={"id": initialUserId}
+ )
return users[0] if users else None
except Exception as e:
logger.error(f"Error getting initial user: {str(e)}")
diff --git a/modules/interfaces/interfaceDbChatObjects.py b/modules/interfaces/interfaceDbChatObjects.py
index 6093eb78..ac6df640 100644
--- a/modules/interfaces/interfaceDbChatObjects.py
+++ b/modules/interfaces/interfaceDbChatObjects.py
@@ -757,12 +757,12 @@ class ChatObjects:
messageId = message.id
if messageId:
# Delete message stats
- existing_stats = self.db.getRecordset(ChatStat, recordFilter={"messageId": messageId})
+ existing_stats = self.db.getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"messageId": messageId})
for stat in existing_stats:
self.db.recordDelete(ChatStat, stat["id"])
# Delete message documents (but NOT the files!)
- existing_docs = self.db.getRecordset(ChatDocument, recordFilter={"messageId": messageId})
+ existing_docs = self.db.getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
for doc in existing_docs:
self.db.recordDelete(ChatDocument, doc["id"])
@@ -770,12 +770,12 @@ class ChatObjects:
self.db.recordDelete(ChatMessage, messageId)
# 2. Delete workflow stats
- existing_stats = self.db.getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
+ existing_stats = self.db.getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"workflowId": workflowId})
for stat in existing_stats:
self.db.recordDelete(ChatStat, stat["id"])
# 3. Delete workflow logs
- existing_logs = self.db.getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
+ existing_logs = self.db.getRecordsetWithRBAC(ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})
for log in existing_logs:
self.db.recordDelete(ChatLog, log["id"])
@@ -818,7 +818,7 @@ class ChatObjects:
return PaginatedResult(items=[], totalItems=0, totalPages=0)
# Get messages for this workflow from normalized table
- messages = self.db.getRecordset(ChatMessage, recordFilter={"workflowId": workflowId})
+ messages = self.db.getRecordsetWithRBAC(ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId})
# Convert raw messages to dict format for sorting/filtering
messageDicts = []
@@ -1059,7 +1059,7 @@ class ChatObjects:
raise ValueError("messageId cannot be empty")
# Check if message exists in database
- messages = self.db.getRecordset(ChatMessage, recordFilter={"id": messageId})
+ messages = self.db.getRecordsetWithRBAC(ChatMessage, self.currentUser, recordFilter={"id": messageId})
if not messages:
logger.warning(f"Message with ID {messageId} does not exist in database")
@@ -1164,12 +1164,12 @@ class ChatObjects:
# CASCADE DELETE: Delete all related data first
# 1. Delete message stats
- existing_stats = self.db.getRecordset(ChatStat, recordFilter={"messageId": messageId})
+ existing_stats = self.db.getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"messageId": messageId})
for stat in existing_stats:
self.db.recordDelete(ChatStat, stat["id"])
# 2. Delete message documents (but NOT the files!)
- existing_docs = self.db.getRecordset(ChatDocument, recordFilter={"messageId": messageId})
+ existing_docs = self.db.getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
for doc in existing_docs:
self.db.recordDelete(ChatDocument, doc["id"])
@@ -1196,7 +1196,7 @@ class ChatObjects:
# Get documents for this message from normalized table
- documents = self.db.getRecordset(ChatDocument, recordFilter={"messageId": messageId})
+ documents = self.db.getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
if not documents:
logger.warning(f"No documents found for message {messageId}")
@@ -1239,7 +1239,7 @@ class ChatObjects:
def getDocuments(self, messageId: str) -> List[ChatDocument]:
"""Returns documents for a message from normalized table."""
try:
- documents = self.db.getRecordset(ChatDocument, recordFilter={"messageId": messageId})
+ documents = self.db.getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})
return [ChatDocument(**doc) for doc in documents]
except Exception as e:
logger.error(f"Error getting message documents: {str(e)}")
@@ -1288,7 +1288,7 @@ class ChatObjects:
return PaginatedResult(items=[], totalItems=0, totalPages=0)
# Get logs for this workflow from normalized table
- logs = self.db.getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
+ logs = self.db.getRecordsetWithRBAC(ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})
# Convert raw logs to dict format for sorting/filtering
logDicts = []
@@ -1407,7 +1407,7 @@ class ChatObjects:
return []
# Get stats for this workflow from normalized table
- stats = self.db.getRecordset(ChatStat, recordFilter={"workflowId": workflowId})
+ stats = self.db.getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"workflowId": workflowId})
if not stats:
return []
@@ -1457,7 +1457,7 @@ class ChatObjects:
items = []
# Get messages
- messages = self.db.getRecordset(ChatMessage, recordFilter={"workflowId": workflowId})
+ messages = self.db.getRecordsetWithRBAC(ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId})
for msg in messages:
# Apply timestamp filtering in Python
msgTimestamp = parseTimestamp(msg.get("publishedAt"), default=getUtcTimestamp())
@@ -1498,7 +1498,7 @@ class ChatObjects:
})
# Get logs
- logs = self.db.getRecordset(ChatLog, recordFilter={"workflowId": workflowId})
+ logs = self.db.getRecordsetWithRBAC(ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})
for log in logs:
# Apply timestamp filtering in Python
logTimestamp = parseTimestamp(log.get("timestamp"), default=getUtcTimestamp())
diff --git a/modules/interfaces/interfaceDbComponentObjects.py b/modules/interfaces/interfaceDbComponentObjects.py
index 0e1be949..cedc1fec 100644
--- a/modules/interfaces/interfaceDbComponentObjects.py
+++ b/modules/interfaces/interfaceDbComponentObjects.py
@@ -838,10 +838,11 @@ class ComponentObjects:
def _isfileNameUnique(self, fileName: str, excludeFileId: Optional[str] = None) -> bool:
"""Checks if a fileName is unique for the current user."""
- # Get all files for current user
- files = self.db.getRecordset(FileItem, recordFilter={
- "_createdBy": self.currentUser.id
- })
+ # Get all files filtered by RBAC (will be filtered by user's access level)
+ files = self.db.getRecordsetWithRBAC(
+ FileItem,
+ self.currentUser
+ )
# Check if fileName exists (excluding the current file if updating)
for file in files:
@@ -930,16 +931,20 @@ class ComponentObjects:
if not self.checkRbacPermission(FileItem, "update", fileId):
raise PermissionError(f"No permission to delete file {fileId}")
- # Check for other references to this file (by hash)
+ # Check for other references to this file (by hash) - use RBAC to only check files user has access to
fileHash = file.fileHash
if fileHash:
- otherReferences = [f for f in self.db.getRecordset(FileItem, recordFilter={"fileHash": fileHash})
- if f["id"] != fileId]
+ allReferences = self.db.getRecordsetWithRBAC(
+ FileItem,
+ self.currentUser,
+ recordFilter={"fileHash": fileHash}
+ )
+ otherReferences = [f for f in allReferences if f["id"] != fileId]
# Only delete associated fileData if no other references exist
if not otherReferences:
try:
- fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
+ fileDataEntries = self.db.getRecordsetWithRBAC(FileData, self.currentUser, recordFilter={"id": fileId})
if fileDataEntries:
self.db.recordDelete(FileData, fileId)
logger.debug(f"FileData for file {fileId} deleted")
@@ -1024,7 +1029,7 @@ class ComponentObjects:
logger.warning(f"No access to file ID {fileId}")
return None
- fileDataEntries = self.db.getRecordset(FileData, recordFilter={"id": fileId})
+ fileDataEntries = self.db.getRecordsetWithRBAC(FileData, self.currentUser, recordFilter={"id": fileId})
if not fileDataEntries:
logger.warning(f"No data found for file ID {fileId}")
return None
diff --git a/modules/migration/__init__.py b/modules/migration/__init__.py
deleted file mode 100644
index 49056d7c..00000000
--- a/modules/migration/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Migration modules for database schema and data migrations."""
diff --git a/modules/migration/migrateUamToRbac.py b/modules/migration/migrateUamToRbac.py
deleted file mode 100644
index 688bf8e7..00000000
--- a/modules/migration/migrateUamToRbac.py
+++ /dev/null
@@ -1,212 +0,0 @@
-"""
-Migration script to convert UAM (User Access Management) to RBAC (Role-Based Access Control).
-
-This script:
-1. Creates AccessRule table if it doesn't exist
-2. Adds roleLabels column to User table if it doesn't exist
-3. Converts User.privilege to User.roleLabels
-4. Creates initial RBAC rules based on bootstrap logic
-"""
-
-import logging
-from typing import List, Dict, Any
-from modules.connectors.connectorDbPostgre import DatabaseConnector
-from modules.shared.configuration import APP_CONFIG
-from modules.datamodels.datamodelUam import UserInDB, UserPrivilege
-from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
-from modules.datamodels.datamodelUam import AccessLevel
-from modules.interfaces.interfaceBootstrap import initRbacRules
-
-logger = logging.getLogger(__name__)
-
-
-def migrateUamToRbac(db: DatabaseConnector, dryRun: bool = False) -> Dict[str, Any]:
- """
- Migrate from UAM to RBAC system.
-
- Args:
- db: Database connector instance
- dryRun: If True, only report what would be done without making changes
-
- Returns:
- Dictionary with migration results
- """
- results = {
- "schemaChanges": [],
- "dataMigrations": [],
- "rulesCreated": 0,
- "usersUpdated": 0,
- "errors": []
- }
-
- try:
- # Step 1: Ensure AccessRule table exists
- logger.info("Step 1: Ensuring AccessRule table exists")
- if not dryRun:
- db._ensureTableExists(AccessRule)
- results["schemaChanges"].append("AccessRule table ensured")
- else:
- results["schemaChanges"].append("Would ensure AccessRule table")
-
- # Step 2: Add roleLabels column to UserInDB table if it doesn't exist
- logger.info("Step 2: Adding roleLabels column to UserInDB table")
- if not dryRun:
- try:
- with db.connection.cursor() as cursor:
- # Check if column exists
- cursor.execute("""
- SELECT column_name
- FROM information_schema.columns
- WHERE table_name = 'UserInDB' AND column_name = 'roleLabels'
- """)
- columnExists = cursor.fetchone() is not None
-
- if not columnExists:
- cursor.execute('ALTER TABLE "UserInDB" ADD COLUMN "roleLabels" JSONB DEFAULT \'[]\'::jsonb')
- db.connection.commit()
- results["schemaChanges"].append("Added roleLabels column to UserInDB")
- logger.info("Added roleLabels column to UserInDB table")
- else:
- results["schemaChanges"].append("roleLabels column already exists")
- logger.info("roleLabels column already exists in UserInDB table")
- except Exception as e:
- logger.error(f"Error adding roleLabels column: {e}")
- results["errors"].append(f"Error adding roleLabels column: {e}")
- db.connection.rollback()
- else:
- results["schemaChanges"].append("Would add roleLabels column to UserInDB")
-
- # Step 3: Convert User.privilege to User.roleLabels
- logger.info("Step 3: Converting User.privilege to User.roleLabels")
- if not dryRun:
- try:
- users = db.getRecordset(UserInDB)
- updatedCount = 0
-
- for user in users:
- privilege = user.get("privilege")
- roleLabels = user.get("roleLabels", [])
-
- # Skip if already has roleLabels
- if roleLabels and isinstance(roleLabels, list) and len(roleLabels) > 0:
- logger.debug(f"User {user.get('id')} already has roleLabels: {roleLabels}")
- continue
-
- # Convert privilege to roleLabels
- if privilege == UserPrivilege.SYSADMIN.value:
- newRoleLabels = ["sysadmin"]
- elif privilege == UserPrivilege.ADMIN.value:
- newRoleLabels = ["admin"]
- elif privilege == UserPrivilege.USER.value:
- newRoleLabels = ["user"]
- else:
- # Default to user if privilege is unknown
- newRoleLabels = ["user"]
- logger.warning(f"Unknown privilege '{privilege}' for user {user.get('id')}, defaulting to 'user'")
-
- # Update user
- user["roleLabels"] = newRoleLabels
- db.recordModify(UserInDB, user["id"], user)
- updatedCount += 1
- logger.info(f"Updated user {user.get('id')} ({user.get('username')}): {privilege} -> {newRoleLabels}")
-
- results["usersUpdated"] = updatedCount
- logger.info(f"Updated {updatedCount} users with roleLabels")
- except Exception as e:
- logger.error(f"Error converting user privileges: {e}")
- results["errors"].append(f"Error converting user privileges: {e}")
- else:
- # Dry run: count users that would be updated
- users = db.getRecordset(UserInDB)
- wouldUpdate = 0
- for user in users:
- roleLabels = user.get("roleLabels", [])
- if not roleLabels or not isinstance(roleLabels, list) or len(roleLabels) == 0:
- wouldUpdate += 1
- results["usersUpdated"] = wouldUpdate
- logger.info(f"Would update {wouldUpdate} users with roleLabels")
-
- # Step 4: Create RBAC rules if they don't exist
- logger.info("Step 4: Creating RBAC rules")
- if not dryRun:
- try:
- existingRules = db.getRecordset(AccessRule)
- if existingRules:
- results["rulesCreated"] = len(existingRules)
- results["dataMigrations"].append(f"RBAC rules already exist ({len(existingRules)} rules)")
- logger.info(f"RBAC rules already exist ({len(existingRules)} rules)")
- else:
- # Initialize RBAC rules using bootstrap logic
- initRbacRules(db)
- newRules = db.getRecordset(AccessRule)
- results["rulesCreated"] = len(newRules)
- results["dataMigrations"].append(f"Created {len(newRules)} RBAC rules")
- logger.info(f"Created {len(newRules)} RBAC rules")
- except Exception as e:
- logger.error(f"Error creating RBAC rules: {e}")
- results["errors"].append(f"Error creating RBAC rules: {e}")
- else:
- existingRules = db.getRecordset(AccessRule)
- if existingRules:
- results["rulesCreated"] = len(existingRules)
- results["dataMigrations"].append(f"RBAC rules already exist ({len(existingRules)} rules)")
- else:
- results["dataMigrations"].append("Would create RBAC rules")
-
- logger.info("Migration completed successfully")
- return results
-
- except Exception as e:
- logger.error(f"Migration failed: {e}")
- results["errors"].append(f"Migration failed: {e}")
- return results
-
-
-def validateMigration(db: DatabaseConnector) -> Dict[str, Any]:
- """
- Validate that migration was successful.
-
- Args:
- db: Database connector instance
-
- Returns:
- Dictionary with validation results
- """
- validation = {
- "valid": True,
- "issues": []
- }
-
- try:
- # Check that AccessRule table exists
- try:
- rules = db.getRecordset(AccessRule)
- if not rules:
- validation["valid"] = False
- validation["issues"].append("AccessRule table exists but has no rules")
- except Exception as e:
- validation["valid"] = False
- validation["issues"].append(f"AccessRule table does not exist or is not accessible: {e}")
-
- # Check that all users have roleLabels
- users = db.getRecordset(UserInDB)
- usersWithoutRoles = []
- for user in users:
- roleLabels = user.get("roleLabels", [])
- if not roleLabels or not isinstance(roleLabels, list) or len(roleLabels) == 0:
- usersWithoutRoles.append({
- "id": user.get("id"),
- "username": user.get("username"),
- "privilege": user.get("privilege")
- })
-
- if usersWithoutRoles:
- validation["valid"] = False
- validation["issues"].append(f"{len(usersWithoutRoles)} users without roleLabels: {[u['username'] for u in usersWithoutRoles]}")
-
- return validation
-
- except Exception as e:
- validation["valid"] = False
- validation["issues"].append(f"Validation error: {e}")
- return validation
diff --git a/modules/routes/routeAdminAutomationEvents.py b/modules/routes/routeAdminAutomationEvents.py
index dcac4f27..8eaa0ca7 100644
--- a/modules/routes/routeAdminAutomationEvents.py
+++ b/modules/routes/routeAdminAutomationEvents.py
@@ -11,7 +11,7 @@ import logging
# Import interfaces and models
import modules.interfaces.interfaceDbChatObjects as interfaceDbChatObjects
from modules.security.auth import getCurrentUser, limiter
-from modules.datamodels.datamodelUam import User, UserPrivilege
+from modules.datamodels.datamodelUam import User
# Configure logger
logger = logging.getLogger(__name__)
@@ -30,11 +30,11 @@ router = APIRouter(
)
def requireSysadmin(currentUser: User):
- """Require sysadmin privilege"""
- if currentUser.privilege != UserPrivilege.SYSADMIN:
+ """Require sysadmin role"""
+ if "sysadmin" not in (currentUser.roleLabels or []):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
- detail="Sysadmin privilege required"
+ detail="Sysadmin role required"
)
@router.get("")
diff --git a/modules/routes/routeDataUsers.py b/modules/routes/routeDataUsers.py
index 2f219b5c..017acb17 100644
--- a/modules/routes/routeDataUsers.py
+++ b/modules/routes/routeDataUsers.py
@@ -14,7 +14,7 @@ import modules.interfaces.interfaceDbAppObjects as interfaceDbAppObjects
from modules.security.auth import getCurrentUser, limiter, getCurrentUser
# Import the attribute definition and helper functions
-from modules.datamodels.datamodelUam import User, UserPrivilege
+from modules.datamodels.datamodelUam import User
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
# Configure logger
@@ -141,7 +141,7 @@ async def create_user(
fullName=user_data.fullName,
language=user_data.language,
enabled=user_data.enabled,
- privilege=user_data.privilege,
+ roleLabels=user_data.roleLabels if user_data.roleLabels else ["user"],
authenticationAuthority=user_data.authenticationAuthority
)
@@ -188,7 +188,7 @@ async def reset_user_password(
"""Reset user password (Admin only)"""
try:
# Check if current user is admin
- if currentUser.privilege != UserPrivilege.ADMIN:
+ if "admin" not in (currentUser.roleLabels or []) and "sysadmin" not in (currentUser.roleLabels or []):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators can reset passwords"
diff --git a/modules/routes/routeSecurityAdmin.py b/modules/routes/routeSecurityAdmin.py
index c0513ac0..4899d03a 100644
--- a/modules/routes/routeSecurityAdmin.py
+++ b/modules/routes/routeSecurityAdmin.py
@@ -25,9 +25,10 @@ router = APIRouter(
)
def _ensure_admin_scope(current_user: User, target_mandate_id: Optional[str] = None) -> None:
- if current_user.privilege not in ("admin", "sysadmin"):
+ roleLabels = current_user.roleLabels or []
+ if "admin" not in roleLabels and "sysadmin" not in roleLabels:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin privileges required")
- if current_user.privilege == "admin":
+ if "admin" in roleLabels and "sysadmin" not in roleLabels:
if target_mandate_id and str(target_mandate_id) != str(current_user.mandateId):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden for target mandate")
@@ -63,7 +64,8 @@ async def list_tokens(
recordFilter["connectionId"] = connectionId
if statusFilter:
recordFilter["status"] = statusFilter
- if currentUser.privilege == "admin":
+ roleLabels = currentUser.roleLabels or []
+ if "admin" in roleLabels and "sysadmin" not in roleLabels:
recordFilter["mandateId"] = str(currentUser.mandateId)
tokens = appInterface.db.getRecordset(Token, recordFilter=recordFilter)
@@ -95,10 +97,11 @@ async def revoke_tokens_by_user(
target_mandate = target_user[0].get("mandateId") if target_user else None
_ensure_admin_scope(currentUser, target_mandate)
+ roleLabels = currentUser.roleLabels or []
count = appInterface.revokeTokensByUser(
userId=userId,
authority=AuthAuthority(authority) if authority else None,
- mandateId=None if currentUser.privilege == "sysadmin" else str(currentUser.mandateId),
+ mandateId=None if "sysadmin" in roleLabels else str(currentUser.mandateId),
revokedBy=currentUser.id,
reason=reason
)
diff --git a/modules/routes/routeSecurityLocal.py b/modules/routes/routeSecurityLocal.py
index 7b08ceed..858cf3c6 100644
--- a/modules/routes/routeSecurityLocal.py
+++ b/modules/routes/routeSecurityLocal.py
@@ -15,7 +15,7 @@ from jose import jwt
from modules.security.auth import getCurrentUser, limiter, SECRET_KEY, ALGORITHM
from modules.security.jwtService import createAccessToken, createRefreshToken, setAccessTokenCookie, setRefreshTokenCookie, clearAccessTokenCookie, clearRefreshTokenCookie
from modules.interfaces.interfaceDbAppObjects import getInterface, getRootInterface
-from modules.datamodels.datamodelUam import User, UserInDB, AuthAuthority, UserPrivilege
+from modules.datamodels.datamodelUam import User, UserInDB, AuthAuthority
from modules.datamodels.datamodelSecurity import Token
# Configure logger
@@ -212,9 +212,8 @@ async def register_user(
appInterface.mandateId = defaultMandateId
# Create user with local authentication
- # Set safe default privilege level for new registrations
+ # Set safe default role for new registrations
# New users are disabled by default and require admin approval
- from modules.datamodels.datamodelUam import UserPrivilege
user = appInterface.createUser(
username=userData.username,
password=password,
@@ -222,7 +221,7 @@ async def register_user(
fullName=userData.fullName,
language=userData.language,
enabled=False, # New users are disabled by default
- privilege=UserPrivilege.USER, # Always set to USER for new registrations
+ roleLabels=["user"], # Default role for new registrations
authenticationAuthority=AuthAuthority.LOCAL
)
diff --git a/modules/routes/routeWorkflows.py b/modules/routes/routeWorkflows.py
index 080e8077..6ab0598a 100644
--- a/modules/routes/routeWorkflows.py
+++ b/modules/routes/routeWorkflows.py
@@ -427,8 +427,12 @@ async def delete_workflow(
# Get service center
interfaceDbChat = getServiceChat(currentUser)
- # Get raw workflow data from database to check permissions
- workflows = interfaceDbChat.db.getRecordset(ChatWorkflow, recordFilter={"id": workflowId})
+ # Check workflow access and permission using RBAC
+ workflows = interfaceDbChat.db.getRecordsetWithRBAC(
+ ChatWorkflow,
+ currentUser,
+ recordFilter={"id": workflowId}
+ )
if not workflows:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
diff --git a/tests/integration/rbac/test_rbac_migration.py b/tests/integration/rbac/test_rbac_migration.py
deleted file mode 100644
index 86f3eb6d..00000000
--- a/tests/integration/rbac/test_rbac_migration.py
+++ /dev/null
@@ -1,282 +0,0 @@
-"""
-Integration tests for UAM to RBAC migration.
-Tests that migration correctly converts user privileges to roleLabels.
-Uses real database connection for integration testing.
-"""
-
-import pytest
-from modules.migration.migrateUamToRbac import migrateUamToRbac, validateMigration
-from modules.datamodels.datamodelUam import UserInDB, UserPrivilege
-from modules.connectors.connectorDbPostgre import DatabaseConnector
-from modules.shared.configuration import APP_CONFIG
-
-
-@pytest.fixture(scope="class")
-def db():
- """Create real database connector for integration tests."""
- dbHost = APP_CONFIG.get("DB_HOST", "localhost")
- dbDatabase = APP_CONFIG.get("DB_DATABASE", "poweron_test")
- dbUser = APP_CONFIG.get("DB_USER", "postgres")
- dbPassword = APP_CONFIG.get("DB_PASSWORD", "")
- dbPort = APP_CONFIG.get("DB_PORT", 5432)
-
- db = DatabaseConnector(
- dbHost=dbHost,
- dbDatabase=dbDatabase,
- dbUser=dbUser,
- dbPassword=dbPassword,
- dbPort=dbPort
- )
- yield db
- db.close()
-
-
-class TestRbacMigration:
- """Test RBAC migration from UAM."""
-
- def testMigrateUserPrivilegeToRoleLabels(self, db):
- """Test that user privileges are correctly converted to roleLabels."""
- # Create test users with privileges but no roleLabels
- testUsers = [
- UserInDB(
- id="migrate_test_user1",
- username="migrate_admin",
- privilege=UserPrivilege.SYSADMIN.value
- ),
- UserInDB(
- id="migrate_test_user2",
- username="migrate_admin2",
- privilege=UserPrivilege.ADMIN.value
- ),
- UserInDB(
- id="migrate_test_user3",
- username="migrate_user1",
- privilege=UserPrivilege.USER.value
- )
- ]
-
- try:
- # Create test users in database
- for user in testUsers:
- userData = user.model_dump()
- # Ensure roleLabels is None/empty for migration test
- userData["roleLabels"] = []
- userData["id"] = user.id
- db.recordCreate(UserInDB, userData)
-
- # Run migration
- results = migrateUamToRbac(db, dryRun=False)
-
- # Check that users were updated
- assert results["usersUpdated"] == 3
-
- # Verify users were actually updated in database
- users1 = db.getRecordset(UserInDB, recordFilter={"id": "migrate_test_user1"})
- users2 = db.getRecordset(UserInDB, recordFilter={"id": "migrate_test_user2"})
- users3 = db.getRecordset(UserInDB, recordFilter={"id": "migrate_test_user3"})
- user1 = users1[0] if users1 else None
- user2 = users2[0] if users2 else None
- user3 = users3[0] if users3 else None
-
- assert user1 is not None
- assert "sysadmin" in user1.get("roleLabels", [])
-
- assert user2 is not None
- assert "admin" in user2.get("roleLabels", [])
-
- assert user3 is not None
- assert "user" in user3.get("roleLabels", [])
- finally:
- # Cleanup test users
- for user in testUsers:
- try:
- db.recordDelete(UserInDB, user.id)
- except:
- pass
-
- def testMigrationSkipsUsersWithExistingRoleLabels(self, db):
- """Test that migration skips users who already have roleLabels."""
- # Create test users: one with roleLabels, one without
- user1 = UserInDB(
- id="skip_test_user1",
- username="skip_admin",
- privilege=UserPrivilege.SYSADMIN.value,
- roleLabels=["sysadmin"] # Already migrated
- )
- user2 = UserInDB(
- id="skip_test_user2",
- username="skip_user1",
- privilege=UserPrivilege.USER.value,
- roleLabels=[] # Needs migration
- )
-
- try:
- # Create test users in database
- user1Data = user1.model_dump()
- user1Data["id"] = user1.id
- user2Data = user2.model_dump()
- user2Data["id"] = user2.id
- db.recordCreate(UserInDB, user1Data)
- db.recordCreate(UserInDB, user2Data)
-
- # Run migration
- results = migrateUamToRbac(db, dryRun=False)
-
- # Only one user should be updated (user2)
- assert results["usersUpdated"] == 1
-
- # Verify user1 still has original roleLabels
- users1 = db.getRecordset(UserInDB, recordFilter={"id": "skip_test_user1"})
- updatedUser1 = users1[0] if users1 else None
- assert updatedUser1 is not None
- assert "sysadmin" in updatedUser1.get("roleLabels", [])
-
- # Verify user2 was updated
- users2 = db.getRecordset(UserInDB, recordFilter={"id": "skip_test_user2"})
- updatedUser2 = users2[0] if users2 else None
- assert updatedUser2 is not None
- assert "user" in updatedUser2.get("roleLabels", [])
- finally:
- # Cleanup test users
- try:
- db.recordDelete(UserInDB, "skip_test_user1")
- db.recordDelete(UserInDB, "skip_test_user2")
- except:
- pass
-
- def testDryRunMode(self, db):
- """Test that dry run mode doesn't make changes."""
- # Create test user without roleLabels
- testUser = UserInDB(
- id="dryrun_test_user1",
- username="dryrun_admin",
- privilege=UserPrivilege.SYSADMIN.value,
- roleLabels=[] # Needs migration
- )
-
- try:
- # Create test user in database
- userData = testUser.model_dump()
- userData["id"] = testUser.id
- db.recordCreate(UserInDB, userData)
-
- # Get original state
- originalUsers = db.getRecordset(UserInDB, recordFilter={"id": "dryrun_test_user1"})
- originalUser = originalUsers[0] if originalUsers else None
- assert originalUser is not None
- originalRoleLabels = originalUser.get("roleLabels", [])
-
- # Run migration in dry run mode
- results = migrateUamToRbac(db, dryRun=True)
-
- # Should report what would be done
- assert results["usersUpdated"] == 1
-
- # Verify user was NOT actually updated
- unchangedUsers = db.getRecordset(UserInDB, recordFilter={"id": "dryrun_test_user1"})
- unchangedUser = unchangedUsers[0] if unchangedUsers else None
- assert unchangedUser is not None
- assert unchangedUser.get("roleLabels", []) == originalRoleLabels
- finally:
- # Cleanup test user
- try:
- db.recordDelete(UserInDB, "dryrun_test_user1")
- except:
- pass
-
- def testValidateMigrationSuccess(self, db):
- """Test validation passes when migration is successful."""
- # Create test users with roleLabels (already migrated)
- testUsers = [
- UserInDB(
- id="validate_test_user1",
- username="validate_admin",
- privilege=UserPrivilege.SYSADMIN.value,
- roleLabels=["sysadmin"]
- ),
- UserInDB(
- id="validate_test_user2",
- username="validate_admin2",
- privilege=UserPrivilege.ADMIN.value,
- roleLabels=["admin"]
- )
- ]
-
- try:
- # Create test users in database
- for user in testUsers:
- userData = user.model_dump()
- userData["id"] = user.id
- db.recordCreate(UserInDB, userData)
-
- # Ensure AccessRule table exists (migration should have created it)
- from modules.datamodels.datamodelRbac import AccessRule
- db._ensureTableExists(AccessRule)
-
- # Run validation
- validation = validateMigration(db)
-
- assert validation["valid"] == True
- assert len(validation["issues"]) == 0
- finally:
- # Cleanup test users
- for user in testUsers:
- try:
- db.recordDelete(UserInDB, user.id)
- except:
- pass
-
- def testValidateMigrationFailsWithoutRoleLabels(self, db):
- """Test validation fails when users don't have roleLabels."""
- # Create test users: one with roleLabels, one without, one with empty roleLabels
- testUsers = [
- UserInDB(
- id="validate_fail_user1",
- username="validate_fail_admin",
- privilege=UserPrivilege.SYSADMIN.value,
- roleLabels=["sysadmin"] # Has roleLabels
- ),
- UserInDB(
- id="validate_fail_user2",
- username="validate_fail_user",
- privilege=UserPrivilege.USER.value,
- roleLabels=[] # Empty roleLabels
- ),
- UserInDB(
- id="validate_fail_user3",
- username="validate_fail_user2",
- privilege=UserPrivilege.USER.value
- # Missing roleLabels field (will be None)
- )
- ]
-
- try:
- # Create test users in database
- for user in testUsers:
- userData = user.model_dump()
- userData["id"] = user.id
- # For user3, explicitly set roleLabels to None or remove it
- if user.id == "validate_fail_user3":
- if "roleLabels" in userData:
- del userData["roleLabels"]
- db.recordCreate(UserInDB, userData)
-
- # Ensure AccessRule table exists
- from modules.datamodels.datamodelRbac import AccessRule
- db._ensureTableExists(AccessRule)
-
- # Run validation
- validation = validateMigration(db)
-
- assert validation["valid"] == False
- assert len(validation["issues"]) > 0
- # Check that validation found users without roleLabels
- issuesStr = " ".join(validation["issues"])
- assert "users without roleLabels" in issuesStr or "without roleLabels" in issuesStr
- finally:
- # Cleanup test users
- for user in testUsers:
- try:
- db.recordDelete(UserInDB, user.id)
- except:
- pass
diff --git a/tests/unit/rbac/test_rbac_bootstrap.py b/tests/unit/rbac/test_rbac_bootstrap.py
index 573a4fd1..e12592a1 100644
--- a/tests/unit/rbac/test_rbac_bootstrap.py
+++ b/tests/unit/rbac/test_rbac_bootstrap.py
@@ -14,7 +14,7 @@ from modules.interfaces.interfaceBootstrap import (
createDefaultRoleRules,
createTableSpecificRules
)
-from modules.datamodels.datamodelUam import UserInDB, Mandate, UserPrivilege, AuthAuthority
+from modules.datamodels.datamodelUam import UserInDB, Mandate, AuthAuthority
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
from modules.datamodels.datamodelUam import AccessLevel
@@ -62,7 +62,6 @@ class TestRbacBootstrap:
assert isinstance(user, UserInDB)
assert user.username == "admin"
assert "sysadmin" in user.roleLabels
- assert user.privilege == UserPrivilege.SYSADMIN
def testInitEventUserCreatesWithSysadminRole(self):
"""Test that initEventUser creates user with sysadmin role."""
From d009f93dba798b857ff63b728085d6bb017f8b5c Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Sun, 7 Dec 2025 23:51:05 +0100
Subject: [PATCH 4/6] rbac roles and rules integration tests passed
---
app.py | 3 +
docs/frontend_options_usage.md | 229 ++++++
docs/rbac_admin_roles_and_options_api.md | 372 +++++++++
docs/rbac_getrecordset_review.md | 135 ----
modules/aicore/aicoreModelRegistry.py | 80 +-
modules/connectors/connectorDbPostgre.py | 21 +-
modules/datamodels/datamodelRbac.py | 37 +-
modules/datamodels/datamodelUam.py | 19 +-
modules/features/options/mainOptions.py | 127 ++++
modules/interfaces/interfaceBootstrap.py | 455 ++++++++++-
modules/interfaces/interfaceDbAppObjects.py | 168 +++-
modules/interfaces/interfaceDbChatObjects.py | 5 +-
.../interfaces/interfaceDbComponentObjects.py | 5 +-
modules/routes/routeAdminRbacRoles.py | 716 ++++++++++++++++++
modules/routes/routeAttributes.py | 26 +-
modules/routes/routeOptions.py | 81 ++
modules/routes/routeRbac.py | 626 ++++++++++++++-
modules/security/rbac.py | 32 +-
modules/shared/attributeUtils.py | 40 +-
modules/shared/frontendOptionsTypes.py | 136 ++++
pytest.ini | 9 +
tests/functional/test_kpi_fix.py | 86 ---
tests/functional/test_kpi_full.py | 4 +-
tests/functional/test_kpi_incomplete.py | 9 +-
tests/functional/test_repair_debug.py | 58 --
tests/integration/options/test_options_api.py | 241 ++++++
.../options/test_frontend_options_types.py | 115 +++
tests/unit/options/test_main_options.py | 181 +++++
tests/unit/rbac/test_rbac_bootstrap.py | 18 +-
tests/unit/rbac/test_rbac_permissions.py | 27 +-
tests/unit/services/test_ai_service.py | 146 ----
31 files changed, 3681 insertions(+), 526 deletions(-)
create mode 100644 docs/frontend_options_usage.md
create mode 100644 docs/rbac_admin_roles_and_options_api.md
delete mode 100644 docs/rbac_getrecordset_review.md
create mode 100644 modules/features/options/mainOptions.py
create mode 100644 modules/routes/routeAdminRbacRoles.py
create mode 100644 modules/routes/routeOptions.py
create mode 100644 modules/shared/frontendOptionsTypes.py
delete mode 100644 tests/functional/test_kpi_fix.py
delete mode 100644 tests/functional/test_repair_debug.py
create mode 100644 tests/integration/options/test_options_api.py
create mode 100644 tests/unit/options/test_frontend_options_types.py
create mode 100644 tests/unit/options/test_main_options.py
delete mode 100644 tests/unit/services/test_ai_service.py
diff --git a/app.py b/app.py
index 23a8cb5c..61ec677c 100644
--- a/app.py
+++ b/app.py
@@ -440,3 +440,6 @@ app.include_router(adminAutomationEventsRouter)
from modules.routes.routeRbac import router as rbacRouter
app.include_router(rbacRouter)
+from modules.routes.routeOptions import router as optionsRouter
+app.include_router(optionsRouter)
+
diff --git a/docs/frontend_options_usage.md b/docs/frontend_options_usage.md
new file mode 100644
index 00000000..60489118
--- /dev/null
+++ b/docs/frontend_options_usage.md
@@ -0,0 +1,229 @@
+# Frontend Options Usage Guide
+
+## Overview
+
+The `frontend_options` attribute in Pydantic `Field` definitions supports **two formats** for providing options to frontend select/multiselect fields:
+
+1. **Static List**: Predefined list of options
+2. **String Reference**: Dynamic options fetched from the Options API
+
+## Type System
+
+The type system is defined in `gateway/modules/shared/frontendOptionsTypes.py`:
+
+```python
+from modules.shared.frontendOptionsTypes import FrontendOptions, OptionItem
+
+# FrontendOptions is Union[List[OptionItem], str]
+# OptionItem is Dict[str, Any] with "value" and "label" keys
+```
+
+## Format 1: Static List
+
+Use static lists for fixed, predefined options that don't change based on user context.
+
+### Example
+
+```python
+from pydantic import Field
+from typing import List
+
+language: str = Field(
+ default="en",
+ description="Preferred language",
+ json_schema_extra={
+ "frontend_type": "select",
+ "frontend_readonly": False,
+ "frontend_required": True,
+ "frontend_options": [
+ {"value": "en", "label": {"en": "English", "fr": "Anglais"}},
+ {"value": "fr", "label": {"en": "Français", "fr": "Français"}},
+ {"value": "de", "label": {"en": "Deutsch", "fr": "Allemand"}},
+ ]
+ }
+)
+```
+
+### When to Use Static Lists
+
+- Options are fixed constants (e.g., enum values)
+- Options don't require database queries
+- Options are the same for all users
+- Options are simple and don't change frequently
+
+## Format 2: String Reference
+
+Use string references for dynamic options that come from the database or are context-aware.
+
+### Example
+
+```python
+from pydantic import Field
+from typing import List
+
+roleLabels: List[str] = Field(
+ default_factory=list,
+ description="List of role labels",
+ json_schema_extra={
+ "frontend_type": "multiselect",
+ "frontend_readonly": False,
+ "frontend_required": True,
+ "frontend_options": "user.role" # String reference
+ }
+)
+```
+
+### When to Use String References
+
+- Options come from the database (e.g., user connections)
+- Options are context-aware (filtered by current user's permissions)
+- Options need centralized management
+- Options may change frequently
+- Options depend on user context or permissions
+
+### Frontend Integration
+
+When the frontend encounters a string reference:
+
+1. **Detect**: Check if `frontend_options` is a string (not a list)
+2. **Fetch**: Call `GET /api/options/{optionsName}` (e.g., `/api/options/user.role`)
+3. **Use**: Use the returned options for the select/multiselect field
+
+**Example Frontend Code**:
+```typescript
+// Pseudocode
+if (typeof field.frontend_options === 'string') {
+ // Dynamic options - fetch from API
+ const options = await fetch(`/api/options/${field.frontend_options}`);
+ return options;
+} else {
+ // Static options - use directly
+ return field.frontend_options;
+}
+```
+
+## Available Option Names
+
+| Option Name | Description | Context-Aware |
+|-------------|-------------|---------------|
+| `user.role` | Standard role options (sysadmin, admin, user, viewer) | No |
+| `auth.authority` | Authentication authority options (local, google, msft) | No |
+| `connection.status` | Connection status options (active, inactive, expired, error) | No |
+| `user.connection` | User's connections (fetched from database) | Yes (requires currentUser) |
+
+## Utility Functions
+
+The `frontendOptionsTypes` module provides utility functions:
+
+```python
+from modules.shared.frontendOptionsTypes import (
+ isStringReference,
+ isStaticList,
+ validateFrontendOptions,
+ getOptionsName,
+ getStaticOptions
+)
+
+# Check format
+if isStringReference(frontend_options):
+ optionsName = getOptionsName(frontend_options)
+ # Fetch from API: /api/options/{optionsName}
+elif isStaticList(frontend_options):
+ options = getStaticOptions(frontend_options)
+ # Use directly
+
+# Validate format
+if not validateFrontendOptions(frontend_options):
+ raise ValueError("Invalid frontend_options format")
+```
+
+## Validation
+
+The `validateFrontendOptions()` function ensures:
+
+1. **String References**: Non-empty string
+2. **Static Lists**:
+ - List of dictionaries
+ - Each dictionary has `"value"` and `"label"` keys
+ - `"label"` is a dictionary (multilingual labels)
+
+## Examples in Codebase
+
+### Static List Example
+```python
+# datamodelUam.py - Language field
+language: str = Field(
+ default="en",
+ json_schema_extra={
+ "frontend_options": [
+ {"value": "en", "label": {"en": "English", "fr": "Anglais"}},
+ {"value": "fr", "label": {"en": "Français", "fr": "Français"}},
+ ]
+ }
+)
+```
+
+### String Reference Example
+```python
+# datamodelUam.py - Role labels field
+roleLabels: List[str] = Field(
+ default_factory=list,
+ json_schema_extra={
+ "frontend_options": "user.role" # Dynamic - fetched from API
+ }
+)
+```
+
+### Mixed Example
+```python
+# datamodelRbac.py - AccessRule model
+roleLabel: str = Field(
+ json_schema_extra={
+ "frontend_options": "user.role" # String reference
+ }
+)
+
+context: AccessRuleContext = Field(
+ json_schema_extra={
+ "frontend_options": [ # Static list
+ {"value": "DATA", "label": {"en": "Data", "fr": "Données"}},
+ {"value": "UI", "label": {"en": "UI", "fr": "Interface"}},
+ {"value": "RESOURCE", "label": {"en": "Resource", "fr": "Ressource"}}
+ ]
+ }
+)
+```
+
+## Best Practices
+
+1. **Use Static Lists** for:
+ - Enum values
+ - Fixed constants
+ - Simple options that don't change
+
+2. **Use String References** for:
+ - Database-driven options
+ - Context-aware options
+ - Options that need centralized management
+
+3. **Always validate** frontend_options format when processing
+
+4. **Document** which format is used and why in field descriptions
+
+5. **Frontend**: Always check the type before using options
+
+## Migration Guide
+
+If you have existing static lists that should become dynamic:
+
+1. **Create Options Provider**: Add option logic to `gateway/modules/features/options/mainOptions.py`
+2. **Register Option Name**: Add to `getAvailableOptionsNames()` function
+3. **Update Field**: Change `frontend_options` from list to string reference
+4. **Update Frontend**: Ensure frontend handles string references correctly
+
+## See Also
+
+- `gateway/modules/shared/frontendOptionsTypes.py` - Type definitions and utilities
+- `gateway/modules/features/options/mainOptions.py` - Options API implementation
+- `gateway/modules/routes/routeOptions.py` - Options API endpoints
+- `wiki/appdoc/doc_security_role_based_access.md` - RBAC documentation with frontend_options examples
diff --git a/docs/rbac_admin_roles_and_options_api.md b/docs/rbac_admin_roles_and_options_api.md
new file mode 100644
index 00000000..9265961d
--- /dev/null
+++ b/docs/rbac_admin_roles_and_options_api.md
@@ -0,0 +1,372 @@
+# RBAC Admin Roles Management & Options API
+
+## Overview
+
+This document describes two new features added to support RBAC management:
+
+1. **Options API**: Dynamic options endpoint for frontend select/multiselect fields
+2. **Admin RBAC Roles Module**: Comprehensive role and role assignment management
+
+---
+
+## 1. Options API
+
+### Purpose
+
+The Options API provides dynamic options for frontend form fields that use `frontend_options` as a string reference (e.g., `"user.role"`). This allows the frontend to fetch options from the backend, enabling:
+- Database-driven options (e.g., user connections)
+- Context-aware options (filtered by current user's permissions)
+- Centralized option management
+
+### Frontend Options Format
+
+The `frontend_options` attribute in Pydantic `Field` definitions supports **two formats**:
+
+#### 1. Static List (for basic data types)
+```python
+frontend_options=[
+ {"value": "a", "label": {"en": "All Records", "fr": "Tous les enregistrements"}},
+ {"value": "m", "label": {"en": "My Records", "fr": "Mes enregistrements"}}
+]
+```
+
+#### 2. String Reference (for dynamic/custom types)
+```python
+frontend_options="user.role" # Frontend fetches from /api/options/user.role
+```
+
+### API Endpoints
+
+#### Get Options
+```
+GET /api/options/{optionsName}
+```
+
+**Path Parameters:**
+- `optionsName`: Name of the options set (e.g., "user.role", "user.connection")
+
+**Response:**
+```json
+[
+ {
+ "value": "sysadmin",
+ "label": {
+ "en": "System Administrator",
+ "fr": "Administrateur système"
+ }
+ },
+ {
+ "value": "admin",
+ "label": {
+ "en": "Administrator",
+ "fr": "Administrateur"
+ }
+ }
+]
+```
+
+**Examples:**
+- `GET /api/options/user.role` - Get available role options
+- `GET /api/options/user.connection` - Get user's connections (context-aware)
+- `GET /api/options/auth.authority` - Get authentication authority options
+- `GET /api/options/connection.status` - Get connection status options
+
+#### List Available Options
+```
+GET /api/options/
+```
+
+**Response:**
+```json
+[
+ "user.role",
+ "auth.authority",
+ "connection.status",
+ "user.connection"
+]
+```
+
+### Available Options
+
+| Options Name | Description | Context-Aware |
+|-------------|------------|---------------|
+| `user.role` | Standard role options (sysadmin, admin, user, viewer) | No |
+| `auth.authority` | Authentication authority options (local, google, msft) | No |
+| `connection.status` | Connection status options (active, inactive, expired, error) | No |
+| `user.connection` | User's connections (fetched from database) | Yes (requires currentUser) |
+
+### Implementation
+
+**Files:**
+- `gateway/modules/features/options/mainOptions.py` - Options logic
+- `gateway/modules/routes/routeOptions.py` - Options API endpoints
+
+**Usage in Pydantic Models:**
+```python
+roleLabels: List[str] = Field(
+ default_factory=list,
+ description="List of role labels",
+ json_schema_extra={
+ "frontend_type": "multiselect",
+ "frontend_readonly": False,
+ "frontend_required": True,
+ "frontend_options": "user.role" # String reference
+ }
+)
+```
+
+---
+
+## 2. Admin RBAC Roles Module
+
+### Purpose
+
+The Admin RBAC Roles module provides comprehensive management of roles and role assignments to users. This module allows administrators to:
+- View all available roles with metadata
+- List users with their role assignments
+- Assign/remove roles to/from users
+- Filter users by role or mandate
+- View role statistics (user counts per role)
+
+### Access Control
+
+**Required Permissions:**
+- User must have `admin` or `sysadmin` role
+- RBAC permission check for `UserInDB` table update operations
+
+### API Endpoints
+
+#### List All Roles
+```
+GET /api/admin/rbac/roles/
+```
+
+**Response:**
+```json
+[
+ {
+ "roleLabel": "sysadmin",
+ "description": {
+ "en": "System Administrator - Full access to all system resources",
+ "fr": "Administrateur système - Accès complet à toutes les ressources"
+ },
+ "userCount": 2,
+ "isSystemRole": true
+ },
+ {
+ "roleLabel": "admin",
+ "description": {
+ "en": "Administrator - Manage users and resources within mandate scope",
+ "fr": "Administrateur - Gérer les utilisateurs et ressources dans le périmètre du mandat"
+ },
+ "userCount": 5,
+ "isSystemRole": true
+ }
+]
+```
+
+#### List Users with Roles
+```
+GET /api/admin/rbac/roles/users?roleLabel=admin&mandateId=mandate-123
+```
+
+**Query Parameters:**
+- `roleLabel` (optional): Filter by role label
+- `mandateId` (optional): Filter by mandate ID
+
+**Response:**
+```json
+[
+ {
+ "id": "user-123",
+ "username": "john.doe",
+ "email": "john@example.com",
+ "fullName": "John Doe",
+ "mandateId": "mandate-123",
+ "enabled": true,
+ "roleLabels": ["admin", "user"],
+ "roleCount": 2
+ }
+]
+```
+
+#### Get User Roles
+```
+GET /api/admin/rbac/roles/users/{userId}
+```
+
+**Response:**
+```json
+{
+ "id": "user-123",
+ "username": "john.doe",
+ "email": "john@example.com",
+ "fullName": "John Doe",
+ "mandateId": "mandate-123",
+ "enabled": true,
+ "roleLabels": ["admin", "user"],
+ "roleCount": 2
+}
+```
+
+#### Update User Roles
+```
+PUT /api/admin/rbac/roles/users/{userId}/roles
+```
+
+**Request Body:**
+```json
+{
+ "roleLabels": ["admin", "user"]
+}
+```
+
+**Response:**
+Updated user object with new role assignments
+
+#### Add Role to User
+```
+POST /api/admin/rbac/roles/users/{userId}/roles/{roleLabel}
+```
+
+**Response:**
+Updated user object with role added (if not already present)
+
+#### Remove Role from User
+```
+DELETE /api/admin/rbac/roles/users/{userId}/roles/{roleLabel}
+```
+
+**Response:**
+Updated user object with role removed
+
+**Note:** If all roles are removed, user defaults to `"user"` role
+
+#### Get Users with Specific Role
+```
+GET /api/admin/rbac/roles/roles/{roleLabel}/users?mandateId=mandate-123
+```
+
+**Query Parameters:**
+- `mandateId` (optional): Filter by mandate ID
+
+**Response:**
+List of users with the specified role
+
+### Standard Roles
+
+| Role Label | Description | System Role |
+|-----------|-------------|-------------|
+| `sysadmin` | System Administrator - Full access to all system resources | Yes |
+| `admin` | Administrator - Manage users and resources within mandate scope | Yes |
+| `user` | User - Standard user with access to own records | Yes |
+| `viewer` | Viewer - Read-only access to group records | Yes |
+
+**Custom Roles:** The system also supports custom role labels. These are detected when users are assigned non-standard roles and are marked with `isSystemRole: false`.
+
+### Implementation
+
+**Files:**
+- `gateway/modules/routes/routeAdminRbacRoles.py` - Admin RBAC Roles API endpoints
+
+**Dependencies:**
+- `gateway/modules/interfaces/interfaceDbAppObjects.py` - User management interface
+- `gateway/modules/security/auth.py` - Authentication and authorization
+
+### Usage Examples
+
+#### Assign Multiple Roles to User
+```bash
+curl -X PUT "http://localhost:8000/api/admin/rbac/roles/users/user-123/roles" \
+ -H "Authorization: Bearer " \
+ -H "Content-Type: application/json" \
+ -d '{"roleLabels": ["admin", "user"]}'
+```
+
+#### Add Single Role
+```bash
+curl -X POST "http://localhost:8000/api/admin/rbac/roles/users/user-123/roles/admin" \
+ -H "Authorization: Bearer "
+```
+
+#### Remove Role
+```bash
+curl -X DELETE "http://localhost:8000/api/admin/rbac/roles/users/user-123/roles/viewer" \
+ -H "Authorization: Bearer "
+```
+
+#### List All Admins
+```bash
+curl "http://localhost:8000/api/admin/rbac/roles/roles/admin/users" \
+ -H "Authorization: Bearer "
+```
+
+---
+
+## Integration
+
+### Route Registration
+
+Both modules are registered in `gateway/app.py`:
+
+```python
+from modules.routes.routeOptions import router as optionsRouter
+app.include_router(optionsRouter)
+
+from modules.routes.routeAdminRbacRoles import router as adminRbacRolesRouter
+app.include_router(adminRbacRolesRouter)
+```
+
+### Frontend Integration
+
+#### Using Dynamic Options
+
+When a Pydantic model field uses `frontend_options` as a string reference:
+
+```python
+roleLabels: List[str] = Field(
+ frontend_options="user.role"
+)
+```
+
+The frontend should:
+1. Detect the string reference (not a list)
+2. Fetch options from `/api/options/user.role`
+3. Use the returned options for the select/multiselect field
+
+#### Using Admin RBAC Roles Module
+
+The frontend can use the Admin RBAC Roles endpoints to:
+- Display role management UI
+- Show role assignments in user management
+- Provide role assignment controls
+- Display role statistics
+
+---
+
+## Security Considerations
+
+1. **Options API**:
+ - Requires authentication (currentUser dependency)
+ - Context-aware options (e.g., `user.connection`) are filtered by current user
+ - Rate limited: 120 requests/minute
+
+2. **Admin RBAC Roles Module**:
+ - Requires `admin` or `sysadmin` role
+ - All endpoints are rate limited: 30-60 requests/minute
+ - RBAC permission checks ensure users can only manage roles if they have permission
+
+---
+
+## Future Enhancements
+
+1. **Options API**:
+ - Add more option types (e.g., mandate options, workflow options)
+ - Support for filtered options based on RBAC permissions
+ - Caching for frequently accessed options
+
+2. **Admin RBAC Roles Module**:
+ - Role metadata management (descriptions, permissions summary)
+ - Bulk role assignment operations
+ - Role usage analytics
+ - Role templates/presets
diff --git a/docs/rbac_getrecordset_review.md b/docs/rbac_getrecordset_review.md
deleted file mode 100644
index d2c06524..00000000
--- a/docs/rbac_getrecordset_review.md
+++ /dev/null
@@ -1,135 +0,0 @@
-# RBAC getRecordset() Review
-
-## Overview
-Review of all `getRecordset()` calls in `interfaceDbChatObjects.py` and `interfaceDbComponentObjects.py` to determine which should be converted to `getRecordsetWithRBAC()`.
-
-## Analysis Criteria
-- **Convert to RBAC**: User-facing data that should respect access control
-- **Keep as-is**: Internal/technical operations that don't need RBAC filtering
-
----
-
-## interfaceDbChatObjects.py
-
-### Summary: **14 calls found - ALL should be converted to `getRecordsetWithRBAC()`**
-
-All calls access user-facing data (ChatMessage, ChatDocument, ChatStat, ChatLog) and should respect RBAC even when:
-- Used in cascade delete operations (after parent access is verified)
-- Used to fetch child records (after parent access is verified)
-- Used for existence checks
-
-**Rationale**: RBAC should be applied at every data access point to ensure consistent security and prevent potential bypass scenarios.
-
-### Detailed List:
-
-1. **Line 760** - `deleteWorkflow()` - Cascade delete ChatStat
- - **Action**: Convert to `getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"messageId": messageId})`
- - **Reason**: Deleting related data should respect RBAC
-
-2. **Line 765** - `deleteWorkflow()` - Cascade delete ChatDocument
- - **Action**: Convert to `getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})`
- - **Reason**: Deleting related data should respect RBAC
-
-3. **Line 773** - `deleteWorkflow()` - Cascade delete ChatStat (workflow level)
- - **Action**: Convert to `getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"workflowId": workflowId})`
- - **Reason**: Deleting related data should respect RBAC
-
-4. **Line 778** - `deleteWorkflow()` - Cascade delete ChatLog
- - **Action**: Convert to `getRecordsetWithRBAC(ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})`
- - **Reason**: Deleting related data should respect RBAC
-
-5. **Line 821** - `getMessages()` - Fetch messages for workflow
- - **Action**: Convert to `getRecordsetWithRBAC(ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId})`
- - **Reason**: Child records should still respect RBAC even if parent access is verified
-
-6. **Line 1062** - `updateMessage()` - Check if message exists
- - **Action**: Convert to `getRecordsetWithRBAC(ChatMessage, self.currentUser, recordFilter={"id": messageId})`
- - **Reason**: Existence checks should respect RBAC
-
-7. **Line 1167** - `deleteMessage()` - Cascade delete ChatStat
- - **Action**: Convert to `getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"messageId": messageId})`
- - **Reason**: Deleting related data should respect RBAC
-
-8. **Line 1172** - `deleteMessage()` - Cascade delete ChatDocument
- - **Action**: Convert to `getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})`
- - **Reason**: Deleting related data should respect RBAC
-
-9. **Line 1199** - `deleteFileFromMessage()` - Get documents for message
- - **Action**: Convert to `getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})`
- - **Reason**: Accessing related data should respect RBAC
-
-10. **Line 1242** - `getDocuments()` - Get documents for message
- - **Action**: Convert to `getRecordsetWithRBAC(ChatDocument, self.currentUser, recordFilter={"messageId": messageId})`
- - **Reason**: Public method accessing user data should respect RBAC
-
-11. **Line 1291** - `getLogs()` - Fetch logs for workflow
- - **Action**: Convert to `getRecordsetWithRBAC(ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})`
- - **Reason**: Child records should still respect RBAC even if parent access is verified
-
-12. **Line 1410** - `getStats()` - Fetch stats for workflow
- - **Action**: Convert to `getRecordsetWithRBAC(ChatStat, self.currentUser, recordFilter={"workflowId": workflowId})`
- - **Reason**: Child records should still respect RBAC even if parent access is verified
-
-13. **Line 1460** - `getUnifiedChatData()` - Fetch messages for workflow
- - **Action**: Convert to `getRecordsetWithRBAC(ChatMessage, self.currentUser, recordFilter={"workflowId": workflowId})`
- - **Reason**: Child records should still respect RBAC even if parent access is verified
-
-14. **Line 1501** - `getUnifiedChatData()` - Fetch logs for workflow
- - **Action**: Convert to `getRecordsetWithRBAC(ChatLog, self.currentUser, recordFilter={"workflowId": workflowId})`
- - **Reason**: Child records should still respect RBAC even if parent access is verified
-
----
-
-## interfaceDbComponentObjects.py
-
-### Summary: **3 calls found - 1 keep as-is, 2 should be converted**
-
-### Detailed List:
-
-1. **Line 149** - `_initializeStandardPrompts()` - Check if prompts exist
- - **Action**: **KEEP AS-IS** ✅
- - **Reason**: This is initialization code that runs during bootstrap. It checks if any prompts exist to avoid re-initialization. Since this runs with root user context and is a system-level check, RBAC is not needed here.
-
-2. **Line 947** - `deleteFile()` - Get FileData for deletion
- - **Action**: **CONVERT** to `getRecordsetWithRBAC(FileData, self.currentUser, recordFilter={"id": fileId})`
- - **Reason**: FileData stores binary data associated with FileItem. While it's a technical table, we should still respect RBAC for consistency and security. The file access was already checked via `getFile()`, but FileData access should also be RBAC-filtered.
-
-3. **Line 1032** - `getFileData()` - Get FileData for reading
- - **Action**: **CONVERT** to `getRecordsetWithRBAC(FileData, self.currentUser, recordFilter={"id": fileId})`
- - **Reason**: FileData access should respect RBAC. The file access was already checked via `getFile()`, but FileData access should also be RBAC-filtered for consistency.
-
-**Note on FileData**: FileData is a technical table storing binary file content. However, for consistency and security, RBAC should still be applied. If FileData doesn't have RBAC rules defined, the RBAC filter will effectively be a no-op (allowing access), but the pattern is consistent.
-
----
-
-## Implementation Priority
-
-### High Priority (User-facing data access)
-- All `interfaceDbChatObjects.py` calls (14 calls)
-- `interfaceDbComponentObjects.py` FileData calls (2 calls)
-
-### Low Priority (System initialization)
-- `interfaceDbComponentObjects.py` Prompt initialization check (1 call) - Keep as-is
-
----
-
-## Next Steps
-
-1. Convert all 14 calls in `interfaceDbChatObjects.py` to `getRecordsetWithRBAC()`
-2. Convert 2 FileData calls in `interfaceDbComponentObjects.py` to `getRecordsetWithRBAC()`
-3. Keep 1 Prompt initialization check as-is
-4. Test all changes to ensure RBAC filtering works correctly
-5. Verify cascade delete operations still work correctly with RBAC
-
----
-
-## Testing Checklist
-
-After conversion, verify:
-- [ ] Workflow deletion still works (cascade deletes)
-- [ ] Message deletion still works (cascade deletes)
-- [ ] File deletion still works (FileData cleanup)
-- [ ] File reading still works (FileData access)
-- [ ] Child record access (messages, logs, stats, documents) respects RBAC
-- [ ] Users can only access data they have permission for
-- [ ] No performance degradation from RBAC filtering
diff --git a/modules/aicore/aicoreModelRegistry.py b/modules/aicore/aicoreModelRegistry.py
index 54027a26..8370aaea 100644
--- a/modules/aicore/aicoreModelRegistry.py
+++ b/modules/aicore/aicoreModelRegistry.py
@@ -9,6 +9,10 @@ import os
from typing import Dict, List, Optional, Any
from modules.datamodels.datamodelAi import AiModel
from modules.aicore.aicoreBase import BaseConnectorAi
+from modules.datamodels.datamodelUam import User
+from modules.shared.rbacHelpers import checkResourceAccess
+from modules.security.rbac import RbacClass
+from modules.connectors.connectorDbPostgre import DatabaseConnector
logger = logging.getLogger(__name__)
@@ -142,11 +146,24 @@ class ModelRegistry:
self.refreshModels()
return [model for model in self._models.values() if model.priority == priority]
- def getAvailableModels(self) -> List[AiModel]:
- """Get only available models."""
+ def getAvailableModels(self, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> List[AiModel]:
+ """Get only available models, optionally filtered by RBAC permissions.
+
+ Args:
+ currentUser: Optional user object for RBAC filtering
+ rbacInstance: Optional RBAC instance for permission checks
+
+ Returns:
+ List of available models (filtered by RBAC if user provided)
+ """
self.refreshModels()
allModels = list(self._models.values())
availableModels = [model for model in allModels if model.isAvailable]
+
+ # Apply RBAC filtering if user and RBAC instance provided
+ if currentUser and rbacInstance:
+ availableModels = self._filterModelsByRbac(availableModels, currentUser, rbacInstance)
+
unavailableCount = len(allModels) - len(availableModels)
if unavailableCount > 0:
unavailableModels = [m.name for m in allModels if not m.isAvailable]
@@ -154,6 +171,65 @@ class ModelRegistry:
logger.debug(f"getAvailableModels: Returning {len(availableModels)} models: {[m.name for m in availableModels]}")
return availableModels
+ def _filterModelsByRbac(self, models: List[AiModel], currentUser: User, rbacInstance: RbacClass) -> List[AiModel]:
+ """Filter models based on RBAC permissions.
+
+ Args:
+ models: List of models to filter
+ currentUser: Current user object
+ rbacInstance: RBAC instance for permission checks
+
+ Returns:
+ Filtered list of models that user has access to
+ """
+ filteredModels = []
+ for model in models:
+ # Check access at both connector level and model level
+ connectorResourcePath = f"ai.model.{model.connectorType}"
+ modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}"
+
+ # User needs access to either connector (all models) or specific model
+ hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath)
+ hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath)
+
+ if hasConnectorAccess or hasModelAccess:
+ filteredModels.append(model)
+ else:
+ logger.debug(f"User {currentUser.username} does not have access to model {model.displayName} (connector: {model.connectorType})")
+
+ return filteredModels
+
+ def getModel(self, displayName: str, currentUser: Optional[User] = None, rbacInstance: Optional[RbacClass] = None) -> Optional[AiModel]:
+ """Get a specific model by displayName, optionally checking RBAC permissions.
+
+ Args:
+ displayName: Model display name
+ currentUser: Optional user object for RBAC check
+ rbacInstance: Optional RBAC instance for permission check
+
+ Returns:
+ Model if found and user has access (or if no user provided), None otherwise
+ """
+ self.refreshModels()
+ model = self._models.get(displayName)
+
+ if not model:
+ return None
+
+ # Check RBAC permission if user provided
+ if currentUser and rbacInstance:
+ connectorResourcePath = f"ai.model.{model.connectorType}"
+ modelResourcePath = f"ai.model.{model.connectorType}.{model.displayName}"
+
+ hasConnectorAccess = checkResourceAccess(rbacInstance, currentUser, connectorResourcePath)
+ hasModelAccess = checkResourceAccess(rbacInstance, currentUser, modelResourcePath)
+
+ if not (hasConnectorAccess or hasModelAccess):
+ logger.warning(f"User {currentUser.username} does not have access to model {displayName}")
+ return None
+
+ return model
+
def getConnectorForModel(self, displayName: str) -> Optional[BaseConnectorAi]:
"""Get the connector instance for a specific model by displayName."""
model = self.getModel(displayName)
diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py
index 828fa703..d41d868e 100644
--- a/modules/connectors/connectorDbPostgre.py
+++ b/modules/connectors/connectorDbPostgre.py
@@ -22,16 +22,20 @@ class SystemTable(BaseModel):
table_name: str = Field(
description="Name of the table",
- frontend_type="text",
- frontend_readonly=True,
- frontend_required=True,
+ json_schema_extra={
+ "frontend_type": "text",
+ "frontend_readonly": True,
+ "frontend_required": True,
+ }
)
initial_id: Optional[str] = Field(
default=None,
description="Initial ID for the table",
- frontend_type="text",
- frontend_readonly=True,
- frontend_required=False,
+ json_schema_extra={
+ "frontend_type": "text",
+ "frontend_readonly": True,
+ "frontend_required": False,
+ }
)
@@ -1070,7 +1074,10 @@ class DatabaseConnector:
return []
# Get RBAC permissions for this table
- RbacInstance = RbacClass(self)
+ # AccessRule table is always in DbApp database
+ from modules.interfaces.interfaceDbAppObjects import getRootInterface
+ dbApp = getRootInterface().db
+ RbacInstance = RbacClass(self, dbApp=dbApp)
permissions = RbacInstance.getUserPermissions(
currentUser,
AccessRuleContext.DATA,
diff --git a/modules/datamodels/datamodelRbac.py b/modules/datamodels/datamodelRbac.py
index c2ba90d8..7fcfb6c4 100644
--- a/modules/datamodels/datamodelRbac.py
+++ b/modules/datamodels/datamodelRbac.py
@@ -1,7 +1,7 @@
-"""RBAC models: AccessRule, AccessRuleContext."""
+"""RBAC models: AccessRule, AccessRuleContext, Role."""
import uuid
-from typing import Optional
+from typing import Optional, Dict
from enum import Enum
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import registerModelLabels
@@ -15,6 +15,39 @@ class AccessRuleContext(str, Enum):
RESOURCE = "RESOURCE" # System resources (AI models, actions, etc.)
+class Role(BaseModel):
+ """Data model for RBAC roles"""
+ id: str = Field(
+ default_factory=lambda: str(uuid.uuid4()),
+ description="Unique ID of the role",
+ json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}
+ )
+ roleLabel: str = Field(
+ description="Unique role label identifier (e.g., 'admin', 'user', 'viewer')",
+ json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True}
+ )
+ description: Dict[str, str] = Field(
+ description="Role description in multiple languages",
+ json_schema_extra={"frontend_type": "object", "frontend_readonly": False, "frontend_required": True}
+ )
+ isSystemRole: bool = Field(
+ False,
+ description="Whether this is a system role that cannot be deleted",
+ json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": True, "frontend_required": False}
+ )
+
+registerModelLabels(
+ "Role",
+ {"en": "Role", "fr": "Rôle"},
+ {
+ "id": {"en": "ID", "fr": "ID"},
+ "roleLabel": {"en": "Role Label", "fr": "Label du rôle"},
+ "description": {"en": "Description", "fr": "Description"},
+ "isSystemRole": {"en": "System Role", "fr": "Rôle système"},
+ },
+)
+
+
class AccessRule(BaseModel):
"""Data model for access control rules"""
id: str = Field(
diff --git a/modules/datamodels/datamodelUam.py b/modules/datamodels/datamodelUam.py
index 49e62beb..90068f1b 100644
--- a/modules/datamodels/datamodelUam.py
+++ b/modules/datamodels/datamodelUam.py
@@ -93,20 +93,11 @@ registerModelLabels(
class UserConnection(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Unique ID of the connection", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
userId: str = Field(description="ID of the user this connection belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
- authority: AuthAuthority = Field(description="Authentication authority", json_schema_extra={"frontend_type": "select", "frontend_readonly": True, "frontend_required": False, "frontend_options": [
- {"value": "local", "label": {"en": "Local", "fr": "Local"}},
- {"value": "google", "label": {"en": "Google", "fr": "Google"}},
- {"value": "msft", "label": {"en": "Microsoft", "fr": "Microsoft"}},
- ]})
+ authority: AuthAuthority = Field(description="Authentication authority", json_schema_extra={"frontend_type": "select", "frontend_readonly": True, "frontend_required": False, "frontend_options": "auth.authority"})
externalId: str = Field(description="User ID in the external system", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
externalUsername: str = Field(description="Username in the external system", json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False})
externalEmail: Optional[EmailStr] = Field(None, description="Email in the external system", json_schema_extra={"frontend_type": "email", "frontend_readonly": False, "frontend_required": False})
- status: ConnectionStatus = Field(default=ConnectionStatus.ACTIVE, description="Connection status", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": [
- {"value": "active", "label": {"en": "Active", "fr": "Actif"}},
- {"value": "inactive", "label": {"en": "Inactive", "fr": "Inactif"}},
- {"value": "expired", "label": {"en": "Expired", "fr": "Expiré"}},
- {"value": "pending", "label": {"en": "Pending", "fr": "En attente"}},
- ]})
+ status: ConnectionStatus = Field(default=ConnectionStatus.ACTIVE, description="Connection status", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False, "frontend_options": "connection.status"})
connectedAt: float = Field(default_factory=getUtcTimestamp, description="When the connection was established (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
lastChecked: float = Field(default_factory=getUtcTimestamp, description="When the connection was last verified (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
expiresAt: Optional[float] = Field(None, description="When the connection expires (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
@@ -152,11 +143,7 @@ class User(BaseModel):
description="List of role labels assigned to this user. All roles are opening roles (union) - if one role enables something, it is enabled.",
json_schema_extra={"frontend_type": "multiselect", "frontend_readonly": False, "frontend_required": True, "frontend_options": "user.role"}
)
- authenticationAuthority: AuthAuthority = Field(default=AuthAuthority.LOCAL, description="Primary authentication authority", json_schema_extra={"frontend_type": "select", "frontend_readonly": True, "frontend_required": False, "frontend_options": [
- {"value": "local", "label": {"en": "Local", "fr": "Local"}},
- {"value": "google", "label": {"en": "Google", "fr": "Google"}},
- {"value": "msft", "label": {"en": "Microsoft", "fr": "Microsoft"}},
- ]})
+ authenticationAuthority: AuthAuthority = Field(default=AuthAuthority.LOCAL, description="Primary authentication authority", json_schema_extra={"frontend_type": "select", "frontend_readonly": True, "frontend_required": False, "frontend_options": "auth.authority"})
mandateId: Optional[str] = Field(None, description="ID of the mandate this user belongs to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
registerModelLabels(
"User",
diff --git a/modules/features/options/mainOptions.py b/modules/features/options/mainOptions.py
new file mode 100644
index 00000000..41ef5db2
--- /dev/null
+++ b/modules/features/options/mainOptions.py
@@ -0,0 +1,127 @@
+"""
+Options API feature module.
+Provides dynamic options for frontend select/multiselect fields.
+"""
+
+import logging
+from typing import List, Dict, Any, Optional
+from modules.datamodels.datamodelUam import User, AuthAuthority, ConnectionStatus
+from modules.interfaces.interfaceDbAppObjects import getInterface
+
+logger = logging.getLogger(__name__)
+
+# Standard role definitions (fallback if database is not available)
+STANDARD_ROLES = [
+ {"value": "sysadmin", "label": {"en": "System Administrator", "fr": "Administrateur système"}},
+ {"value": "admin", "label": {"en": "Administrator", "fr": "Administrateur"}},
+ {"value": "user", "label": {"en": "User", "fr": "Utilisateur"}},
+ {"value": "viewer", "label": {"en": "Viewer", "fr": "Visualiseur"}},
+]
+
+# Authentication authority options
+AUTH_AUTHORITY_OPTIONS = [
+ {"value": "local", "label": {"en": "Local", "fr": "Local"}},
+ {"value": "google", "label": {"en": "Google", "fr": "Google"}},
+ {"value": "msft", "label": {"en": "Microsoft", "fr": "Microsoft"}},
+]
+
+# Connection status options
+# Note: Matches ConnectionStatus enum values (active, expired, revoked, pending)
+# Plus "error" for error states (not in enum but used in UI)
+CONNECTION_STATUS_OPTIONS = [
+ {"value": "active", "label": {"en": "Active", "fr": "Actif"}},
+ {"value": "expired", "label": {"en": "Expired", "fr": "Expiré"}},
+ {"value": "revoked", "label": {"en": "Revoked", "fr": "Révoqué"}},
+ {"value": "pending", "label": {"en": "Pending", "fr": "En attente"}},
+ {"value": "error", "label": {"en": "Error", "fr": "Erreur"}},
+]
+
+
+def getOptions(optionsName: str, currentUser: Optional[User] = None) -> List[Dict[str, Any]]:
+ """
+ Get options for a given options name.
+
+ Args:
+ optionsName: Name of the options set to retrieve (e.g., "user.role", "user.connection")
+ currentUser: Optional current user for context-aware options
+
+ Returns:
+ List of option dictionaries with "value" and "label" keys
+
+ Raises:
+ ValueError: If optionsName is not recognized
+ """
+ optionsNameLower = optionsName.lower()
+
+ if optionsNameLower == "user.role":
+ # Fetch roles from database
+ if currentUser:
+ try:
+ interface = getInterface(currentUser)
+ roles = interface.getAllRoles()
+
+ # Convert Role objects to options format
+ options = []
+ for role in roles:
+ # Use English description as label, fallback to roleLabel
+ label = role.description.get("en", role.roleLabel) if isinstance(role.description, dict) else role.roleLabel
+ options.append({
+ "value": role.roleLabel,
+ "label": label
+ })
+
+ # If no roles in database, return standard roles as fallback
+ if options:
+ return options
+ except Exception as e:
+ logger.warning(f"Error fetching roles from database, using fallback: {e}")
+
+ # Fallback to standard roles if database fetch fails or no user context
+ return STANDARD_ROLES
+
+ elif optionsNameLower == "auth.authority":
+ return AUTH_AUTHORITY_OPTIONS
+
+ elif optionsNameLower == "connection.status":
+ return CONNECTION_STATUS_OPTIONS
+
+ elif optionsNameLower == "user.connection":
+ # Dynamic options: Get user connections from database
+ if not currentUser:
+ return []
+
+ try:
+ interface = getInterface(currentUser)
+ connections = interface.getUserConnections(currentUser.id)
+
+ return [
+ {
+ "value": conn.id,
+ "label": {
+ "en": f"{conn.authority.value} - {conn.externalUsername or conn.externalId}",
+ "fr": f"{conn.authority.value} - {conn.externalUsername or conn.externalId}"
+ }
+ }
+ for conn in connections
+ ]
+ except Exception as e:
+ logger.error(f"Error fetching user connections for options: {e}")
+ return []
+
+ else:
+ raise ValueError(f"Unknown options name: {optionsName}")
+
+
+def getAvailableOptionsNames() -> List[str]:
+ """
+ Get list of all available options names.
+
+ Returns:
+ List of available options names
+ """
+ return [
+ "user.role",
+ "auth.authority",
+ "connection.status",
+ "user.connection",
+ ]
diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py
index 55d94c3c..54129c7c 100644
--- a/modules/interfaces/interfaceBootstrap.py
+++ b/modules/interfaces/interfaceBootstrap.py
@@ -4,7 +4,7 @@ Contains all bootstrap logic including mandate, users, and RBAC rules.
"""
import logging
-from typing import Optional
+from typing import Optional, List, Dict, Any
from passlib.context import CryptContext
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
@@ -16,6 +16,7 @@ from modules.datamodels.datamodelUam import (
from modules.datamodels.datamodelRbac import (
AccessRule,
AccessRuleContext,
+ Role,
)
from modules.datamodels.datamodelUam import AccessLevel
@@ -43,6 +44,9 @@ def initBootstrap(db: DatabaseConnector) -> None:
# Initialize event user
eventUserId = initEventUser(db, mandateId)
+ # Initialize roles
+ initRoles(db)
+
# Initialize RBAC rules
initRbacRules(db)
@@ -149,10 +153,59 @@ def initEventUser(db: DatabaseConnector, mandateId: Optional[str]) -> Optional[s
return userId
+def initRoles(db: DatabaseConnector) -> None:
+ """
+ Initialize standard roles if they don't exist.
+
+ Args:
+ db: Database connector instance
+ """
+ logger.info("Initializing roles")
+
+ standardRoles = [
+ Role(
+ roleLabel="sysadmin",
+ description={"en": "System Administrator - Full access to all system resources", "fr": "Administrateur système - Accès complet à toutes les ressources"},
+ isSystemRole=True
+ ),
+ Role(
+ roleLabel="admin",
+ description={"en": "Administrator - Manage users and resources within mandate scope", "fr": "Administrateur - Gérer les utilisateurs et ressources dans le périmètre du mandat"},
+ isSystemRole=True
+ ),
+ Role(
+ roleLabel="user",
+ description={"en": "User - Standard user with access to own records", "fr": "Utilisateur - Utilisateur standard avec accès à ses propres enregistrements"},
+ isSystemRole=True
+ ),
+ Role(
+ roleLabel="viewer",
+ description={"en": "Viewer - Read-only access to group records", "fr": "Visualiseur - Accès en lecture seule aux enregistrements du groupe"},
+ isSystemRole=True
+ ),
+ ]
+
+ existingRoles = db.getRecordset(Role)
+ existingRoleLabels = {role.get("roleLabel") for role in existingRoles}
+
+ for role in standardRoles:
+ if role.roleLabel not in existingRoleLabels:
+ try:
+ db.recordCreate(Role, role)
+ logger.info(f"Created role: {role.roleLabel}")
+ except Exception as e:
+ logger.warning(f"Error creating role {role.roleLabel}: {e}")
+ else:
+ logger.debug(f"Role {role.roleLabel} already exists")
+
+ logger.info("Roles initialization completed")
+
+
def initRbacRules(db: DatabaseConnector) -> None:
"""
Initialize RBAC rules if they don't exist.
Converts all UAM logic from interface*Access.py modules to RBAC rules.
+ Also checks for and adds missing rules for new tables.
Args:
db: Database connector instance
@@ -160,6 +213,8 @@ def initRbacRules(db: DatabaseConnector) -> None:
existingRules = db.getRecordset(AccessRule)
if existingRules:
logger.info(f"RBAC rules already exist ({len(existingRules)} rules)")
+ # Check for missing rules for ChatWorkflow and Prompt tables
+ _addMissingTableRules(db, existingRules)
return
logger.info("Initializing RBAC rules")
@@ -170,6 +225,12 @@ def initRbacRules(db: DatabaseConnector) -> None:
# Create table-specific rules (converted from UAM logic)
createTableSpecificRules(db)
+ # Create UI context rules
+ createUiContextRules(db)
+
+ # Create RESOURCE context rules
+ createResourceContextRules(db)
+
logger.info("RBAC rules initialization completed")
@@ -495,6 +556,90 @@ def createTableSpecificRules(db: DatabaseConnector) -> None:
delete=AccessLevel.NONE,
))
+ # ChatWorkflow table - Users can access their own workflows
+ tableRules.append(AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.DATA,
+ item="ChatWorkflow",
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.DATA,
+ item="ChatWorkflow",
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="ChatWorkflow",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.DATA,
+ item="ChatWorkflow",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+
+ # Prompt table - Users can access their own prompts
+ tableRules.append(AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.DATA,
+ item="Prompt",
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.DATA,
+ item="Prompt",
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.DATA,
+ item="Prompt",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY,
+ ))
+ tableRules.append(AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.DATA,
+ item="Prompt",
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+
# Create all table-specific rules
for rule in tableRules:
db.recordCreate(AccessRule, rule)
@@ -502,6 +647,265 @@ def createTableSpecificRules(db: DatabaseConnector) -> None:
logger.info(f"Created {len(tableRules)} table-specific rules")
+def createUiContextRules(db: DatabaseConnector) -> None:
+ """
+ Create UI context rules for controlling UI element visibility.
+ These rules control which UI components users can see based on their roles.
+
+ Args:
+ db: Database connector instance
+ """
+ uiRules = []
+
+ # Generic UI rules - all roles can view UI by default
+ # Specific UI elements can override these with more restrictive rules
+
+ # Sysadmin - full UI access
+ uiRules.append(AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.UI,
+ item=None,
+ view=True,
+ read=None,
+ create=None,
+ update=None,
+ delete=None,
+ ))
+
+ # Admin - full UI access
+ uiRules.append(AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.UI,
+ item=None,
+ view=True,
+ read=None,
+ create=None,
+ update=None,
+ delete=None,
+ ))
+
+ # User - full UI access
+ uiRules.append(AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.UI,
+ item=None,
+ view=True,
+ read=None,
+ create=None,
+ update=None,
+ delete=None,
+ ))
+
+ # Viewer - full UI access (can view but may have restricted actions)
+ uiRules.append(AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.UI,
+ item=None,
+ view=True,
+ read=None,
+ create=None,
+ update=None,
+ delete=None,
+ ))
+
+ # Create all UI context rules
+ for rule in uiRules:
+ db.recordCreate(AccessRule, rule)
+
+ logger.info(f"Created {len(uiRules)} UI context rules")
+
+
+def createResourceContextRules(db: DatabaseConnector) -> None:
+ """
+ Create RESOURCE context rules for controlling resource access (AI models, actions, etc.).
+ These rules control which resources users can access based on their roles.
+
+ Args:
+ db: Database connector instance
+ """
+ resourceRules = []
+
+ # Generic resource rules - all roles can access resources by default
+ # Specific resources can override these with more restrictive rules
+
+ # Sysadmin - full resource access
+ resourceRules.append(AccessRule(
+ roleLabel="sysadmin",
+ context=AccessRuleContext.RESOURCE,
+ item=None,
+ view=True,
+ read=None,
+ create=None,
+ update=None,
+ delete=None,
+ ))
+
+ # Admin - full resource access
+ resourceRules.append(AccessRule(
+ roleLabel="admin",
+ context=AccessRuleContext.RESOURCE,
+ item=None,
+ view=True,
+ read=None,
+ create=None,
+ update=None,
+ delete=None,
+ ))
+
+ # User - full resource access
+ resourceRules.append(AccessRule(
+ roleLabel="user",
+ context=AccessRuleContext.RESOURCE,
+ item=None,
+ view=True,
+ read=None,
+ create=None,
+ update=None,
+ delete=None,
+ ))
+
+ # Viewer - full resource access (can view but may have restricted actions)
+ resourceRules.append(AccessRule(
+ roleLabel="viewer",
+ context=AccessRuleContext.RESOURCE,
+ item=None,
+ view=True,
+ read=None,
+ create=None,
+ update=None,
+ delete=None,
+ ))
+
+ # Create all RESOURCE context rules
+ for rule in resourceRules:
+ db.recordCreate(AccessRule, rule)
+
+ logger.info(f"Created {len(resourceRules)} RESOURCE context rules")
+
+
+def _addMissingTableRules(db: DatabaseConnector, existingRules: List[Dict[str, Any]]) -> None:
+ """
+ Add missing RBAC rules for tables that were added after initial bootstrap.
+
+ Args:
+ db: Database connector instance
+ existingRules: List of existing AccessRule records
+ """
+ # Check which tables already have rules
+ existingItems = {rule.get("item") for rule in existingRules if rule.get("context") == AccessRuleContext.DATA}
+ existingRoles = {rule.get("roleLabel") for rule in existingRules}
+
+ # Tables that need rules
+ requiredTables = ["ChatWorkflow", "Prompt"]
+ requiredRoles = ["sysadmin", "admin", "user", "viewer"]
+
+ newRules = []
+
+ for table in requiredTables:
+ if table not in existingItems:
+ logger.info(f"Adding missing RBAC rules for table {table}")
+ # ChatWorkflow rules
+ if table == "ChatWorkflow":
+ for roleLabel in requiredRoles:
+ if roleLabel == "sysadmin":
+ newRules.append(AccessRule(
+ roleLabel=roleLabel,
+ context=AccessRuleContext.DATA,
+ item=table,
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL,
+ ))
+ elif roleLabel == "admin":
+ newRules.append(AccessRule(
+ roleLabel=roleLabel,
+ context=AccessRuleContext.DATA,
+ item=table,
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP,
+ ))
+ elif roleLabel == "user":
+ newRules.append(AccessRule(
+ roleLabel=roleLabel,
+ context=AccessRuleContext.DATA,
+ item=table,
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY,
+ ))
+ elif roleLabel == "viewer":
+ newRules.append(AccessRule(
+ roleLabel=roleLabel,
+ context=AccessRuleContext.DATA,
+ item=table,
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+ # Prompt rules (same as ChatWorkflow)
+ elif table == "Prompt":
+ for roleLabel in requiredRoles:
+ if roleLabel == "sysadmin":
+ newRules.append(AccessRule(
+ roleLabel=roleLabel,
+ context=AccessRuleContext.DATA,
+ item=table,
+ view=True,
+ read=AccessLevel.ALL,
+ create=AccessLevel.ALL,
+ update=AccessLevel.ALL,
+ delete=AccessLevel.ALL,
+ ))
+ elif roleLabel == "admin":
+ newRules.append(AccessRule(
+ roleLabel=roleLabel,
+ context=AccessRuleContext.DATA,
+ item=table,
+ view=True,
+ read=AccessLevel.GROUP,
+ create=AccessLevel.GROUP,
+ update=AccessLevel.GROUP,
+ delete=AccessLevel.GROUP,
+ ))
+ elif roleLabel == "user":
+ newRules.append(AccessRule(
+ roleLabel=roleLabel,
+ context=AccessRuleContext.DATA,
+ item=table,
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.MY,
+ update=AccessLevel.MY,
+ delete=AccessLevel.MY,
+ ))
+ elif roleLabel == "viewer":
+ newRules.append(AccessRule(
+ roleLabel=roleLabel,
+ context=AccessRuleContext.DATA,
+ item=table,
+ view=True,
+ read=AccessLevel.MY,
+ create=AccessLevel.NONE,
+ update=AccessLevel.NONE,
+ delete=AccessLevel.NONE,
+ ))
+
+ # Create missing rules
+ if newRules:
+ for rule in newRules:
+ db.recordCreate(AccessRule, rule)
+ logger.info(f"Added {len(newRules)} missing RBAC rules")
+
+
def assignInitialUserRoles(db: DatabaseConnector, adminUserId: str, eventUserId: str) -> None:
"""
Assign initial roles to admin and event users.
@@ -511,23 +915,38 @@ def assignInitialUserRoles(db: DatabaseConnector, adminUserId: str, eventUserId:
adminUserId: Admin user ID
eventUserId: Event user ID
"""
- # Update admin user with sysadmin role
- adminUser = db.getRecordset(UserInDB, recordFilter={"id": adminUserId})
- if adminUser:
- adminUserData = adminUser[0]
- if "sysadmin" not in adminUserData.get("roleLabels", []):
- adminUserData["roleLabels"] = adminUserData.get("roleLabels", []) + ["sysadmin"]
- db.recordUpdate(UserInDB, adminUserId, adminUserData)
- logger.info(f"Assigned sysadmin role to admin user {adminUserId}")
-
- # Update event user with sysadmin role
- eventUser = db.getRecordset(UserInDB, recordFilter={"id": eventUserId})
- if eventUser:
- eventUserData = eventUser[0]
- if "sysadmin" not in eventUserData.get("roleLabels", []):
- eventUserData["roleLabels"] = eventUserData.get("roleLabels", []) + ["sysadmin"]
- db.recordUpdate(UserInDB, eventUserId, eventUserData)
- logger.info(f"Assigned sysadmin role to event user {eventUserId}")
+ # Set context to admin user for bootstrap operations
+ originalUserId = db.userId if hasattr(db, 'userId') else None
+ try:
+ if adminUserId:
+ db.updateContext(adminUserId)
+
+ # Update admin user with sysadmin role
+ adminUser = db.getRecordset(UserInDB, recordFilter={"id": adminUserId})
+ if adminUser:
+ adminUserData = adminUser[0]
+ roleLabels = adminUserData.get("roleLabels") or []
+ if "sysadmin" not in roleLabels:
+ adminUserData["roleLabels"] = roleLabels + ["sysadmin"]
+ db.recordModify(UserInDB, adminUserId, adminUserData)
+ logger.info(f"Assigned sysadmin role to admin user {adminUserId}")
+
+ # Update event user with sysadmin role
+ eventUser = db.getRecordset(UserInDB, recordFilter={"id": eventUserId})
+ if eventUser:
+ eventUserData = eventUser[0]
+ roleLabels = eventUserData.get("roleLabels") or []
+ if "sysadmin" not in roleLabels:
+ eventUserData["roleLabels"] = roleLabels + ["sysadmin"]
+ db.recordModify(UserInDB, eventUserId, eventUserData)
+ logger.info(f"Assigned sysadmin role to event user {eventUserId}")
+ finally:
+ # Restore original context if it existed
+ if originalUserId:
+ db.updateContext(originalUserId)
+ elif hasattr(db, 'userId'):
+ # If original was None/empty, just set it directly
+ db.userId = originalUserId
def _getPasswordHash(password: Optional[str]) -> Optional[str]:
diff --git a/modules/interfaces/interfaceDbAppObjects.py b/modules/interfaces/interfaceDbAppObjects.py
index cf582fa2..8be2f7dd 100644
--- a/modules/interfaces/interfaceDbAppObjects.py
+++ b/modules/interfaces/interfaceDbAppObjects.py
@@ -25,6 +25,7 @@ from modules.datamodels.datamodelUam import (
from modules.datamodels.datamodelRbac import (
AccessRule,
AccessRuleContext,
+ Role,
)
from modules.datamodels.datamodelUam import AccessLevel
from modules.datamodels.datamodelSecurity import Token, AuthEvent, TokenStatus
@@ -88,7 +89,8 @@ class AppObjects:
# Initialize RBAC interface
if not currentUser:
raise ValueError("User context is required for RBAC")
- self.rbac = RbacClass(self.db)
+ # Pass self.db as dbApp since this interface uses DbApp database
+ self.rbac = RbacClass(self.db, dbApp=self.db)
# Update database context
self.db.updateContext(self.userId)
@@ -424,10 +426,13 @@ class AppObjects:
recordFilter={"mandateId": mandateId} if mandateId else None
)
- # Filter out database-specific fields
+ # Filter out database-specific fields and normalize data
filteredUsers = []
for user in users:
cleanedUser = {k: v for k, v in user.items() if not k.startswith("_")}
+ # Ensure roleLabels is always a list, not None
+ if cleanedUser.get("roleLabels") is None:
+ cleanedUser["roleLabels"] = []
filteredUsers.append(cleanedUser)
# If no pagination requested, return all items
@@ -451,6 +456,11 @@ class AppObjects:
endIdx = startIdx + pagination.pageSize
pagedUsers = filteredUsers[startIdx:endIdx]
+ # Ensure roleLabels is always a list for paginated results too
+ for user in pagedUsers:
+ if user.get("roleLabels") is None:
+ user["roleLabels"] = []
+
# Convert to model objects
items = [User(**user) for user in pagedUsers]
@@ -478,6 +488,9 @@ class AppObjects:
userDict = users[0]
# Filter out database-specific fields
cleanedUser = {k: v for k, v in userDict.items() if not k.startswith("_")}
+ # Ensure roleLabels is always a list, not None
+ if cleanedUser.get("roleLabels") is None:
+ cleanedUser["roleLabels"] = []
return User(**cleanedUser)
except Exception as e:
@@ -500,6 +513,9 @@ class AppObjects:
# User already filtered by RBAC, just clean fields
user_dict = users[0]
cleanedUser = {k: v for k, v in user_dict.items() if not k.startswith("_")}
+ # Ensure roleLabels is always a list, not None
+ if cleanedUser.get("roleLabels") is None:
+ cleanedUser["roleLabels"] = []
return User(**cleanedUser)
except Exception as e:
@@ -1525,7 +1541,7 @@ class AppObjects:
Updated AccessRule object
"""
try:
- updatedRule = self.db.recordUpdate(AccessRule, ruleId, accessRule.model_dump())
+ updatedRule = self.db.recordModify(AccessRule, ruleId, accessRule.model_dump())
logger.info(f"Updated access rule with ID {ruleId}")
return AccessRule(**updatedRule)
except Exception as e:
@@ -1601,7 +1617,8 @@ class AppObjects:
List of AccessRule objects (most specific for each role)
"""
try:
- RbacInstance = RbacClass(self.db)
+ # Pass self.db as dbApp since this interface uses DbApp database
+ RbacInstance = RbacClass(self.db, dbApp=self.db)
allRules = []
for roleLabel in roleLabels:
@@ -1619,6 +1636,149 @@ class AppObjects:
logger.error(f"Error getting access rules for roles: {str(e)}")
return []
+ def createRole(self, role: Role) -> Role:
+ """
+ Create a new role.
+
+ Args:
+ role: Role object to create
+
+ Returns:
+ Created Role object
+ """
+ try:
+ # Check if role label already exists
+ existingRoles = self.db.getRecordset(Role, recordFilter={"roleLabel": role.roleLabel})
+ if existingRoles:
+ raise ValueError(f"Role with label '{role.roleLabel}' already exists")
+
+ createdRole = self.db.recordCreate(Role, role)
+ logger.info(f"Created role with ID {createdRole.get('id')} and label {role.roleLabel}")
+ return Role(**createdRole)
+ except Exception as e:
+ logger.error(f"Error creating role: {str(e)}")
+ raise
+
+ def getRole(self, roleId: str) -> Optional[Role]:
+ """
+ Get a role by ID.
+
+ Args:
+ roleId: Role ID
+
+ Returns:
+ Role object if found, None otherwise
+ """
+ try:
+ roles = self.db.getRecordset(Role, recordFilter={"id": roleId})
+ if roles:
+ return Role(**roles[0])
+ return None
+ except Exception as e:
+ logger.error(f"Error getting role {roleId}: {str(e)}")
+ return None
+
+ def getRoleByLabel(self, roleLabel: str) -> Optional[Role]:
+ """
+ Get a role by label.
+
+ Args:
+ roleLabel: Role label
+
+ Returns:
+ Role object if found, None otherwise
+ """
+ try:
+ roles = self.db.getRecordset(Role, recordFilter={"roleLabel": roleLabel})
+ if roles:
+ return Role(**roles[0])
+ return None
+ except Exception as e:
+ logger.error(f"Error getting role by label {roleLabel}: {str(e)}")
+ return None
+
+ def getAllRoles(self) -> List[Role]:
+ """
+ Get all roles.
+
+ Returns:
+ List of Role objects
+ """
+ try:
+ roles = self.db.getRecordset(Role)
+ return [Role(**role) for role in roles]
+ except Exception as e:
+ logger.error(f"Error getting all roles: {str(e)}")
+ return []
+
+ def updateRole(self, roleId: str, role: Role) -> Role:
+ """
+ Update an existing role.
+
+ Args:
+ roleId: Role ID
+ role: Updated Role object
+
+ Returns:
+ Updated Role object
+ """
+ try:
+ # Check if role exists
+ existingRole = self.getRole(roleId)
+ if not existingRole:
+ raise ValueError(f"Role with ID {roleId} not found")
+
+ # If role label is being changed, check for conflicts
+ if role.roleLabel != existingRole.roleLabel:
+ conflictingRole = self.getRoleByLabel(role.roleLabel)
+ if conflictingRole and conflictingRole.id != roleId:
+ raise ValueError(f"Role with label '{role.roleLabel}' already exists")
+
+ updatedRole = self.db.recordModify(Role, roleId, role.model_dump())
+ logger.info(f"Updated role with ID {roleId}")
+ return Role(**updatedRole)
+ except Exception as e:
+ logger.error(f"Error updating role {roleId}: {str(e)}")
+ raise
+
+ def deleteRole(self, roleId: str) -> bool:
+ """
+ Delete a role.
+
+ Args:
+ roleId: Role ID
+
+ Returns:
+ True if deleted successfully, False otherwise
+ """
+ try:
+ # Check if role exists
+ role = self.getRole(roleId)
+ if not role:
+ return False
+
+ # Prevent deletion of system roles
+ if role.isSystemRole:
+ raise ValueError(f"Cannot delete system role '{role.roleLabel}'")
+
+ # Check if role is assigned to any users
+ allUsers = self.getUsers()
+ for user in allUsers:
+ if role.roleLabel in (user.roleLabels or []):
+ raise ValueError(f"Cannot delete role '{role.roleLabel}' - it is assigned to users")
+
+ # Check if role is used in any access rules
+ accessRules = self.getAccessRules(roleLabel=role.roleLabel)
+ if accessRules:
+ raise ValueError(f"Cannot delete role '{role.roleLabel}' - it is used in access rules")
+
+ self.db.recordDelete(Role, roleId)
+ logger.info(f"Deleted role with ID {roleId}")
+ return True
+ except Exception as e:
+ logger.error(f"Error deleting role {roleId}: {str(e)}")
+ raise
+
# Public Methods
diff --git a/modules/interfaces/interfaceDbChatObjects.py b/modules/interfaces/interfaceDbChatObjects.py
index ac6df640..fba9ee88 100644
--- a/modules/interfaces/interfaceDbChatObjects.py
+++ b/modules/interfaces/interfaceDbChatObjects.py
@@ -268,7 +268,10 @@ class ChatObjects:
# Initialize RBAC interface
if not self.currentUser:
raise ValueError("User context is required for RBAC")
- self.rbac = RbacClass(self.db)
+ # Get DbApp connection for RBAC AccessRule queries
+ from modules.interfaces.interfaceDbAppObjects import getRootInterface
+ dbApp = getRootInterface().db
+ self.rbac = RbacClass(self.db, dbApp=dbApp)
# Update database context
self.db.updateContext(self.userId)
diff --git a/modules/interfaces/interfaceDbComponentObjects.py b/modules/interfaces/interfaceDbComponentObjects.py
index cedc1fec..98ad0886 100644
--- a/modules/interfaces/interfaceDbComponentObjects.py
+++ b/modules/interfaces/interfaceDbComponentObjects.py
@@ -85,7 +85,10 @@ class ComponentObjects:
# Initialize RBAC interface
if not self.currentUser:
raise ValueError("User context is required for RBAC")
- self.rbac = RbacClass(self.db)
+ # Get DbApp connection for RBAC AccessRule queries
+ from modules.interfaces.interfaceDbAppObjects import getRootInterface
+ dbApp = getRootInterface().db
+ self.rbac = RbacClass(self.db, dbApp=dbApp)
# Update database context
self.db.updateContext(self.userId)
diff --git a/modules/routes/routeAdminRbacRoles.py b/modules/routes/routeAdminRbacRoles.py
new file mode 100644
index 00000000..38e92e04
--- /dev/null
+++ b/modules/routes/routeAdminRbacRoles.py
@@ -0,0 +1,716 @@
+"""
+Admin RBAC Roles Management routes.
+Provides endpoints for managing roles and role assignments to users.
+"""
+
+from fastapi import APIRouter, HTTPException, Depends, Query, Body, Path, Request
+from typing import List, Dict, Any, Optional
+import logging
+
+from modules.security.auth import getCurrentUser, limiter
+from modules.datamodels.datamodelUam import User, UserInDB
+from modules.datamodels.datamodelRbac import Role
+from modules.interfaces.interfaceDbAppObjects import getInterface
+
+# Configure logger
+logger = logging.getLogger(__name__)
+
+router = APIRouter(
+ prefix="/api/admin/rbac/roles",
+ tags=["Admin RBAC Roles"],
+ responses={404: {"description": "Not found"}}
+)
+
+
+def _ensureAdminAccess(currentUser: User) -> None:
+ """Ensure current user has admin access to RBAC roles management."""
+ interface = getInterface(currentUser)
+
+ # Check if user has admin or sysadmin role
+ roleLabels = currentUser.roleLabels or []
+ if "sysadmin" not in roleLabels and "admin" not in roleLabels:
+ raise HTTPException(
+ status_code=403,
+ detail="Admin or sysadmin role required to manage RBAC roles"
+ )
+
+ # Additional RBAC check: verify user has permission to update UserInDB
+ # This is already covered by admin/sysadmin role check above, but we can add explicit RBAC check if needed
+ # For now, admin/sysadmin role check is sufficient
+
+
+@router.get("/", response_model=List[Dict[str, Any]])
+@limiter.limit("60/minute")
+async def listRoles(
+ request: Request,
+ currentUser: User = Depends(getCurrentUser)
+) -> List[Dict[str, Any]]:
+ """
+ Get list of all available roles with metadata.
+
+ Returns:
+ - List of role dictionaries with role label, description, and user count
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ # Get all roles from database
+ dbRoles = interface.getAllRoles()
+
+ # Get all users to count role assignments
+ allUsers = interface.getUsers()
+
+ # Count users per role
+ roleCounts: Dict[str, int] = {}
+ for user in allUsers:
+ for roleLabel in (user.roleLabels or []):
+ roleCounts[roleLabel] = roleCounts.get(roleLabel, 0) + 1
+
+ # Convert Role objects to dictionaries and add user counts
+ result = []
+ for role in dbRoles:
+ result.append({
+ "id": role.id,
+ "roleLabel": role.roleLabel,
+ "description": role.description,
+ "userCount": roleCounts.get(role.roleLabel, 0),
+ "isSystemRole": role.isSystemRole
+ })
+
+ # Add any roles found in user assignments that don't exist in database
+ dbRoleLabels = {role.roleLabel for role in dbRoles}
+ for roleLabel, count in roleCounts.items():
+ if roleLabel not in dbRoleLabels:
+ result.append({
+ "id": None,
+ "roleLabel": roleLabel,
+ "description": {"en": f"Custom role: {roleLabel}", "fr": f"Rôle personnalisé : {roleLabel}"},
+ "userCount": count,
+ "isSystemRole": False
+ })
+
+ return result
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error listing roles: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to list roles: {str(e)}"
+ )
+
+
+@router.get("/options", response_model=List[Dict[str, Any]])
+@limiter.limit("60/minute")
+async def getRoleOptions(
+ request: Request,
+ currentUser: User = Depends(getCurrentUser)
+) -> List[Dict[str, Any]]:
+ """
+ Get role options for select dropdowns.
+ Returns roles in format suitable for frontend select components.
+
+ Returns:
+ - List of role option dictionaries with value and label
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ # Get all roles from database
+ dbRoles = interface.getAllRoles()
+
+ # Convert to options format
+ options = []
+ for role in dbRoles:
+ # Use English description as label, fallback to roleLabel
+ label = role.description.get("en", role.roleLabel) if isinstance(role.description, dict) else role.roleLabel
+ options.append({
+ "value": role.roleLabel,
+ "label": label
+ })
+
+ return options
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting role options: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to get role options: {str(e)}"
+ )
+
+
+@router.post("/", response_model=Dict[str, Any])
+@limiter.limit("30/minute")
+async def createRole(
+ request: Request,
+ role: Role = Body(...),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, Any]:
+ """
+ Create a new role.
+
+ Request Body:
+ - role: Role object to create
+
+ Returns:
+ - Created role dictionary
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ createdRole = interface.createRole(role)
+
+ return {
+ "id": createdRole.id,
+ "roleLabel": createdRole.roleLabel,
+ "description": createdRole.description,
+ "isSystemRole": createdRole.isSystemRole
+ }
+
+ except HTTPException:
+ raise
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=str(e)
+ )
+ except Exception as e:
+ logger.error(f"Error creating role: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to create role: {str(e)}"
+ )
+
+
+@router.get("/{roleId}", response_model=Dict[str, Any])
+@limiter.limit("60/minute")
+async def getRole(
+ request: Request,
+ roleId: str = Path(..., description="Role ID"),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, Any]:
+ """
+ Get a role by ID.
+
+ Path Parameters:
+ - roleId: Role ID
+
+ Returns:
+ - Role dictionary
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ role = interface.getRole(roleId)
+ if not role:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Role {roleId} not found"
+ )
+
+ return {
+ "id": role.id,
+ "roleLabel": role.roleLabel,
+ "description": role.description,
+ "isSystemRole": role.isSystemRole
+ }
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting role: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to get role: {str(e)}"
+ )
+
+
+@router.put("/{roleId}", response_model=Dict[str, Any])
+@limiter.limit("30/minute")
+async def updateRole(
+ request: Request,
+ roleId: str = Path(..., description="Role ID"),
+ role: Role = Body(...),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, Any]:
+ """
+ Update an existing role.
+
+ Path Parameters:
+ - roleId: Role ID
+
+ Request Body:
+ - role: Updated Role object
+
+ Returns:
+ - Updated role dictionary
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ updatedRole = interface.updateRole(roleId, role)
+
+ return {
+ "id": updatedRole.id,
+ "roleLabel": updatedRole.roleLabel,
+ "description": updatedRole.description,
+ "isSystemRole": updatedRole.isSystemRole
+ }
+
+ except HTTPException:
+ raise
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=str(e)
+ )
+ except Exception as e:
+ logger.error(f"Error updating role: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to update role: {str(e)}"
+ )
+
+
+@router.delete("/{roleId}", response_model=Dict[str, str])
+@limiter.limit("30/minute")
+async def deleteRole(
+ request: Request,
+ roleId: str = Path(..., description="Role ID"),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, str]:
+ """
+ Delete a role.
+
+ Path Parameters:
+ - roleId: Role ID
+
+ Returns:
+ - Success message
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ success = interface.deleteRole(roleId)
+ if not success:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Role {roleId} not found"
+ )
+
+ return {"message": f"Role {roleId} deleted successfully"}
+
+ except HTTPException:
+ raise
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=str(e)
+ )
+ except Exception as e:
+ logger.error(f"Error deleting role: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to delete role: {str(e)}"
+ )
+
+
+@router.get("/users", response_model=List[Dict[str, Any]])
+@limiter.limit("60/minute")
+async def listUsersWithRoles(
+ request: Request,
+ roleLabel: Optional[str] = Query(None, description="Filter by role label"),
+ mandateId: Optional[str] = Query(None, description="Filter by mandate ID"),
+ currentUser: User = Depends(getCurrentUser)
+) -> List[Dict[str, Any]]:
+ """
+ Get list of users with their role assignments.
+
+ Query Parameters:
+ - roleLabel: Optional filter by role label
+ - mandateId: Optional filter by mandate ID
+
+ Returns:
+ - List of user dictionaries with role assignments
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ # Get users based on filters
+ if mandateId:
+ # Filter by mandate (if user has permission)
+ users = interface.getUsers()
+ users = [u for u in users if u.mandateId == mandateId]
+ else:
+ users = interface.getUsers()
+
+ # Filter by role if specified
+ if roleLabel:
+ users = [u for u in users if roleLabel in (u.roleLabels or [])]
+
+ # Format response
+ result = []
+ for user in users:
+ result.append({
+ "id": user.id,
+ "username": user.username,
+ "email": user.email,
+ "fullName": user.fullName,
+ "mandateId": user.mandateId,
+ "enabled": user.enabled,
+ "roleLabels": user.roleLabels or [],
+ "roleCount": len(user.roleLabels or [])
+ })
+
+ return result
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error listing users with roles: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to list users with roles: {str(e)}"
+ )
+
+
+@router.get("/users/{userId}", response_model=Dict[str, Any])
+@limiter.limit("60/minute")
+async def getUserRoles(
+ request: Request,
+ userId: str = Path(..., description="User ID"),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, Any]:
+ """
+ Get role assignments for a specific user.
+
+ Path Parameters:
+ - userId: User ID
+
+ Returns:
+ - User dictionary with role assignments
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ # Get user
+ user = interface.getUser(userId)
+ if not user:
+ raise HTTPException(
+ status_code=404,
+ detail=f"User {userId} not found"
+ )
+
+ return {
+ "id": user.id,
+ "username": user.username,
+ "email": user.email,
+ "fullName": user.fullName,
+ "mandateId": user.mandateId,
+ "enabled": user.enabled,
+ "roleLabels": user.roleLabels or [],
+ "roleCount": len(user.roleLabels or [])
+ }
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting user roles: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to get user roles: {str(e)}"
+ )
+
+
+@router.put("/users/{userId}/roles", response_model=Dict[str, Any])
+@limiter.limit("30/minute")
+async def updateUserRoles(
+ request: Request,
+ userId: str = Path(..., description="User ID"),
+ roleLabels: List[str] = Body(..., description="List of role labels to assign"),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, Any]:
+ """
+ Update role assignments for a specific user.
+
+ Path Parameters:
+ - userId: User ID
+
+ Request Body:
+ - roleLabels: List of role labels to assign (e.g., ["admin", "user"])
+
+ Returns:
+ - Updated user dictionary with role assignments
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ # Get user
+ user = interface.getUser(userId)
+ if not user:
+ raise HTTPException(
+ status_code=404,
+ detail=f"User {userId} not found"
+ )
+
+ # Validate role labels (basic validation - check against standard roles)
+ standardRoles = ["sysadmin", "admin", "user", "viewer"]
+ for roleLabel in roleLabels:
+ if roleLabel not in standardRoles:
+ logger.warning(f"Non-standard role label assigned: {roleLabel}")
+
+ # Update user roles
+ userData = {
+ "roleLabels": roleLabels
+ }
+
+ updatedUser = interface.updateUser(userId, userData)
+
+ logger.info(f"Updated roles for user {userId}: {roleLabels}")
+
+ return {
+ "id": updatedUser.id,
+ "username": updatedUser.username,
+ "email": updatedUser.email,
+ "fullName": updatedUser.fullName,
+ "mandateId": updatedUser.mandateId,
+ "enabled": updatedUser.enabled,
+ "roleLabels": updatedUser.roleLabels or [],
+ "roleCount": len(updatedUser.roleLabels or [])
+ }
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error updating user roles: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to update user roles: {str(e)}"
+ )
+
+
+@router.post("/users/{userId}/roles/{roleLabel}", response_model=Dict[str, Any])
+@limiter.limit("30/minute")
+async def addUserRole(
+ request: Request,
+ userId: str = Path(..., description="User ID"),
+ roleLabel: str = Path(..., description="Role label to add"),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, Any]:
+ """
+ Add a role to a user (if not already assigned).
+
+ Path Parameters:
+ - userId: User ID
+ - roleLabel: Role label to add
+
+ Returns:
+ - Updated user dictionary with role assignments
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ # Get user
+ user = interface.getUser(userId)
+ if not user:
+ raise HTTPException(
+ status_code=404,
+ detail=f"User {userId} not found"
+ )
+
+ # Get current roles
+ currentRoles = list(user.roleLabels or [])
+
+ # Add role if not already present
+ if roleLabel not in currentRoles:
+ currentRoles.append(roleLabel)
+
+ # Update user roles
+ userData = {
+ "roleLabels": currentRoles
+ }
+
+ updatedUser = interface.updateUser(userId, userData)
+
+ logger.info(f"Added role {roleLabel} to user {userId}")
+ else:
+ updatedUser = user
+
+ return {
+ "id": updatedUser.id,
+ "username": updatedUser.username,
+ "email": updatedUser.email,
+ "fullName": updatedUser.fullName,
+ "mandateId": updatedUser.mandateId,
+ "enabled": updatedUser.enabled,
+ "roleLabels": updatedUser.roleLabels or [],
+ "roleCount": len(updatedUser.roleLabels or [])
+ }
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error adding role to user: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to add role to user: {str(e)}"
+ )
+
+
+@router.delete("/users/{userId}/roles/{roleLabel}", response_model=Dict[str, Any])
+@limiter.limit("30/minute")
+async def removeUserRole(
+ request: Request,
+ userId: str = Path(..., description="User ID"),
+ roleLabel: str = Path(..., description="Role label to remove"),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, Any]:
+ """
+ Remove a role from a user.
+
+ Path Parameters:
+ - userId: User ID
+ - roleLabel: Role label to remove
+
+ Returns:
+ - Updated user dictionary with role assignments
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ # Get user
+ user = interface.getUser(userId)
+ if not user:
+ raise HTTPException(
+ status_code=404,
+ detail=f"User {userId} not found"
+ )
+
+ # Get current roles
+ currentRoles = list(user.roleLabels or [])
+
+ # Remove role if present
+ if roleLabel in currentRoles:
+ currentRoles.remove(roleLabel)
+
+ # Ensure user has at least one role (default to "user")
+ if not currentRoles:
+ currentRoles = ["user"]
+ logger.warning(f"User {userId} had all roles removed, defaulting to 'user' role")
+
+ # Update user roles
+ userData = {
+ "roleLabels": currentRoles
+ }
+
+ updatedUser = interface.updateUser(userId, userData)
+
+ logger.info(f"Removed role {roleLabel} from user {userId}")
+ else:
+ updatedUser = user
+
+ return {
+ "id": updatedUser.id,
+ "username": updatedUser.username,
+ "email": updatedUser.email,
+ "fullName": updatedUser.fullName,
+ "mandateId": updatedUser.mandateId,
+ "enabled": updatedUser.enabled,
+ "roleLabels": updatedUser.roleLabels or [],
+ "roleCount": len(updatedUser.roleLabels or [])
+ }
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error removing role from user: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to remove role from user: {str(e)}"
+ )
+
+
+@router.get("/roles/{roleLabel}/users", response_model=List[Dict[str, Any]])
+@limiter.limit("60/minute")
+async def getUsersWithRole(
+ request: Request,
+ roleLabel: str = Path(..., description="Role label"),
+ mandateId: Optional[str] = Query(None, description="Filter by mandate ID"),
+ currentUser: User = Depends(getCurrentUser)
+) -> List[Dict[str, Any]]:
+ """
+ Get all users with a specific role.
+
+ Path Parameters:
+ - roleLabel: Role label
+
+ Query Parameters:
+ - mandateId: Optional filter by mandate ID
+
+ Returns:
+ - List of users with the specified role
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ # Get all users
+ users = interface.getUsers()
+
+ # Filter by role
+ users = [u for u in users if roleLabel in (u.roleLabels or [])]
+
+ # Filter by mandate if specified
+ if mandateId:
+ users = [u for u in users if u.mandateId == mandateId]
+
+ # Format response
+ result = []
+ for user in users:
+ result.append({
+ "id": user.id,
+ "username": user.username,
+ "email": user.email,
+ "fullName": user.fullName,
+ "mandateId": user.mandateId,
+ "enabled": user.enabled,
+ "roleLabels": user.roleLabels or [],
+ "roleCount": len(user.roleLabels or [])
+ })
+
+ return result
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting users with role: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to get users with role: {str(e)}"
+ )
diff --git a/modules/routes/routeAttributes.py b/modules/routes/routeAttributes.py
index 5ada9a4e..59c5e0d5 100644
--- a/modules/routes/routeAttributes.py
+++ b/modules/routes/routeAttributes.py
@@ -46,15 +46,29 @@ async def get_entity_attributes(
# Get model class and derive attributes from it
modelClass = modelClasses[entityType]
- attribute_defs = getModelAttributeDefinitions(modelClass)
+ try:
+ attribute_defs = getModelAttributeDefinitions(modelClass)
+ except Exception as e:
+ logger.error(f"Error getting attribute definitions for {entityType}: {str(e)}", exc_info=True)
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Error getting attribute definitions for {entityType}: {str(e)}"
+ )
# Convert dictionary attributes to AttributeDefinition objects
attribute_definitions = []
- for attr in attribute_defs["attributes"]:
- if isinstance(attr, dict) and attr.get('visible', True):
- attribute_definitions.append(AttributeDefinition(**attr))
- elif hasattr(attr, 'visible') and attr.visible:
- attribute_definitions.append(attr)
+ try:
+ for attr in attribute_defs["attributes"]:
+ if isinstance(attr, dict) and attr.get('visible', True):
+ attribute_definitions.append(AttributeDefinition(**attr))
+ elif hasattr(attr, 'visible') and attr.visible:
+ attribute_definitions.append(attr)
+ except Exception as e:
+ logger.error(f"Error converting attribute definitions for {entityType}: {str(e)}", exc_info=True)
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Error converting attribute definitions for {entityType}: {str(e)}"
+ )
return AttributeResponse(attributes=attribute_definitions)
diff --git a/modules/routes/routeOptions.py b/modules/routes/routeOptions.py
new file mode 100644
index 00000000..86d53c0f
--- /dev/null
+++ b/modules/routes/routeOptions.py
@@ -0,0 +1,81 @@
+"""
+Options API routes for dynamic frontend options.
+Provides endpoints for fetching options for select/multiselect fields.
+"""
+
+from fastapi import APIRouter, HTTPException, Depends, Query, Request
+from typing import List, Dict, Any
+import logging
+
+from modules.security.auth import getCurrentUser, limiter
+from modules.datamodels.datamodelUam import User
+from modules.features.options.mainOptions import getOptions, getAvailableOptionsNames
+
+# Configure logger
+logger = logging.getLogger(__name__)
+
+router = APIRouter(
+ prefix="/api/options",
+ tags=["Options"],
+ responses={404: {"description": "Not found"}}
+)
+
+
+@router.get("/{optionsName}", response_model=List[Dict[str, Any]])
+@limiter.limit("120/minute")
+async def getOptionsEndpoint(
+ request: Request,
+ optionsName: str,
+ currentUser: User = Depends(getCurrentUser)
+) -> List[Dict[str, Any]]:
+ """
+ Get options for a given options name.
+
+ Path Parameters:
+ - optionsName: Name of the options set (e.g., "user.role", "user.connection")
+
+ Returns:
+ - List of option dictionaries with "value" and "label" keys
+
+ Examples:
+ - GET /api/options/user.role
+ - GET /api/options/user.connection
+ - GET /api/options/auth.authority
+ - GET /api/options/connection.status
+ """
+ try:
+ options = getOptions(optionsName, currentUser)
+ return options
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=str(e)
+ )
+ except Exception as e:
+ logger.error(f"Error getting options for {optionsName}: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to get options: {str(e)}"
+ )
+
+
+@router.get("/", response_model=List[str])
+@limiter.limit("30/minute")
+async def listAvailableOptions(
+ request: Request,
+ currentUser: User = Depends(getCurrentUser)
+) -> List[str]:
+ """
+ Get list of all available options names.
+
+ Returns:
+ - List of available options names
+ """
+ try:
+ return getAvailableOptionsNames()
+ except Exception as e:
+ logger.error(f"Error listing available options: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to list options: {str(e)}"
+ )
diff --git a/modules/routes/routeRbac.py b/modules/routes/routeRbac.py
index 95184779..975f23b9 100644
--- a/modules/routes/routeRbac.py
+++ b/modules/routes/routeRbac.py
@@ -3,13 +3,13 @@ RBAC routes for the backend API.
Implements endpoints for role-based access control permissions.
"""
-from fastapi import APIRouter, HTTPException, Depends, Query, Request
-from typing import Optional
+from fastapi import APIRouter, HTTPException, Depends, Query, Body, Path, Request
+from typing import Optional, List, Dict, Any
import logging
from modules.security.auth import getCurrentUser, limiter
from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel
-from modules.datamodels.datamodelRbac import AccessRuleContext
+from modules.datamodels.datamodelRbac import AccessRuleContext, AccessRule, Role
from modules.interfaces.interfaceDbAppObjects import getInterface
# Configure logger
@@ -159,3 +159,623 @@ async def getAccessRules(
status_code=500,
detail=f"Failed to get access rules: {str(e)}"
)
+
+
+@router.get("/rules/{ruleId}", response_model=dict)
+@limiter.limit("30/minute")
+async def getAccessRule(
+ request: Request,
+ ruleId: str = Path(..., description="Access rule ID"),
+ currentUser: User = Depends(getCurrentUser)
+) -> dict:
+ """
+ Get a specific access rule by ID.
+ Only returns rule if the current user has permission to view it.
+
+ Path Parameters:
+ - ruleId: Access rule ID
+
+ Returns:
+ - AccessRule object
+ """
+ try:
+ # Get interface
+ interface = getInterface(currentUser)
+
+ # Check if user has permission to view access rules
+ if not interface.rbac:
+ raise HTTPException(
+ status_code=500,
+ detail="RBAC interface not available"
+ )
+
+ # Check permission - only sysadmin can view rules
+ permissions = interface.rbac.getUserPermissions(
+ currentUser,
+ AccessRuleContext.DATA,
+ "AccessRule"
+ )
+
+ if not permissions.view or permissions.read == AccessLevel.NONE:
+ raise HTTPException(
+ status_code=403,
+ detail="No permission to view access rules"
+ )
+
+ # Get rule
+ rule = interface.getAccessRule(ruleId)
+ if not rule:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Access rule {ruleId} not found"
+ )
+
+ # Convert to dict for JSON serialization
+ return rule.model_dump()
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting access rule {ruleId}: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to get access rule: {str(e)}"
+ )
+
+
+@router.post("/rules", response_model=dict)
+@limiter.limit("30/minute")
+async def createAccessRule(
+ request: Request,
+ accessRuleData: dict = Body(..., description="Access rule data"),
+ currentUser: User = Depends(getCurrentUser)
+) -> dict:
+ """
+ Create a new access rule.
+ Only sysadmin can create access rules.
+
+ Request Body:
+ - AccessRule object data (roleLabel, context, item, view, read, create, update, delete)
+
+ Returns:
+ - Created AccessRule object
+ """
+ try:
+ # Get interface
+ interface = getInterface(currentUser)
+
+ # Check if user has permission to create access rules
+ if not interface.rbac:
+ raise HTTPException(
+ status_code=500,
+ detail="RBAC interface not available"
+ )
+
+ # Check permission - only sysadmin can create rules
+ permissions = interface.rbac.getUserPermissions(
+ currentUser,
+ AccessRuleContext.DATA,
+ "AccessRule"
+ )
+
+ if not permissions.create or permissions.create == AccessLevel.NONE:
+ raise HTTPException(
+ status_code=403,
+ detail="No permission to create access rules"
+ )
+
+ # Validate and parse access rule data
+ try:
+ # Parse context if provided as string
+ if "context" in accessRuleData and isinstance(accessRuleData["context"], str):
+ accessRuleData["context"] = AccessRuleContext(accessRuleData["context"].upper())
+
+ # Parse AccessLevel fields if provided as strings
+ for field in ["read", "create", "update", "delete"]:
+ if field in accessRuleData and isinstance(accessRuleData[field], str):
+ accessRuleData[field] = AccessLevel(accessRuleData[field])
+
+ # Create AccessRule object
+ accessRule = AccessRule(**accessRuleData)
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid access rule data: {str(e)}"
+ )
+
+ # Create rule
+ createdRule = interface.createAccessRule(accessRule)
+
+ logger.info(f"Created access rule {createdRule.id} by user {currentUser.id}")
+
+ # Convert to dict for JSON serialization
+ return createdRule.model_dump()
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error creating access rule: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to create access rule: {str(e)}"
+ )
+
+
+@router.put("/rules/{ruleId}", response_model=dict)
+@limiter.limit("30/minute")
+async def updateAccessRule(
+ request: Request,
+ ruleId: str = Path(..., description="Access rule ID"),
+ accessRuleData: dict = Body(..., description="Updated access rule data"),
+ currentUser: User = Depends(getCurrentUser)
+) -> dict:
+ """
+ Update an existing access rule.
+ Only sysadmin can update access rules.
+
+ Path Parameters:
+ - ruleId: Access rule ID
+
+ Request Body:
+ - AccessRule object data (roleLabel, context, item, view, read, create, update, delete)
+
+ Returns:
+ - Updated AccessRule object
+ """
+ try:
+ # Get interface
+ interface = getInterface(currentUser)
+
+ # Check if user has permission to update access rules
+ if not interface.rbac:
+ raise HTTPException(
+ status_code=500,
+ detail="RBAC interface not available"
+ )
+
+ # Check permission - only sysadmin can update rules
+ permissions = interface.rbac.getUserPermissions(
+ currentUser,
+ AccessRuleContext.DATA,
+ "AccessRule"
+ )
+
+ if not permissions.update or permissions.update == AccessLevel.NONE:
+ raise HTTPException(
+ status_code=403,
+ detail="No permission to update access rules"
+ )
+
+ # Get existing rule to ensure it exists
+ existingRule = interface.getAccessRule(ruleId)
+ if not existingRule:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Access rule {ruleId} not found"
+ )
+
+ # Validate and parse access rule data
+ try:
+ # Merge with existing rule data
+ updateData = existingRule.model_dump()
+ updateData.update(accessRuleData)
+
+ # Parse context if provided as string
+ if "context" in updateData and isinstance(updateData["context"], str):
+ updateData["context"] = AccessRuleContext(updateData["context"].upper())
+
+ # Parse AccessLevel fields if provided as strings
+ for field in ["read", "create", "update", "delete"]:
+ if field in updateData and isinstance(updateData[field], str):
+ updateData[field] = AccessLevel(updateData[field])
+
+ # Ensure ID is set correctly
+ updateData["id"] = ruleId
+
+ # Create AccessRule object
+ accessRule = AccessRule(**updateData)
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid access rule data: {str(e)}"
+ )
+
+ # Update rule
+ updatedRule = interface.updateAccessRule(ruleId, accessRule)
+
+ logger.info(f"Updated access rule {ruleId} by user {currentUser.id}")
+
+ # Convert to dict for JSON serialization
+ return updatedRule.model_dump()
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error updating access rule {ruleId}: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to update access rule: {str(e)}"
+ )
+
+
+@router.delete("/rules/{ruleId}")
+@limiter.limit("30/minute")
+async def deleteAccessRule(
+ request: Request,
+ ruleId: str = Path(..., description="Access rule ID"),
+ currentUser: User = Depends(getCurrentUser)
+) -> dict:
+ """
+ Delete an access rule.
+ Only sysadmin can delete access rules.
+
+ Path Parameters:
+ - ruleId: Access rule ID
+
+ Returns:
+ - Success message
+ """
+ try:
+ # Get interface
+ interface = getInterface(currentUser)
+
+ # Check if user has permission to delete access rules
+ if not interface.rbac:
+ raise HTTPException(
+ status_code=500,
+ detail="RBAC interface not available"
+ )
+
+ # Check permission - only sysadmin can delete rules
+ permissions = interface.rbac.getUserPermissions(
+ currentUser,
+ AccessRuleContext.DATA,
+ "AccessRule"
+ )
+
+ if not permissions.delete or permissions.delete == AccessLevel.NONE:
+ raise HTTPException(
+ status_code=403,
+ detail="No permission to delete access rules"
+ )
+
+ # Get existing rule to ensure it exists
+ existingRule = interface.getAccessRule(ruleId)
+ if not existingRule:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Access rule {ruleId} not found"
+ )
+
+ # Delete rule
+ success = interface.deleteAccessRule(ruleId)
+
+ if not success:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to delete access rule {ruleId}"
+ )
+
+ logger.info(f"Deleted access rule {ruleId} by user {currentUser.id}")
+
+ return {"success": True, "message": f"Access rule {ruleId} deleted successfully"}
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error deleting access rule {ruleId}: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to delete access rule: {str(e)}"
+ )
+
+
+# ============================================================================
+# Role Management Endpoints
+# ============================================================================
+
+def _ensureAdminAccess(currentUser: User) -> None:
+ """Ensure current user has admin access to RBAC roles management."""
+ interface = getInterface(currentUser)
+
+ # Check if user has admin or sysadmin role
+ roleLabels = currentUser.roleLabels or []
+ if "sysadmin" not in roleLabels and "admin" not in roleLabels:
+ raise HTTPException(
+ status_code=403,
+ detail="Admin or sysadmin role required to manage RBAC roles"
+ )
+
+
+@router.get("/roles", response_model=List[Dict[str, Any]])
+@limiter.limit("60/minute")
+async def listRoles(
+ request: Request,
+ currentUser: User = Depends(getCurrentUser)
+) -> List[Dict[str, Any]]:
+ """
+ Get list of all available roles with metadata.
+
+ Returns:
+ - List of role dictionaries with role label, description, and user count
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ # Get all roles from database
+ dbRoles = interface.getAllRoles()
+
+ # Get all users to count role assignments
+ # Since _ensureAdminAccess ensures user is sysadmin or admin,
+ # and getUsersByMandate returns all users for sysadmin regardless of mandateId,
+ # we can pass the current user's mandateId (for sysadmin it will be ignored by RBAC)
+ allUsers = interface.getUsersByMandate(currentUser.mandateId or "")
+
+ # Count users per role
+ roleCounts: Dict[str, int] = {}
+ for user in allUsers:
+ for roleLabel in (user.roleLabels or []):
+ roleCounts[roleLabel] = roleCounts.get(roleLabel, 0) + 1
+
+ # Convert Role objects to dictionaries and add user counts
+ result = []
+ for role in dbRoles:
+ result.append({
+ "id": role.id,
+ "roleLabel": role.roleLabel,
+ "description": role.description,
+ "userCount": roleCounts.get(role.roleLabel, 0),
+ "isSystemRole": role.isSystemRole
+ })
+
+ # Add any roles found in user assignments that don't exist in database
+ dbRoleLabels = {role.roleLabel for role in dbRoles}
+ for roleLabel, count in roleCounts.items():
+ if roleLabel not in dbRoleLabels:
+ result.append({
+ "id": None,
+ "roleLabel": roleLabel,
+ "description": {"en": f"Custom role: {roleLabel}", "fr": f"Rôle personnalisé : {roleLabel}"},
+ "userCount": count,
+ "isSystemRole": False
+ })
+
+ return result
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error listing roles: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to list roles: {str(e)}"
+ )
+
+
+@router.get("/roles/options", response_model=List[Dict[str, Any]])
+@limiter.limit("60/minute")
+async def getRoleOptions(
+ request: Request,
+ currentUser: User = Depends(getCurrentUser)
+) -> List[Dict[str, Any]]:
+ """
+ Get role options for select dropdowns.
+ Returns roles in format suitable for frontend select components.
+
+ Returns:
+ - List of role option dictionaries with value and label
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ # Get all roles from database
+ dbRoles = interface.getAllRoles()
+
+ # Convert to options format
+ options = []
+ for role in dbRoles:
+ # Use English description as label, fallback to roleLabel
+ label = role.description.get("en", role.roleLabel) if isinstance(role.description, dict) else role.roleLabel
+ options.append({
+ "value": role.roleLabel,
+ "label": label
+ })
+
+ return options
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting role options: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to get role options: {str(e)}"
+ )
+
+
+@router.post("/roles", response_model=Dict[str, Any])
+@limiter.limit("30/minute")
+async def createRole(
+ request: Request,
+ role: Role = Body(...),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, Any]:
+ """
+ Create a new role.
+
+ Request Body:
+ - role: Role object to create
+
+ Returns:
+ - Created role dictionary
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ createdRole = interface.createRole(role)
+
+ return {
+ "id": createdRole.id,
+ "roleLabel": createdRole.roleLabel,
+ "description": createdRole.description,
+ "isSystemRole": createdRole.isSystemRole
+ }
+
+ except HTTPException:
+ raise
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=str(e)
+ )
+ except Exception as e:
+ logger.error(f"Error creating role: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to create role: {str(e)}"
+ )
+
+
+@router.get("/roles/{roleId}", response_model=Dict[str, Any])
+@limiter.limit("60/minute")
+async def getRole(
+ request: Request,
+ roleId: str = Path(..., description="Role ID"),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, Any]:
+ """
+ Get a role by ID.
+
+ Path Parameters:
+ - roleId: Role ID
+
+ Returns:
+ - Role dictionary
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ role = interface.getRole(roleId)
+ if not role:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Role {roleId} not found"
+ )
+
+ return {
+ "id": role.id,
+ "roleLabel": role.roleLabel,
+ "description": role.description,
+ "isSystemRole": role.isSystemRole
+ }
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting role: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to get role: {str(e)}"
+ )
+
+
+@router.put("/roles/{roleId}", response_model=Dict[str, Any])
+@limiter.limit("30/minute")
+async def updateRole(
+ request: Request,
+ roleId: str = Path(..., description="Role ID"),
+ role: Role = Body(...),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, Any]:
+ """
+ Update an existing role.
+
+ Path Parameters:
+ - roleId: Role ID
+
+ Request Body:
+ - role: Updated Role object
+
+ Returns:
+ - Updated role dictionary
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ updatedRole = interface.updateRole(roleId, role)
+
+ return {
+ "id": updatedRole.id,
+ "roleLabel": updatedRole.roleLabel,
+ "description": updatedRole.description,
+ "isSystemRole": updatedRole.isSystemRole
+ }
+
+ except HTTPException:
+ raise
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=str(e)
+ )
+ except Exception as e:
+ logger.error(f"Error updating role: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to update role: {str(e)}"
+ )
+
+
+@router.delete("/roles/{roleId}", response_model=Dict[str, str])
+@limiter.limit("30/minute")
+async def deleteRole(
+ request: Request,
+ roleId: str = Path(..., description="Role ID"),
+ currentUser: User = Depends(getCurrentUser)
+) -> Dict[str, str]:
+ """
+ Delete a role.
+
+ Path Parameters:
+ - roleId: Role ID
+
+ Returns:
+ - Success message
+ """
+ try:
+ _ensureAdminAccess(currentUser)
+
+ interface = getInterface(currentUser)
+
+ success = interface.deleteRole(roleId)
+ if not success:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Role {roleId} not found"
+ )
+
+ return {"message": f"Role {roleId} deleted successfully"}
+
+ except HTTPException:
+ raise
+ except ValueError as e:
+ raise HTTPException(
+ status_code=400,
+ detail=str(e)
+ )
+ except Exception as e:
+ logger.error(f"Error deleting role: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to delete role: {str(e)}"
+ )
diff --git a/modules/security/rbac.py b/modules/security/rbac.py
index ca2050de..c783172b 100644
--- a/modules/security/rbac.py
+++ b/modules/security/rbac.py
@@ -20,9 +20,17 @@ class RbacClass:
RBAC interface for permission resolution and rule validation.
"""
- def __init__(self, db: "DatabaseConnector"):
- """Initialize RBAC interface with database connector."""
+ def __init__(self, db: "DatabaseConnector", dbApp: "DatabaseConnector"):
+ """
+ Initialize RBAC interface with database connector.
+
+ Args:
+ db: Database connector for general operations (may be from any database)
+ dbApp: DbApp database connector for AccessRule queries.
+ AccessRule table is always in the DbApp database.
+ """
self.db = db
+ self.dbApp = dbApp
def getUserPermissions(self, user: User, context: AccessRuleContext, item: str) -> UserPermissions:
"""
@@ -44,8 +52,7 @@ class RbacClass:
delete=AccessLevel.NONE
)
- if not user.roleLabels:
- logger.warning(f"User {user.id} has no roleLabels assigned")
+ if not hasattr(user, 'roleLabels') or not user.roleLabels:
return permissions
# Step 1: For each role, find the most specific matching rule (most specific wins within role)
@@ -171,6 +178,7 @@ class RbacClass:
def _getRulesForRole(self, roleLabel: str, context: AccessRuleContext) -> List[AccessRule]:
"""
Get all access rules for a specific role and context.
+ Always queries from DbApp database, not the current database.
Args:
roleLabel: Role label to get rules for
@@ -180,15 +188,25 @@ class RbacClass:
List of AccessRule objects
"""
try:
- rules = self.db.getRecordset(
+ # Always use DbApp database for AccessRule queries
+ rules = self.dbApp.getRecordset(
AccessRule,
recordFilter={
"roleLabel": roleLabel,
"context": context.value
}
)
+
# Convert dict records to AccessRule objects
- return [AccessRule(**record) for record in rules]
+ accessRules = []
+ for record in rules:
+ try:
+ accessRule = AccessRule(**record)
+ accessRules.append(accessRule)
+ except Exception as e:
+ logger.error(f"Error converting rule record to AccessRule: {e}, record={record}")
+
+ return accessRules
except Exception as e:
- logger.error(f"Error getting rules for role {roleLabel} and context {context.value}: {e}")
+ logger.error(f"Error getting rules for role {roleLabel} and context {context.value}: {e}", exc_info=True)
return []
diff --git a/modules/shared/attributeUtils.py b/modules/shared/attributeUtils.py
index b88a94e7..9116d330 100644
--- a/modules/shared/attributeUtils.py
+++ b/modules/shared/attributeUtils.py
@@ -3,7 +3,7 @@ Shared utilities for model attributes and labels.
"""
from pydantic import BaseModel, Field, ConfigDict
-from typing import Dict, Any, List, Type, Optional
+from typing import Dict, Any, List, Type, Optional, Union
import inspect
import importlib
import os
@@ -22,7 +22,7 @@ class AttributeDefinition(BaseModel):
description: Optional[str] = None
required: bool = False
default: Any = None
- options: Optional[List[Any]] = None
+ options: Optional[Union[str, List[Any]]] = None # Can be a string reference (e.g., "user.role") or a list of options
validation: Optional[Dict[str, Any]] = None
ui: Optional[Dict[str, Any]] = None
# New frontend metadata fields
@@ -194,14 +194,20 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag
else:
field_default = default_value
+ # Safely get description
+ description = ""
+ try:
+ if hasattr(field_info, "description") and field_info.description:
+ description = str(field_info.description)
+ except Exception:
+ pass
+
attributes.append(
{
"name": name,
"type": field_type,
"required": frontend_required,
- "description": field.description
- if hasattr(field, "description")
- else "",
+ "description": description,
"label": labels.get(name, name),
"placeholder": f"Please enter {labels.get(name, name)}",
"editable": not frontend_readonly,
@@ -259,17 +265,21 @@ def getModelClasses() -> Dict[str, Type[BaseModel]]:
# Convert fileName to module name (e.g., datamodelUtils.py -> datamodelUtils)
module_name = fileName[:-3]
- # Import the module dynamically
- module = importlib.import_module(f"modules.datamodels.{module_name}")
+ try:
+ # Import the module dynamically
+ module = importlib.import_module(f"modules.datamodels.{module_name}")
- # Get all classes from the module
- for name, obj in inspect.getmembers(module):
- if (
- inspect.isclass(obj)
- and issubclass(obj, BaseModel)
- and obj != BaseModel
- ):
- modelClasses[name] = obj
+ # Get all classes from the module
+ for name, obj in inspect.getmembers(module):
+ if (
+ inspect.isclass(obj)
+ and issubclass(obj, BaseModel)
+ and obj != BaseModel
+ ):
+ modelClasses[name] = obj
+ except Exception as e:
+ logger.warning(f"Error importing module {module_name}: {str(e)}", exc_info=True)
+ # Continue with other modules even if one fails
return modelClasses
diff --git a/modules/shared/frontendOptionsTypes.py b/modules/shared/frontendOptionsTypes.py
new file mode 100644
index 00000000..d31ff558
--- /dev/null
+++ b/modules/shared/frontendOptionsTypes.py
@@ -0,0 +1,136 @@
+"""
+Type definitions and utilities for frontend_options attribute.
+
+The frontend_options attribute supports two formats:
+1. Static List: A list of option dictionaries for static options
+2. String Reference: A string identifier that references dynamic options from /api/options/{optionsName}
+"""
+
+from typing import List, Dict, Any, Union
+
+try:
+ from typing import TypeAlias # Python 3.10+
+except ImportError:
+ from typing_extensions import TypeAlias # Python < 3.10
+
+# Type definition for a single option item
+OptionItem: TypeAlias = Dict[str, Any]
+"""
+Single option item format:
+{
+ "value": str, # The value to be stored/returned
+ "label": { # Multilingual labels
+ "en": str,
+ "fr": str,
+ ...
+ }
+}
+"""
+
+# Type definition for frontend_options - can be either a list or string reference
+FrontendOptions: TypeAlias = Union[List[OptionItem], str]
+"""
+frontend_options can be either:
+1. List[OptionItem]: Static list of options
+ Example: [{"value": "a", "label": {"en": "All", "fr": "Tous"}}]
+
+2. str: String reference to dynamic options API
+ Example: "user.role" -> Frontend fetches from /api/options/user.role
+"""
+
+
+def isStringReference(frontendOptions: FrontendOptions) -> bool:
+ """
+ Check if frontend_options is a string reference (dynamic) or a list (static).
+
+ Args:
+ frontendOptions: The frontend_options value to check
+
+ Returns:
+ True if it's a string reference, False if it's a list
+ """
+ return isinstance(frontendOptions, str)
+
+
+def isStaticList(frontendOptions: FrontendOptions) -> bool:
+ """
+ Check if frontend_options is a static list or a string reference.
+
+ Args:
+ frontendOptions: The frontend_options value to check
+
+ Returns:
+ True if it's a static list, False if it's a string reference
+ """
+ return isinstance(frontendOptions, list)
+
+
+def validateFrontendOptions(frontendOptions: FrontendOptions) -> bool:
+ """
+ Validate that frontend_options is in the correct format.
+
+ Args:
+ frontendOptions: The frontend_options value to validate
+
+ Returns:
+ True if valid, False otherwise
+ """
+ if isinstance(frontendOptions, str):
+ # String reference: should be a non-empty string
+ return bool(frontendOptions.strip())
+
+ elif isinstance(frontendOptions, list):
+ # Static list: should contain option dictionaries
+ if not frontendOptions:
+ return True # Empty list is valid (no options)
+
+ for option in frontendOptions:
+ if not isinstance(option, dict):
+ return False
+ if "value" not in option:
+ return False
+ if "label" not in option:
+ return False
+ if not isinstance(option["label"], dict):
+ return False
+
+ return True
+
+ else:
+ return False
+
+
+def getOptionsName(frontendOptions: FrontendOptions) -> str:
+ """
+ Get the options name from a string reference.
+
+ Args:
+ frontendOptions: The frontend_options value (must be a string reference)
+
+ Returns:
+ The options name (e.g., "user.role")
+
+ Raises:
+ ValueError: If frontendOptions is not a string reference
+ """
+ if not isStringReference(frontendOptions):
+ raise ValueError(f"frontend_options is not a string reference: {type(frontendOptions)}")
+ return frontendOptions
+
+
+def getStaticOptions(frontendOptions: FrontendOptions) -> List[OptionItem]:
+ """
+ Get the static options list.
+
+ Args:
+ frontendOptions: The frontend_options value (must be a static list)
+
+ Returns:
+ The list of option items
+
+ Raises:
+ ValueError: If frontendOptions is not a static list
+ """
+ if not isStaticList(frontendOptions):
+ raise ValueError(f"frontend_options is not a static list: {type(frontendOptions)}")
+ return frontendOptions
diff --git a/pytest.ini b/pytest.ini
index ad1e22f2..0a8eb39c 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -11,3 +11,12 @@ log_file_date_format = %Y-%m-%d %H:%M:%S
# Only run non-expensive tests by default, verbose log, short traceback
# Use 'pytest -m ""' to run ALL tests.
addopts = -v --tb=short -m 'not expensive'
+
+# Suppress deprecation warnings from third-party libraries
+filterwarnings =
+ ignore::DeprecationWarning:pkg_resources
+ ignore::DeprecationWarning:google.cloud.translate_v2
+ ignore::DeprecationWarning:passlib.handlers.argon2
+ ignore:pkg_resources is deprecated:DeprecationWarning
+ ignore:Deprecated call to.*pkg_resources.declare_namespace:DeprecationWarning
+ ignore:Accessing argon2.__version__ is deprecated:DeprecationWarning
diff --git a/tests/functional/test_kpi_fix.py b/tests/functional/test_kpi_fix.py
deleted file mode 100644
index 1e864815..00000000
--- a/tests/functional/test_kpi_fix.py
+++ /dev/null
@@ -1,86 +0,0 @@
-"""Test KPI extraction fix with incomplete JSON"""
-import json
-import sys
-import os
-
-# Add gateway directory to path
-_gateway_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
-if _gateway_path not in sys.path:
- sys.path.insert(0, _gateway_path)
-
-from modules.services.serviceAi.subJsonResponseHandling import JsonResponseHandler
-from modules.datamodels.datamodelAi import JsonAccumulationState
-
-# Load actual incomplete JSON response
-json_file = os.path.join(
- os.path.dirname(__file__),
- "..", "..", "..", "local", "debug", "prompts",
- "20251130-211706-078-document_generation_response.txt"
-)
-
-with open(json_file, 'r', encoding='utf-8') as f:
- incompleteJsonString = f.read()
-
-# KPI definition
-kpiDefinitions = [{
- "id": "prime_numbers_count",
- "description": "Number of prime numbers generated and organized in the table",
- "jsonPath": "documents[0].sections[0].elements[0].rows",
- "targetValue": 4000
-}]
-
-print("="*60)
-print("KPI EXTRACTION FIX TEST")
-print("="*60)
-
-# Test 1: Extract from incomplete JSON string
-print(f"\nTest 1: Extracting from incomplete JSON string...")
-updatedKpis = JsonResponseHandler.extractKpiValuesFromIncompleteJson(
- incompleteJsonString,
- [{**kpi, "currentValue": 0} for kpi in kpiDefinitions]
-)
-
-print(f" Result: {updatedKpis[0].get('currentValue', 'N/A')} rows")
-print(f" Expected: ~400 rows (incomplete JSON)")
-
-# Test 2: Compare with repaired JSON
-print(f"\nTest 2: Comparing with repaired JSON...")
-from modules.shared.jsonUtils import extractJsonString, repairBrokenJson
-
-extracted = extractJsonString(incompleteJsonString)
-repaired = repairBrokenJson(extracted)
-
-if repaired:
- repairedKpis = JsonResponseHandler.extractKpiValuesFromJson(
- repaired,
- [{**kpi, "currentValue": 0} for kpi in kpiDefinitions]
- )
- print(f" Repaired JSON: {repairedKpis[0].get('currentValue', 'N/A')} rows")
- print(f" Incomplete JSON string: {updatedKpis[0].get('currentValue', 'N/A')} rows")
-
- if updatedKpis[0].get('currentValue', 0) > repairedKpis[0].get('currentValue', 0):
- print(f" ✅ Fix works! Incomplete JSON string extraction found more data")
- else:
- print(f" ⚠️ Both methods found same or less data")
-
-# Test 3: Validate progression
-print(f"\nTest 3: Testing KPI validation...")
-accumulationState = JsonAccumulationState(
- accumulatedJsonString=incompleteJsonString,
- isAccumulationMode=True,
- lastParsedResult=repaired,
- allSections=[],
- kpis=[{**kpi, "currentValue": 0} for kpi in kpiDefinitions]
-)
-
-shouldProceed, reason = JsonResponseHandler.validateKpiProgression(
- accumulationState,
- updatedKpis
-)
-
-print(f" Result: shouldProceed={shouldProceed}, reason={reason}")
-if shouldProceed:
- print(f" ✅ Validation passes - KPIs will progress correctly")
-else:
- print(f" ❌ Validation fails - {reason}")
-
diff --git a/tests/functional/test_kpi_full.py b/tests/functional/test_kpi_full.py
index 2d73f4be..e8cf1ec1 100644
--- a/tests/functional/test_kpi_full.py
+++ b/tests/functional/test_kpi_full.py
@@ -2,6 +2,7 @@
import json
import sys
import os
+import pytest
# Add gateway directory to path
_gateway_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
@@ -19,8 +20,7 @@ json_file = os.path.join(
)
if not os.path.exists(json_file):
- print(f"File not found: {json_file}")
- sys.exit(1)
+ pytest.skip(f"Test data file not found: {json_file}", allow_module_level=True)
with open(json_file, 'r', encoding='utf-8') as f:
content = f.read()
diff --git a/tests/functional/test_kpi_incomplete.py b/tests/functional/test_kpi_incomplete.py
index e308246f..a6d724e9 100644
--- a/tests/functional/test_kpi_incomplete.py
+++ b/tests/functional/test_kpi_incomplete.py
@@ -2,6 +2,7 @@
import json
import sys
import os
+import pytest
# Add gateway directory to path
_gateway_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
@@ -20,8 +21,7 @@ json_file = os.path.join(
)
if not os.path.exists(json_file):
- print(f"File not found: {json_file}")
- sys.exit(1)
+ pytest.skip(f"Test data file not found: {json_file}", allow_module_level=True)
with open(json_file, 'r', encoding='utf-8') as f:
content = f.read()
@@ -54,8 +54,7 @@ except json.JSONDecodeError as e:
print(f" ❌ Repair error: {e2}")
if not parsedJson:
- print("\n❌ Cannot proceed - JSON cannot be parsed or repaired")
- sys.exit(1)
+ pytest.skip("Cannot proceed - JSON cannot be parsed or repaired", allow_module_level=True)
# Step 3: Check if path exists
print(f"\nStep 3: Checking if KPI path exists...")
@@ -73,7 +72,7 @@ except Exception as e:
print(f" ❌ Path extraction failed: {e}")
import traceback
traceback.print_exc()
- sys.exit(1)
+ pytest.skip(f"Path extraction failed: {e}", allow_module_level=True)
# Step 4: Test KPI extraction
print(f"\nStep 4: Testing KPI extraction...")
diff --git a/tests/functional/test_repair_debug.py b/tests/functional/test_repair_debug.py
deleted file mode 100644
index 1e60d725..00000000
--- a/tests/functional/test_repair_debug.py
+++ /dev/null
@@ -1,58 +0,0 @@
-"""Debug what repairBrokenJson returns"""
-import json
-import sys
-import os
-
-# Add gateway directory to path
-_gateway_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
-if _gateway_path not in sys.path:
- sys.path.insert(0, _gateway_path)
-
-from modules.shared.jsonUtils import extractJsonString, repairBrokenJson
-
-# Load actual incomplete JSON response
-json_file = os.path.join(
- os.path.dirname(__file__),
- "..", "..", "..", "local", "debug", "prompts",
- "20251130-211706-078-document_generation_response.txt"
-)
-
-with open(json_file, 'r', encoding='utf-8') as f:
- content = f.read()
-
-extracted = extractJsonString(content)
-print(f"Extracted JSON length: {len(extracted)} chars")
-print(f"Last 200 chars: {extracted[-200:]}")
-
-repaired = repairBrokenJson(extracted)
-if repaired:
- print(f"\nRepaired JSON structure:")
- print(f" Has 'documents': {'documents' in repaired}")
- if 'documents' in repaired and isinstance(repaired['documents'], list) and len(repaired['documents']) > 0:
- doc = repaired['documents'][0]
- print(f" Has 'sections': {'sections' in doc}")
- if 'sections' in doc and isinstance(doc['sections'], list) and len(doc['sections']) > 0:
- section = doc['sections'][0]
- print(f" Has 'elements': {'elements' in section}")
- if 'elements' in section and isinstance(section['elements'], list) and len(section['elements']) > 0:
- element = section['elements'][0]
- print(f" Has 'rows': {'rows' in element}")
- if 'rows' in element:
- rows = element['rows']
- print(f" Rows type: {type(rows)}")
- if isinstance(rows, list):
- print(f" Rows count: {len(rows)}")
- if len(rows) > 0:
- print(f" First row: {rows[0]}")
- print(f" Last row: {rows[-1]}")
- else:
- print(f" Rows value: {rows}")
-
- # Save to file for inspection
- output_file = os.path.join(os.path.dirname(__file__), "repaired_debug.json")
- with open(output_file, 'w', encoding='utf-8') as f:
- json.dump(repaired, f, indent=2, ensure_ascii=False)
- print(f"\nSaved repaired JSON to: {output_file}")
-else:
- print("Repair failed")
-
diff --git a/tests/integration/options/test_options_api.py b/tests/integration/options/test_options_api.py
new file mode 100644
index 00000000..ac9b5468
--- /dev/null
+++ b/tests/integration/options/test_options_api.py
@@ -0,0 +1,241 @@
+"""
+Integration tests for Options API endpoints.
+Tests the actual API endpoints with real database connections.
+"""
+
+import pytest
+import secrets
+from fastapi.testclient import TestClient
+from modules.datamodels.datamodelUam import User
+from modules.interfaces.interfaceDbAppObjects import getRootInterface
+
+
+@pytest.fixture
+def app():
+ """Create FastAPI app instance for testing."""
+ from app import app as fastapi_app
+ return fastapi_app
+
+
+@pytest.fixture
+def testClient(app):
+ """Create test client for API testing."""
+ return TestClient(app)
+
+
+@pytest.fixture
+def csrfToken():
+ """Generate a valid CSRF token for testing."""
+ # Generate a hex string between 16-64 characters (CSRF validation requirement)
+ return secrets.token_hex(16) # 32 character hex string
+
+
+@pytest.fixture
+def testUser() -> User:
+ """Create a test user for API testing."""
+ # Use getRootInterface for system operations like user creation
+ # The root interface automatically uses the root mandate
+ rootInterface = getRootInterface()
+ user = rootInterface.createUser(
+ username="testuser_options",
+ email="testuser_options@example.com",
+ password="testpass123",
+ roleLabels=["user"]
+ )
+ return user
+
+
+class TestOptionsAPI:
+ """Test Options API endpoints."""
+
+ def testGetOptionsUserRole(self, testClient, testUser, csrfToken):
+ """Test GET /api/options/user.role endpoint."""
+ # Get auth token (stored in cookie)
+ response = testClient.post(
+ "/api/local/login",
+ data={"username": testUser.username, "password": "testpass123"},
+ headers={"X-CSRF-Token": csrfToken}
+ )
+ assert response.status_code == 200
+
+ # Extract token from cookie for Bearer header
+ token = response.cookies.get("auth_token")
+ assert token is not None
+
+ # Get options
+ response = testClient.get(
+ "/api/options/user.role",
+ headers={"Authorization": f"Bearer {token}"}
+ )
+
+ assert response.status_code == 200
+ options = response.json()
+
+ assert isinstance(options, list)
+ assert len(options) >= 4 # At least sysadmin, admin, user, viewer
+
+ # Check structure
+ for option in options:
+ assert "value" in option
+ assert "label" in option
+ assert isinstance(option["label"], dict)
+
+ # Check specific values
+ values = [opt["value"] for opt in options]
+ assert "sysadmin" in values
+ assert "admin" in values
+ assert "user" in values
+ assert "viewer" in values
+
+ def testGetOptionsAuthAuthority(self, testClient, testUser, csrfToken):
+ """Test GET /api/options/auth.authority endpoint."""
+ # Get auth token (stored in cookie)
+ response = testClient.post(
+ "/api/local/login",
+ data={"username": testUser.username, "password": "testpass123"},
+ headers={"X-CSRF-Token": csrfToken}
+ )
+ assert response.status_code == 200
+
+ # Extract token from cookie for Bearer header
+ token = response.cookies.get("auth_token")
+ assert token is not None
+
+ # Get options
+ response = testClient.get(
+ "/api/options/auth.authority",
+ headers={"Authorization": f"Bearer {token}"}
+ )
+
+ assert response.status_code == 200
+ options = response.json()
+
+ assert isinstance(options, list)
+ assert len(options) == 3 # local, google, msft
+
+ # Check structure
+ for option in options:
+ assert "value" in option
+ assert "label" in option
+
+ # Check specific values
+ values = [opt["value"] for opt in options]
+ assert "local" in values
+ assert "google" in values
+ assert "msft" in values
+
+ def testGetOptionsConnectionStatus(self, testClient, testUser, csrfToken):
+ """Test GET /api/options/connection.status endpoint."""
+ # Get auth token (stored in cookie)
+ response = testClient.post(
+ "/api/local/login",
+ data={"username": testUser.username, "password": "testpass123"},
+ headers={"X-CSRF-Token": csrfToken}
+ )
+ assert response.status_code == 200
+
+ # Extract token from cookie for Bearer header
+ token = response.cookies.get("auth_token")
+ assert token is not None
+
+ # Get options
+ response = testClient.get(
+ "/api/options/connection.status",
+ headers={"Authorization": f"Bearer {token}"}
+ )
+
+ assert response.status_code == 200
+ options = response.json()
+
+ assert isinstance(options, list)
+ assert len(options) >= 4 # active, inactive, expired, pending, revoked, error
+
+ # Check structure
+ for option in options:
+ assert "value" in option
+ assert "label" in option
+
+ def testGetOptionsUserConnection(self, testClient, testUser, csrfToken):
+ """Test GET /api/options/user.connection endpoint (context-aware)."""
+ # Get auth token (stored in cookie)
+ response = testClient.post(
+ "/api/local/login",
+ data={"username": testUser.username, "password": "testpass123"},
+ headers={"X-CSRF-Token": csrfToken}
+ )
+ assert response.status_code == 200
+
+ # Extract token from cookie for Bearer header
+ token = response.cookies.get("auth_token")
+ assert token is not None
+
+ # Get options (should return empty list if no connections)
+ response = testClient.get(
+ "/api/options/user.connection",
+ headers={"Authorization": f"Bearer {token}"}
+ )
+
+ assert response.status_code == 200
+ options = response.json()
+
+ # Should return a list (may be empty)
+ assert isinstance(options, list)
+
+ def testGetOptionsList(self, testClient, testUser, csrfToken):
+ """Test GET /api/options/ endpoint (list all available options)."""
+ # Get auth token (stored in cookie)
+ response = testClient.post(
+ "/api/local/login",
+ data={"username": testUser.username, "password": "testpass123"},
+ headers={"X-CSRF-Token": csrfToken}
+ )
+ assert response.status_code == 200
+
+ # Extract token from cookie for Bearer header
+ token = response.cookies.get("auth_token")
+ assert token is not None
+
+ # Get available options names
+ response = testClient.get(
+ "/api/options/",
+ headers={"Authorization": f"Bearer {token}"}
+ )
+
+ assert response.status_code == 200
+ optionsNames = response.json()
+
+ assert isinstance(optionsNames, list)
+ assert "user.role" in optionsNames
+ assert "auth.authority" in optionsNames
+ assert "connection.status" in optionsNames
+ assert "user.connection" in optionsNames
+
+ def testGetOptionsUnknown(self, testClient, testUser, csrfToken):
+ """Test GET /api/options/unknown.options endpoint (should return 400)."""
+ # Get auth token (stored in cookie)
+ response = testClient.post(
+ "/api/local/login",
+ data={"username": testUser.username, "password": "testpass123"},
+ headers={"X-CSRF-Token": csrfToken}
+ )
+ assert response.status_code == 200
+
+ # Extract token from cookie for Bearer header
+ token = response.cookies.get("auth_token")
+ assert token is not None
+
+ # Get unknown options (should return error)
+ response = testClient.get(
+ "/api/options/unknown.options",
+ headers={"Authorization": f"Bearer {token}"}
+ )
+
+ assert response.status_code == 400
+
+ def testGetOptionsUnauthorized(self, testClient):
+ """Test GET /api/options/user.role without authentication."""
+ # Try to get options without auth token
+ response = testClient.get("/api/options/user.role")
+
+ # Should require authentication
+ assert response.status_code == 401
diff --git a/tests/unit/options/test_frontend_options_types.py b/tests/unit/options/test_frontend_options_types.py
new file mode 100644
index 00000000..544587f9
--- /dev/null
+++ b/tests/unit/options/test_frontend_options_types.py
@@ -0,0 +1,115 @@
+"""
+Unit tests for frontend_options type system and utilities.
+Tests type validation, format detection, and utility functions.
+"""
+
+import pytest
+from modules.shared.frontendOptionsTypes import (
+ FrontendOptions,
+ OptionItem,
+ isStringReference,
+ isStaticList,
+ validateFrontendOptions,
+ getOptionsName,
+ getStaticOptions
+)
+
+
+class TestFrontendOptionsTypes:
+ """Test frontend_options type system."""
+
+ def testIsStringReference(self):
+ """Test string reference detection."""
+ assert isStringReference("user.role") is True
+ assert isStringReference("auth.authority") is True
+ assert isStringReference("") is True # Empty string is still a string
+
+ assert isStringReference([]) is False
+ assert isStringReference([{"value": "a"}]) is False
+ assert isStringReference(None) is False
+
+ def testIsStaticList(self):
+ """Test static list detection."""
+ assert isStaticList([]) is True
+ assert isStaticList([{"value": "a", "label": {"en": "A"}}]) is True
+
+ assert isStaticList("user.role") is False
+ assert isStaticList(None) is False
+
+ def testValidateFrontendOptionsString(self):
+ """Test validation of string references."""
+ assert validateFrontendOptions("user.role") is True
+ assert validateFrontendOptions("auth.authority") is True
+ assert validateFrontendOptions("") is False # Empty string is invalid
+ assert validateFrontendOptions(" ") is False # Whitespace-only is invalid
+
+ def testValidateFrontendOptionsStaticList(self):
+ """Test validation of static lists."""
+ # Valid static list
+ validList = [
+ {"value": "a", "label": {"en": "All", "fr": "Tous"}},
+ {"value": "m", "label": {"en": "My", "fr": "Mes"}}
+ ]
+ assert validateFrontendOptions(validList) is True
+
+ # Empty list is valid
+ assert validateFrontendOptions([]) is True
+
+ # Missing value key
+ invalidList1 = [{"label": {"en": "Test"}}]
+ assert validateFrontendOptions(invalidList1) is False
+
+ # Missing label key
+ invalidList2 = [{"value": "a"}]
+ assert validateFrontendOptions(invalidList2) is False
+
+ # Label is not a dict
+ invalidList3 = [{"value": "a", "label": "not a dict"}]
+ assert validateFrontendOptions(invalidList3) is False
+
+ # Not a list or string
+ assert validateFrontendOptions(None) is False
+ assert validateFrontendOptions(123) is False
+ assert validateFrontendOptions({}) is False
+
+ def testGetOptionsName(self):
+ """Test getting options name from string reference."""
+ assert getOptionsName("user.role") == "user.role"
+ assert getOptionsName("auth.authority") == "auth.authority"
+
+ # Should raise ValueError for non-string
+ with pytest.raises(ValueError):
+ getOptionsName([])
+
+ with pytest.raises(ValueError):
+ getOptionsName(None)
+
+ def testGetStaticOptions(self):
+ """Test getting static options list."""
+ options = [
+ {"value": "a", "label": {"en": "All"}},
+ {"value": "m", "label": {"en": "My"}}
+ ]
+ assert getStaticOptions(options) == options
+
+ # Should raise ValueError for non-list
+ with pytest.raises(ValueError):
+ getStaticOptions("user.role")
+
+ with pytest.raises(ValueError):
+ getStaticOptions(None)
+
+ def testTypeAliases(self):
+ """Test that type aliases are properly defined."""
+ # FrontendOptions should accept both str and List[OptionItem]
+ stringRef: FrontendOptions = "user.role"
+ staticList: FrontendOptions = [{"value": "a", "label": {"en": "A"}}]
+
+ assert isinstance(stringRef, str)
+ assert isinstance(staticList, list)
+
+ # OptionItem should be Dict[str, Any]
+ optionItem: OptionItem = {"value": "test", "label": {"en": "Test"}}
+ assert isinstance(optionItem, dict)
+ assert "value" in optionItem
+ assert "label" in optionItem
diff --git a/tests/unit/options/test_main_options.py b/tests/unit/options/test_main_options.py
new file mode 100644
index 00000000..172e64e5
--- /dev/null
+++ b/tests/unit/options/test_main_options.py
@@ -0,0 +1,181 @@
+"""
+Unit tests for Options API (mainOptions.py).
+Tests option retrieval, validation, and context-aware options.
+"""
+
+import pytest
+from unittest.mock import Mock, patch
+from modules.features.options.mainOptions import (
+ getOptions,
+ getAvailableOptionsNames,
+ STANDARD_ROLES,
+ AUTH_AUTHORITY_OPTIONS,
+ CONNECTION_STATUS_OPTIONS
+)
+from modules.datamodels.datamodelUam import User, UserConnection, AuthAuthority
+
+
+class TestMainOptions:
+ """Test Options API functionality."""
+
+ def testGetOptionsUserRole(self):
+ """Test getting user role options."""
+ options = getOptions("user.role")
+
+ assert isinstance(options, list)
+ assert len(options) == 4 # sysadmin, admin, user, viewer
+
+ # Check structure
+ for option in options:
+ assert "value" in option
+ assert "label" in option
+ assert isinstance(option["label"], dict)
+ assert "en" in option["label"]
+ assert "fr" in option["label"]
+
+ # Check specific values
+ values = [opt["value"] for opt in options]
+ assert "sysadmin" in values
+ assert "admin" in values
+ assert "user" in values
+ assert "viewer" in values
+
+ def testGetOptionsAuthAuthority(self):
+ """Test getting auth authority options."""
+ options = getOptions("auth.authority")
+
+ assert isinstance(options, list)
+ assert len(options) == 3 # local, google, msft
+
+ # Check structure
+ for option in options:
+ assert "value" in option
+ assert "label" in option
+
+ # Check specific values
+ values = [opt["value"] for opt in options]
+ assert "local" in values
+ assert "google" in values
+ assert "msft" in values
+
+ def testGetOptionsConnectionStatus(self):
+ """Test getting connection status options."""
+ options = getOptions("connection.status")
+
+ assert isinstance(options, list)
+ assert len(options) == 5 # active, expired, revoked, pending, error
+
+ # Check structure
+ for option in options:
+ assert "value" in option
+ assert "label" in option
+
+ # Check specific values
+ values = [opt["value"] for opt in options]
+ assert "active" in values
+ assert "expired" in values
+ assert "revoked" in values
+ assert "pending" in values
+ assert "error" in values
+
+ def testGetOptionsUserConnection(self):
+ """Test getting user connection options (context-aware)."""
+ # Without currentUser, should return empty list
+ options = getOptions("user.connection")
+ assert options == []
+
+ # With currentUser but no connections
+ user = User(
+ id="user1",
+ username="testuser",
+ roleLabels=["user"],
+ mandateId="mandate1"
+ )
+
+ with patch('modules.features.options.mainOptions.getInterface') as mockGetInterface:
+ mockInterface = Mock()
+ mockInterface.getUserConnections.return_value = []
+ mockGetInterface.return_value = mockInterface
+
+ options = getOptions("user.connection", currentUser=user)
+ assert options == []
+
+ def testGetOptionsUserConnectionWithData(self):
+ """Test getting user connection options with actual connections."""
+ user = User(
+ id="user1",
+ username="testuser",
+ roleLabels=["user"],
+ mandateId="mandate1"
+ )
+
+ # Mock connections
+ mockConn1 = Mock(spec=UserConnection)
+ mockConn1.id = "conn1"
+ mockConn1.authority = AuthAuthority.GOOGLE
+ mockConn1.externalUsername = "user@example.com"
+ mockConn1.externalId = None
+
+ mockConn2 = Mock(spec=UserConnection)
+ mockConn2.id = "conn2"
+ mockConn2.authority = AuthAuthority.MSFT
+ mockConn2.externalUsername = None
+ mockConn2.externalId = "external-id-123"
+
+ with patch('modules.features.options.mainOptions.getInterface') as mockGetInterface:
+ mockInterface = Mock()
+ mockInterface.getUserConnections.return_value = [mockConn1, mockConn2]
+ mockGetInterface.return_value = mockInterface
+
+ options = getOptions("user.connection", currentUser=user)
+
+ assert len(options) == 2
+ assert options[0]["value"] == "conn1"
+ assert options[1]["value"] == "conn2"
+
+ # Check labels contain authority and username/id
+ assert "google" in options[0]["label"]["en"].lower()
+ assert "user@example.com" in options[0]["label"]["en"]
+
+ def testGetOptionsCaseInsensitive(self):
+ """Test that options name matching is case-insensitive."""
+ options1 = getOptions("user.role")
+ options2 = getOptions("USER.ROLE")
+ options3 = getOptions("User.Role")
+
+ assert options1 == options2 == options3
+
+ def testGetOptionsUnknown(self):
+ """Test that unknown options name raises ValueError."""
+ with pytest.raises(ValueError, match="Unknown options name"):
+ getOptions("unknown.options")
+
+ def testGetAvailableOptionsNames(self):
+ """Test getting list of available options names."""
+ names = getAvailableOptionsNames()
+
+ assert isinstance(names, list)
+ assert "user.role" in names
+ assert "auth.authority" in names
+ assert "connection.status" in names
+ assert "user.connection" in names
+ assert len(names) == 4
+
+ def testStandardRolesConstant(self):
+ """Test that STANDARD_ROLES constant is properly defined."""
+ assert isinstance(STANDARD_ROLES, list)
+ assert len(STANDARD_ROLES) == 4
+
+ for role in STANDARD_ROLES:
+ assert "value" in role
+ assert "label" in role
+
+ def testAuthAuthorityOptionsConstant(self):
+ """Test that AUTH_AUTHORITY_OPTIONS constant is properly defined."""
+ assert isinstance(AUTH_AUTHORITY_OPTIONS, list)
+ assert len(AUTH_AUTHORITY_OPTIONS) == 3
+
+ def testConnectionStatusOptionsConstant(self):
+ """Test that CONNECTION_STATUS_OPTIONS constant is properly defined."""
+ assert isinstance(CONNECTION_STATUS_OPTIONS, list)
+ assert len(CONNECTION_STATUS_OPTIONS) == 5 # active, expired, revoked, pending, error
diff --git a/tests/unit/rbac/test_rbac_bootstrap.py b/tests/unit/rbac/test_rbac_bootstrap.py
index e12592a1..37be1185 100644
--- a/tests/unit/rbac/test_rbac_bootstrap.py
+++ b/tests/unit/rbac/test_rbac_bootstrap.py
@@ -137,13 +137,25 @@ class TestRbacBootstrap:
assert rule.view == False
def testInitRbacRulesSkipsIfExists(self):
- """Test that initRbacRules skips creation if rules already exist."""
+ """Test that initRbacRules skips default rule creation if rules already exist, but adds missing table-specific rules."""
db = Mock()
- db.getRecordset = Mock(return_value=[{"id": "rule1"}]) # Rules exist
+ # Mock existing rules - include rules for ChatWorkflow and Prompt to prevent adding missing rules
+ # Need rules for all required roles to fully prevent creation
+ existingRules = []
+ for table in ["ChatWorkflow", "Prompt"]:
+ for role in ["sysadmin", "admin", "user", "viewer"]:
+ existingRules.append({
+ "id": f"rule_{table}_{role}",
+ "item": table,
+ "context": AccessRuleContext.DATA.value,
+ "roleLabel": role
+ })
+ db.getRecordset = Mock(return_value=existingRules)
+ db.recordCreate = Mock()
initRbacRules(db)
- # Should not create new rules
+ # Should not create new rules since all required tables already have rules for all roles
db.recordCreate.assert_not_called()
def testInitRbacRulesCreatesIfNotExists(self):
diff --git a/tests/unit/rbac/test_rbac_permissions.py b/tests/unit/rbac/test_rbac_permissions.py
index d180f5b8..1b814137 100644
--- a/tests/unit/rbac/test_rbac_permissions.py
+++ b/tests/unit/rbac/test_rbac_permissions.py
@@ -18,9 +18,10 @@ class TestRbacPermissionResolution:
"""Test permission resolution with a single role and generic rule."""
# Mock database connector
db = Mock(spec=DatabaseConnector)
+ dbApp = Mock(spec=DatabaseConnector)
# Create RBAC interface
- rbac = RbacClass(db)
+ rbac = RbacClass(db, dbApp=dbApp)
# Create user with single role
user = User(
@@ -65,7 +66,8 @@ class TestRbacPermissionResolution:
def testRuleSpecificityMostSpecificWins(self):
"""Test that most specific rule wins within a single role."""
db = Mock(spec=DatabaseConnector)
- rbac = RbacClass(db)
+ dbApp = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db, dbApp=dbApp)
user = User(
id="user1",
@@ -118,7 +120,8 @@ class TestRbacPermissionResolution:
def testMultipleRolesUnionLogic(self):
"""Test that multiple roles use union (opening) logic."""
db = Mock(spec=DatabaseConnector)
- rbac = RbacClass(db)
+ dbApp = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db, dbApp=dbApp)
# User with multiple roles
user = User(
@@ -165,7 +168,8 @@ class TestRbacPermissionResolution:
def testViewFalseOverridesGeneric(self):
"""Test that specific view=false overrides generic view=true."""
db = Mock(spec=DatabaseConnector)
- rbac = RbacClass(db)
+ dbApp = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db, dbApp=dbApp)
user = User(
id="user1",
@@ -207,7 +211,8 @@ class TestRbacPermissionResolution:
def testNoRolesReturnsNoAccess(self):
"""Test that user with no roles gets no access."""
db = Mock(spec=DatabaseConnector)
- rbac = RbacClass(db)
+ dbApp = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db, dbApp=dbApp)
user = User(
id="user1",
@@ -231,7 +236,8 @@ class TestRbacPermissionResolution:
def testFindMostSpecificRule(self):
"""Test findMostSpecificRule method."""
db = Mock(spec=DatabaseConnector)
- rbac = RbacClass(db)
+ dbApp = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db, dbApp=dbApp)
rules = [
AccessRule(
@@ -278,7 +284,8 @@ class TestRbacPermissionResolution:
def testValidateAccessRuleOpeningRights(self):
"""Test that CUD permissions respect read permission level."""
db = Mock(spec=DatabaseConnector)
- rbac = RbacClass(db)
+ dbApp = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db, dbApp=dbApp)
# Valid: Read=MY, Create=MY (allowed)
rule1 = AccessRule(
@@ -335,7 +342,8 @@ class TestRbacPermissionResolution:
def testUiContextOnlyViewMatters(self):
"""Test that UI context only checks view permission."""
db = Mock(spec=DatabaseConnector)
- rbac = RbacClass(db)
+ dbApp = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db, dbApp=dbApp)
user = User(
id="user1",
@@ -371,7 +379,8 @@ class TestRbacPermissionResolution:
def testResourceContextOnlyViewMatters(self):
"""Test that RESOURCE context only checks view permission."""
db = Mock(spec=DatabaseConnector)
- rbac = RbacClass(db)
+ dbApp = Mock(spec=DatabaseConnector)
+ rbac = RbacClass(db, dbApp=dbApp)
user = User(
id="user1",
diff --git a/tests/unit/services/test_ai_service.py b/tests/unit/services/test_ai_service.py
deleted file mode 100644
index e665fef7..00000000
--- a/tests/unit/services/test_ai_service.py
+++ /dev/null
@@ -1,146 +0,0 @@
-#!/usr/bin/env python3
-"""
-Unit tests for AI service (mainServiceAi.py)
-Tests callAiContent, callAiPlanning, and related functionality.
-"""
-
-import pytest
-from unittest.mock import Mock, AsyncMock, patch
-
-from modules.datamodels.datamodelAi import AiCallOptions, OperationTypeEnum, PriorityEnum, ProcessingModeEnum
-from modules.datamodels.datamodelExtraction import ContentPart
-from modules.datamodels.datamodelWorkflow import AiResponse
-
-
-class TestAiServiceCallAiContent:
- """Test callAiContent method (mocked)"""
-
- @pytest.mark.asyncio
- async def test_callAiContent_requires_operationType(self):
- """Test that callAiContent requires operationType to be set"""
- from modules.services.serviceAi.mainServiceAi import AiService
-
- # Create mock services
- mockServices = Mock()
- mockServices.workflow = None
- mockServices.chat = Mock()
- mockServices.chat.progressLogStart = Mock()
- mockServices.chat.progressLogUpdate = Mock()
- mockServices.chat.progressLogFinish = Mock()
- mockServices.chat.storeWorkflowStat = Mock()
-
- aiService = AiService(mockServices)
-
- # Mock aiObjects initialization
- aiService.aiObjects = Mock()
- aiService._ensureAiObjectsInitialized = AsyncMock()
-
- # Test with missing operationType - should analyze prompt
- options = AiCallOptions() # operationType not set
- options.operationType = None
-
- # Mock _analyzePromptAndCreateOptions
- analyzedOptions = AiCallOptions()
- analyzedOptions.operationType = OperationTypeEnum.DATA_ANALYSE
- aiService._analyzePromptAndCreateOptions = AsyncMock(return_value=analyzedOptions)
-
- # Mock _callAiWithLooping
- aiService._callAiWithLooping = AsyncMock(return_value="Test response")
-
- # Mock aiObjects.call
- mockResponse = Mock()
- mockResponse.content = "Test response"
- aiService.aiObjects.call = AsyncMock(return_value=mockResponse)
-
- # Call should work (will analyze prompt if operationType not set)
- result = await aiService.callAiContent(
- prompt="Test prompt",
- options=options
- )
-
- # Should have analyzed prompt and set operationType
- assert result is not None
- assert isinstance(result, AiResponse)
-
-
-class TestAiServiceCallAiPlanning:
- """Test callAiPlanning method (mocked)"""
-
- @pytest.mark.asyncio
- async def test_callAiPlanning_basic(self):
- """Test basic callAiPlanning call"""
- from modules.services.serviceAi.mainServiceAi import AiService
-
- # Create mock services
- mockServices = Mock()
- mockServices.workflow = None
- mockServices.utils = Mock()
- mockServices.utils.writeDebugFile = Mock()
-
- aiService = AiService(mockServices)
-
- # Mock aiObjects
- aiService.aiObjects = Mock()
- mockResponse = Mock()
- mockResponse.content = '{"result": "plan"}'
- aiService.aiObjects.call = AsyncMock(return_value=mockResponse)
- aiService._ensureAiObjectsInitialized = AsyncMock()
-
- # Call planning
- result = await aiService.callAiPlanning(
- prompt="Test planning prompt"
- )
-
- assert result == '{"result": "plan"}'
-
-
-class TestAiServiceOperationTypeHandling:
- """Test operationType handling in callAiContent"""
-
- @pytest.mark.asyncio
- async def test_callAiContent_with_outputFormat_sets_documentGenerate(self):
- """Test that outputFormat sets operationType to DOCUMENT_GENERATE"""
- from modules.services.serviceAi.mainServiceAi import AiService
-
- mockServices = Mock()
- mockServices.workflow = None
- mockServices.chat = Mock()
- mockServices.chat.progressLogStart = Mock()
- mockServices.chat.progressLogUpdate = Mock()
- mockServices.chat.progressLogFinish = Mock()
- mockServices.utils = Mock()
- mockServices.utils.jsonExtractString = Mock(return_value='{"documents": []}')
-
- aiService = AiService(mockServices)
- aiService.aiObjects = Mock()
- aiService._ensureAiObjectsInitialized = AsyncMock()
-
- # Mock _callAiWithLooping
- aiService._callAiWithLooping = AsyncMock(return_value='{"documents": []}')
-
- # Mock generation service
- with patch('modules.services.serviceGeneration.mainServiceGeneration.GenerationService') as mockGenService:
- mockGenInstance = Mock()
- mockGenInstance.renderReport = AsyncMock(return_value=(b"content", "application/pdf"))
- mockGenService.return_value = mockGenInstance
-
- options = AiCallOptions() # operationType not set
- options.operationType = None
-
- # Should set operationType to DOCUMENT_GENERATE when outputFormat is provided
- try:
- result = await aiService.callAiContent(
- prompt="Generate document",
- options=options,
- outputFormat="pdf"
- )
- # If it gets here, operationType was set correctly
- assert options.operationType == OperationTypeEnum.DOCUMENT_GENERATE
- except Exception:
- # If it fails, that's okay for unit test - we're testing the logic
- pass
-
-
-if __name__ == "__main__":
- pytest.main([__file__, "-v"])
-
From 72f5fbde4690141458fc7f55fd8fc7b8b23079dc Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Mon, 8 Dec 2025 00:13:26 +0100
Subject: [PATCH 5/6] added attribute types: TextMultilingual, multiselect
---
modules/datamodels/datamodelRbac.py | 5 ++-
modules/datamodels/datamodelUtils.py | 51 ++++++++++++++++++++++++-
modules/features/options/mainOptions.py | 12 +++++-
modules/shared/attributeUtils.py | 31 ++++++++++-----
4 files changed, 84 insertions(+), 15 deletions(-)
diff --git a/modules/datamodels/datamodelRbac.py b/modules/datamodels/datamodelRbac.py
index 7fcfb6c4..96f7ef55 100644
--- a/modules/datamodels/datamodelRbac.py
+++ b/modules/datamodels/datamodelRbac.py
@@ -5,6 +5,7 @@ from typing import Optional, Dict
from enum import Enum
from pydantic import BaseModel, Field
from modules.shared.attributeUtils import registerModelLabels
+from modules.datamodels.datamodelUtils import TextMultilingual
from modules.datamodels.datamodelUam import AccessLevel
@@ -26,9 +27,9 @@ class Role(BaseModel):
description="Unique role label identifier (e.g., 'admin', 'user', 'viewer')",
json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": True}
)
- description: Dict[str, str] = Field(
+ description: TextMultilingual = Field(
description="Role description in multiple languages",
- json_schema_extra={"frontend_type": "object", "frontend_readonly": False, "frontend_required": True}
+ json_schema_extra={"frontend_type": "multilingual", "frontend_readonly": False, "frontend_required": True}
)
isSystemRole: bool = Field(
False,
diff --git a/modules/datamodels/datamodelUtils.py b/modules/datamodels/datamodelUtils.py
index 4f1c69c2..3ff5d3fa 100644
--- a/modules/datamodels/datamodelUtils.py
+++ b/modules/datamodels/datamodelUtils.py
@@ -1,6 +1,7 @@
-"""Utility datamodels: Prompt."""
+"""Utility datamodels: Prompt, TextMultilingual."""
-from pydantic import BaseModel, Field
+from typing import Dict, Optional
+from pydantic import BaseModel, Field, field_validator
from modules.shared.attributeUtils import registerModelLabels
import uuid
@@ -22,3 +23,49 @@ registerModelLabels(
)
+class TextMultilingual(BaseModel):
+ """
+ Multilingual text field supporting multiple languages.
+ Default languages: en (English), ge (German), fr (French), it (Italian)
+ English (en) is the default/required language.
+ """
+ en: str = Field(description="English text (default language, required)")
+ ge: Optional[str] = Field(None, description="German text")
+ fr: Optional[str] = Field(None, description="French text")
+ it: Optional[str] = Field(None, description="Italian text")
+
+ @field_validator('en')
+ @classmethod
+ def validate_en_required(cls, v):
+ """Ensure English text is not empty"""
+ if not v or not v.strip():
+ raise ValueError("English text (en) is required and cannot be empty")
+ return v
+
+ def model_dump(self, **kwargs) -> Dict[str, str]:
+ """Return as dictionary, filtering out None values"""
+ result = {}
+ for lang in ['en', 'ge', 'fr', 'it']:
+ value = getattr(self, lang, None)
+ if value is not None:
+ result[lang] = value
+ return result
+
+ @classmethod
+ def from_dict(cls, data: Dict[str, str]) -> 'TextMultilingual':
+ """Create TextMultilingual from dictionary"""
+ return cls(
+ en=data.get('en', ''),
+ ge=data.get('ge'),
+ fr=data.get('fr'),
+ it=data.get('it')
+ )
+
+ def get_text(self, lang: str = 'en') -> str:
+ """Get text for a specific language, fallback to English if not available"""
+ value = getattr(self, lang, None)
+ if value:
+ return value
+ return self.en # Fallback to English
+
+
diff --git a/modules/features/options/mainOptions.py b/modules/features/options/mainOptions.py
index 41ef5db2..d05b3bc1 100644
--- a/modules/features/options/mainOptions.py
+++ b/modules/features/options/mainOptions.py
@@ -64,7 +64,17 @@ def getOptions(optionsName: str, currentUser: Optional[User] = None) -> List[Dic
options = []
for role in roles:
# Use English description as label, fallback to roleLabel
- label = role.description.get("en", role.roleLabel) if isinstance(role.description, dict) else role.roleLabel
+ # Handle TextMultilingual object
+ if hasattr(role.description, 'get_text'):
+ # TextMultilingual object
+ label = role.description.get_text('en')
+ elif isinstance(role.description, dict):
+ # Dict format (backward compatibility)
+ label = role.description.get("en", role.roleLabel)
+ else:
+ # Fallback to roleLabel
+ label = role.roleLabel
+
options.append({
"value": role.roleLabel,
"label": label
diff --git a/modules/shared/attributeUtils.py b/modules/shared/attributeUtils.py
index 9116d330..74aeee10 100644
--- a/modules/shared/attributeUtils.py
+++ b/modules/shared/attributeUtils.py
@@ -166,16 +166,27 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag
if frontend_options is None and "frontend_options" in json_extra:
frontend_options = json_extra.get("frontend_options")
- # Use frontend type if available, otherwise fall back to Python type
- field_type = (
- frontend_type
- if frontend_type
- else (
- field.annotation.__name__
- if hasattr(field.annotation, "__name__")
- else str(field.annotation)
- )
- )
+ # Use frontend type if available, otherwise detect from Python type
+ if frontend_type:
+ field_type = frontend_type
+ else:
+ # Check if it's TextMultilingual type
+ annotation_str = str(field.annotation)
+ # Check both the module path and class name for TextMultilingual
+ if ('TextMultilingual' in annotation_str or
+ (hasattr(field.annotation, '__name__') and field.annotation.__name__ == 'TextMultilingual') or
+ 'datamodelUtils.TextMultilingual' in annotation_str or
+ 'datamodels.datamodelUtils.TextMultilingual' in annotation_str):
+ field_type = 'multilingual'
+ elif hasattr(field.annotation, "__name__"):
+ annotation_name = field.annotation.__name__
+ # Check if it's a Dict type (for JSON/object fields)
+ if annotation_name == 'Dict' or annotation_str.startswith('typing.Dict') or annotation_str.startswith('Dict['):
+ field_type = 'object' # Will be rendered as textarea for JSON editing
+ else:
+ field_type = annotation_name
+ else:
+ field_type = str(field.annotation)
# Extract default value from field
# In Pydantic v2, FieldInfo has a 'default' attribute
From 1d9a1d7613cb6df5652fa9db4cabf059c4aeb10d Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Mon, 8 Dec 2025 07:38:39 +0100
Subject: [PATCH 6/6] fix
---
modules/features/options/mainOptions.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/modules/features/options/mainOptions.py b/modules/features/options/mainOptions.py
index d05b3bc1..75f1d6f2 100644
--- a/modules/features/options/mainOptions.py
+++ b/modules/features/options/mainOptions.py
@@ -5,7 +5,7 @@ Provides dynamic options for frontend select/multiselect fields.
import logging
from typing import List, Dict, Any, Optional
-from modules.datamodels.datamodelUam import User, AuthAuthority, ConnectionStatus
+from modules.datamodels.datamodelUam import User
from modules.interfaces.interfaceDbAppObjects import getInterface
logger = logging.getLogger(__name__)