db connection pooling and rag limit transparency

This commit is contained in:
ValueOn AG 2026-05-17 20:38:37 +02:00
parent f5aba4bf99
commit 2bb65c2303
23 changed files with 1519 additions and 782 deletions

9
app.py
View file

@ -439,6 +439,15 @@ async def lifespan(app: FastAPI):
except Exception as e:
logger.warning(f"Could not shutdown feature containers: {e}")
# --- Close all PostgreSQL connection pools ---
# Must run LAST: feature `onStop` hooks may still issue DB calls during
# shutdown. Once we tear down the pools, no more borrows are possible.
try:
from modules.connectors.connectorDbPostgre import closeAllPools
closeAllPools()
except Exception as e:
logger.warning(f"Closing DB connection pools failed: {e}")
logger.info("Application has been shut down")

View file

@ -2,9 +2,12 @@
# All rights reserved.
import contextvars
import re
import time
import psycopg2
import psycopg2.extras
import psycopg2.pool
import logging
from contextlib import contextmanager
from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type
import uuid
from pydantic import BaseModel, Field
@ -44,24 +47,6 @@ class DatabaseQueryError(RuntimeError):
self.original = original
def _rollbackQuietly(connection) -> None:
"""Restore the connection state after a failed query.
Postgres puts the connection in an error state after any failed
statement; subsequent queries on the same connection raise
``InFailedSqlTransaction`` until we rollback. We swallow rollback
errors because the original query error is what the caller should
see a secondary rollback failure typically means the connection
is gone and will be reopened on the next ``_ensure_connection``.
"""
if connection is None:
return
try:
connection.rollback()
except Exception:
pass
class SystemTable(PowerOnModel):
"""Data model for system table entries"""
@ -203,9 +188,174 @@ def _quotePgIdent(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"'
# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
# contextvars to avoid races when concurrent requests share the same connector.
# ---------------------------------------------------------------------------
# Connection pool registry
# ---------------------------------------------------------------------------
# psycopg2 connections are NOT thread-safe; sharing one connection across the
# FastAPI threadpool (sync `def` routes) or across multiple async tasks that
# happen to be active simultaneously results in either
# `OperationalError: another command is already in progress` or — far worse —
# an unbounded hang in `recv()` on the connection socket, because two cursors
# wait for one server response.
#
# `_PoolRegistry` keeps one `ThreadedConnectionPool` per database identity
# (`host`, `db`, `port`). Every DB call borrows a connection from the pool,
# runs its query, and returns the connection. The pool itself guarantees that
# no two callers share the same psycopg2 connection at the same time.
#
# `statement_timeout=30000` (30 s) is a safety net: a runaway query is aborted
# instead of hanging forever and poisoning the connection. `connect_timeout=10`
# prevents indefinite blocking when the Postgres host is unreachable.
_DEFAULT_POOL_MIN = 2
_DEFAULT_POOL_MAX = 20
_STATEMENT_TIMEOUT_MS = 30000
_CONNECT_TIMEOUT_S = 10
# psycopg2.pool.ThreadedConnectionPool.getconn() does NOT block when the pool
# is exhausted — it raises `psycopg2.pool.PoolError` immediately. That makes
# bursty workloads (50 concurrent route calls against a max=20 pool) fail
# spuriously. `borrowConn()` therefore retries with a small backoff up to
# `_BORROW_WAIT_TIMEOUT_S` seconds before giving up.
_BORROW_WAIT_TIMEOUT_S = 30.0
_BORROW_WAIT_BACKOFF_S = 0.05
def _resolvePoolMax() -> int:
"""Pool size is configurable via `DB_POOL_MAX_CONN` (default 20)."""
try:
return max(2, int(APP_CONFIG.get("DB_POOL_MAX_CONN") or _DEFAULT_POOL_MAX))
except (ValueError, TypeError):
return _DEFAULT_POOL_MAX
class _PoolRegistry:
"""Process-wide registry of `ThreadedConnectionPool` instances.
Keyed by `(host, database, port)` so that two `DatabaseConnector` instances
pointing at the same physical database share one pool. Lazy-initialised and
thread-safe.
"""
_pools: Dict[tuple, psycopg2.pool.ThreadedConnectionPool] = {}
_lock = threading.Lock()
@classmethod
def getPool(
cls,
*,
dbHost: str,
dbDatabase: str,
dbUser: str,
dbPassword: str,
dbPort: int,
) -> psycopg2.pool.ThreadedConnectionPool:
port = int(dbPort) if dbPort is not None else 5432
key = (dbHost, dbDatabase, port)
# Fast path: pool exists.
pool = cls._pools.get(key)
if pool is not None:
return pool
# Slow path: create exactly one pool per key, even under contention.
with cls._lock:
pool = cls._pools.get(key)
if pool is not None:
return pool
poolMax = _resolvePoolMax()
options = f"-c statement_timeout={_STATEMENT_TIMEOUT_MS}"
try:
pool = psycopg2.pool.ThreadedConnectionPool(
_DEFAULT_POOL_MIN,
poolMax,
host=dbHost,
port=port,
database=dbDatabase,
user=dbUser,
password=dbPassword,
client_encoding="utf8",
cursor_factory=psycopg2.extras.RealDictCursor,
connect_timeout=_CONNECT_TIMEOUT_S,
options=options,
)
except Exception as e:
logger.error(
"Failed to create connection pool for db=%s host=%s: %s",
dbDatabase, dbHost, e,
)
raise
cls._pools[key] = pool
logger.debug(
"Created connection pool for db=%s host=%s port=%s (min=%d max=%d, stmt_timeout=%dms)",
dbDatabase, dbHost, port, _DEFAULT_POOL_MIN, poolMax, _STATEMENT_TIMEOUT_MS,
)
return pool
@classmethod
def closeAll(cls) -> None:
"""Close every pool. Call once during FastAPI shutdown."""
with cls._lock:
for key, pool in list(cls._pools.items()):
try:
pool.closeall()
except Exception as e:
logger.warning("Error closing pool %s: %s", key, e)
cls._pools.clear()
logger.info("All database connection pools closed")
def closeAllPools() -> None:
"""Public entry point for FastAPI lifespan shutdown hook."""
_PoolRegistry.closeAll()
class _ConnectionShim:
"""Backward-compatibility stand-in for the old `DatabaseConnector.connection`.
Old callers used patterns like::
db.connection.commit()
if db.connection and not db.connection.closed: ...
These are now no-ops because the pool owns the connection lifecycle.
Direct cursor access through `self.connection.cursor()` is intentionally
blocked with a clear error so silent breakage is impossible every such
call site must migrate to `db.borrowCursor()`.
"""
closed = False
def __bool__(self) -> bool:
return True
def commit(self) -> None:
return
def rollback(self) -> None:
return
def close(self) -> None:
return
def cursor(self, *args, **kwargs):
raise RuntimeError(
"DatabaseConnector.connection.cursor() is no longer supported. "
"Use `db.borrowCursor()` (or `db.borrowConn()` for multi-statement "
"transactions) so the connection is borrowed from and returned to "
"the pool correctly."
)
_CONNECTION_SHIM = _ConnectionShim()
# ---------------------------------------------------------------------------
# Connector cache (lightweight wrappers — actual connections live in the pool)
# ---------------------------------------------------------------------------
# Multiple call sites (`routeI18n`, `aiAuditLogger`, `mainBackgroundJobService`,
# interfaces) ask for a connector via `getCachedConnector(...)` and expect to
# get back the same object on subsequent calls. Now that real connection
# multiplexing happens at the pool layer, the cache returns lightweight
# `DatabaseConnector` wrappers — they hold no connection themselves, only the
# DSN params and a reference to the shared pool.
_MAX_CACHED_CONNECTORS = 32
_connector_cache: Dict[tuple, "DatabaseConnector"] = {}
_connector_cache_order: List[tuple] = [] # FIFO order for eviction
@ -223,22 +373,36 @@ def getCachedConnector(
dbPort: int = None,
userId: str = None,
) -> "DatabaseConnector":
"""Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits.
Uses contextvars for userId so concurrent requests sharing the same connector get correct sysCreatedBy/sysModifiedBy.
"""Return a cached `DatabaseConnector` wrapper for `(host, database, port)`.
Two layers of caching, both intentional:
1. **Pool layer** (`_PoolRegistry`) owns the actual psycopg2 connections.
Per `(host, db, port)` exactly one `ThreadedConnectionPool`, shared
across every wrapper that points at the same physical database. This
is what actually saves Postgres connection slots.
2. **Wrapper layer** (this function) caches the `DatabaseConnector`
Python object. Each wrapper triggers an `initDbSystem()` on first
instantiation (CREATE DATABASE if missing, CREATE TABLE _system,
pool warm-up). Caching the wrapper avoids paying that bootstrap cost
on every request and stops the log from filling with "PostgreSQL
database system initialized" lines.
`userId` is request-scoped via the `_current_user_id` contextvar so two
concurrent requests sharing the same cached wrapper still produce
correct `sysCreatedBy` / `sysModifiedBy` audit fields.
"""
port = int(dbPort) if dbPort is not None else 5432
key = (dbHost, dbDatabase, port)
with _connector_cache_lock:
if key not in _connector_cache:
# Evict oldest if at capacity
# FIFO eviction. Connectors are now lightweight (no per-instance
# connection), so eviction is purely a memory bookkeeping concern;
# the underlying pool stays alive in `_PoolRegistry`.
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
oldest_key = _connector_cache_order.pop(0)
if oldest_key in _connector_cache:
try:
_connector_cache[oldest_key].close(forceClose=True)
except Exception as e:
logger.warning(f"Error closing evicted connector: {e}")
del _connector_cache[oldest_key]
_connector_cache.pop(oldest_key, None)
_connector_cache[key] = DatabaseConnector(
dbHost=dbHost,
dbDatabase=dbDatabase,
@ -282,34 +446,38 @@ class DatabaseConnector:
# Set userId (default to empty string if None)
self.userId = userId if userId is not None else ""
# Initialize database system first (creates database if needed)
self.connection = None
# No per-instance connection any more — real connections live in the
# shared `_PoolRegistry` pool. `_isCachedShared` is retained because
# `close(forceClose=False)` callers (interface __del__) still ask.
self._isCachedShared = False
self.initDbSystem()
# No caching needed with proper database - PostgreSQL handles performance
# Thread safety
self._lock = threading.Lock()
# pgvector extension state
# pgvector extension state (cached per connector instance — cheap)
self._vectorExtensionEnabled = False
# Initialize system table
# System table bootstrap: create database, system table, ensure metadata.
self._systemTableName = "_system"
self.initDbSystem()
self._initializeSystemTable()
def initDbSystem(self):
"""Initialize the database system - creates database and tables."""
try:
# Create database if it doesn't exist
self._create_database_if_not_exists()
"""Bootstrap the physical database and the `_system` metadata table.
# Create tables
Uses a short-lived autocommit connection (NOT the pool) because
`CREATE DATABASE` cannot run inside a transaction block. Also warms up
the pool by acquiring it for this `(host, db, port)` once.
"""
try:
self._create_database_if_not_exists()
self._create_tables()
# Establish connection to the database
self._connect()
# Warm the pool so the first request doesn't pay for socket setup.
_PoolRegistry.getPool(
dbHost=self.dbHost,
dbDatabase=self.dbDatabase,
dbUser=self.dbUser,
dbPassword=self.dbPassword,
dbPort=self.dbPort,
)
logger.debug(
"PostgreSQL database system initialized (db=%s, host=%s, port=%s)",
@ -319,10 +487,121 @@ class DatabaseConnector:
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
raise
def _create_database_if_not_exists(self):
"""Create the database if it doesn't exist."""
@property
def connection(self) -> "_ConnectionShim":
"""Backward-compat shim — see `_ConnectionShim` docstring."""
return _CONNECTION_SHIM
def _ensure_connection(self) -> None:
"""No-op for backward compatibility.
Previously this method tested-or-reconnected the per-instance socket.
With pooling, the `ThreadedConnectionPool` re-establishes broken
connections lazily on the next `getconn()`. Kept as a no-op so that
legacy call sites continue to compile.
"""
return
@contextmanager
def borrowCursor(self):
"""Borrow a cursor for one short SQL block.
Convenience wrapper around `borrowConn()` for the common pattern:
with db.borrowCursor() as cursor:
cursor.execute(...)
rows = cursor.fetchall()
Replaces the legacy `with db.connection.cursor() as cursor:` pattern.
Commit/rollback and pool return are handled automatically callers
must NOT call `.commit()`/`.rollback()` on the cursor themselves.
"""
with self.borrowConn() as conn:
with conn.cursor() as cursor:
yield cursor
@contextmanager
def borrowConn(self):
"""Borrow a connection from the pool for the duration of the block.
Pool-exhaustion semantics: `ThreadedConnectionPool.getconn()` raises
`psycopg2.pool.PoolError` immediately when the pool is at its `maxconn`
limit it does NOT block. We wrap that with a bounded busy-wait so
bursty workloads (e.g. 50 concurrent route calls against a max=20
pool) queue up instead of failing. If no connection becomes available
within `_BORROW_WAIT_TIMEOUT_S`, the `PoolError` is propagated so a
genuinely deadlocked pool surfaces as a real error.
On normal exit the current transaction is committed (idempotent for
read-only queries leaves the connection in a clean state for the
next borrower). On exception the transaction is rolled back and the
exception propagates. The connection is **always** returned to the
pool, even when commit/rollback fails otherwise a single bad query
would leak a slot and eventually exhaust the pool.
"""
pool = _PoolRegistry.getPool(
dbHost=self.dbHost,
dbDatabase=self.dbDatabase,
dbUser=self.dbUser,
dbPassword=self.dbPassword,
dbPort=self.dbPort,
)
conn = self._acquireConn(pool)
try:
yield conn
except Exception:
try:
conn.rollback()
except Exception:
pass
raise
else:
# Best-effort commit so the connection goes back to the pool with
# no in-flight transaction. Failure here is non-fatal — the pool
# will re-establish the socket on the next `getconn()` if needed.
try:
conn.commit()
except Exception:
try:
conn.rollback()
except Exception:
pass
finally:
try:
pool.putconn(conn)
except Exception as e:
logger.warning("Failed to return connection to pool: %s", e)
@staticmethod
def _acquireConn(pool: psycopg2.pool.ThreadedConnectionPool):
"""Get a connection from the pool, waiting up to `_BORROW_WAIT_TIMEOUT_S`.
psycopg2's pool throws on exhaustion instead of queueing — this helper
polls with a short backoff so callers see queue semantics.
"""
deadline = time.monotonic() + _BORROW_WAIT_TIMEOUT_S
attempt = 0
while True:
try:
return pool.getconn()
except psycopg2.pool.PoolError as e:
attempt += 1
if time.monotonic() >= deadline:
logger.error(
"Connection pool exhausted after %.1fs wait (%d retries)",
_BORROW_WAIT_TIMEOUT_S, attempt,
)
raise
time.sleep(_BORROW_WAIT_BACKOFF_S)
def _create_database_if_not_exists(self):
"""Create the database if it doesn't exist.
Uses an autocommit connection on the `postgres` admin DB because
`CREATE DATABASE` cannot run inside a transaction block so this
path intentionally does NOT use the pool.
"""
try:
# Use the configured user for database creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
@ -330,22 +609,20 @@ class DatabaseConnector:
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
connect_timeout=_CONNECT_TIMEOUT_S,
)
conn.autocommit = True
try:
with conn.cursor() as cursor:
# Check if database exists
cursor.execute(
"SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
)
exists = cursor.fetchone()
if not exists:
# Create database with proper quoting for names with hyphens
quoted_db_name = f'"{self.dbDatabase}"'
cursor.execute(f"CREATE DATABASE {quoted_db_name}")
logger.info(f"Created database: {self.dbDatabase}")
finally:
conn.close()
except Exception as e:
@ -356,9 +633,12 @@ class DatabaseConnector:
)
def _create_tables(self):
"""Create only the system table - application tables are created by interfaces."""
"""Create the `_system` table.
Uses a short-lived autocommit connection (not the pool) runs exactly
once at connector creation.
"""
try:
# Use the configured user for table creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
@ -366,11 +646,11 @@ class DatabaseConnector:
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
connect_timeout=_CONNECT_TIMEOUT_S,
)
conn.autocommit = True
try:
with conn.cursor() as cursor:
# Create only the system table
cursor.execute("""
CREATE TABLE IF NOT EXISTS _system (
id SERIAL PRIMARY KEY,
@ -382,6 +662,7 @@ class DatabaseConnector:
"sysModifiedBy" VARCHAR(255)
)
""")
finally:
conn.close()
except Exception as e:
@ -391,67 +672,26 @@ class DatabaseConnector:
)
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
def _connect(self):
"""Establish connection to PostgreSQL database."""
try:
# Use configured user for main connection with proper parameter handling
self.connection = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database=self.dbDatabase,
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
cursor_factory=psycopg2.extras.RealDictCursor,
)
self.connection.autocommit = False # Use transactions
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
raise
def _ensure_connection(self):
"""Ensure database connection is alive, reconnect if necessary."""
try:
if self.connection is None or self.connection.closed:
self._connect()
else:
# Test connection with a simple query
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1")
except Exception as e:
logger.warning(f"Connection lost, reconnecting: {e}")
self._connect()
def _initializeSystemTable(self):
"""Initializes the system table if it doesn't exist yet."""
try:
# First ensure the system table exists
self._ensureTableExists(SystemTable)
with self.connection.cursor() as cursor:
# Check if system table has any data
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute('SELECT COUNT(*) FROM "_system"')
row = cursor.fetchone()
count = row["count"] if row else 0
self.connection.commit()
cursor.fetchone() # noqa: just verifies table is readable
except Exception as e:
logger.error(f"Error initializing system table: {e}")
self.connection.rollback()
raise
def _loadSystemTable(self) -> Dict[str, str]:
"""Loads the system table with the initial IDs."""
try:
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
rows = cursor.fetchall()
system_data = {}
for row in rows:
system_data[row["table_name"]] = row["initial_id"]
return system_data
return {row["table_name"]: row["initial_id"] for row in rows}
except Exception as e:
logger.error(f"Error loading system table: {e}")
return {}
@ -459,11 +699,9 @@ class DatabaseConnector:
def _saveSystemTable(self, data: Dict[str, str]) -> bool:
"""Saves the system table with the initial IDs."""
try:
with self.connection.cursor() as cursor:
# Clear existing data
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute('DELETE FROM "_system"')
# Insert new data
for table_name, initial_id in data.items():
cursor.execute(
"""
@ -472,21 +710,16 @@ class DatabaseConnector:
""",
(table_name, initial_id, getUtcTimestamp()),
)
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error saving system table: {e}")
self.connection.rollback()
return False
def _ensureSystemTableExists(self) -> bool:
"""Ensures the system table exists, creates it if it doesn't."""
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
# Check if system table exists
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
(self._systemTableName,),
@ -494,7 +727,6 @@ class DatabaseConnector:
exists = cursor.fetchone()["count"] > 0
if not exists:
# Create system table
cursor.execute(f"""
CREATE TABLE "{self._systemTableName}" (
"table_name" VARCHAR(255) PRIMARY KEY,
@ -507,7 +739,6 @@ class DatabaseConnector:
""")
logger.info("System table created successfully")
else:
# Check if we need to add missing columns to existing table
cursor.execute(
"""
SELECT column_name FROM information_schema.columns
@ -527,7 +758,6 @@ class DatabaseConnector:
cursor.execute(
f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
)
return True
except Exception as e:
logger.error(f"Error ensuring system table exists: {e}")
@ -542,10 +772,8 @@ class DatabaseConnector:
return self._ensureSystemTableExists()
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
# Check if table exists by querying information_schema with case-insensitive search
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT COUNT(*) FROM information_schema.tables
@ -556,7 +784,6 @@ class DatabaseConnector:
exists = cursor.fetchone()["count"] > 0
if not exists:
# Create table from Pydantic model
self._create_table_from_model(cursor, table, model_class)
logger.info(
f"Created table '{table}' with columns from Pydantic model"
@ -581,15 +808,12 @@ class DatabaseConnector:
for row in existing_column_rows
}
# Desired columns based on model
model_fields = getModelFields(model_class)
desired_columns = set(["id"]) | set(model_fields.keys())
# Add missing columns
for col in sorted(desired_columns - existing_columns):
# Determine SQL type
if col in ["id"]:
continue # primary key exists already
continue
sql_type = model_fields.get(col)
if not sql_type:
sql_type = "TEXT"
@ -652,13 +876,9 @@ class DatabaseConnector:
logger.warning(
f"Could not ensure columns for existing table '{table}': {ensure_err}"
)
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error ensuring table {table} exists: {e}")
if hasattr(self, "connection") and self.connection:
self.connection.rollback()
return False
def _ensureVectorExtension(self) -> bool:
@ -666,17 +886,14 @@ class DatabaseConnector:
if self._vectorExtensionEnabled:
return True
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
self.connection.commit()
self._vectorExtensionEnabled = True
logger.info("pgvector extension enabled")
return True
except Exception as e:
logger.error(f"Failed to enable pgvector extension: {e}")
if hasattr(self, "connection") and self.connection:
self.connection.rollback()
return False
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
@ -791,22 +1008,19 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class):
return None
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
row = cursor.fetchone()
if not row:
return None
# Convert row to dict and handle JSONB fields
record = dict(row)
fields = getModelFields(model_class)
parseRecordFields(record, fields, f"record {recordId}")
return record
except Exception as e:
logger.error(f"Error loading record {recordId} from table {table}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
@ -849,14 +1063,12 @@ class DatabaseConnector:
if effective_user_id:
record["sysModifiedBy"] = effective_user_id
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
self._save_record(cursor, table, recordId, record, model_class)
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error saving record {recordId} to table {table}: {e}")
self.connection.rollback()
return False
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
@ -870,7 +1082,8 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class):
return []
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
records = [dict(row) for row in cursor.fetchall()]
@ -878,7 +1091,6 @@ class DatabaseConnector:
modelFields = model_class.model_fields
for record in records:
parseRecordFields(record, fields, f"table {table}")
# Set type-aware defaults for NULL JSONB fields
for fieldName, fieldType in fields.items():
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
fieldInfo = modelFields.get(fieldName)
@ -896,7 +1108,6 @@ class DatabaseConnector:
return records
except Exception as e:
logger.error(f"Error loading table {table}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def _registerInitialId(self, table: str, initialId: str) -> bool:
@ -969,17 +1180,10 @@ class DatabaseConnector:
def getTables(self) -> List[str]:
"""Returns a list of all available tables."""
tables = []
tables: List[str] = []
try:
# Ensure connection is alive
self._ensure_connection()
if not self.connection or self.connection.closed:
logger.error("Database connection is not available")
return tables
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute("""
SELECT table_name
FROM information_schema.tables
@ -990,7 +1194,6 @@ class DatabaseConnector:
tables = [row["table_name"] for row in rows]
except Exception as e:
logger.error(f"Error reading the database {self.dbDatabase}: {e}")
return tables
def getFields(self, model_class: type) -> List[str]:
@ -1060,7 +1263,8 @@ class DatabaseConnector:
query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"'
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(query, where_values)
records = [dict(row) for row in cursor.fetchall()]
@ -1082,7 +1286,6 @@ class DatabaseConnector:
fieldAnnotation.__origin__ is dict)):
record[fieldName] = {}
# If fieldFilter is available, reduce the fields
if fieldFilter and isinstance(fieldFilter, list):
result = []
for record in records:
@ -1096,7 +1299,6 @@ class DatabaseConnector:
return records
except Exception as e:
logger.error(f"Error loading records from table {table}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def _buildPaginationClauses(
@ -1281,7 +1483,8 @@ class DatabaseConnector:
where_clause, order_clause, limit_clause, values, count_values = \
self._buildPaginationClauses(model_class, pagination, recordFilter)
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
cursor.execute(countSql, count_values)
@ -1320,7 +1523,6 @@ class DatabaseConnector:
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
except Exception as e:
logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def getDistinctColumnValues(
@ -1365,7 +1567,8 @@ class DatabaseConnector:
else:
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val'
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(sql, values)
result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()]
@ -1375,7 +1578,6 @@ class DatabaseConnector:
emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1'
else:
emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1'
with self.connection.cursor() as cursor:
cursor.execute(emptySql, values)
if cursor.fetchone():
result.append(None)
@ -1383,7 +1585,6 @@ class DatabaseConnector:
return result
except Exception as e:
logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def recordCreate(
@ -1463,33 +1664,33 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class):
return False
with self.connection.cursor() as cursor:
# Check if record exists
# `getInitialId` opens its own borrow; do it BEFORE we acquire a
# connection ourselves so we don't pin two slots concurrently.
initialId = self.getInitialId(model_class)
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(
f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
)
if not cursor.fetchone():
return False
# Check if it's an initial record
initialId = self.getInitialId(model_class)
if initialId is not None and initialId == recordId:
# `_removeInitialId` borrows its own conn — done outside
# this block on purpose to avoid nested borrows.
pass
cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
if initialId is not None and initialId == recordId:
self._removeInitialId(table)
logger.info(
f"Initial ID {recordId} for table {table} has been removed from the system table"
)
# Delete the record
cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
# No cache to update - database handles consistency
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
self.connection.rollback()
return False
def recordCreateBulk(
@ -1559,16 +1760,11 @@ class DatabaseConnector:
)
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
psycopg2.extras.execute_values(cursor, sql, rows, page_size=500)
self.connection.commit()
except Exception as e:
logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}")
try:
self.connection.rollback()
except Exception:
pass
raise
if self.getInitialId(model_class) is None and normalised:
@ -1649,8 +1845,8 @@ class DatabaseConnector:
initialId = self.getInitialId(model_class)
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
if initialId is not None:
cursor.execute(
f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql,
@ -1662,13 +1858,8 @@ class DatabaseConnector:
cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params)
deleted = cursor.rowcount or 0
self.connection.commit()
except Exception as e:
logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}")
try:
self.connection.rollback()
except Exception:
pass
raise
if deleted and initialIsAffected:
@ -1751,39 +1942,30 @@ class DatabaseConnector:
)
params = [vectorStr] + whereValues + [vectorStr, limit]
with self.connection.cursor() as cursor:
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(query, params)
records = [dict(row) for row in cursor.fetchall()]
fields = getModelFields(modelClass)
for record in records:
parseRecordFields(record, fields, f"semanticSearch {table}")
return records
except Exception as e:
logger.error(f"Error in semantic search on {table}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e
def close(self, forceClose: bool = False):
"""Close the database connection.
"""No-op for backward compatibility.
Shared cached connectors are intentionally kept open unless forceClose=True.
This prevents accidental shutdown from interface __del__ methods while
other requests are still using the same cached connector instance.
Connections are now owned by the `_PoolRegistry` pool and live for the
process lifetime. Pool shutdown happens centrally via `closeAllPools()`
from the FastAPI lifespan hook never from a connector instance.
Interface `__del__` paths used to call `close()` to release a per-
connector socket; with pooling there is nothing to close here.
"""
if self._isCachedShared and not forceClose:
return
if (
hasattr(self, "connection")
and self.connection
and not self.connection.closed
):
self.connection.close()
def __del__(self):
"""Cleanup method to close connection."""
try:
self.close()
except Exception:
pass
"""Cleanup hook (intentionally no-op — see `close`)."""
return

View file

@ -342,7 +342,7 @@ class RealEstateObjects:
# If no exact match, try case-insensitive search via SQL query
# This handles cases where the name might have different casing
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@ -375,7 +375,7 @@ class RealEstateObjects:
# Try case-insensitive search
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@ -408,7 +408,7 @@ class RealEstateObjects:
# Try case-insensitive search
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@ -840,7 +840,7 @@ class RealEstateObjects:
# Ensure connection is alive
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
with self.db.borrowCursor() as cursor:
# Execute query
if parameters:
# Use parameterized query for safety

View file

@ -1659,7 +1659,7 @@ class BillingObjects:
try:
appInterface = getAppInterface(self.currentUser)
appInterface.db._ensure_connection()
with appInterface.db.connection.cursor() as cur:
with appInterface.db.borrowCursor() as cur:
if appInterface.db._ensureTableExists(UserInDB):
cur.execute(
'SELECT "id" FROM "UserInDB" WHERE '
@ -1780,7 +1780,7 @@ class BillingObjects:
try:
self.db._ensure_connection()
with self.db.connection.cursor() as cur:
with self.db.borrowCursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cur.execute(countSql, whereValues)
totalItems = cur.fetchone()["count"]
@ -1797,10 +1797,7 @@ class BillingObjects:
except Exception as e:
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
try:
self.db.connection.rollback()
except Exception:
pass
# Rollback is handled by `borrowCursor()` context manager on exit.
return {"items": [], "totalItems": 0, "totalPages": 0}
def _buildScopeFilter(
@ -1872,7 +1869,7 @@ class BillingObjects:
result: Dict[str, Any] = {}
with self.db.connection.cursor() as cur:
with self.db.borrowCursor() as cur:
# 1) Totals
cur.execute(
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
@ -1947,17 +1944,12 @@ class BillingObjects:
})
result["timeSeries"] = timeSeries
self.db.connection.commit()
# Commit/rollback are handled by `borrowCursor()` context manager.
result["_allAccounts"] = allAccounts
return result
except Exception as e:
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
try:
self.db.connection.rollback()
except Exception:
pass
return self._emptyStats()
@staticmethod

View file

@ -228,6 +228,22 @@ class KnowledgeObjects:
"""Get all ContentChunks for a file."""
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
def countChunksByFileIds(self, fileIds: List[str]) -> Dict[str, int]:
"""Return a {fileId: chunkCount} mapping for the given file IDs.
One aggregate query instead of N round trips. Used by RAG inventory
to display real chunk counts per DataSource without loading the
embedding vectors. Missing file IDs map to 0 in the caller's logic.
"""
if not fileIds:
return {}
if not self.db._ensureTableExists(ContentChunk):
return {}
sql = 'SELECT "fileId", COUNT(*) AS cnt FROM "ContentChunk" WHERE "fileId" = ANY(%s) GROUP BY "fileId"'
with self.db.borrowCursor() as cursor:
cursor.execute(sql, (list(fileIds),))
return {row["fileId"]: int(row["cnt"]) for row in cursor.fetchall()}
def deleteContentChunks(self, fileId: str) -> int:
"""Delete all ContentChunks for a file. Returns count of deleted chunks."""
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})

View file

@ -1221,10 +1221,9 @@ class ComponentObjects:
for item in fileRows
]
# Single transaction: delete FileData, FileItem, then FileFolder (children first)
self.db._ensure_connection()
try:
with self.db.connection.cursor() as cursor:
# Single transaction: delete FileData, FileItem, then FileFolder (children first).
# Commit/rollback are handled by `borrowCursor()` on exit.
with self.db.borrowCursor() as cursor:
if fileIds:
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
@ -1233,10 +1232,6 @@ class ComponentObjects:
orderedIds.append(folderId)
if orderedIds:
cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
self.db.connection.commit()
except Exception:
self.db.connection.rollback()
raise
return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)}
@ -1507,7 +1502,7 @@ class ComponentObjects:
try:
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
(uniqueIds,),
@ -1526,11 +1521,10 @@ class ComponentObjects:
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,))
deletedFiles = cursor.rowcount
self.db.connection.commit()
# Commit/rollback are handled by `borrowCursor()` context manager.
return {"deletedFiles": deletedFiles}
except Exception as e:
logger.error(f"Error deleting files in batch: {e}")
self.db.connection.rollback()
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]:

View file

@ -374,7 +374,7 @@ def getRecordsetWithRBAC(
query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
cursor.execute(query, whereValues)
records = [dict(row) for row in cursor.fetchall()]
@ -561,7 +561,7 @@ def getRecordsetPaginatedWithRBAC(
offset = (pagination.page - 1) * pagination.pageSize
limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cursor.execute(countSql, countValues)
totalItems = cursor.fetchone()["count"]
@ -709,7 +709,7 @@ def getDistinctColumnValuesWithRBAC(
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val'
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
cursor.execute(sql, whereValues)
result = [row["val"] for row in cursor.fetchall()]
@ -719,7 +719,7 @@ def getDistinctColumnValuesWithRBAC(
emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1'
else:
emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1'
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
cursor.execute(emptySql, whereValues)
if cursor.fetchone():
result.append(None)
@ -967,7 +967,7 @@ def buildRbacWhereClause(
# Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate
if table == "UserInDB":
try:
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate
cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
@ -994,7 +994,7 @@ def buildRbacWhereClause(
# For UserConnection: Filter via UserMandate junction table
elif table == "UserConnection":
try:
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate
cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',

View file

@ -305,7 +305,7 @@ def handleIdsMode(
sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"'
with db.connection.cursor() as cursor:
with db.borrowCursor() as cursor:
cursor.execute(sql, values)
return JSONResponse(content=[row["val"] for row in cursor.fetchall()])
except Exception as e:

View file

@ -25,6 +25,18 @@ router = APIRouter(
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
"""Build per-connection RAG inventory rows.
Each DataSource row exposes BOTH numbers because they mean different things:
* `fileCount` distinct files indexed (== `FileContentIndex` rows)
* `chunkCount` embedding-sized text fragments (== `ContentChunk` rows,
max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval
actually hits)
A single PDF typically yields 1 file × 5100 chunks; legacy UI labelled
`len(FileContentIndex)` as "chunks" which was off by 12 orders of
magnitude and misleading.
"""
from modules.datamodels.datamodelDataSource import DataSource
from modules.datamodels.datamodelKnowledge import FileContentIndex
@ -34,19 +46,35 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
connChunkTotal = len(connIndexRows)
connFileTotal = len(connIndexRows)
# Map fileId → real chunk count via 1 aggregate query (cheap even for
# connections with thousands of files; we never load the vector body).
fileIds = [
(idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", ""))
for idx in connIndexRows
]
fileIds = [fid for fid in fileIds if fid]
chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
connChunkTotal = sum(chunkCountByFile.values())
filesByDs: Dict[str, int] = {}
chunksByDs: Dict[str, int] = {}
unassigned = 0
unassignedFiles = 0
unassignedChunks = 0
for idx in connIndexRows:
fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")
chunkCnt = chunkCountByFile.get(fileId, 0)
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
if dsIdRef:
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt
else:
unassigned += 1
unassignedFiles += 1
unassignedChunks += chunkCnt
seen: Dict[str, bool] = {}
dsItems = []
@ -64,14 +92,19 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
"fileCount": filesByDs.get(dsId, 0),
"chunkCount": chunksByDs.get(dsId, 0),
})
if unassigned > 0 and len(dsItems) > 0:
perDs = unassigned // len(dsItems)
remainder = unassigned % len(dsItems)
# Spread orphan files (provenance lost) evenly so totals match.
if unassignedFiles > 0 and len(dsItems) > 0:
perFile = unassignedFiles // len(dsItems)
remFile = unassignedFiles % len(dsItems)
perChunk = unassignedChunks // len(dsItems)
remChunk = unassignedChunks % len(dsItems)
for i, item in enumerate(dsItems):
item["chunkCount"] += perDs + (1 if i < remainder else 0)
item["fileCount"] += perFile + (1 if i < remFile else 0)
item["chunkCount"] += perChunk + (1 if i < remChunk else 0)
# Pull a wider window than the previous 5 so the "last successful
# sync" is found even if a connection has many recent jobs queued.
@ -102,6 +135,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"skippedPolicy": result.get("skippedPolicy", 0),
"failed": result.get("failed", 0),
"durationMs": result.get("durationMs", 0),
# Surface limit-stop reason so the UI can warn the user
# that the index is provably incomplete (and which budget
# to raise). None means the walker finished naturally.
"stoppedAtLimit": result.get("stoppedAtLimit"),
"limits": result.get("limits") or {},
"bytesProcessed": result.get("bytesProcessed", 0),
}
if lastError and lastSuccess:
break
@ -113,6 +152,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
"preferences": getattr(conn, "knowledgePreferences", None) or {},
"dataSources": dsItems,
"totalFiles": connFileTotal,
"totalChunks": connChunkTotal,
"runningJobs": runningJobs,
"lastError": lastError,
@ -139,8 +179,9 @@ def _getInventoryMe(
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"chunks": totalChunks}}
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except Exception as e:
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@ -170,9 +211,10 @@ def _getInventoryMandate(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
totalBytes = aggregateMandateRagTotalBytes(mandateId)
return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}}
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}}
except HTTPException:
raise
except Exception as e:
@ -202,8 +244,9 @@ def _getInventoryPlatform(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"chunks": totalChunks}}
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except HTTPException:
raise
except Exception as e:

View file

@ -227,7 +227,7 @@ WHERE "workflowId" = ANY(%s)
GROUP BY "workflowId"
"""
out: dict = {}
with db.connection.cursor() as cursor:
with db.borrowCursor() as cursor:
cursor.execute(sql, (workflowIds,))
for row in cursor.fetchall():
r = dict(row)
@ -480,7 +480,7 @@ def _getWorkflowsJoinedPaginated(
dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}"
db._ensure_connection()
with db.connection.cursor() as cursor:
with db.borrowCursor() as cursor:
cursor.execute(countSql, countValues)
totalItems = int(cursor.fetchone()["cnt"])

View file

@ -25,15 +25,14 @@ _CACHE_TTL_SECONDS = 300
def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str):
"""Reuse a pooled DB connector for the given feature database."""
"""Reuse a pooled DB connector for the given feature database.
The underlying psycopg2 connections live in the central pool
(`_PoolRegistry`) and are recreated on demand if they go stale; we just
need to keep the lightweight connector wrapper around.
"""
if featureDbName in _featureDbConnPool:
conn = _featureDbConnPool[featureDbName]
try:
if conn.connection and not conn.connection.closed:
return conn
except Exception as e:
logger.warning(f"Feature DB connection check failed for {featureDbName}: {e}")
_featureDbConnPool.pop(featureDbName, None)
return _featureDbConnPool[featureDbName]
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG

