# Copyright (c) 2025 Patrick Motsch # All rights reserved. import contextvars import re import psycopg2 import psycopg2.extras import logging from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type import uuid from pydantic import BaseModel, Field import threading from modules.shared.timeUtils import getUtcTimestamp from modules.shared.configuration import APP_CONFIG from modules.datamodels.datamodelBase import PowerOnModel from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext logger = logging.getLogger(__name__) # No mapping needed - table name = Pydantic model name exactly class SystemTable(PowerOnModel): """Data model for system table entries""" table_name: str = Field( description="Name of the table", json_schema_extra={ "frontend_type": "text", "frontend_readonly": True, "frontend_required": True, } ) initial_id: Optional[str] = Field( default=None, description="Initial ID for the table", json_schema_extra={ "frontend_type": "text", "frontend_readonly": True, "frontend_required": False, } ) def _isVectorType(sqlType: str) -> bool: """Check if a SQL type string represents a pgvector column.""" return sqlType.upper().startswith("VECTOR") def _isJsonbType(fieldType) -> bool: """Check if a type should be stored as JSONB in PostgreSQL.""" # Direct dict or list if fieldType == dict or fieldType == list: return True # Generic List[X] or Dict[X, Y] origin = get_origin(fieldType) if origin in (dict, list): return True # Direct Pydantic BaseModel subclass if isinstance(fieldType, type) and issubclass(fieldType, BaseModel): return True # Optional[X] - check the inner type if origin is Union: args = get_args(fieldType) for arg in args: if arg is type(None): continue # Recursively check the inner type if _isJsonbType(arg): return True return False def _get_model_fields(model_class) -> Dict[str, str]: """Get all fields from Pydantic model and map to SQL types. Supports explicit db_type override via json_schema_extra={"db_type": "vector(1536)"}. This enables pgvector columns without special-casing field names. """ model_fields = model_class.model_fields fields = {} for field_name, field_info in model_fields.items(): field_type = field_info.annotation # Explicit db_type override (e.g. vector columns) extra = field_info.json_schema_extra if extra and isinstance(extra, dict) and "db_type" in extra: fields[field_name] = extra["db_type"] continue # Unwrap Optional[X] → X (handles both typing.Union and types.UnionType) origin = get_origin(field_type) if origin is Union: args = [a for a in get_args(field_type) if a is not type(None)] if len(args) == 1: field_type = args[0] elif hasattr(field_type, '__args__') and type(None) in getattr(field_type, '__args__', ()): args = [a for a in field_type.__args__ if a is not type(None)] if len(args) == 1: field_type = args[0] if _isJsonbType(field_type): fields[field_name] = "JSONB" elif field_type is bool: fields[field_name] = "BOOLEAN" elif field_type is int: fields[field_name] = "INTEGER" elif field_type is float: fields[field_name] = "DOUBLE PRECISION" elif field_type in (str, type(None)): fields[field_name] = "TEXT" else: fields[field_name] = "TEXT" return fields def _get_fk_sort_meta(model_class) -> Dict[str, Dict[str, str]]: """Map FK field name -> {model, labelField} from json_schema_extra (fk_model + frontend_fk_display_field).""" result: Dict[str, Dict[str, str]] = {} for name, field_info in model_class.model_fields.items(): extra = field_info.json_schema_extra if not extra or not isinstance(extra, dict): continue fk_model = extra.get("fk_model") label_field = extra.get("frontend_fk_display_field") if fk_model and label_field: result[name] = {"model": str(fk_model), "labelField": str(label_field)} return result def _parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: str = "") -> None: """Parse record fields in-place: numeric typing, vector parsing, JSONB deserialization.""" import json as _json for fieldName, fieldType in fields.items(): if fieldName not in record: continue value = record[fieldName] if fieldType in ("DOUBLE PRECISION", "INTEGER") and value is not None: try: record[fieldName] = float(value) if fieldType == "DOUBLE PRECISION" else int(value) except (ValueError, TypeError): logger.warning(f"Could not convert {fieldName} to {fieldType} ({context}): {value}") elif _isVectorType(fieldType) and value is not None: if isinstance(value, str): try: record[fieldName] = [float(v) for v in value.strip("[]").split(",")] except (ValueError, TypeError): logger.warning(f"Could not parse vector field {fieldName} ({context})") elif isinstance(value, list): pass # already a list elif fieldType == "BOOLEAN": record[fieldName] = bool(value) if value is not None else False elif fieldType == "JSONB" and value is not None: try: if isinstance(value, str): record[fieldName] = _json.loads(value) elif not isinstance(value, (dict, list)): record[fieldName] = _json.loads(str(value)) except (_json.JSONDecodeError, TypeError, ValueError): logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})") def _quotePgIdent(name: str) -> str: return '"' + str(name).replace('"', '""') + '"' # Cache connectors by (host, database, port) to avoid duplicate inits for same database. # Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via # contextvars to avoid races when concurrent requests share the same connector. _MAX_CACHED_CONNECTORS = 32 _connector_cache: Dict[tuple, "DatabaseConnector"] = {} _connector_cache_order: List[tuple] = [] # FIFO order for eviction _connector_cache_lock = threading.Lock() _current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( "db_connector_user_id", default=None ) def _get_cached_connector( dbHost: str, dbDatabase: str, dbUser: str = None, dbPassword: str = None, dbPort: int = None, userId: str = None, ) -> "DatabaseConnector": """Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits. Uses contextvars for userId so concurrent requests sharing the same connector get correct sysCreatedBy/sysModifiedBy. """ port = int(dbPort) if dbPort is not None else 5432 key = (dbHost, dbDatabase, port) with _connector_cache_lock: if key not in _connector_cache: # Evict oldest if at capacity while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order: oldest_key = _connector_cache_order.pop(0) if oldest_key in _connector_cache: try: _connector_cache[oldest_key].close(forceClose=True) except Exception as e: logger.warning(f"Error closing evicted connector: {e}") del _connector_cache[oldest_key] _connector_cache[key] = DatabaseConnector( dbHost=dbHost, dbDatabase=dbDatabase, dbUser=dbUser, dbPassword=dbPassword, dbPort=dbPort, userId=userId, ) _connector_cache[key]._isCachedShared = True _connector_cache_order.append(key) conn = _connector_cache[key] # Set request-scoped userId via contextvar (avoids mutating shared connector) if userId is not None: _current_user_id.set(userId) return conn class DatabaseConnector: """ A connector for PostgreSQL-based data storage. Provides generic database operations without user/mandate filtering. Uses PostgreSQL with JSONB columns for flexible data storage. """ def __init__( self, dbHost: str, dbDatabase: str, dbUser: str = None, dbPassword: str = None, dbPort: int = None, userId: str = None, ): # Store the input parameters self.dbHost = dbHost self.dbDatabase = dbDatabase self.dbUser = dbUser self.dbPassword = dbPassword self.dbPort = dbPort # Set userId (default to empty string if None) self.userId = userId if userId is not None else "" # Initialize database system first (creates database if needed) self.connection = None self._isCachedShared = False self.initDbSystem() # No caching needed with proper database - PostgreSQL handles performance # Thread safety self._lock = threading.Lock() # pgvector extension state self._vectorExtensionEnabled = False # Initialize system table self._systemTableName = "_system" self._initializeSystemTable() def initDbSystem(self): """Initialize the database system - creates database and tables.""" try: # Create database if it doesn't exist self._create_database_if_not_exists() # Create tables self._create_tables() # Establish connection to the database self._connect() logger.info("PostgreSQL database system initialized successfully") except Exception as e: logger.error(f"FATAL ERROR: Database system initialization failed: {e}") raise def _create_database_if_not_exists(self): """Create the database if it doesn't exist.""" try: # Use the configured user for database creation conn = psycopg2.connect( host=self.dbHost, port=self.dbPort, database="postgres", user=self.dbUser, password=self.dbPassword, client_encoding="utf8", ) conn.autocommit = True with conn.cursor() as cursor: # Check if database exists cursor.execute( "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,) ) exists = cursor.fetchone() if not exists: # Create database with proper quoting for names with hyphens quoted_db_name = f'"{self.dbDatabase}"' cursor.execute(f"CREATE DATABASE {quoted_db_name}") logger.info(f"Created database: {self.dbDatabase}") conn.close() except Exception as e: logger.error(f"FATAL ERROR: Cannot create database: {e}") logger.error("Database connection failed - application cannot start") raise RuntimeError( f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}" ) def _create_tables(self): """Create only the system table - application tables are created by interfaces.""" try: # Use the configured user for table creation conn = psycopg2.connect( host=self.dbHost, port=self.dbPort, database=self.dbDatabase, user=self.dbUser, password=self.dbPassword, client_encoding="utf8", ) conn.autocommit = True with conn.cursor() as cursor: # Create only the system table cursor.execute(""" CREATE TABLE IF NOT EXISTS _system ( id SERIAL PRIMARY KEY, table_name VARCHAR(255) UNIQUE NOT NULL, initial_id VARCHAR(255) NOT NULL, "sysCreatedAt" DOUBLE PRECISION, "sysCreatedBy" VARCHAR(255), "sysModifiedAt" DOUBLE PRECISION, "sysModifiedBy" VARCHAR(255) ) """) conn.close() except Exception as e: logger.error(f"FATAL ERROR: Cannot create system table: {e}") logger.error( "Database system table creation failed - application cannot start" ) raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}") def _connect(self): """Establish connection to PostgreSQL database.""" try: # Use configured user for main connection with proper parameter handling self.connection = psycopg2.connect( host=self.dbHost, port=self.dbPort, database=self.dbDatabase, user=self.dbUser, password=self.dbPassword, client_encoding="utf8", cursor_factory=psycopg2.extras.RealDictCursor, ) self.connection.autocommit = False # Use transactions except Exception as e: logger.error(f"Failed to connect to PostgreSQL: {e}") raise def _ensure_connection(self): """Ensure database connection is alive, reconnect if necessary.""" try: if self.connection is None or self.connection.closed: self._connect() else: # Test connection with a simple query with self.connection.cursor() as cursor: cursor.execute("SELECT 1") except Exception as e: logger.warning(f"Connection lost, reconnecting: {e}") self._connect() def _initializeSystemTable(self): """Initializes the system table if it doesn't exist yet.""" try: # First ensure the system table exists self._ensureTableExists(SystemTable) with self.connection.cursor() as cursor: # Check if system table has any data cursor.execute('SELECT COUNT(*) FROM "_system"') row = cursor.fetchone() count = row["count"] if row else 0 self.connection.commit() except Exception as e: logger.error(f"Error initializing system table: {e}") self.connection.rollback() raise def _loadSystemTable(self) -> Dict[str, str]: """Loads the system table with the initial IDs.""" try: with self.connection.cursor() as cursor: cursor.execute('SELECT "table_name", "initial_id" FROM "_system"') rows = cursor.fetchall() system_data = {} for row in rows: system_data[row["table_name"]] = row["initial_id"] return system_data except Exception as e: logger.error(f"Error loading system table: {e}") return {} def _saveSystemTable(self, data: Dict[str, str]) -> bool: """Saves the system table with the initial IDs.""" try: with self.connection.cursor() as cursor: # Clear existing data cursor.execute('DELETE FROM "_system"') # Insert new data for table_name, initial_id in data.items(): cursor.execute( """ INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt") VALUES (%s, %s, %s) """, (table_name, initial_id, getUtcTimestamp()), ) self.connection.commit() return True except Exception as e: logger.error(f"Error saving system table: {e}") self.connection.rollback() return False def _ensureSystemTableExists(self) -> bool: """Ensures the system table exists, creates it if it doesn't.""" try: self._ensure_connection() with self.connection.cursor() as cursor: # Check if system table exists cursor.execute( "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", (self._systemTableName,), ) exists = cursor.fetchone()["count"] > 0 if not exists: # Create system table cursor.execute(f""" CREATE TABLE "{self._systemTableName}" ( "table_name" VARCHAR(255) PRIMARY KEY, "initial_id" VARCHAR(255), "sysCreatedAt" DOUBLE PRECISION, "sysCreatedBy" VARCHAR(255), "sysModifiedAt" DOUBLE PRECISION, "sysModifiedBy" VARCHAR(255) ) """) logger.info("System table created successfully") else: # Check if we need to add missing columns to existing table cursor.execute( """ SELECT column_name FROM information_schema.columns WHERE table_name = %s AND table_schema = 'public' """, (self._systemTableName,), ) existing_columns = [row["column_name"] for row in cursor.fetchall()] for sys_col, sys_sql in [ ("sysCreatedAt", "DOUBLE PRECISION"), ("sysCreatedBy", "VARCHAR(255)"), ("sysModifiedAt", "DOUBLE PRECISION"), ("sysModifiedBy", "VARCHAR(255)"), ]: if sys_col not in existing_columns: cursor.execute( f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}' ) return True except Exception as e: logger.error(f"Error ensuring system table exists: {e}") return False def _ensureTableExists(self, model_class: type) -> bool: """Ensures a table exists, creates it if it doesn't.""" table = model_class.__name__ if table == "SystemTable": # Handle system table specially - it uses _system as the actual table name return self._ensureSystemTableExists() try: self._ensure_connection() with self.connection.cursor() as cursor: # Check if table exists by querying information_schema with case-insensitive search cursor.execute( """ SELECT COUNT(*) FROM information_schema.tables WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' """, (table,), ) exists = cursor.fetchone()["count"] > 0 if not exists: # Create table from Pydantic model self._create_table_from_model(cursor, table, model_class) logger.info( f"Created table '{table}' with columns from Pydantic model" ) else: # Table exists: ensure all columns from model are present (simple additive migration) try: cursor.execute( """ SELECT column_name, data_type FROM information_schema.columns WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' """, (table,), ) existing_column_rows = cursor.fetchall() existing_columns = { row["column_name"] for row in existing_column_rows } existing_column_types = { row["column_name"]: (row["data_type"] or "").lower() for row in existing_column_rows } # Desired columns based on model model_fields = _get_model_fields(model_class) desired_columns = set(["id"]) | set(model_fields.keys()) # Add missing columns for col in sorted(desired_columns - existing_columns): # Determine SQL type if col in ["id"]: continue # primary key exists already sql_type = model_fields.get(col) if not sql_type: sql_type = "TEXT" try: cursor.execute( f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}' ) logger.info( f"Added missing column '{col}' ({sql_type}) to '{table}'" ) except Exception as add_err: logger.warning( f"Could not add column '{col}' to '{table}': {add_err}" ) # Targeted type-downgrade: if a model field has been # changed from a structured type (JSONB) to a plain # TEXT field, alter the column so writes don't fail. # JSONB -> TEXT is a safe, lossless cast (JSONB is # rendered as its JSON-text representation; the # corresponding Pydantic ``@field_validator`` is # responsible for re-decoding legacy data on read). for col in sorted(desired_columns & existing_columns): if col == "id": continue desired_sql = (model_fields.get(col) or "").upper() currentType = existing_column_types.get(col, "") if desired_sql == "TEXT" and currentType == "jsonb": try: cursor.execute( f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE TEXT USING "{col}"::text' ) logger.info( f"Downgraded column '{col}' from JSONB to TEXT on '{table}'" ) except Exception as alter_err: logger.warning( f"Could not downgrade column '{col}' on '{table}': {alter_err}" ) except Exception as ensure_err: logger.warning( f"Could not ensure columns for existing table '{table}': {ensure_err}" ) self.connection.commit() return True except Exception as e: logger.error(f"Error ensuring table {table} exists: {e}") if hasattr(self, "connection") and self.connection: self.connection.rollback() return False def _ensureVectorExtension(self) -> bool: """Enable pgvector extension if not already enabled. Called lazily on first vector table.""" if self._vectorExtensionEnabled: return True try: self._ensure_connection() with self.connection.cursor() as cursor: cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") self.connection.commit() self._vectorExtensionEnabled = True logger.info("pgvector extension enabled") return True except Exception as e: logger.error(f"Failed to enable pgvector extension: {e}") if hasattr(self, "connection") and self.connection: self.connection.rollback() return False def _create_table_from_model(self, cursor, table: str, model_class: type) -> None: """Create table with columns matching Pydantic model fields.""" fields = _get_model_fields(model_class) # Enable pgvector if any field uses vector type if any(_isVectorType(sqlType) for sqlType in fields.values()): self._ensureVectorExtension() # Build column definitions with quoted identifiers to preserve exact case columns = ['"id" VARCHAR(255) PRIMARY KEY'] for field_name, sql_type in fields.items(): if field_name != "id": # Skip id, already defined columns.append(f'"{field_name}" {sql_type}') # Create table sql = f'CREATE TABLE IF NOT EXISTS "{table}" ({", ".join(columns)})' cursor.execute(sql) # Create indexes for foreign keys for field_name in fields: if field_name.endswith("Id") and field_name != "id": cursor.execute( f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")' ) def _save_record( self, cursor, table: str, recordId: str, record: Dict[str, Any], model_class: type, ) -> None: """Save record to normalized table with explicit columns.""" # Get columns from Pydantic model instead of database schema fields = _get_model_fields(model_class) columns = ["id"] + [field for field in fields.keys() if field != "id"] if not columns: logger.error(f"No columns found for table {table}") return # Filter record data to only include columns that exist in the table filtered_record = {k: v for k, v in record.items() if k in columns} # Ensure id is set filtered_record["id"] = recordId # Prepare values in the correct order values = [] for col in columns: value = filtered_record.get(col) # Handle timestamp fields - store as Unix timestamps (floats) for consistency if col in ["sysCreatedAt", "sysModifiedAt"] and value is not None: if isinstance(value, str): # Try to parse string as timestamp try: value = float(value) except: pass # Keep as string if parsing fails # Convert enum values to their string representation elif hasattr(value, "value"): value = value.value # Handle vector fields (pgvector) - convert List[float] to string elif col in fields and _isVectorType(fields[col]) and value is not None: if isinstance(value, list): value = f"[{','.join(str(v) for v in value)}]" # Handle JSONB fields - ensure proper JSON format for PostgreSQL elif col in fields and fields[col] == "JSONB" and value is not None: import json if isinstance(value, (dict, list)): value = json.dumps(value) elif isinstance(value, str): try: json.loads(value) except (json.JSONDecodeError, TypeError): value = json.dumps(value) elif hasattr(value, 'model_dump'): value = json.dumps(value.model_dump()) else: value = json.dumps(value) values.append(value) # Build INSERT/UPDATE with quoted identifiers col_names = ", ".join([f'"{col}"' for col in columns]) placeholders = ", ".join(["%s"] * len(columns)) updates = ", ".join( [ f'"{col}" = EXCLUDED."{col}"' for col in columns[1:] if col not in ["sysCreatedAt", "sysCreatedBy"] ] ) sql = f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders}) ON CONFLICT ("id") DO UPDATE SET {updates}' cursor.execute(sql, values) def _loadRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]: """Loads a single record from the normalized table.""" table = model_class.__name__ try: if not self._ensureTableExists(model_class): return None with self.connection.cursor() as cursor: cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,)) row = cursor.fetchone() if not row: return None # Convert row to dict and handle JSONB fields record = dict(row) fields = _get_model_fields(model_class) _parseRecordFields(record, fields, f"record {recordId}") return record except Exception as e: logger.error(f"Error loading record {recordId} from table {table}: {e}") return None def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]: """Load one row by primary key (routes / services; wraps _loadRecord).""" return self._loadRecord(model_class, str(recordId)) def _saveRecord( self, model_class: type, recordId: str, record: Dict[str, Any] ) -> bool: """Saves a single record to the table.""" table = model_class.__name__ try: if not self._ensureTableExists(model_class): return False recordId = str(recordId) if "id" in record and str(record["id"]) != recordId: raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}") # Add metadata - use contextvar for request-scoped userId when sharing connector effective_user_id = _current_user_id.get() if effective_user_id is None: effective_user_id = self.userId currentTime = getUtcTimestamp() # Set sysCreatedAt/sysCreatedBy on first persist; always refresh modified fields. # Treat None and 0 as unset (empty / bad defaults); model_dump often has sysCreatedAt=None. createdTs = record.get("sysCreatedAt") if createdTs is None or createdTs == 0 or createdTs == 0.0: record["sysCreatedAt"] = currentTime if effective_user_id: record["sysCreatedBy"] = effective_user_id elif not record.get("sysCreatedBy"): if effective_user_id: record["sysCreatedBy"] = effective_user_id record["sysModifiedAt"] = currentTime if effective_user_id: record["sysModifiedBy"] = effective_user_id with self.connection.cursor() as cursor: self._save_record(cursor, table, recordId, record, model_class) self.connection.commit() return True except Exception as e: logger.error(f"Error saving record {recordId} to table {table}: {e}") self.connection.rollback() return False def _loadTable(self, model_class: type) -> List[Dict[str, Any]]: """Loads all records from a normalized table.""" table = model_class.__name__ if table == self._systemTableName: return self._loadSystemTable() try: if not self._ensureTableExists(model_class): return [] with self.connection.cursor() as cursor: cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"') records = [dict(row) for row in cursor.fetchall()] fields = _get_model_fields(model_class) modelFields = model_class.model_fields for record in records: _parseRecordFields(record, fields, f"table {table}") # Set type-aware defaults for NULL JSONB fields for fieldName, fieldType in fields.items(): if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: fieldInfo = modelFields.get(fieldName) if fieldInfo: fieldAnnotation = fieldInfo.annotation if (fieldAnnotation == list or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is list)): record[fieldName] = [] elif (fieldAnnotation == dict or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is dict)): record[fieldName] = {} return records except Exception as e: logger.error(f"Error loading table {table}: {e}") return [] def _registerInitialId(self, table: str, initialId: str) -> bool: """Registers the initial ID for a table.""" try: systemData = self._loadSystemTable() if table not in systemData: systemData[table] = initialId success = self._saveSystemTable(systemData) if success: logger.info(f"Initial ID {initialId} for table {table} registered") return success else: # Table already has an initial ID registered logger.debug(f"Table {table} already has initial ID {systemData[table]}") return True except Exception as e: logger.error(f"Error registering the initial ID for table {table}: {e}") return False def _removeInitialId(self, table: str) -> bool: """Removes the initial ID for a table from the system table.""" try: systemData = self._loadSystemTable() if table in systemData: del systemData[table] success = self._saveSystemTable(systemData) if success: logger.info( f"Initial ID for table {table} removed from system table" ) return success return True # If not present, this is not an error except Exception as e: logger.error(f"Error removing initial ID for table {table}: {e}") return False def buildRbacWhereClause( self, permissions: UserPermissions, currentUser: User, table: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None, ) -> Optional[Dict[str, Any]]: """Delegate to interfaceRbac.buildRbacWhereClause (tests and call sites use connector as entry).""" from modules.interfaces.interfaceRbac import buildRbacWhereClause as _buildRbacWhereClause return _buildRbacWhereClause( permissions, currentUser, table, self, mandateId=mandateId, featureInstanceId=featureInstanceId, ) def updateContext(self, userId: str) -> None: """Updates the context of the database connector. Sets both instance userId and contextvar for request-scoped use when connector is shared. """ if userId is None: raise ValueError("userId must be provided") self.userId = userId _current_user_id.set(userId) # Public API def getTables(self) -> List[str]: """Returns a list of all available tables.""" tables = [] try: # Ensure connection is alive self._ensure_connection() if not self.connection or self.connection.closed: logger.error("Database connection is not available") return tables with self.connection.cursor() as cursor: cursor.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name """) rows = cursor.fetchall() tables = [row["table_name"] for row in rows] except Exception as e: logger.error(f"Error reading the database {self.dbDatabase}: {e}") return tables def getFields(self, model_class: type) -> List[str]: """Returns a list of all fields in a table.""" data = self._loadTable(model_class) if not data: return [] fields = list(data[0].keys()) if data else [] return fields def getSchema( self, model_class: type, language: str = None ) -> Dict[str, Dict[str, Any]]: """Returns a schema object for a table with data types and labels.""" data = self._loadTable(model_class) schema = {} if not data: return schema firstRecord = data[0] for field, value in firstRecord.items(): dataType = type(value).__name__ label = field schema[field] = {"type": dataType, "label": label} return schema def getRecordset( self, model_class: type, fieldFilter: List[str] = None, recordFilter: Dict[str, Any] = None, ) -> List[Dict[str, Any]]: """Returns a list of records from a table, filtered by criteria.""" table = model_class.__name__ try: if not self._ensureTableExists(model_class): return [] # Build WHERE clause from recordFilter where_conditions = [] where_values = [] if recordFilter: for field, value in recordFilter.items(): if value is None: where_conditions.append(f'"{field}" IS NULL') elif isinstance(value, list): where_conditions.append(f'"{field}" = ANY(%s)') where_values.append(value) else: where_conditions.append(f'"{field}" = %s') where_values.append(value) if where_conditions: where_clause = " WHERE " + " AND ".join(where_conditions) else: where_clause = "" query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"' with self.connection.cursor() as cursor: cursor.execute(query, where_values) records = [dict(row) for row in cursor.fetchall()] fields = _get_model_fields(model_class) modelFields = model_class.model_fields for record in records: _parseRecordFields(record, fields, f"table {table}") for fieldName, fieldType in fields.items(): if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: fieldInfo = modelFields.get(fieldName) if fieldInfo: fieldAnnotation = fieldInfo.annotation if (fieldAnnotation == list or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is list)): record[fieldName] = [] elif (fieldAnnotation == dict or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is dict)): record[fieldName] = {} # If fieldFilter is available, reduce the fields if fieldFilter and isinstance(fieldFilter, list): result = [] for record in records: filteredRecord = {} for field in fieldFilter: if field in record: filteredRecord[field] = record[field] result.append(filteredRecord) return result return records except Exception as e: logger.error(f"Error loading records from table {table}: {e}") return [] def _buildPaginationClauses( self, model_class: type, pagination, recordFilter: Dict[str, Any] = None, ): """ Translate PaginationParams + recordFilter into SQL clauses. Returns (where_clause, order_clause, limit_clause, values, count_values). """ fields = _get_model_fields(model_class) validColumns = set(fields.keys()) where_parts: List[str] = [] values: List[Any] = [] if recordFilter: for field, value in recordFilter.items(): if value is None: where_parts.append(f'"{field}" IS NULL') elif isinstance(value, list): where_parts.append(f'"{field}" = ANY(%s)') values.append(value) else: where_parts.append(f'"{field}" = %s') values.append(value) if pagination and pagination.filters: for key, val in pagination.filters.items(): if key == "search" and isinstance(val, str) and val.strip(): term = f"%{val.strip()}%" textCols = [c for c, t in fields.items() if t == "TEXT"] if textCols: orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols] where_parts.append(f"({' OR '.join(orParts)})") values.extend([term] * len(textCols)) continue if key not in validColumns: logger.debug(f"_buildPaginationClauses: key '{key}' NOT in validColumns {list(validColumns)[:10]}") continue colType = fields.get(key, "TEXT") logger.debug(f"_buildPaginationClauses: filter key='{key}' val={val!r} type(val)={type(val).__name__} colType={colType}") if val is None: where_parts.append(f'("{key}" IS NULL OR "{key}" = \'\')') continue if isinstance(val, dict): op = val.get("operator", "equals") v = val.get("value", "") if op in ("equals", "eq"): if colType == "BOOLEAN": where_parts.append(f'COALESCE("{key}", FALSE) = %s') values.append(str(v).lower() == "true") else: where_parts.append(f'"{key}"::TEXT = %s') values.append(str(v)) elif op == "contains": where_parts.append(f'"{key}"::TEXT ILIKE %s') values.append(f"%{v}%") elif op == "startsWith": where_parts.append(f'"{key}"::TEXT ILIKE %s') values.append(f"{v}%") elif op == "endsWith": where_parts.append(f'"{key}"::TEXT ILIKE %s') values.append(f"%{v}") elif op in ("gt", "gte", "lt", "lte"): sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op] where_parts.append(f'"{key}"::TEXT {sqlOp} %s') values.append(str(v)) elif op == "between": fromVal = v.get("from", "") if isinstance(v, dict) else "" toVal = v.get("to", "") if isinstance(v, dict) else "" if not fromVal and not toVal: continue colType = fields.get(key, "TEXT") isNumericCol = colType in ("INTEGER", "DOUBLE PRECISION") isDateVal = bool(fromVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(fromVal))) or \ bool(toVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(toVal))) if isNumericCol and isDateVal: from datetime import datetime as _dt, timezone as _tz if fromVal and toVal: fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp() toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp() where_parts.append(f'"{key}" >= %s AND "{key}" <= %s') values.extend([fromTs, toTs]) elif fromVal: fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp() where_parts.append(f'"{key}" >= %s') values.append(fromTs) else: toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp() where_parts.append(f'"{key}" <= %s') values.append(toTs) else: if fromVal and toVal: where_parts.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s') values.extend([str(fromVal), str(toVal)]) elif fromVal: where_parts.append(f'"{key}"::TEXT >= %s') values.append(str(fromVal)) elif toVal: where_parts.append(f'"{key}"::TEXT <= %s') values.append(str(toVal)) else: if colType == "BOOLEAN": where_parts.append(f'COALESCE("{key}", FALSE) = %s') values.append(str(val).lower() == "true") else: where_parts.append(f'"{key}"::TEXT ILIKE %s') values.append(str(val)) where_clause = " WHERE " + " AND ".join(where_parts) if where_parts else "" count_values = list(values) orderParts: List[str] = [] if pagination and pagination.sort: for sf in pagination.sort: sfField = sf.get("field") if isinstance(sf, dict) else getattr(sf, "field", None) sfDir = sf.get("direction", "asc") if isinstance(sf, dict) else getattr(sf, "direction", "asc") if sfField and sfField in validColumns: direction = "DESC" if str(sfDir).lower() == "desc" else "ASC" colType = fields.get(sfField, "TEXT") if colType == "BOOLEAN": orderParts.append(f'COALESCE("{sfField}", FALSE) {direction}') else: orderParts.append(f'"{sfField}" {direction} NULLS LAST') if not orderParts: orderParts.append('"id"') order_clause = " ORDER BY " + ", ".join(orderParts) limit_clause = "" if pagination: offset = (pagination.page - 1) * pagination.pageSize limit_clause = f" LIMIT {pagination.pageSize} OFFSET {offset}" return where_clause, order_clause, limit_clause, values, count_values def getRecordsetPaginated( self, model_class: type, pagination=None, recordFilter: Dict[str, Any] = None, fieldFilter: List[str] = None, ) -> Dict[str, Any]: """ Returns paginated records with filtering + sorting at the SQL level. Returns { "items": [...], "totalItems": int, "totalPages": int }. If pagination is None, returns all records (no LIMIT/OFFSET). """ from modules.datamodels.datamodelPagination import PaginationParams import math table = model_class.__name__ try: if not self._ensureTableExists(model_class): return {"items": [], "totalItems": 0, "totalPages": 0} where_clause, order_clause, limit_clause, values, count_values = \ self._buildPaginationClauses(model_class, pagination, recordFilter) with self.connection.cursor() as cursor: countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}' dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}' cursor.execute(countSql, count_values) totalItems = cursor.fetchone()["count"] cursor.execute(dataSql, values) records = [dict(row) for row in cursor.fetchall()] fields = _get_model_fields(model_class) modelFields = model_class.model_fields for record in records: _parseRecordFields(record, fields, f"table {table}") for fieldName, fieldType in fields.items(): if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: fieldInfo = modelFields.get(fieldName) if fieldInfo: fieldAnnotation = fieldInfo.annotation if (fieldAnnotation == list or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is list)): record[fieldName] = [] elif (fieldAnnotation == dict or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is dict)): record[fieldName] = {} if fieldFilter and isinstance(fieldFilter, list): records = [{f: r[f] for f in fieldFilter if f in r} for r in records] pageSize = pagination.pageSize if pagination else max(totalItems, 1) totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0 return {"items": records, "totalItems": totalItems, "totalPages": totalPages} except Exception as e: logger.error(f"Error in getRecordsetPaginated for table {table}: {e}") return {"items": [], "totalItems": 0, "totalPages": 0} def getDistinctColumnValues( self, model_class: type, column: str, pagination=None, recordFilter: Dict[str, Any] = None, ) -> List[str]: """ Returns sorted distinct non-null values for a column using SQL DISTINCT. Applies cross-filtering (all filters except the requested column). """ table = model_class.__name__ fields = _get_model_fields(model_class) if column not in fields: return [] try: if not self._ensureTableExists(model_class): return [] if pagination: import copy pagination = copy.deepcopy(pagination) if pagination.filters and column in pagination.filters: pagination.filters.pop(column, None) pagination.sort = [] where_clause, _, _, values, _ = \ self._buildPaginationClauses(model_class, pagination, recordFilter) sql = ( f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} ' f'WHERE "{column}" IS NOT NULL AND "{column}"::TEXT != \'\' ' if not where_clause else f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} ' f'AND "{column}" IS NOT NULL AND "{column}"::TEXT != \'\' ' ) sql += 'ORDER BY val' with self.connection.cursor() as cursor: cursor.execute(sql, values) return [row["val"] for row in cursor.fetchall()] except Exception as e: logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}") return [] def recordCreate( self, model_class: type, record: Union[Dict[str, Any], BaseModel] ) -> Dict[str, Any]: """Creates a new record in a table based on Pydantic model class.""" # If record is a Pydantic model, convert to dict if isinstance(record, BaseModel): record = record.model_dump() elif isinstance(record, dict): record = record.copy() else: raise ValueError("Record must be a Pydantic model or dictionary") # Ensure record has an ID if "id" not in record: record["id"] = str(uuid.uuid4()) # Save record success = self._saveRecord(model_class, record["id"], record) if not success: table = model_class.__name__ raise ValueError(f"Failed to save record {record['id']} to table {table}") # Check if this is the first record in the table and register as initial ID table = model_class.__name__ existingInitialId = self.getInitialId(model_class) if existingInitialId is None: # This is the first record, register it as the initial ID self._registerInitialId(table, record["id"]) logger.info(f"Registered initial ID {record['id']} for table {table}") return record def recordModify( self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel] ) -> Dict[str, Any]: """Modifies an existing record in a table based on Pydantic model class.""" # Load existing record existingRecord = self._loadRecord(model_class, recordId) if not existingRecord: table = model_class.__name__ raise ValueError(f"Record {recordId} not found in table {table}") # If record is a Pydantic model, convert to dict if isinstance(record, BaseModel): record = record.model_dump() elif isinstance(record, dict): record = record.copy() else: raise ValueError("Record must be a Pydantic model or dictionary") # CRITICAL: Ensure we never modify the ID if "id" in record and str(record["id"]) != recordId: logger.error( f"Attempted to modify record ID from {recordId} to {record['id']}" ) raise ValueError( "Cannot modify record ID - it must match the provided recordId" ) # Update existing record with new data existingRecord.update(record) # Save updated record saved = self._saveRecord(model_class, recordId, existingRecord) if not saved: table = model_class.__name__ raise ValueError(f"Failed to save record {recordId} to table {table}") return existingRecord def recordDelete(self, model_class: type, recordId: str) -> bool: """Deletes a record from the table based on Pydantic model class.""" table = model_class.__name__ try: if not self._ensureTableExists(model_class): return False with self.connection.cursor() as cursor: # Check if record exists cursor.execute( f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,) ) if not cursor.fetchone(): return False # Check if it's an initial record initialId = self.getInitialId(model_class) if initialId is not None and initialId == recordId: self._removeInitialId(table) logger.info( f"Initial ID {recordId} for table {table} has been removed from the system table" ) # Delete the record cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,)) # No cache to update - database handles consistency self.connection.commit() return True except Exception as e: logger.error(f"Error deleting record {recordId} from table {table}: {e}") self.connection.rollback() return False def getInitialId(self, model_class: type) -> Optional[str]: """Returns the initial ID for a table.""" table = model_class.__name__ systemData = self._loadSystemTable() initialId = systemData.get(table) return initialId def semanticSearch( self, modelClass: type, vectorColumn: str, queryVector: List[float], limit: int = 10, recordFilter: Dict[str, Any] = None, minScore: float = None, ) -> List[Dict[str, Any]]: """Semantic search using pgvector cosine distance. Args: modelClass: Pydantic model class for the table. vectorColumn: Name of the vector column to search. queryVector: Query vector as List[float]. limit: Maximum number of results. recordFilter: Additional WHERE filters (field: value). minScore: Minimum cosine similarity (0.0 - 1.0). Returns: List of records with an added '_score' field (cosine similarity), sorted by similarity descending. """ table = modelClass.__name__ try: if not self._ensureTableExists(modelClass): return [] vectorStr = f"[{','.join(str(v) for v in queryVector)}]" whereConditions = [] whereValues = [] if recordFilter: for field, value in recordFilter.items(): if value is None: whereConditions.append(f'"{field}" IS NULL') elif isinstance(value, (list, tuple)): if not value: whereConditions.append("1 = 0") else: whereConditions.append(f'"{field}" = ANY(%s)') whereValues.append(list(value)) else: whereConditions.append(f'"{field}" = %s') whereValues.append(value) if minScore is not None: whereConditions.append( f'1 - ("{vectorColumn}" <=> %s::vector) >= %s' ) whereValues.extend([vectorStr, minScore]) whereClause = "" if whereConditions: whereClause = " WHERE " + " AND ".join(whereConditions) query = ( f'SELECT *, 1 - ("{vectorColumn}" <=> %s::vector) AS "_score" ' f'FROM "{table}"{whereClause} ' f'ORDER BY "{vectorColumn}" <=> %s::vector ' f'LIMIT %s' ) params = [vectorStr] + whereValues + [vectorStr, limit] with self.connection.cursor() as cursor: cursor.execute(query, params) records = [dict(row) for row in cursor.fetchall()] fields = _get_model_fields(modelClass) for record in records: _parseRecordFields(record, fields, f"semanticSearch {table}") return records except Exception as e: logger.error(f"Error in semantic search on {table}: {e}") return [] def close(self, forceClose: bool = False): """Close the database connection. Shared cached connectors are intentionally kept open unless forceClose=True. This prevents accidental shutdown from interface __del__ methods while other requests are still using the same cached connector instance. """ if self._isCachedShared and not forceClose: return if ( hasattr(self, "connection") and self.connection and not self.connection.closed ): self.connection.close() def __del__(self): """Cleanup method to close connection.""" try: self.close() except Exception: pass