gateway/modules/connectors/connectorDbPostgre.py

1455 lines
58 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
import contextvars
import re
import psycopg2
import psycopg2.extras
import logging
from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type
import uuid
from pydantic import BaseModel, Field
import threading
from modules.shared.timeUtils import getUtcTimestamp
from modules.shared.configuration import APP_CONFIG
from modules.datamodels.datamodelBase import PowerOnModel
from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
logger = logging.getLogger(__name__)
# No mapping needed - table name = Pydantic model name exactly
class SystemTable(PowerOnModel):
"""Data model for system table entries"""
table_name: str = Field(
description="Name of the table",
json_schema_extra={
"frontend_type": "text",
"frontend_readonly": True,
"frontend_required": True,
}
)
initial_id: Optional[str] = Field(
default=None,
description="Initial ID for the table",
json_schema_extra={
"frontend_type": "text",
"frontend_readonly": True,
"frontend_required": False,
}
)
def _isVectorType(sqlType: str) -> bool:
"""Check if a SQL type string represents a pgvector column."""
return sqlType.upper().startswith("VECTOR")
def _isJsonbType(fieldType) -> bool:
"""Check if a type should be stored as JSONB in PostgreSQL."""
# Direct dict or list
if fieldType == dict or fieldType == list:
return True
# Generic List[X] or Dict[X, Y]
origin = get_origin(fieldType)
if origin in (dict, list):
return True
# Direct Pydantic BaseModel subclass
if isinstance(fieldType, type) and issubclass(fieldType, BaseModel):
return True
# Optional[X] - check the inner type
if origin is Union:
args = get_args(fieldType)
for arg in args:
if arg is type(None):
continue
# Recursively check the inner type
if _isJsonbType(arg):
return True
return False
def _get_model_fields(model_class) -> Dict[str, str]:
"""Get all fields from Pydantic model and map to SQL types.
Supports explicit db_type override via json_schema_extra={"db_type": "vector(1536)"}.
This enables pgvector columns without special-casing field names.
"""
model_fields = model_class.model_fields
fields = {}
for field_name, field_info in model_fields.items():
field_type = field_info.annotation
# Explicit db_type override (e.g. vector columns)
extra = field_info.json_schema_extra
if extra and isinstance(extra, dict) and "db_type" in extra:
fields[field_name] = extra["db_type"]
continue
# Unwrap Optional[X] → X (handles both typing.Union and types.UnionType)
origin = get_origin(field_type)
if origin is Union:
args = [a for a in get_args(field_type) if a is not type(None)]
if len(args) == 1:
field_type = args[0]
elif hasattr(field_type, '__args__') and type(None) in getattr(field_type, '__args__', ()):
args = [a for a in field_type.__args__ if a is not type(None)]
if len(args) == 1:
field_type = args[0]
if _isJsonbType(field_type):
fields[field_name] = "JSONB"
elif field_type is bool:
fields[field_name] = "BOOLEAN"
elif field_type is int:
fields[field_name] = "INTEGER"
elif field_type is float:
fields[field_name] = "DOUBLE PRECISION"
elif field_type in (str, type(None)):
fields[field_name] = "TEXT"
else:
fields[field_name] = "TEXT"
return fields
def _parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: str = "") -> None:
"""Parse record fields in-place: numeric typing, vector parsing, JSONB deserialization."""
import json as _json
for fieldName, fieldType in fields.items():
if fieldName not in record:
continue
value = record[fieldName]
if fieldType in ("DOUBLE PRECISION", "INTEGER") and value is not None:
try:
record[fieldName] = float(value) if fieldType == "DOUBLE PRECISION" else int(value)
except (ValueError, TypeError):
logger.warning(f"Could not convert {fieldName} to {fieldType} ({context}): {value}")
elif _isVectorType(fieldType) and value is not None:
if isinstance(value, str):
try:
record[fieldName] = [float(v) for v in value.strip("[]").split(",")]
except (ValueError, TypeError):
logger.warning(f"Could not parse vector field {fieldName} ({context})")
elif isinstance(value, list):
pass # already a list
elif fieldType == "BOOLEAN":
record[fieldName] = bool(value) if value is not None else False
elif fieldType == "JSONB" and value is not None:
try:
if isinstance(value, str):
record[fieldName] = _json.loads(value)
elif not isinstance(value, (dict, list)):
record[fieldName] = _json.loads(str(value))
except (_json.JSONDecodeError, TypeError, ValueError):
logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})")
def _quotePgIdent(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"'
# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
# contextvars to avoid races when concurrent requests share the same connector.
_MAX_CACHED_CONNECTORS = 32
_connector_cache: Dict[tuple, "DatabaseConnector"] = {}
_connector_cache_order: List[tuple] = [] # FIFO order for eviction
_connector_cache_lock = threading.Lock()
_current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"db_connector_user_id", default=None
)
def _get_cached_connector(
dbHost: str,
dbDatabase: str,
dbUser: str = None,
dbPassword: str = None,
dbPort: int = None,
userId: str = None,
) -> "DatabaseConnector":
"""Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits.
Uses contextvars for userId so concurrent requests sharing the same connector get correct sysCreatedBy/sysModifiedBy.
"""
port = int(dbPort) if dbPort is not None else 5432
key = (dbHost, dbDatabase, port)
with _connector_cache_lock:
if key not in _connector_cache:
# Evict oldest if at capacity
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
oldest_key = _connector_cache_order.pop(0)
if oldest_key in _connector_cache:
try:
_connector_cache[oldest_key].close(forceClose=True)
except Exception as e:
logger.warning(f"Error closing evicted connector: {e}")
del _connector_cache[oldest_key]
_connector_cache[key] = DatabaseConnector(
dbHost=dbHost,
dbDatabase=dbDatabase,
dbUser=dbUser,
dbPassword=dbPassword,
dbPort=dbPort,
userId=userId,
)
_connector_cache[key]._isCachedShared = True
_connector_cache_order.append(key)
conn = _connector_cache[key]
# Set request-scoped userId via contextvar (avoids mutating shared connector)
if userId is not None:
_current_user_id.set(userId)
return conn
class DatabaseConnector:
"""
A connector for PostgreSQL-based data storage.
Provides generic database operations without user/mandate filtering.
Uses PostgreSQL with JSONB columns for flexible data storage.
"""
def __init__(
self,
dbHost: str,
dbDatabase: str,
dbUser: str = None,
dbPassword: str = None,
dbPort: int = None,
userId: str = None,
):
# Store the input parameters
self.dbHost = dbHost
self.dbDatabase = dbDatabase
self.dbUser = dbUser
self.dbPassword = dbPassword
self.dbPort = dbPort
# Set userId (default to empty string if None)
self.userId = userId if userId is not None else ""
# Initialize database system first (creates database if needed)
self.connection = None
self._isCachedShared = False
self.initDbSystem()
# No caching needed with proper database - PostgreSQL handles performance
# Thread safety
self._lock = threading.Lock()
# pgvector extension state
self._vectorExtensionEnabled = False
# Initialize system table
self._systemTableName = "_system"
self._initializeSystemTable()
def initDbSystem(self):
"""Initialize the database system - creates database and tables."""
try:
# Create database if it doesn't exist
self._create_database_if_not_exists()
# Create tables
self._create_tables()
# Establish connection to the database
self._connect()
logger.info("PostgreSQL database system initialized successfully")
except Exception as e:
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
raise
def _create_database_if_not_exists(self):
"""Create the database if it doesn't exist."""
try:
# Use the configured user for database creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database="postgres",
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
)
conn.autocommit = True
with conn.cursor() as cursor:
# Check if database exists
cursor.execute(
"SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
)
exists = cursor.fetchone()
if not exists:
# Create database with proper quoting for names with hyphens
quoted_db_name = f'"{self.dbDatabase}"'
cursor.execute(f"CREATE DATABASE {quoted_db_name}")
logger.info(f"Created database: {self.dbDatabase}")
conn.close()
except Exception as e:
logger.error(f"FATAL ERROR: Cannot create database: {e}")
logger.error("Database connection failed - application cannot start")
raise RuntimeError(
f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}"
)
def _create_tables(self):
"""Create only the system table - application tables are created by interfaces."""
try:
# Use the configured user for table creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database=self.dbDatabase,
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
)
conn.autocommit = True
with conn.cursor() as cursor:
# Create only the system table
cursor.execute("""
CREATE TABLE IF NOT EXISTS _system (
id SERIAL PRIMARY KEY,
table_name VARCHAR(255) UNIQUE NOT NULL,
initial_id VARCHAR(255) NOT NULL,
"sysCreatedAt" DOUBLE PRECISION,
"sysCreatedBy" VARCHAR(255),
"sysModifiedAt" DOUBLE PRECISION,
"sysModifiedBy" VARCHAR(255)
)
""")
conn.close()
except Exception as e:
logger.error(f"FATAL ERROR: Cannot create system table: {e}")
logger.error(
"Database system table creation failed - application cannot start"
)
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
def _connect(self):
"""Establish connection to PostgreSQL database."""
try:
# Use configured user for main connection with proper parameter handling
self.connection = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database=self.dbDatabase,
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
cursor_factory=psycopg2.extras.RealDictCursor,
)
self.connection.autocommit = False # Use transactions
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
raise
def _ensure_connection(self):
"""Ensure database connection is alive, reconnect if necessary."""
try:
if self.connection is None or self.connection.closed:
self._connect()
else:
# Test connection with a simple query
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1")
except Exception as e:
logger.warning(f"Connection lost, reconnecting: {e}")
self._connect()
def _initializeSystemTable(self):
"""Initializes the system table if it doesn't exist yet."""
try:
# First ensure the system table exists
self._ensureTableExists(SystemTable)
with self.connection.cursor() as cursor:
# Check if system table has any data
cursor.execute('SELECT COUNT(*) FROM "_system"')
row = cursor.fetchone()
count = row["count"] if row else 0
self.connection.commit()
except Exception as e:
logger.error(f"Error initializing system table: {e}")
self.connection.rollback()
raise
def _loadSystemTable(self) -> Dict[str, str]:
"""Loads the system table with the initial IDs."""
try:
with self.connection.cursor() as cursor:
cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
rows = cursor.fetchall()
system_data = {}
for row in rows:
system_data[row["table_name"]] = row["initial_id"]
return system_data
except Exception as e:
logger.error(f"Error loading system table: {e}")
return {}
def _saveSystemTable(self, data: Dict[str, str]) -> bool:
"""Saves the system table with the initial IDs."""
try:
with self.connection.cursor() as cursor:
# Clear existing data
cursor.execute('DELETE FROM "_system"')
# Insert new data
for table_name, initial_id in data.items():
cursor.execute(
"""
INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt")
VALUES (%s, %s, %s)
""",
(table_name, initial_id, getUtcTimestamp()),
)
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error saving system table: {e}")
self.connection.rollback()
return False
def _ensureSystemTableExists(self) -> bool:
"""Ensures the system table exists, creates it if it doesn't."""
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
# Check if system table exists
cursor.execute(
"SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
(self._systemTableName,),
)
exists = cursor.fetchone()["count"] > 0
if not exists:
# Create system table
cursor.execute(f"""
CREATE TABLE "{self._systemTableName}" (
"table_name" VARCHAR(255) PRIMARY KEY,
"initial_id" VARCHAR(255),
"sysCreatedAt" DOUBLE PRECISION,
"sysCreatedBy" VARCHAR(255),
"sysModifiedAt" DOUBLE PRECISION,
"sysModifiedBy" VARCHAR(255)
)
""")
logger.info("System table created successfully")
else:
# Check if we need to add missing columns to existing table
cursor.execute(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = %s AND table_schema = 'public'
""",
(self._systemTableName,),
)
existing_columns = [row["column_name"] for row in cursor.fetchall()]
for sys_col, sys_sql in [
("sysCreatedAt", "DOUBLE PRECISION"),
("sysCreatedBy", "VARCHAR(255)"),
("sysModifiedAt", "DOUBLE PRECISION"),
("sysModifiedBy", "VARCHAR(255)"),
]:
if sys_col not in existing_columns:
cursor.execute(
f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
)
return True
except Exception as e:
logger.error(f"Error ensuring system table exists: {e}")
return False
def _ensureTableExists(self, model_class: type) -> bool:
"""Ensures a table exists, creates it if it doesn't."""
table = model_class.__name__
if table == "SystemTable":
# Handle system table specially - it uses _system as the actual table name
return self._ensureSystemTableExists()
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
# Check if table exists by querying information_schema with case-insensitive search
cursor.execute(
"""
SELECT COUNT(*) FROM information_schema.tables
WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
""",
(table,),
)
exists = cursor.fetchone()["count"] > 0
if not exists:
# Create table from Pydantic model
self._create_table_from_model(cursor, table, model_class)
logger.info(
f"Created table '{table}' with columns from Pydantic model"
)
else:
# Table exists: ensure all columns from model are present (simple additive migration)
try:
cursor.execute(
"""
SELECT column_name FROM information_schema.columns
WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
""",
(table,),
)
existing_columns = {
row["column_name"] for row in cursor.fetchall()
}
# Desired columns based on model
model_fields = _get_model_fields(model_class)
desired_columns = set(["id"]) | set(model_fields.keys())
# Add missing columns
for col in sorted(desired_columns - existing_columns):
# Determine SQL type
if col in ["id"]:
continue # primary key exists already
sql_type = model_fields.get(col)
if not sql_type:
sql_type = "TEXT"
try:
cursor.execute(
f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}'
)
logger.info(
f"Added missing column '{col}' ({sql_type}) to '{table}'"
)
except Exception as add_err:
logger.warning(
f"Could not add column '{col}' to '{table}': {add_err}"
)
except Exception as ensure_err:
logger.warning(
f"Could not ensure columns for existing table '{table}': {ensure_err}"
)
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error ensuring table {table} exists: {e}")
if hasattr(self, "connection") and self.connection:
self.connection.rollback()
return False
def _ensureVectorExtension(self) -> bool:
"""Enable pgvector extension if not already enabled. Called lazily on first vector table."""
if self._vectorExtensionEnabled:
return True
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
self.connection.commit()
self._vectorExtensionEnabled = True
logger.info("pgvector extension enabled")
return True
except Exception as e:
logger.error(f"Failed to enable pgvector extension: {e}")
if hasattr(self, "connection") and self.connection:
self.connection.rollback()
return False
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
"""Create table with columns matching Pydantic model fields."""
fields = _get_model_fields(model_class)
# Enable pgvector if any field uses vector type
if any(_isVectorType(sqlType) for sqlType in fields.values()):
self._ensureVectorExtension()
# Build column definitions with quoted identifiers to preserve exact case
columns = ['"id" VARCHAR(255) PRIMARY KEY']
for field_name, sql_type in fields.items():
if field_name != "id": # Skip id, already defined
columns.append(f'"{field_name}" {sql_type}')
# Create table
sql = f'CREATE TABLE IF NOT EXISTS "{table}" ({", ".join(columns)})'
cursor.execute(sql)
# Create indexes for foreign keys
for field_name in fields:
if field_name.endswith("Id") and field_name != "id":
cursor.execute(
f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")'
)
def _save_record(
self,
cursor,
table: str,
recordId: str,
record: Dict[str, Any],
model_class: type,
) -> None:
"""Save record to normalized table with explicit columns."""
# Get columns from Pydantic model instead of database schema
fields = _get_model_fields(model_class)
columns = ["id"] + [field for field in fields.keys() if field != "id"]
if not columns:
logger.error(f"No columns found for table {table}")
return
# Filter record data to only include columns that exist in the table
filtered_record = {k: v for k, v in record.items() if k in columns}
# Ensure id is set
filtered_record["id"] = recordId
# Prepare values in the correct order
values = []
for col in columns:
value = filtered_record.get(col)
# Handle timestamp fields - store as Unix timestamps (floats) for consistency
if col in ["sysCreatedAt", "sysModifiedAt"] and value is not None:
if isinstance(value, str):
# Try to parse string as timestamp
try:
value = float(value)
except:
pass # Keep as string if parsing fails
# Convert enum values to their string representation
elif hasattr(value, "value"):
value = value.value
# Handle vector fields (pgvector) - convert List[float] to string
elif col in fields and _isVectorType(fields[col]) and value is not None:
if isinstance(value, list):
value = f"[{','.join(str(v) for v in value)}]"
# Handle JSONB fields - ensure proper JSON format for PostgreSQL
elif col in fields and fields[col] == "JSONB" and value is not None:
import json
if isinstance(value, (dict, list)):
value = json.dumps(value)
elif isinstance(value, str):
try:
json.loads(value)
except (json.JSONDecodeError, TypeError):
value = json.dumps(value)
elif hasattr(value, 'model_dump'):
value = json.dumps(value.model_dump())
else:
value = json.dumps(value)
values.append(value)
# Build INSERT/UPDATE with quoted identifiers
col_names = ", ".join([f'"{col}"' for col in columns])
placeholders = ", ".join(["%s"] * len(columns))
updates = ", ".join(
[
f'"{col}" = EXCLUDED."{col}"'
for col in columns[1:]
if col not in ["sysCreatedAt", "sysCreatedBy"]
]
)
sql = f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders}) ON CONFLICT ("id") DO UPDATE SET {updates}'
cursor.execute(sql, values)
def _loadRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
"""Loads a single record from the normalized table."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return None
with self.connection.cursor() as cursor:
cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
row = cursor.fetchone()
if not row:
return None
# Convert row to dict and handle JSONB fields
record = dict(row)
fields = _get_model_fields(model_class)
_parseRecordFields(record, fields, f"record {recordId}")
return record
except Exception as e:
logger.error(f"Error loading record {recordId} from table {table}: {e}")
return None
def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
"""Load one row by primary key (routes / services; wraps _loadRecord)."""
return self._loadRecord(model_class, str(recordId))
def _saveRecord(
self, model_class: type, recordId: str, record: Dict[str, Any]
) -> bool:
"""Saves a single record to the table."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return False
recordId = str(recordId)
if "id" in record and str(record["id"]) != recordId:
raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}")
# Add metadata - use contextvar for request-scoped userId when sharing connector
effective_user_id = _current_user_id.get()
if effective_user_id is None:
effective_user_id = self.userId
currentTime = getUtcTimestamp()
# Set sysCreatedAt/sysCreatedBy on first persist; always refresh modified fields.
# Treat None and 0 as unset (empty / bad defaults); model_dump often has sysCreatedAt=None.
createdTs = record.get("sysCreatedAt")
if createdTs is None or createdTs == 0 or createdTs == 0.0:
record["sysCreatedAt"] = currentTime
if effective_user_id:
record["sysCreatedBy"] = effective_user_id
elif not record.get("sysCreatedBy"):
if effective_user_id:
record["sysCreatedBy"] = effective_user_id
record["sysModifiedAt"] = currentTime
if effective_user_id:
record["sysModifiedBy"] = effective_user_id
with self.connection.cursor() as cursor:
self._save_record(cursor, table, recordId, record, model_class)
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error saving record {recordId} to table {table}: {e}")
self.connection.rollback()
return False
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
"""Loads all records from a normalized table."""
table = model_class.__name__
if table == self._systemTableName:
return self._loadSystemTable()
try:
if not self._ensureTableExists(model_class):
return []
with self.connection.cursor() as cursor:
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
records = [dict(row) for row in cursor.fetchall()]
fields = _get_model_fields(model_class)
modelFields = model_class.model_fields
for record in records:
_parseRecordFields(record, fields, f"table {table}")
# Set type-aware defaults for NULL JSONB fields
for fieldName, fieldType in fields.items():
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
fieldInfo = modelFields.get(fieldName)
if fieldInfo:
fieldAnnotation = fieldInfo.annotation
if (fieldAnnotation == list or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is list)):
record[fieldName] = []
elif (fieldAnnotation == dict or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is dict)):
record[fieldName] = {}
return records
except Exception as e:
logger.error(f"Error loading table {table}: {e}")
return []
def _registerInitialId(self, table: str, initialId: str) -> bool:
"""Registers the initial ID for a table."""
try:
systemData = self._loadSystemTable()
if table not in systemData:
systemData[table] = initialId
success = self._saveSystemTable(systemData)
if success:
logger.info(f"Initial ID {initialId} for table {table} registered")
return success
else:
# Table already has an initial ID registered
logger.debug(f"Table {table} already has initial ID {systemData[table]}")
return True
except Exception as e:
logger.error(f"Error registering the initial ID for table {table}: {e}")
return False
def _removeInitialId(self, table: str) -> bool:
"""Removes the initial ID for a table from the system table."""
try:
systemData = self._loadSystemTable()
if table in systemData:
del systemData[table]
success = self._saveSystemTable(systemData)
if success:
logger.info(
f"Initial ID for table {table} removed from system table"
)
return success
return True # If not present, this is not an error
except Exception as e:
logger.error(f"Error removing initial ID for table {table}: {e}")
return False
def buildRbacWhereClause(
self,
permissions: UserPermissions,
currentUser: User,
table: str,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""Delegate to interfaceRbac.buildRbacWhereClause (tests and call sites use connector as entry)."""
from modules.interfaces.interfaceRbac import buildRbacWhereClause as _buildRbacWhereClause
return _buildRbacWhereClause(
permissions,
currentUser,
table,
self,
mandateId=mandateId,
featureInstanceId=featureInstanceId,
)
def updateContext(self, userId: str) -> None:
"""Updates the context of the database connector.
Sets both instance userId and contextvar for request-scoped use when connector is shared.
"""
if userId is None:
raise ValueError("userId must be provided")
self.userId = userId
_current_user_id.set(userId)
# Public API
def getTables(self) -> List[str]:
"""Returns a list of all available tables."""
tables = []
try:
# Ensure connection is alive
self._ensure_connection()
if not self.connection or self.connection.closed:
logger.error("Database connection is not available")
return tables
with self.connection.cursor() as cursor:
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
ORDER BY table_name
""")
rows = cursor.fetchall()
tables = [row["table_name"] for row in rows]
except Exception as e:
logger.error(f"Error reading the database {self.dbDatabase}: {e}")
return tables
def getFields(self, model_class: type) -> List[str]:
"""Returns a list of all fields in a table."""
data = self._loadTable(model_class)
if not data:
return []
fields = list(data[0].keys()) if data else []
return fields
def getSchema(
self, model_class: type, language: str = None
) -> Dict[str, Dict[str, Any]]:
"""Returns a schema object for a table with data types and labels."""
data = self._loadTable(model_class)
schema = {}
if not data:
return schema
firstRecord = data[0]
for field, value in firstRecord.items():
dataType = type(value).__name__
label = field
schema[field] = {"type": dataType, "label": label}
return schema
def getRecordset(
self,
model_class: type,
fieldFilter: List[str] = None,
recordFilter: Dict[str, Any] = None,
) -> List[Dict[str, Any]]:
"""Returns a list of records from a table, filtered by criteria."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return []
# Build WHERE clause from recordFilter
where_conditions = []
where_values = []
if recordFilter:
for field, value in recordFilter.items():
if value is None:
# Use IS NULL for None values (= NULL is always false in SQL)
where_conditions.append(f'"{field}" IS NULL')
else:
where_conditions.append(f'"{field}" = %s')
where_values.append(value)
# Build the query
if where_conditions:
where_clause = " WHERE " + " AND ".join(where_conditions)
else:
where_clause = ""
query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"'
with self.connection.cursor() as cursor:
cursor.execute(query, where_values)
records = [dict(row) for row in cursor.fetchall()]
fields = _get_model_fields(model_class)
modelFields = model_class.model_fields
for record in records:
_parseRecordFields(record, fields, f"table {table}")
for fieldName, fieldType in fields.items():
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
fieldInfo = modelFields.get(fieldName)
if fieldInfo:
fieldAnnotation = fieldInfo.annotation
if (fieldAnnotation == list or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is list)):
record[fieldName] = []
elif (fieldAnnotation == dict or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is dict)):
record[fieldName] = {}
# If fieldFilter is available, reduce the fields
if fieldFilter and isinstance(fieldFilter, list):
result = []
for record in records:
filteredRecord = {}
for field in fieldFilter:
if field in record:
filteredRecord[field] = record[field]
result.append(filteredRecord)
return result
return records
except Exception as e:
logger.error(f"Error loading records from table {table}: {e}")
return []
def _buildPaginationClauses(
self,
model_class: type,
pagination,
recordFilter: Dict[str, Any] = None,
):
"""
Translate PaginationParams + recordFilter into SQL clauses.
Returns (where_clause, order_clause, limit_clause, values, count_values).
"""
fields = _get_model_fields(model_class)
validColumns = set(fields.keys())
where_parts: List[str] = []
values: List[Any] = []
if recordFilter:
for field, value in recordFilter.items():
if value is None:
where_parts.append(f'"{field}" IS NULL')
elif isinstance(value, list):
where_parts.append(f'"{field}" = ANY(%s)')
values.append(value)
else:
where_parts.append(f'"{field}" = %s')
values.append(value)
if pagination and pagination.filters:
for key, val in pagination.filters.items():
if key == "search" and isinstance(val, str) and val.strip():
term = f"%{val.strip()}%"
textCols = [c for c, t in fields.items() if t == "TEXT"]
if textCols:
orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols]
where_parts.append(f"({' OR '.join(orParts)})")
values.extend([term] * len(textCols))
continue
if key not in validColumns:
logger.debug(f"_buildPaginationClauses: key '{key}' NOT in validColumns {list(validColumns)[:10]}")
continue
colType = fields.get(key, "TEXT")
logger.debug(f"_buildPaginationClauses: filter key='{key}' val={val!r} type(val)={type(val).__name__} colType={colType}")
if val is None:
where_parts.append(f'"{key}" IS NULL')
continue
if isinstance(val, dict):
op = val.get("operator", "equals")
v = val.get("value", "")
if op in ("equals", "eq"):
if colType == "BOOLEAN":
where_parts.append(f'COALESCE("{key}", FALSE) = %s')
values.append(str(v).lower() == "true")
else:
where_parts.append(f'"{key}"::TEXT = %s')
values.append(str(v))
elif op == "contains":
where_parts.append(f'"{key}"::TEXT ILIKE %s')
values.append(f"%{v}%")
elif op == "startsWith":
where_parts.append(f'"{key}"::TEXT ILIKE %s')
values.append(f"{v}%")
elif op == "endsWith":
where_parts.append(f'"{key}"::TEXT ILIKE %s')
values.append(f"%{v}")
elif op in ("gt", "gte", "lt", "lte"):
sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op]
where_parts.append(f'"{key}"::TEXT {sqlOp} %s')
values.append(str(v))
elif op == "between":
fromVal = v.get("from", "") if isinstance(v, dict) else ""
toVal = v.get("to", "") if isinstance(v, dict) else ""
if not fromVal and not toVal:
continue
colType = fields.get(key, "TEXT")
isNumericCol = colType in ("INTEGER", "DOUBLE PRECISION")
isDateVal = bool(fromVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(fromVal))) or \
bool(toVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(toVal)))
if isNumericCol and isDateVal:
from datetime import datetime as _dt, timezone as _tz
if fromVal and toVal:
fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp()
toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp()
where_parts.append(f'"{key}" >= %s AND "{key}" <= %s')
values.extend([fromTs, toTs])
elif fromVal:
fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp()
where_parts.append(f'"{key}" >= %s')
values.append(fromTs)
else:
toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp()
where_parts.append(f'"{key}" <= %s')
values.append(toTs)
else:
if fromVal and toVal:
where_parts.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s')
values.extend([str(fromVal), str(toVal)])
elif fromVal:
where_parts.append(f'"{key}"::TEXT >= %s')
values.append(str(fromVal))
elif toVal:
where_parts.append(f'"{key}"::TEXT <= %s')
values.append(str(toVal))
else:
if colType == "BOOLEAN":
where_parts.append(f'COALESCE("{key}", FALSE) = %s')
values.append(str(val).lower() == "true")
else:
where_parts.append(f'"{key}"::TEXT ILIKE %s')
values.append(str(val))
where_clause = " WHERE " + " AND ".join(where_parts) if where_parts else ""
count_values = list(values)
orderParts: List[str] = []
if pagination and pagination.sort:
for sf in pagination.sort:
if sf.field in validColumns:
direction = "DESC" if sf.direction.lower() == "desc" else "ASC"
colType = fields.get(sf.field, "TEXT")
if colType == "BOOLEAN":
orderParts.append(f'COALESCE("{sf.field}", FALSE) {direction}')
else:
orderParts.append(f'"{sf.field}" {direction} NULLS LAST')
if not orderParts:
orderParts.append('"id"')
order_clause = " ORDER BY " + ", ".join(orderParts)
limit_clause = ""
if pagination:
offset = (pagination.page - 1) * pagination.pageSize
limit_clause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
return where_clause, order_clause, limit_clause, values, count_values
def getRecordsetPaginated(
self,
model_class: type,
pagination=None,
recordFilter: Dict[str, Any] = None,
fieldFilter: List[str] = None,
) -> Dict[str, Any]:
"""
Returns paginated records with filtering + sorting at the SQL level.
Returns { "items": [...], "totalItems": int, "totalPages": int }.
If pagination is None, returns all records (no LIMIT/OFFSET).
"""
from modules.datamodels.datamodelPagination import PaginationParams
import math
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return {"items": [], "totalItems": 0, "totalPages": 0}
where_clause, order_clause, limit_clause, values, count_values = \
self._buildPaginationClauses(model_class, pagination, recordFilter)
with self.connection.cursor() as cursor:
countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
cursor.execute(countSql, count_values)
totalItems = cursor.fetchone()["count"]
dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
cursor.execute(dataSql, values)
records = [dict(row) for row in cursor.fetchall()]
fields = _get_model_fields(model_class)
modelFields = model_class.model_fields
for record in records:
_parseRecordFields(record, fields, f"table {table}")
for fieldName, fieldType in fields.items():
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
fieldInfo = modelFields.get(fieldName)
if fieldInfo:
fieldAnnotation = fieldInfo.annotation
if (fieldAnnotation == list or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is list)):
record[fieldName] = []
elif (fieldAnnotation == dict or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is dict)):
record[fieldName] = {}
if fieldFilter and isinstance(fieldFilter, list):
records = [{f: r[f] for f in fieldFilter if f in r} for r in records]
pageSize = pagination.pageSize if pagination else max(totalItems, 1)
totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
except Exception as e:
logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
return {"items": [], "totalItems": 0, "totalPages": 0}
def getDistinctColumnValues(
self,
model_class: type,
column: str,
pagination=None,
recordFilter: Dict[str, Any] = None,
) -> List[str]:
"""
Returns sorted distinct non-null values for a column using SQL DISTINCT.
Applies cross-filtering (all filters except the requested column).
"""
table = model_class.__name__
fields = _get_model_fields(model_class)
if column not in fields:
return []
try:
if not self._ensureTableExists(model_class):
return []
if pagination:
if pagination.filters and column in pagination.filters:
import copy
pagination = copy.deepcopy(pagination)
pagination.filters.pop(column, None)
where_clause, _, _, values, _ = \
self._buildPaginationClauses(model_class, pagination, recordFilter)
sql = (
f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} '
f'WHERE "{column}" IS NOT NULL AND "{column}"::TEXT != \'\' '
if not where_clause else
f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} '
f'AND "{column}" IS NOT NULL AND "{column}"::TEXT != \'\' '
)
sql += 'ORDER BY val'
with self.connection.cursor() as cursor:
cursor.execute(sql, values)
return [row["val"] for row in cursor.fetchall()]
except Exception as e:
logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
return []
def recordCreate(
self, model_class: type, record: Union[Dict[str, Any], BaseModel]
) -> Dict[str, Any]:
"""Creates a new record in a table based on Pydantic model class."""
# If record is a Pydantic model, convert to dict
if isinstance(record, BaseModel):
record = record.model_dump()
elif isinstance(record, dict):
record = record.copy()
else:
raise ValueError("Record must be a Pydantic model or dictionary")
# Ensure record has an ID
if "id" not in record:
record["id"] = str(uuid.uuid4())
# Save record
success = self._saveRecord(model_class, record["id"], record)
if not success:
table = model_class.__name__
raise ValueError(f"Failed to save record {record['id']} to table {table}")
# Check if this is the first record in the table and register as initial ID
table = model_class.__name__
existingInitialId = self.getInitialId(model_class)
if existingInitialId is None:
# This is the first record, register it as the initial ID
self._registerInitialId(table, record["id"])
logger.info(f"Registered initial ID {record['id']} for table {table}")
return record
def recordModify(
self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel]
) -> Dict[str, Any]:
"""Modifies an existing record in a table based on Pydantic model class."""
# Load existing record
existingRecord = self._loadRecord(model_class, recordId)
if not existingRecord:
table = model_class.__name__
raise ValueError(f"Record {recordId} not found in table {table}")
# If record is a Pydantic model, convert to dict
if isinstance(record, BaseModel):
record = record.model_dump()
elif isinstance(record, dict):
record = record.copy()
else:
raise ValueError("Record must be a Pydantic model or dictionary")
# CRITICAL: Ensure we never modify the ID
if "id" in record and str(record["id"]) != recordId:
logger.error(
f"Attempted to modify record ID from {recordId} to {record['id']}"
)
raise ValueError(
"Cannot modify record ID - it must match the provided recordId"
)
# Update existing record with new data
existingRecord.update(record)
# Save updated record
saved = self._saveRecord(model_class, recordId, existingRecord)
if not saved:
table = model_class.__name__
raise ValueError(f"Failed to save record {recordId} to table {table}")
return existingRecord
def recordDelete(self, model_class: type, recordId: str) -> bool:
"""Deletes a record from the table based on Pydantic model class."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return False
with self.connection.cursor() as cursor:
# Check if record exists
cursor.execute(
f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
)
if not cursor.fetchone():
return False
# Check if it's an initial record
initialId = self.getInitialId(model_class)
if initialId is not None and initialId == recordId:
self._removeInitialId(table)
logger.info(
f"Initial ID {recordId} for table {table} has been removed from the system table"
)
# Delete the record
cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
# No cache to update - database handles consistency
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
self.connection.rollback()
return False
def getInitialId(self, model_class: type) -> Optional[str]:
"""Returns the initial ID for a table."""
table = model_class.__name__
systemData = self._loadSystemTable()
initialId = systemData.get(table)
return initialId
def semanticSearch(
self,
modelClass: type,
vectorColumn: str,
queryVector: List[float],
limit: int = 10,
recordFilter: Dict[str, Any] = None,
minScore: float = None,
) -> List[Dict[str, Any]]:
"""Semantic search using pgvector cosine distance.
Args:
modelClass: Pydantic model class for the table.
vectorColumn: Name of the vector column to search.
queryVector: Query vector as List[float].
limit: Maximum number of results.
recordFilter: Additional WHERE filters (field: value).
minScore: Minimum cosine similarity (0.0 - 1.0).
Returns:
List of records with an added '_score' field (cosine similarity),
sorted by similarity descending.
"""
table = modelClass.__name__
try:
if not self._ensureTableExists(modelClass):
return []
vectorStr = f"[{','.join(str(v) for v in queryVector)}]"
whereConditions = []
whereValues = []
if recordFilter:
for field, value in recordFilter.items():
if value is None:
whereConditions.append(f'"{field}" IS NULL')
elif isinstance(value, (list, tuple)):
if not value:
whereConditions.append("1 = 0")
else:
whereConditions.append(f'"{field}" = ANY(%s)')
whereValues.append(list(value))
else:
whereConditions.append(f'"{field}" = %s')
whereValues.append(value)
if minScore is not None:
whereConditions.append(
f'1 - ("{vectorColumn}" <=> %s::vector) >= %s'
)
whereValues.extend([vectorStr, minScore])
whereClause = ""
if whereConditions:
whereClause = " WHERE " + " AND ".join(whereConditions)
query = (
f'SELECT *, 1 - ("{vectorColumn}" <=> %s::vector) AS "_score" '
f'FROM "{table}"{whereClause} '
f'ORDER BY "{vectorColumn}" <=> %s::vector '
f'LIMIT %s'
)
params = [vectorStr] + whereValues + [vectorStr, limit]
with self.connection.cursor() as cursor:
cursor.execute(query, params)
records = [dict(row) for row in cursor.fetchall()]
fields = _get_model_fields(modelClass)
for record in records:
_parseRecordFields(record, fields, f"semanticSearch {table}")
return records
except Exception as e:
logger.error(f"Error in semantic search on {table}: {e}")
return []
def close(self, forceClose: bool = False):
"""Close the database connection.
Shared cached connectors are intentionally kept open unless forceClose=True.
This prevents accidental shutdown from interface __del__ methods while
other requests are still using the same cached connector instance.
"""
if self._isCachedShared and not forceClose:
return
if (
hasattr(self, "connection")
and self.connection
and not self.connection.closed
):
self.connection.close()
def __del__(self):
"""Cleanup method to close connection."""
try:
self.close()
except Exception:
pass