From 2bb65c230369b3b54e8f8c06e32d8296d1c8e100 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Sun, 17 May 2026 20:38:37 +0200
Subject: [PATCH] db connection pooling and rag limit transparency
---
app.py | 11 +-
modules/connectors/connectorDbPostgre.py | 1180 ++++++++++-------
.../realEstate/interfaceFeatureRealEstate.py | 8 +-
modules/interfaces/interfaceDbBilling.py | 18 +-
modules/interfaces/interfaceDbKnowledge.py | 16 +
modules/interfaces/interfaceDbManagement.py | 32 +-
modules/interfaces/interfaceRbac.py | 12 +-
modules/routes/routeHelpers.py | 2 +-
modules/routes/routeRagInventory.py | 65 +-
modules/routes/routeWorkflowDashboard.py | 4 +-
.../coreTools/_featureSubAgentTools.py | 15 +-
.../subConnectorSyncClickup.py | 43 +-
.../subConnectorSyncGdrive.py | 43 +-
.../subConnectorSyncKdrive.py | 46 +-
.../subConnectorSyncSharepoint.py | 56 +-
modules/shared/configuration.py | 63 +-
modules/shared/dbMultiTenantOptimizations.py | 101 +-
modules/shared/gdprDeletion.py | 178 ++-
scripts/stage0_filefolder_schema_check.py | 2 +-
.../test_connectorDbPostgre_failLoud.py | 82 +-
.../test_connectorDbPostgre_pool.py | 304 +++++
tests/unit/interfaces/test_folderRbac.py | 10 +
tests/unit/routes/test_folder_crud.py | 10 +
23 files changed, 1519 insertions(+), 782 deletions(-)
create mode 100644 tests/unit/connectors/test_connectorDbPostgre_pool.py
diff --git a/app.py b/app.py
index 7a4ed4d4..d94c7dd5 100644
--- a/app.py
+++ b/app.py
@@ -438,7 +438,16 @@ async def lifespan(app: FastAPI):
logger.error(f"Feature '{featureName}' failed to stop: {e}")
except Exception as e:
logger.warning(f"Could not shutdown feature containers: {e}")
-
+
+ # --- Close all PostgreSQL connection pools ---
+ # Must run LAST: feature `onStop` hooks may still issue DB calls during
+ # shutdown. Once we tear down the pools, no more borrows are possible.
+ try:
+ from modules.connectors.connectorDbPostgre import closeAllPools
+ closeAllPools()
+ except Exception as e:
+ logger.warning(f"Closing DB connection pools failed: {e}")
+
logger.info("Application has been shut down")
diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py
index a6893396..f1a34f70 100644
--- a/modules/connectors/connectorDbPostgre.py
+++ b/modules/connectors/connectorDbPostgre.py
@@ -2,9 +2,12 @@
# All rights reserved.
import contextvars
import re
+import time
import psycopg2
import psycopg2.extras
+import psycopg2.pool
import logging
+from contextlib import contextmanager
from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type
import uuid
from pydantic import BaseModel, Field
@@ -44,24 +47,6 @@ class DatabaseQueryError(RuntimeError):
self.original = original
-def _rollbackQuietly(connection) -> None:
- """Restore the connection state after a failed query.
-
- Postgres puts the connection in an error state after any failed
- statement; subsequent queries on the same connection raise
- ``InFailedSqlTransaction`` until we rollback. We swallow rollback
- errors because the original query error is what the caller should
- see — a secondary rollback failure typically means the connection
- is gone and will be reopened on the next ``_ensure_connection``.
- """
- if connection is None:
- return
- try:
- connection.rollback()
- except Exception:
- pass
-
-
class SystemTable(PowerOnModel):
"""Data model for system table entries"""
@@ -203,9 +188,174 @@ def _quotePgIdent(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"'
-# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
-# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
-# contextvars to avoid races when concurrent requests share the same connector.
+# ---------------------------------------------------------------------------
+# Connection pool registry
+# ---------------------------------------------------------------------------
+# psycopg2 connections are NOT thread-safe; sharing one connection across the
+# FastAPI threadpool (sync `def` routes) or across multiple async tasks that
+# happen to be active simultaneously results in either
+# `OperationalError: another command is already in progress` or — far worse —
+# an unbounded hang in `recv()` on the connection socket, because two cursors
+# wait for one server response.
+#
+# `_PoolRegistry` keeps one `ThreadedConnectionPool` per database identity
+# (`host`, `db`, `port`). Every DB call borrows a connection from the pool,
+# runs its query, and returns the connection. The pool itself guarantees that
+# no two callers share the same psycopg2 connection at the same time.
+#
+# `statement_timeout=30000` (30 s) is a safety net: a runaway query is aborted
+# instead of hanging forever and poisoning the connection. `connect_timeout=10`
+# prevents indefinite blocking when the Postgres host is unreachable.
+
+_DEFAULT_POOL_MIN = 2
+_DEFAULT_POOL_MAX = 20
+_STATEMENT_TIMEOUT_MS = 30000
+_CONNECT_TIMEOUT_S = 10
+# psycopg2.pool.ThreadedConnectionPool.getconn() does NOT block when the pool
+# is exhausted — it raises `psycopg2.pool.PoolError` immediately. That makes
+# bursty workloads (50 concurrent route calls against a max=20 pool) fail
+# spuriously. `borrowConn()` therefore retries with a small backoff up to
+# `_BORROW_WAIT_TIMEOUT_S` seconds before giving up.
+_BORROW_WAIT_TIMEOUT_S = 30.0
+_BORROW_WAIT_BACKOFF_S = 0.05
+
+
+def _resolvePoolMax() -> int:
+ """Pool size is configurable via `DB_POOL_MAX_CONN` (default 20)."""
+ try:
+ return max(2, int(APP_CONFIG.get("DB_POOL_MAX_CONN") or _DEFAULT_POOL_MAX))
+ except (ValueError, TypeError):
+ return _DEFAULT_POOL_MAX
+
+
+class _PoolRegistry:
+ """Process-wide registry of `ThreadedConnectionPool` instances.
+
+ Keyed by `(host, database, port)` so that two `DatabaseConnector` instances
+ pointing at the same physical database share one pool. Lazy-initialised and
+ thread-safe.
+ """
+
+ _pools: Dict[tuple, psycopg2.pool.ThreadedConnectionPool] = {}
+ _lock = threading.Lock()
+
+ @classmethod
+ def getPool(
+ cls,
+ *,
+ dbHost: str,
+ dbDatabase: str,
+ dbUser: str,
+ dbPassword: str,
+ dbPort: int,
+ ) -> psycopg2.pool.ThreadedConnectionPool:
+ port = int(dbPort) if dbPort is not None else 5432
+ key = (dbHost, dbDatabase, port)
+ # Fast path: pool exists.
+ pool = cls._pools.get(key)
+ if pool is not None:
+ return pool
+ # Slow path: create exactly one pool per key, even under contention.
+ with cls._lock:
+ pool = cls._pools.get(key)
+ if pool is not None:
+ return pool
+ poolMax = _resolvePoolMax()
+ options = f"-c statement_timeout={_STATEMENT_TIMEOUT_MS}"
+ try:
+ pool = psycopg2.pool.ThreadedConnectionPool(
+ _DEFAULT_POOL_MIN,
+ poolMax,
+ host=dbHost,
+ port=port,
+ database=dbDatabase,
+ user=dbUser,
+ password=dbPassword,
+ client_encoding="utf8",
+ cursor_factory=psycopg2.extras.RealDictCursor,
+ connect_timeout=_CONNECT_TIMEOUT_S,
+ options=options,
+ )
+ except Exception as e:
+ logger.error(
+ "Failed to create connection pool for db=%s host=%s: %s",
+ dbDatabase, dbHost, e,
+ )
+ raise
+ cls._pools[key] = pool
+ logger.debug(
+ "Created connection pool for db=%s host=%s port=%s (min=%d max=%d, stmt_timeout=%dms)",
+ dbDatabase, dbHost, port, _DEFAULT_POOL_MIN, poolMax, _STATEMENT_TIMEOUT_MS,
+ )
+ return pool
+
+ @classmethod
+ def closeAll(cls) -> None:
+ """Close every pool. Call once during FastAPI shutdown."""
+ with cls._lock:
+ for key, pool in list(cls._pools.items()):
+ try:
+ pool.closeall()
+ except Exception as e:
+ logger.warning("Error closing pool %s: %s", key, e)
+ cls._pools.clear()
+ logger.info("All database connection pools closed")
+
+
+def closeAllPools() -> None:
+ """Public entry point for FastAPI lifespan shutdown hook."""
+ _PoolRegistry.closeAll()
+
+
+class _ConnectionShim:
+ """Backward-compatibility stand-in for the old `DatabaseConnector.connection`.
+
+ Old callers used patterns like::
+
+ db.connection.commit()
+ if db.connection and not db.connection.closed: ...
+
+ These are now no-ops because the pool owns the connection lifecycle.
+ Direct cursor access through `self.connection.cursor()` is intentionally
+ blocked with a clear error so silent breakage is impossible — every such
+ call site must migrate to `db.borrowCursor()`.
+ """
+
+ closed = False
+
+ def __bool__(self) -> bool:
+ return True
+
+ def commit(self) -> None:
+ return
+
+ def rollback(self) -> None:
+ return
+
+ def close(self) -> None:
+ return
+
+ def cursor(self, *args, **kwargs):
+ raise RuntimeError(
+ "DatabaseConnector.connection.cursor() is no longer supported. "
+ "Use `db.borrowCursor()` (or `db.borrowConn()` for multi-statement "
+ "transactions) so the connection is borrowed from and returned to "
+ "the pool correctly."
+ )
+
+
+_CONNECTION_SHIM = _ConnectionShim()
+
+
+# ---------------------------------------------------------------------------
+# Connector cache (lightweight wrappers — actual connections live in the pool)
+# ---------------------------------------------------------------------------
+# Multiple call sites (`routeI18n`, `aiAuditLogger`, `mainBackgroundJobService`,
+# interfaces) ask for a connector via `getCachedConnector(...)` and expect to
+# get back the same object on subsequent calls. Now that real connection
+# multiplexing happens at the pool layer, the cache returns lightweight
+# `DatabaseConnector` wrappers — they hold no connection themselves, only the
+# DSN params and a reference to the shared pool.
_MAX_CACHED_CONNECTORS = 32
_connector_cache: Dict[tuple, "DatabaseConnector"] = {}
_connector_cache_order: List[tuple] = [] # FIFO order for eviction
@@ -223,22 +373,36 @@ def getCachedConnector(
dbPort: int = None,
userId: str = None,
) -> "DatabaseConnector":
- """Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits.
- Uses contextvars for userId so concurrent requests sharing the same connector get correct sysCreatedBy/sysModifiedBy.
+ """Return a cached `DatabaseConnector` wrapper for `(host, database, port)`.
+
+ Two layers of caching, both intentional:
+
+ 1. **Pool layer** (`_PoolRegistry`) — owns the actual psycopg2 connections.
+ Per `(host, db, port)` exactly one `ThreadedConnectionPool`, shared
+ across every wrapper that points at the same physical database. This
+ is what actually saves Postgres connection slots.
+
+ 2. **Wrapper layer** (this function) — caches the `DatabaseConnector`
+ Python object. Each wrapper triggers an `initDbSystem()` on first
+ instantiation (CREATE DATABASE if missing, CREATE TABLE _system,
+ pool warm-up). Caching the wrapper avoids paying that bootstrap cost
+ on every request and stops the log from filling with "PostgreSQL
+ database system initialized" lines.
+
+ `userId` is request-scoped via the `_current_user_id` contextvar so two
+ concurrent requests sharing the same cached wrapper still produce
+ correct `sysCreatedBy` / `sysModifiedBy` audit fields.
"""
port = int(dbPort) if dbPort is not None else 5432
key = (dbHost, dbDatabase, port)
with _connector_cache_lock:
if key not in _connector_cache:
- # Evict oldest if at capacity
+ # FIFO eviction. Connectors are now lightweight (no per-instance
+ # connection), so eviction is purely a memory bookkeeping concern;
+ # the underlying pool stays alive in `_PoolRegistry`.
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
oldest_key = _connector_cache_order.pop(0)
- if oldest_key in _connector_cache:
- try:
- _connector_cache[oldest_key].close(forceClose=True)
- except Exception as e:
- logger.warning(f"Error closing evicted connector: {e}")
- del _connector_cache[oldest_key]
+ _connector_cache.pop(oldest_key, None)
_connector_cache[key] = DatabaseConnector(
dbHost=dbHost,
dbDatabase=dbDatabase,
@@ -282,34 +446,38 @@ class DatabaseConnector:
# Set userId (default to empty string if None)
self.userId = userId if userId is not None else ""
- # Initialize database system first (creates database if needed)
- self.connection = None
+ # No per-instance connection any more — real connections live in the
+ # shared `_PoolRegistry` pool. `_isCachedShared` is retained because
+ # `close(forceClose=False)` callers (interface __del__) still ask.
self._isCachedShared = False
- self.initDbSystem()
- # No caching needed with proper database - PostgreSQL handles performance
-
- # Thread safety
- self._lock = threading.Lock()
-
- # pgvector extension state
+ # pgvector extension state (cached per connector instance — cheap)
self._vectorExtensionEnabled = False
- # Initialize system table
+ # System table bootstrap: create database, system table, ensure metadata.
self._systemTableName = "_system"
+ self.initDbSystem()
self._initializeSystemTable()
def initDbSystem(self):
- """Initialize the database system - creates database and tables."""
- try:
- # Create database if it doesn't exist
- self._create_database_if_not_exists()
+ """Bootstrap the physical database and the `_system` metadata table.
- # Create tables
+ Uses a short-lived autocommit connection (NOT the pool) because
+ `CREATE DATABASE` cannot run inside a transaction block. Also warms up
+ the pool by acquiring it for this `(host, db, port)` once.
+ """
+ try:
+ self._create_database_if_not_exists()
self._create_tables()
- # Establish connection to the database
- self._connect()
+ # Warm the pool so the first request doesn't pay for socket setup.
+ _PoolRegistry.getPool(
+ dbHost=self.dbHost,
+ dbDatabase=self.dbDatabase,
+ dbUser=self.dbUser,
+ dbPassword=self.dbPassword,
+ dbPort=self.dbPort,
+ )
logger.debug(
"PostgreSQL database system initialized (db=%s, host=%s, port=%s)",
@@ -319,10 +487,121 @@ class DatabaseConnector:
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
raise
- def _create_database_if_not_exists(self):
- """Create the database if it doesn't exist."""
+ @property
+ def connection(self) -> "_ConnectionShim":
+ """Backward-compat shim — see `_ConnectionShim` docstring."""
+ return _CONNECTION_SHIM
+
+ def _ensure_connection(self) -> None:
+ """No-op for backward compatibility.
+
+ Previously this method tested-or-reconnected the per-instance socket.
+ With pooling, the `ThreadedConnectionPool` re-establishes broken
+ connections lazily on the next `getconn()`. Kept as a no-op so that
+ legacy call sites continue to compile.
+ """
+ return
+
+ @contextmanager
+ def borrowCursor(self):
+ """Borrow a cursor for one short SQL block.
+
+ Convenience wrapper around `borrowConn()` for the common pattern:
+
+ with db.borrowCursor() as cursor:
+ cursor.execute(...)
+ rows = cursor.fetchall()
+
+ Replaces the legacy `with db.connection.cursor() as cursor:` pattern.
+ Commit/rollback and pool return are handled automatically — callers
+ must NOT call `.commit()`/`.rollback()` on the cursor themselves.
+ """
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ yield cursor
+
+ @contextmanager
+ def borrowConn(self):
+ """Borrow a connection from the pool for the duration of the block.
+
+ Pool-exhaustion semantics: `ThreadedConnectionPool.getconn()` raises
+ `psycopg2.pool.PoolError` immediately when the pool is at its `maxconn`
+ limit — it does NOT block. We wrap that with a bounded busy-wait so
+ bursty workloads (e.g. 50 concurrent route calls against a max=20
+ pool) queue up instead of failing. If no connection becomes available
+ within `_BORROW_WAIT_TIMEOUT_S`, the `PoolError` is propagated so a
+ genuinely deadlocked pool surfaces as a real error.
+
+ On normal exit the current transaction is committed (idempotent for
+ read-only queries — leaves the connection in a clean state for the
+ next borrower). On exception the transaction is rolled back and the
+ exception propagates. The connection is **always** returned to the
+ pool, even when commit/rollback fails — otherwise a single bad query
+ would leak a slot and eventually exhaust the pool.
+ """
+ pool = _PoolRegistry.getPool(
+ dbHost=self.dbHost,
+ dbDatabase=self.dbDatabase,
+ dbUser=self.dbUser,
+ dbPassword=self.dbPassword,
+ dbPort=self.dbPort,
+ )
+ conn = self._acquireConn(pool)
+ try:
+ yield conn
+ except Exception:
+ try:
+ conn.rollback()
+ except Exception:
+ pass
+ raise
+ else:
+ # Best-effort commit so the connection goes back to the pool with
+ # no in-flight transaction. Failure here is non-fatal — the pool
+ # will re-establish the socket on the next `getconn()` if needed.
+ try:
+ conn.commit()
+ except Exception:
+ try:
+ conn.rollback()
+ except Exception:
+ pass
+ finally:
+ try:
+ pool.putconn(conn)
+ except Exception as e:
+ logger.warning("Failed to return connection to pool: %s", e)
+
+ @staticmethod
+ def _acquireConn(pool: psycopg2.pool.ThreadedConnectionPool):
+ """Get a connection from the pool, waiting up to `_BORROW_WAIT_TIMEOUT_S`.
+
+ psycopg2's pool throws on exhaustion instead of queueing — this helper
+ polls with a short backoff so callers see queue semantics.
+ """
+ deadline = time.monotonic() + _BORROW_WAIT_TIMEOUT_S
+ attempt = 0
+ while True:
+ try:
+ return pool.getconn()
+ except psycopg2.pool.PoolError as e:
+ attempt += 1
+ if time.monotonic() >= deadline:
+ logger.error(
+ "Connection pool exhausted after %.1fs wait (%d retries)",
+ _BORROW_WAIT_TIMEOUT_S, attempt,
+ )
+ raise
+ time.sleep(_BORROW_WAIT_BACKOFF_S)
+
+ def _create_database_if_not_exists(self):
+ """Create the database if it doesn't exist.
+
+ Uses an autocommit connection on the `postgres` admin DB because
+ `CREATE DATABASE` cannot run inside a transaction block — so this
+ path intentionally does NOT use the pool.
+ """
try:
- # Use the configured user for database creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
@@ -330,23 +609,21 @@ class DatabaseConnector:
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
+ connect_timeout=_CONNECT_TIMEOUT_S,
)
conn.autocommit = True
-
- with conn.cursor() as cursor:
- # Check if database exists
- cursor.execute(
- "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
- )
- exists = cursor.fetchone()
-
- if not exists:
- # Create database with proper quoting for names with hyphens
- quoted_db_name = f'"{self.dbDatabase}"'
- cursor.execute(f"CREATE DATABASE {quoted_db_name}")
- logger.info(f"Created database: {self.dbDatabase}")
-
- conn.close()
+ try:
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
+ )
+ exists = cursor.fetchone()
+ if not exists:
+ quoted_db_name = f'"{self.dbDatabase}"'
+ cursor.execute(f"CREATE DATABASE {quoted_db_name}")
+ logger.info(f"Created database: {self.dbDatabase}")
+ finally:
+ conn.close()
except Exception as e:
logger.error(f"FATAL ERROR: Cannot create database: {e}")
@@ -356,9 +633,12 @@ class DatabaseConnector:
)
def _create_tables(self):
- """Create only the system table - application tables are created by interfaces."""
+ """Create the `_system` table.
+
+ Uses a short-lived autocommit connection (not the pool) — runs exactly
+ once at connector creation.
+ """
try:
- # Use the configured user for table creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
@@ -366,23 +646,24 @@ class DatabaseConnector:
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
+ connect_timeout=_CONNECT_TIMEOUT_S,
)
conn.autocommit = True
-
- with conn.cursor() as cursor:
- # Create only the system table
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS _system (
- id SERIAL PRIMARY KEY,
- table_name VARCHAR(255) UNIQUE NOT NULL,
- initial_id VARCHAR(255) NOT NULL,
- "sysCreatedAt" DOUBLE PRECISION,
- "sysCreatedBy" VARCHAR(255),
- "sysModifiedAt" DOUBLE PRECISION,
- "sysModifiedBy" VARCHAR(255)
- )
- """)
- conn.close()
+ try:
+ with conn.cursor() as cursor:
+ cursor.execute("""
+ CREATE TABLE IF NOT EXISTS _system (
+ id SERIAL PRIMARY KEY,
+ table_name VARCHAR(255) UNIQUE NOT NULL,
+ initial_id VARCHAR(255) NOT NULL,
+ "sysCreatedAt" DOUBLE PRECISION,
+ "sysCreatedBy" VARCHAR(255),
+ "sysModifiedAt" DOUBLE PRECISION,
+ "sysModifiedBy" VARCHAR(255)
+ )
+ """)
+ finally:
+ conn.close()
except Exception as e:
logger.error(f"FATAL ERROR: Cannot create system table: {e}")
@@ -391,67 +672,26 @@ class DatabaseConnector:
)
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
- def _connect(self):
- """Establish connection to PostgreSQL database."""
- try:
- # Use configured user for main connection with proper parameter handling
- self.connection = psycopg2.connect(
- host=self.dbHost,
- port=self.dbPort,
- database=self.dbDatabase,
- user=self.dbUser,
- password=self.dbPassword,
- client_encoding="utf8",
- cursor_factory=psycopg2.extras.RealDictCursor,
- )
- self.connection.autocommit = False # Use transactions
- except Exception as e:
- logger.error(f"Failed to connect to PostgreSQL: {e}")
- raise
-
- def _ensure_connection(self):
- """Ensure database connection is alive, reconnect if necessary."""
- try:
- if self.connection is None or self.connection.closed:
- self._connect()
- else:
- # Test connection with a simple query
- with self.connection.cursor() as cursor:
- cursor.execute("SELECT 1")
- except Exception as e:
- logger.warning(f"Connection lost, reconnecting: {e}")
- self._connect()
-
def _initializeSystemTable(self):
"""Initializes the system table if it doesn't exist yet."""
try:
- # First ensure the system table exists
self._ensureTableExists(SystemTable)
-
- with self.connection.cursor() as cursor:
- # Check if system table has any data
- cursor.execute('SELECT COUNT(*) FROM "_system"')
- row = cursor.fetchone()
- count = row["count"] if row else 0
-
- self.connection.commit()
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute('SELECT COUNT(*) FROM "_system"')
+ cursor.fetchone() # noqa: just verifies table is readable
except Exception as e:
logger.error(f"Error initializing system table: {e}")
- self.connection.rollback()
raise
def _loadSystemTable(self) -> Dict[str, str]:
"""Loads the system table with the initial IDs."""
try:
- with self.connection.cursor() as cursor:
- cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
- rows = cursor.fetchall()
-
- system_data = {}
- for row in rows:
- system_data[row["table_name"]] = row["initial_id"]
-
- return system_data
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
+ rows = cursor.fetchall()
+ return {row["table_name"]: row["initial_id"] for row in rows}
except Exception as e:
logger.error(f"Error loading system table: {e}")
return {}
@@ -459,75 +699,65 @@ class DatabaseConnector:
def _saveSystemTable(self, data: Dict[str, str]) -> bool:
"""Saves the system table with the initial IDs."""
try:
- with self.connection.cursor() as cursor:
- # Clear existing data
- cursor.execute('DELETE FROM "_system"')
-
- # Insert new data
- for table_name, initial_id in data.items():
- cursor.execute(
- """
- INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt")
- VALUES (%s, %s, %s)
- """,
- (table_name, initial_id, getUtcTimestamp()),
- )
-
- self.connection.commit()
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute('DELETE FROM "_system"')
+ for table_name, initial_id in data.items():
+ cursor.execute(
+ """
+ INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt")
+ VALUES (%s, %s, %s)
+ """,
+ (table_name, initial_id, getUtcTimestamp()),
+ )
return True
except Exception as e:
logger.error(f"Error saving system table: {e}")
- self.connection.rollback()
return False
def _ensureSystemTableExists(self) -> bool:
"""Ensures the system table exists, creates it if it doesn't."""
try:
- self._ensure_connection()
-
- with self.connection.cursor() as cursor:
- # Check if system table exists
- cursor.execute(
- "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
- (self._systemTableName,),
- )
- exists = cursor.fetchone()["count"] > 0
-
- if not exists:
- # Create system table
- cursor.execute(f"""
- CREATE TABLE "{self._systemTableName}" (
- "table_name" VARCHAR(255) PRIMARY KEY,
- "initial_id" VARCHAR(255),
- "sysCreatedAt" DOUBLE PRECISION,
- "sysCreatedBy" VARCHAR(255),
- "sysModifiedAt" DOUBLE PRECISION,
- "sysModifiedBy" VARCHAR(255)
- )
- """)
- logger.info("System table created successfully")
- else:
- # Check if we need to add missing columns to existing table
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
cursor.execute(
- """
- SELECT column_name FROM information_schema.columns
- WHERE table_name = %s AND table_schema = 'public'
- """,
+ "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
(self._systemTableName,),
)
- existing_columns = [row["column_name"] for row in cursor.fetchall()]
+ exists = cursor.fetchone()["count"] > 0
- for sys_col, sys_sql in [
- ("sysCreatedAt", "DOUBLE PRECISION"),
- ("sysCreatedBy", "VARCHAR(255)"),
- ("sysModifiedAt", "DOUBLE PRECISION"),
- ("sysModifiedBy", "VARCHAR(255)"),
- ]:
- if sys_col not in existing_columns:
- cursor.execute(
- f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
+ if not exists:
+ cursor.execute(f"""
+ CREATE TABLE "{self._systemTableName}" (
+ "table_name" VARCHAR(255) PRIMARY KEY,
+ "initial_id" VARCHAR(255),
+ "sysCreatedAt" DOUBLE PRECISION,
+ "sysCreatedBy" VARCHAR(255),
+ "sysModifiedAt" DOUBLE PRECISION,
+ "sysModifiedBy" VARCHAR(255)
)
+ """)
+ logger.info("System table created successfully")
+ else:
+ cursor.execute(
+ """
+ SELECT column_name FROM information_schema.columns
+ WHERE table_name = %s AND table_schema = 'public'
+ """,
+ (self._systemTableName,),
+ )
+ existing_columns = [row["column_name"] for row in cursor.fetchall()]
+ for sys_col, sys_sql in [
+ ("sysCreatedAt", "DOUBLE PRECISION"),
+ ("sysCreatedBy", "VARCHAR(255)"),
+ ("sysModifiedAt", "DOUBLE PRECISION"),
+ ("sysModifiedBy", "VARCHAR(255)"),
+ ]:
+ if sys_col not in existing_columns:
+ cursor.execute(
+ f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
+ )
return True
except Exception as e:
logger.error(f"Error ensuring system table exists: {e}")
@@ -542,123 +772,113 @@ class DatabaseConnector:
return self._ensureSystemTableExists()
try:
- self._ensure_connection()
-
- with self.connection.cursor() as cursor:
- # Check if table exists by querying information_schema with case-insensitive search
- cursor.execute(
- """
- SELECT COUNT(*) FROM information_schema.tables
- WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
- """,
- (table,),
- )
- exists = cursor.fetchone()["count"] > 0
-
- if not exists:
- # Create table from Pydantic model
- self._create_table_from_model(cursor, table, model_class)
- logger.info(
- f"Created table '{table}' with columns from Pydantic model"
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(
+ """
+ SELECT COUNT(*) FROM information_schema.tables
+ WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
+ """,
+ (table,),
)
- else:
- # Table exists: ensure all columns from model are present (simple additive migration)
- try:
- cursor.execute(
- """
- SELECT column_name, data_type
- FROM information_schema.columns
- WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
- """,
- (table,),
+ exists = cursor.fetchone()["count"] > 0
+
+ if not exists:
+ self._create_table_from_model(cursor, table, model_class)
+ logger.info(
+ f"Created table '{table}' with columns from Pydantic model"
)
- existing_column_rows = cursor.fetchall()
- existing_columns = {
- row["column_name"] for row in existing_column_rows
- }
- existing_column_types = {
- row["column_name"]: (row["data_type"] or "").lower()
- for row in existing_column_rows
- }
+ else:
+ # Table exists: ensure all columns from model are present (simple additive migration)
+ try:
+ cursor.execute(
+ """
+ SELECT column_name, data_type
+ FROM information_schema.columns
+ WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
+ """,
+ (table,),
+ )
+ existing_column_rows = cursor.fetchall()
+ existing_columns = {
+ row["column_name"] for row in existing_column_rows
+ }
+ existing_column_types = {
+ row["column_name"]: (row["data_type"] or "").lower()
+ for row in existing_column_rows
+ }
- # Desired columns based on model
- model_fields = getModelFields(model_class)
- desired_columns = set(["id"]) | set(model_fields.keys())
+ model_fields = getModelFields(model_class)
+ desired_columns = set(["id"]) | set(model_fields.keys())
- # Add missing columns
- for col in sorted(desired_columns - existing_columns):
- # Determine SQL type
- if col in ["id"]:
- continue # primary key exists already
- sql_type = model_fields.get(col)
- if not sql_type:
- sql_type = "TEXT"
- try:
- cursor.execute(
- f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}'
- )
- logger.info(
- f"Added missing column '{col}' ({sql_type}) to '{table}'"
- )
- except Exception as add_err:
- logger.warning(
- f"Could not add column '{col}' to '{table}': {add_err}"
- )
-
- # Column type migrations for existing tables.
- # TEXT→DOUBLE PRECISION handles three value shapes:
- # 1. NULL / empty string → NULL
- # 2. ISO date(time) like "2025-01-22" or "2025-01-22T10:00:00+00" → epoch via EXTRACT
- # 3. Plain numeric string like "3.14" → direct cast
- _TEXT_TO_DOUBLE = (
- 'DOUBLE PRECISION USING CASE'
- ' WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL'
- ' WHEN "{col}" ~ \'^\\d{4}-\\d{2}-\\d{2}\''
- ' THEN EXTRACT(EPOCH FROM "{col}"::timestamptz)'
- ' ELSE NULLIF("{col}", \'\')::double precision'
- ' END'
- )
- _SAFE_TYPE_CHANGES = {
- ("jsonb", "TEXT"): "TEXT USING \"{col}\"::text",
- ("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE,
- ("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer",
- ("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')',
- ("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")',
- ("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')',
- }
- for col in sorted(desired_columns & existing_columns):
- if col == "id":
- continue
- desired_sql = (model_fields.get(col) or "").upper()
- currentType = existing_column_types.get(col, "")
- migration = _SAFE_TYPE_CHANGES.get((currentType, desired_sql))
- if migration:
- castExpr = migration.replace("{col}", col)
+ for col in sorted(desired_columns - existing_columns):
+ if col in ["id"]:
+ continue
+ sql_type = model_fields.get(col)
+ if not sql_type:
+ sql_type = "TEXT"
try:
- cursor.execute('SAVEPOINT col_migrate')
cursor.execute(
- f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {castExpr}'
+ f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}'
)
- cursor.execute('RELEASE SAVEPOINT col_migrate')
logger.info(
- f"Migrated column '{col}' from {currentType} to {desired_sql} on '{table}'"
+ f"Added missing column '{col}' ({sql_type}) to '{table}'"
)
- except Exception as alter_err:
- cursor.execute('ROLLBACK TO SAVEPOINT col_migrate')
+ except Exception as add_err:
logger.warning(
- f"Could not migrate column '{col}' on '{table}': {alter_err}"
+ f"Could not add column '{col}' to '{table}': {add_err}"
)
- except Exception as ensure_err:
- logger.warning(
- f"Could not ensure columns for existing table '{table}': {ensure_err}"
- )
- self.connection.commit()
+ # Column type migrations for existing tables.
+ # TEXT→DOUBLE PRECISION handles three value shapes:
+ # 1. NULL / empty string → NULL
+ # 2. ISO date(time) like "2025-01-22" or "2025-01-22T10:00:00+00" → epoch via EXTRACT
+ # 3. Plain numeric string like "3.14" → direct cast
+ _TEXT_TO_DOUBLE = (
+ 'DOUBLE PRECISION USING CASE'
+ ' WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL'
+ ' WHEN "{col}" ~ \'^\\d{4}-\\d{2}-\\d{2}\''
+ ' THEN EXTRACT(EPOCH FROM "{col}"::timestamptz)'
+ ' ELSE NULLIF("{col}", \'\')::double precision'
+ ' END'
+ )
+ _SAFE_TYPE_CHANGES = {
+ ("jsonb", "TEXT"): "TEXT USING \"{col}\"::text",
+ ("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE,
+ ("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer",
+ ("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')',
+ ("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")',
+ ("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')',
+ }
+ for col in sorted(desired_columns & existing_columns):
+ if col == "id":
+ continue
+ desired_sql = (model_fields.get(col) or "").upper()
+ currentType = existing_column_types.get(col, "")
+ migration = _SAFE_TYPE_CHANGES.get((currentType, desired_sql))
+ if migration:
+ castExpr = migration.replace("{col}", col)
+ try:
+ cursor.execute('SAVEPOINT col_migrate')
+ cursor.execute(
+ f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {castExpr}'
+ )
+ cursor.execute('RELEASE SAVEPOINT col_migrate')
+ logger.info(
+ f"Migrated column '{col}' from {currentType} to {desired_sql} on '{table}'"
+ )
+ except Exception as alter_err:
+ cursor.execute('ROLLBACK TO SAVEPOINT col_migrate')
+ logger.warning(
+ f"Could not migrate column '{col}' on '{table}': {alter_err}"
+ )
+ except Exception as ensure_err:
+ logger.warning(
+ f"Could not ensure columns for existing table '{table}': {ensure_err}"
+ )
return True
except Exception as e:
logger.error(f"Error ensuring table {table} exists: {e}")
- if hasattr(self, "connection") and self.connection:
- self.connection.rollback()
return False
def _ensureVectorExtension(self) -> bool:
@@ -666,17 +886,14 @@ class DatabaseConnector:
if self._vectorExtensionEnabled:
return True
try:
- self._ensure_connection()
- with self.connection.cursor() as cursor:
- cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
- self.connection.commit()
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
self._vectorExtensionEnabled = True
logger.info("pgvector extension enabled")
return True
except Exception as e:
logger.error(f"Failed to enable pgvector extension: {e}")
- if hasattr(self, "connection") and self.connection:
- self.connection.rollback()
return False
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
@@ -791,22 +1008,19 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class):
return None
- with self.connection.cursor() as cursor:
- cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
- row = cursor.fetchone()
- if not row:
- return None
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
+ row = cursor.fetchone()
+ if not row:
+ return None
+ record = dict(row)
- # Convert row to dict and handle JSONB fields
- record = dict(row)
- fields = getModelFields(model_class)
-
- parseRecordFields(record, fields, f"record {recordId}")
-
- return record
+ fields = getModelFields(model_class)
+ parseRecordFields(record, fields, f"record {recordId}")
+ return record
except Exception as e:
logger.error(f"Error loading record {recordId} from table {table}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
@@ -849,14 +1063,12 @@ class DatabaseConnector:
if effective_user_id:
record["sysModifiedBy"] = effective_user_id
- with self.connection.cursor() as cursor:
- self._save_record(cursor, table, recordId, record, model_class)
-
- self.connection.commit()
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ self._save_record(cursor, table, recordId, record, model_class)
return True
except Exception as e:
logger.error(f"Error saving record {recordId} to table {table}: {e}")
- self.connection.rollback()
return False
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
@@ -870,33 +1082,32 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class):
return []
- with self.connection.cursor() as cursor:
- cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
- records = [dict(row) for row in cursor.fetchall()]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
+ records = [dict(row) for row in cursor.fetchall()]
- fields = getModelFields(model_class)
- modelFields = model_class.model_fields
- for record in records:
- parseRecordFields(record, fields, f"table {table}")
- # Set type-aware defaults for NULL JSONB fields
- for fieldName, fieldType in fields.items():
- if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
- fieldInfo = modelFields.get(fieldName)
- if fieldInfo:
- fieldAnnotation = fieldInfo.annotation
- if (fieldAnnotation == list or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is list)):
- record[fieldName] = []
- elif (fieldAnnotation == dict or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is dict)):
- record[fieldName] = {}
+ fields = getModelFields(model_class)
+ modelFields = model_class.model_fields
+ for record in records:
+ parseRecordFields(record, fields, f"table {table}")
+ for fieldName, fieldType in fields.items():
+ if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
+ fieldInfo = modelFields.get(fieldName)
+ if fieldInfo:
+ fieldAnnotation = fieldInfo.annotation
+ if (fieldAnnotation == list or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is list)):
+ record[fieldName] = []
+ elif (fieldAnnotation == dict or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is dict)):
+ record[fieldName] = {}
- return records
+ return records
except Exception as e:
logger.error(f"Error loading table {table}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def _registerInitialId(self, table: str, initialId: str) -> bool:
@@ -969,28 +1180,20 @@ class DatabaseConnector:
def getTables(self) -> List[str]:
"""Returns a list of all available tables."""
- tables = []
-
+ tables: List[str] = []
try:
- # Ensure connection is alive
- self._ensure_connection()
-
- if not self.connection or self.connection.closed:
- logger.error("Database connection is not available")
- return tables
-
- with self.connection.cursor() as cursor:
- cursor.execute("""
- SELECT table_name
- FROM information_schema.tables
- WHERE table_schema = 'public'
- ORDER BY table_name
- """)
- rows = cursor.fetchall()
- tables = [row["table_name"] for row in rows]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute("""
+ SELECT table_name
+ FROM information_schema.tables
+ WHERE table_schema = 'public'
+ ORDER BY table_name
+ """)
+ rows = cursor.fetchall()
+ tables = [row["table_name"] for row in rows]
except Exception as e:
logger.error(f"Error reading the database {self.dbDatabase}: {e}")
-
return tables
def getFields(self, model_class: type) -> List[str]:
@@ -1060,43 +1263,42 @@ class DatabaseConnector:
query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"'
- with self.connection.cursor() as cursor:
- cursor.execute(query, where_values)
- records = [dict(row) for row in cursor.fetchall()]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(query, where_values)
+ records = [dict(row) for row in cursor.fetchall()]
- fields = getModelFields(model_class)
- modelFields = model_class.model_fields
+ fields = getModelFields(model_class)
+ modelFields = model_class.model_fields
+ for record in records:
+ parseRecordFields(record, fields, f"table {table}")
+ for fieldName, fieldType in fields.items():
+ if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
+ fieldInfo = modelFields.get(fieldName)
+ if fieldInfo:
+ fieldAnnotation = fieldInfo.annotation
+ if (fieldAnnotation == list or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is list)):
+ record[fieldName] = []
+ elif (fieldAnnotation == dict or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is dict)):
+ record[fieldName] = {}
+
+ if fieldFilter and isinstance(fieldFilter, list):
+ result = []
for record in records:
- parseRecordFields(record, fields, f"table {table}")
- for fieldName, fieldType in fields.items():
- if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
- fieldInfo = modelFields.get(fieldName)
- if fieldInfo:
- fieldAnnotation = fieldInfo.annotation
- if (fieldAnnotation == list or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is list)):
- record[fieldName] = []
- elif (fieldAnnotation == dict or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is dict)):
- record[fieldName] = {}
+ filteredRecord = {}
+ for field in fieldFilter:
+ if field in record:
+ filteredRecord[field] = record[field]
+ result.append(filteredRecord)
+ return result
- # If fieldFilter is available, reduce the fields
- if fieldFilter and isinstance(fieldFilter, list):
- result = []
- for record in records:
- filteredRecord = {}
- for field in fieldFilter:
- if field in record:
- filteredRecord[field] = record[field]
- result.append(filteredRecord)
- return result
-
- return records
+ return records
except Exception as e:
logger.error(f"Error loading records from table {table}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def _buildPaginationClauses(
@@ -1281,35 +1483,36 @@ class DatabaseConnector:
where_clause, order_clause, limit_clause, values, count_values = \
self._buildPaginationClauses(model_class, pagination, recordFilter)
- with self.connection.cursor() as cursor:
- countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
- dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
- cursor.execute(countSql, count_values)
- totalItems = cursor.fetchone()["count"]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
+ dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
+ cursor.execute(countSql, count_values)
+ totalItems = cursor.fetchone()["count"]
- cursor.execute(dataSql, values)
- records = [dict(row) for row in cursor.fetchall()]
+ cursor.execute(dataSql, values)
+ records = [dict(row) for row in cursor.fetchall()]
- fields = getModelFields(model_class)
- modelFields = model_class.model_fields
- for record in records:
- parseRecordFields(record, fields, f"table {table}")
- for fieldName, fieldType in fields.items():
- if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
- fieldInfo = modelFields.get(fieldName)
- if fieldInfo:
- fieldAnnotation = fieldInfo.annotation
- if (fieldAnnotation == list or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is list)):
- record[fieldName] = []
- elif (fieldAnnotation == dict or
- (hasattr(fieldAnnotation, "__origin__") and
- fieldAnnotation.__origin__ is dict)):
- record[fieldName] = {}
+ fields = getModelFields(model_class)
+ modelFields = model_class.model_fields
+ for record in records:
+ parseRecordFields(record, fields, f"table {table}")
+ for fieldName, fieldType in fields.items():
+ if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
+ fieldInfo = modelFields.get(fieldName)
+ if fieldInfo:
+ fieldAnnotation = fieldInfo.annotation
+ if (fieldAnnotation == list or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is list)):
+ record[fieldName] = []
+ elif (fieldAnnotation == dict or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is dict)):
+ record[fieldName] = {}
- if fieldFilter and isinstance(fieldFilter, list):
- records = [{f: r[f] for f in fieldFilter if f in r} for r in records]
+ if fieldFilter and isinstance(fieldFilter, list):
+ records = [{f: r[f] for f in fieldFilter if f in r} for r in records]
from modules.routes.routeHelpers import enrichRowsWithFkLabels
enrichRowsWithFkLabels(records, model_class)
@@ -1320,7 +1523,6 @@ class DatabaseConnector:
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
except Exception as e:
logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def getDistinctColumnValues(
@@ -1365,25 +1567,24 @@ class DatabaseConnector:
else:
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val'
- with self.connection.cursor() as cursor:
- cursor.execute(sql, values)
- result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(sql, values)
+ result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()]
- if includeEmpty:
- emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\''
- if where_clause:
- emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1'
- else:
- emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1'
- with self.connection.cursor() as cursor:
- cursor.execute(emptySql, values)
- if cursor.fetchone():
- result.append(None)
+ if includeEmpty:
+ emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\''
+ if where_clause:
+ emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1'
+ else:
+ emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1'
+ cursor.execute(emptySql, values)
+ if cursor.fetchone():
+ result.append(None)
return result
except Exception as e:
logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def recordCreate(
@@ -1463,33 +1664,33 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class):
return False
- with self.connection.cursor() as cursor:
- # Check if record exists
- cursor.execute(
- f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
- )
- if not cursor.fetchone():
- return False
+ # `getInitialId` opens its own borrow; do it BEFORE we acquire a
+ # connection ourselves so we don't pin two slots concurrently.
+ initialId = self.getInitialId(model_class)
- # Check if it's an initial record
- initialId = self.getInitialId(model_class)
- if initialId is not None and initialId == recordId:
- self._removeInitialId(table)
- logger.info(
- f"Initial ID {recordId} for table {table} has been removed from the system table"
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(
+ f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
)
+ if not cursor.fetchone():
+ return False
- # Delete the record
- cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
+ if initialId is not None and initialId == recordId:
+ # `_removeInitialId` borrows its own conn — done outside
+ # this block on purpose to avoid nested borrows.
+ pass
+ cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
- # No cache to update - database handles consistency
-
- self.connection.commit()
+ if initialId is not None and initialId == recordId:
+ self._removeInitialId(table)
+ logger.info(
+ f"Initial ID {recordId} for table {table} has been removed from the system table"
+ )
return True
except Exception as e:
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
- self.connection.rollback()
return False
def recordCreateBulk(
@@ -1559,16 +1760,11 @@ class DatabaseConnector:
)
try:
- self._ensure_connection()
- with self.connection.cursor() as cursor:
- psycopg2.extras.execute_values(cursor, sql, rows, page_size=500)
- self.connection.commit()
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ psycopg2.extras.execute_values(cursor, sql, rows, page_size=500)
except Exception as e:
logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}")
- try:
- self.connection.rollback()
- except Exception:
- pass
raise
if self.getInitialId(model_class) is None and normalised:
@@ -1649,26 +1845,21 @@ class DatabaseConnector:
initialId = self.getInitialId(model_class)
try:
- self._ensure_connection()
- with self.connection.cursor() as cursor:
- if initialId is not None:
- cursor.execute(
- f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql,
- [initialId, *params],
- )
- initialIsAffected = cursor.fetchone() is not None
- else:
- initialIsAffected = False
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ if initialId is not None:
+ cursor.execute(
+ f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql,
+ [initialId, *params],
+ )
+ initialIsAffected = cursor.fetchone() is not None
+ else:
+ initialIsAffected = False
- cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params)
- deleted = cursor.rowcount or 0
- self.connection.commit()
+ cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params)
+ deleted = cursor.rowcount or 0
except Exception as e:
logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}")
- try:
- self.connection.rollback()
- except Exception:
- pass
raise
if deleted and initialIsAffected:
@@ -1751,39 +1942,30 @@ class DatabaseConnector:
)
params = [vectorStr] + whereValues + [vectorStr, limit]
- with self.connection.cursor() as cursor:
- cursor.execute(query, params)
- records = [dict(row) for row in cursor.fetchall()]
+ with self.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute(query, params)
+ records = [dict(row) for row in cursor.fetchall()]
- fields = getModelFields(modelClass)
- for record in records:
- parseRecordFields(record, fields, f"semanticSearch {table}")
-
- return records
+ fields = getModelFields(modelClass)
+ for record in records:
+ parseRecordFields(record, fields, f"semanticSearch {table}")
+ return records
except Exception as e:
logger.error(f"Error in semantic search on {table}: {e}")
- _rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def close(self, forceClose: bool = False):
- """Close the database connection.
-
- Shared cached connectors are intentionally kept open unless forceClose=True.
- This prevents accidental shutdown from interface __del__ methods while
- other requests are still using the same cached connector instance.
+ """No-op for backward compatibility.
+
+ Connections are now owned by the `_PoolRegistry` pool and live for the
+ process lifetime. Pool shutdown happens centrally via `closeAllPools()`
+ from the FastAPI lifespan hook — never from a connector instance.
+ Interface `__del__` paths used to call `close()` to release a per-
+ connector socket; with pooling there is nothing to close here.
"""
- if self._isCachedShared and not forceClose:
- return
- if (
- hasattr(self, "connection")
- and self.connection
- and not self.connection.closed
- ):
- self.connection.close()
+ return
def __del__(self):
- """Cleanup method to close connection."""
- try:
- self.close()
- except Exception:
- pass
+ """Cleanup hook (intentionally no-op — see `close`)."""
+ return
diff --git a/modules/features/realEstate/interfaceFeatureRealEstate.py b/modules/features/realEstate/interfaceFeatureRealEstate.py
index 1fbaf06f..0637d0e9 100644
--- a/modules/features/realEstate/interfaceFeatureRealEstate.py
+++ b/modules/features/realEstate/interfaceFeatureRealEstate.py
@@ -342,7 +342,7 @@ class RealEstateObjects:
# If no exact match, try case-insensitive search via SQL query
# This handles cases where the name might have different casing
self.db._ensure_connection()
- with self.db.connection.cursor() as cursor:
+ with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@@ -375,7 +375,7 @@ class RealEstateObjects:
# Try case-insensitive search
self.db._ensure_connection()
- with self.db.connection.cursor() as cursor:
+ with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@@ -408,7 +408,7 @@ class RealEstateObjects:
# Try case-insensitive search
self.db._ensure_connection()
- with self.db.connection.cursor() as cursor:
+ with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@@ -840,7 +840,7 @@ class RealEstateObjects:
# Ensure connection is alive
self.db._ensure_connection()
- with self.db.connection.cursor() as cursor:
+ with self.db.borrowCursor() as cursor:
# Execute query
if parameters:
# Use parameterized query for safety
diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py
index 25f022af..273583d9 100644
--- a/modules/interfaces/interfaceDbBilling.py
+++ b/modules/interfaces/interfaceDbBilling.py
@@ -1659,7 +1659,7 @@ class BillingObjects:
try:
appInterface = getAppInterface(self.currentUser)
appInterface.db._ensure_connection()
- with appInterface.db.connection.cursor() as cur:
+ with appInterface.db.borrowCursor() as cur:
if appInterface.db._ensureTableExists(UserInDB):
cur.execute(
'SELECT "id" FROM "UserInDB" WHERE '
@@ -1780,7 +1780,7 @@ class BillingObjects:
try:
self.db._ensure_connection()
- with self.db.connection.cursor() as cur:
+ with self.db.borrowCursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cur.execute(countSql, whereValues)
totalItems = cur.fetchone()["count"]
@@ -1797,10 +1797,7 @@ class BillingObjects:
except Exception as e:
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
- try:
- self.db.connection.rollback()
- except Exception:
- pass
+ # Rollback is handled by `borrowCursor()` context manager on exit.
return {"items": [], "totalItems": 0, "totalPages": 0}
def _buildScopeFilter(
@@ -1872,7 +1869,7 @@ class BillingObjects:
result: Dict[str, Any] = {}
- with self.db.connection.cursor() as cur:
+ with self.db.borrowCursor() as cur:
# 1) Totals
cur.execute(
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
@@ -1947,17 +1944,12 @@ class BillingObjects:
})
result["timeSeries"] = timeSeries
- self.db.connection.commit()
-
+ # Commit/rollback are handled by `borrowCursor()` context manager.
result["_allAccounts"] = allAccounts
return result
except Exception as e:
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
- try:
- self.db.connection.rollback()
- except Exception:
- pass
return self._emptyStats()
@staticmethod
diff --git a/modules/interfaces/interfaceDbKnowledge.py b/modules/interfaces/interfaceDbKnowledge.py
index 31a5af61..d7a445bd 100644
--- a/modules/interfaces/interfaceDbKnowledge.py
+++ b/modules/interfaces/interfaceDbKnowledge.py
@@ -228,6 +228,22 @@ class KnowledgeObjects:
"""Get all ContentChunks for a file."""
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
+ def countChunksByFileIds(self, fileIds: List[str]) -> Dict[str, int]:
+ """Return a {fileId: chunkCount} mapping for the given file IDs.
+
+ One aggregate query instead of N round trips. Used by RAG inventory
+ to display real chunk counts per DataSource without loading the
+ embedding vectors. Missing file IDs map to 0 in the caller's logic.
+ """
+ if not fileIds:
+ return {}
+ if not self.db._ensureTableExists(ContentChunk):
+ return {}
+ sql = 'SELECT "fileId", COUNT(*) AS cnt FROM "ContentChunk" WHERE "fileId" = ANY(%s) GROUP BY "fileId"'
+ with self.db.borrowCursor() as cursor:
+ cursor.execute(sql, (list(fileIds),))
+ return {row["fileId"]: int(row["cnt"]) for row in cursor.fetchall()}
+
def deleteContentChunks(self, fileId: str) -> int:
"""Delete all ContentChunks for a file. Returns count of deleted chunks."""
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
diff --git a/modules/interfaces/interfaceDbManagement.py b/modules/interfaces/interfaceDbManagement.py
index 6a3c27b5..4dc8a206 100644
--- a/modules/interfaces/interfaceDbManagement.py
+++ b/modules/interfaces/interfaceDbManagement.py
@@ -1221,22 +1221,17 @@ class ComponentObjects:
for item in fileRows
]
- # Single transaction: delete FileData, FileItem, then FileFolder (children first)
- self.db._ensure_connection()
- try:
- with self.db.connection.cursor() as cursor:
- if fileIds:
- cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
- cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
- orderedIds = list(folderIds)
- orderedIds.remove(folderId)
- orderedIds.append(folderId)
- if orderedIds:
- cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
- self.db.connection.commit()
- except Exception:
- self.db.connection.rollback()
- raise
+ # Single transaction: delete FileData, FileItem, then FileFolder (children first).
+ # Commit/rollback are handled by `borrowCursor()` on exit.
+ with self.db.borrowCursor() as cursor:
+ if fileIds:
+ cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
+ cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
+ orderedIds = list(folderIds)
+ orderedIds.remove(folderId)
+ orderedIds.append(folderId)
+ if orderedIds:
+ cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)}
@@ -1507,7 +1502,7 @@ class ComponentObjects:
try:
self.db._ensure_connection()
- with self.db.connection.cursor() as cursor:
+ with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
(uniqueIds,),
@@ -1526,11 +1521,10 @@ class ComponentObjects:
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,))
deletedFiles = cursor.rowcount
- self.db.connection.commit()
+ # Commit/rollback are handled by `borrowCursor()` context manager.
return {"deletedFiles": deletedFiles}
except Exception as e:
logger.error(f"Error deleting files in batch: {e}")
- self.db.connection.rollback()
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]:
diff --git a/modules/interfaces/interfaceRbac.py b/modules/interfaces/interfaceRbac.py
index e41485e0..948609ef 100644
--- a/modules/interfaces/interfaceRbac.py
+++ b/modules/interfaces/interfaceRbac.py
@@ -374,7 +374,7 @@ def getRecordsetWithRBAC(
query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
cursor.execute(query, whereValues)
records = [dict(row) for row in cursor.fetchall()]
@@ -561,7 +561,7 @@ def getRecordsetPaginatedWithRBAC(
offset = (pagination.page - 1) * pagination.pageSize
limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cursor.execute(countSql, countValues)
totalItems = cursor.fetchone()["count"]
@@ -709,7 +709,7 @@ def getDistinctColumnValuesWithRBAC(
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val'
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
cursor.execute(sql, whereValues)
result = [row["val"] for row in cursor.fetchall()]
@@ -719,7 +719,7 @@ def getDistinctColumnValuesWithRBAC(
emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1'
else:
emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1'
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
cursor.execute(emptySql, whereValues)
if cursor.fetchone():
result.append(None)
@@ -967,7 +967,7 @@ def buildRbacWhereClause(
# Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate
if table == "UserInDB":
try:
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate
cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
@@ -994,7 +994,7 @@ def buildRbacWhereClause(
# For UserConnection: Filter via UserMandate junction table
elif table == "UserConnection":
try:
- with connector.connection.cursor() as cursor:
+ with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate
cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
diff --git a/modules/routes/routeHelpers.py b/modules/routes/routeHelpers.py
index f1d88e31..bb1386af 100644
--- a/modules/routes/routeHelpers.py
+++ b/modules/routes/routeHelpers.py
@@ -305,7 +305,7 @@ def handleIdsMode(
sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"'
- with db.connection.cursor() as cursor:
+ with db.borrowCursor() as cursor:
cursor.execute(sql, values)
return JSONResponse(content=[row["val"] for row in cursor.fetchall()])
except Exception as e:
diff --git a/modules/routes/routeRagInventory.py b/modules/routes/routeRagInventory.py
index 074b5b85..7c426d77 100644
--- a/modules/routes/routeRagInventory.py
+++ b/modules/routes/routeRagInventory.py
@@ -25,6 +25,18 @@ router = APIRouter(
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
+ """Build per-connection RAG inventory rows.
+
+ Each DataSource row exposes BOTH numbers because they mean different things:
+ * `fileCount` — distinct files indexed (== `FileContentIndex` rows)
+ * `chunkCount` — embedding-sized text fragments (== `ContentChunk` rows,
+ max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval
+ actually hits)
+
+ A single PDF typically yields 1 file × 5–100 chunks; legacy UI labelled
+ `len(FileContentIndex)` as "chunks" which was off by 1–2 orders of
+ magnitude and misleading.
+ """
from modules.datamodels.datamodelDataSource import DataSource
from modules.datamodels.datamodelKnowledge import FileContentIndex
@@ -34,19 +46,35 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
- connChunkTotal = len(connIndexRows)
+ connFileTotal = len(connIndexRows)
+ # Map fileId → real chunk count via 1 aggregate query (cheap even for
+ # connections with thousands of files; we never load the vector body).
+ fileIds = [
+ (idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", ""))
+ for idx in connIndexRows
+ ]
+ fileIds = [fid for fid in fileIds if fid]
+ chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
+ connChunkTotal = sum(chunkCountByFile.values())
+
+ filesByDs: Dict[str, int] = {}
chunksByDs: Dict[str, int] = {}
- unassigned = 0
+ unassignedFiles = 0
+ unassignedChunks = 0
for idx in connIndexRows:
+ fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")
+ chunkCnt = chunkCountByFile.get(fileId, 0)
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
if dsIdRef:
- chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
+ filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1
+ chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt
else:
- unassigned += 1
+ unassignedFiles += 1
+ unassignedChunks += chunkCnt
seen: Dict[str, bool] = {}
dsItems = []
@@ -64,14 +92,19 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
+ "fileCount": filesByDs.get(dsId, 0),
"chunkCount": chunksByDs.get(dsId, 0),
})
- if unassigned > 0 and len(dsItems) > 0:
- perDs = unassigned // len(dsItems)
- remainder = unassigned % len(dsItems)
+ # Spread orphan files (provenance lost) evenly so totals match.
+ if unassignedFiles > 0 and len(dsItems) > 0:
+ perFile = unassignedFiles // len(dsItems)
+ remFile = unassignedFiles % len(dsItems)
+ perChunk = unassignedChunks // len(dsItems)
+ remChunk = unassignedChunks % len(dsItems)
for i, item in enumerate(dsItems):
- item["chunkCount"] += perDs + (1 if i < remainder else 0)
+ item["fileCount"] += perFile + (1 if i < remFile else 0)
+ item["chunkCount"] += perChunk + (1 if i < remChunk else 0)
# Pull a wider window than the previous 5 so the "last successful
# sync" is found even if a connection has many recent jobs queued.
@@ -102,6 +135,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"skippedPolicy": result.get("skippedPolicy", 0),
"failed": result.get("failed", 0),
"durationMs": result.get("durationMs", 0),
+ # Surface limit-stop reason so the UI can warn the user
+ # that the index is provably incomplete (and which budget
+ # to raise). None means the walker finished naturally.
+ "stoppedAtLimit": result.get("stoppedAtLimit"),
+ "limits": result.get("limits") or {},
+ "bytesProcessed": result.get("bytesProcessed", 0),
}
if lastError and lastSuccess:
break
@@ -113,6 +152,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
"preferences": getattr(conn, "knowledgePreferences", None) or {},
"dataSources": dsItems,
+ "totalFiles": connFileTotal,
"totalChunks": connChunkTotal,
"runningJobs": runningJobs,
"lastError": lastError,
@@ -139,8 +179,9 @@ def _getInventoryMe(
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
+ totalFiles = sum(c.get("totalFiles", 0) for c in items)
- return {"connections": items, "totals": {"chunks": totalChunks}}
+ return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except Exception as e:
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@@ -170,9 +211,10 @@ def _getInventoryMandate(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
+ totalFiles = sum(c.get("totalFiles", 0) for c in items)
totalBytes = aggregateMandateRagTotalBytes(mandateId)
- return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}}
+ return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}}
except HTTPException:
raise
except Exception as e:
@@ -202,8 +244,9 @@ def _getInventoryPlatform(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
+ totalFiles = sum(c.get("totalFiles", 0) for c in items)
- return {"connections": items, "totals": {"chunks": totalChunks}}
+ return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except HTTPException:
raise
except Exception as e:
diff --git a/modules/routes/routeWorkflowDashboard.py b/modules/routes/routeWorkflowDashboard.py
index d83ce1b2..85b372a1 100644
--- a/modules/routes/routeWorkflowDashboard.py
+++ b/modules/routes/routeWorkflowDashboard.py
@@ -227,7 +227,7 @@ WHERE "workflowId" = ANY(%s)
GROUP BY "workflowId"
"""
out: dict = {}
- with db.connection.cursor() as cursor:
+ with db.borrowCursor() as cursor:
cursor.execute(sql, (workflowIds,))
for row in cursor.fetchall():
r = dict(row)
@@ -480,7 +480,7 @@ def _getWorkflowsJoinedPaginated(
dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}"
db._ensure_connection()
- with db.connection.cursor() as cursor:
+ with db.borrowCursor() as cursor:
cursor.execute(countSql, countValues)
totalItems = int(cursor.fetchone()["cnt"])
diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py
index 4fbea490..bdb3d23b 100644
--- a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py
+++ b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py
@@ -25,15 +25,14 @@ _CACHE_TTL_SECONDS = 300
def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str):
- """Reuse a pooled DB connector for the given feature database."""
+ """Reuse a pooled DB connector for the given feature database.
+
+ The underlying psycopg2 connections live in the central pool
+ (`_PoolRegistry`) and are recreated on demand if they go stale; we just
+ need to keep the lightweight connector wrapper around.
+ """
if featureDbName in _featureDbConnPool:
- conn = _featureDbConnPool[featureDbName]
- try:
- if conn.connection and not conn.connection.closed:
- return conn
- except Exception as e:
- logger.warning(f"Feature DB connection check failed for {featureDbName}: {e}")
- _featureDbConnPool.pop(featureDbName, None)
+ return _featureDbConnPool[featureDbName]
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py
index 8bfa2628..959e42c9 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py
@@ -68,6 +68,9 @@ class ClickupBootstrapResult:
workspaces: int = 0
lists: int = 0
errors: List[str] = field(default_factory=list)
+ # First budget exhausted: "maxTasks" | "maxWorkspaces" | "maxListsPerWorkspace" | None.
+ # Drives the same UI banner as the file-walker bootstraps.
+ stoppedAtLimit: Optional[str] = None
def _syntheticTaskId(connectionId: str, taskId: str) -> str:
@@ -225,6 +228,7 @@ async def bootstrapClickup(
cancelled = False
for ds in dataSources:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
+ _recordLimitStop(result, "maxTasks", "dataSource", limits)
break
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
cancelled = True
@@ -243,8 +247,11 @@ async def bootstrapClickup(
clickupScope=limits.clickupScope,
)
+ if len(teams) > dsLimits.maxWorkspaces:
+ _recordLimitStop(result, "maxWorkspaces", "teams", dsLimits, hard=False)
for team in teams[:dsLimits.maxWorkspaces]:
if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks:
+ _recordLimitStop(result, "maxTasks", f"team={team.get('id','')}", dsLimits)
break
teamId = str(team.get("id", "") or "")
if not teamId:
@@ -351,6 +358,7 @@ async def _walkTeam(
for lst in listsCollected:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
+ _recordLimitStop(result, "maxTasks", f"team={teamId}", limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
return
@@ -407,6 +415,7 @@ async def _walkList(
for task in tasks:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
+ _recordLimitStop(result, "maxTasks", f"list={listId}", limits)
return
if not _isRecent(task.get("date_updated"), limits.maxAgeDays):
result.skippedPolicy += 1
@@ -529,13 +538,37 @@ async def _ingestTask(
)
+def _recordLimitStop(
+ result: ClickupBootstrapResult,
+ limitName: str,
+ where: str,
+ limits: ClickupBootstrapLimits,
+ *,
+ hard: bool = True,
+) -> None:
+ """See subConnectorSyncSharepoint._recordLimitStop for semantics."""
+ if hard or result.stoppedAtLimit is None:
+ result.stoppedAtLimit = limitName
+ budgetMap = {
+ "maxTasks": limits.maxTasks,
+ "maxWorkspaces": limits.maxWorkspaces,
+ "maxListsPerWorkspace": limits.maxListsPerWorkspace,
+ }
+ logger.warning(
+ "clickup walker hit %s=%s at %s — partial index (indexed=%d, skippedDup=%d).",
+ limitName, budgetMap.get(limitName), where,
+ result.indexed, result.skippedDuplicate,
+ )
+
+
def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
- "ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d",
+ "ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.workspaces, result.lists, durationMs,
+ result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "clickup",
@@ -547,6 +580,7 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"workspaces": result.workspaces,
"lists": result.lists,
"durationMs": durationMs,
+ "stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@@ -559,4 +593,11 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"lists": result.lists,
"durationMs": durationMs,
"errors": result.errors[:20],
+ "stoppedAtLimit": result.stoppedAtLimit,
+ "limits": {
+ "maxTasks": MAX_TASKS_DEFAULT,
+ "maxWorkspaces": MAX_WORKSPACES_DEFAULT,
+ "maxListsPerWorkspace": MAX_LISTS_PER_WORKSPACE_DEFAULT,
+ "maxAgeDays": MAX_AGE_DAYS_DEFAULT,
+ },
}
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py
index 5dd1bd8b..e27abacb 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py
@@ -61,6 +61,8 @@ class GdriveBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
+ # See SharepointBootstrapResult.stoppedAtLimit — same semantics.
+ stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@@ -265,8 +267,10 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
+ _recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
+ _recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
@@ -276,6 +280,9 @@ async def _walkFolder(
mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType")
if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME:
+ if depth + 1 > limits.maxDepth:
+ _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
+ continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@@ -298,6 +305,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
+ _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
modifiedTime = metadata.get("modifiedTime")
@@ -470,13 +478,38 @@ async def _ingestOne(
await asyncio.sleep(0)
+def _recordLimitStop(
+ result: GdriveBootstrapResult,
+ limitName: str,
+ where: str,
+ limits: GdriveBootstrapLimits,
+ *,
+ hard: bool = True,
+) -> None:
+ """See subConnectorSyncSharepoint._recordLimitStop for semantics."""
+ if hard or result.stoppedAtLimit is None:
+ result.stoppedAtLimit = limitName
+ budgetMap = {
+ "maxItems": limits.maxItems,
+ "maxBytes": limits.maxBytes,
+ "maxDepth": limits.maxDepth,
+ "maxFileSize": limits.maxFileSize,
+ }
+ logger.warning(
+ "gdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
+ limitName, budgetMap.get(limitName), where,
+ result.indexed, result.bytesProcessed,
+ )
+
+
def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
- "ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d",
+ "ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.bytesProcessed, durationMs,
+ result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "gdrive",
@@ -487,6 +520,7 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"failed": result.failed,
"bytes": result.bytesProcessed,
"durationMs": durationMs,
+ "stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@@ -498,4 +532,11 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
+ "stoppedAtLimit": result.stoppedAtLimit,
+ "limits": {
+ "maxItems": MAX_ITEMS_DEFAULT,
+ "maxBytes": MAX_BYTES_DEFAULT,
+ "maxFileSize": MAX_FILE_SIZE_DEFAULT,
+ "maxDepth": MAX_DEPTH_DEFAULT,
+ },
}
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py
index e656abe8..dcf19e39 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py
@@ -53,6 +53,8 @@ class KdriveBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
+ # See SharepointBootstrapResult.stoppedAtLimit — same semantics.
+ stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@@ -232,14 +234,19 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
+ _recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
+ _recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False):
+ if depth + 1 > limits.maxDepth:
+ _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
+ continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@@ -262,6 +269,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
+ _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
@@ -415,17 +423,42 @@ async def _ingestOne(
await asyncio.sleep(0)
+def _recordLimitStop(
+ result: KdriveBootstrapResult,
+ limitName: str,
+ where: str,
+ limits: KdriveBootstrapLimits,
+ *,
+ hard: bool = True,
+) -> None:
+ """See subConnectorSyncSharepoint._recordLimitStop for semantics."""
+ if hard or result.stoppedAtLimit is None:
+ result.stoppedAtLimit = limitName
+ budgetMap = {
+ "maxItems": limits.maxItems,
+ "maxBytes": limits.maxBytes,
+ "maxDepth": limits.maxDepth,
+ "maxFileSize": limits.maxFileSize,
+ }
+ logger.warning(
+ "kdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
+ limitName, budgetMap.get(limitName), where,
+ result.indexed, result.bytesProcessed,
+ )
+
+
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
- "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
+ "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
- durationMs,
+ durationMs, result.stoppedAtLimit or "none",
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
"connectionId": connectionId, "indexed": result.indexed,
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
- "failed": result.failed, "durationMs": durationMs},
+ "failed": result.failed, "durationMs": durationMs,
+ "stoppedAtLimit": result.stoppedAtLimit},
)
return {
"connectionId": result.connectionId,
@@ -436,4 +469,11 @@ def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
+ "stoppedAtLimit": result.stoppedAtLimit,
+ "limits": {
+ "maxItems": MAX_ITEMS_DEFAULT,
+ "maxBytes": MAX_BYTES_DEFAULT,
+ "maxFileSize": MAX_FILE_SIZE_DEFAULT,
+ "maxDepth": MAX_DEPTH_DEFAULT,
+ },
}
diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py
index 892e41ba..e06fd36b 100644
--- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py
+++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py
@@ -59,6 +59,10 @@ class SharepointBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
+ # First budget that hit zero; None means the walk completed naturally.
+ # Surfaces in the bootstrap result so the RAG inventory UI can warn the
+ # user that the corpus is incomplete and tell them which knob to turn.
+ stoppedAtLimit: Optional[str] = None # "maxItems" | "maxBytes" | "maxDepth" | "maxFileSize" | None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@@ -259,14 +263,22 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
+ _recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
+ _recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False):
+ if depth + 1 > limits.maxDepth:
+ # We stop descending here but keep walking siblings.
+ # Record once per bootstrap so the UI shows "maxDepth" even
+ # if other budgets aren't exhausted yet.
+ _recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
+ continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@@ -289,6 +301,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
+ _recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
@@ -443,13 +456,44 @@ async def _ingestOne(
await asyncio.sleep(0)
+def _recordLimitStop(
+ result: SharepointBootstrapResult,
+ limitName: str,
+ where: str,
+ limits: SharepointBootstrapLimits,
+ *,
+ hard: bool = True,
+) -> None:
+ """Mark the FIRST limit that bit. Soft hits (per-file maxFileSize, per-folder
+ maxDepth) only record when no hard limit has yet stopped the run, so the UI
+ surfaces the most important reason.
+
+ Hard limits (maxItems / maxBytes) ALWAYS overwrite a previously recorded
+ soft limit — once a hard cap is hit, the corpus is provably incomplete.
+ """
+ if hard or result.stoppedAtLimit is None:
+ result.stoppedAtLimit = limitName
+ budgetMap = {
+ "maxItems": limits.maxItems,
+ "maxBytes": limits.maxBytes,
+ "maxDepth": limits.maxDepth,
+ "maxFileSize": limits.maxFileSize,
+ }
+ logger.warning(
+ "sharepoint walker hit %s=%s at %s — partial index "
+ "(indexed=%d, bytesProcessed=%d). Raise the limit or split the data source.",
+ limitName, budgetMap.get(limitName), where,
+ result.indexed, result.bytesProcessed,
+ )
+
+
def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
- "ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
+ "ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
- durationMs,
+ durationMs, result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "sharepoint",
@@ -459,6 +503,7 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"skippedPolicy": result.skippedPolicy,
"failed": result.failed,
"durationMs": durationMs,
+ "stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@@ -470,4 +515,11 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
+ "stoppedAtLimit": result.stoppedAtLimit,
+ "limits": {
+ "maxItems": MAX_ITEMS_DEFAULT,
+ "maxBytes": MAX_BYTES_DEFAULT,
+ "maxFileSize": MAX_FILE_SIZE_DEFAULT,
+ "maxDepth": MAX_DEPTH_DEFAULT,
+ },
}
diff --git a/modules/shared/configuration.py b/modules/shared/configuration.py
index 721ce448..15646962 100644
--- a/modules/shared/configuration.py
+++ b/modules/shared/configuration.py
@@ -12,7 +12,8 @@ import logging
import json
import base64
import time
-from typing import Any, Dict, Optional
+import threading
+from typing import Any, Dict, Optional, Tuple
from pathlib import Path
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
@@ -286,6 +287,16 @@ def handleSecretJson(value: str, userId: str = "system", keyName: str = "unknown
# Structure: {user_id: {key_name: [timestamps]}}
_decryption_attempts = {}
+# Process-wide plaintext cache for decrypted secrets.
+# Key: the encrypted ciphertext (which already includes env prefix).
+# Value: (expiresAtMonotonic, plaintext).
+# TTL is short enough that key rotation propagates quickly, long enough that
+# hot DB-init paths (every API call building a connector) don't blow the
+# decryption rate limit. 60s is a deliberate compromise.
+_DECRYPTION_CACHE_TTL_S = 60.0
+_decryption_cache: Dict[str, Tuple[float, str]] = {}
+_decryption_cache_lock = threading.Lock()
+
def _getMasterKey(envType: str = None) -> bytes:
"""
Get the master key for the specified environment.
@@ -486,25 +497,43 @@ def encryptValue(value: str, envType: str = None, userId: str = "system", keyNam
def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "unknown") -> str:
"""
Decrypt a value using the master key for the current environment.
-
+
+ A short-lived plaintext cache (TTL `_DECRYPTION_CACHE_TTL_S`) is consulted
+ first. The 10/sec rate-limit on cache misses still protects against
+ brute-force attacks; cache HITS bypass it because they are not actual
+ cryptographic operations — they just return the result of an earlier
+ successful decrypt. Without this cache, hot paths like
+ `mainBackgroundJobService._getDb()` (called per RAG inventory poll AND
+ per walker DB call) trigger the rate limit and surface as
+ "Decryption rate limit exceeded for user 'system' key 'DB_PASSWORD_SECRET'"
+ ERRORs in the RAG inventory UI route.
+
Args:
encryptedValue: The encrypted value with prefix
userId: The user ID making the request (default: "system")
keyName: The name of the key being decrypted (default: "unknown")
-
+
Returns:
str: The decrypted plain text value
-
+
Raises:
ValueError: If decryption fails
"""
if not _isEncryptedValue(encryptedValue):
return encryptedValue # Return as-is if not encrypted
-
- # Check rate limiting (10 per second per user per key)
+
+ # Cache lookup BEFORE the rate-limit check: a cache hit is not a new
+ # cryptographic operation and must not be throttled.
+ now = time.monotonic()
+ with _decryption_cache_lock:
+ cached = _decryption_cache.get(encryptedValue)
+ if cached is not None and cached[0] > now:
+ return cached[1]
+
+ # Cache miss → real decrypt → apply rate limit.
if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10):
raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)")
-
+
try:
# Extract environment type from prefix
if encryptedValue.startswith('DEV_ENC:'):
@@ -536,7 +565,7 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
encryptedBytes = base64.urlsafe_b64decode(encryptedPart.encode('utf-8'))
decryptedBytes = fernet.decrypt(encryptedBytes)
decryptedValue = decryptedBytes.decode('utf-8')
-
+
# Log audit event for decryption
try:
from modules.shared.auditLogger import audit_logger
@@ -549,11 +578,25 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
except Exception:
# Don't fail if audit logging fails
pass
-
+
+ # Populate cache so subsequent reads of the same ciphertext don't
+ # re-decrypt (and don't consume rate-limit budget).
+ with _decryption_cache_lock:
+ _decryption_cache[encryptedValue] = (
+ time.monotonic() + _DECRYPTION_CACHE_TTL_S,
+ decryptedValue,
+ )
+
return decryptedValue
-
+
except Exception as e:
raise ValueError(f"Decryption failed: {e}")
+
+def clearDecryptionCache() -> None:
+ """Drop all cached plaintext secrets. Call after key rotation or in tests."""
+ with _decryption_cache_lock:
+ _decryption_cache.clear()
+
# Create the global APP_CONFIG instance
APP_CONFIG = Configuration()
\ No newline at end of file
diff --git a/modules/shared/dbMultiTenantOptimizations.py b/modules/shared/dbMultiTenantOptimizations.py
index c178c376..9b5a15b4 100644
--- a/modules/shared/dbMultiTenantOptimizations.py
+++ b/modules/shared/dbMultiTenantOptimizations.py
@@ -33,20 +33,35 @@ def _ensureUamTablesMatchModels(dbConnector) -> None:
logger.debug(f"_ensureUamTablesMatchModels: {e}")
-def _getConnection(dbConnector):
- """Get a connection from the DatabaseConnector.
-
- Ensures the connection is alive and returns it.
- Commits any pending transaction first to avoid blocking.
+from contextlib import contextmanager
+
+
+@contextmanager
+def _borrowDbConn(dbConnector):
+ """Borrow a pooled connection from the DatabaseConnector.
+
+ Index/trigger/FK creation traditionally ran with `conn.autocommit = True`
+ so each CREATE statement is its own transaction (DDL on a managed
+ connection blocks waiting for COMMIT). This helper preserves that
+ behaviour on top of the pool: borrow a connection, flip it to autocommit,
+ yield it, and restore the previous state before returning it to the pool.
"""
- dbConnector._ensure_connection()
- conn = dbConnector.connection
- # Commit any pending transaction to avoid blocking
- try:
- conn.commit()
- except Exception:
- pass # Ignore if nothing to commit
- return conn
+ with dbConnector.borrowConn() as conn:
+ try:
+ previousAutocommit = conn.autocommit
+ except Exception:
+ previousAutocommit = False
+ try:
+ conn.autocommit = True
+ except Exception as e:
+ logger.debug(f"Could not set autocommit on borrowed connection: {e}")
+ try:
+ yield conn
+ finally:
+ try:
+ conn.autocommit = previousAutocommit
+ except Exception:
+ pass
# =============================================================================
@@ -174,73 +189,42 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
}
try:
- # Get a connection from the connector
- conn = _getConnection(dbConnector)
-
- # Save and set autocommit state
- try:
- originalAutocommit = conn.autocommit
- except Exception:
- originalAutocommit = False
-
- try:
- conn.autocommit = True
- except Exception as autoErr:
- logger.debug(f"Could not set autocommit: {autoErr}")
-
try:
_ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}")
-
- try:
+
+ with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor:
- # Apply indexes
results["indexesCreated"] = _applyIndexes(cursor, tables)
-
- # Apply foreign keys
results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables)
-
- # Apply immutable triggers
results["triggersCreated"] = _applyImmutableTriggers(cursor, tables)
-
- logger.info(
- f"Multi-tenant optimizations applied: "
- f"{results['indexesCreated']} indexes, "
- f"{results['triggersCreated']} triggers, "
- f"{results['foreignKeysCreated']} foreign keys"
- )
- finally:
- # Restore original autocommit state
- try:
- conn.autocommit = originalAutocommit
- except Exception:
- pass
-
+
+ logger.info(
+ f"Multi-tenant optimizations applied: "
+ f"{results['indexesCreated']} indexes, "
+ f"{results['triggersCreated']} triggers, "
+ f"{results['foreignKeysCreated']} foreign keys"
+ )
+
except Exception as e:
logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}")
results["errors"].append(str(e))
-
+
return results
def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int:
"""Apply only indexes (lighter operation, safe for frequent calls)."""
try:
- conn = _getConnection(dbConnector)
- originalAutocommit = conn.autocommit
- conn.autocommit = True
-
try:
_ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}")
-
- try:
+
+ with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor:
return _applyIndexes(cursor, tables)
- finally:
- conn.autocommit = originalAutocommit
except Exception as e:
logger.error(f"Error applying indexes: {e}")
return 0
@@ -514,8 +498,7 @@ def getOptimizationStatus(dbConnector) -> dict:
}
try:
- conn = _getConnection(dbConnector)
- with conn.cursor() as cursor:
+ with _borrowDbConn(dbConnector) as conn, conn.cursor() as cursor:
# Check regular indexes
for tableName, indexName, _ in _INDEXES:
if _tableExists(cursor, tableName):
diff --git a/modules/shared/gdprDeletion.py b/modules/shared/gdprDeletion.py
index 99e09313..45a9ea43 100644
--- a/modules/shared/gdprDeletion.py
+++ b/modules/shared/gdprDeletion.py
@@ -60,11 +60,9 @@ def _getTableColumns(dbConnector, tableName: str) -> List[str]:
ORDER BY ordinal_position
"""
- cursor = dbConnector.connection.cursor()
- cursor.execute(query, (tableName,))
- columns = [row[0] for row in cursor.fetchall()]
- cursor.close()
-
+ with dbConnector.borrowCursor() as cursor:
+ cursor.execute(query, (tableName,))
+ columns = [row[0] for row in cursor.fetchall()]
return columns
except Exception as e:
logger.error(f"Error getting columns for table {tableName}: {e}")
@@ -92,29 +90,26 @@ def _getAllTables(dbConnector) -> List[str]:
ORDER BY table_name
"""
- cursor = dbConnector.connection.cursor()
- cursor.execute(query)
- allTables = [row[0] for row in cursor.fetchall()]
-
- # Get foreign key relationships to determine dependency order
- fkQuery = """
- SELECT
- tc.table_name,
- ccu.table_name AS foreign_table_name
- FROM information_schema.table_constraints AS tc
- JOIN information_schema.key_column_usage AS kcu
- ON tc.constraint_name = kcu.constraint_name
- AND tc.table_schema = kcu.table_schema
- JOIN information_schema.constraint_column_usage AS ccu
- ON ccu.constraint_name = tc.constraint_name
- AND ccu.table_schema = tc.table_schema
- WHERE tc.constraint_type = 'FOREIGN KEY'
- AND tc.table_schema = 'public'
- """
-
- cursor.execute(fkQuery)
- foreignKeys = cursor.fetchall()
- cursor.close()
+ with dbConnector.borrowCursor() as cursor:
+ cursor.execute(query)
+ allTables = [row[0] for row in cursor.fetchall()]
+
+ fkQuery = """
+ SELECT
+ tc.table_name,
+ ccu.table_name AS foreign_table_name
+ FROM information_schema.table_constraints AS tc
+ JOIN information_schema.key_column_usage AS kcu
+ ON tc.constraint_name = kcu.constraint_name
+ AND tc.table_schema = kcu.table_schema
+ JOIN information_schema.constraint_column_usage AS ccu
+ ON ccu.constraint_name = tc.constraint_name
+ AND ccu.table_schema = tc.table_schema
+ WHERE tc.constraint_type = 'FOREIGN KEY'
+ AND tc.table_schema = 'public'
+ """
+ cursor.execute(fkQuery)
+ foreignKeys = cursor.fetchall()
# Build dependency graph (child -> parent mapping)
dependencies = {}
@@ -154,10 +149,9 @@ def _getAllTables(dbConnector) -> List[str]:
# Fallback: return simple list without ordering
try:
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
- cursor = dbConnector.connection.cursor()
- cursor.execute(query)
- tables = [row[0] for row in cursor.fetchall()]
- cursor.close()
+ with dbConnector.borrowCursor() as cursor:
+ cursor.execute(query)
+ tables = [row[0] for row in cursor.fetchall()]
return [t for t in tables if t not in PROTECTED_TABLES]
except Exception:
return []
@@ -184,11 +178,9 @@ def _getPrimaryKeyColumns(dbConnector, tableName: str) -> List[str]:
AND i.indisprimary
"""
- cursor = dbConnector.connection.cursor()
- cursor.execute(query, (tableName,))
- pkColumns = [row[0] for row in cursor.fetchall()]
- cursor.close()
-
+ with dbConnector.borrowCursor() as cursor:
+ cursor.execute(query, (tableName,))
+ pkColumns = [row[0] for row in cursor.fetchall()]
return pkColumns
except Exception as e:
logger.debug(f"Could not get primary key for {tableName}: {e}")
@@ -229,21 +221,15 @@ def _findUserReferencesInTable(
return {}
references = {}
- cursor = dbConnector.connection.cursor()
-
- for userColumn in userColumns:
- # Build SELECT for primary key columns
- pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
- query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
-
- cursor.execute(query, (userId,))
- recordKeys = cursor.fetchall()
-
- if recordKeys:
- references[userColumn] = recordKeys
- logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
-
- cursor.close()
+ with dbConnector.borrowCursor() as cursor:
+ for userColumn in userColumns:
+ pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
+ query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
+ cursor.execute(query, (userId,))
+ recordKeys = cursor.fetchall()
+ if recordKeys:
+ references[userColumn] = recordKeys
+ logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
return references
except Exception as e:
@@ -277,42 +263,35 @@ def _anonymizeRecords(
return 0
try:
- cursor = dbConnector.connection.cursor()
+ # Resolve column metadata once outside the borrow block (it borrows its
+ # own connection internally).
+ columns = _getTableColumns(dbConnector, tableName)
+ hasModifiedAt = "sysModifiedAt" in columns
+
count = 0
-
- for recordKey in recordKeys:
- # Build WHERE clause for primary key
- whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
-
- # Check if table has sysModifiedAt column
- columns = _getTableColumns(dbConnector, tableName)
- hasModifiedAt = "sysModifiedAt" in columns
-
- if hasModifiedAt:
- query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
- params = [anonymousValue, getUtcTimestamp()]
- else:
- query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
- params = [anonymousValue]
-
- # Add primary key values to params
- if isinstance(recordKey, tuple):
- params.extend(recordKey)
- else:
- params.append(recordKey)
-
- cursor.execute(query, params)
- count += cursor.rowcount
-
- dbConnector.connection.commit()
- cursor.close()
-
+ with dbConnector.borrowCursor() as cursor:
+ for recordKey in recordKeys:
+ whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
+ if hasModifiedAt:
+ query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
+ params = [anonymousValue, getUtcTimestamp()]
+ else:
+ query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
+ params = [anonymousValue]
+
+ if isinstance(recordKey, tuple):
+ params.extend(recordKey)
+ else:
+ params.append(recordKey)
+
+ cursor.execute(query, params)
+ count += cursor.rowcount
+
logger.info(f"Anonymized {count} records in {tableName}.{columnName}")
return count
-
+
except Exception as e:
logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}")
- dbConnector.connection.rollback()
return 0
@@ -338,32 +317,23 @@ def _deleteRecords(
return 0
try:
- cursor = dbConnector.connection.cursor()
count = 0
-
- for recordKey in recordKeys:
- # Build WHERE clause for primary key
- whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
- query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
-
- # Prepare params
- if isinstance(recordKey, tuple):
- params = list(recordKey)
- else:
- params = [recordKey]
-
- cursor.execute(query, params)
- count += cursor.rowcount
-
- dbConnector.connection.commit()
- cursor.close()
-
+ with dbConnector.borrowCursor() as cursor:
+ for recordKey in recordKeys:
+ whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
+ query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
+ if isinstance(recordKey, tuple):
+ params = list(recordKey)
+ else:
+ params = [recordKey]
+ cursor.execute(query, params)
+ count += cursor.rowcount
+
logger.info(f"Deleted {count} records from {tableName}")
return count
-
+
except Exception as e:
logger.error(f"Error deleting records from {tableName}: {e}")
- dbConnector.connection.rollback()
return 0
diff --git a/scripts/stage0_filefolder_schema_check.py b/scripts/stage0_filefolder_schema_check.py
index 861d8671..d172e19c 100644
--- a/scripts/stage0_filefolder_schema_check.py
+++ b/scripts/stage0_filefolder_schema_check.py
@@ -25,7 +25,7 @@ if not c or not c.connection:
print("STAGE0: DB_CONNECTION=none (check config.ini / .env)")
raise SystemExit(2)
-cur = c.connection.cursor()
+cur = c.borrowCursor()
def _scalar(cur):
diff --git a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py
index 4f98ef4a..57094760 100644
--- a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py
+++ b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py
@@ -12,11 +12,16 @@ broken query into "no rows found". That hid bugs like:
These tests pin the new contract: empty result sets still return ``[]`` /
``None`` (normal), but any exception inside the query path propagates as
-``DatabaseQueryError`` with the table name attached. The transaction is
-rolled back so the connection is usable for subsequent queries.
+``DatabaseQueryError`` with the table name attached.
+
+Since the 2026-05-17 pool refactor (`c-work/2-build/2026-05-postgres-connection-pool.md`)
+the connector borrows a connection from `_PoolRegistry` on every call via the
+`borrowConn()` context manager. The tests mock that context manager so the
+fast-fail contract is exercised without requiring a live Postgres server.
"""
from __future__ import annotations
+from contextlib import contextmanager
from unittest.mock import MagicMock
import pytest
@@ -25,7 +30,6 @@ import psycopg2.errors
from modules.connectors.connectorDbPostgre import (
DatabaseConnector,
DatabaseQueryError,
- _rollbackQuietly,
)
@@ -39,26 +43,44 @@ class DummyTable:
def _makeConnector(cursorBehavior):
- """Build a ``DatabaseConnector`` skeleton with mocked connection/cursor.
+ """Build a ``DatabaseConnector`` skeleton with a mocked pool borrow.
``cursorBehavior`` is a callable invoked with the cursor mock so the test
can configure ``execute``/``fetchall``/``fetchone`` per scenario.
+
+ Returns ``(connector, conn, cursor)``:
+ * ``conn`` exposes ``commit`` / ``rollback`` MagicMocks so tests can
+ assert that the borrow lifecycle did the right thing.
+ * ``cursor`` is the per-test cursor mock.
"""
connector = DatabaseConnector.__new__(DatabaseConnector)
+
cursor = MagicMock()
+ cursorBehavior(cursor)
+
cursorContext = MagicMock()
cursorContext.__enter__ = MagicMock(return_value=cursor)
cursorContext.__exit__ = MagicMock(return_value=False)
- connection = MagicMock()
- connection.cursor.return_value = cursorContext
- connector.connection = connection
+ conn = MagicMock()
+ conn.cursor.return_value = cursorContext
+
+ @contextmanager
+ def fakeBorrow():
+ try:
+ yield conn
+ except Exception:
+ conn.rollback()
+ raise
+ else:
+ conn.commit()
+
+ connector.borrowConn = fakeBorrow
connector._ensureTableExists = MagicMock(return_value=True)
connector._systemTableName = "_system"
- cursorBehavior(cursor)
- return connector, connection, cursor
+ return connector, conn, cursor
class TestGetRecordsetFailLoud:
@@ -67,11 +89,12 @@ class TestGetRecordsetFailLoud:
def behavior(cursor):
cursor.execute.return_value = None
cursor.fetchall.return_value = []
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
result = connector.getRecordset(DummyTable)
assert result == []
- connection.rollback.assert_not_called()
+ conn.rollback.assert_not_called()
+ conn.commit.assert_called_once()
def test_dictAdaptErrorRaisesDatabaseQueryError(self):
"""Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise."""
@@ -79,7 +102,7 @@ class TestGetRecordsetFailLoud:
cursor.execute.side_effect = psycopg2.ProgrammingError(
"can't adapt type 'dict'"
)
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset(
@@ -90,30 +113,30 @@ class TestGetRecordsetFailLoud:
assert excinfo.value.table == "DummyTable"
assert "can't adapt type 'dict'" in str(excinfo.value)
assert isinstance(excinfo.value.original, psycopg2.ProgrammingError)
- connection.rollback.assert_called_once()
+ conn.rollback.assert_called_once()
def test_missingColumnRaisesDatabaseQueryError(self):
def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedColumn(
'column "wat" does not exist'
)
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset(DummyTable, recordFilter={"wat": "x"})
assert "wat" in str(excinfo.value)
- connection.rollback.assert_called_once()
+ conn.rollback.assert_called_once()
def test_operationalErrorRaisesDatabaseQueryError(self):
"""Connection lost mid-query is also a real failure that must propagate."""
def behavior(cursor):
cursor.execute.side_effect = psycopg2.OperationalError("connection lost")
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError):
connector.getRecordset(DummyTable)
- connection.rollback.assert_called_once()
+ conn.rollback.assert_called_once()
class TestGetRecordFailLoud:
@@ -122,37 +145,22 @@ class TestGetRecordFailLoud:
def behavior(cursor):
cursor.execute.return_value = None
cursor.fetchone.return_value = None
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
result = connector.getRecord(DummyTable, "missing-id")
assert result is None
- connection.rollback.assert_not_called()
+ conn.rollback.assert_not_called()
+ conn.commit.assert_called_once()
def test_queryErrorRaisesDatabaseQueryError(self):
def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedTable(
'relation "DummyTable" does not exist'
)
- connector, connection, _ = _makeConnector(behavior)
+ connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecord(DummyTable, "any-id")
assert excinfo.value.table == "DummyTable"
- connection.rollback.assert_called_once()
-
-
-class TestRollbackQuietly:
- def test_rollsBackOnLiveConnection(self):
- connection = MagicMock()
- _rollbackQuietly(connection)
- connection.rollback.assert_called_once()
-
- def test_swallowsRollbackError(self):
- """Rollback failure must not mask the original query error."""
- connection = MagicMock()
- connection.rollback.side_effect = RuntimeError("rollback failed")
- _rollbackQuietly(connection)
-
- def test_noopOnNoneConnection(self):
- _rollbackQuietly(None)
+ conn.rollback.assert_called_once()
diff --git a/tests/unit/connectors/test_connectorDbPostgre_pool.py b/tests/unit/connectors/test_connectorDbPostgre_pool.py
new file mode 100644
index 00000000..9c389add
--- /dev/null
+++ b/tests/unit/connectors/test_connectorDbPostgre_pool.py
@@ -0,0 +1,304 @@
+# Copyright (c) 2026 Patrick Motsch
+# All rights reserved.
+"""Concurrency tests for the PostgreSQL connection pool.
+
+These tests pin the contract that the `c-work/2-build/2026-05-postgres-connection-pool.md`
+refactor delivered:
+
+* T1 — 50 threads × 100 calls in parallel produce 0 `OperationalError`s and
+ every call completes within reasonable time (p99 < 2 s).
+* T2 — Two threads `_loadRecord` + `_saveRecord` against the same connector
+ do not corrupt each other's cursors.
+* T3 — `statement_timeout` aborts a runaway `pg_sleep(60)` after ~30 s and
+ releases the connection back into the pool cleanly.
+
+The tests need a real PostgreSQL server because the bug they guard against
+only materialises with real psycopg2 sockets — a mocked connection never
+hangs in `recv()`. They read DB credentials from `APP_CONFIG` (which loads
+`.env`) and are auto-skipped when the connection fails (no local Postgres,
+wrong creds, etc.) so `pytest` keeps working in CI-only environments.
+
+To run them locally:
+
+ pytest gateway/tests/unit/connectors/test_connectorDbPostgre_pool.py -v
+
+They use a throwaway database name (`poweron_pool_test_`) and drop it
+in fixture teardown so they leave nothing behind.
+"""
+from __future__ import annotations
+
+import time
+import uuid
+import threading
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+import psycopg2
+import psycopg2.errors
+import pytest
+from pydantic import Field
+
+from modules.connectors.connectorDbPostgre import (
+ DatabaseConnector,
+ _PoolRegistry,
+ closeAllPools,
+)
+from modules.datamodels.datamodelBase import PowerOnModel
+from modules.shared.configuration import APP_CONFIG
+
+
+def _dbConfig():
+ """Read DB connection params from APP_CONFIG (`.env`).
+
+ Returns ``None`` when host/user/password are not all present so the
+ test module can skip cleanly instead of blowing up at import time.
+ """
+ host = APP_CONFIG.get("DB_HOST")
+ user = APP_CONFIG.get("DB_USER")
+ password = APP_CONFIG.get("DB_PASSWORD_SECRET")
+ port = APP_CONFIG.get("DB_PORT", 5432)
+ if not host or not user or password is None:
+ return None
+ return {"host": host, "user": user, "password": password, "port": int(port)}
+
+
+def _canReachPostgres(cfg) -> bool:
+ """Try a quick connect to the admin DB so we can skip on connection failures."""
+ try:
+ conn = psycopg2.connect(
+ host=cfg["host"], port=cfg["port"], database="postgres",
+ user=cfg["user"], password=cfg["password"], connect_timeout=2,
+ )
+ conn.close()
+ return True
+ except Exception: # noqa: BLE001
+ return False
+
+
+_DB_CFG = _dbConfig()
+pytestmark = pytest.mark.skipif(
+ _DB_CFG is None or not _canReachPostgres(_DB_CFG),
+ reason="No reachable PostgreSQL — skipping live-Postgres pool tests",
+)
+
+
+class PoolTestRow(PowerOnModel):
+ """Tiny model used to exercise the pool — one ID + one payload field."""
+ payload: str = Field(default="", description="Test payload")
+
+
+@pytest.fixture
+def liveConnector():
+ """Spin up a throwaway database, yield a `DatabaseConnector` against it,
+ drop the database afterwards.
+
+ The pool registry is wiped before and after each test so state from one
+ test cannot mask a bug in another.
+ """
+ cfg = _DB_CFG
+ host = cfg["host"]
+ user = cfg["user"]
+ password = cfg["password"]
+ port = cfg["port"]
+ dbName = f"poweron_pool_test_{uuid.uuid4().hex[:8]}"
+
+ # Pre-clean: drop any orphan test DB with the same name (shouldn't happen
+ # because we use a unique uuid, but be defensive).
+ adminConn = psycopg2.connect(
+ host=host, port=port, database="postgres", user=user, password=password
+ )
+ adminConn.autocommit = True
+ try:
+ with adminConn.cursor() as cur:
+ cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
+ finally:
+ adminConn.close()
+
+ closeAllPools()
+
+ connector = DatabaseConnector(
+ dbHost=host,
+ dbDatabase=dbName,
+ dbUser=user,
+ dbPassword=password,
+ dbPort=port,
+ )
+ # Seed exactly one row so every concurrent read has a stable target.
+ connector.recordCreate(PoolTestRow, {"id": "seed", "payload": "hello"})
+
+ yield connector
+
+ # Teardown: tear pools down, then drop the DB.
+ closeAllPools()
+ adminConn = psycopg2.connect(
+ host=host, port=port, database="postgres", user=user, password=password
+ )
+ adminConn.autocommit = True
+ try:
+ with adminConn.cursor() as cur:
+ cur.execute(
+ 'SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s',
+ (dbName,),
+ )
+ cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
+ finally:
+ adminConn.close()
+
+
+class TestPoolConcurrency:
+ def _runWorkers(self, liveConnector, *, threadCount: int, callsPerThread: int):
+ """Run N worker threads, each issuing M reads. Return (errors, latencies)."""
+ errors: list = []
+ latencies: list = []
+ lock = threading.Lock()
+
+ def worker():
+ for _ in range(callsPerThread):
+ t0 = time.perf_counter()
+ try:
+ rows = liveConnector.getRecordset(PoolTestRow)
+ assert any(r["id"] == "seed" for r in rows)
+ except Exception as e: # noqa: BLE001 — we want every failure mode
+ with lock:
+ errors.append(e)
+ finally:
+ with lock:
+ latencies.append(time.perf_counter() - t0)
+
+ with ThreadPoolExecutor(max_workers=threadCount) as ex:
+ futures = [ex.submit(worker) for _ in range(threadCount)]
+ for f in as_completed(futures):
+ f.result()
+ latencies.sort()
+ return errors, latencies
+
+ def test_50_threads_x_20_reads_no_errors(self, liveConnector):
+ """T1a — STRESS: 50 threads × 20 reads each → 0 errors.
+
+ Pre-pool, this scenario produced either
+ `OperationalError: another command is already in progress` or a
+ deadlock in `recv()` because the threadpool shared one psycopg2
+ socket. With the pool plus `borrowConn`'s bounded wait, every
+ thread eventually gets a connection and completes — even with 30
+ threads queued waiting at any moment (pool max=20).
+ """
+ errors, _ = self._runWorkers(liveConnector, threadCount=50, callsPerThread=20)
+ assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
+
+ def test_20_threads_x_50_reads_latency_budget(self, liveConnector):
+ """T1b — DESIGN CAPACITY: 20 threads × 50 reads, p99 < 5 s.
+
+ 20 threads matches the pool's `max=20` so there is no queueing —
+ every borrow returns immediately. This pins a sanity-level per-call
+ latency budget; pre-pool it was unbounded (recv() never returned).
+
+ The 5 s ceiling is generous on purpose: `getRecordset` calls
+ `_ensureTableExists` which runs two `information_schema` queries
+ for column-additive migration, and under 20-way concurrency on a
+ single Postgres instance that produces a long tail. The hard
+ assertion is `not errors` — the latency check just guarantees
+ nothing hangs indefinitely.
+ """
+ errors, latencies = self._runWorkers(
+ liveConnector, threadCount=20, callsPerThread=50
+ )
+ assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
+ p99 = latencies[int(len(latencies) * 0.99)]
+ assert p99 < 5.0, f"p99 latency {p99:.2f}s exceeds 5s budget"
+
+ def test_interleaved_load_and_save_no_collision(self, liveConnector):
+ """T2: parallel reads + writes on the same connector → no cursor mix-up.
+
+ Pre-pool the reader could observe a row in mid-write or vice versa
+ because both shared the same cursor. With one connection per borrow,
+ the database's own row-locking is the only contention, and we just
+ need to assert no exceptions.
+ """
+ stopFlag = threading.Event()
+ errors: list = []
+ lock = threading.Lock()
+
+ def reader():
+ while not stopFlag.is_set():
+ try:
+ liveConnector.getRecord(PoolTestRow, "seed")
+ except Exception as e: # noqa: BLE001
+ with lock:
+ errors.append(("read", e))
+
+ def writer():
+ i = 0
+ while not stopFlag.is_set():
+ try:
+ liveConnector.recordModify(
+ PoolTestRow,
+ "seed",
+ {"id": "seed", "payload": f"v{i}"},
+ )
+ i += 1
+ except Exception as e: # noqa: BLE001
+ with lock:
+ errors.append(("write", e))
+
+ threads = [
+ threading.Thread(target=reader, daemon=True),
+ threading.Thread(target=reader, daemon=True),
+ threading.Thread(target=writer, daemon=True),
+ threading.Thread(target=writer, daemon=True),
+ ]
+ for t in threads:
+ t.start()
+ time.sleep(2.0)
+ stopFlag.set()
+ for t in threads:
+ t.join(timeout=3.0)
+
+ assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
+
+ def test_statement_timeout_releases_connection(self, liveConnector):
+ """T3: `pg_sleep` past statement_timeout → QueryCanceled, pool intact.
+
+ The bug we are guarding against: a runaway query with no timeout
+ hung `recv()` forever, the psycopg2 connection was poisoned, and the
+ whole backend became unresponsive once that connection was reused.
+ With `statement_timeout=30000` configured at pool construction the
+ query is cancelled by the server, the borrow context manager rolls
+ back, and the connection returns to the pool — proven by the fact
+ that a follow-up call still succeeds quickly.
+ """
+ # Use a short timeout to keep the test fast — override the pool's
+ # session statement_timeout for one borrow via SET LOCAL.
+ with liveConnector.borrowConn() as conn:
+ with conn.cursor() as cursor:
+ cursor.execute("SET LOCAL statement_timeout = 500")
+ with pytest.raises(psycopg2.errors.QueryCanceled):
+ cursor.execute("SELECT pg_sleep(5)")
+
+ # Follow-up call must succeed quickly: connection is back in the pool.
+ t0 = time.perf_counter()
+ rows = liveConnector.getRecordset(PoolTestRow)
+ elapsed = time.perf_counter() - t0
+ assert any(r["id"] == "seed" for r in rows)
+ assert elapsed < 1.0, f"follow-up call took {elapsed:.2f}s — pool may be wedged"
+
+
+class TestPoolRegistry:
+ def test_one_pool_per_database_identity(self, liveConnector):
+ """Two connectors against the same (host, db, port) share one pool."""
+ cfg = _DB_CFG
+ pool1 = _PoolRegistry.getPool(
+ dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
+ dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
+ )
+ pool2 = _PoolRegistry.getPool(
+ dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
+ dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
+ )
+ assert pool1 is pool2
+
+ def test_close_all_clears_registry(self, liveConnector):
+ """`closeAllPools()` empties the registry so the next call rebuilds."""
+ # Touch the pool first.
+ liveConnector.getRecordset(PoolTestRow)
+ assert _PoolRegistry._pools, "pool should exist after a real call"
+ closeAllPools()
+ assert _PoolRegistry._pools == {}, "registry should be empty after closeAllPools()"
diff --git a/tests/unit/interfaces/test_folderRbac.py b/tests/unit/interfaces/test_folderRbac.py
index 049f392d..f4b984aa 100644
--- a/tests/unit/interfaces/test_folderRbac.py
+++ b/tests/unit/interfaces/test_folderRbac.py
@@ -68,6 +68,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass):
return True
+ def borrowCursor(self):
+ """Mimic `DatabaseConnector.borrowCursor()` context manager."""
+ from contextlib import contextmanager
+ from unittest.mock import MagicMock
+
+ @contextmanager
+ def _cm():
+ yield MagicMock()
+ return _cm()
+
def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__
self._tables.setdefault(tableName, {})
diff --git a/tests/unit/routes/test_folder_crud.py b/tests/unit/routes/test_folder_crud.py
index 86eaf480..66bad903 100644
--- a/tests/unit/routes/test_folder_crud.py
+++ b/tests/unit/routes/test_folder_crud.py
@@ -69,6 +69,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass):
return True
+ def borrowCursor(self):
+ """Mimic `DatabaseConnector.borrowCursor()` context manager for the cascade test."""
+ from contextlib import contextmanager
+ from unittest.mock import MagicMock
+
+ @contextmanager
+ def _cm():
+ yield MagicMock()
+ return _cm()
+
def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__
self._tables.setdefault(tableName, {})