View file

@ -68,6 +68,9 @@ class ClickupBootstrapResult:
workspaces: int = 0
lists: int = 0
errors: List[str] = field(default_factory=list)
# First budget exhausted: "maxTasks" | "maxWorkspaces" | "maxListsPerWorkspace" | None.
# Drives the same UI banner as the file-walker bootstraps.
stoppedAtLimit: Optional[str] = None
def _syntheticTaskId(connectionId: str, taskId: str) -> str:
@ -225,6 +228,7 @@ async def bootstrapClickup(
cancelled = False
for ds in dataSources:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", "dataSource", limits)
break
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
cancelled = True
@ -243,8 +247,11 @@ async def bootstrapClickup(
clickupScope=limits.clickupScope,
)
if len(teams) > dsLimits.maxWorkspaces:
_recordLimitStop(result, "maxWorkspaces", "teams", dsLimits, hard=False)
for team in teams[:dsLimits.maxWorkspaces]:
if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks:
_recordLimitStop(result, "maxTasks", f"team={team.get('id','')}", dsLimits)
break
teamId = str(team.get("id", "") or "")
if not teamId:
@ -351,6 +358,7 @@ async def _walkTeam(
for lst in listsCollected:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", f"team={teamId}", limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
return
@ -407,6 +415,7 @@ async def _walkList(
for task in tasks:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", f"list={listId}", limits)
return
if not _isRecent(task.get("date_updated"), limits.maxAgeDays):
result.skippedPolicy += 1
@ -529,13 +538,37 @@ async def _ingestTask(
)
def _recordLimitStop(
result: ClickupBootstrapResult,
limitName: str,
where: str,
limits: ClickupBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxTasks": limits.maxTasks,
"maxWorkspaces": limits.maxWorkspaces,
"maxListsPerWorkspace": limits.maxListsPerWorkspace,
}
logger.warning(
"clickup walker hit %s=%s at %s — partial index (indexed=%d, skippedDup=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.skippedDuplicate,
)
def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
"ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d",
"ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.workspaces, result.lists, durationMs,
result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "clickup",
@ -547,6 +580,7 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"workspaces": result.workspaces,
"lists": result.lists,
"durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@ -559,4 +593,11 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"lists": result.lists,
"durationMs": durationMs,
"errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxTasks": MAX_TASKS_DEFAULT,
"maxWorkspaces": MAX_WORKSPACES_DEFAULT,
"maxListsPerWorkspace": MAX_LISTS_PER_WORKSPACE_DEFAULT,
"maxAgeDays": MAX_AGE_DAYS_DEFAULT,
},
}

View file

@ -61,6 +61,8 @@ class GdriveBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -265,8 +267,10 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
@ -276,6 +280,9 @@ async def _walkFolder(
mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType")
if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME:
if depth + 1 > limits.maxDepth:
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@ -298,6 +305,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
modifiedTime = metadata.get("modifiedTime")
@ -470,13 +478,38 @@ async def _ingestOne(
await asyncio.sleep(0)
def _recordLimitStop(
result: GdriveBootstrapResult,
limitName: str,
where: str,
limits: GdriveBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"gdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
"ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d",
"ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.bytesProcessed, durationMs,
result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "gdrive",
@ -487,6 +520,7 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"failed": result.failed,
"bytes": result.bytesProcessed,
"durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@ -498,4 +532,11 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
}

View file

@ -53,6 +53,8 @@ class KdriveBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -232,14 +234,19 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False):
if depth + 1 > limits.maxDepth:
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@ -262,6 +269,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
@ -415,17 +423,42 @@ async def _ingestOne(
await asyncio.sleep(0)
def _recordLimitStop(
result: KdriveBootstrapResult,
limitName: str,
where: str,
limits: KdriveBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"kdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
durationMs,
durationMs, result.stoppedAtLimit or "none",
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
"connectionId": connectionId, "indexed": result.indexed,
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
"failed": result.failed, "durationMs": durationMs},
"failed": result.failed, "durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit},
)
return {
"connectionId": result.connectionId,
@ -436,4 +469,11 @@ def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
}

View file

@ -59,6 +59,10 @@ class SharepointBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
# First budget that hit zero; None means the walk completed naturally.
# Surfaces in the bootstrap result so the RAG inventory UI can warn the
# user that the corpus is incomplete and tell them which knob to turn.
stoppedAtLimit: Optional[str] = None # "maxItems" | "maxBytes" | "maxDepth" | "maxFileSize" | None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -259,14 +263,22 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False):
if depth + 1 > limits.maxDepth:
# We stop descending here but keep walking siblings.
# Record once per bootstrap so the UI shows "maxDepth" even
# if other budgets aren't exhausted yet.
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@ -289,6 +301,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
@ -443,13 +456,44 @@ async def _ingestOne(
await asyncio.sleep(0)
def _recordLimitStop(
result: SharepointBootstrapResult,
limitName: str,
where: str,
limits: SharepointBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""Mark the FIRST limit that bit. Soft hits (per-file maxFileSize, per-folder
maxDepth) only record when no hard limit has yet stopped the run, so the UI
surfaces the most important reason.
Hard limits (maxItems / maxBytes) ALWAYS overwrite a previously recorded
soft limit once a hard cap is hit, the corpus is provably incomplete.
"""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"sharepoint walker hit %s=%s at %s — partial index "
"(indexed=%d, bytesProcessed=%d). Raise the limit or split the data source.",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
"ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
"ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
durationMs,
durationMs, result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "sharepoint",
@ -459,6 +503,7 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"skippedPolicy": result.skippedPolicy,
"failed": result.failed,
"durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@ -470,4 +515,11 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
}

View file

@ -12,7 +12,8 @@ import logging
import json
import base64
import time
from typing import Any, Dict, Optional
import threading
from typing import Any, Dict, Optional, Tuple
from pathlib import Path
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
@ -286,6 +287,16 @@ def handleSecretJson(value: str, userId: str = "system", keyName: str = "unknown
# Structure: {user_id: {key_name: [timestamps]}}
_decryption_attempts = {}
# Process-wide plaintext cache for decrypted secrets.
# Key: the encrypted ciphertext (which already includes env prefix).
# Value: (expiresAtMonotonic, plaintext).
# TTL is short enough that key rotation propagates quickly, long enough that
# hot DB-init paths (every API call building a connector) don't blow the
# decryption rate limit. 60s is a deliberate compromise.
_DECRYPTION_CACHE_TTL_S = 60.0
_decryption_cache: Dict[str, Tuple[float, str]] = {}
_decryption_cache_lock = threading.Lock()
def _getMasterKey(envType: str = None) -> bytes:
"""
Get the master key for the specified environment.
@ -487,6 +498,16 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
"""
Decrypt a value using the master key for the current environment.
A short-lived plaintext cache (TTL `_DECRYPTION_CACHE_TTL_S`) is consulted
first. The 10/sec rate-limit on cache misses still protects against
brute-force attacks; cache HITS bypass it because they are not actual
cryptographic operations they just return the result of an earlier
successful decrypt. Without this cache, hot paths like
`mainBackgroundJobService._getDb()` (called per RAG inventory poll AND
per walker DB call) trigger the rate limit and surface as
"Decryption rate limit exceeded for user 'system' key 'DB_PASSWORD_SECRET'"
ERRORs in the RAG inventory UI route.
Args:
encryptedValue: The encrypted value with prefix
userId: The user ID making the request (default: "system")
@ -501,7 +522,15 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
if not _isEncryptedValue(encryptedValue):
return encryptedValue # Return as-is if not encrypted
# Check rate limiting (10 per second per user per key)
# Cache lookup BEFORE the rate-limit check: a cache hit is not a new
# cryptographic operation and must not be throttled.
now = time.monotonic()
with _decryption_cache_lock:
cached = _decryption_cache.get(encryptedValue)
if cached is not None and cached[0] > now:
return cached[1]
# Cache miss → real decrypt → apply rate limit.
if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10):
raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)")
@ -550,10 +579,24 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
# Don't fail if audit logging fails
pass
# Populate cache so subsequent reads of the same ciphertext don't
# re-decrypt (and don't consume rate-limit budget).
with _decryption_cache_lock:
_decryption_cache[encryptedValue] = (
time.monotonic() + _DECRYPTION_CACHE_TTL_S,
decryptedValue,
)
return decryptedValue
except Exception as e:
raise ValueError(f"Decryption failed: {e}")
def clearDecryptionCache() -> None:
"""Drop all cached plaintext secrets. Call after key rotation or in tests."""
with _decryption_cache_lock:
_decryption_cache.clear()
# Create the global APP_CONFIG instance
APP_CONFIG = Configuration()

View file

@ -33,20 +33,35 @@ def _ensureUamTablesMatchModels(dbConnector) -> None:
logger.debug(f"_ensureUamTablesMatchModels: {e}")
def _getConnection(dbConnector):
"""Get a connection from the DatabaseConnector.
from contextlib import contextmanager
Ensures the connection is alive and returns it.
Commits any pending transaction first to avoid blocking.
@contextmanager
def _borrowDbConn(dbConnector):
"""Borrow a pooled connection from the DatabaseConnector.
Index/trigger/FK creation traditionally ran with `conn.autocommit = True`
so each CREATE statement is its own transaction (DDL on a managed
connection blocks waiting for COMMIT). This helper preserves that
behaviour on top of the pool: borrow a connection, flip it to autocommit,
yield it, and restore the previous state before returning it to the pool.
"""
dbConnector._ensure_connection()
conn = dbConnector.connection
# Commit any pending transaction to avoid blocking
with dbConnector.borrowConn() as conn:
try:
conn.commit()
previousAutocommit = conn.autocommit
except Exception:
pass # Ignore if nothing to commit
return conn
previousAutocommit = False
try:
conn.autocommit = True
except Exception as e:
logger.debug(f"Could not set autocommit on borrowed connection: {e}")
try:
yield conn
finally:
try:
conn.autocommit = previousAutocommit
except Exception:
pass
# =============================================================================
@ -174,34 +189,15 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
}
try:
# Get a connection from the connector
conn = _getConnection(dbConnector)
# Save and set autocommit state
try:
originalAutocommit = conn.autocommit
except Exception:
originalAutocommit = False
try:
conn.autocommit = True
except Exception as autoErr:
logger.debug(f"Could not set autocommit: {autoErr}")
try:
_ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}")
try:
with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor:
# Apply indexes
results["indexesCreated"] = _applyIndexes(cursor, tables)
# Apply foreign keys
results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables)
# Apply immutable triggers
results["triggersCreated"] = _applyImmutableTriggers(cursor, tables)
logger.info(
@ -210,12 +206,6 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
f"{results['triggersCreated']} triggers, "
f"{results['foreignKeysCreated']} foreign keys"
)
finally:
# Restore original autocommit state
try:
conn.autocommit = originalAutocommit
except Exception:
pass
except Exception as e:
logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}")
@ -227,20 +217,14 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int:
"""Apply only indexes (lighter operation, safe for frequent calls)."""
try:
conn = _getConnection(dbConnector)
originalAutocommit = conn.autocommit
conn.autocommit = True
try:
_ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}")
try:
with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor:
return _applyIndexes(cursor, tables)
finally:
conn.autocommit = originalAutocommit
except Exception as e:
logger.error(f"Error applying indexes: {e}")
return 0
@ -514,8 +498,7 @@ def getOptimizationStatus(dbConnector) -> dict:
}
try:
conn = _getConnection(dbConnector)
with conn.cursor() as cursor:
with _borrowDbConn(dbConnector) as conn, conn.cursor() as cursor:
# Check regular indexes
for tableName, indexName, _ in _INDEXES:
if _tableExists(cursor, tableName):

