# Copyright (c) 2025 Patrick Motsch # All rights reserved. import contextvars import re import time import psycopg2 import psycopg2.extras import psycopg2.pool import logging from contextlib import contextmanager from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type import uuid from pydantic import BaseModel, Field import threading from modules.shared.timeUtils import getUtcTimestamp from modules.shared.configuration import APP_CONFIG from modules.datamodels.datamodelBase import PowerOnModel from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext logger = logging.getLogger(__name__) # No mapping needed - table name = Pydantic model name exactly class DatabaseQueryError(RuntimeError): """Raised by DB read methods when the underlying SQL query failed. Empty result sets do NOT raise this — they return ``[]`` / ``None`` / ``{"items": [], "totalItems": 0, "totalPages": 0}`` as before. This exception is reserved for **real** failures: psycopg2 ProgrammingError, DataError, OperationalError, IntegrityError, plus any unexpected Python error raised inside a query path. Read methods used to silently swallow such errors and return empty collections, which made every caller incapable of distinguishing "no rows" from "broken query / type adapter / dropped column / lost connection". That hid concrete bugs (e.g. dict passed where Postgres expected a UUID string) behind misleading downstream "no record found" errors. """ def __init__(self, table: str, message: str, original: BaseException = None): super().__init__(f"{table}: {message}") self.table = table self.original = original class SystemTable(PowerOnModel): """Data model for system table entries""" table_name: str = Field( description="Name of the table", json_schema_extra={ "frontend_type": "text", "frontend_readonly": True, "frontend_required": True, } ) initial_id: Optional[str] = Field( default=None, description="Initial ID for the table", json_schema_extra={ "frontend_type": "text", "frontend_readonly": True, "frontend_required": False, } ) def _isVectorType(sqlType: str) -> bool: """Check if a SQL type string represents a pgvector column.""" return sqlType.upper().startswith("VECTOR") def _isJsonbType(fieldType) -> bool: """Check if a type should be stored as JSONB in PostgreSQL.""" # Direct dict or list if fieldType == dict or fieldType == list: return True # Generic List[X] or Dict[X, Y] origin = get_origin(fieldType) if origin in (dict, list): return True # Direct Pydantic BaseModel subclass if isinstance(fieldType, type) and issubclass(fieldType, BaseModel): return True # Optional[X] - check the inner type if origin is Union: args = get_args(fieldType) for arg in args: if arg is type(None): continue # Recursively check the inner type if _isJsonbType(arg): return True return False def getModelFields(model_class) -> Dict[str, str]: """Get all fields from Pydantic model and map to SQL types. Supports explicit db_type override via json_schema_extra={"db_type": "vector(1536)"}. This enables pgvector columns without special-casing field names. """ model_fields = model_class.model_fields fields = {} for field_name, field_info in model_fields.items(): field_type = field_info.annotation # Explicit db_type override (e.g. vector columns) extra = field_info.json_schema_extra if extra and isinstance(extra, dict) and "db_type" in extra: fields[field_name] = extra["db_type"] continue # Unwrap Optional[X] → X (handles both typing.Union and types.UnionType) origin = get_origin(field_type) if origin is Union: args = [a for a in get_args(field_type) if a is not type(None)] if len(args) == 1: field_type = args[0] elif hasattr(field_type, '__args__') and type(None) in getattr(field_type, '__args__', ()): args = [a for a in field_type.__args__ if a is not type(None)] if len(args) == 1: field_type = args[0] if _isJsonbType(field_type): fields[field_name] = "JSONB" elif field_type is bool: fields[field_name] = "BOOLEAN" elif field_type is int: fields[field_name] = "INTEGER" elif field_type is float: fields[field_name] = "DOUBLE PRECISION" elif field_type in (str, type(None)): fields[field_name] = "TEXT" else: fields[field_name] = "TEXT" return fields def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: str = "") -> None: """Parse record fields in-place: numeric typing, vector parsing, JSONB deserialization.""" import json as _json for fieldName, fieldType in fields.items(): if fieldName not in record: continue value = record[fieldName] if fieldType in ("DOUBLE PRECISION", "INTEGER") and value is not None: try: record[fieldName] = float(value) if fieldType == "DOUBLE PRECISION" else int(value) except (ValueError, TypeError): logger.warning(f"Could not convert {fieldName} to {fieldType} ({context}): {value}") elif _isVectorType(fieldType) and value is not None: if isinstance(value, str): try: record[fieldName] = [float(v) for v in value.strip("[]").split(",")] except (ValueError, TypeError): logger.warning(f"Could not parse vector field {fieldName} ({context})") elif isinstance(value, list): pass # already a list elif fieldType == "BOOLEAN": record[fieldName] = bool(value) if value is not None else None elif fieldType == "JSONB" and value is not None: try: if isinstance(value, str): record[fieldName] = _json.loads(value) elif not isinstance(value, (dict, list)): record[fieldName] = _json.loads(str(value)) except (_json.JSONDecodeError, TypeError, ValueError): logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})") def _stripNulBytesFromStr(value: Any) -> Any: """psycopg2 rejects bound parameters whose Python str contains NUL (0x00). Some extracted files (e.g. SQL dumps, mixed binary treated as text) still carry those bytes; PostgreSQL TEXT could store them via other paths, but the client protocol path used here cannot. """ if isinstance(value, str) and "\x00" in value: return value.replace("\x00", "") return value def _quotePgIdent(name: str) -> str: return '"' + str(name).replace('"', '""') + '"' # --------------------------------------------------------------------------- # Connection pool registry # --------------------------------------------------------------------------- # psycopg2 connections are NOT thread-safe; sharing one connection across the # FastAPI threadpool (sync `def` routes) or across multiple async tasks that # happen to be active simultaneously results in either # `OperationalError: another command is already in progress` or — far worse — # an unbounded hang in `recv()` on the connection socket, because two cursors # wait for one server response. # # `_PoolRegistry` keeps one `ThreadedConnectionPool` per database identity # (`host`, `db`, `port`). Every DB call borrows a connection from the pool, # runs its query, and returns the connection. The pool itself guarantees that # no two callers share the same psycopg2 connection at the same time. # # `statement_timeout=30000` (30 s) is a safety net: a runaway query is aborted # instead of hanging forever and poisoning the connection. `connect_timeout=10` # prevents indefinite blocking when the Postgres host is unreachable. _DEFAULT_POOL_MIN = 2 _DEFAULT_POOL_MAX = 20 _STATEMENT_TIMEOUT_MS = 30000 _CONNECT_TIMEOUT_S = 10 # psycopg2.pool.ThreadedConnectionPool.getconn() does NOT block when the pool # is exhausted — it raises `psycopg2.pool.PoolError` immediately. That makes # bursty workloads (50 concurrent route calls against a max=20 pool) fail # spuriously. `borrowConn()` therefore retries with a small backoff up to # `_BORROW_WAIT_TIMEOUT_S` seconds before giving up. _BORROW_WAIT_TIMEOUT_S = 30.0 _BORROW_WAIT_BACKOFF_S = 0.05 _shuttingDown = False def _resolvePoolMax() -> int: """Pool size is configurable via `DB_POOL_MAX_CONN` (default 20).""" try: return max(2, int(APP_CONFIG.get("DB_POOL_MAX_CONN") or _DEFAULT_POOL_MAX)) except (ValueError, TypeError): return _DEFAULT_POOL_MAX class _PoolRegistry: """Process-wide registry of `ThreadedConnectionPool` instances. Keyed by `(host, database, port)` so that two `DatabaseConnector` instances pointing at the same physical database share one pool. Lazy-initialised and thread-safe. """ _pools: Dict[tuple, psycopg2.pool.ThreadedConnectionPool] = {} _lock = threading.Lock() @classmethod def getPool( cls, *, dbHost: str, dbDatabase: str, dbUser: str, dbPassword: str, dbPort: int, ) -> psycopg2.pool.ThreadedConnectionPool: port = int(dbPort) if dbPort is not None else 5432 key = (dbHost, dbDatabase, port) # Fast path: pool exists. pool = cls._pools.get(key) if pool is not None: return pool # Slow path: create exactly one pool per key, even under contention. with cls._lock: pool = cls._pools.get(key) if pool is not None: return pool poolMax = _resolvePoolMax() options = f"-c statement_timeout={_STATEMENT_TIMEOUT_MS}" try: pool = psycopg2.pool.ThreadedConnectionPool( _DEFAULT_POOL_MIN, poolMax, host=dbHost, port=port, database=dbDatabase, user=dbUser, password=dbPassword, client_encoding="utf8", cursor_factory=psycopg2.extras.RealDictCursor, connect_timeout=_CONNECT_TIMEOUT_S, options=options, ) except Exception as e: logger.error( "Failed to create connection pool for db=%s host=%s: %s", dbDatabase, dbHost, e, ) raise cls._pools[key] = pool logger.debug( "Created connection pool for db=%s host=%s port=%s (min=%d max=%d, stmt_timeout=%dms)", dbDatabase, dbHost, port, _DEFAULT_POOL_MIN, poolMax, _STATEMENT_TIMEOUT_MS, ) return pool @classmethod def closeAll(cls) -> None: """Close every pool. Call once during FastAPI shutdown.""" with cls._lock: for key, pool in list(cls._pools.items()): try: pool.closeall() except Exception as e: logger.warning("Error closing pool %s: %s", key, e) cls._pools.clear() logger.info("All database connection pools closed") def closeAllPools() -> None: """Public entry point for FastAPI lifespan shutdown hook. Sets the shutdown flag first so that any in-flight ``_acquireConn`` loops abort immediately instead of polling for up to 30 s. """ global _shuttingDown _shuttingDown = True _PoolRegistry.closeAll() class _ConnectionShim: """Backward-compatibility stand-in for the old `DatabaseConnector.connection`. Old callers used patterns like:: db.connection.commit() if db.connection and not db.connection.closed: ... These are now no-ops because the pool owns the connection lifecycle. Direct cursor access through `self.connection.cursor()` is intentionally blocked with a clear error so silent breakage is impossible — every such call site must migrate to `db.borrowCursor()`. """ closed = False def __bool__(self) -> bool: return True def commit(self) -> None: return def rollback(self) -> None: return def close(self) -> None: return def cursor(self, *args, **kwargs): raise RuntimeError( "DatabaseConnector.connection.cursor() is no longer supported. " "Use `db.borrowCursor()` (or `db.borrowConn()` for multi-statement " "transactions) so the connection is borrowed from and returned to " "the pool correctly." ) _CONNECTION_SHIM = _ConnectionShim() # --------------------------------------------------------------------------- # Connector cache (lightweight wrappers — actual connections live in the pool) # --------------------------------------------------------------------------- # Multiple call sites (`routeI18n`, `aiAuditLogger`, `mainBackgroundJobService`, # interfaces) ask for a connector via `getCachedConnector(...)` and expect to # get back the same object on subsequent calls. Now that real connection # multiplexing happens at the pool layer, the cache returns lightweight # `DatabaseConnector` wrappers — they hold no connection themselves, only the # DSN params and a reference to the shared pool. _MAX_CACHED_CONNECTORS = 32 _connector_cache: Dict[tuple, "DatabaseConnector"] = {} _connector_cache_order: List[tuple] = [] # FIFO order for eviction _connector_cache_lock = threading.Lock() _current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar( "db_connector_user_id", default=None ) def getCachedConnector( dbHost: str, dbDatabase: str, dbUser: str = None, dbPassword: str = None, dbPort: int = None, userId: str = None, ) -> "DatabaseConnector": """Return a cached `DatabaseConnector` wrapper for `(host, database, port)`. Two layers of caching, both intentional: 1. **Pool layer** (`_PoolRegistry`) — owns the actual psycopg2 connections. Per `(host, db, port)` exactly one `ThreadedConnectionPool`, shared across every wrapper that points at the same physical database. This is what actually saves Postgres connection slots. 2. **Wrapper layer** (this function) — caches the `DatabaseConnector` Python object. Each wrapper triggers an `initDbSystem()` on first instantiation (CREATE DATABASE if missing, CREATE TABLE _system, pool warm-up). Caching the wrapper avoids paying that bootstrap cost on every request and stops the log from filling with "PostgreSQL database system initialized" lines. `userId` is request-scoped via the `_current_user_id` contextvar so two concurrent requests sharing the same cached wrapper still produce correct `sysCreatedBy` / `sysModifiedBy` audit fields. """ port = int(dbPort) if dbPort is not None else 5432 key = (dbHost, dbDatabase, port) with _connector_cache_lock: if key not in _connector_cache: # FIFO eviction. Connectors are now lightweight (no per-instance # connection), so eviction is purely a memory bookkeeping concern; # the underlying pool stays alive in `_PoolRegistry`. while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order: oldest_key = _connector_cache_order.pop(0) _connector_cache.pop(oldest_key, None) _connector_cache[key] = DatabaseConnector( dbHost=dbHost, dbDatabase=dbDatabase, dbUser=dbUser, dbPassword=dbPassword, dbPort=dbPort, userId=userId, ) _connector_cache[key]._isCachedShared = True _connector_cache_order.append(key) conn = _connector_cache[key] # Set request-scoped userId via contextvar (avoids mutating shared connector) if userId is not None: _current_user_id.set(userId) return conn class DatabaseConnector: """ A connector for PostgreSQL-based data storage. Provides generic database operations without user/mandate filtering. Uses PostgreSQL with JSONB columns for flexible data storage. """ def __init__( self, dbHost: str, dbDatabase: str, dbUser: str = None, dbPassword: str = None, dbPort: int = None, userId: str = None, ): # Store the input parameters self.dbHost = dbHost self.dbDatabase = dbDatabase self.dbUser = dbUser self.dbPassword = dbPassword self.dbPort = dbPort # Set userId (default to empty string if None) self.userId = userId if userId is not None else "" # No per-instance connection any more — real connections live in the # shared `_PoolRegistry` pool. `_isCachedShared` is retained because # `close(forceClose=False)` callers (interface __del__) still ask. self._isCachedShared = False # pgvector extension state (cached per connector instance — cheap) self._vectorExtensionEnabled = False # System table bootstrap: create database, system table, ensure metadata. self._systemTableName = "_system" self.initDbSystem() self._initializeSystemTable() def initDbSystem(self): """Bootstrap the physical database and the `_system` metadata table. Uses a short-lived autocommit connection (NOT the pool) because `CREATE DATABASE` cannot run inside a transaction block. Also warms up the pool by acquiring it for this `(host, db, port)` once. """ try: self._create_database_if_not_exists() self._create_tables() # Warm the pool so the first request doesn't pay for socket setup. _PoolRegistry.getPool( dbHost=self.dbHost, dbDatabase=self.dbDatabase, dbUser=self.dbUser, dbPassword=self.dbPassword, dbPort=self.dbPort, ) logger.debug( "PostgreSQL database system initialized (db=%s, host=%s, port=%s)", self.dbDatabase, self.dbHost, self.dbPort, ) except Exception as e: logger.error(f"FATAL ERROR: Database system initialization failed: {e}") raise @property def connection(self) -> "_ConnectionShim": """Backward-compat shim — see `_ConnectionShim` docstring.""" return _CONNECTION_SHIM def _ensure_connection(self) -> None: """No-op for backward compatibility. Previously this method tested-or-reconnected the per-instance socket. With pooling, the `ThreadedConnectionPool` re-establishes broken connections lazily on the next `getconn()`. Kept as a no-op so that legacy call sites continue to compile. """ return @contextmanager def borrowCursor(self): """Borrow a cursor for one short SQL block. Convenience wrapper around `borrowConn()` for the common pattern: with db.borrowCursor() as cursor: cursor.execute(...) rows = cursor.fetchall() Replaces the legacy `with db.connection.cursor() as cursor:` pattern. Commit/rollback and pool return are handled automatically — callers must NOT call `.commit()`/`.rollback()` on the cursor themselves. """ with self.borrowConn() as conn: with conn.cursor() as cursor: yield cursor @contextmanager def borrowConn(self): """Borrow a connection from the pool for the duration of the block. Pool-exhaustion semantics: `ThreadedConnectionPool.getconn()` raises `psycopg2.pool.PoolError` immediately when the pool is at its `maxconn` limit — it does NOT block. We wrap that with a bounded busy-wait so bursty workloads (e.g. 50 concurrent route calls against a max=20 pool) queue up instead of failing. If no connection becomes available within `_BORROW_WAIT_TIMEOUT_S`, the `PoolError` is propagated so a genuinely deadlocked pool surfaces as a real error. On normal exit the current transaction is committed (idempotent for read-only queries — leaves the connection in a clean state for the next borrower). On exception the transaction is rolled back and the exception propagates. The connection is **always** returned to the pool, even when commit/rollback fails — otherwise a single bad query would leak a slot and eventually exhaust the pool. """ pool = _PoolRegistry.getPool( dbHost=self.dbHost, dbDatabase=self.dbDatabase, dbUser=self.dbUser, dbPassword=self.dbPassword, dbPort=self.dbPort, ) conn = self._acquireConn(pool) try: yield conn except Exception: try: conn.rollback() except Exception: pass raise else: # Best-effort commit so the connection goes back to the pool with # no in-flight transaction. Failure here is non-fatal — the pool # will re-establish the socket on the next `getconn()` if needed. try: conn.commit() except Exception: try: conn.rollback() except Exception: pass finally: try: pool.putconn(conn) except Exception as e: logger.warning("Failed to return connection to pool: %s", e) @staticmethod def _acquireConn(pool: psycopg2.pool.ThreadedConnectionPool): """Get a connection from the pool, waiting up to `_BORROW_WAIT_TIMEOUT_S`. psycopg2's pool throws on exhaustion instead of queueing — this helper polls with a short backoff so callers see queue semantics. Aborts immediately when the application is shutting down. """ if _shuttingDown: raise psycopg2.pool.PoolError("Application is shutting down") deadline = time.monotonic() + _BORROW_WAIT_TIMEOUT_S attempt = 0 while True: try: return pool.getconn() except psycopg2.pool.PoolError as e: attempt += 1 if _shuttingDown: raise psycopg2.pool.PoolError("Application is shutting down") if time.monotonic() >= deadline: logger.error( "Connection pool exhausted after %.1fs wait (%d retries)", _BORROW_WAIT_TIMEOUT_S, attempt, ) raise time.sleep(_BORROW_WAIT_BACKOFF_S) def _create_database_if_not_exists(self): """Create the database if it doesn't exist. Uses an autocommit connection on the `postgres` admin DB because `CREATE DATABASE` cannot run inside a transaction block — so this path intentionally does NOT use the pool. """ try: conn = psycopg2.connect( host=self.dbHost, port=self.dbPort, database="postgres", user=self.dbUser, password=self.dbPassword, client_encoding="utf8", connect_timeout=_CONNECT_TIMEOUT_S, ) conn.autocommit = True try: with conn.cursor() as cursor: cursor.execute( "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,) ) exists = cursor.fetchone() if not exists: quoted_db_name = f'"{self.dbDatabase}"' cursor.execute(f"CREATE DATABASE {quoted_db_name}") logger.info(f"Created database: {self.dbDatabase}") finally: conn.close() except Exception as e: logger.error(f"FATAL ERROR: Cannot create database: {e}") logger.error("Database connection failed - application cannot start") raise RuntimeError( f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}" ) def _create_tables(self): """Create the `_system` table. Uses a short-lived autocommit connection (not the pool) — runs exactly once at connector creation. """ try: conn = psycopg2.connect( host=self.dbHost, port=self.dbPort, database=self.dbDatabase, user=self.dbUser, password=self.dbPassword, client_encoding="utf8", connect_timeout=_CONNECT_TIMEOUT_S, ) conn.autocommit = True try: with conn.cursor() as cursor: cursor.execute(""" CREATE TABLE IF NOT EXISTS _system ( id SERIAL PRIMARY KEY, table_name VARCHAR(255) UNIQUE NOT NULL, initial_id VARCHAR(255) NOT NULL, "sysCreatedAt" DOUBLE PRECISION, "sysCreatedBy" VARCHAR(255), "sysModifiedAt" DOUBLE PRECISION, "sysModifiedBy" VARCHAR(255) ) """) finally: conn.close() except Exception as e: logger.error(f"FATAL ERROR: Cannot create system table: {e}") logger.error( "Database system table creation failed - application cannot start" ) raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}") def _initializeSystemTable(self): """Initializes the system table if it doesn't exist yet.""" try: self._ensureTableExists(SystemTable) with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute('SELECT COUNT(*) FROM "_system"') cursor.fetchone() # noqa: just verifies table is readable except Exception as e: logger.error(f"Error initializing system table: {e}") raise def _loadSystemTable(self) -> Dict[str, str]: """Loads the system table with the initial IDs.""" try: with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute('SELECT "table_name", "initial_id" FROM "_system"') rows = cursor.fetchall() return {row["table_name"]: row["initial_id"] for row in rows} except Exception as e: logger.error(f"Error loading system table: {e}") return {} def _saveSystemTable(self, data: Dict[str, str]) -> bool: """Saves the system table with the initial IDs.""" try: with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute('DELETE FROM "_system"') for table_name, initial_id in data.items(): cursor.execute( """ INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt") VALUES (%s, %s, %s) """, (table_name, initial_id, getUtcTimestamp()), ) return True except Exception as e: logger.error(f"Error saving system table: {e}") return False def _ensureSystemTableExists(self) -> bool: """Ensures the system table exists, creates it if it doesn't.""" try: with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute( "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", (self._systemTableName,), ) exists = cursor.fetchone()["count"] > 0 if not exists: cursor.execute(f""" CREATE TABLE "{self._systemTableName}" ( "table_name" VARCHAR(255) PRIMARY KEY, "initial_id" VARCHAR(255), "sysCreatedAt" DOUBLE PRECISION, "sysCreatedBy" VARCHAR(255), "sysModifiedAt" DOUBLE PRECISION, "sysModifiedBy" VARCHAR(255) ) """) logger.info("System table created successfully") else: cursor.execute( """ SELECT column_name FROM information_schema.columns WHERE table_name = %s AND table_schema = 'public' """, (self._systemTableName,), ) existing_columns = [row["column_name"] for row in cursor.fetchall()] for sys_col, sys_sql in [ ("sysCreatedAt", "DOUBLE PRECISION"), ("sysCreatedBy", "VARCHAR(255)"), ("sysModifiedAt", "DOUBLE PRECISION"), ("sysModifiedBy", "VARCHAR(255)"), ]: if sys_col not in existing_columns: cursor.execute( f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}' ) return True except Exception as e: logger.error(f"Error ensuring system table exists: {e}") return False def _ensureTableExists(self, model_class: type) -> bool: """Ensures a table exists, creates it if it doesn't.""" table = model_class.__name__ if table == "SystemTable": # Handle system table specially - it uses _system as the actual table name return self._ensureSystemTableExists() try: with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute( """ SELECT COUNT(*) FROM information_schema.tables WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' """, (table,), ) exists = cursor.fetchone()["count"] > 0 if not exists: self._create_table_from_model(cursor, table, model_class) logger.info( f"Created table '{table}' with columns from Pydantic model" ) else: # Table exists: ensure all columns from model are present (simple additive migration) try: cursor.execute( """ SELECT column_name, data_type FROM information_schema.columns WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' """, (table,), ) existing_column_rows = cursor.fetchall() existing_columns = { row["column_name"] for row in existing_column_rows } existing_column_types = { row["column_name"]: (row["data_type"] or "").lower() for row in existing_column_rows } model_fields = getModelFields(model_class) desired_columns = set(["id"]) | set(model_fields.keys()) for col in sorted(desired_columns - existing_columns): if col in ["id"]: continue sql_type = model_fields.get(col) if not sql_type: sql_type = "TEXT" try: cursor.execute( f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}' ) logger.info( f"Added missing column '{col}' ({sql_type}) to '{table}'" ) except Exception as add_err: logger.warning( f"Could not add column '{col}' to '{table}': {add_err}" ) # Column type migrations for existing tables. # TEXT→DOUBLE PRECISION handles three value shapes: # 1. NULL / empty string → NULL # 2. ISO date(time) like "2025-01-22" or "2025-01-22T10:00:00+00" → epoch via EXTRACT # 3. Plain numeric string like "3.14" → direct cast _TEXT_TO_DOUBLE = ( 'DOUBLE PRECISION USING CASE' ' WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL' ' WHEN "{col}" ~ \'^\\d{4}-\\d{2}-\\d{2}\'' ' THEN EXTRACT(EPOCH FROM "{col}"::timestamptz)' ' ELSE NULLIF("{col}", \'\')::double precision' ' END' ) _SAFE_TYPE_CHANGES = { ("jsonb", "TEXT"): "TEXT USING \"{col}\"::text", ("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE, ("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer", ("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')', ("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")', ("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')', } for col in sorted(desired_columns & existing_columns): if col == "id": continue desired_sql = (model_fields.get(col) or "").upper() currentType = existing_column_types.get(col, "") migration = _SAFE_TYPE_CHANGES.get((currentType, desired_sql)) if not migration and desired_sql.startswith("VECTOR") and currentType == "text": migration = f'{desired_sql} USING CASE WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL ELSE "{col}"::vector END' if migration: castExpr = migration.replace("{col}", col) try: cursor.execute('SAVEPOINT col_migrate') cursor.execute( f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {castExpr}' ) cursor.execute('RELEASE SAVEPOINT col_migrate') logger.info( f"Migrated column '{col}' from {currentType} to {desired_sql} on '{table}'" ) except Exception as alter_err: cursor.execute('ROLLBACK TO SAVEPOINT col_migrate') logger.warning( f"Could not migrate column '{col}' on '{table}': {alter_err}" ) except Exception as ensure_err: logger.warning( f"Could not ensure columns for existing table '{table}': {ensure_err}" ) return True except Exception as e: logger.error(f"Error ensuring table {table} exists: {e}") return False def _ensureVectorExtension(self) -> bool: """Enable pgvector extension if not already enabled. Called lazily on first vector table.""" if self._vectorExtensionEnabled: return True try: with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") self._vectorExtensionEnabled = True logger.info("pgvector extension enabled") return True except Exception as e: logger.error(f"Failed to enable pgvector extension: {e}") return False def _create_table_from_model(self, cursor, table: str, model_class: type) -> None: """Create table with columns matching Pydantic model fields.""" fields = getModelFields(model_class) # Enable pgvector if any field uses vector type if any(_isVectorType(sqlType) for sqlType in fields.values()): self._ensureVectorExtension() # Build column definitions with quoted identifiers to preserve exact case columns = ['"id" VARCHAR(255) PRIMARY KEY'] for field_name, sql_type in fields.items(): if field_name != "id": # Skip id, already defined columns.append(f'"{field_name}" {sql_type}') # Create table sql = f'CREATE TABLE IF NOT EXISTS "{table}" ({", ".join(columns)})' cursor.execute(sql) # Create indexes for foreign keys for field_name in fields: if field_name.endswith("Id") and field_name != "id": cursor.execute( f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")' ) def _save_record( self, cursor, table: str, recordId: str, record: Dict[str, Any], model_class: type, ) -> None: """Save record to normalized table with explicit columns.""" # Get columns from Pydantic model instead of database schema fields = getModelFields(model_class) columns = ["id"] + [field for field in fields.keys() if field != "id"] if not columns: logger.error(f"No columns found for table {table}") return # Filter record data to only include columns that exist in the table filtered_record = {k: v for k, v in record.items() if k in columns} # Ensure id is set filtered_record["id"] = recordId # Prepare values in the correct order values = [] for col in columns: value = filtered_record.get(col) # Handle timestamp fields - store as Unix timestamps (floats) for consistency if col in ["sysCreatedAt", "sysModifiedAt"] and value is not None: if isinstance(value, str): # Try to parse string as timestamp try: value = float(value) except: pass # Keep as string if parsing fails # Convert enum values to their string representation elif hasattr(value, "value"): value = value.value # Handle vector fields (pgvector) - convert List[float] to string elif col in fields and _isVectorType(fields[col]) and value is not None: if isinstance(value, list): value = f"[{','.join(str(v) for v in value)}]" # Handle JSONB fields - ensure proper JSON format for PostgreSQL elif col in fields and fields[col] == "JSONB" and value is not None: import json if isinstance(value, (dict, list)): value = json.dumps(value) elif isinstance(value, str): try: json.loads(value) except (json.JSONDecodeError, TypeError): value = json.dumps(value) elif hasattr(value, 'model_dump'): value = json.dumps(value.model_dump()) else: value = json.dumps(value) values.append(_stripNulBytesFromStr(value)) # Build INSERT/UPDATE with quoted identifiers col_names = ", ".join([f'"{col}"' for col in columns]) placeholders = ", ".join(["%s"] * len(columns)) updates = ", ".join( [ f'"{col}" = EXCLUDED."{col}"' for col in columns[1:] if col not in ["sysCreatedAt", "sysCreatedBy"] ] ) sql = f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders}) ON CONFLICT ("id") DO UPDATE SET {updates}' cursor.execute(sql, values) def _loadRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]: """Loads a single record from the normalized table.""" table = model_class.__name__ try: if not self._ensureTableExists(model_class): return None with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,)) row = cursor.fetchone() if not row: return None record = dict(row) fields = getModelFields(model_class) parseRecordFields(record, fields, f"record {recordId}") return record except Exception as e: logger.error(f"Error loading record {recordId} from table {table}: {e}") raise DatabaseQueryError(table, str(e), original=e) from e def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]: """Load one row by primary key (routes / services; wraps _loadRecord).""" return self._loadRecord(model_class, str(recordId)) def _saveRecord( self, model_class: type, recordId: str, record: Dict[str, Any] ) -> bool: """Saves a single record to the table.""" table = model_class.__name__ try: if not self._ensureTableExists(model_class): return False recordId = str(recordId) if "id" in record and str(record["id"]) != recordId: raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}") # Add metadata - use contextvar for request-scoped userId when sharing connector effective_user_id = _current_user_id.get() if effective_user_id is None: effective_user_id = self.userId currentTime = getUtcTimestamp() # Set sysCreatedAt/sysCreatedBy on first persist; always refresh modified fields. # Treat None and 0 as unset (empty / bad defaults); model_dump often has sysCreatedAt=None. createdTs = record.get("sysCreatedAt") if createdTs is None or createdTs == 0 or createdTs == 0.0: record["sysCreatedAt"] = currentTime # Do not wipe caller-provided sysCreatedBy (e.g. FileItem from createFile with # real user). ContextVar can be "system" for the DB pool while the business # user is set on the record from model_dump(). if effective_user_id and not record.get("sysCreatedBy"): record["sysCreatedBy"] = effective_user_id elif not record.get("sysCreatedBy"): if effective_user_id: record["sysCreatedBy"] = effective_user_id record["sysModifiedAt"] = currentTime if effective_user_id: record["sysModifiedBy"] = effective_user_id with self.borrowConn() as conn: with conn.cursor() as cursor: self._save_record(cursor, table, recordId, record, model_class) return True except Exception as e: logger.error(f"Error saving record {recordId} to table {table}: {e}") return False def _loadTable(self, model_class: type) -> List[Dict[str, Any]]: """Loads all records from a normalized table.""" table = model_class.__name__ if table == self._systemTableName: return self._loadSystemTable() try: if not self._ensureTableExists(model_class): return [] with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"') records = [dict(row) for row in cursor.fetchall()] fields = getModelFields(model_class) modelFields = model_class.model_fields for record in records: parseRecordFields(record, fields, f"table {table}") for fieldName, fieldType in fields.items(): if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: fieldInfo = modelFields.get(fieldName) if fieldInfo: fieldAnnotation = fieldInfo.annotation if (fieldAnnotation == list or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is list)): record[fieldName] = [] elif (fieldAnnotation == dict or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is dict)): record[fieldName] = {} return records except Exception as e: logger.error(f"Error loading table {table}: {e}") raise DatabaseQueryError(table, str(e), original=e) from e def _registerInitialId(self, table: str, initialId: str) -> bool: """Registers the initial ID for a table.""" try: systemData = self._loadSystemTable() if table not in systemData: systemData[table] = initialId success = self._saveSystemTable(systemData) if success: logger.info(f"Initial ID {initialId} for table {table} registered") return success else: # Table already has an initial ID registered logger.debug(f"Table {table} already has initial ID {systemData[table]}") return True except Exception as e: logger.error(f"Error registering the initial ID for table {table}: {e}") return False def _removeInitialId(self, table: str) -> bool: """Removes the initial ID for a table from the system table.""" try: systemData = self._loadSystemTable() if table in systemData: del systemData[table] success = self._saveSystemTable(systemData) if success: logger.info( f"Initial ID for table {table} removed from system table" ) return success return True # If not present, this is not an error except Exception as e: logger.error(f"Error removing initial ID for table {table}: {e}") return False def buildRbacWhereClause( self, permissions: UserPermissions, currentUser: User, table: str, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None, ) -> Optional[Dict[str, Any]]: """Delegate to interfaceRbac.buildRbacWhereClause (tests and call sites use connector as entry).""" from modules.interfaces.interfaceRbac import buildRbacWhereClause as _buildRbacWhereClause return _buildRbacWhereClause( permissions, currentUser, table, self, mandateId=mandateId, featureInstanceId=featureInstanceId, ) def updateContext(self, userId: str) -> None: """Updates the context of the database connector. Sets both instance userId and contextvar for request-scoped use when connector is shared. """ if userId is None: raise ValueError("userId must be provided") self.userId = userId _current_user_id.set(userId) # Public API def getTables(self) -> List[str]: """Returns a list of all available tables.""" tables: List[str] = [] try: with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name """) rows = cursor.fetchall() tables = [row["table_name"] for row in rows] except Exception as e: logger.error(f"Error reading the database {self.dbDatabase}: {e}") return tables def getFields(self, model_class: type) -> List[str]: """Returns a list of all fields in a table.""" data = self._loadTable(model_class) if not data: return [] fields = list(data[0].keys()) if data else [] return fields def getSchema( self, model_class: type, language: str = None ) -> Dict[str, Dict[str, Any]]: """Returns a schema object for a table with data types and labels.""" data = self._loadTable(model_class) schema = {} if not data: return schema firstRecord = data[0] for field, value in firstRecord.items(): dataType = type(value).__name__ label = field schema[field] = {"type": dataType, "label": label} return schema def getRecordset( self, model_class: type, fieldFilter: List[str] = None, recordFilter: Dict[str, Any] = None, ) -> List[Dict[str, Any]]: """Returns a list of records from a table, filtered by criteria.""" table = model_class.__name__ try: if not self._ensureTableExists(model_class): return [] # Build WHERE clause from recordFilter where_conditions = [] where_values = [] if recordFilter: for field, value in recordFilter.items(): if value is None: where_conditions.append(f'"{field}" IS NULL') elif isinstance(value, list): where_conditions.append(f'"{field}" = ANY(%s)') where_values.append(value) else: where_conditions.append(f'"{field}" = %s') where_values.append(value) if where_conditions: where_clause = " WHERE " + " AND ".join(where_conditions) else: where_clause = "" query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"' with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute(query, where_values) records = [dict(row) for row in cursor.fetchall()] fields = getModelFields(model_class) modelFields = model_class.model_fields for record in records: parseRecordFields(record, fields, f"table {table}") for fieldName, fieldType in fields.items(): if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: fieldInfo = modelFields.get(fieldName) if fieldInfo: fieldAnnotation = fieldInfo.annotation if (fieldAnnotation == list or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is list)): record[fieldName] = [] elif (fieldAnnotation == dict or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is dict)): record[fieldName] = {} if fieldFilter and isinstance(fieldFilter, list): result = [] for record in records: filteredRecord = {} for field in fieldFilter: if field in record: filteredRecord[field] = record[field] result.append(filteredRecord) return result return records except Exception as e: logger.error(f"Error loading records from table {table}: {e}") raise DatabaseQueryError(table, str(e), original=e) from e def _buildPaginationClauses( self, model_class: type, pagination, recordFilter: Dict[str, Any] = None, ): """ Translate PaginationParams + recordFilter into SQL clauses. Returns (where_clause, order_clause, limit_clause, values, count_values). """ fields = getModelFields(model_class) validColumns = set(fields.keys()) where_parts: List[str] = [] values: List[Any] = [] if recordFilter: for field, value in recordFilter.items(): if value is None: where_parts.append(f'"{field}" IS NULL') elif isinstance(value, list): where_parts.append(f'"{field}" = ANY(%s)') values.append(value) else: where_parts.append(f'"{field}" = %s') values.append(value) if pagination and pagination.filters: for key, val in pagination.filters.items(): if key == "search" and isinstance(val, str) and val.strip(): term = f"%{val.strip()}%" textCols = [c for c, t in fields.items() if t == "TEXT"] if textCols: orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols] where_parts.append(f"({' OR '.join(orParts)})") values.extend([term] * len(textCols)) continue if key not in validColumns: logger.debug(f"_buildPaginationClauses: key '{key}' NOT in validColumns {list(validColumns)[:10]}") continue colType = fields.get(key, "TEXT") logger.debug(f"_buildPaginationClauses: filter key='{key}' val={val!r} type(val)={type(val).__name__} colType={colType}") if val is None: where_parts.append(f'("{key}" IS NULL OR "{key}" = \'\')') continue if isinstance(val, dict): op = val.get("operator", "equals") v = val.get("value", "") if op in ("equals", "eq"): if colType == "BOOLEAN": where_parts.append(f'COALESCE("{key}", FALSE) = %s') values.append(str(v).lower() == "true") else: where_parts.append(f'"{key}"::TEXT = %s') values.append(str(v)) elif op == "contains": where_parts.append(f'"{key}"::TEXT ILIKE %s') values.append(f"%{v}%") elif op == "startsWith": where_parts.append(f'"{key}"::TEXT ILIKE %s') values.append(f"{v}%") elif op == "endsWith": where_parts.append(f'"{key}"::TEXT ILIKE %s') values.append(f"%{v}") elif op in ("gt", "gte", "lt", "lte"): sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op] if colType in ("INTEGER", "DOUBLE PRECISION"): try: where_parts.append(f'"{key}"::double precision {sqlOp} %s') values.append(float(v)) except (ValueError, TypeError): continue else: where_parts.append(f'"{key}"::TEXT {sqlOp} %s') values.append(str(v)) elif op == "between": fromVal = v.get("from", "") if isinstance(v, dict) else "" toVal = v.get("to", "") if isinstance(v, dict) else "" if not fromVal and not toVal: continue colType = fields.get(key, "TEXT") isNumericCol = colType in ("INTEGER", "DOUBLE PRECISION") isDateVal = bool(fromVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(fromVal))) or \ bool(toVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(toVal))) if isNumericCol and isDateVal: from datetime import datetime as _dt, timezone as _tz if fromVal and toVal: fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp() toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp() where_parts.append(f'"{key}" >= %s AND "{key}" <= %s') values.extend([fromTs, toTs]) elif fromVal: fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp() where_parts.append(f'"{key}" >= %s') values.append(fromTs) else: toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp() where_parts.append(f'"{key}" <= %s') values.append(toTs) elif isNumericCol: try: if fromVal and toVal: where_parts.append( f'"{key}"::double precision >= %s AND "{key}"::double precision <= %s' ) values.extend([float(fromVal), float(toVal)]) elif fromVal: where_parts.append(f'"{key}"::double precision >= %s') values.append(float(fromVal)) elif toVal: where_parts.append(f'"{key}"::double precision <= %s') values.append(float(toVal)) except (ValueError, TypeError): continue else: if fromVal and toVal: where_parts.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s') values.extend([str(fromVal), str(toVal)]) elif fromVal: where_parts.append(f'"{key}"::TEXT >= %s') values.append(str(fromVal)) elif toVal: where_parts.append(f'"{key}"::TEXT <= %s') values.append(str(toVal)) else: if colType == "BOOLEAN": where_parts.append(f'COALESCE("{key}", FALSE) = %s') values.append(str(val).lower() == "true") else: where_parts.append(f'"{key}"::TEXT ILIKE %s') values.append(str(val)) where_clause = " WHERE " + " AND ".join(where_parts) if where_parts else "" count_values = list(values) orderParts: List[str] = [] if pagination and pagination.sort: for sf in pagination.sort: sfField = sf.get("field") if isinstance(sf, dict) else getattr(sf, "field", None) sfDir = sf.get("direction", "asc") if isinstance(sf, dict) else getattr(sf, "direction", "asc") if sfField and sfField in validColumns: direction = "DESC" if str(sfDir).lower() == "desc" else "ASC" colType = fields.get(sfField, "TEXT") if colType == "BOOLEAN": orderParts.append(f'COALESCE("{sfField}", FALSE) {direction}') else: orderParts.append(f'"{sfField}" {direction} NULLS LAST') if not orderParts: orderParts.append('"id"') order_clause = " ORDER BY " + ", ".join(orderParts) limit_clause = "" if pagination: offset = (pagination.page - 1) * pagination.pageSize limit_clause = f" LIMIT {pagination.pageSize} OFFSET {offset}" return where_clause, order_clause, limit_clause, values, count_values def getRecordsetPaginated( self, model_class: type, pagination=None, recordFilter: Dict[str, Any] = None, fieldFilter: List[str] = None, ) -> Dict[str, Any]: """ Returns paginated records with filtering + sorting at the SQL level. Returns { "items": [...], "totalItems": int, "totalPages": int }. If pagination is None, returns all records (no LIMIT/OFFSET). """ from modules.datamodels.datamodelPagination import PaginationParams import math table = model_class.__name__ try: if not self._ensureTableExists(model_class): return {"items": [], "totalItems": 0, "totalPages": 0} where_clause, order_clause, limit_clause, values, count_values = \ self._buildPaginationClauses(model_class, pagination, recordFilter) with self.borrowConn() as conn: with conn.cursor() as cursor: countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}' dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}' cursor.execute(countSql, count_values) totalItems = cursor.fetchone()["count"] cursor.execute(dataSql, values) records = [dict(row) for row in cursor.fetchall()] fields = getModelFields(model_class) modelFields = model_class.model_fields for record in records: parseRecordFields(record, fields, f"table {table}") for fieldName, fieldType in fields.items(): if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: fieldInfo = modelFields.get(fieldName) if fieldInfo: fieldAnnotation = fieldInfo.annotation if (fieldAnnotation == list or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is list)): record[fieldName] = [] elif (fieldAnnotation == dict or (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is dict)): record[fieldName] = {} if fieldFilter and isinstance(fieldFilter, list): records = [{f: r[f] for f in fieldFilter if f in r} for r in records] from modules.routes.routeHelpers import enrichRowsWithFkLabels enrichRowsWithFkLabels(records, model_class) pageSize = pagination.pageSize if pagination else max(totalItems, 1) totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0 return {"items": records, "totalItems": totalItems, "totalPages": totalPages} except Exception as e: logger.error(f"Error in getRecordsetPaginated for table {table}: {e}") raise DatabaseQueryError(table, str(e), original=e) from e def getDistinctColumnValues( self, model_class: type, column: str, pagination=None, recordFilter: Dict[str, Any] = None, includeEmpty: bool = True, ) -> List[Optional[str]]: """Return sorted distinct values for a column using SQL DISTINCT. When ``includeEmpty`` is True (default), NULL and empty-string rows are represented as a single ``None`` entry at the end of the list — this allows the frontend to offer a "(Leer)" filter option. Applies cross-filtering (all filters except the requested column). """ table = model_class.__name__ fields = getModelFields(model_class) if column not in fields: return [] try: if not self._ensureTableExists(model_class): return [] if pagination: import copy pagination = copy.deepcopy(pagination) if pagination.filters and column in pagination.filters: pagination.filters.pop(column, None) pagination.sort = [] where_clause, _, _, values, _ = \ self._buildPaginationClauses(model_class, pagination, recordFilter) nonNullCond = f'"{column}" IS NOT NULL AND "{column}"::TEXT != \'\'' if where_clause: sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} AND {nonNullCond} ORDER BY val' else: sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val' with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute(sql, values) result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()] if includeEmpty: emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\'' if where_clause: emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1' else: emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1' cursor.execute(emptySql, values) if cursor.fetchone(): result.append(None) return result except Exception as e: logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}") raise DatabaseQueryError(table, str(e), original=e) from e def recordCreate( self, model_class: type, record: Union[Dict[str, Any], BaseModel] ) -> Dict[str, Any]: """Creates a new record in a table based on Pydantic model class.""" # If record is a Pydantic model, convert to dict if isinstance(record, BaseModel): record = record.model_dump() elif isinstance(record, dict): record = record.copy() else: raise ValueError("Record must be a Pydantic model or dictionary") # Ensure record has an ID if "id" not in record: record["id"] = str(uuid.uuid4()) # Save record success = self._saveRecord(model_class, record["id"], record) if not success: table = model_class.__name__ raise ValueError(f"Failed to save record {record['id']} to table {table}") # Check if this is the first record in the table and register as initial ID table = model_class.__name__ existingInitialId = self.getInitialId(model_class) if existingInitialId is None: # This is the first record, register it as the initial ID self._registerInitialId(table, record["id"]) logger.info(f"Registered initial ID {record['id']} for table {table}") return record def recordModify( self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel] ) -> Dict[str, Any]: """Modifies an existing record in a table based on Pydantic model class.""" # Load existing record existingRecord = self._loadRecord(model_class, recordId) if not existingRecord: table = model_class.__name__ raise ValueError(f"Record {recordId} not found in table {table}") # If record is a Pydantic model, convert to dict if isinstance(record, BaseModel): record = record.model_dump() elif isinstance(record, dict): record = record.copy() else: raise ValueError("Record must be a Pydantic model or dictionary") # CRITICAL: Ensure we never modify the ID if "id" in record and str(record["id"]) != recordId: logger.error( f"Attempted to modify record ID from {recordId} to {record['id']}" ) raise ValueError( "Cannot modify record ID - it must match the provided recordId" ) # Update existing record with new data existingRecord.update(record) # Save updated record saved = self._saveRecord(model_class, recordId, existingRecord) if not saved: table = model_class.__name__ raise ValueError(f"Failed to save record {recordId} to table {table}") return existingRecord def recordDelete(self, model_class: type, recordId: str) -> bool: """Deletes a record from the table based on Pydantic model class.""" table = model_class.__name__ try: if not self._ensureTableExists(model_class): return False # `getInitialId` opens its own borrow; do it BEFORE we acquire a # connection ourselves so we don't pin two slots concurrently. initialId = self.getInitialId(model_class) with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute( f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,) ) if not cursor.fetchone(): return False if initialId is not None and initialId == recordId: # `_removeInitialId` borrows its own conn — done outside # this block on purpose to avoid nested borrows. pass cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,)) if initialId is not None and initialId == recordId: self._removeInitialId(table) logger.info( f"Initial ID {recordId} for table {table} has been removed from the system table" ) return True except Exception as e: logger.error(f"Error deleting record {recordId} from table {table}: {e}") return False def recordCreateBulk( self, model_class: type, records: List[Union[Dict[str, Any], BaseModel]] ) -> int: """Bulk-insert many records in a single transaction. Use this instead of calling recordCreate() in a tight loop when importing large datasets (>100 rows). Performance gain is roughly two orders of magnitude because: - one network round-trip via execute_values() instead of N - one COMMIT instead of N - initial ID is registered once for the whole batch instead of every row Returns the number of rows successfully inserted. Caller is responsible for catching exceptions; on any error the transaction is rolled back so the table stays consistent (all-or-nothing). """ if not records: return 0 table = model_class.__name__ if not self._ensureTableExists(model_class): raise ValueError(f"Table {table} does not exist") fields = getModelFields(model_class) columns = ["id"] + [f for f in fields.keys() if f != "id"] modelFields = model_class.model_fields effectiveUserId = _current_user_id.get() if effectiveUserId is None: effectiveUserId = self.userId currentTime = getUtcTimestamp() normalised: List[Dict[str, Any]] = [] for raw in records: if isinstance(raw, BaseModel): rec = raw.model_dump() elif isinstance(raw, dict): rec = raw.copy() else: raise ValueError("Bulk record must be a Pydantic model or dictionary") if "id" not in rec or not rec["id"]: rec["id"] = str(uuid.uuid4()) createdTs = rec.get("sysCreatedAt") if createdTs is None or createdTs == 0 or createdTs == 0.0: rec["sysCreatedAt"] = currentTime if effectiveUserId and not rec.get("sysCreatedBy"): rec["sysCreatedBy"] = effectiveUserId elif not rec.get("sysCreatedBy") and effectiveUserId: rec["sysCreatedBy"] = effectiveUserId rec["sysModifiedAt"] = currentTime if effectiveUserId: rec["sysModifiedBy"] = effectiveUserId normalised.append(rec) rows = [self._coerceRowForInsert(rec, columns, fields, modelFields) for rec in normalised] col_names = ", ".join([f'"{c}"' for c in columns]) updates = ", ".join( [f'"{c}" = EXCLUDED."{c}"' for c in columns[1:] if c not in ("sysCreatedAt", "sysCreatedBy")] ) sql = ( f'INSERT INTO "{table}" ({col_names}) VALUES %s ' f'ON CONFLICT ("id") DO UPDATE SET {updates}' ) try: with self.borrowConn() as conn: with conn.cursor() as cursor: psycopg2.extras.execute_values(cursor, sql, rows, page_size=500) except Exception as e: logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}") raise if self.getInitialId(model_class) is None and normalised: self._registerInitialId(table, normalised[0]["id"]) logger.info(f"Registered initial ID {normalised[0]['id']} for table {table}") return len(rows) def _coerceRowForInsert( self, record: Dict[str, Any], columns: List[str], fields: Dict[str, str], modelFields: Dict[str, Any], ) -> tuple: """Convert one record dict to a positional tuple matching `columns`. Mirrors the per-column coercion logic in `_save_record` so that bulk and single inserts produce identical on-disk values (timestamps as floats, enums as strings, vectors as pgvector text, JSONB as JSON strings). """ import json as _json out = [] for col in columns: value = record.get(col) if col in ("sysCreatedAt", "sysModifiedAt") and value is not None: if isinstance(value, str): try: value = float(value) except Exception: pass elif hasattr(value, "value"): value = value.value elif col in fields and _isVectorType(fields[col]) and value is not None: if isinstance(value, list): value = f"[{','.join(str(v) for v in value)}]" elif col in fields and fields[col] == "JSONB" and value is not None: if isinstance(value, (dict, list)): value = _json.dumps(value) elif isinstance(value, str): try: _json.loads(value) except (ValueError, TypeError): value = _json.dumps(value) elif hasattr(value, "model_dump"): value = _json.dumps(value.model_dump()) else: value = _json.dumps(value) out.append(value) return tuple(out) def recordDeleteWhere( self, model_class: type, recordFilter: Dict[str, Any] ) -> int: """Delete all records matching a simple equality filter, in one statement. Replaces the N+1 pattern `for r in getRecordset(...): recordDelete(r.id)`. Returns the number of rows actually deleted. If the table holds the initial ID and that row gets deleted, the initial ID registration is cleared so the next insert can re-register a fresh one. """ if not recordFilter: raise ValueError("recordDeleteWhere requires a non-empty recordFilter (refusing to truncate)") table = model_class.__name__ if not self._ensureTableExists(model_class): return 0 fields = getModelFields(model_class) clauses: List[str] = [] params: List[Any] = [] for key, val in recordFilter.items(): if key not in fields and key != "id": raise ValueError(f"recordDeleteWhere: unknown column {table}.{key}") clauses.append(f'"{key}" = %s') params.append(val) whereSql = " AND ".join(clauses) initialId = self.getInitialId(model_class) try: with self.borrowConn() as conn: with conn.cursor() as cursor: if initialId is not None: cursor.execute( f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql, [initialId, *params], ) initialIsAffected = cursor.fetchone() is not None else: initialIsAffected = False cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params) deleted = cursor.rowcount or 0 except Exception as e: logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}") raise if deleted and initialIsAffected: self._removeInitialId(table) logger.info(f"Initial ID for table {table} cleared (bulk-delete removed it)") if deleted: logger.info(f"recordDeleteWhere: deleted {deleted} rows from {table} where {recordFilter}") return deleted def getInitialId(self, model_class: type) -> Optional[str]: """Returns the initial ID for a table.""" table = model_class.__name__ systemData = self._loadSystemTable() initialId = systemData.get(table) return initialId def semanticSearch( self, modelClass: type, vectorColumn: str, queryVector: List[float], limit: int = 10, recordFilter: Dict[str, Any] = None, minScore: float = None, ) -> List[Dict[str, Any]]: """Semantic search using pgvector cosine distance. Args: modelClass: Pydantic model class for the table. vectorColumn: Name of the vector column to search. queryVector: Query vector as List[float]. limit: Maximum number of results. recordFilter: Additional WHERE filters (field: value). minScore: Minimum cosine similarity (0.0 - 1.0). Returns: List of records with an added '_score' field (cosine similarity), sorted by similarity descending. """ table = modelClass.__name__ try: if not self._ensureTableExists(modelClass): return [] vectorStr = f"[{','.join(str(v) for v in queryVector)}]" whereConditions = [] whereValues = [] if recordFilter: for field, value in recordFilter.items(): if value is None: whereConditions.append(f'"{field}" IS NULL') elif isinstance(value, (list, tuple)): if not value: whereConditions.append("1 = 0") else: whereConditions.append(f'"{field}" = ANY(%s)') whereValues.append(list(value)) else: whereConditions.append(f'"{field}" = %s') whereValues.append(value) if minScore is not None: whereConditions.append( f'1 - ("{vectorColumn}" <=> %s::vector) >= %s' ) whereValues.extend([vectorStr, minScore]) whereClause = "" if whereConditions: whereClause = " WHERE " + " AND ".join(whereConditions) query = ( f'SELECT *, 1 - ("{vectorColumn}" <=> %s::vector) AS "_score" ' f'FROM "{table}"{whereClause} ' f'ORDER BY "{vectorColumn}" <=> %s::vector ' f'LIMIT %s' ) params = [vectorStr] + whereValues + [vectorStr, limit] with self.borrowConn() as conn: with conn.cursor() as cursor: cursor.execute(query, params) records = [dict(row) for row in cursor.fetchall()] fields = getModelFields(modelClass) for record in records: parseRecordFields(record, fields, f"semanticSearch {table}") return records except Exception as e: logger.error(f"Error in semantic search on {table}: {e}") raise DatabaseQueryError(table, str(e), original=e) from e def close(self, forceClose: bool = False): """No-op for backward compatibility. Connections are now owned by the `_PoolRegistry` pool and live for the process lifetime. Pool shutdown happens centrally via `closeAllPools()` from the FastAPI lifespan hook — never from a connector instance. Interface `__del__` paths used to call `close()` to release a per- connector socket; with pooling there is nothing to close here. """ return def __del__(self): """Cleanup hook (intentionally no-op — see `close`).""" return