243 lines
9.8 KiB
Python
243 lines
9.8 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
RBAC helper functions for interfaces.
|
|
Provides RBAC filtering for database queries without connectors importing security.
|
|
"""
|
|
|
|
import logging
|
|
import json
|
|
from typing import List, Dict, Any, Optional, Type
|
|
from pydantic import BaseModel
|
|
from modules.datamodels.datamodelRbac import AccessRuleContext
|
|
from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel
|
|
from modules.security.rbac import RbacClass
|
|
from modules.security.rootAccess import getRootDbAppConnector
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def getRecordsetWithRBAC(
|
|
connector, # DatabaseConnector instance
|
|
modelClass: Type[BaseModel],
|
|
currentUser: User,
|
|
recordFilter: Dict[str, Any] = None,
|
|
orderBy: str = None,
|
|
limit: int = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get records with RBAC filtering applied at database level.
|
|
This function wraps connector.getRecordset() with RBAC logic.
|
|
|
|
Args:
|
|
connector: DatabaseConnector instance
|
|
modelClass: Pydantic model class for the table
|
|
currentUser: User object with roleLabels
|
|
recordFilter: Additional record filters
|
|
orderBy: Field to order by (defaults to "id")
|
|
limit: Maximum number of records to return
|
|
|
|
Returns:
|
|
List of filtered records
|
|
"""
|
|
table = modelClass.__name__
|
|
|
|
try:
|
|
if not connector._ensureTableExists(modelClass):
|
|
return []
|
|
|
|
# Get RBAC permissions for this table
|
|
# AccessRule table is always in DbApp database
|
|
dbApp = getRootDbAppConnector()
|
|
rbacInstance = RbacClass(connector, dbApp=dbApp)
|
|
permissions = rbacInstance.getUserPermissions(
|
|
currentUser,
|
|
AccessRuleContext.DATA,
|
|
table
|
|
)
|
|
|
|
# Check view permission first
|
|
if not permissions.view:
|
|
logger.debug(f"User {currentUser.id} has no view permission for table {table}")
|
|
return []
|
|
|
|
# Build WHERE clause with RBAC filtering
|
|
whereConditions = []
|
|
whereValues = []
|
|
|
|
# Add RBAC WHERE clause based on read permission
|
|
rbacWhereClause = buildRbacWhereClause(permissions, currentUser, table, connector)
|
|
if rbacWhereClause:
|
|
whereConditions.append(rbacWhereClause["condition"])
|
|
whereValues.extend(rbacWhereClause["values"])
|
|
|
|
# Add additional record filters
|
|
if recordFilter:
|
|
for field, value in recordFilter.items():
|
|
whereConditions.append(f'"{field}" = %s')
|
|
whereValues.append(value)
|
|
|
|
# Build the query
|
|
whereClause = ""
|
|
if whereConditions:
|
|
whereClause = " WHERE " + " AND ".join(whereConditions)
|
|
|
|
orderByClause = f' ORDER BY "{orderBy}"' if orderBy else ' ORDER BY "id"'
|
|
limitClause = f" LIMIT {limit}" if limit else ""
|
|
|
|
query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
|
|
|
|
with connector.connection.cursor() as cursor:
|
|
cursor.execute(query, whereValues)
|
|
records = [dict(row) for row in cursor.fetchall()]
|
|
|
|
# Handle JSONB fields and ensure numeric types are correct
|
|
# Import the helper function from connector module
|
|
from modules.connectors.connectorDbPostgre import _get_model_fields
|
|
fields = _get_model_fields(modelClass)
|
|
for record in records:
|
|
for fieldName, fieldType in fields.items():
|
|
# Ensure numeric fields are properly typed
|
|
if fieldType in ("DOUBLE PRECISION", "INTEGER") and fieldName in record:
|
|
value = record[fieldName]
|
|
if value is not None:
|
|
try:
|
|
if fieldType == "DOUBLE PRECISION":
|
|
record[fieldName] = float(value)
|
|
elif fieldType == "INTEGER":
|
|
record[fieldName] = int(value)
|
|
except (ValueError, TypeError):
|
|
logger.warning(
|
|
f"Could not convert {fieldName} to {fieldType} for record {record.get('id', 'unknown')}: {value}"
|
|
)
|
|
elif fieldType == "JSONB" and fieldName in record:
|
|
if record[fieldName] is None:
|
|
# Generic type-based default: List types -> [], Dict types -> {}
|
|
# Interfaces handle domain-specific defaults
|
|
modelFields = modelClass.model_fields
|
|
fieldInfo = modelFields.get(fieldName)
|
|
if fieldInfo:
|
|
fieldAnnotation = fieldInfo.annotation
|
|
# Check if it's a List type
|
|
if (fieldAnnotation == list or
|
|
(hasattr(fieldAnnotation, "__origin__") and
|
|
fieldAnnotation.__origin__ is list)):
|
|
record[fieldName] = []
|
|
# Check if it's a Dict type
|
|
elif (fieldAnnotation == dict or
|
|
(hasattr(fieldAnnotation, "__origin__") and
|
|
fieldAnnotation.__origin__ is dict)):
|
|
record[fieldName] = {}
|
|
else:
|
|
record[fieldName] = None
|
|
else:
|
|
record[fieldName] = None
|
|
else:
|
|
try:
|
|
if isinstance(record[fieldName], str):
|
|
record[fieldName] = json.loads(record[fieldName])
|
|
elif isinstance(record[fieldName], (dict, list)):
|
|
pass
|
|
else:
|
|
record[fieldName] = json.loads(str(record[fieldName]))
|
|
except (json.JSONDecodeError, TypeError, ValueError):
|
|
logger.warning(
|
|
f"Could not parse JSONB field {fieldName}, keeping as string: {record[fieldName]}"
|
|
)
|
|
|
|
return records
|
|
except Exception as e:
|
|
logger.error(f"Error loading records with RBAC from table {table}: {e}")
|
|
return []
|
|
|
|
|
|
def buildRbacWhereClause(
|
|
permissions: UserPermissions,
|
|
currentUser: User,
|
|
table: str,
|
|
connector # DatabaseConnector instance for connection access
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Build RBAC WHERE clause based on permissions and access level.
|
|
Moved from connector to interfaces.
|
|
|
|
Args:
|
|
permissions: UserPermissions object
|
|
currentUser: User object
|
|
table: Table name
|
|
connector: DatabaseConnector instance (needed for GROUP queries)
|
|
|
|
Returns:
|
|
Dictionary with "condition" and "values" keys, or None if no filtering needed
|
|
"""
|
|
if not permissions or not hasattr(permissions, "read"):
|
|
return None
|
|
|
|
readLevel = permissions.read
|
|
|
|
# No access - return empty result condition
|
|
if readLevel == AccessLevel.NONE:
|
|
return {"condition": "1 = 0", "values": []}
|
|
|
|
# All records - no filtering needed
|
|
if readLevel == AccessLevel.ALL:
|
|
return None
|
|
|
|
# My records - filter by _createdBy or userId field
|
|
if readLevel == AccessLevel.MY:
|
|
# Try common field names for creator
|
|
userIdField = None
|
|
if table == "UserInDB":
|
|
userIdField = "id"
|
|
elif table == "UserConnection":
|
|
userIdField = "userId"
|
|
else:
|
|
userIdField = "_createdBy"
|
|
|
|
return {
|
|
"condition": f'"{userIdField}" = %s',
|
|
"values": [currentUser.id]
|
|
}
|
|
|
|
# Group records - filter by mandateId
|
|
if readLevel == AccessLevel.GROUP:
|
|
if not currentUser.mandateId:
|
|
logger.warning(f"User {currentUser.id} has no mandateId for GROUP access")
|
|
return {"condition": "1 = 0", "values": []}
|
|
|
|
# For UserInDB, filter by mandateId directly
|
|
if table == "UserInDB":
|
|
return {
|
|
"condition": '"mandateId" = %s',
|
|
"values": [currentUser.mandateId]
|
|
}
|
|
# For UserConnection, need to join with UserInDB or filter by mandateId in user
|
|
elif table == "UserConnection":
|
|
# Get all user IDs in the same mandate using direct SQL query
|
|
try:
|
|
with connector.connection.cursor() as cursor:
|
|
cursor.execute(
|
|
'SELECT "id" FROM "UserInDB" WHERE "mandateId" = %s',
|
|
(currentUser.mandateId,)
|
|
)
|
|
users = cursor.fetchall()
|
|
userIds = [u["id"] for u in users]
|
|
if not userIds:
|
|
return {"condition": "1 = 0", "values": []}
|
|
placeholders = ",".join(["%s"] * len(userIds))
|
|
return {
|
|
"condition": f'"userId" IN ({placeholders})',
|
|
"values": userIds
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error building GROUP filter for UserConnection: {e}")
|
|
return {"condition": "1 = 0", "values": []}
|
|
# For other tables, filter by mandateId
|
|
else:
|
|
return {
|
|
"condition": '"mandateId" = %s',
|
|
"values": [currentUser.mandateId]
|
|
}
|
|
|
|
return None
|
|
|