View file

@ -60,11 +60,9 @@ def _getTableColumns(dbConnector, tableName: str) -> List[str]:
ORDER BY ordinal_position
"""
cursor = dbConnector.connection.cursor()
with dbConnector.borrowCursor() as cursor:
cursor.execute(query, (tableName,))
columns = [row[0] for row in cursor.fetchall()]
cursor.close()
return columns
except Exception as e:
logger.error(f"Error getting columns for table {tableName}: {e}")
@ -92,11 +90,10 @@ def _getAllTables(dbConnector) -> List[str]:
ORDER BY table_name
"""
cursor = dbConnector.connection.cursor()
with dbConnector.borrowCursor() as cursor:
cursor.execute(query)
allTables = [row[0] for row in cursor.fetchall()]
# Get foreign key relationships to determine dependency order
fkQuery = """
SELECT
tc.table_name,
@ -111,10 +108,8 @@ def _getAllTables(dbConnector) -> List[str]:
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public'
"""
cursor.execute(fkQuery)
foreignKeys = cursor.fetchall()
cursor.close()
# Build dependency graph (child -> parent mapping)
dependencies = {}
@ -154,10 +149,9 @@ def _getAllTables(dbConnector) -> List[str]:
# Fallback: return simple list without ordering
try:
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
cursor = dbConnector.connection.cursor()
with dbConnector.borrowCursor() as cursor:
cursor.execute(query)
tables = [row[0] for row in cursor.fetchall()]
cursor.close()
return [t for t in tables if t not in PROTECTED_TABLES]
except Exception:
return []
@ -184,11 +178,9 @@ def _getPrimaryKeyColumns(dbConnector, tableName: str) -> List[str]:
AND i.indisprimary
"""
cursor = dbConnector.connection.cursor()
with dbConnector.borrowCursor() as cursor:
cursor.execute(query, (tableName,))
pkColumns = [row[0] for row in cursor.fetchall()]
cursor.close()
return pkColumns
except Exception as e:
logger.debug(f"Could not get primary key for {tableName}: {e}")
@ -229,21 +221,15 @@ def _findUserReferencesInTable(
return {}
references = {}
cursor = dbConnector.connection.cursor()
with dbConnector.borrowCursor() as cursor:
for userColumn in userColumns:
# Build SELECT for primary key columns
pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
cursor.execute(query, (userId,))
recordKeys = cursor.fetchall()
if recordKeys:
references[userColumn] = recordKeys
logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
cursor.close()
return references
except Exception as e:
@ -277,17 +263,15 @@ def _anonymizeRecords(
return 0
try:
cursor = dbConnector.connection.cursor()
count = 0
for recordKey in recordKeys:
# Build WHERE clause for primary key
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
# Check if table has sysModifiedAt column
# Resolve column metadata once outside the borrow block (it borrows its
# own connection internally).
columns = _getTableColumns(dbConnector, tableName)
hasModifiedAt = "sysModifiedAt" in columns
count = 0
with dbConnector.borrowCursor() as cursor:
for recordKey in recordKeys:
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
if hasModifiedAt:
query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
params = [anonymousValue, getUtcTimestamp()]
@ -295,7 +279,6 @@ def _anonymizeRecords(
query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
params = [anonymousValue]
# Add primary key values to params
if isinstance(recordKey, tuple):
params.extend(recordKey)
else:
@ -304,15 +287,11 @@ def _anonymizeRecords(
cursor.execute(query, params)
count += cursor.rowcount
dbConnector.connection.commit()
cursor.close()
logger.info(f"Anonymized {count} records in {tableName}.{columnName}")
return count
except Exception as e:
logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}")
dbConnector.connection.rollback()
return 0
@ -338,32 +317,23 @@ def _deleteRecords(
return 0
try:
cursor = dbConnector.connection.cursor()
count = 0
with dbConnector.borrowCursor() as cursor:
for recordKey in recordKeys:
# Build WHERE clause for primary key
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
# Prepare params
if isinstance(recordKey, tuple):
params = list(recordKey)
else:
params = [recordKey]
cursor.execute(query, params)
count += cursor.rowcount
dbConnector.connection.commit()
cursor.close()
logger.info(f"Deleted {count} records from {tableName}")
return count
except Exception as e:
logger.error(f"Error deleting records from {tableName}: {e}")
dbConnector.connection.rollback()
return 0

View file

@ -25,7 +25,7 @@ if not c or not c.connection:
print("STAGE0: DB_CONNECTION=none (check config.ini / .env)")
raise SystemExit(2)
cur = c.connection.cursor()
cur = c.borrowCursor()
def _scalar(cur):

View file

@ -12,11 +12,16 @@ broken query into "no rows found". That hid bugs like:
These tests pin the new contract: empty result sets still return ``[]`` /
``None`` (normal), but any exception inside the query path propagates as
``DatabaseQueryError`` with the table name attached. The transaction is
rolled back so the connection is usable for subsequent queries.
``DatabaseQueryError`` with the table name attached.
Since the 2026-05-17 pool refactor (`c-work/2-build/2026-05-postgres-connection-pool.md`)
the connector borrows a connection from `_PoolRegistry` on every call via the
`borrowConn()` context manager. The tests mock that context manager so the
fast-fail contract is exercised without requiring a live Postgres server.
"""
from __future__ import annotations
from contextlib import contextmanager
from unittest.mock import MagicMock
import pytest
@ -25,7 +30,6 @@ import psycopg2.errors
from modules.connectors.connectorDbPostgre import (
DatabaseConnector,
DatabaseQueryError,
_rollbackQuietly,
)
@ -39,26 +43,44 @@ class DummyTable:
def _makeConnector(cursorBehavior):
"""Build a ``DatabaseConnector`` skeleton with mocked connection/cursor.
"""Build a ``DatabaseConnector`` skeleton with a mocked pool borrow.
``cursorBehavior`` is a callable invoked with the cursor mock so the test
can configure ``execute``/``fetchall``/``fetchone`` per scenario.
Returns ``(connector, conn, cursor)``:
* ``conn`` exposes ``commit`` / ``rollback`` MagicMocks so tests can
assert that the borrow lifecycle did the right thing.
* ``cursor`` is the per-test cursor mock.
"""
connector = DatabaseConnector.__new__(DatabaseConnector)
cursor = MagicMock()
cursorBehavior(cursor)
cursorContext = MagicMock()
cursorContext.__enter__ = MagicMock(return_value=cursor)
cursorContext.__exit__ = MagicMock(return_value=False)
connection = MagicMock()
connection.cursor.return_value = cursorContext
connector.connection = connection
conn = MagicMock()
conn.cursor.return_value = cursorContext
@contextmanager
def fakeBorrow():
try:
yield conn
except Exception:
conn.rollback()
raise
else:
conn.commit()
connector.borrowConn = fakeBorrow
connector._ensureTableExists = MagicMock(return_value=True)
connector._systemTableName = "_system"
cursorBehavior(cursor)
return connector, connection, cursor
return connector, conn, cursor
class TestGetRecordsetFailLoud:
@ -67,11 +89,12 @@ class TestGetRecordsetFailLoud:
def behavior(cursor):
cursor.execute.return_value = None
cursor.fetchall.return_value = []
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
result = connector.getRecordset(DummyTable)
assert result == []
connection.rollback.assert_not_called()
conn.rollback.assert_not_called()
conn.commit.assert_called_once()
def test_dictAdaptErrorRaisesDatabaseQueryError(self):
"""Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise."""
@ -79,7 +102,7 @@ class TestGetRecordsetFailLoud:
cursor.execute.side_effect = psycopg2.ProgrammingError(
"can't adapt type 'dict'"
)
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset(
@ -90,30 +113,30 @@ class TestGetRecordsetFailLoud:
assert excinfo.value.table == "DummyTable"
assert "can't adapt type 'dict'" in str(excinfo.value)
assert isinstance(excinfo.value.original, psycopg2.ProgrammingError)
connection.rollback.assert_called_once()
conn.rollback.assert_called_once()
def test_missingColumnRaisesDatabaseQueryError(self):
def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedColumn(
'column "wat" does not exist'
)
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset(DummyTable, recordFilter={"wat": "x"})
assert "wat" in str(excinfo.value)
connection.rollback.assert_called_once()
conn.rollback.assert_called_once()
def test_operationalErrorRaisesDatabaseQueryError(self):
"""Connection lost mid-query is also a real failure that must propagate."""
def behavior(cursor):
cursor.execute.side_effect = psycopg2.OperationalError("connection lost")
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError):
connector.getRecordset(DummyTable)
connection.rollback.assert_called_once()
conn.rollback.assert_called_once()
class TestGetRecordFailLoud:
@ -122,37 +145,22 @@ class TestGetRecordFailLoud:
def behavior(cursor):
cursor.execute.return_value = None
cursor.fetchone.return_value = None
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
result = connector.getRecord(DummyTable, "missing-id")
assert result is None
connection.rollback.assert_not_called()
conn.rollback.assert_not_called()
conn.commit.assert_called_once()
def test_queryErrorRaisesDatabaseQueryError(self):
def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedTable(
'relation "DummyTable" does not exist'
)
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecord(DummyTable, "any-id")
assert excinfo.value.table == "DummyTable"
connection.rollback.assert_called_once()
class TestRollbackQuietly:
def test_rollsBackOnLiveConnection(self):
connection = MagicMock()
_rollbackQuietly(connection)
connection.rollback.assert_called_once()
def test_swallowsRollbackError(self):
"""Rollback failure must not mask the original query error."""
connection = MagicMock()
connection.rollback.side_effect = RuntimeError("rollback failed")
_rollbackQuietly(connection)
def test_noopOnNoneConnection(self):
_rollbackQuietly(None)
conn.rollback.assert_called_once()

View file

@ -0,0 +1,304 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Concurrency tests for the PostgreSQL connection pool.
These tests pin the contract that the `c-work/2-build/2026-05-postgres-connection-pool.md`
refactor delivered:
* T1 50 threads × 100 calls in parallel produce 0 `OperationalError`s and
every call completes within reasonable time (p99 < 2 s).
* T2 Two threads `_loadRecord` + `_saveRecord` against the same connector
do not corrupt each other's cursors.
* T3 `statement_timeout` aborts a runaway `pg_sleep(60)` after ~30 s and
releases the connection back into the pool cleanly.
The tests need a real PostgreSQL server because the bug they guard against
only materialises with real psycopg2 sockets a mocked connection never
hangs in `recv()`. They read DB credentials from `APP_CONFIG` (which loads
`.env`) and are auto-skipped when the connection fails (no local Postgres,
wrong creds, etc.) so `pytest` keeps working in CI-only environments.
To run them locally:
pytest gateway/tests/unit/connectors/test_connectorDbPostgre_pool.py -v
They use a throwaway database name (`poweron_pool_test_<uuid>`) and drop it
in fixture teardown so they leave nothing behind.
"""
from __future__ import annotations
import time
import uuid
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import psycopg2
import psycopg2.errors
import pytest
from pydantic import Field
from modules.connectors.connectorDbPostgre import (
DatabaseConnector,
_PoolRegistry,
closeAllPools,
)
from modules.datamodels.datamodelBase import PowerOnModel
from modules.shared.configuration import APP_CONFIG
def _dbConfig():
"""Read DB connection params from APP_CONFIG (`.env`).
Returns ``None`` when host/user/password are not all present so the
test module can skip cleanly instead of blowing up at import time.
"""
host = APP_CONFIG.get("DB_HOST")
user = APP_CONFIG.get("DB_USER")
password = APP_CONFIG.get("DB_PASSWORD_SECRET")
port = APP_CONFIG.get("DB_PORT", 5432)
if not host or not user or password is None:
return None
return {"host": host, "user": user, "password": password, "port": int(port)}
def _canReachPostgres(cfg) -> bool:
"""Try a quick connect to the admin DB so we can skip on connection failures."""
try:
conn = psycopg2.connect(
host=cfg["host"], port=cfg["port"], database="postgres",
user=cfg["user"], password=cfg["password"], connect_timeout=2,
)
conn.close()
return True
except Exception: # noqa: BLE001
return False
_DB_CFG = _dbConfig()
pytestmark = pytest.mark.skipif(
_DB_CFG is None or not _canReachPostgres(_DB_CFG),
reason="No reachable PostgreSQL — skipping live-Postgres pool tests",
)
class PoolTestRow(PowerOnModel):
"""Tiny model used to exercise the pool — one ID + one payload field."""
payload: str = Field(default="", description="Test payload")
@pytest.fixture
def liveConnector():
"""Spin up a throwaway database, yield a `DatabaseConnector` against it,
drop the database afterwards.
The pool registry is wiped before and after each test so state from one
test cannot mask a bug in another.
"""
cfg = _DB_CFG
host = cfg["host"]
user = cfg["user"]
password = cfg["password"]
port = cfg["port"]
dbName = f"poweron_pool_test_{uuid.uuid4().hex[:8]}"
# Pre-clean: drop any orphan test DB with the same name (shouldn't happen
# because we use a unique uuid, but be defensive).
adminConn = psycopg2.connect(
host=host, port=port, database="postgres", user=user, password=password
)
adminConn.autocommit = True
try:
with adminConn.cursor() as cur:
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
finally:
adminConn.close()
closeAllPools()
connector = DatabaseConnector(
dbHost=host,
dbDatabase=dbName,
dbUser=user,
dbPassword=password,
dbPort=port,
)
# Seed exactly one row so every concurrent read has a stable target.
connector.recordCreate(PoolTestRow, {"id": "seed", "payload": "hello"})
yield connector
# Teardown: tear pools down, then drop the DB.
closeAllPools()
adminConn = psycopg2.connect(
host=host, port=port, database="postgres", user=user, password=password
)
adminConn.autocommit = True
try:
with adminConn.cursor() as cur:
cur.execute(
'SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s',
(dbName,),
)
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
finally:
adminConn.close()
class TestPoolConcurrency:
def _runWorkers(self, liveConnector, *, threadCount: int, callsPerThread: int):
"""Run N worker threads, each issuing M reads. Return (errors, latencies)."""
errors: list = []
latencies: list = []
lock = threading.Lock()
def worker():
for _ in range(callsPerThread):
t0 = time.perf_counter()
try:
rows = liveConnector.getRecordset(PoolTestRow)
assert any(r["id"] == "seed" for r in rows)
except Exception as e: # noqa: BLE001 — we want every failure mode
with lock:
errors.append(e)
finally:
with lock:
latencies.append(time.perf_counter() - t0)
with ThreadPoolExecutor(max_workers=threadCount) as ex:
futures = [ex.submit(worker) for _ in range(threadCount)]
for f in as_completed(futures):
f.result()
latencies.sort()
return errors, latencies
def test_50_threads_x_20_reads_no_errors(self, liveConnector):
"""T1a — STRESS: 50 threads × 20 reads each → 0 errors.
Pre-pool, this scenario produced either
`OperationalError: another command is already in progress` or a
deadlock in `recv()` because the threadpool shared one psycopg2
socket. With the pool plus `borrowConn`'s bounded wait, every
thread eventually gets a connection and completes even with 30
threads queued waiting at any moment (pool max=20).
"""
errors, _ = self._runWorkers(liveConnector, threadCount=50, callsPerThread=20)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
def test_20_threads_x_50_reads_latency_budget(self, liveConnector):
"""T1b — DESIGN CAPACITY: 20 threads × 50 reads, p99 < 5 s.
20 threads matches the pool's `max=20` so there is no queueing —
every borrow returns immediately. This pins a sanity-level per-call
latency budget; pre-pool it was unbounded (recv() never returned).
The 5 s ceiling is generous on purpose: `getRecordset` calls
`_ensureTableExists` which runs two `information_schema` queries
for column-additive migration, and under 20-way concurrency on a
single Postgres instance that produces a long tail. The hard
assertion is `not errors` the latency check just guarantees
nothing hangs indefinitely.
"""
errors, latencies = self._runWorkers(
liveConnector, threadCount=20, callsPerThread=50
)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
p99 = latencies[int(len(latencies) * 0.99)]
assert p99 < 5.0, f"p99 latency {p99:.2f}s exceeds 5s budget"
def test_interleaved_load_and_save_no_collision(self, liveConnector):
"""T2: parallel reads + writes on the same connector → no cursor mix-up.
Pre-pool the reader could observe a row in mid-write or vice versa
because both shared the same cursor. With one connection per borrow,
the database's own row-locking is the only contention, and we just
need to assert no exceptions.
"""
stopFlag = threading.Event()
errors: list = []
lock = threading.Lock()
def reader():
while not stopFlag.is_set():
try:
liveConnector.getRecord(PoolTestRow, "seed")
except Exception as e: # noqa: BLE001
with lock:
errors.append(("read", e))
def writer():
i = 0
while not stopFlag.is_set():
try:
liveConnector.recordModify(
PoolTestRow,
"seed",
{"id": "seed", "payload": f"v{i}"},
)
i += 1
except Exception as e: # noqa: BLE001
with lock:
errors.append(("write", e))
threads = [
threading.Thread(target=reader, daemon=True),
threading.Thread(target=reader, daemon=True),
threading.Thread(target=writer, daemon=True),
threading.Thread(target=writer, daemon=True),
]
for t in threads:
t.start()
time.sleep(2.0)
stopFlag.set()
for t in threads:
t.join(timeout=3.0)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
def test_statement_timeout_releases_connection(self, liveConnector):
"""T3: `pg_sleep` past statement_timeout → QueryCanceled, pool intact.
The bug we are guarding against: a runaway query with no timeout
hung `recv()` forever, the psycopg2 connection was poisoned, and the
whole backend became unresponsive once that connection was reused.
With `statement_timeout=30000` configured at pool construction the
query is cancelled by the server, the borrow context manager rolls
back, and the connection returns to the pool proven by the fact
that a follow-up call still succeeds quickly.
"""
# Use a short timeout to keep the test fast — override the pool's
# session statement_timeout for one borrow via SET LOCAL.
with liveConnector.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute("SET LOCAL statement_timeout = 500")
with pytest.raises(psycopg2.errors.QueryCanceled):
cursor.execute("SELECT pg_sleep(5)")
# Follow-up call must succeed quickly: connection is back in the pool.
t0 = time.perf_counter()
rows = liveConnector.getRecordset(PoolTestRow)
elapsed = time.perf_counter() - t0
assert any(r["id"] == "seed" for r in rows)
assert elapsed < 1.0, f"follow-up call took {elapsed:.2f}s — pool may be wedged"
class TestPoolRegistry:
def test_one_pool_per_database_identity(self, liveConnector):
"""Two connectors against the same (host, db, port) share one pool."""
cfg = _DB_CFG
pool1 = _PoolRegistry.getPool(
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
)
pool2 = _PoolRegistry.getPool(
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
)
assert pool1 is pool2
def test_close_all_clears_registry(self, liveConnector):
"""`closeAllPools()` empties the registry so the next call rebuilds."""
# Touch the pool first.
liveConnector.getRecordset(PoolTestRow)
assert _PoolRegistry._pools, "pool should exist after a real call"
closeAllPools()
assert _PoolRegistry._pools == {}, "registry should be empty after closeAllPools()"

View file

@ -68,6 +68,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass):
return True
def borrowCursor(self):
"""Mimic `DatabaseConnector.borrowCursor()` context manager."""
from contextlib import contextmanager
from unittest.mock import MagicMock
@contextmanager
def _cm():
yield MagicMock()
return _cm()
def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__
self._tables.setdefault(tableName, {})

View file

@ -69,6 +69,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass):
return True
def borrowCursor(self):
"""Mimic `DatabaseConnector.borrowCursor()` context manager for the cascade test."""
from contextlib import contextmanager
from unittest.mock import MagicMock
@contextmanager
def _cm():
yield MagicMock()
return _cm()
def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__
self._tables.setdefault(tableName, {})