diff --git a/app.py b/app.py index 7a4ed4d4..d94c7dd5 100644 --- a/app.py +++ b/app.py @@ -438,7 +438,16 @@ async def lifespan(app: FastAPI): logger.error(f"Feature '{featureName}' failed to stop: {e}") except Exception as e: logger.warning(f"Could not shutdown feature containers: {e}") - + + # --- Close all PostgreSQL connection pools --- + # Must run LAST: feature `onStop` hooks may still issue DB calls during + # shutdown. Once we tear down the pools, no more borrows are possible. + try: + from modules.connectors.connectorDbPostgre import closeAllPools + closeAllPools() + except Exception as e: + logger.warning(f"Closing DB connection pools failed: {e}") + logger.info("Application has been shut down") diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py index a6893396..f1a34f70 100644 --- a/modules/connectors/connectorDbPostgre.py +++ b/modules/connectors/connectorDbPostgre.py @@ -2,9 +2,12 @@ # All rights reserved. import contextvars import re +import time import psycopg2 import psycopg2.extras +import psycopg2.pool import logging +from contextlib import contextmanager from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type import uuid from pydantic import BaseModel, Field @@ -44,24 +47,6 @@ class DatabaseQueryError(RuntimeError): self.original = original -def _rollbackQuietly(connection) -> None: - """Restore the connection state after a failed query. - - Postgres puts the connection in an error state after any failed - statement; subsequent queries on the same connection raise - ``InFailedSqlTransaction`` until we rollback. We swallow rollback - errors because the original query error is what the caller should - see — a secondary rollback failure typically means the connection - is gone and will be reopened on the next ``_ensure_connection``. - """ - if connection is None: - return - try: - connection.rollback() - except Exception: - pass - - class SystemTable(PowerOnModel): """Data model for system table entries""" @@ -203,9 +188,174 @@ def _quotePgIdent(name: str) -> str: return '"' + str(name).replace('"', '""') + '"' -# Cache connectors by (host, database, port) to avoid duplicate inits for same database. -# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via -# contextvars to avoid races when concurrent requests share the same connector. +# --------------------------------------------------------------------------- +# Connection pool registry +# --------------------------------------------------------------------------- +# psycopg2 connections are NOT thread-safe; sharing one connection across the +# FastAPI threadpool (sync `def` routes) or across multiple async tasks that +# happen to be active simultaneously results in either +# `OperationalError: another command is already in progress` or — far worse — +# an unbounded hang in `recv()` on the connection socket, because two cursors +# wait for one server response. +# +# `_PoolRegistry` keeps one `ThreadedConnectionPool` per database identity +# (`host`, `db`, `port`). Every DB call borrows a connection from the pool, +# runs its query, and returns the connection. The pool itself guarantees that +# no two callers share the same psycopg2 connection at the same time. +# +# `statement_timeout=30000` (30 s) is a safety net: a runaway query is aborted +# instead of hanging forever and poisoning the connection. `connect_timeout=10` +# prevents indefinite blocking when the Postgres host is unreachable. + +_DEFAULT_POOL_MIN = 2 +_DEFAULT_POOL_MAX = 20 +_STATEMENT_TIMEOUT_MS = 30000 +_CONNECT_TIMEOUT_S = 10 +# psycopg2.pool.ThreadedConnectionPool.getconn() does NOT block when the pool +# is exhausted — it raises `psycopg2.pool.PoolError` immediately. That makes +# bursty workloads (50 concurrent route calls against a max=20 pool) fail +# spuriously. `borrowConn()` therefore retries with a small backoff up to +# `_BORROW_WAIT_TIMEOUT_S` seconds before giving up. +_BORROW_WAIT_TIMEOUT_S = 30.0 +_BORROW_WAIT_BACKOFF_S = 0.05 + + +def _resolvePoolMax() -> int: + """Pool size is configurable via `DB_POOL_MAX_CONN` (default 20).""" + try: + return max(2, int(APP_CONFIG.get("DB_POOL_MAX_CONN") or _DEFAULT_POOL_MAX)) + except (ValueError, TypeError): + return _DEFAULT_POOL_MAX + + +class _PoolRegistry: + """Process-wide registry of `ThreadedConnectionPool` instances. + + Keyed by `(host, database, port)` so that two `DatabaseConnector` instances + pointing at the same physical database share one pool. Lazy-initialised and + thread-safe. + """ + + _pools: Dict[tuple, psycopg2.pool.ThreadedConnectionPool] = {} + _lock = threading.Lock() + + @classmethod + def getPool( + cls, + *, + dbHost: str, + dbDatabase: str, + dbUser: str, + dbPassword: str, + dbPort: int, + ) -> psycopg2.pool.ThreadedConnectionPool: + port = int(dbPort) if dbPort is not None else 5432 + key = (dbHost, dbDatabase, port) + # Fast path: pool exists. + pool = cls._pools.get(key) + if pool is not None: + return pool + # Slow path: create exactly one pool per key, even under contention. + with cls._lock: + pool = cls._pools.get(key) + if pool is not None: + return pool + poolMax = _resolvePoolMax() + options = f"-c statement_timeout={_STATEMENT_TIMEOUT_MS}" + try: + pool = psycopg2.pool.ThreadedConnectionPool( + _DEFAULT_POOL_MIN, + poolMax, + host=dbHost, + port=port, + database=dbDatabase, + user=dbUser, + password=dbPassword, + client_encoding="utf8", + cursor_factory=psycopg2.extras.RealDictCursor, + connect_timeout=_CONNECT_TIMEOUT_S, + options=options, + ) + except Exception as e: + logger.error( + "Failed to create connection pool for db=%s host=%s: %s", + dbDatabase, dbHost, e, + ) + raise + cls._pools[key] = pool + logger.debug( + "Created connection pool for db=%s host=%s port=%s (min=%d max=%d, stmt_timeout=%dms)", + dbDatabase, dbHost, port, _DEFAULT_POOL_MIN, poolMax, _STATEMENT_TIMEOUT_MS, + ) + return pool + + @classmethod + def closeAll(cls) -> None: + """Close every pool. Call once during FastAPI shutdown.""" + with cls._lock: + for key, pool in list(cls._pools.items()): + try: + pool.closeall() + except Exception as e: + logger.warning("Error closing pool %s: %s", key, e) + cls._pools.clear() + logger.info("All database connection pools closed") + + +def closeAllPools() -> None: + """Public entry point for FastAPI lifespan shutdown hook.""" + _PoolRegistry.closeAll() + + +class _ConnectionShim: + """Backward-compatibility stand-in for the old `DatabaseConnector.connection`. + + Old callers used patterns like:: + + db.connection.commit() + if db.connection and not db.connection.closed: ... + + These are now no-ops because the pool owns the connection lifecycle. + Direct cursor access through `self.connection.cursor()` is intentionally + blocked with a clear error so silent breakage is impossible — every such + call site must migrate to `db.borrowCursor()`. + """ + + closed = False + + def __bool__(self) -> bool: + return True + + def commit(self) -> None: + return + + def rollback(self) -> None: + return + + def close(self) -> None: + return + + def cursor(self, *args, **kwargs): + raise RuntimeError( + "DatabaseConnector.connection.cursor() is no longer supported. " + "Use `db.borrowCursor()` (or `db.borrowConn()` for multi-statement " + "transactions) so the connection is borrowed from and returned to " + "the pool correctly." + ) + + +_CONNECTION_SHIM = _ConnectionShim() + + +# --------------------------------------------------------------------------- +# Connector cache (lightweight wrappers — actual connections live in the pool) +# --------------------------------------------------------------------------- +# Multiple call sites (`routeI18n`, `aiAuditLogger`, `mainBackgroundJobService`, +# interfaces) ask for a connector via `getCachedConnector(...)` and expect to +# get back the same object on subsequent calls. Now that real connection +# multiplexing happens at the pool layer, the cache returns lightweight +# `DatabaseConnector` wrappers — they hold no connection themselves, only the +# DSN params and a reference to the shared pool. _MAX_CACHED_CONNECTORS = 32 _connector_cache: Dict[tuple, "DatabaseConnector"] = {} _connector_cache_order: List[tuple] = [] # FIFO order for eviction @@ -223,22 +373,36 @@ def getCachedConnector( dbPort: int = None, userId: str = None, ) -> "DatabaseConnector": - """Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits. - Uses contextvars for userId so concurrent requests sharing the same connector get correct sysCreatedBy/sysModifiedBy. + """Return a cached `DatabaseConnector` wrapper for `(host, database, port)`. + + Two layers of caching, both intentional: + + 1. **Pool layer** (`_PoolRegistry`) — owns the actual psycopg2 connections. + Per `(host, db, port)` exactly one `ThreadedConnectionPool`, shared + across every wrapper that points at the same physical database. This + is what actually saves Postgres connection slots. + + 2. **Wrapper layer** (this function) — caches the `DatabaseConnector` + Python object. Each wrapper triggers an `initDbSystem()` on first + instantiation (CREATE DATABASE if missing, CREATE TABLE _system, + pool warm-up). Caching the wrapper avoids paying that bootstrap cost + on every request and stops the log from filling with "PostgreSQL + database system initialized" lines. + + `userId` is request-scoped via the `_current_user_id` contextvar so two + concurrent requests sharing the same cached wrapper still produce + correct `sysCreatedBy` / `sysModifiedBy` audit fields. """ port = int(dbPort) if dbPort is not None else 5432 key = (dbHost, dbDatabase, port) with _connector_cache_lock: if key not in _connector_cache: - # Evict oldest if at capacity + # FIFO eviction. Connectors are now lightweight (no per-instance + # connection), so eviction is purely a memory bookkeeping concern; + # the underlying pool stays alive in `_PoolRegistry`. while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order: oldest_key = _connector_cache_order.pop(0) - if oldest_key in _connector_cache: - try: - _connector_cache[oldest_key].close(forceClose=True) - except Exception as e: - logger.warning(f"Error closing evicted connector: {e}") - del _connector_cache[oldest_key] + _connector_cache.pop(oldest_key, None) _connector_cache[key] = DatabaseConnector( dbHost=dbHost, dbDatabase=dbDatabase, @@ -282,34 +446,38 @@ class DatabaseConnector: # Set userId (default to empty string if None) self.userId = userId if userId is not None else "" - # Initialize database system first (creates database if needed) - self.connection = None + # No per-instance connection any more — real connections live in the + # shared `_PoolRegistry` pool. `_isCachedShared` is retained because + # `close(forceClose=False)` callers (interface __del__) still ask. self._isCachedShared = False - self.initDbSystem() - # No caching needed with proper database - PostgreSQL handles performance - - # Thread safety - self._lock = threading.Lock() - - # pgvector extension state + # pgvector extension state (cached per connector instance — cheap) self._vectorExtensionEnabled = False - # Initialize system table + # System table bootstrap: create database, system table, ensure metadata. self._systemTableName = "_system" + self.initDbSystem() self._initializeSystemTable() def initDbSystem(self): - """Initialize the database system - creates database and tables.""" - try: - # Create database if it doesn't exist - self._create_database_if_not_exists() + """Bootstrap the physical database and the `_system` metadata table. - # Create tables + Uses a short-lived autocommit connection (NOT the pool) because + `CREATE DATABASE` cannot run inside a transaction block. Also warms up + the pool by acquiring it for this `(host, db, port)` once. + """ + try: + self._create_database_if_not_exists() self._create_tables() - # Establish connection to the database - self._connect() + # Warm the pool so the first request doesn't pay for socket setup. + _PoolRegistry.getPool( + dbHost=self.dbHost, + dbDatabase=self.dbDatabase, + dbUser=self.dbUser, + dbPassword=self.dbPassword, + dbPort=self.dbPort, + ) logger.debug( "PostgreSQL database system initialized (db=%s, host=%s, port=%s)", @@ -319,10 +487,121 @@ class DatabaseConnector: logger.error(f"FATAL ERROR: Database system initialization failed: {e}") raise - def _create_database_if_not_exists(self): - """Create the database if it doesn't exist.""" + @property + def connection(self) -> "_ConnectionShim": + """Backward-compat shim — see `_ConnectionShim` docstring.""" + return _CONNECTION_SHIM + + def _ensure_connection(self) -> None: + """No-op for backward compatibility. + + Previously this method tested-or-reconnected the per-instance socket. + With pooling, the `ThreadedConnectionPool` re-establishes broken + connections lazily on the next `getconn()`. Kept as a no-op so that + legacy call sites continue to compile. + """ + return + + @contextmanager + def borrowCursor(self): + """Borrow a cursor for one short SQL block. + + Convenience wrapper around `borrowConn()` for the common pattern: + + with db.borrowCursor() as cursor: + cursor.execute(...) + rows = cursor.fetchall() + + Replaces the legacy `with db.connection.cursor() as cursor:` pattern. + Commit/rollback and pool return are handled automatically — callers + must NOT call `.commit()`/`.rollback()` on the cursor themselves. + """ + with self.borrowConn() as conn: + with conn.cursor() as cursor: + yield cursor + + @contextmanager + def borrowConn(self): + """Borrow a connection from the pool for the duration of the block. + + Pool-exhaustion semantics: `ThreadedConnectionPool.getconn()` raises + `psycopg2.pool.PoolError` immediately when the pool is at its `maxconn` + limit — it does NOT block. We wrap that with a bounded busy-wait so + bursty workloads (e.g. 50 concurrent route calls against a max=20 + pool) queue up instead of failing. If no connection becomes available + within `_BORROW_WAIT_TIMEOUT_S`, the `PoolError` is propagated so a + genuinely deadlocked pool surfaces as a real error. + + On normal exit the current transaction is committed (idempotent for + read-only queries — leaves the connection in a clean state for the + next borrower). On exception the transaction is rolled back and the + exception propagates. The connection is **always** returned to the + pool, even when commit/rollback fails — otherwise a single bad query + would leak a slot and eventually exhaust the pool. + """ + pool = _PoolRegistry.getPool( + dbHost=self.dbHost, + dbDatabase=self.dbDatabase, + dbUser=self.dbUser, + dbPassword=self.dbPassword, + dbPort=self.dbPort, + ) + conn = self._acquireConn(pool) + try: + yield conn + except Exception: + try: + conn.rollback() + except Exception: + pass + raise + else: + # Best-effort commit so the connection goes back to the pool with + # no in-flight transaction. Failure here is non-fatal — the pool + # will re-establish the socket on the next `getconn()` if needed. + try: + conn.commit() + except Exception: + try: + conn.rollback() + except Exception: + pass + finally: + try: + pool.putconn(conn) + except Exception as e: + logger.warning("Failed to return connection to pool: %s", e) + + @staticmethod + def _acquireConn(pool: psycopg2.pool.ThreadedConnectionPool): + """Get a connection from the pool, waiting up to `_BORROW_WAIT_TIMEOUT_S`. + + psycopg2's pool throws on exhaustion instead of queueing — this helper + polls with a short backoff so callers see queue semantics. + """ + deadline = time.monotonic() + _BORROW_WAIT_TIMEOUT_S + attempt = 0 + while True: + try: + return pool.getconn() + except psycopg2.pool.PoolError as e: + attempt += 1 + if time.monotonic() >= deadline: + logger.error( + "Connection pool exhausted after %.1fs wait (%d retries)", + _BORROW_WAIT_TIMEOUT_S, attempt, + ) + raise + time.sleep(_BORROW_WAIT_BACKOFF_S) + + def _create_database_if_not_exists(self): + """Create the database if it doesn't exist. + + Uses an autocommit connection on the `postgres` admin DB because + `CREATE DATABASE` cannot run inside a transaction block — so this + path intentionally does NOT use the pool. + """ try: - # Use the configured user for database creation conn = psycopg2.connect( host=self.dbHost, port=self.dbPort, @@ -330,23 +609,21 @@ class DatabaseConnector: user=self.dbUser, password=self.dbPassword, client_encoding="utf8", + connect_timeout=_CONNECT_TIMEOUT_S, ) conn.autocommit = True - - with conn.cursor() as cursor: - # Check if database exists - cursor.execute( - "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,) - ) - exists = cursor.fetchone() - - if not exists: - # Create database with proper quoting for names with hyphens - quoted_db_name = f'"{self.dbDatabase}"' - cursor.execute(f"CREATE DATABASE {quoted_db_name}") - logger.info(f"Created database: {self.dbDatabase}") - - conn.close() + try: + with conn.cursor() as cursor: + cursor.execute( + "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,) + ) + exists = cursor.fetchone() + if not exists: + quoted_db_name = f'"{self.dbDatabase}"' + cursor.execute(f"CREATE DATABASE {quoted_db_name}") + logger.info(f"Created database: {self.dbDatabase}") + finally: + conn.close() except Exception as e: logger.error(f"FATAL ERROR: Cannot create database: {e}") @@ -356,9 +633,12 @@ class DatabaseConnector: ) def _create_tables(self): - """Create only the system table - application tables are created by interfaces.""" + """Create the `_system` table. + + Uses a short-lived autocommit connection (not the pool) — runs exactly + once at connector creation. + """ try: - # Use the configured user for table creation conn = psycopg2.connect( host=self.dbHost, port=self.dbPort, @@ -366,23 +646,24 @@ class DatabaseConnector: user=self.dbUser, password=self.dbPassword, client_encoding="utf8", + connect_timeout=_CONNECT_TIMEOUT_S, ) conn.autocommit = True - - with conn.cursor() as cursor: - # Create only the system table - cursor.execute(""" - CREATE TABLE IF NOT EXISTS _system ( - id SERIAL PRIMARY KEY, - table_name VARCHAR(255) UNIQUE NOT NULL, - initial_id VARCHAR(255) NOT NULL, - "sysCreatedAt" DOUBLE PRECISION, - "sysCreatedBy" VARCHAR(255), - "sysModifiedAt" DOUBLE PRECISION, - "sysModifiedBy" VARCHAR(255) - ) - """) - conn.close() + try: + with conn.cursor() as cursor: + cursor.execute(""" + CREATE TABLE IF NOT EXISTS _system ( + id SERIAL PRIMARY KEY, + table_name VARCHAR(255) UNIQUE NOT NULL, + initial_id VARCHAR(255) NOT NULL, + "sysCreatedAt" DOUBLE PRECISION, + "sysCreatedBy" VARCHAR(255), + "sysModifiedAt" DOUBLE PRECISION, + "sysModifiedBy" VARCHAR(255) + ) + """) + finally: + conn.close() except Exception as e: logger.error(f"FATAL ERROR: Cannot create system table: {e}") @@ -391,67 +672,26 @@ class DatabaseConnector: ) raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}") - def _connect(self): - """Establish connection to PostgreSQL database.""" - try: - # Use configured user for main connection with proper parameter handling - self.connection = psycopg2.connect( - host=self.dbHost, - port=self.dbPort, - database=self.dbDatabase, - user=self.dbUser, - password=self.dbPassword, - client_encoding="utf8", - cursor_factory=psycopg2.extras.RealDictCursor, - ) - self.connection.autocommit = False # Use transactions - except Exception as e: - logger.error(f"Failed to connect to PostgreSQL: {e}") - raise - - def _ensure_connection(self): - """Ensure database connection is alive, reconnect if necessary.""" - try: - if self.connection is None or self.connection.closed: - self._connect() - else: - # Test connection with a simple query - with self.connection.cursor() as cursor: - cursor.execute("SELECT 1") - except Exception as e: - logger.warning(f"Connection lost, reconnecting: {e}") - self._connect() - def _initializeSystemTable(self): """Initializes the system table if it doesn't exist yet.""" try: - # First ensure the system table exists self._ensureTableExists(SystemTable) - - with self.connection.cursor() as cursor: - # Check if system table has any data - cursor.execute('SELECT COUNT(*) FROM "_system"') - row = cursor.fetchone() - count = row["count"] if row else 0 - - self.connection.commit() + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute('SELECT COUNT(*) FROM "_system"') + cursor.fetchone() # noqa: just verifies table is readable except Exception as e: logger.error(f"Error initializing system table: {e}") - self.connection.rollback() raise def _loadSystemTable(self) -> Dict[str, str]: """Loads the system table with the initial IDs.""" try: - with self.connection.cursor() as cursor: - cursor.execute('SELECT "table_name", "initial_id" FROM "_system"') - rows = cursor.fetchall() - - system_data = {} - for row in rows: - system_data[row["table_name"]] = row["initial_id"] - - return system_data + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute('SELECT "table_name", "initial_id" FROM "_system"') + rows = cursor.fetchall() + return {row["table_name"]: row["initial_id"] for row in rows} except Exception as e: logger.error(f"Error loading system table: {e}") return {} @@ -459,75 +699,65 @@ class DatabaseConnector: def _saveSystemTable(self, data: Dict[str, str]) -> bool: """Saves the system table with the initial IDs.""" try: - with self.connection.cursor() as cursor: - # Clear existing data - cursor.execute('DELETE FROM "_system"') - - # Insert new data - for table_name, initial_id in data.items(): - cursor.execute( - """ - INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt") - VALUES (%s, %s, %s) - """, - (table_name, initial_id, getUtcTimestamp()), - ) - - self.connection.commit() + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute('DELETE FROM "_system"') + for table_name, initial_id in data.items(): + cursor.execute( + """ + INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt") + VALUES (%s, %s, %s) + """, + (table_name, initial_id, getUtcTimestamp()), + ) return True except Exception as e: logger.error(f"Error saving system table: {e}") - self.connection.rollback() return False def _ensureSystemTableExists(self) -> bool: """Ensures the system table exists, creates it if it doesn't.""" try: - self._ensure_connection() - - with self.connection.cursor() as cursor: - # Check if system table exists - cursor.execute( - "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", - (self._systemTableName,), - ) - exists = cursor.fetchone()["count"] > 0 - - if not exists: - # Create system table - cursor.execute(f""" - CREATE TABLE "{self._systemTableName}" ( - "table_name" VARCHAR(255) PRIMARY KEY, - "initial_id" VARCHAR(255), - "sysCreatedAt" DOUBLE PRECISION, - "sysCreatedBy" VARCHAR(255), - "sysModifiedAt" DOUBLE PRECISION, - "sysModifiedBy" VARCHAR(255) - ) - """) - logger.info("System table created successfully") - else: - # Check if we need to add missing columns to existing table + with self.borrowConn() as conn: + with conn.cursor() as cursor: cursor.execute( - """ - SELECT column_name FROM information_schema.columns - WHERE table_name = %s AND table_schema = 'public' - """, + "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", (self._systemTableName,), ) - existing_columns = [row["column_name"] for row in cursor.fetchall()] + exists = cursor.fetchone()["count"] > 0 - for sys_col, sys_sql in [ - ("sysCreatedAt", "DOUBLE PRECISION"), - ("sysCreatedBy", "VARCHAR(255)"), - ("sysModifiedAt", "DOUBLE PRECISION"), - ("sysModifiedBy", "VARCHAR(255)"), - ]: - if sys_col not in existing_columns: - cursor.execute( - f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}' + if not exists: + cursor.execute(f""" + CREATE TABLE "{self._systemTableName}" ( + "table_name" VARCHAR(255) PRIMARY KEY, + "initial_id" VARCHAR(255), + "sysCreatedAt" DOUBLE PRECISION, + "sysCreatedBy" VARCHAR(255), + "sysModifiedAt" DOUBLE PRECISION, + "sysModifiedBy" VARCHAR(255) ) + """) + logger.info("System table created successfully") + else: + cursor.execute( + """ + SELECT column_name FROM information_schema.columns + WHERE table_name = %s AND table_schema = 'public' + """, + (self._systemTableName,), + ) + existing_columns = [row["column_name"] for row in cursor.fetchall()] + for sys_col, sys_sql in [ + ("sysCreatedAt", "DOUBLE PRECISION"), + ("sysCreatedBy", "VARCHAR(255)"), + ("sysModifiedAt", "DOUBLE PRECISION"), + ("sysModifiedBy", "VARCHAR(255)"), + ]: + if sys_col not in existing_columns: + cursor.execute( + f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}' + ) return True except Exception as e: logger.error(f"Error ensuring system table exists: {e}") @@ -542,123 +772,113 @@ class DatabaseConnector: return self._ensureSystemTableExists() try: - self._ensure_connection() - - with self.connection.cursor() as cursor: - # Check if table exists by querying information_schema with case-insensitive search - cursor.execute( - """ - SELECT COUNT(*) FROM information_schema.tables - WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' - """, - (table,), - ) - exists = cursor.fetchone()["count"] > 0 - - if not exists: - # Create table from Pydantic model - self._create_table_from_model(cursor, table, model_class) - logger.info( - f"Created table '{table}' with columns from Pydantic model" + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute( + """ + SELECT COUNT(*) FROM information_schema.tables + WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' + """, + (table,), ) - else: - # Table exists: ensure all columns from model are present (simple additive migration) - try: - cursor.execute( - """ - SELECT column_name, data_type - FROM information_schema.columns - WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' - """, - (table,), + exists = cursor.fetchone()["count"] > 0 + + if not exists: + self._create_table_from_model(cursor, table, model_class) + logger.info( + f"Created table '{table}' with columns from Pydantic model" ) - existing_column_rows = cursor.fetchall() - existing_columns = { - row["column_name"] for row in existing_column_rows - } - existing_column_types = { - row["column_name"]: (row["data_type"] or "").lower() - for row in existing_column_rows - } + else: + # Table exists: ensure all columns from model are present (simple additive migration) + try: + cursor.execute( + """ + SELECT column_name, data_type + FROM information_schema.columns + WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public' + """, + (table,), + ) + existing_column_rows = cursor.fetchall() + existing_columns = { + row["column_name"] for row in existing_column_rows + } + existing_column_types = { + row["column_name"]: (row["data_type"] or "").lower() + for row in existing_column_rows + } - # Desired columns based on model - model_fields = getModelFields(model_class) - desired_columns = set(["id"]) | set(model_fields.keys()) + model_fields = getModelFields(model_class) + desired_columns = set(["id"]) | set(model_fields.keys()) - # Add missing columns - for col in sorted(desired_columns - existing_columns): - # Determine SQL type - if col in ["id"]: - continue # primary key exists already - sql_type = model_fields.get(col) - if not sql_type: - sql_type = "TEXT" - try: - cursor.execute( - f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}' - ) - logger.info( - f"Added missing column '{col}' ({sql_type}) to '{table}'" - ) - except Exception as add_err: - logger.warning( - f"Could not add column '{col}' to '{table}': {add_err}" - ) - - # Column type migrations for existing tables. - # TEXT→DOUBLE PRECISION handles three value shapes: - # 1. NULL / empty string → NULL - # 2. ISO date(time) like "2025-01-22" or "2025-01-22T10:00:00+00" → epoch via EXTRACT - # 3. Plain numeric string like "3.14" → direct cast - _TEXT_TO_DOUBLE = ( - 'DOUBLE PRECISION USING CASE' - ' WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL' - ' WHEN "{col}" ~ \'^\\d{4}-\\d{2}-\\d{2}\'' - ' THEN EXTRACT(EPOCH FROM "{col}"::timestamptz)' - ' ELSE NULLIF("{col}", \'\')::double precision' - ' END' - ) - _SAFE_TYPE_CHANGES = { - ("jsonb", "TEXT"): "TEXT USING \"{col}\"::text", - ("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE, - ("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer", - ("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')', - ("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")', - ("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')', - } - for col in sorted(desired_columns & existing_columns): - if col == "id": - continue - desired_sql = (model_fields.get(col) or "").upper() - currentType = existing_column_types.get(col, "") - migration = _SAFE_TYPE_CHANGES.get((currentType, desired_sql)) - if migration: - castExpr = migration.replace("{col}", col) + for col in sorted(desired_columns - existing_columns): + if col in ["id"]: + continue + sql_type = model_fields.get(col) + if not sql_type: + sql_type = "TEXT" try: - cursor.execute('SAVEPOINT col_migrate') cursor.execute( - f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {castExpr}' + f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}' ) - cursor.execute('RELEASE SAVEPOINT col_migrate') logger.info( - f"Migrated column '{col}' from {currentType} to {desired_sql} on '{table}'" + f"Added missing column '{col}' ({sql_type}) to '{table}'" ) - except Exception as alter_err: - cursor.execute('ROLLBACK TO SAVEPOINT col_migrate') + except Exception as add_err: logger.warning( - f"Could not migrate column '{col}' on '{table}': {alter_err}" + f"Could not add column '{col}' to '{table}': {add_err}" ) - except Exception as ensure_err: - logger.warning( - f"Could not ensure columns for existing table '{table}': {ensure_err}" - ) - self.connection.commit() + # Column type migrations for existing tables. + # TEXT→DOUBLE PRECISION handles three value shapes: + # 1. NULL / empty string → NULL + # 2. ISO date(time) like "2025-01-22" or "2025-01-22T10:00:00+00" → epoch via EXTRACT + # 3. Plain numeric string like "3.14" → direct cast + _TEXT_TO_DOUBLE = ( + 'DOUBLE PRECISION USING CASE' + ' WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL' + ' WHEN "{col}" ~ \'^\\d{4}-\\d{2}-\\d{2}\'' + ' THEN EXTRACT(EPOCH FROM "{col}"::timestamptz)' + ' ELSE NULLIF("{col}", \'\')::double precision' + ' END' + ) + _SAFE_TYPE_CHANGES = { + ("jsonb", "TEXT"): "TEXT USING \"{col}\"::text", + ("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE, + ("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer", + ("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')', + ("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")', + ("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')', + } + for col in sorted(desired_columns & existing_columns): + if col == "id": + continue + desired_sql = (model_fields.get(col) or "").upper() + currentType = existing_column_types.get(col, "") + migration = _SAFE_TYPE_CHANGES.get((currentType, desired_sql)) + if migration: + castExpr = migration.replace("{col}", col) + try: + cursor.execute('SAVEPOINT col_migrate') + cursor.execute( + f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {castExpr}' + ) + cursor.execute('RELEASE SAVEPOINT col_migrate') + logger.info( + f"Migrated column '{col}' from {currentType} to {desired_sql} on '{table}'" + ) + except Exception as alter_err: + cursor.execute('ROLLBACK TO SAVEPOINT col_migrate') + logger.warning( + f"Could not migrate column '{col}' on '{table}': {alter_err}" + ) + except Exception as ensure_err: + logger.warning( + f"Could not ensure columns for existing table '{table}': {ensure_err}" + ) return True except Exception as e: logger.error(f"Error ensuring table {table} exists: {e}") - if hasattr(self, "connection") and self.connection: - self.connection.rollback() return False def _ensureVectorExtension(self) -> bool: @@ -666,17 +886,14 @@ class DatabaseConnector: if self._vectorExtensionEnabled: return True try: - self._ensure_connection() - with self.connection.cursor() as cursor: - cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") - self.connection.commit() + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") self._vectorExtensionEnabled = True logger.info("pgvector extension enabled") return True except Exception as e: logger.error(f"Failed to enable pgvector extension: {e}") - if hasattr(self, "connection") and self.connection: - self.connection.rollback() return False def _create_table_from_model(self, cursor, table: str, model_class: type) -> None: @@ -791,22 +1008,19 @@ class DatabaseConnector: if not self._ensureTableExists(model_class): return None - with self.connection.cursor() as cursor: - cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,)) - row = cursor.fetchone() - if not row: - return None + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,)) + row = cursor.fetchone() + if not row: + return None + record = dict(row) - # Convert row to dict and handle JSONB fields - record = dict(row) - fields = getModelFields(model_class) - - parseRecordFields(record, fields, f"record {recordId}") - - return record + fields = getModelFields(model_class) + parseRecordFields(record, fields, f"record {recordId}") + return record except Exception as e: logger.error(f"Error loading record {recordId} from table {table}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]: @@ -849,14 +1063,12 @@ class DatabaseConnector: if effective_user_id: record["sysModifiedBy"] = effective_user_id - with self.connection.cursor() as cursor: - self._save_record(cursor, table, recordId, record, model_class) - - self.connection.commit() + with self.borrowConn() as conn: + with conn.cursor() as cursor: + self._save_record(cursor, table, recordId, record, model_class) return True except Exception as e: logger.error(f"Error saving record {recordId} to table {table}: {e}") - self.connection.rollback() return False def _loadTable(self, model_class: type) -> List[Dict[str, Any]]: @@ -870,33 +1082,32 @@ class DatabaseConnector: if not self._ensureTableExists(model_class): return [] - with self.connection.cursor() as cursor: - cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"') - records = [dict(row) for row in cursor.fetchall()] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"') + records = [dict(row) for row in cursor.fetchall()] - fields = getModelFields(model_class) - modelFields = model_class.model_fields - for record in records: - parseRecordFields(record, fields, f"table {table}") - # Set type-aware defaults for NULL JSONB fields - for fieldName, fieldType in fields.items(): - if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: - fieldInfo = modelFields.get(fieldName) - if fieldInfo: - fieldAnnotation = fieldInfo.annotation - if (fieldAnnotation == list or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is list)): - record[fieldName] = [] - elif (fieldAnnotation == dict or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is dict)): - record[fieldName] = {} + fields = getModelFields(model_class) + modelFields = model_class.model_fields + for record in records: + parseRecordFields(record, fields, f"table {table}") + for fieldName, fieldType in fields.items(): + if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: + fieldInfo = modelFields.get(fieldName) + if fieldInfo: + fieldAnnotation = fieldInfo.annotation + if (fieldAnnotation == list or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is list)): + record[fieldName] = [] + elif (fieldAnnotation == dict or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is dict)): + record[fieldName] = {} - return records + return records except Exception as e: logger.error(f"Error loading table {table}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def _registerInitialId(self, table: str, initialId: str) -> bool: @@ -969,28 +1180,20 @@ class DatabaseConnector: def getTables(self) -> List[str]: """Returns a list of all available tables.""" - tables = [] - + tables: List[str] = [] try: - # Ensure connection is alive - self._ensure_connection() - - if not self.connection or self.connection.closed: - logger.error("Database connection is not available") - return tables - - with self.connection.cursor() as cursor: - cursor.execute(""" - SELECT table_name - FROM information_schema.tables - WHERE table_schema = 'public' - ORDER BY table_name - """) - rows = cursor.fetchall() - tables = [row["table_name"] for row in rows] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(""" + SELECT table_name + FROM information_schema.tables + WHERE table_schema = 'public' + ORDER BY table_name + """) + rows = cursor.fetchall() + tables = [row["table_name"] for row in rows] except Exception as e: logger.error(f"Error reading the database {self.dbDatabase}: {e}") - return tables def getFields(self, model_class: type) -> List[str]: @@ -1060,43 +1263,42 @@ class DatabaseConnector: query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"' - with self.connection.cursor() as cursor: - cursor.execute(query, where_values) - records = [dict(row) for row in cursor.fetchall()] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(query, where_values) + records = [dict(row) for row in cursor.fetchall()] - fields = getModelFields(model_class) - modelFields = model_class.model_fields + fields = getModelFields(model_class) + modelFields = model_class.model_fields + for record in records: + parseRecordFields(record, fields, f"table {table}") + for fieldName, fieldType in fields.items(): + if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: + fieldInfo = modelFields.get(fieldName) + if fieldInfo: + fieldAnnotation = fieldInfo.annotation + if (fieldAnnotation == list or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is list)): + record[fieldName] = [] + elif (fieldAnnotation == dict or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is dict)): + record[fieldName] = {} + + if fieldFilter and isinstance(fieldFilter, list): + result = [] for record in records: - parseRecordFields(record, fields, f"table {table}") - for fieldName, fieldType in fields.items(): - if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: - fieldInfo = modelFields.get(fieldName) - if fieldInfo: - fieldAnnotation = fieldInfo.annotation - if (fieldAnnotation == list or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is list)): - record[fieldName] = [] - elif (fieldAnnotation == dict or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is dict)): - record[fieldName] = {} + filteredRecord = {} + for field in fieldFilter: + if field in record: + filteredRecord[field] = record[field] + result.append(filteredRecord) + return result - # If fieldFilter is available, reduce the fields - if fieldFilter and isinstance(fieldFilter, list): - result = [] - for record in records: - filteredRecord = {} - for field in fieldFilter: - if field in record: - filteredRecord[field] = record[field] - result.append(filteredRecord) - return result - - return records + return records except Exception as e: logger.error(f"Error loading records from table {table}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def _buildPaginationClauses( @@ -1281,35 +1483,36 @@ class DatabaseConnector: where_clause, order_clause, limit_clause, values, count_values = \ self._buildPaginationClauses(model_class, pagination, recordFilter) - with self.connection.cursor() as cursor: - countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}' - dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}' - cursor.execute(countSql, count_values) - totalItems = cursor.fetchone()["count"] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}' + dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}' + cursor.execute(countSql, count_values) + totalItems = cursor.fetchone()["count"] - cursor.execute(dataSql, values) - records = [dict(row) for row in cursor.fetchall()] + cursor.execute(dataSql, values) + records = [dict(row) for row in cursor.fetchall()] - fields = getModelFields(model_class) - modelFields = model_class.model_fields - for record in records: - parseRecordFields(record, fields, f"table {table}") - for fieldName, fieldType in fields.items(): - if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: - fieldInfo = modelFields.get(fieldName) - if fieldInfo: - fieldAnnotation = fieldInfo.annotation - if (fieldAnnotation == list or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is list)): - record[fieldName] = [] - elif (fieldAnnotation == dict or - (hasattr(fieldAnnotation, "__origin__") and - fieldAnnotation.__origin__ is dict)): - record[fieldName] = {} + fields = getModelFields(model_class) + modelFields = model_class.model_fields + for record in records: + parseRecordFields(record, fields, f"table {table}") + for fieldName, fieldType in fields.items(): + if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: + fieldInfo = modelFields.get(fieldName) + if fieldInfo: + fieldAnnotation = fieldInfo.annotation + if (fieldAnnotation == list or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is list)): + record[fieldName] = [] + elif (fieldAnnotation == dict or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is dict)): + record[fieldName] = {} - if fieldFilter and isinstance(fieldFilter, list): - records = [{f: r[f] for f in fieldFilter if f in r} for r in records] + if fieldFilter and isinstance(fieldFilter, list): + records = [{f: r[f] for f in fieldFilter if f in r} for r in records] from modules.routes.routeHelpers import enrichRowsWithFkLabels enrichRowsWithFkLabels(records, model_class) @@ -1320,7 +1523,6 @@ class DatabaseConnector: return {"items": records, "totalItems": totalItems, "totalPages": totalPages} except Exception as e: logger.error(f"Error in getRecordsetPaginated for table {table}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def getDistinctColumnValues( @@ -1365,25 +1567,24 @@ class DatabaseConnector: else: sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val' - with self.connection.cursor() as cursor: - cursor.execute(sql, values) - result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(sql, values) + result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()] - if includeEmpty: - emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\'' - if where_clause: - emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1' - else: - emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1' - with self.connection.cursor() as cursor: - cursor.execute(emptySql, values) - if cursor.fetchone(): - result.append(None) + if includeEmpty: + emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\'' + if where_clause: + emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1' + else: + emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1' + cursor.execute(emptySql, values) + if cursor.fetchone(): + result.append(None) return result except Exception as e: logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def recordCreate( @@ -1463,33 +1664,33 @@ class DatabaseConnector: if not self._ensureTableExists(model_class): return False - with self.connection.cursor() as cursor: - # Check if record exists - cursor.execute( - f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,) - ) - if not cursor.fetchone(): - return False + # `getInitialId` opens its own borrow; do it BEFORE we acquire a + # connection ourselves so we don't pin two slots concurrently. + initialId = self.getInitialId(model_class) - # Check if it's an initial record - initialId = self.getInitialId(model_class) - if initialId is not None and initialId == recordId: - self._removeInitialId(table) - logger.info( - f"Initial ID {recordId} for table {table} has been removed from the system table" + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute( + f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,) ) + if not cursor.fetchone(): + return False - # Delete the record - cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,)) + if initialId is not None and initialId == recordId: + # `_removeInitialId` borrows its own conn — done outside + # this block on purpose to avoid nested borrows. + pass + cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,)) - # No cache to update - database handles consistency - - self.connection.commit() + if initialId is not None and initialId == recordId: + self._removeInitialId(table) + logger.info( + f"Initial ID {recordId} for table {table} has been removed from the system table" + ) return True except Exception as e: logger.error(f"Error deleting record {recordId} from table {table}: {e}") - self.connection.rollback() return False def recordCreateBulk( @@ -1559,16 +1760,11 @@ class DatabaseConnector: ) try: - self._ensure_connection() - with self.connection.cursor() as cursor: - psycopg2.extras.execute_values(cursor, sql, rows, page_size=500) - self.connection.commit() + with self.borrowConn() as conn: + with conn.cursor() as cursor: + psycopg2.extras.execute_values(cursor, sql, rows, page_size=500) except Exception as e: logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}") - try: - self.connection.rollback() - except Exception: - pass raise if self.getInitialId(model_class) is None and normalised: @@ -1649,26 +1845,21 @@ class DatabaseConnector: initialId = self.getInitialId(model_class) try: - self._ensure_connection() - with self.connection.cursor() as cursor: - if initialId is not None: - cursor.execute( - f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql, - [initialId, *params], - ) - initialIsAffected = cursor.fetchone() is not None - else: - initialIsAffected = False + with self.borrowConn() as conn: + with conn.cursor() as cursor: + if initialId is not None: + cursor.execute( + f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql, + [initialId, *params], + ) + initialIsAffected = cursor.fetchone() is not None + else: + initialIsAffected = False - cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params) - deleted = cursor.rowcount or 0 - self.connection.commit() + cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params) + deleted = cursor.rowcount or 0 except Exception as e: logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}") - try: - self.connection.rollback() - except Exception: - pass raise if deleted and initialIsAffected: @@ -1751,39 +1942,30 @@ class DatabaseConnector: ) params = [vectorStr] + whereValues + [vectorStr, limit] - with self.connection.cursor() as cursor: - cursor.execute(query, params) - records = [dict(row) for row in cursor.fetchall()] + with self.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute(query, params) + records = [dict(row) for row in cursor.fetchall()] - fields = getModelFields(modelClass) - for record in records: - parseRecordFields(record, fields, f"semanticSearch {table}") - - return records + fields = getModelFields(modelClass) + for record in records: + parseRecordFields(record, fields, f"semanticSearch {table}") + return records except Exception as e: logger.error(f"Error in semantic search on {table}: {e}") - _rollbackQuietly(getattr(self, "connection", None)) raise DatabaseQueryError(table, str(e), original=e) from e def close(self, forceClose: bool = False): - """Close the database connection. - - Shared cached connectors are intentionally kept open unless forceClose=True. - This prevents accidental shutdown from interface __del__ methods while - other requests are still using the same cached connector instance. + """No-op for backward compatibility. + + Connections are now owned by the `_PoolRegistry` pool and live for the + process lifetime. Pool shutdown happens centrally via `closeAllPools()` + from the FastAPI lifespan hook — never from a connector instance. + Interface `__del__` paths used to call `close()` to release a per- + connector socket; with pooling there is nothing to close here. """ - if self._isCachedShared and not forceClose: - return - if ( - hasattr(self, "connection") - and self.connection - and not self.connection.closed - ): - self.connection.close() + return def __del__(self): - """Cleanup method to close connection.""" - try: - self.close() - except Exception: - pass + """Cleanup hook (intentionally no-op — see `close`).""" + return diff --git a/modules/features/realEstate/interfaceFeatureRealEstate.py b/modules/features/realEstate/interfaceFeatureRealEstate.py index 1fbaf06f..0637d0e9 100644 --- a/modules/features/realEstate/interfaceFeatureRealEstate.py +++ b/modules/features/realEstate/interfaceFeatureRealEstate.py @@ -342,7 +342,7 @@ class RealEstateObjects: # If no exact match, try case-insensitive search via SQL query # This handles cases where the name might have different casing self.db._ensure_connection() - with self.db.connection.cursor() as cursor: + with self.db.borrowCursor() as cursor: cursor.execute( 'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1', (name,) @@ -375,7 +375,7 @@ class RealEstateObjects: # Try case-insensitive search self.db._ensure_connection() - with self.db.connection.cursor() as cursor: + with self.db.borrowCursor() as cursor: cursor.execute( 'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1', (name,) @@ -408,7 +408,7 @@ class RealEstateObjects: # Try case-insensitive search self.db._ensure_connection() - with self.db.connection.cursor() as cursor: + with self.db.borrowCursor() as cursor: cursor.execute( 'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1', (name,) @@ -840,7 +840,7 @@ class RealEstateObjects: # Ensure connection is alive self.db._ensure_connection() - with self.db.connection.cursor() as cursor: + with self.db.borrowCursor() as cursor: # Execute query if parameters: # Use parameterized query for safety diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py index 25f022af..273583d9 100644 --- a/modules/interfaces/interfaceDbBilling.py +++ b/modules/interfaces/interfaceDbBilling.py @@ -1659,7 +1659,7 @@ class BillingObjects: try: appInterface = getAppInterface(self.currentUser) appInterface.db._ensure_connection() - with appInterface.db.connection.cursor() as cur: + with appInterface.db.borrowCursor() as cur: if appInterface.db._ensureTableExists(UserInDB): cur.execute( 'SELECT "id" FROM "UserInDB" WHERE ' @@ -1780,7 +1780,7 @@ class BillingObjects: try: self.db._ensure_connection() - with self.db.connection.cursor() as cur: + with self.db.borrowCursor() as cur: countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}' cur.execute(countSql, whereValues) totalItems = cur.fetchone()["count"] @@ -1797,10 +1797,7 @@ class BillingObjects: except Exception as e: logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True) - try: - self.db.connection.rollback() - except Exception: - pass + # Rollback is handled by `borrowCursor()` context manager on exit. return {"items": [], "totalItems": 0, "totalPages": 0} def _buildScopeFilter( @@ -1872,7 +1869,7 @@ class BillingObjects: result: Dict[str, Any] = {} - with self.db.connection.cursor() as cur: + with self.db.borrowCursor() as cur: # 1) Totals cur.execute( f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}', @@ -1947,17 +1944,12 @@ class BillingObjects: }) result["timeSeries"] = timeSeries - self.db.connection.commit() - + # Commit/rollback are handled by `borrowCursor()` context manager. result["_allAccounts"] = allAccounts return result except Exception as e: logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True) - try: - self.db.connection.rollback() - except Exception: - pass return self._emptyStats() @staticmethod diff --git a/modules/interfaces/interfaceDbKnowledge.py b/modules/interfaces/interfaceDbKnowledge.py index 31a5af61..d7a445bd 100644 --- a/modules/interfaces/interfaceDbKnowledge.py +++ b/modules/interfaces/interfaceDbKnowledge.py @@ -228,6 +228,22 @@ class KnowledgeObjects: """Get all ContentChunks for a file.""" return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId}) + def countChunksByFileIds(self, fileIds: List[str]) -> Dict[str, int]: + """Return a {fileId: chunkCount} mapping for the given file IDs. + + One aggregate query instead of N round trips. Used by RAG inventory + to display real chunk counts per DataSource without loading the + embedding vectors. Missing file IDs map to 0 in the caller's logic. + """ + if not fileIds: + return {} + if not self.db._ensureTableExists(ContentChunk): + return {} + sql = 'SELECT "fileId", COUNT(*) AS cnt FROM "ContentChunk" WHERE "fileId" = ANY(%s) GROUP BY "fileId"' + with self.db.borrowCursor() as cursor: + cursor.execute(sql, (list(fileIds),)) + return {row["fileId"]: int(row["cnt"]) for row in cursor.fetchall()} + def deleteContentChunks(self, fileId: str) -> int: """Delete all ContentChunks for a file. Returns count of deleted chunks.""" chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId}) diff --git a/modules/interfaces/interfaceDbManagement.py b/modules/interfaces/interfaceDbManagement.py index 6a3c27b5..4dc8a206 100644 --- a/modules/interfaces/interfaceDbManagement.py +++ b/modules/interfaces/interfaceDbManagement.py @@ -1221,22 +1221,17 @@ class ComponentObjects: for item in fileRows ] - # Single transaction: delete FileData, FileItem, then FileFolder (children first) - self.db._ensure_connection() - try: - with self.db.connection.cursor() as cursor: - if fileIds: - cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,)) - cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,)) - orderedIds = list(folderIds) - orderedIds.remove(folderId) - orderedIds.append(folderId) - if orderedIds: - cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,)) - self.db.connection.commit() - except Exception: - self.db.connection.rollback() - raise + # Single transaction: delete FileData, FileItem, then FileFolder (children first). + # Commit/rollback are handled by `borrowCursor()` on exit. + with self.db.borrowCursor() as cursor: + if fileIds: + cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,)) + cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,)) + orderedIds = list(folderIds) + orderedIds.remove(folderId) + orderedIds.append(folderId) + if orderedIds: + cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,)) return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)} @@ -1507,7 +1502,7 @@ class ComponentObjects: try: self.db._ensure_connection() - with self.db.connection.cursor() as cursor: + with self.db.borrowCursor() as cursor: cursor.execute( 'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)', (uniqueIds,), @@ -1526,11 +1521,10 @@ class ComponentObjects: cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,)) deletedFiles = cursor.rowcount - self.db.connection.commit() + # Commit/rollback are handled by `borrowCursor()` context manager. return {"deletedFiles": deletedFiles} except Exception as e: logger.error(f"Error deleting files in batch: {e}") - self.db.connection.rollback() raise FileDeletionError(f"Error deleting files in batch: {str(e)}") def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]: diff --git a/modules/interfaces/interfaceRbac.py b/modules/interfaces/interfaceRbac.py index e41485e0..948609ef 100644 --- a/modules/interfaces/interfaceRbac.py +++ b/modules/interfaces/interfaceRbac.py @@ -374,7 +374,7 @@ def getRecordsetWithRBAC( query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}' - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: cursor.execute(query, whereValues) records = [dict(row) for row in cursor.fetchall()] @@ -561,7 +561,7 @@ def getRecordsetPaginatedWithRBAC( offset = (pagination.page - 1) * pagination.pageSize limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}" - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}' cursor.execute(countSql, countValues) totalItems = cursor.fetchone()["count"] @@ -709,7 +709,7 @@ def getDistinctColumnValuesWithRBAC( sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val' - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: cursor.execute(sql, whereValues) result = [row["val"] for row in cursor.fetchall()] @@ -719,7 +719,7 @@ def getDistinctColumnValuesWithRBAC( emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1' else: emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1' - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: cursor.execute(emptySql, whereValues) if cursor.fetchone(): result.append(None) @@ -967,7 +967,7 @@ def buildRbacWhereClause( # Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate if table == "UserInDB": try: - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: # Get all user IDs that are members of the current mandate cursor.execute( 'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true', @@ -994,7 +994,7 @@ def buildRbacWhereClause( # For UserConnection: Filter via UserMandate junction table elif table == "UserConnection": try: - with connector.connection.cursor() as cursor: + with connector.borrowCursor() as cursor: # Get all user IDs that are members of the current mandate cursor.execute( 'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true', diff --git a/modules/routes/routeHelpers.py b/modules/routes/routeHelpers.py index f1d88e31..bb1386af 100644 --- a/modules/routes/routeHelpers.py +++ b/modules/routes/routeHelpers.py @@ -305,7 +305,7 @@ def handleIdsMode( sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"' - with db.connection.cursor() as cursor: + with db.borrowCursor() as cursor: cursor.execute(sql, values) return JSONResponse(content=[row["val"] for row in cursor.fetchall()]) except Exception as e: diff --git a/modules/routes/routeRagInventory.py b/modules/routes/routeRagInventory.py index 074b5b85..7c426d77 100644 --- a/modules/routes/routeRagInventory.py +++ b/modules/routes/routeRagInventory.py @@ -25,6 +25,18 @@ router = APIRouter( def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]: + """Build per-connection RAG inventory rows. + + Each DataSource row exposes BOTH numbers because they mean different things: + * `fileCount` — distinct files indexed (== `FileContentIndex` rows) + * `chunkCount` — embedding-sized text fragments (== `ContentChunk` rows, + max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval + actually hits) + + A single PDF typically yields 1 file × 5–100 chunks; legacy UI labelled + `len(FileContentIndex)` as "chunks" which was off by 1–2 orders of + magnitude and misleading. + """ from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelKnowledge import FileContentIndex @@ -34,19 +46,35 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId}) - connChunkTotal = len(connIndexRows) + connFileTotal = len(connIndexRows) + # Map fileId → real chunk count via 1 aggregate query (cheap even for + # connections with thousands of files; we never load the vector body). + fileIds = [ + (idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")) + for idx in connIndexRows + ] + fileIds = [fid for fid in fileIds if fid] + chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {} + connChunkTotal = sum(chunkCountByFile.values()) + + filesByDs: Dict[str, int] = {} chunksByDs: Dict[str, int] = {} - unassigned = 0 + unassignedFiles = 0 + unassignedChunks = 0 for idx in connIndexRows: + fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "") + chunkCnt = chunkCountByFile.get(fileId, 0) struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {} ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {} prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {} dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else "" if dsIdRef: - chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1 + filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1 + chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt else: - unassigned += 1 + unassignedFiles += 1 + unassignedChunks += chunkCnt seen: Dict[str, bool] = {} dsItems = [] @@ -64,14 +92,19 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False), "neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False), "lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None), + "fileCount": filesByDs.get(dsId, 0), "chunkCount": chunksByDs.get(dsId, 0), }) - if unassigned > 0 and len(dsItems) > 0: - perDs = unassigned // len(dsItems) - remainder = unassigned % len(dsItems) + # Spread orphan files (provenance lost) evenly so totals match. + if unassignedFiles > 0 and len(dsItems) > 0: + perFile = unassignedFiles // len(dsItems) + remFile = unassignedFiles % len(dsItems) + perChunk = unassignedChunks // len(dsItems) + remChunk = unassignedChunks % len(dsItems) for i, item in enumerate(dsItems): - item["chunkCount"] += perDs + (1 if i < remainder else 0) + item["fileCount"] += perFile + (1 if i < remFile else 0) + item["chunkCount"] += perChunk + (1 if i < remChunk else 0) # Pull a wider window than the previous 5 so the "last successful # sync" is found even if a connection has many recent jobs queued. @@ -102,6 +135,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "skippedPolicy": result.get("skippedPolicy", 0), "failed": result.get("failed", 0), "durationMs": result.get("durationMs", 0), + # Surface limit-stop reason so the UI can warn the user + # that the index is provably incomplete (and which budget + # to raise). None means the walker finished naturally. + "stoppedAtLimit": result.get("stoppedAtLimit"), + "limits": result.get("limits") or {}, + "bytesProcessed": result.get("bytesProcessed", 0), } if lastError and lastSuccess: break @@ -113,6 +152,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False), "preferences": getattr(conn, "knowledgePreferences", None) or {}, "dataSources": dsItems, + "totalFiles": connFileTotal, "totalChunks": connChunkTotal, "runningJobs": runningJobs, "lastError": lastError, @@ -139,8 +179,9 @@ def _getInventoryMe( items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) + totalFiles = sum(c.get("totalFiles", 0) for c in items) - return {"connections": items, "totals": {"chunks": totalChunks}} + return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}} except Exception as e: logger.error("Error in RAG inventory /me: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @@ -170,9 +211,10 @@ def _getInventoryMandate( items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) + totalFiles = sum(c.get("totalFiles", 0) for c in items) totalBytes = aggregateMandateRagTotalBytes(mandateId) - return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}} + return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}} except HTTPException: raise except Exception as e: @@ -202,8 +244,9 @@ def _getInventoryPlatform( items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) + totalFiles = sum(c.get("totalFiles", 0) for c in items) - return {"connections": items, "totals": {"chunks": totalChunks}} + return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}} except HTTPException: raise except Exception as e: diff --git a/modules/routes/routeWorkflowDashboard.py b/modules/routes/routeWorkflowDashboard.py index d83ce1b2..85b372a1 100644 --- a/modules/routes/routeWorkflowDashboard.py +++ b/modules/routes/routeWorkflowDashboard.py @@ -227,7 +227,7 @@ WHERE "workflowId" = ANY(%s) GROUP BY "workflowId" """ out: dict = {} - with db.connection.cursor() as cursor: + with db.borrowCursor() as cursor: cursor.execute(sql, (workflowIds,)) for row in cursor.fetchall(): r = dict(row) @@ -480,7 +480,7 @@ def _getWorkflowsJoinedPaginated( dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}" db._ensure_connection() - with db.connection.cursor() as cursor: + with db.borrowCursor() as cursor: cursor.execute(countSql, countValues) totalItems = int(cursor.fetchone()["cnt"]) diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py index 4fbea490..bdb3d23b 100644 --- a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py +++ b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py @@ -25,15 +25,14 @@ _CACHE_TTL_SECONDS = 300 def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str): - """Reuse a pooled DB connector for the given feature database.""" + """Reuse a pooled DB connector for the given feature database. + + The underlying psycopg2 connections live in the central pool + (`_PoolRegistry`) and are recreated on demand if they go stale; we just + need to keep the lightweight connector wrapper around. + """ if featureDbName in _featureDbConnPool: - conn = _featureDbConnPool[featureDbName] - try: - if conn.connection and not conn.connection.closed: - return conn - except Exception as e: - logger.warning(f"Feature DB connection check failed for {featureDbName}: {e}") - _featureDbConnPool.pop(featureDbName, None) + return _featureDbConnPool[featureDbName] from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py index 8bfa2628..959e42c9 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py @@ -68,6 +68,9 @@ class ClickupBootstrapResult: workspaces: int = 0 lists: int = 0 errors: List[str] = field(default_factory=list) + # First budget exhausted: "maxTasks" | "maxWorkspaces" | "maxListsPerWorkspace" | None. + # Drives the same UI banner as the file-walker bootstraps. + stoppedAtLimit: Optional[str] = None def _syntheticTaskId(connectionId: str, taskId: str) -> str: @@ -225,6 +228,7 @@ async def bootstrapClickup( cancelled = False for ds in dataSources: if result.indexed + result.skippedDuplicate >= limits.maxTasks: + _recordLimitStop(result, "maxTasks", "dataSource", limits) break if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): cancelled = True @@ -243,8 +247,11 @@ async def bootstrapClickup( clickupScope=limits.clickupScope, ) + if len(teams) > dsLimits.maxWorkspaces: + _recordLimitStop(result, "maxWorkspaces", "teams", dsLimits, hard=False) for team in teams[:dsLimits.maxWorkspaces]: if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks: + _recordLimitStop(result, "maxTasks", f"team={team.get('id','')}", dsLimits) break teamId = str(team.get("id", "") or "") if not teamId: @@ -351,6 +358,7 @@ async def _walkTeam( for lst in listsCollected: if result.indexed + result.skippedDuplicate >= limits.maxTasks: + _recordLimitStop(result, "maxTasks", f"team={teamId}", limits) return if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): return @@ -407,6 +415,7 @@ async def _walkList( for task in tasks: if result.indexed + result.skippedDuplicate >= limits.maxTasks: + _recordLimitStop(result, "maxTasks", f"list={listId}", limits) return if not _isRecent(task.get("date_updated"), limits.maxAgeDays): result.skippedPolicy += 1 @@ -529,13 +538,37 @@ async def _ingestTask( ) +def _recordLimitStop( + result: ClickupBootstrapResult, + limitName: str, + where: str, + limits: ClickupBootstrapLimits, + *, + hard: bool = True, +) -> None: + """See subConnectorSyncSharepoint._recordLimitStop for semantics.""" + if hard or result.stoppedAtLimit is None: + result.stoppedAtLimit = limitName + budgetMap = { + "maxTasks": limits.maxTasks, + "maxWorkspaces": limits.maxWorkspaces, + "maxListsPerWorkspace": limits.maxListsPerWorkspace, + } + logger.warning( + "clickup walker hit %s=%s at %s — partial index (indexed=%d, skippedDup=%d).", + limitName, budgetMap.get(limitName), where, + result.indexed, result.skippedDuplicate, + ) + + def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]: durationMs = int((time.time() - startMs) * 1000) logger.info( - "ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d", + "ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d stoppedAtLimit=%s", connectionId, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, result.workspaces, result.lists, durationMs, + result.stoppedAtLimit or "none", extra={ "event": "ingestion.connection.bootstrap.done", "part": "clickup", @@ -547,6 +580,7 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: "workspaces": result.workspaces, "lists": result.lists, "durationMs": durationMs, + "stoppedAtLimit": result.stoppedAtLimit, }, ) return { @@ -559,4 +593,11 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: "lists": result.lists, "durationMs": durationMs, "errors": result.errors[:20], + "stoppedAtLimit": result.stoppedAtLimit, + "limits": { + "maxTasks": MAX_TASKS_DEFAULT, + "maxWorkspaces": MAX_WORKSPACES_DEFAULT, + "maxListsPerWorkspace": MAX_LISTS_PER_WORKSPACE_DEFAULT, + "maxAgeDays": MAX_AGE_DAYS_DEFAULT, + }, } diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py index 5dd1bd8b..e27abacb 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py @@ -61,6 +61,8 @@ class GdriveBootstrapResult: failed: int = 0 bytesProcessed: int = 0 errors: List[str] = field(default_factory=list) + # See SharepointBootstrapResult.stoppedAtLimit — same semantics. + stoppedAtLimit: Optional[str] = None def _syntheticFileId(connectionId: str, externalItemId: str) -> str: @@ -265,8 +267,10 @@ async def _walkFolder( for entry in entries: if result.indexed + result.skippedDuplicate >= limits.maxItems: + _recordLimitStop(result, "maxItems", folderPath, limits) return if result.bytesProcessed >= limits.maxBytes: + _recordLimitStop(result, "maxBytes", folderPath, limits) return if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): return @@ -276,6 +280,9 @@ async def _walkFolder( mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType") if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME: + if depth + 1 > limits.maxDepth: + _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False) + continue await _walkFolder( adapter=adapter, knowledgeService=knowledgeService, @@ -298,6 +305,7 @@ async def _walkFolder( continue size = int(getattr(entry, "size", 0) or 0) if size and size > limits.maxFileSize: + _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False) result.skippedPolicy += 1 continue modifiedTime = metadata.get("modifiedTime") @@ -470,13 +478,38 @@ async def _ingestOne( await asyncio.sleep(0) +def _recordLimitStop( + result: GdriveBootstrapResult, + limitName: str, + where: str, + limits: GdriveBootstrapLimits, + *, + hard: bool = True, +) -> None: + """See subConnectorSyncSharepoint._recordLimitStop for semantics.""" + if hard or result.stoppedAtLimit is None: + result.stoppedAtLimit = limitName + budgetMap = { + "maxItems": limits.maxItems, + "maxBytes": limits.maxBytes, + "maxDepth": limits.maxDepth, + "maxFileSize": limits.maxFileSize, + } + logger.warning( + "gdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).", + limitName, budgetMap.get(limitName), where, + result.indexed, result.bytesProcessed, + ) + + def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]: durationMs = int((time.time() - startMs) * 1000) logger.info( - "ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d", + "ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d stoppedAtLimit=%s", connectionId, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, result.bytesProcessed, durationMs, + result.stoppedAtLimit or "none", extra={ "event": "ingestion.connection.bootstrap.done", "part": "gdrive", @@ -487,6 +520,7 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f "failed": result.failed, "bytes": result.bytesProcessed, "durationMs": durationMs, + "stoppedAtLimit": result.stoppedAtLimit, }, ) return { @@ -498,4 +532,11 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f "bytesProcessed": result.bytesProcessed, "durationMs": durationMs, "errors": result.errors[:20], + "stoppedAtLimit": result.stoppedAtLimit, + "limits": { + "maxItems": MAX_ITEMS_DEFAULT, + "maxBytes": MAX_BYTES_DEFAULT, + "maxFileSize": MAX_FILE_SIZE_DEFAULT, + "maxDepth": MAX_DEPTH_DEFAULT, + }, } diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py index e656abe8..dcf19e39 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py @@ -53,6 +53,8 @@ class KdriveBootstrapResult: failed: int = 0 bytesProcessed: int = 0 errors: List[str] = field(default_factory=list) + # See SharepointBootstrapResult.stoppedAtLimit — same semantics. + stoppedAtLimit: Optional[str] = None def _syntheticFileId(connectionId: str, externalItemId: str) -> str: @@ -232,14 +234,19 @@ async def _walkFolder( for entry in entries: if result.indexed + result.skippedDuplicate >= limits.maxItems: + _recordLimitStop(result, "maxItems", folderPath, limits) return if result.bytesProcessed >= limits.maxBytes: + _recordLimitStop(result, "maxBytes", folderPath, limits) return if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): return entryPath = getattr(entry, "path", "") or "" if getattr(entry, "isFolder", False): + if depth + 1 > limits.maxDepth: + _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False) + continue await _walkFolder( adapter=adapter, knowledgeService=knowledgeService, @@ -262,6 +269,7 @@ async def _walkFolder( continue size = int(getattr(entry, "size", 0) or 0) if size and size > limits.maxFileSize: + _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False) result.skippedPolicy += 1 continue @@ -415,17 +423,42 @@ async def _ingestOne( await asyncio.sleep(0) +def _recordLimitStop( + result: KdriveBootstrapResult, + limitName: str, + where: str, + limits: KdriveBootstrapLimits, + *, + hard: bool = True, +) -> None: + """See subConnectorSyncSharepoint._recordLimitStop for semantics.""" + if hard or result.stoppedAtLimit is None: + result.stoppedAtLimit = limitName + budgetMap = { + "maxItems": limits.maxItems, + "maxBytes": limits.maxBytes, + "maxDepth": limits.maxDepth, + "maxFileSize": limits.maxFileSize, + } + logger.warning( + "kdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).", + limitName, budgetMap.get(limitName), where, + result.indexed, result.bytesProcessed, + ) + + def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]: durationMs = int((time.time() - startMs) * 1000) logger.info( - "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d", + "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s", connectionId, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, - durationMs, + durationMs, result.stoppedAtLimit or "none", extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive", "connectionId": connectionId, "indexed": result.indexed, "skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy, - "failed": result.failed, "durationMs": durationMs}, + "failed": result.failed, "durationMs": durationMs, + "stoppedAtLimit": result.stoppedAtLimit}, ) return { "connectionId": result.connectionId, @@ -436,4 +469,11 @@ def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: f "bytesProcessed": result.bytesProcessed, "durationMs": durationMs, "errors": result.errors[:20], + "stoppedAtLimit": result.stoppedAtLimit, + "limits": { + "maxItems": MAX_ITEMS_DEFAULT, + "maxBytes": MAX_BYTES_DEFAULT, + "maxFileSize": MAX_FILE_SIZE_DEFAULT, + "maxDepth": MAX_DEPTH_DEFAULT, + }, } diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py index 892e41ba..e06fd36b 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py @@ -59,6 +59,10 @@ class SharepointBootstrapResult: failed: int = 0 bytesProcessed: int = 0 errors: List[str] = field(default_factory=list) + # First budget that hit zero; None means the walk completed naturally. + # Surfaces in the bootstrap result so the RAG inventory UI can warn the + # user that the corpus is incomplete and tell them which knob to turn. + stoppedAtLimit: Optional[str] = None # "maxItems" | "maxBytes" | "maxDepth" | "maxFileSize" | None def _syntheticFileId(connectionId: str, externalItemId: str) -> str: @@ -259,14 +263,22 @@ async def _walkFolder( for entry in entries: if result.indexed + result.skippedDuplicate >= limits.maxItems: + _recordLimitStop(result, "maxItems", folderPath, limits) return if result.bytesProcessed >= limits.maxBytes: + _recordLimitStop(result, "maxBytes", folderPath, limits) return if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): return entryPath = getattr(entry, "path", "") or "" if getattr(entry, "isFolder", False): + if depth + 1 > limits.maxDepth: + # We stop descending here but keep walking siblings. + # Record once per bootstrap so the UI shows "maxDepth" even + # if other budgets aren't exhausted yet. + _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False) + continue await _walkFolder( adapter=adapter, knowledgeService=knowledgeService, @@ -289,6 +301,7 @@ async def _walkFolder( continue size = int(getattr(entry, "size", 0) or 0) if size and size > limits.maxFileSize: + _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False) result.skippedPolicy += 1 continue @@ -443,13 +456,44 @@ async def _ingestOne( await asyncio.sleep(0) +def _recordLimitStop( + result: SharepointBootstrapResult, + limitName: str, + where: str, + limits: SharepointBootstrapLimits, + *, + hard: bool = True, +) -> None: + """Mark the FIRST limit that bit. Soft hits (per-file maxFileSize, per-folder + maxDepth) only record when no hard limit has yet stopped the run, so the UI + surfaces the most important reason. + + Hard limits (maxItems / maxBytes) ALWAYS overwrite a previously recorded + soft limit — once a hard cap is hit, the corpus is provably incomplete. + """ + if hard or result.stoppedAtLimit is None: + result.stoppedAtLimit = limitName + budgetMap = { + "maxItems": limits.maxItems, + "maxBytes": limits.maxBytes, + "maxDepth": limits.maxDepth, + "maxFileSize": limits.maxFileSize, + } + logger.warning( + "sharepoint walker hit %s=%s at %s — partial index " + "(indexed=%d, bytesProcessed=%d). Raise the limit or split the data source.", + limitName, budgetMap.get(limitName), where, + result.indexed, result.bytesProcessed, + ) + + def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]: durationMs = int((time.time() - startMs) * 1000) logger.info( - "ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d", + "ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s", connectionId, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, - durationMs, + durationMs, result.stoppedAtLimit or "none", extra={ "event": "ingestion.connection.bootstrap.done", "part": "sharepoint", @@ -459,6 +503,7 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM "skippedPolicy": result.skippedPolicy, "failed": result.failed, "durationMs": durationMs, + "stoppedAtLimit": result.stoppedAtLimit, }, ) return { @@ -470,4 +515,11 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM "bytesProcessed": result.bytesProcessed, "durationMs": durationMs, "errors": result.errors[:20], + "stoppedAtLimit": result.stoppedAtLimit, + "limits": { + "maxItems": MAX_ITEMS_DEFAULT, + "maxBytes": MAX_BYTES_DEFAULT, + "maxFileSize": MAX_FILE_SIZE_DEFAULT, + "maxDepth": MAX_DEPTH_DEFAULT, + }, } diff --git a/modules/shared/configuration.py b/modules/shared/configuration.py index 721ce448..15646962 100644 --- a/modules/shared/configuration.py +++ b/modules/shared/configuration.py @@ -12,7 +12,8 @@ import logging import json import base64 import time -from typing import Any, Dict, Optional +import threading +from typing import Any, Dict, Optional, Tuple from pathlib import Path from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes @@ -286,6 +287,16 @@ def handleSecretJson(value: str, userId: str = "system", keyName: str = "unknown # Structure: {user_id: {key_name: [timestamps]}} _decryption_attempts = {} +# Process-wide plaintext cache for decrypted secrets. +# Key: the encrypted ciphertext (which already includes env prefix). +# Value: (expiresAtMonotonic, plaintext). +# TTL is short enough that key rotation propagates quickly, long enough that +# hot DB-init paths (every API call building a connector) don't blow the +# decryption rate limit. 60s is a deliberate compromise. +_DECRYPTION_CACHE_TTL_S = 60.0 +_decryption_cache: Dict[str, Tuple[float, str]] = {} +_decryption_cache_lock = threading.Lock() + def _getMasterKey(envType: str = None) -> bytes: """ Get the master key for the specified environment. @@ -486,25 +497,43 @@ def encryptValue(value: str, envType: str = None, userId: str = "system", keyNam def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "unknown") -> str: """ Decrypt a value using the master key for the current environment. - + + A short-lived plaintext cache (TTL `_DECRYPTION_CACHE_TTL_S`) is consulted + first. The 10/sec rate-limit on cache misses still protects against + brute-force attacks; cache HITS bypass it because they are not actual + cryptographic operations — they just return the result of an earlier + successful decrypt. Without this cache, hot paths like + `mainBackgroundJobService._getDb()` (called per RAG inventory poll AND + per walker DB call) trigger the rate limit and surface as + "Decryption rate limit exceeded for user 'system' key 'DB_PASSWORD_SECRET'" + ERRORs in the RAG inventory UI route. + Args: encryptedValue: The encrypted value with prefix userId: The user ID making the request (default: "system") keyName: The name of the key being decrypted (default: "unknown") - + Returns: str: The decrypted plain text value - + Raises: ValueError: If decryption fails """ if not _isEncryptedValue(encryptedValue): return encryptedValue # Return as-is if not encrypted - - # Check rate limiting (10 per second per user per key) + + # Cache lookup BEFORE the rate-limit check: a cache hit is not a new + # cryptographic operation and must not be throttled. + now = time.monotonic() + with _decryption_cache_lock: + cached = _decryption_cache.get(encryptedValue) + if cached is not None and cached[0] > now: + return cached[1] + + # Cache miss → real decrypt → apply rate limit. if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10): raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)") - + try: # Extract environment type from prefix if encryptedValue.startswith('DEV_ENC:'): @@ -536,7 +565,7 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un encryptedBytes = base64.urlsafe_b64decode(encryptedPart.encode('utf-8')) decryptedBytes = fernet.decrypt(encryptedBytes) decryptedValue = decryptedBytes.decode('utf-8') - + # Log audit event for decryption try: from modules.shared.auditLogger import audit_logger @@ -549,11 +578,25 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un except Exception: # Don't fail if audit logging fails pass - + + # Populate cache so subsequent reads of the same ciphertext don't + # re-decrypt (and don't consume rate-limit budget). + with _decryption_cache_lock: + _decryption_cache[encryptedValue] = ( + time.monotonic() + _DECRYPTION_CACHE_TTL_S, + decryptedValue, + ) + return decryptedValue - + except Exception as e: raise ValueError(f"Decryption failed: {e}") + +def clearDecryptionCache() -> None: + """Drop all cached plaintext secrets. Call after key rotation or in tests.""" + with _decryption_cache_lock: + _decryption_cache.clear() + # Create the global APP_CONFIG instance APP_CONFIG = Configuration() \ No newline at end of file diff --git a/modules/shared/dbMultiTenantOptimizations.py b/modules/shared/dbMultiTenantOptimizations.py index c178c376..9b5a15b4 100644 --- a/modules/shared/dbMultiTenantOptimizations.py +++ b/modules/shared/dbMultiTenantOptimizations.py @@ -33,20 +33,35 @@ def _ensureUamTablesMatchModels(dbConnector) -> None: logger.debug(f"_ensureUamTablesMatchModels: {e}") -def _getConnection(dbConnector): - """Get a connection from the DatabaseConnector. - - Ensures the connection is alive and returns it. - Commits any pending transaction first to avoid blocking. +from contextlib import contextmanager + + +@contextmanager +def _borrowDbConn(dbConnector): + """Borrow a pooled connection from the DatabaseConnector. + + Index/trigger/FK creation traditionally ran with `conn.autocommit = True` + so each CREATE statement is its own transaction (DDL on a managed + connection blocks waiting for COMMIT). This helper preserves that + behaviour on top of the pool: borrow a connection, flip it to autocommit, + yield it, and restore the previous state before returning it to the pool. """ - dbConnector._ensure_connection() - conn = dbConnector.connection - # Commit any pending transaction to avoid blocking - try: - conn.commit() - except Exception: - pass # Ignore if nothing to commit - return conn + with dbConnector.borrowConn() as conn: + try: + previousAutocommit = conn.autocommit + except Exception: + previousAutocommit = False + try: + conn.autocommit = True + except Exception as e: + logger.debug(f"Could not set autocommit on borrowed connection: {e}") + try: + yield conn + finally: + try: + conn.autocommit = previousAutocommit + except Exception: + pass # ============================================================================= @@ -174,73 +189,42 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non } try: - # Get a connection from the connector - conn = _getConnection(dbConnector) - - # Save and set autocommit state - try: - originalAutocommit = conn.autocommit - except Exception: - originalAutocommit = False - - try: - conn.autocommit = True - except Exception as autoErr: - logger.debug(f"Could not set autocommit: {autoErr}") - try: _ensureUamTablesMatchModels(dbConnector) except Exception as preIdxErr: logger.debug(f"Pre-index table ensure: {preIdxErr}") - - try: + + with _borrowDbConn(dbConnector) as conn: with conn.cursor() as cursor: - # Apply indexes results["indexesCreated"] = _applyIndexes(cursor, tables) - - # Apply foreign keys results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables) - - # Apply immutable triggers results["triggersCreated"] = _applyImmutableTriggers(cursor, tables) - - logger.info( - f"Multi-tenant optimizations applied: " - f"{results['indexesCreated']} indexes, " - f"{results['triggersCreated']} triggers, " - f"{results['foreignKeysCreated']} foreign keys" - ) - finally: - # Restore original autocommit state - try: - conn.autocommit = originalAutocommit - except Exception: - pass - + + logger.info( + f"Multi-tenant optimizations applied: " + f"{results['indexesCreated']} indexes, " + f"{results['triggersCreated']} triggers, " + f"{results['foreignKeysCreated']} foreign keys" + ) + except Exception as e: logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}") results["errors"].append(str(e)) - + return results def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int: """Apply only indexes (lighter operation, safe for frequent calls).""" try: - conn = _getConnection(dbConnector) - originalAutocommit = conn.autocommit - conn.autocommit = True - try: _ensureUamTablesMatchModels(dbConnector) except Exception as preIdxErr: logger.debug(f"Pre-index table ensure: {preIdxErr}") - - try: + + with _borrowDbConn(dbConnector) as conn: with conn.cursor() as cursor: return _applyIndexes(cursor, tables) - finally: - conn.autocommit = originalAutocommit except Exception as e: logger.error(f"Error applying indexes: {e}") return 0 @@ -514,8 +498,7 @@ def getOptimizationStatus(dbConnector) -> dict: } try: - conn = _getConnection(dbConnector) - with conn.cursor() as cursor: + with _borrowDbConn(dbConnector) as conn, conn.cursor() as cursor: # Check regular indexes for tableName, indexName, _ in _INDEXES: if _tableExists(cursor, tableName): diff --git a/modules/shared/gdprDeletion.py b/modules/shared/gdprDeletion.py index 99e09313..45a9ea43 100644 --- a/modules/shared/gdprDeletion.py +++ b/modules/shared/gdprDeletion.py @@ -60,11 +60,9 @@ def _getTableColumns(dbConnector, tableName: str) -> List[str]: ORDER BY ordinal_position """ - cursor = dbConnector.connection.cursor() - cursor.execute(query, (tableName,)) - columns = [row[0] for row in cursor.fetchall()] - cursor.close() - + with dbConnector.borrowCursor() as cursor: + cursor.execute(query, (tableName,)) + columns = [row[0] for row in cursor.fetchall()] return columns except Exception as e: logger.error(f"Error getting columns for table {tableName}: {e}") @@ -92,29 +90,26 @@ def _getAllTables(dbConnector) -> List[str]: ORDER BY table_name """ - cursor = dbConnector.connection.cursor() - cursor.execute(query) - allTables = [row[0] for row in cursor.fetchall()] - - # Get foreign key relationships to determine dependency order - fkQuery = """ - SELECT - tc.table_name, - ccu.table_name AS foreign_table_name - FROM information_schema.table_constraints AS tc - JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema - JOIN information_schema.constraint_column_usage AS ccu - ON ccu.constraint_name = tc.constraint_name - AND ccu.table_schema = tc.table_schema - WHERE tc.constraint_type = 'FOREIGN KEY' - AND tc.table_schema = 'public' - """ - - cursor.execute(fkQuery) - foreignKeys = cursor.fetchall() - cursor.close() + with dbConnector.borrowCursor() as cursor: + cursor.execute(query) + allTables = [row[0] for row in cursor.fetchall()] + + fkQuery = """ + SELECT + tc.table_name, + ccu.table_name AS foreign_table_name + FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = 'public' + """ + cursor.execute(fkQuery) + foreignKeys = cursor.fetchall() # Build dependency graph (child -> parent mapping) dependencies = {} @@ -154,10 +149,9 @@ def _getAllTables(dbConnector) -> List[str]: # Fallback: return simple list without ordering try: query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'" - cursor = dbConnector.connection.cursor() - cursor.execute(query) - tables = [row[0] for row in cursor.fetchall()] - cursor.close() + with dbConnector.borrowCursor() as cursor: + cursor.execute(query) + tables = [row[0] for row in cursor.fetchall()] return [t for t in tables if t not in PROTECTED_TABLES] except Exception: return [] @@ -184,11 +178,9 @@ def _getPrimaryKeyColumns(dbConnector, tableName: str) -> List[str]: AND i.indisprimary """ - cursor = dbConnector.connection.cursor() - cursor.execute(query, (tableName,)) - pkColumns = [row[0] for row in cursor.fetchall()] - cursor.close() - + with dbConnector.borrowCursor() as cursor: + cursor.execute(query, (tableName,)) + pkColumns = [row[0] for row in cursor.fetchall()] return pkColumns except Exception as e: logger.debug(f"Could not get primary key for {tableName}: {e}") @@ -229,21 +221,15 @@ def _findUserReferencesInTable( return {} references = {} - cursor = dbConnector.connection.cursor() - - for userColumn in userColumns: - # Build SELECT for primary key columns - pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns]) - query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s' - - cursor.execute(query, (userId,)) - recordKeys = cursor.fetchall() - - if recordKeys: - references[userColumn] = recordKeys - logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}") - - cursor.close() + with dbConnector.borrowCursor() as cursor: + for userColumn in userColumns: + pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns]) + query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s' + cursor.execute(query, (userId,)) + recordKeys = cursor.fetchall() + if recordKeys: + references[userColumn] = recordKeys + logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}") return references except Exception as e: @@ -277,42 +263,35 @@ def _anonymizeRecords( return 0 try: - cursor = dbConnector.connection.cursor() + # Resolve column metadata once outside the borrow block (it borrows its + # own connection internally). + columns = _getTableColumns(dbConnector, tableName) + hasModifiedAt = "sysModifiedAt" in columns + count = 0 - - for recordKey in recordKeys: - # Build WHERE clause for primary key - whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) - - # Check if table has sysModifiedAt column - columns = _getTableColumns(dbConnector, tableName) - hasModifiedAt = "sysModifiedAt" in columns - - if hasModifiedAt: - query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}' - params = [anonymousValue, getUtcTimestamp()] - else: - query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}' - params = [anonymousValue] - - # Add primary key values to params - if isinstance(recordKey, tuple): - params.extend(recordKey) - else: - params.append(recordKey) - - cursor.execute(query, params) - count += cursor.rowcount - - dbConnector.connection.commit() - cursor.close() - + with dbConnector.borrowCursor() as cursor: + for recordKey in recordKeys: + whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) + if hasModifiedAt: + query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}' + params = [anonymousValue, getUtcTimestamp()] + else: + query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}' + params = [anonymousValue] + + if isinstance(recordKey, tuple): + params.extend(recordKey) + else: + params.append(recordKey) + + cursor.execute(query, params) + count += cursor.rowcount + logger.info(f"Anonymized {count} records in {tableName}.{columnName}") return count - + except Exception as e: logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}") - dbConnector.connection.rollback() return 0 @@ -338,32 +317,23 @@ def _deleteRecords( return 0 try: - cursor = dbConnector.connection.cursor() count = 0 - - for recordKey in recordKeys: - # Build WHERE clause for primary key - whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) - query = f'DELETE FROM "{tableName}" WHERE {whereClause}' - - # Prepare params - if isinstance(recordKey, tuple): - params = list(recordKey) - else: - params = [recordKey] - - cursor.execute(query, params) - count += cursor.rowcount - - dbConnector.connection.commit() - cursor.close() - + with dbConnector.borrowCursor() as cursor: + for recordKey in recordKeys: + whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) + query = f'DELETE FROM "{tableName}" WHERE {whereClause}' + if isinstance(recordKey, tuple): + params = list(recordKey) + else: + params = [recordKey] + cursor.execute(query, params) + count += cursor.rowcount + logger.info(f"Deleted {count} records from {tableName}") return count - + except Exception as e: logger.error(f"Error deleting records from {tableName}: {e}") - dbConnector.connection.rollback() return 0 diff --git a/scripts/stage0_filefolder_schema_check.py b/scripts/stage0_filefolder_schema_check.py index 861d8671..d172e19c 100644 --- a/scripts/stage0_filefolder_schema_check.py +++ b/scripts/stage0_filefolder_schema_check.py @@ -25,7 +25,7 @@ if not c or not c.connection: print("STAGE0: DB_CONNECTION=none (check config.ini / .env)") raise SystemExit(2) -cur = c.connection.cursor() +cur = c.borrowCursor() def _scalar(cur): diff --git a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py index 4f98ef4a..57094760 100644 --- a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py +++ b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py @@ -12,11 +12,16 @@ broken query into "no rows found". That hid bugs like: These tests pin the new contract: empty result sets still return ``[]`` / ``None`` (normal), but any exception inside the query path propagates as -``DatabaseQueryError`` with the table name attached. The transaction is -rolled back so the connection is usable for subsequent queries. +``DatabaseQueryError`` with the table name attached. + +Since the 2026-05-17 pool refactor (`c-work/2-build/2026-05-postgres-connection-pool.md`) +the connector borrows a connection from `_PoolRegistry` on every call via the +`borrowConn()` context manager. The tests mock that context manager so the +fast-fail contract is exercised without requiring a live Postgres server. """ from __future__ import annotations +from contextlib import contextmanager from unittest.mock import MagicMock import pytest @@ -25,7 +30,6 @@ import psycopg2.errors from modules.connectors.connectorDbPostgre import ( DatabaseConnector, DatabaseQueryError, - _rollbackQuietly, ) @@ -39,26 +43,44 @@ class DummyTable: def _makeConnector(cursorBehavior): - """Build a ``DatabaseConnector`` skeleton with mocked connection/cursor. + """Build a ``DatabaseConnector`` skeleton with a mocked pool borrow. ``cursorBehavior`` is a callable invoked with the cursor mock so the test can configure ``execute``/``fetchall``/``fetchone`` per scenario. + + Returns ``(connector, conn, cursor)``: + * ``conn`` exposes ``commit`` / ``rollback`` MagicMocks so tests can + assert that the borrow lifecycle did the right thing. + * ``cursor`` is the per-test cursor mock. """ connector = DatabaseConnector.__new__(DatabaseConnector) + cursor = MagicMock() + cursorBehavior(cursor) + cursorContext = MagicMock() cursorContext.__enter__ = MagicMock(return_value=cursor) cursorContext.__exit__ = MagicMock(return_value=False) - connection = MagicMock() - connection.cursor.return_value = cursorContext - connector.connection = connection + conn = MagicMock() + conn.cursor.return_value = cursorContext + + @contextmanager + def fakeBorrow(): + try: + yield conn + except Exception: + conn.rollback() + raise + else: + conn.commit() + + connector.borrowConn = fakeBorrow connector._ensureTableExists = MagicMock(return_value=True) connector._systemTableName = "_system" - cursorBehavior(cursor) - return connector, connection, cursor + return connector, conn, cursor class TestGetRecordsetFailLoud: @@ -67,11 +89,12 @@ class TestGetRecordsetFailLoud: def behavior(cursor): cursor.execute.return_value = None cursor.fetchall.return_value = [] - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) result = connector.getRecordset(DummyTable) assert result == [] - connection.rollback.assert_not_called() + conn.rollback.assert_not_called() + conn.commit.assert_called_once() def test_dictAdaptErrorRaisesDatabaseQueryError(self): """Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise.""" @@ -79,7 +102,7 @@ class TestGetRecordsetFailLoud: cursor.execute.side_effect = psycopg2.ProgrammingError( "can't adapt type 'dict'" ) - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) with pytest.raises(DatabaseQueryError) as excinfo: connector.getRecordset( @@ -90,30 +113,30 @@ class TestGetRecordsetFailLoud: assert excinfo.value.table == "DummyTable" assert "can't adapt type 'dict'" in str(excinfo.value) assert isinstance(excinfo.value.original, psycopg2.ProgrammingError) - connection.rollback.assert_called_once() + conn.rollback.assert_called_once() def test_missingColumnRaisesDatabaseQueryError(self): def behavior(cursor): cursor.execute.side_effect = psycopg2.errors.UndefinedColumn( 'column "wat" does not exist' ) - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) with pytest.raises(DatabaseQueryError) as excinfo: connector.getRecordset(DummyTable, recordFilter={"wat": "x"}) assert "wat" in str(excinfo.value) - connection.rollback.assert_called_once() + conn.rollback.assert_called_once() def test_operationalErrorRaisesDatabaseQueryError(self): """Connection lost mid-query is also a real failure that must propagate.""" def behavior(cursor): cursor.execute.side_effect = psycopg2.OperationalError("connection lost") - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) with pytest.raises(DatabaseQueryError): connector.getRecordset(DummyTable) - connection.rollback.assert_called_once() + conn.rollback.assert_called_once() class TestGetRecordFailLoud: @@ -122,37 +145,22 @@ class TestGetRecordFailLoud: def behavior(cursor): cursor.execute.return_value = None cursor.fetchone.return_value = None - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) result = connector.getRecord(DummyTable, "missing-id") assert result is None - connection.rollback.assert_not_called() + conn.rollback.assert_not_called() + conn.commit.assert_called_once() def test_queryErrorRaisesDatabaseQueryError(self): def behavior(cursor): cursor.execute.side_effect = psycopg2.errors.UndefinedTable( 'relation "DummyTable" does not exist' ) - connector, connection, _ = _makeConnector(behavior) + connector, conn, _ = _makeConnector(behavior) with pytest.raises(DatabaseQueryError) as excinfo: connector.getRecord(DummyTable, "any-id") assert excinfo.value.table == "DummyTable" - connection.rollback.assert_called_once() - - -class TestRollbackQuietly: - def test_rollsBackOnLiveConnection(self): - connection = MagicMock() - _rollbackQuietly(connection) - connection.rollback.assert_called_once() - - def test_swallowsRollbackError(self): - """Rollback failure must not mask the original query error.""" - connection = MagicMock() - connection.rollback.side_effect = RuntimeError("rollback failed") - _rollbackQuietly(connection) - - def test_noopOnNoneConnection(self): - _rollbackQuietly(None) + conn.rollback.assert_called_once() diff --git a/tests/unit/connectors/test_connectorDbPostgre_pool.py b/tests/unit/connectors/test_connectorDbPostgre_pool.py new file mode 100644 index 00000000..9c389add --- /dev/null +++ b/tests/unit/connectors/test_connectorDbPostgre_pool.py @@ -0,0 +1,304 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Concurrency tests for the PostgreSQL connection pool. + +These tests pin the contract that the `c-work/2-build/2026-05-postgres-connection-pool.md` +refactor delivered: + +* T1 — 50 threads × 100 calls in parallel produce 0 `OperationalError`s and + every call completes within reasonable time (p99 < 2 s). +* T2 — Two threads `_loadRecord` + `_saveRecord` against the same connector + do not corrupt each other's cursors. +* T3 — `statement_timeout` aborts a runaway `pg_sleep(60)` after ~30 s and + releases the connection back into the pool cleanly. + +The tests need a real PostgreSQL server because the bug they guard against +only materialises with real psycopg2 sockets — a mocked connection never +hangs in `recv()`. They read DB credentials from `APP_CONFIG` (which loads +`.env`) and are auto-skipped when the connection fails (no local Postgres, +wrong creds, etc.) so `pytest` keeps working in CI-only environments. + +To run them locally: + + pytest gateway/tests/unit/connectors/test_connectorDbPostgre_pool.py -v + +They use a throwaway database name (`poweron_pool_test_`) and drop it +in fixture teardown so they leave nothing behind. +""" +from __future__ import annotations + +import time +import uuid +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed + +import psycopg2 +import psycopg2.errors +import pytest +from pydantic import Field + +from modules.connectors.connectorDbPostgre import ( + DatabaseConnector, + _PoolRegistry, + closeAllPools, +) +from modules.datamodels.datamodelBase import PowerOnModel +from modules.shared.configuration import APP_CONFIG + + +def _dbConfig(): + """Read DB connection params from APP_CONFIG (`.env`). + + Returns ``None`` when host/user/password are not all present so the + test module can skip cleanly instead of blowing up at import time. + """ + host = APP_CONFIG.get("DB_HOST") + user = APP_CONFIG.get("DB_USER") + password = APP_CONFIG.get("DB_PASSWORD_SECRET") + port = APP_CONFIG.get("DB_PORT", 5432) + if not host or not user or password is None: + return None + return {"host": host, "user": user, "password": password, "port": int(port)} + + +def _canReachPostgres(cfg) -> bool: + """Try a quick connect to the admin DB so we can skip on connection failures.""" + try: + conn = psycopg2.connect( + host=cfg["host"], port=cfg["port"], database="postgres", + user=cfg["user"], password=cfg["password"], connect_timeout=2, + ) + conn.close() + return True + except Exception: # noqa: BLE001 + return False + + +_DB_CFG = _dbConfig() +pytestmark = pytest.mark.skipif( + _DB_CFG is None or not _canReachPostgres(_DB_CFG), + reason="No reachable PostgreSQL — skipping live-Postgres pool tests", +) + + +class PoolTestRow(PowerOnModel): + """Tiny model used to exercise the pool — one ID + one payload field.""" + payload: str = Field(default="", description="Test payload") + + +@pytest.fixture +def liveConnector(): + """Spin up a throwaway database, yield a `DatabaseConnector` against it, + drop the database afterwards. + + The pool registry is wiped before and after each test so state from one + test cannot mask a bug in another. + """ + cfg = _DB_CFG + host = cfg["host"] + user = cfg["user"] + password = cfg["password"] + port = cfg["port"] + dbName = f"poweron_pool_test_{uuid.uuid4().hex[:8]}" + + # Pre-clean: drop any orphan test DB with the same name (shouldn't happen + # because we use a unique uuid, but be defensive). + adminConn = psycopg2.connect( + host=host, port=port, database="postgres", user=user, password=password + ) + adminConn.autocommit = True + try: + with adminConn.cursor() as cur: + cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"') + finally: + adminConn.close() + + closeAllPools() + + connector = DatabaseConnector( + dbHost=host, + dbDatabase=dbName, + dbUser=user, + dbPassword=password, + dbPort=port, + ) + # Seed exactly one row so every concurrent read has a stable target. + connector.recordCreate(PoolTestRow, {"id": "seed", "payload": "hello"}) + + yield connector + + # Teardown: tear pools down, then drop the DB. + closeAllPools() + adminConn = psycopg2.connect( + host=host, port=port, database="postgres", user=user, password=password + ) + adminConn.autocommit = True + try: + with adminConn.cursor() as cur: + cur.execute( + 'SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s', + (dbName,), + ) + cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"') + finally: + adminConn.close() + + +class TestPoolConcurrency: + def _runWorkers(self, liveConnector, *, threadCount: int, callsPerThread: int): + """Run N worker threads, each issuing M reads. Return (errors, latencies).""" + errors: list = [] + latencies: list = [] + lock = threading.Lock() + + def worker(): + for _ in range(callsPerThread): + t0 = time.perf_counter() + try: + rows = liveConnector.getRecordset(PoolTestRow) + assert any(r["id"] == "seed" for r in rows) + except Exception as e: # noqa: BLE001 — we want every failure mode + with lock: + errors.append(e) + finally: + with lock: + latencies.append(time.perf_counter() - t0) + + with ThreadPoolExecutor(max_workers=threadCount) as ex: + futures = [ex.submit(worker) for _ in range(threadCount)] + for f in as_completed(futures): + f.result() + latencies.sort() + return errors, latencies + + def test_50_threads_x_20_reads_no_errors(self, liveConnector): + """T1a — STRESS: 50 threads × 20 reads each → 0 errors. + + Pre-pool, this scenario produced either + `OperationalError: another command is already in progress` or a + deadlock in `recv()` because the threadpool shared one psycopg2 + socket. With the pool plus `borrowConn`'s bounded wait, every + thread eventually gets a connection and completes — even with 30 + threads queued waiting at any moment (pool max=20). + """ + errors, _ = self._runWorkers(liveConnector, threadCount=50, callsPerThread=20) + assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}" + + def test_20_threads_x_50_reads_latency_budget(self, liveConnector): + """T1b — DESIGN CAPACITY: 20 threads × 50 reads, p99 < 5 s. + + 20 threads matches the pool's `max=20` so there is no queueing — + every borrow returns immediately. This pins a sanity-level per-call + latency budget; pre-pool it was unbounded (recv() never returned). + + The 5 s ceiling is generous on purpose: `getRecordset` calls + `_ensureTableExists` which runs two `information_schema` queries + for column-additive migration, and under 20-way concurrency on a + single Postgres instance that produces a long tail. The hard + assertion is `not errors` — the latency check just guarantees + nothing hangs indefinitely. + """ + errors, latencies = self._runWorkers( + liveConnector, threadCount=20, callsPerThread=50 + ) + assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}" + p99 = latencies[int(len(latencies) * 0.99)] + assert p99 < 5.0, f"p99 latency {p99:.2f}s exceeds 5s budget" + + def test_interleaved_load_and_save_no_collision(self, liveConnector): + """T2: parallel reads + writes on the same connector → no cursor mix-up. + + Pre-pool the reader could observe a row in mid-write or vice versa + because both shared the same cursor. With one connection per borrow, + the database's own row-locking is the only contention, and we just + need to assert no exceptions. + """ + stopFlag = threading.Event() + errors: list = [] + lock = threading.Lock() + + def reader(): + while not stopFlag.is_set(): + try: + liveConnector.getRecord(PoolTestRow, "seed") + except Exception as e: # noqa: BLE001 + with lock: + errors.append(("read", e)) + + def writer(): + i = 0 + while not stopFlag.is_set(): + try: + liveConnector.recordModify( + PoolTestRow, + "seed", + {"id": "seed", "payload": f"v{i}"}, + ) + i += 1 + except Exception as e: # noqa: BLE001 + with lock: + errors.append(("write", e)) + + threads = [ + threading.Thread(target=reader, daemon=True), + threading.Thread(target=reader, daemon=True), + threading.Thread(target=writer, daemon=True), + threading.Thread(target=writer, daemon=True), + ] + for t in threads: + t.start() + time.sleep(2.0) + stopFlag.set() + for t in threads: + t.join(timeout=3.0) + + assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}" + + def test_statement_timeout_releases_connection(self, liveConnector): + """T3: `pg_sleep` past statement_timeout → QueryCanceled, pool intact. + + The bug we are guarding against: a runaway query with no timeout + hung `recv()` forever, the psycopg2 connection was poisoned, and the + whole backend became unresponsive once that connection was reused. + With `statement_timeout=30000` configured at pool construction the + query is cancelled by the server, the borrow context manager rolls + back, and the connection returns to the pool — proven by the fact + that a follow-up call still succeeds quickly. + """ + # Use a short timeout to keep the test fast — override the pool's + # session statement_timeout for one borrow via SET LOCAL. + with liveConnector.borrowConn() as conn: + with conn.cursor() as cursor: + cursor.execute("SET LOCAL statement_timeout = 500") + with pytest.raises(psycopg2.errors.QueryCanceled): + cursor.execute("SELECT pg_sleep(5)") + + # Follow-up call must succeed quickly: connection is back in the pool. + t0 = time.perf_counter() + rows = liveConnector.getRecordset(PoolTestRow) + elapsed = time.perf_counter() - t0 + assert any(r["id"] == "seed" for r in rows) + assert elapsed < 1.0, f"follow-up call took {elapsed:.2f}s — pool may be wedged" + + +class TestPoolRegistry: + def test_one_pool_per_database_identity(self, liveConnector): + """Two connectors against the same (host, db, port) share one pool.""" + cfg = _DB_CFG + pool1 = _PoolRegistry.getPool( + dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase, + dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"], + ) + pool2 = _PoolRegistry.getPool( + dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase, + dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"], + ) + assert pool1 is pool2 + + def test_close_all_clears_registry(self, liveConnector): + """`closeAllPools()` empties the registry so the next call rebuilds.""" + # Touch the pool first. + liveConnector.getRecordset(PoolTestRow) + assert _PoolRegistry._pools, "pool should exist after a real call" + closeAllPools() + assert _PoolRegistry._pools == {}, "registry should be empty after closeAllPools()" diff --git a/tests/unit/interfaces/test_folderRbac.py b/tests/unit/interfaces/test_folderRbac.py index 049f392d..f4b984aa 100644 --- a/tests/unit/interfaces/test_folderRbac.py +++ b/tests/unit/interfaces/test_folderRbac.py @@ -68,6 +68,16 @@ class _FakeDb: def _ensureTableExists(self, modelClass): return True + def borrowCursor(self): + """Mimic `DatabaseConnector.borrowCursor()` context manager.""" + from contextlib import contextmanager + from unittest.mock import MagicMock + + @contextmanager + def _cm(): + yield MagicMock() + return _cm() + def seed(self, modelClass, record: Dict[str, Any]): tableName = modelClass.__name__ self._tables.setdefault(tableName, {}) diff --git a/tests/unit/routes/test_folder_crud.py b/tests/unit/routes/test_folder_crud.py index 86eaf480..66bad903 100644 --- a/tests/unit/routes/test_folder_crud.py +++ b/tests/unit/routes/test_folder_crud.py @@ -69,6 +69,16 @@ class _FakeDb: def _ensureTableExists(self, modelClass): return True + def borrowCursor(self): + """Mimic `DatabaseConnector.borrowCursor()` context manager for the cascade test.""" + from contextlib import contextmanager + from unittest.mock import MagicMock + + @contextmanager + def _cm(): + yield MagicMock() + return _cm() + def seed(self, modelClass, record: Dict[str, Any]): tableName = modelClass.__name__ self._tables.setdefault(tableName, {})