db connection pooling and rag limit transparency

This commit is contained in:
ValueOn AG 2026-05-17 20:38:37 +02:00
parent f5aba4bf99
commit 2bb65c2303
23 changed files with 1519 additions and 782 deletions

9
app.py
View file

@ -439,6 +439,15 @@ async def lifespan(app: FastAPI):
except Exception as e: except Exception as e:
logger.warning(f"Could not shutdown feature containers: {e}") logger.warning(f"Could not shutdown feature containers: {e}")
# --- Close all PostgreSQL connection pools ---
# Must run LAST: feature `onStop` hooks may still issue DB calls during
# shutdown. Once we tear down the pools, no more borrows are possible.
try:
from modules.connectors.connectorDbPostgre import closeAllPools
closeAllPools()
except Exception as e:
logger.warning(f"Closing DB connection pools failed: {e}")
logger.info("Application has been shut down") logger.info("Application has been shut down")

View file

@ -2,9 +2,12 @@
# All rights reserved. # All rights reserved.
import contextvars import contextvars
import re import re
import time
import psycopg2 import psycopg2
import psycopg2.extras import psycopg2.extras
import psycopg2.pool
import logging import logging
from contextlib import contextmanager
from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type
import uuid import uuid
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -44,24 +47,6 @@ class DatabaseQueryError(RuntimeError):
self.original = original self.original = original
def _rollbackQuietly(connection) -> None:
"""Restore the connection state after a failed query.
Postgres puts the connection in an error state after any failed
statement; subsequent queries on the same connection raise
``InFailedSqlTransaction`` until we rollback. We swallow rollback
errors because the original query error is what the caller should
see a secondary rollback failure typically means the connection
is gone and will be reopened on the next ``_ensure_connection``.
"""
if connection is None:
return
try:
connection.rollback()
except Exception:
pass
class SystemTable(PowerOnModel): class SystemTable(PowerOnModel):
"""Data model for system table entries""" """Data model for system table entries"""
@ -203,9 +188,174 @@ def _quotePgIdent(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"' return '"' + str(name).replace('"', '""') + '"'
# Cache connectors by (host, database, port) to avoid duplicate inits for same database. # ---------------------------------------------------------------------------
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via # Connection pool registry
# contextvars to avoid races when concurrent requests share the same connector. # ---------------------------------------------------------------------------
# psycopg2 connections are NOT thread-safe; sharing one connection across the
# FastAPI threadpool (sync `def` routes) or across multiple async tasks that
# happen to be active simultaneously results in either
# `OperationalError: another command is already in progress` or — far worse —
# an unbounded hang in `recv()` on the connection socket, because two cursors
# wait for one server response.
#
# `_PoolRegistry` keeps one `ThreadedConnectionPool` per database identity
# (`host`, `db`, `port`). Every DB call borrows a connection from the pool,
# runs its query, and returns the connection. The pool itself guarantees that
# no two callers share the same psycopg2 connection at the same time.
#
# `statement_timeout=30000` (30 s) is a safety net: a runaway query is aborted
# instead of hanging forever and poisoning the connection. `connect_timeout=10`
# prevents indefinite blocking when the Postgres host is unreachable.
_DEFAULT_POOL_MIN = 2
_DEFAULT_POOL_MAX = 20
_STATEMENT_TIMEOUT_MS = 30000
_CONNECT_TIMEOUT_S = 10
# psycopg2.pool.ThreadedConnectionPool.getconn() does NOT block when the pool
# is exhausted — it raises `psycopg2.pool.PoolError` immediately. That makes
# bursty workloads (50 concurrent route calls against a max=20 pool) fail
# spuriously. `borrowConn()` therefore retries with a small backoff up to
# `_BORROW_WAIT_TIMEOUT_S` seconds before giving up.
_BORROW_WAIT_TIMEOUT_S = 30.0
_BORROW_WAIT_BACKOFF_S = 0.05
def _resolvePoolMax() -> int:
"""Pool size is configurable via `DB_POOL_MAX_CONN` (default 20)."""
try:
return max(2, int(APP_CONFIG.get("DB_POOL_MAX_CONN") or _DEFAULT_POOL_MAX))
except (ValueError, TypeError):
return _DEFAULT_POOL_MAX
class _PoolRegistry:
"""Process-wide registry of `ThreadedConnectionPool` instances.
Keyed by `(host, database, port)` so that two `DatabaseConnector` instances
pointing at the same physical database share one pool. Lazy-initialised and
thread-safe.
"""
_pools: Dict[tuple, psycopg2.pool.ThreadedConnectionPool] = {}
_lock = threading.Lock()
@classmethod
def getPool(
cls,
*,
dbHost: str,
dbDatabase: str,
dbUser: str,
dbPassword: str,
dbPort: int,
) -> psycopg2.pool.ThreadedConnectionPool:
port = int(dbPort) if dbPort is not None else 5432
key = (dbHost, dbDatabase, port)
# Fast path: pool exists.
pool = cls._pools.get(key)
if pool is not None:
return pool
# Slow path: create exactly one pool per key, even under contention.
with cls._lock:
pool = cls._pools.get(key)
if pool is not None:
return pool
poolMax = _resolvePoolMax()
options = f"-c statement_timeout={_STATEMENT_TIMEOUT_MS}"
try:
pool = psycopg2.pool.ThreadedConnectionPool(
_DEFAULT_POOL_MIN,
poolMax,
host=dbHost,
port=port,
database=dbDatabase,
user=dbUser,
password=dbPassword,
client_encoding="utf8",
cursor_factory=psycopg2.extras.RealDictCursor,
connect_timeout=_CONNECT_TIMEOUT_S,
options=options,
)
except Exception as e:
logger.error(
"Failed to create connection pool for db=%s host=%s: %s",
dbDatabase, dbHost, e,
)
raise
cls._pools[key] = pool
logger.debug(
"Created connection pool for db=%s host=%s port=%s (min=%d max=%d, stmt_timeout=%dms)",
dbDatabase, dbHost, port, _DEFAULT_POOL_MIN, poolMax, _STATEMENT_TIMEOUT_MS,
)
return pool
@classmethod
def closeAll(cls) -> None:
"""Close every pool. Call once during FastAPI shutdown."""
with cls._lock:
for key, pool in list(cls._pools.items()):
try:
pool.closeall()
except Exception as e:
logger.warning("Error closing pool %s: %s", key, e)
cls._pools.clear()
logger.info("All database connection pools closed")
def closeAllPools() -> None:
"""Public entry point for FastAPI lifespan shutdown hook."""
_PoolRegistry.closeAll()
class _ConnectionShim:
"""Backward-compatibility stand-in for the old `DatabaseConnector.connection`.
Old callers used patterns like::
db.connection.commit()
if db.connection and not db.connection.closed: ...
These are now no-ops because the pool owns the connection lifecycle.
Direct cursor access through `self.connection.cursor()` is intentionally
blocked with a clear error so silent breakage is impossible every such
call site must migrate to `db.borrowCursor()`.
"""
closed = False
def __bool__(self) -> bool:
return True
def commit(self) -> None:
return
def rollback(self) -> None:
return
def close(self) -> None:
return
def cursor(self, *args, **kwargs):
raise RuntimeError(
"DatabaseConnector.connection.cursor() is no longer supported. "
"Use `db.borrowCursor()` (or `db.borrowConn()` for multi-statement "
"transactions) so the connection is borrowed from and returned to "
"the pool correctly."
)
_CONNECTION_SHIM = _ConnectionShim()
# ---------------------------------------------------------------------------
# Connector cache (lightweight wrappers — actual connections live in the pool)
# ---------------------------------------------------------------------------
# Multiple call sites (`routeI18n`, `aiAuditLogger`, `mainBackgroundJobService`,
# interfaces) ask for a connector via `getCachedConnector(...)` and expect to
# get back the same object on subsequent calls. Now that real connection
# multiplexing happens at the pool layer, the cache returns lightweight
# `DatabaseConnector` wrappers — they hold no connection themselves, only the
# DSN params and a reference to the shared pool.
_MAX_CACHED_CONNECTORS = 32 _MAX_CACHED_CONNECTORS = 32
_connector_cache: Dict[tuple, "DatabaseConnector"] = {} _connector_cache: Dict[tuple, "DatabaseConnector"] = {}
_connector_cache_order: List[tuple] = [] # FIFO order for eviction _connector_cache_order: List[tuple] = [] # FIFO order for eviction
@ -223,22 +373,36 @@ def getCachedConnector(
dbPort: int = None, dbPort: int = None,
userId: str = None, userId: str = None,
) -> "DatabaseConnector": ) -> "DatabaseConnector":
"""Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits. """Return a cached `DatabaseConnector` wrapper for `(host, database, port)`.
Uses contextvars for userId so concurrent requests sharing the same connector get correct sysCreatedBy/sysModifiedBy.
Two layers of caching, both intentional:
1. **Pool layer** (`_PoolRegistry`) owns the actual psycopg2 connections.
Per `(host, db, port)` exactly one `ThreadedConnectionPool`, shared
across every wrapper that points at the same physical database. This
is what actually saves Postgres connection slots.
2. **Wrapper layer** (this function) caches the `DatabaseConnector`
Python object. Each wrapper triggers an `initDbSystem()` on first
instantiation (CREATE DATABASE if missing, CREATE TABLE _system,
pool warm-up). Caching the wrapper avoids paying that bootstrap cost
on every request and stops the log from filling with "PostgreSQL
database system initialized" lines.
`userId` is request-scoped via the `_current_user_id` contextvar so two
concurrent requests sharing the same cached wrapper still produce
correct `sysCreatedBy` / `sysModifiedBy` audit fields.
""" """
port = int(dbPort) if dbPort is not None else 5432 port = int(dbPort) if dbPort is not None else 5432
key = (dbHost, dbDatabase, port) key = (dbHost, dbDatabase, port)
with _connector_cache_lock: with _connector_cache_lock:
if key not in _connector_cache: if key not in _connector_cache:
# Evict oldest if at capacity # FIFO eviction. Connectors are now lightweight (no per-instance
# connection), so eviction is purely a memory bookkeeping concern;
# the underlying pool stays alive in `_PoolRegistry`.
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order: while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
oldest_key = _connector_cache_order.pop(0) oldest_key = _connector_cache_order.pop(0)
if oldest_key in _connector_cache: _connector_cache.pop(oldest_key, None)
try:
_connector_cache[oldest_key].close(forceClose=True)
except Exception as e:
logger.warning(f"Error closing evicted connector: {e}")
del _connector_cache[oldest_key]
_connector_cache[key] = DatabaseConnector( _connector_cache[key] = DatabaseConnector(
dbHost=dbHost, dbHost=dbHost,
dbDatabase=dbDatabase, dbDatabase=dbDatabase,
@ -282,34 +446,38 @@ class DatabaseConnector:
# Set userId (default to empty string if None) # Set userId (default to empty string if None)
self.userId = userId if userId is not None else "" self.userId = userId if userId is not None else ""
# Initialize database system first (creates database if needed) # No per-instance connection any more — real connections live in the
self.connection = None # shared `_PoolRegistry` pool. `_isCachedShared` is retained because
# `close(forceClose=False)` callers (interface __del__) still ask.
self._isCachedShared = False self._isCachedShared = False
self.initDbSystem()
# No caching needed with proper database - PostgreSQL handles performance # pgvector extension state (cached per connector instance — cheap)
# Thread safety
self._lock = threading.Lock()
# pgvector extension state
self._vectorExtensionEnabled = False self._vectorExtensionEnabled = False
# Initialize system table # System table bootstrap: create database, system table, ensure metadata.
self._systemTableName = "_system" self._systemTableName = "_system"
self.initDbSystem()
self._initializeSystemTable() self._initializeSystemTable()
def initDbSystem(self): def initDbSystem(self):
"""Initialize the database system - creates database and tables.""" """Bootstrap the physical database and the `_system` metadata table.
try:
# Create database if it doesn't exist
self._create_database_if_not_exists()
# Create tables Uses a short-lived autocommit connection (NOT the pool) because
`CREATE DATABASE` cannot run inside a transaction block. Also warms up
the pool by acquiring it for this `(host, db, port)` once.
"""
try:
self._create_database_if_not_exists()
self._create_tables() self._create_tables()
# Establish connection to the database # Warm the pool so the first request doesn't pay for socket setup.
self._connect() _PoolRegistry.getPool(
dbHost=self.dbHost,
dbDatabase=self.dbDatabase,
dbUser=self.dbUser,
dbPassword=self.dbPassword,
dbPort=self.dbPort,
)
logger.debug( logger.debug(
"PostgreSQL database system initialized (db=%s, host=%s, port=%s)", "PostgreSQL database system initialized (db=%s, host=%s, port=%s)",
@ -319,10 +487,121 @@ class DatabaseConnector:
logger.error(f"FATAL ERROR: Database system initialization failed: {e}") logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
raise raise
def _create_database_if_not_exists(self): @property
"""Create the database if it doesn't exist.""" def connection(self) -> "_ConnectionShim":
"""Backward-compat shim — see `_ConnectionShim` docstring."""
return _CONNECTION_SHIM
def _ensure_connection(self) -> None:
"""No-op for backward compatibility.
Previously this method tested-or-reconnected the per-instance socket.
With pooling, the `ThreadedConnectionPool` re-establishes broken
connections lazily on the next `getconn()`. Kept as a no-op so that
legacy call sites continue to compile.
"""
return
@contextmanager
def borrowCursor(self):
"""Borrow a cursor for one short SQL block.
Convenience wrapper around `borrowConn()` for the common pattern:
with db.borrowCursor() as cursor:
cursor.execute(...)
rows = cursor.fetchall()
Replaces the legacy `with db.connection.cursor() as cursor:` pattern.
Commit/rollback and pool return are handled automatically callers
must NOT call `.commit()`/`.rollback()` on the cursor themselves.
"""
with self.borrowConn() as conn:
with conn.cursor() as cursor:
yield cursor
@contextmanager
def borrowConn(self):
"""Borrow a connection from the pool for the duration of the block.
Pool-exhaustion semantics: `ThreadedConnectionPool.getconn()` raises
`psycopg2.pool.PoolError` immediately when the pool is at its `maxconn`
limit it does NOT block. We wrap that with a bounded busy-wait so
bursty workloads (e.g. 50 concurrent route calls against a max=20
pool) queue up instead of failing. If no connection becomes available
within `_BORROW_WAIT_TIMEOUT_S`, the `PoolError` is propagated so a
genuinely deadlocked pool surfaces as a real error.
On normal exit the current transaction is committed (idempotent for
read-only queries leaves the connection in a clean state for the
next borrower). On exception the transaction is rolled back and the
exception propagates. The connection is **always** returned to the
pool, even when commit/rollback fails otherwise a single bad query
would leak a slot and eventually exhaust the pool.
"""
pool = _PoolRegistry.getPool(
dbHost=self.dbHost,
dbDatabase=self.dbDatabase,
dbUser=self.dbUser,
dbPassword=self.dbPassword,
dbPort=self.dbPort,
)
conn = self._acquireConn(pool)
try:
yield conn
except Exception:
try:
conn.rollback()
except Exception:
pass
raise
else:
# Best-effort commit so the connection goes back to the pool with
# no in-flight transaction. Failure here is non-fatal — the pool
# will re-establish the socket on the next `getconn()` if needed.
try:
conn.commit()
except Exception:
try:
conn.rollback()
except Exception:
pass
finally:
try:
pool.putconn(conn)
except Exception as e:
logger.warning("Failed to return connection to pool: %s", e)
@staticmethod
def _acquireConn(pool: psycopg2.pool.ThreadedConnectionPool):
"""Get a connection from the pool, waiting up to `_BORROW_WAIT_TIMEOUT_S`.
psycopg2's pool throws on exhaustion instead of queueing — this helper
polls with a short backoff so callers see queue semantics.
"""
deadline = time.monotonic() + _BORROW_WAIT_TIMEOUT_S
attempt = 0
while True:
try:
return pool.getconn()
except psycopg2.pool.PoolError as e:
attempt += 1
if time.monotonic() >= deadline:
logger.error(
"Connection pool exhausted after %.1fs wait (%d retries)",
_BORROW_WAIT_TIMEOUT_S, attempt,
)
raise
time.sleep(_BORROW_WAIT_BACKOFF_S)
def _create_database_if_not_exists(self):
"""Create the database if it doesn't exist.
Uses an autocommit connection on the `postgres` admin DB because
`CREATE DATABASE` cannot run inside a transaction block so this
path intentionally does NOT use the pool.
"""
try: try:
# Use the configured user for database creation
conn = psycopg2.connect( conn = psycopg2.connect(
host=self.dbHost, host=self.dbHost,
port=self.dbPort, port=self.dbPort,
@ -330,22 +609,20 @@ class DatabaseConnector:
user=self.dbUser, user=self.dbUser,
password=self.dbPassword, password=self.dbPassword,
client_encoding="utf8", client_encoding="utf8",
connect_timeout=_CONNECT_TIMEOUT_S,
) )
conn.autocommit = True conn.autocommit = True
try:
with conn.cursor() as cursor: with conn.cursor() as cursor:
# Check if database exists
cursor.execute( cursor.execute(
"SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,) "SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
) )
exists = cursor.fetchone() exists = cursor.fetchone()
if not exists: if not exists:
# Create database with proper quoting for names with hyphens
quoted_db_name = f'"{self.dbDatabase}"' quoted_db_name = f'"{self.dbDatabase}"'
cursor.execute(f"CREATE DATABASE {quoted_db_name}") cursor.execute(f"CREATE DATABASE {quoted_db_name}")
logger.info(f"Created database: {self.dbDatabase}") logger.info(f"Created database: {self.dbDatabase}")
finally:
conn.close() conn.close()
except Exception as e: except Exception as e:
@ -356,9 +633,12 @@ class DatabaseConnector:
) )
def _create_tables(self): def _create_tables(self):
"""Create only the system table - application tables are created by interfaces.""" """Create the `_system` table.
Uses a short-lived autocommit connection (not the pool) runs exactly
once at connector creation.
"""
try: try:
# Use the configured user for table creation
conn = psycopg2.connect( conn = psycopg2.connect(
host=self.dbHost, host=self.dbHost,
port=self.dbPort, port=self.dbPort,
@ -366,11 +646,11 @@ class DatabaseConnector:
user=self.dbUser, user=self.dbUser,
password=self.dbPassword, password=self.dbPassword,
client_encoding="utf8", client_encoding="utf8",
connect_timeout=_CONNECT_TIMEOUT_S,
) )
conn.autocommit = True conn.autocommit = True
try:
with conn.cursor() as cursor: with conn.cursor() as cursor:
# Create only the system table
cursor.execute(""" cursor.execute("""
CREATE TABLE IF NOT EXISTS _system ( CREATE TABLE IF NOT EXISTS _system (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
@ -382,6 +662,7 @@ class DatabaseConnector:
"sysModifiedBy" VARCHAR(255) "sysModifiedBy" VARCHAR(255)
) )
""") """)
finally:
conn.close() conn.close()
except Exception as e: except Exception as e:
@ -391,67 +672,26 @@ class DatabaseConnector:
) )
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}") raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
def _connect(self):
"""Establish connection to PostgreSQL database."""
try:
# Use configured user for main connection with proper parameter handling
self.connection = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database=self.dbDatabase,
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
cursor_factory=psycopg2.extras.RealDictCursor,
)
self.connection.autocommit = False # Use transactions
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
raise
def _ensure_connection(self):
"""Ensure database connection is alive, reconnect if necessary."""
try:
if self.connection is None or self.connection.closed:
self._connect()
else:
# Test connection with a simple query
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1")
except Exception as e:
logger.warning(f"Connection lost, reconnecting: {e}")
self._connect()
def _initializeSystemTable(self): def _initializeSystemTable(self):
"""Initializes the system table if it doesn't exist yet.""" """Initializes the system table if it doesn't exist yet."""
try: try:
# First ensure the system table exists
self._ensureTableExists(SystemTable) self._ensureTableExists(SystemTable)
with self.borrowConn() as conn:
with self.connection.cursor() as cursor: with conn.cursor() as cursor:
# Check if system table has any data
cursor.execute('SELECT COUNT(*) FROM "_system"') cursor.execute('SELECT COUNT(*) FROM "_system"')
row = cursor.fetchone() cursor.fetchone() # noqa: just verifies table is readable
count = row["count"] if row else 0
self.connection.commit()
except Exception as e: except Exception as e:
logger.error(f"Error initializing system table: {e}") logger.error(f"Error initializing system table: {e}")
self.connection.rollback()
raise raise
def _loadSystemTable(self) -> Dict[str, str]: def _loadSystemTable(self) -> Dict[str, str]:
"""Loads the system table with the initial IDs.""" """Loads the system table with the initial IDs."""
try: try:
with self.connection.cursor() as cursor: with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute('SELECT "table_name", "initial_id" FROM "_system"') cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
rows = cursor.fetchall() rows = cursor.fetchall()
return {row["table_name"]: row["initial_id"] for row in rows}
system_data = {}
for row in rows:
system_data[row["table_name"]] = row["initial_id"]
return system_data
except Exception as e: except Exception as e:
logger.error(f"Error loading system table: {e}") logger.error(f"Error loading system table: {e}")
return {} return {}
@ -459,11 +699,9 @@ class DatabaseConnector:
def _saveSystemTable(self, data: Dict[str, str]) -> bool: def _saveSystemTable(self, data: Dict[str, str]) -> bool:
"""Saves the system table with the initial IDs.""" """Saves the system table with the initial IDs."""
try: try:
with self.connection.cursor() as cursor: with self.borrowConn() as conn:
# Clear existing data with conn.cursor() as cursor:
cursor.execute('DELETE FROM "_system"') cursor.execute('DELETE FROM "_system"')
# Insert new data
for table_name, initial_id in data.items(): for table_name, initial_id in data.items():
cursor.execute( cursor.execute(
""" """
@ -472,21 +710,16 @@ class DatabaseConnector:
""", """,
(table_name, initial_id, getUtcTimestamp()), (table_name, initial_id, getUtcTimestamp()),
) )
self.connection.commit()
return True return True
except Exception as e: except Exception as e:
logger.error(f"Error saving system table: {e}") logger.error(f"Error saving system table: {e}")
self.connection.rollback()
return False return False
def _ensureSystemTableExists(self) -> bool: def _ensureSystemTableExists(self) -> bool:
"""Ensures the system table exists, creates it if it doesn't.""" """Ensures the system table exists, creates it if it doesn't."""
try: try:
self._ensure_connection() with self.borrowConn() as conn:
with conn.cursor() as cursor:
with self.connection.cursor() as cursor:
# Check if system table exists
cursor.execute( cursor.execute(
"SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s", "SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
(self._systemTableName,), (self._systemTableName,),
@ -494,7 +727,6 @@ class DatabaseConnector:
exists = cursor.fetchone()["count"] > 0 exists = cursor.fetchone()["count"] > 0
if not exists: if not exists:
# Create system table
cursor.execute(f""" cursor.execute(f"""
CREATE TABLE "{self._systemTableName}" ( CREATE TABLE "{self._systemTableName}" (
"table_name" VARCHAR(255) PRIMARY KEY, "table_name" VARCHAR(255) PRIMARY KEY,
@ -507,7 +739,6 @@ class DatabaseConnector:
""") """)
logger.info("System table created successfully") logger.info("System table created successfully")
else: else:
# Check if we need to add missing columns to existing table
cursor.execute( cursor.execute(
""" """
SELECT column_name FROM information_schema.columns SELECT column_name FROM information_schema.columns
@ -527,7 +758,6 @@ class DatabaseConnector:
cursor.execute( cursor.execute(
f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}' f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
) )
return True return True
except Exception as e: except Exception as e:
logger.error(f"Error ensuring system table exists: {e}") logger.error(f"Error ensuring system table exists: {e}")
@ -542,10 +772,8 @@ class DatabaseConnector:
return self._ensureSystemTableExists() return self._ensureSystemTableExists()
try: try:
self._ensure_connection() with self.borrowConn() as conn:
with conn.cursor() as cursor:
with self.connection.cursor() as cursor:
# Check if table exists by querying information_schema with case-insensitive search
cursor.execute( cursor.execute(
""" """
SELECT COUNT(*) FROM information_schema.tables SELECT COUNT(*) FROM information_schema.tables
@ -556,7 +784,6 @@ class DatabaseConnector:
exists = cursor.fetchone()["count"] > 0 exists = cursor.fetchone()["count"] > 0
if not exists: if not exists:
# Create table from Pydantic model
self._create_table_from_model(cursor, table, model_class) self._create_table_from_model(cursor, table, model_class)
logger.info( logger.info(
f"Created table '{table}' with columns from Pydantic model" f"Created table '{table}' with columns from Pydantic model"
@ -581,15 +808,12 @@ class DatabaseConnector:
for row in existing_column_rows for row in existing_column_rows
} }
# Desired columns based on model
model_fields = getModelFields(model_class) model_fields = getModelFields(model_class)
desired_columns = set(["id"]) | set(model_fields.keys()) desired_columns = set(["id"]) | set(model_fields.keys())
# Add missing columns
for col in sorted(desired_columns - existing_columns): for col in sorted(desired_columns - existing_columns):
# Determine SQL type
if col in ["id"]: if col in ["id"]:
continue # primary key exists already continue
sql_type = model_fields.get(col) sql_type = model_fields.get(col)
if not sql_type: if not sql_type:
sql_type = "TEXT" sql_type = "TEXT"
@ -652,13 +876,9 @@ class DatabaseConnector:
logger.warning( logger.warning(
f"Could not ensure columns for existing table '{table}': {ensure_err}" f"Could not ensure columns for existing table '{table}': {ensure_err}"
) )
self.connection.commit()
return True return True
except Exception as e: except Exception as e:
logger.error(f"Error ensuring table {table} exists: {e}") logger.error(f"Error ensuring table {table} exists: {e}")
if hasattr(self, "connection") and self.connection:
self.connection.rollback()
return False return False
def _ensureVectorExtension(self) -> bool: def _ensureVectorExtension(self) -> bool:
@ -666,17 +886,14 @@ class DatabaseConnector:
if self._vectorExtensionEnabled: if self._vectorExtensionEnabled:
return True return True
try: try:
self._ensure_connection() with self.borrowConn() as conn:
with self.connection.cursor() as cursor: with conn.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector") cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
self.connection.commit()
self._vectorExtensionEnabled = True self._vectorExtensionEnabled = True
logger.info("pgvector extension enabled") logger.info("pgvector extension enabled")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Failed to enable pgvector extension: {e}") logger.error(f"Failed to enable pgvector extension: {e}")
if hasattr(self, "connection") and self.connection:
self.connection.rollback()
return False return False
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None: def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
@ -791,22 +1008,19 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class): if not self._ensureTableExists(model_class):
return None return None
with self.connection.cursor() as cursor: with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,)) cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
row = cursor.fetchone() row = cursor.fetchone()
if not row: if not row:
return None return None
# Convert row to dict and handle JSONB fields
record = dict(row) record = dict(row)
fields = getModelFields(model_class) fields = getModelFields(model_class)
parseRecordFields(record, fields, f"record {recordId}") parseRecordFields(record, fields, f"record {recordId}")
return record return record
except Exception as e: except Exception as e:
logger.error(f"Error loading record {recordId} from table {table}: {e}") logger.error(f"Error loading record {recordId} from table {table}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e raise DatabaseQueryError(table, str(e), original=e) from e
def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]: def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
@ -849,14 +1063,12 @@ class DatabaseConnector:
if effective_user_id: if effective_user_id:
record["sysModifiedBy"] = effective_user_id record["sysModifiedBy"] = effective_user_id
with self.connection.cursor() as cursor: with self.borrowConn() as conn:
with conn.cursor() as cursor:
self._save_record(cursor, table, recordId, record, model_class) self._save_record(cursor, table, recordId, record, model_class)
self.connection.commit()
return True return True
except Exception as e: except Exception as e:
logger.error(f"Error saving record {recordId} to table {table}: {e}") logger.error(f"Error saving record {recordId} to table {table}: {e}")
self.connection.rollback()
return False return False
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]: def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
@ -870,7 +1082,8 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class): if not self._ensureTableExists(model_class):
return [] return []
with self.connection.cursor() as cursor: with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"') cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
records = [dict(row) for row in cursor.fetchall()] records = [dict(row) for row in cursor.fetchall()]
@ -878,7 +1091,6 @@ class DatabaseConnector:
modelFields = model_class.model_fields modelFields = model_class.model_fields
for record in records: for record in records:
parseRecordFields(record, fields, f"table {table}") parseRecordFields(record, fields, f"table {table}")
# Set type-aware defaults for NULL JSONB fields
for fieldName, fieldType in fields.items(): for fieldName, fieldType in fields.items():
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
fieldInfo = modelFields.get(fieldName) fieldInfo = modelFields.get(fieldName)
@ -896,7 +1108,6 @@ class DatabaseConnector:
return records return records
except Exception as e: except Exception as e:
logger.error(f"Error loading table {table}: {e}") logger.error(f"Error loading table {table}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e raise DatabaseQueryError(table, str(e), original=e) from e
def _registerInitialId(self, table: str, initialId: str) -> bool: def _registerInitialId(self, table: str, initialId: str) -> bool:
@ -969,17 +1180,10 @@ class DatabaseConnector:
def getTables(self) -> List[str]: def getTables(self) -> List[str]:
"""Returns a list of all available tables.""" """Returns a list of all available tables."""
tables = [] tables: List[str] = []
try: try:
# Ensure connection is alive with self.borrowConn() as conn:
self._ensure_connection() with conn.cursor() as cursor:
if not self.connection or self.connection.closed:
logger.error("Database connection is not available")
return tables
with self.connection.cursor() as cursor:
cursor.execute(""" cursor.execute("""
SELECT table_name SELECT table_name
FROM information_schema.tables FROM information_schema.tables
@ -990,7 +1194,6 @@ class DatabaseConnector:
tables = [row["table_name"] for row in rows] tables = [row["table_name"] for row in rows]
except Exception as e: except Exception as e:
logger.error(f"Error reading the database {self.dbDatabase}: {e}") logger.error(f"Error reading the database {self.dbDatabase}: {e}")
return tables return tables
def getFields(self, model_class: type) -> List[str]: def getFields(self, model_class: type) -> List[str]:
@ -1060,7 +1263,8 @@ class DatabaseConnector:
query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"' query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"'
with self.connection.cursor() as cursor: with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(query, where_values) cursor.execute(query, where_values)
records = [dict(row) for row in cursor.fetchall()] records = [dict(row) for row in cursor.fetchall()]
@ -1082,7 +1286,6 @@ class DatabaseConnector:
fieldAnnotation.__origin__ is dict)): fieldAnnotation.__origin__ is dict)):
record[fieldName] = {} record[fieldName] = {}
# If fieldFilter is available, reduce the fields
if fieldFilter and isinstance(fieldFilter, list): if fieldFilter and isinstance(fieldFilter, list):
result = [] result = []
for record in records: for record in records:
@ -1096,7 +1299,6 @@ class DatabaseConnector:
return records return records
except Exception as e: except Exception as e:
logger.error(f"Error loading records from table {table}: {e}") logger.error(f"Error loading records from table {table}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e raise DatabaseQueryError(table, str(e), original=e) from e
def _buildPaginationClauses( def _buildPaginationClauses(
@ -1281,7 +1483,8 @@ class DatabaseConnector:
where_clause, order_clause, limit_clause, values, count_values = \ where_clause, order_clause, limit_clause, values, count_values = \
self._buildPaginationClauses(model_class, pagination, recordFilter) self._buildPaginationClauses(model_class, pagination, recordFilter)
with self.connection.cursor() as cursor: with self.borrowConn() as conn:
with conn.cursor() as cursor:
countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}' countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}' dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
cursor.execute(countSql, count_values) cursor.execute(countSql, count_values)
@ -1320,7 +1523,6 @@ class DatabaseConnector:
return {"items": records, "totalItems": totalItems, "totalPages": totalPages} return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
except Exception as e: except Exception as e:
logger.error(f"Error in getRecordsetPaginated for table {table}: {e}") logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e raise DatabaseQueryError(table, str(e), original=e) from e
def getDistinctColumnValues( def getDistinctColumnValues(
@ -1365,7 +1567,8 @@ class DatabaseConnector:
else: else:
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val' sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val'
with self.connection.cursor() as cursor: with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(sql, values) cursor.execute(sql, values)
result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()] result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()]
@ -1375,7 +1578,6 @@ class DatabaseConnector:
emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1' emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1'
else: else:
emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1' emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1'
with self.connection.cursor() as cursor:
cursor.execute(emptySql, values) cursor.execute(emptySql, values)
if cursor.fetchone(): if cursor.fetchone():
result.append(None) result.append(None)
@ -1383,7 +1585,6 @@ class DatabaseConnector:
return result return result
except Exception as e: except Exception as e:
logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}") logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e raise DatabaseQueryError(table, str(e), original=e) from e
def recordCreate( def recordCreate(
@ -1463,33 +1664,33 @@ class DatabaseConnector:
if not self._ensureTableExists(model_class): if not self._ensureTableExists(model_class):
return False return False
with self.connection.cursor() as cursor: # `getInitialId` opens its own borrow; do it BEFORE we acquire a
# Check if record exists # connection ourselves so we don't pin two slots concurrently.
initialId = self.getInitialId(model_class)
with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute( cursor.execute(
f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,) f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
) )
if not cursor.fetchone(): if not cursor.fetchone():
return False return False
# Check if it's an initial record if initialId is not None and initialId == recordId:
initialId = self.getInitialId(model_class) # `_removeInitialId` borrows its own conn — done outside
# this block on purpose to avoid nested borrows.
pass
cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
if initialId is not None and initialId == recordId: if initialId is not None and initialId == recordId:
self._removeInitialId(table) self._removeInitialId(table)
logger.info( logger.info(
f"Initial ID {recordId} for table {table} has been removed from the system table" f"Initial ID {recordId} for table {table} has been removed from the system table"
) )
# Delete the record
cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
# No cache to update - database handles consistency
self.connection.commit()
return True return True
except Exception as e: except Exception as e:
logger.error(f"Error deleting record {recordId} from table {table}: {e}") logger.error(f"Error deleting record {recordId} from table {table}: {e}")
self.connection.rollback()
return False return False
def recordCreateBulk( def recordCreateBulk(
@ -1559,16 +1760,11 @@ class DatabaseConnector:
) )
try: try:
self._ensure_connection() with self.borrowConn() as conn:
with self.connection.cursor() as cursor: with conn.cursor() as cursor:
psycopg2.extras.execute_values(cursor, sql, rows, page_size=500) psycopg2.extras.execute_values(cursor, sql, rows, page_size=500)
self.connection.commit()
except Exception as e: except Exception as e:
logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}") logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}")
try:
self.connection.rollback()
except Exception:
pass
raise raise
if self.getInitialId(model_class) is None and normalised: if self.getInitialId(model_class) is None and normalised:
@ -1649,8 +1845,8 @@ class DatabaseConnector:
initialId = self.getInitialId(model_class) initialId = self.getInitialId(model_class)
try: try:
self._ensure_connection() with self.borrowConn() as conn:
with self.connection.cursor() as cursor: with conn.cursor() as cursor:
if initialId is not None: if initialId is not None:
cursor.execute( cursor.execute(
f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql, f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql,
@ -1662,13 +1858,8 @@ class DatabaseConnector:
cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params) cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params)
deleted = cursor.rowcount or 0 deleted = cursor.rowcount or 0
self.connection.commit()
except Exception as e: except Exception as e:
logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}") logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}")
try:
self.connection.rollback()
except Exception:
pass
raise raise
if deleted and initialIsAffected: if deleted and initialIsAffected:
@ -1751,39 +1942,30 @@ class DatabaseConnector:
) )
params = [vectorStr] + whereValues + [vectorStr, limit] params = [vectorStr] + whereValues + [vectorStr, limit]
with self.connection.cursor() as cursor: with self.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute(query, params) cursor.execute(query, params)
records = [dict(row) for row in cursor.fetchall()] records = [dict(row) for row in cursor.fetchall()]
fields = getModelFields(modelClass) fields = getModelFields(modelClass)
for record in records: for record in records:
parseRecordFields(record, fields, f"semanticSearch {table}") parseRecordFields(record, fields, f"semanticSearch {table}")
return records return records
except Exception as e: except Exception as e:
logger.error(f"Error in semantic search on {table}: {e}") logger.error(f"Error in semantic search on {table}: {e}")
_rollbackQuietly(getattr(self, "connection", None))
raise DatabaseQueryError(table, str(e), original=e) from e raise DatabaseQueryError(table, str(e), original=e) from e
def close(self, forceClose: bool = False): def close(self, forceClose: bool = False):
"""Close the database connection. """No-op for backward compatibility.
Shared cached connectors are intentionally kept open unless forceClose=True. Connections are now owned by the `_PoolRegistry` pool and live for the
This prevents accidental shutdown from interface __del__ methods while process lifetime. Pool shutdown happens centrally via `closeAllPools()`
other requests are still using the same cached connector instance. from the FastAPI lifespan hook never from a connector instance.
Interface `__del__` paths used to call `close()` to release a per-
connector socket; with pooling there is nothing to close here.
""" """
if self._isCachedShared and not forceClose:
return return
if (
hasattr(self, "connection")
and self.connection
and not self.connection.closed
):
self.connection.close()
def __del__(self): def __del__(self):
"""Cleanup method to close connection.""" """Cleanup hook (intentionally no-op — see `close`)."""
try: return
self.close()
except Exception:
pass

View file

@ -342,7 +342,7 @@ class RealEstateObjects:
# If no exact match, try case-insensitive search via SQL query # If no exact match, try case-insensitive search via SQL query
# This handles cases where the name might have different casing # This handles cases where the name might have different casing
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cursor: with self.db.borrowCursor() as cursor:
cursor.execute( cursor.execute(
'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1', 'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,) (name,)
@ -375,7 +375,7 @@ class RealEstateObjects:
# Try case-insensitive search # Try case-insensitive search
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cursor: with self.db.borrowCursor() as cursor:
cursor.execute( cursor.execute(
'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1', 'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,) (name,)
@ -408,7 +408,7 @@ class RealEstateObjects:
# Try case-insensitive search # Try case-insensitive search
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cursor: with self.db.borrowCursor() as cursor:
cursor.execute( cursor.execute(
'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1', 'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,) (name,)
@ -840,7 +840,7 @@ class RealEstateObjects:
# Ensure connection is alive # Ensure connection is alive
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cursor: with self.db.borrowCursor() as cursor:
# Execute query # Execute query
if parameters: if parameters:
# Use parameterized query for safety # Use parameterized query for safety

View file

@ -1659,7 +1659,7 @@ class BillingObjects:
try: try:
appInterface = getAppInterface(self.currentUser) appInterface = getAppInterface(self.currentUser)
appInterface.db._ensure_connection() appInterface.db._ensure_connection()
with appInterface.db.connection.cursor() as cur: with appInterface.db.borrowCursor() as cur:
if appInterface.db._ensureTableExists(UserInDB): if appInterface.db._ensureTableExists(UserInDB):
cur.execute( cur.execute(
'SELECT "id" FROM "UserInDB" WHERE ' 'SELECT "id" FROM "UserInDB" WHERE '
@ -1780,7 +1780,7 @@ class BillingObjects:
try: try:
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cur: with self.db.borrowCursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}' countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cur.execute(countSql, whereValues) cur.execute(countSql, whereValues)
totalItems = cur.fetchone()["count"] totalItems = cur.fetchone()["count"]
@ -1797,10 +1797,7 @@ class BillingObjects:
except Exception as e: except Exception as e:
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True) logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
try: # Rollback is handled by `borrowCursor()` context manager on exit.
self.db.connection.rollback()
except Exception:
pass
return {"items": [], "totalItems": 0, "totalPages": 0} return {"items": [], "totalItems": 0, "totalPages": 0}
def _buildScopeFilter( def _buildScopeFilter(
@ -1872,7 +1869,7 @@ class BillingObjects:
result: Dict[str, Any] = {} result: Dict[str, Any] = {}
with self.db.connection.cursor() as cur: with self.db.borrowCursor() as cur:
# 1) Totals # 1) Totals
cur.execute( cur.execute(
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}', f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
@ -1947,17 +1944,12 @@ class BillingObjects:
}) })
result["timeSeries"] = timeSeries result["timeSeries"] = timeSeries
self.db.connection.commit() # Commit/rollback are handled by `borrowCursor()` context manager.
result["_allAccounts"] = allAccounts result["_allAccounts"] = allAccounts
return result return result
except Exception as e: except Exception as e:
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True) logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
try:
self.db.connection.rollback()
except Exception:
pass
return self._emptyStats() return self._emptyStats()
@staticmethod @staticmethod

View file

@ -228,6 +228,22 @@ class KnowledgeObjects:
"""Get all ContentChunks for a file.""" """Get all ContentChunks for a file."""
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId}) return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
def countChunksByFileIds(self, fileIds: List[str]) -> Dict[str, int]:
"""Return a {fileId: chunkCount} mapping for the given file IDs.
One aggregate query instead of N round trips. Used by RAG inventory
to display real chunk counts per DataSource without loading the
embedding vectors. Missing file IDs map to 0 in the caller's logic.
"""
if not fileIds:
return {}
if not self.db._ensureTableExists(ContentChunk):
return {}
sql = 'SELECT "fileId", COUNT(*) AS cnt FROM "ContentChunk" WHERE "fileId" = ANY(%s) GROUP BY "fileId"'
with self.db.borrowCursor() as cursor:
cursor.execute(sql, (list(fileIds),))
return {row["fileId"]: int(row["cnt"]) for row in cursor.fetchall()}
def deleteContentChunks(self, fileId: str) -> int: def deleteContentChunks(self, fileId: str) -> int:
"""Delete all ContentChunks for a file. Returns count of deleted chunks.""" """Delete all ContentChunks for a file. Returns count of deleted chunks."""
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId}) chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})

View file

@ -1221,10 +1221,9 @@ class ComponentObjects:
for item in fileRows for item in fileRows
] ]
# Single transaction: delete FileData, FileItem, then FileFolder (children first) # Single transaction: delete FileData, FileItem, then FileFolder (children first).
self.db._ensure_connection() # Commit/rollback are handled by `borrowCursor()` on exit.
try: with self.db.borrowCursor() as cursor:
with self.db.connection.cursor() as cursor:
if fileIds: if fileIds:
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,)) cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,)) cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
@ -1233,10 +1232,6 @@ class ComponentObjects:
orderedIds.append(folderId) orderedIds.append(folderId)
if orderedIds: if orderedIds:
cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,)) cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
self.db.connection.commit()
except Exception:
self.db.connection.rollback()
raise
return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)} return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)}
@ -1507,7 +1502,7 @@ class ComponentObjects:
try: try:
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cursor: with self.db.borrowCursor() as cursor:
cursor.execute( cursor.execute(
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)', 'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
(uniqueIds,), (uniqueIds,),
@ -1526,11 +1521,10 @@ class ComponentObjects:
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,)) cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,))
deletedFiles = cursor.rowcount deletedFiles = cursor.rowcount
self.db.connection.commit() # Commit/rollback are handled by `borrowCursor()` context manager.
return {"deletedFiles": deletedFiles} return {"deletedFiles": deletedFiles}
except Exception as e: except Exception as e:
logger.error(f"Error deleting files in batch: {e}") logger.error(f"Error deleting files in batch: {e}")
self.db.connection.rollback()
raise FileDeletionError(f"Error deleting files in batch: {str(e)}") raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]: def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]:

View file

@ -374,7 +374,7 @@ def getRecordsetWithRBAC(
query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}' query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
cursor.execute(query, whereValues) cursor.execute(query, whereValues)
records = [dict(row) for row in cursor.fetchall()] records = [dict(row) for row in cursor.fetchall()]
@ -561,7 +561,7 @@ def getRecordsetPaginatedWithRBAC(
offset = (pagination.page - 1) * pagination.pageSize offset = (pagination.page - 1) * pagination.pageSize
limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}" limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}' countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cursor.execute(countSql, countValues) cursor.execute(countSql, countValues)
totalItems = cursor.fetchone()["count"] totalItems = cursor.fetchone()["count"]
@ -709,7 +709,7 @@ def getDistinctColumnValuesWithRBAC(
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val' sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val'
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
cursor.execute(sql, whereValues) cursor.execute(sql, whereValues)
result = [row["val"] for row in cursor.fetchall()] result = [row["val"] for row in cursor.fetchall()]
@ -719,7 +719,7 @@ def getDistinctColumnValuesWithRBAC(
emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1' emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1'
else: else:
emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1' emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1'
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
cursor.execute(emptySql, whereValues) cursor.execute(emptySql, whereValues)
if cursor.fetchone(): if cursor.fetchone():
result.append(None) result.append(None)
@ -967,7 +967,7 @@ def buildRbacWhereClause(
# Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate # Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate
if table == "UserInDB": if table == "UserInDB":
try: try:
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate # Get all user IDs that are members of the current mandate
cursor.execute( cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true', 'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
@ -994,7 +994,7 @@ def buildRbacWhereClause(
# For UserConnection: Filter via UserMandate junction table # For UserConnection: Filter via UserMandate junction table
elif table == "UserConnection": elif table == "UserConnection":
try: try:
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate # Get all user IDs that are members of the current mandate
cursor.execute( cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true', 'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',

View file

@ -305,7 +305,7 @@ def handleIdsMode(
sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"' sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"'
with db.connection.cursor() as cursor: with db.borrowCursor() as cursor:
cursor.execute(sql, values) cursor.execute(sql, values)
return JSONResponse(content=[row["val"] for row in cursor.fetchall()]) return JSONResponse(content=[row["val"] for row in cursor.fetchall()])
except Exception as e: except Exception as e:

View file

@ -25,6 +25,18 @@ router = APIRouter(
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]: def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
"""Build per-connection RAG inventory rows.
Each DataSource row exposes BOTH numbers because they mean different things:
* `fileCount` distinct files indexed (== `FileContentIndex` rows)
* `chunkCount` embedding-sized text fragments (== `ContentChunk` rows,
max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval
actually hits)
A single PDF typically yields 1 file × 5100 chunks; legacy UI labelled
`len(FileContentIndex)` as "chunks" which was off by 12 orders of
magnitude and misleading.
"""
from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelDataSource import DataSource
from modules.datamodels.datamodelKnowledge import FileContentIndex from modules.datamodels.datamodelKnowledge import FileContentIndex
@ -34,19 +46,35 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId}) connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
connChunkTotal = len(connIndexRows) connFileTotal = len(connIndexRows)
# Map fileId → real chunk count via 1 aggregate query (cheap even for
# connections with thousands of files; we never load the vector body).
fileIds = [
(idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", ""))
for idx in connIndexRows
]
fileIds = [fid for fid in fileIds if fid]
chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
connChunkTotal = sum(chunkCountByFile.values())
filesByDs: Dict[str, int] = {}
chunksByDs: Dict[str, int] = {} chunksByDs: Dict[str, int] = {}
unassigned = 0 unassignedFiles = 0
unassignedChunks = 0
for idx in connIndexRows: for idx in connIndexRows:
fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")
chunkCnt = chunkCountByFile.get(fileId, 0)
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {} struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {} ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {} prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else "" dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
if dsIdRef: if dsIdRef:
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1 filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt
else: else:
unassigned += 1 unassignedFiles += 1
unassignedChunks += chunkCnt
seen: Dict[str, bool] = {} seen: Dict[str, bool] = {}
dsItems = [] dsItems = []
@ -64,14 +92,19 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False), "ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False), "neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None), "lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
"fileCount": filesByDs.get(dsId, 0),
"chunkCount": chunksByDs.get(dsId, 0), "chunkCount": chunksByDs.get(dsId, 0),
}) })
if unassigned > 0 and len(dsItems) > 0: # Spread orphan files (provenance lost) evenly so totals match.
perDs = unassigned // len(dsItems) if unassignedFiles > 0 and len(dsItems) > 0:
remainder = unassigned % len(dsItems) perFile = unassignedFiles // len(dsItems)
remFile = unassignedFiles % len(dsItems)
perChunk = unassignedChunks // len(dsItems)
remChunk = unassignedChunks % len(dsItems)
for i, item in enumerate(dsItems): for i, item in enumerate(dsItems):
item["chunkCount"] += perDs + (1 if i < remainder else 0) item["fileCount"] += perFile + (1 if i < remFile else 0)
item["chunkCount"] += perChunk + (1 if i < remChunk else 0)
# Pull a wider window than the previous 5 so the "last successful # Pull a wider window than the previous 5 so the "last successful
# sync" is found even if a connection has many recent jobs queued. # sync" is found even if a connection has many recent jobs queued.
@ -102,6 +135,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"skippedPolicy": result.get("skippedPolicy", 0), "skippedPolicy": result.get("skippedPolicy", 0),
"failed": result.get("failed", 0), "failed": result.get("failed", 0),
"durationMs": result.get("durationMs", 0), "durationMs": result.get("durationMs", 0),
# Surface limit-stop reason so the UI can warn the user
# that the index is provably incomplete (and which budget
# to raise). None means the walker finished naturally.
"stoppedAtLimit": result.get("stoppedAtLimit"),
"limits": result.get("limits") or {},
"bytesProcessed": result.get("bytesProcessed", 0),
} }
if lastError and lastSuccess: if lastError and lastSuccess:
break break
@ -113,6 +152,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False), "knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
"preferences": getattr(conn, "knowledgePreferences", None) or {}, "preferences": getattr(conn, "knowledgePreferences", None) or {},
"dataSources": dsItems, "dataSources": dsItems,
"totalFiles": connFileTotal,
"totalChunks": connChunkTotal, "totalChunks": connChunkTotal,
"runningJobs": runningJobs, "runningJobs": runningJobs,
"lastError": lastError, "lastError": lastError,
@ -139,8 +179,9 @@ def _getInventoryMe(
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items) totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"chunks": totalChunks}} return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except Exception as e: except Exception as e:
logger.error("Error in RAG inventory /me: %s", e, exc_info=True) logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -170,9 +211,10 @@ def _getInventoryMandate(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items) totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
totalBytes = aggregateMandateRagTotalBytes(mandateId) totalBytes = aggregateMandateRagTotalBytes(mandateId)
return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}} return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -202,8 +244,9 @@ def _getInventoryPlatform(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items) totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"chunks": totalChunks}} return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:

View file

@ -227,7 +227,7 @@ WHERE "workflowId" = ANY(%s)
GROUP BY "workflowId" GROUP BY "workflowId"
""" """
out: dict = {} out: dict = {}
with db.connection.cursor() as cursor: with db.borrowCursor() as cursor:
cursor.execute(sql, (workflowIds,)) cursor.execute(sql, (workflowIds,))
for row in cursor.fetchall(): for row in cursor.fetchall():
r = dict(row) r = dict(row)
@ -480,7 +480,7 @@ def _getWorkflowsJoinedPaginated(
dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}" dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}"
db._ensure_connection() db._ensure_connection()
with db.connection.cursor() as cursor: with db.borrowCursor() as cursor:
cursor.execute(countSql, countValues) cursor.execute(countSql, countValues)
totalItems = int(cursor.fetchone()["cnt"]) totalItems = int(cursor.fetchone()["cnt"])

View file

@ -25,15 +25,14 @@ _CACHE_TTL_SECONDS = 300
def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str): def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str):
"""Reuse a pooled DB connector for the given feature database.""" """Reuse a pooled DB connector for the given feature database.
The underlying psycopg2 connections live in the central pool
(`_PoolRegistry`) and are recreated on demand if they go stale; we just
need to keep the lightweight connector wrapper around.
"""
if featureDbName in _featureDbConnPool: if featureDbName in _featureDbConnPool:
conn = _featureDbConnPool[featureDbName] return _featureDbConnPool[featureDbName]
try:
if conn.connection and not conn.connection.closed:
return conn
except Exception as e:
logger.warning(f"Feature DB connection check failed for {featureDbName}: {e}")
_featureDbConnPool.pop(featureDbName, None)
from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG from modules.shared.configuration import APP_CONFIG

View file

@ -68,6 +68,9 @@ class ClickupBootstrapResult:
workspaces: int = 0 workspaces: int = 0
lists: int = 0 lists: int = 0
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
# First budget exhausted: "maxTasks" | "maxWorkspaces" | "maxListsPerWorkspace" | None.
# Drives the same UI banner as the file-walker bootstraps.
stoppedAtLimit: Optional[str] = None
def _syntheticTaskId(connectionId: str, taskId: str) -> str: def _syntheticTaskId(connectionId: str, taskId: str) -> str:
@ -225,6 +228,7 @@ async def bootstrapClickup(
cancelled = False cancelled = False
for ds in dataSources: for ds in dataSources:
if result.indexed + result.skippedDuplicate >= limits.maxTasks: if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", "dataSource", limits)
break break
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
cancelled = True cancelled = True
@ -243,8 +247,11 @@ async def bootstrapClickup(
clickupScope=limits.clickupScope, clickupScope=limits.clickupScope,
) )
if len(teams) > dsLimits.maxWorkspaces:
_recordLimitStop(result, "maxWorkspaces", "teams", dsLimits, hard=False)
for team in teams[:dsLimits.maxWorkspaces]: for team in teams[:dsLimits.maxWorkspaces]:
if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks: if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks:
_recordLimitStop(result, "maxTasks", f"team={team.get('id','')}", dsLimits)
break break
teamId = str(team.get("id", "") or "") teamId = str(team.get("id", "") or "")
if not teamId: if not teamId:
@ -351,6 +358,7 @@ async def _walkTeam(
for lst in listsCollected: for lst in listsCollected:
if result.indexed + result.skippedDuplicate >= limits.maxTasks: if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", f"team={teamId}", limits)
return return
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
return return
@ -407,6 +415,7 @@ async def _walkList(
for task in tasks: for task in tasks:
if result.indexed + result.skippedDuplicate >= limits.maxTasks: if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", f"list={listId}", limits)
return return
if not _isRecent(task.get("date_updated"), limits.maxAgeDays): if not _isRecent(task.get("date_updated"), limits.maxAgeDays):
result.skippedPolicy += 1 result.skippedPolicy += 1
@ -529,13 +538,37 @@ async def _ingestTask(
) )
def _recordLimitStop(
result: ClickupBootstrapResult,
limitName: str,
where: str,
limits: ClickupBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxTasks": limits.maxTasks,
"maxWorkspaces": limits.maxWorkspaces,
"maxListsPerWorkspace": limits.maxListsPerWorkspace,
}
logger.warning(
"clickup walker hit %s=%s at %s — partial index (indexed=%d, skippedDup=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.skippedDuplicate,
)
def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]: def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000) durationMs = int((time.time() - startMs) * 1000)
logger.info( logger.info(
"ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d", "ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d stoppedAtLimit=%s",
connectionId, connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.workspaces, result.lists, durationMs, result.failed, result.workspaces, result.lists, durationMs,
result.stoppedAtLimit or "none",
extra={ extra={
"event": "ingestion.connection.bootstrap.done", "event": "ingestion.connection.bootstrap.done",
"part": "clickup", "part": "clickup",
@ -547,6 +580,7 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"workspaces": result.workspaces, "workspaces": result.workspaces,
"lists": result.lists, "lists": result.lists,
"durationMs": durationMs, "durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
}, },
) )
return { return {
@ -559,4 +593,11 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"lists": result.lists, "lists": result.lists,
"durationMs": durationMs, "durationMs": durationMs,
"errors": result.errors[:20], "errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxTasks": MAX_TASKS_DEFAULT,
"maxWorkspaces": MAX_WORKSPACES_DEFAULT,
"maxListsPerWorkspace": MAX_LISTS_PER_WORKSPACE_DEFAULT,
"maxAgeDays": MAX_AGE_DAYS_DEFAULT,
},
} }

View file

@ -61,6 +61,8 @@ class GdriveBootstrapResult:
failed: int = 0 failed: int = 0
bytesProcessed: int = 0 bytesProcessed: int = 0
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str: def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -265,8 +267,10 @@ async def _walkFolder(
for entry in entries: for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems: if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return return
if result.bytesProcessed >= limits.maxBytes: if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return return
@ -276,6 +280,9 @@ async def _walkFolder(
mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType") mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType")
if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME: if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME:
if depth + 1 > limits.maxDepth:
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder( await _walkFolder(
adapter=adapter, adapter=adapter,
knowledgeService=knowledgeService, knowledgeService=knowledgeService,
@ -298,6 +305,7 @@ async def _walkFolder(
continue continue
size = int(getattr(entry, "size", 0) or 0) size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize: if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1 result.skippedPolicy += 1
continue continue
modifiedTime = metadata.get("modifiedTime") modifiedTime = metadata.get("modifiedTime")
@ -470,13 +478,38 @@ async def _ingestOne(
await asyncio.sleep(0) await asyncio.sleep(0)
def _recordLimitStop(
result: GdriveBootstrapResult,
limitName: str,
where: str,
limits: GdriveBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"gdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]: def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000) durationMs = int((time.time() - startMs) * 1000)
logger.info( logger.info(
"ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d", "ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d stoppedAtLimit=%s",
connectionId, connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.bytesProcessed, durationMs, result.failed, result.bytesProcessed, durationMs,
result.stoppedAtLimit or "none",
extra={ extra={
"event": "ingestion.connection.bootstrap.done", "event": "ingestion.connection.bootstrap.done",
"part": "gdrive", "part": "gdrive",
@ -487,6 +520,7 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"failed": result.failed, "failed": result.failed,
"bytes": result.bytesProcessed, "bytes": result.bytesProcessed,
"durationMs": durationMs, "durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
}, },
) )
return { return {
@ -498,4 +532,11 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed, "bytesProcessed": result.bytesProcessed,
"durationMs": durationMs, "durationMs": durationMs,
"errors": result.errors[:20], "errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
} }

View file

@ -53,6 +53,8 @@ class KdriveBootstrapResult:
failed: int = 0 failed: int = 0
bytesProcessed: int = 0 bytesProcessed: int = 0
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str: def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -232,14 +234,19 @@ async def _walkFolder(
for entry in entries: for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems: if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return return
if result.bytesProcessed >= limits.maxBytes: if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return return
entryPath = getattr(entry, "path", "") or "" entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False): if getattr(entry, "isFolder", False):
if depth + 1 > limits.maxDepth:
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder( await _walkFolder(
adapter=adapter, adapter=adapter,
knowledgeService=knowledgeService, knowledgeService=knowledgeService,
@ -262,6 +269,7 @@ async def _walkFolder(
continue continue
size = int(getattr(entry, "size", 0) or 0) size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize: if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1 result.skippedPolicy += 1
continue continue
@ -415,17 +423,42 @@ async def _ingestOne(
await asyncio.sleep(0) await asyncio.sleep(0)
def _recordLimitStop(
result: KdriveBootstrapResult,
limitName: str,
where: str,
limits: KdriveBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"kdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]: def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000) durationMs = int((time.time() - startMs) * 1000)
logger.info( logger.info(
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d", "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId, connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
durationMs, durationMs, result.stoppedAtLimit or "none",
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive", extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
"connectionId": connectionId, "indexed": result.indexed, "connectionId": connectionId, "indexed": result.indexed,
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy, "skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
"failed": result.failed, "durationMs": durationMs}, "failed": result.failed, "durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit},
) )
return { return {
"connectionId": result.connectionId, "connectionId": result.connectionId,
@ -436,4 +469,11 @@ def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed, "bytesProcessed": result.bytesProcessed,
"durationMs": durationMs, "durationMs": durationMs,
"errors": result.errors[:20], "errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
} }

View file

@ -59,6 +59,10 @@ class SharepointBootstrapResult:
failed: int = 0 failed: int = 0
bytesProcessed: int = 0 bytesProcessed: int = 0
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
# First budget that hit zero; None means the walk completed naturally.
# Surfaces in the bootstrap result so the RAG inventory UI can warn the
# user that the corpus is incomplete and tell them which knob to turn.
stoppedAtLimit: Optional[str] = None # "maxItems" | "maxBytes" | "maxDepth" | "maxFileSize" | None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str: def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -259,14 +263,22 @@ async def _walkFolder(
for entry in entries: for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems: if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return return
if result.bytesProcessed >= limits.maxBytes: if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return return
entryPath = getattr(entry, "path", "") or "" entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False): if getattr(entry, "isFolder", False):
if depth + 1 > limits.maxDepth:
# We stop descending here but keep walking siblings.
# Record once per bootstrap so the UI shows "maxDepth" even
# if other budgets aren't exhausted yet.
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder( await _walkFolder(
adapter=adapter, adapter=adapter,
knowledgeService=knowledgeService, knowledgeService=knowledgeService,
@ -289,6 +301,7 @@ async def _walkFolder(
continue continue
size = int(getattr(entry, "size", 0) or 0) size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize: if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1 result.skippedPolicy += 1
continue continue
@ -443,13 +456,44 @@ async def _ingestOne(
await asyncio.sleep(0) await asyncio.sleep(0)
def _recordLimitStop(
result: SharepointBootstrapResult,
limitName: str,
where: str,
limits: SharepointBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""Mark the FIRST limit that bit. Soft hits (per-file maxFileSize, per-folder
maxDepth) only record when no hard limit has yet stopped the run, so the UI
surfaces the most important reason.
Hard limits (maxItems / maxBytes) ALWAYS overwrite a previously recorded
soft limit once a hard cap is hit, the corpus is provably incomplete.
"""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"sharepoint walker hit %s=%s at %s — partial index "
"(indexed=%d, bytesProcessed=%d). Raise the limit or split the data source.",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]: def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000) durationMs = int((time.time() - startMs) * 1000)
logger.info( logger.info(
"ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d", "ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId, connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
durationMs, durationMs, result.stoppedAtLimit or "none",
extra={ extra={
"event": "ingestion.connection.bootstrap.done", "event": "ingestion.connection.bootstrap.done",
"part": "sharepoint", "part": "sharepoint",
@ -459,6 +503,7 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"skippedPolicy": result.skippedPolicy, "skippedPolicy": result.skippedPolicy,
"failed": result.failed, "failed": result.failed,
"durationMs": durationMs, "durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
}, },
) )
return { return {
@ -470,4 +515,11 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"bytesProcessed": result.bytesProcessed, "bytesProcessed": result.bytesProcessed,
"durationMs": durationMs, "durationMs": durationMs,
"errors": result.errors[:20], "errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
} }

View file

@ -12,7 +12,8 @@ import logging
import json import json
import base64 import base64
import time import time
from typing import Any, Dict, Optional import threading
from typing import Any, Dict, Optional, Tuple
from pathlib import Path from pathlib import Path
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
@ -286,6 +287,16 @@ def handleSecretJson(value: str, userId: str = "system", keyName: str = "unknown
# Structure: {user_id: {key_name: [timestamps]}} # Structure: {user_id: {key_name: [timestamps]}}
_decryption_attempts = {} _decryption_attempts = {}
# Process-wide plaintext cache for decrypted secrets.
# Key: the encrypted ciphertext (which already includes env prefix).
# Value: (expiresAtMonotonic, plaintext).
# TTL is short enough that key rotation propagates quickly, long enough that
# hot DB-init paths (every API call building a connector) don't blow the
# decryption rate limit. 60s is a deliberate compromise.
_DECRYPTION_CACHE_TTL_S = 60.0
_decryption_cache: Dict[str, Tuple[float, str]] = {}
_decryption_cache_lock = threading.Lock()
def _getMasterKey(envType: str = None) -> bytes: def _getMasterKey(envType: str = None) -> bytes:
""" """
Get the master key for the specified environment. Get the master key for the specified environment.
@ -487,6 +498,16 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
""" """
Decrypt a value using the master key for the current environment. Decrypt a value using the master key for the current environment.
A short-lived plaintext cache (TTL `_DECRYPTION_CACHE_TTL_S`) is consulted
first. The 10/sec rate-limit on cache misses still protects against
brute-force attacks; cache HITS bypass it because they are not actual
cryptographic operations they just return the result of an earlier
successful decrypt. Without this cache, hot paths like
`mainBackgroundJobService._getDb()` (called per RAG inventory poll AND
per walker DB call) trigger the rate limit and surface as
"Decryption rate limit exceeded for user 'system' key 'DB_PASSWORD_SECRET'"
ERRORs in the RAG inventory UI route.
Args: Args:
encryptedValue: The encrypted value with prefix encryptedValue: The encrypted value with prefix
userId: The user ID making the request (default: "system") userId: The user ID making the request (default: "system")
@ -501,7 +522,15 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
if not _isEncryptedValue(encryptedValue): if not _isEncryptedValue(encryptedValue):
return encryptedValue # Return as-is if not encrypted return encryptedValue # Return as-is if not encrypted
# Check rate limiting (10 per second per user per key) # Cache lookup BEFORE the rate-limit check: a cache hit is not a new
# cryptographic operation and must not be throttled.
now = time.monotonic()
with _decryption_cache_lock:
cached = _decryption_cache.get(encryptedValue)
if cached is not None and cached[0] > now:
return cached[1]
# Cache miss → real decrypt → apply rate limit.
if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10): if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10):
raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)") raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)")
@ -550,10 +579,24 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
# Don't fail if audit logging fails # Don't fail if audit logging fails
pass pass
# Populate cache so subsequent reads of the same ciphertext don't
# re-decrypt (and don't consume rate-limit budget).
with _decryption_cache_lock:
_decryption_cache[encryptedValue] = (
time.monotonic() + _DECRYPTION_CACHE_TTL_S,
decryptedValue,
)
return decryptedValue return decryptedValue
except Exception as e: except Exception as e:
raise ValueError(f"Decryption failed: {e}") raise ValueError(f"Decryption failed: {e}")
def clearDecryptionCache() -> None:
"""Drop all cached plaintext secrets. Call after key rotation or in tests."""
with _decryption_cache_lock:
_decryption_cache.clear()
# Create the global APP_CONFIG instance # Create the global APP_CONFIG instance
APP_CONFIG = Configuration() APP_CONFIG = Configuration()

View file

@ -33,20 +33,35 @@ def _ensureUamTablesMatchModels(dbConnector) -> None:
logger.debug(f"_ensureUamTablesMatchModels: {e}") logger.debug(f"_ensureUamTablesMatchModels: {e}")
def _getConnection(dbConnector): from contextlib import contextmanager
"""Get a connection from the DatabaseConnector.
Ensures the connection is alive and returns it.
Commits any pending transaction first to avoid blocking. @contextmanager
def _borrowDbConn(dbConnector):
"""Borrow a pooled connection from the DatabaseConnector.
Index/trigger/FK creation traditionally ran with `conn.autocommit = True`
so each CREATE statement is its own transaction (DDL on a managed
connection blocks waiting for COMMIT). This helper preserves that
behaviour on top of the pool: borrow a connection, flip it to autocommit,
yield it, and restore the previous state before returning it to the pool.
""" """
dbConnector._ensure_connection() with dbConnector.borrowConn() as conn:
conn = dbConnector.connection
# Commit any pending transaction to avoid blocking
try: try:
conn.commit() previousAutocommit = conn.autocommit
except Exception: except Exception:
pass # Ignore if nothing to commit previousAutocommit = False
return conn try:
conn.autocommit = True
except Exception as e:
logger.debug(f"Could not set autocommit on borrowed connection: {e}")
try:
yield conn
finally:
try:
conn.autocommit = previousAutocommit
except Exception:
pass
# ============================================================================= # =============================================================================
@ -174,34 +189,15 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
} }
try: try:
# Get a connection from the connector
conn = _getConnection(dbConnector)
# Save and set autocommit state
try:
originalAutocommit = conn.autocommit
except Exception:
originalAutocommit = False
try:
conn.autocommit = True
except Exception as autoErr:
logger.debug(f"Could not set autocommit: {autoErr}")
try: try:
_ensureUamTablesMatchModels(dbConnector) _ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr: except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}") logger.debug(f"Pre-index table ensure: {preIdxErr}")
try: with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
# Apply indexes
results["indexesCreated"] = _applyIndexes(cursor, tables) results["indexesCreated"] = _applyIndexes(cursor, tables)
# Apply foreign keys
results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables) results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables)
# Apply immutable triggers
results["triggersCreated"] = _applyImmutableTriggers(cursor, tables) results["triggersCreated"] = _applyImmutableTriggers(cursor, tables)
logger.info( logger.info(
@ -210,12 +206,6 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
f"{results['triggersCreated']} triggers, " f"{results['triggersCreated']} triggers, "
f"{results['foreignKeysCreated']} foreign keys" f"{results['foreignKeysCreated']} foreign keys"
) )
finally:
# Restore original autocommit state
try:
conn.autocommit = originalAutocommit
except Exception:
pass
except Exception as e: except Exception as e:
logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}") logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}")
@ -227,20 +217,14 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int: def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int:
"""Apply only indexes (lighter operation, safe for frequent calls).""" """Apply only indexes (lighter operation, safe for frequent calls)."""
try: try:
conn = _getConnection(dbConnector)
originalAutocommit = conn.autocommit
conn.autocommit = True
try: try:
_ensureUamTablesMatchModels(dbConnector) _ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr: except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}") logger.debug(f"Pre-index table ensure: {preIdxErr}")
try: with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
return _applyIndexes(cursor, tables) return _applyIndexes(cursor, tables)
finally:
conn.autocommit = originalAutocommit
except Exception as e: except Exception as e:
logger.error(f"Error applying indexes: {e}") logger.error(f"Error applying indexes: {e}")
return 0 return 0
@ -514,8 +498,7 @@ def getOptimizationStatus(dbConnector) -> dict:
} }
try: try:
conn = _getConnection(dbConnector) with _borrowDbConn(dbConnector) as conn, conn.cursor() as cursor:
with conn.cursor() as cursor:
# Check regular indexes # Check regular indexes
for tableName, indexName, _ in _INDEXES: for tableName, indexName, _ in _INDEXES:
if _tableExists(cursor, tableName): if _tableExists(cursor, tableName):

View file

@ -60,11 +60,9 @@ def _getTableColumns(dbConnector, tableName: str) -> List[str]:
ORDER BY ordinal_position ORDER BY ordinal_position
""" """
cursor = dbConnector.connection.cursor() with dbConnector.borrowCursor() as cursor:
cursor.execute(query, (tableName,)) cursor.execute(query, (tableName,))
columns = [row[0] for row in cursor.fetchall()] columns = [row[0] for row in cursor.fetchall()]
cursor.close()
return columns return columns
except Exception as e: except Exception as e:
logger.error(f"Error getting columns for table {tableName}: {e}") logger.error(f"Error getting columns for table {tableName}: {e}")
@ -92,11 +90,10 @@ def _getAllTables(dbConnector) -> List[str]:
ORDER BY table_name ORDER BY table_name
""" """
cursor = dbConnector.connection.cursor() with dbConnector.borrowCursor() as cursor:
cursor.execute(query) cursor.execute(query)
allTables = [row[0] for row in cursor.fetchall()] allTables = [row[0] for row in cursor.fetchall()]
# Get foreign key relationships to determine dependency order
fkQuery = """ fkQuery = """
SELECT SELECT
tc.table_name, tc.table_name,
@ -111,10 +108,8 @@ def _getAllTables(dbConnector) -> List[str]:
WHERE tc.constraint_type = 'FOREIGN KEY' WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public' AND tc.table_schema = 'public'
""" """
cursor.execute(fkQuery) cursor.execute(fkQuery)
foreignKeys = cursor.fetchall() foreignKeys = cursor.fetchall()
cursor.close()
# Build dependency graph (child -> parent mapping) # Build dependency graph (child -> parent mapping)
dependencies = {} dependencies = {}
@ -154,10 +149,9 @@ def _getAllTables(dbConnector) -> List[str]:
# Fallback: return simple list without ordering # Fallback: return simple list without ordering
try: try:
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'" query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
cursor = dbConnector.connection.cursor() with dbConnector.borrowCursor() as cursor:
cursor.execute(query) cursor.execute(query)
tables = [row[0] for row in cursor.fetchall()] tables = [row[0] for row in cursor.fetchall()]
cursor.close()
return [t for t in tables if t not in PROTECTED_TABLES] return [t for t in tables if t not in PROTECTED_TABLES]
except Exception: except Exception:
return [] return []
@ -184,11 +178,9 @@ def _getPrimaryKeyColumns(dbConnector, tableName: str) -> List[str]:
AND i.indisprimary AND i.indisprimary
""" """
cursor = dbConnector.connection.cursor() with dbConnector.borrowCursor() as cursor:
cursor.execute(query, (tableName,)) cursor.execute(query, (tableName,))
pkColumns = [row[0] for row in cursor.fetchall()] pkColumns = [row[0] for row in cursor.fetchall()]
cursor.close()
return pkColumns return pkColumns
except Exception as e: except Exception as e:
logger.debug(f"Could not get primary key for {tableName}: {e}") logger.debug(f"Could not get primary key for {tableName}: {e}")
@ -229,21 +221,15 @@ def _findUserReferencesInTable(
return {} return {}
references = {} references = {}
cursor = dbConnector.connection.cursor() with dbConnector.borrowCursor() as cursor:
for userColumn in userColumns: for userColumn in userColumns:
# Build SELECT for primary key columns
pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns]) pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s' query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
cursor.execute(query, (userId,)) cursor.execute(query, (userId,))
recordKeys = cursor.fetchall() recordKeys = cursor.fetchall()
if recordKeys: if recordKeys:
references[userColumn] = recordKeys references[userColumn] = recordKeys
logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}") logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
cursor.close()
return references return references
except Exception as e: except Exception as e:
@ -277,17 +263,15 @@ def _anonymizeRecords(
return 0 return 0
try: try:
cursor = dbConnector.connection.cursor() # Resolve column metadata once outside the borrow block (it borrows its
count = 0 # own connection internally).
for recordKey in recordKeys:
# Build WHERE clause for primary key
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
# Check if table has sysModifiedAt column
columns = _getTableColumns(dbConnector, tableName) columns = _getTableColumns(dbConnector, tableName)
hasModifiedAt = "sysModifiedAt" in columns hasModifiedAt = "sysModifiedAt" in columns
count = 0
with dbConnector.borrowCursor() as cursor:
for recordKey in recordKeys:
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
if hasModifiedAt: if hasModifiedAt:
query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}' query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
params = [anonymousValue, getUtcTimestamp()] params = [anonymousValue, getUtcTimestamp()]
@ -295,7 +279,6 @@ def _anonymizeRecords(
query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}' query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
params = [anonymousValue] params = [anonymousValue]
# Add primary key values to params
if isinstance(recordKey, tuple): if isinstance(recordKey, tuple):
params.extend(recordKey) params.extend(recordKey)
else: else:
@ -304,15 +287,11 @@ def _anonymizeRecords(
cursor.execute(query, params) cursor.execute(query, params)
count += cursor.rowcount count += cursor.rowcount
dbConnector.connection.commit()
cursor.close()
logger.info(f"Anonymized {count} records in {tableName}.{columnName}") logger.info(f"Anonymized {count} records in {tableName}.{columnName}")
return count return count
except Exception as e: except Exception as e:
logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}") logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}")
dbConnector.connection.rollback()
return 0 return 0
@ -338,32 +317,23 @@ def _deleteRecords(
return 0 return 0
try: try:
cursor = dbConnector.connection.cursor()
count = 0 count = 0
with dbConnector.borrowCursor() as cursor:
for recordKey in recordKeys: for recordKey in recordKeys:
# Build WHERE clause for primary key
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
query = f'DELETE FROM "{tableName}" WHERE {whereClause}' query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
# Prepare params
if isinstance(recordKey, tuple): if isinstance(recordKey, tuple):
params = list(recordKey) params = list(recordKey)
else: else:
params = [recordKey] params = [recordKey]
cursor.execute(query, params) cursor.execute(query, params)
count += cursor.rowcount count += cursor.rowcount
dbConnector.connection.commit()
cursor.close()
logger.info(f"Deleted {count} records from {tableName}") logger.info(f"Deleted {count} records from {tableName}")
return count return count
except Exception as e: except Exception as e:
logger.error(f"Error deleting records from {tableName}: {e}") logger.error(f"Error deleting records from {tableName}: {e}")
dbConnector.connection.rollback()
return 0 return 0

View file

@ -25,7 +25,7 @@ if not c or not c.connection:
print("STAGE0: DB_CONNECTION=none (check config.ini / .env)") print("STAGE0: DB_CONNECTION=none (check config.ini / .env)")
raise SystemExit(2) raise SystemExit(2)
cur = c.connection.cursor() cur = c.borrowCursor()
def _scalar(cur): def _scalar(cur):

View file

@ -12,11 +12,16 @@ broken query into "no rows found". That hid bugs like:
These tests pin the new contract: empty result sets still return ``[]`` / These tests pin the new contract: empty result sets still return ``[]`` /
``None`` (normal), but any exception inside the query path propagates as ``None`` (normal), but any exception inside the query path propagates as
``DatabaseQueryError`` with the table name attached. The transaction is ``DatabaseQueryError`` with the table name attached.
rolled back so the connection is usable for subsequent queries.
Since the 2026-05-17 pool refactor (`c-work/2-build/2026-05-postgres-connection-pool.md`)
the connector borrows a connection from `_PoolRegistry` on every call via the
`borrowConn()` context manager. The tests mock that context manager so the
fast-fail contract is exercised without requiring a live Postgres server.
""" """
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -25,7 +30,6 @@ import psycopg2.errors
from modules.connectors.connectorDbPostgre import ( from modules.connectors.connectorDbPostgre import (
DatabaseConnector, DatabaseConnector,
DatabaseQueryError, DatabaseQueryError,
_rollbackQuietly,
) )
@ -39,26 +43,44 @@ class DummyTable:
def _makeConnector(cursorBehavior): def _makeConnector(cursorBehavior):
"""Build a ``DatabaseConnector`` skeleton with mocked connection/cursor. """Build a ``DatabaseConnector`` skeleton with a mocked pool borrow.
``cursorBehavior`` is a callable invoked with the cursor mock so the test ``cursorBehavior`` is a callable invoked with the cursor mock so the test
can configure ``execute``/``fetchall``/``fetchone`` per scenario. can configure ``execute``/``fetchall``/``fetchone`` per scenario.
Returns ``(connector, conn, cursor)``:
* ``conn`` exposes ``commit`` / ``rollback`` MagicMocks so tests can
assert that the borrow lifecycle did the right thing.
* ``cursor`` is the per-test cursor mock.
""" """
connector = DatabaseConnector.__new__(DatabaseConnector) connector = DatabaseConnector.__new__(DatabaseConnector)
cursor = MagicMock() cursor = MagicMock()
cursorBehavior(cursor)
cursorContext = MagicMock() cursorContext = MagicMock()
cursorContext.__enter__ = MagicMock(return_value=cursor) cursorContext.__enter__ = MagicMock(return_value=cursor)
cursorContext.__exit__ = MagicMock(return_value=False) cursorContext.__exit__ = MagicMock(return_value=False)
connection = MagicMock() conn = MagicMock()
connection.cursor.return_value = cursorContext conn.cursor.return_value = cursorContext
connector.connection = connection
@contextmanager
def fakeBorrow():
try:
yield conn
except Exception:
conn.rollback()
raise
else:
conn.commit()
connector.borrowConn = fakeBorrow
connector._ensureTableExists = MagicMock(return_value=True) connector._ensureTableExists = MagicMock(return_value=True)
connector._systemTableName = "_system" connector._systemTableName = "_system"
cursorBehavior(cursor) return connector, conn, cursor
return connector, connection, cursor
class TestGetRecordsetFailLoud: class TestGetRecordsetFailLoud:
@ -67,11 +89,12 @@ class TestGetRecordsetFailLoud:
def behavior(cursor): def behavior(cursor):
cursor.execute.return_value = None cursor.execute.return_value = None
cursor.fetchall.return_value = [] cursor.fetchall.return_value = []
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
result = connector.getRecordset(DummyTable) result = connector.getRecordset(DummyTable)
assert result == [] assert result == []
connection.rollback.assert_not_called() conn.rollback.assert_not_called()
conn.commit.assert_called_once()
def test_dictAdaptErrorRaisesDatabaseQueryError(self): def test_dictAdaptErrorRaisesDatabaseQueryError(self):
"""Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise.""" """Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise."""
@ -79,7 +102,7 @@ class TestGetRecordsetFailLoud:
cursor.execute.side_effect = psycopg2.ProgrammingError( cursor.execute.side_effect = psycopg2.ProgrammingError(
"can't adapt type 'dict'" "can't adapt type 'dict'"
) )
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo: with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset( connector.getRecordset(
@ -90,30 +113,30 @@ class TestGetRecordsetFailLoud:
assert excinfo.value.table == "DummyTable" assert excinfo.value.table == "DummyTable"
assert "can't adapt type 'dict'" in str(excinfo.value) assert "can't adapt type 'dict'" in str(excinfo.value)
assert isinstance(excinfo.value.original, psycopg2.ProgrammingError) assert isinstance(excinfo.value.original, psycopg2.ProgrammingError)
connection.rollback.assert_called_once() conn.rollback.assert_called_once()
def test_missingColumnRaisesDatabaseQueryError(self): def test_missingColumnRaisesDatabaseQueryError(self):
def behavior(cursor): def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedColumn( cursor.execute.side_effect = psycopg2.errors.UndefinedColumn(
'column "wat" does not exist' 'column "wat" does not exist'
) )
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo: with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset(DummyTable, recordFilter={"wat": "x"}) connector.getRecordset(DummyTable, recordFilter={"wat": "x"})
assert "wat" in str(excinfo.value) assert "wat" in str(excinfo.value)
connection.rollback.assert_called_once() conn.rollback.assert_called_once()
def test_operationalErrorRaisesDatabaseQueryError(self): def test_operationalErrorRaisesDatabaseQueryError(self):
"""Connection lost mid-query is also a real failure that must propagate.""" """Connection lost mid-query is also a real failure that must propagate."""
def behavior(cursor): def behavior(cursor):
cursor.execute.side_effect = psycopg2.OperationalError("connection lost") cursor.execute.side_effect = psycopg2.OperationalError("connection lost")
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError): with pytest.raises(DatabaseQueryError):
connector.getRecordset(DummyTable) connector.getRecordset(DummyTable)
connection.rollback.assert_called_once() conn.rollback.assert_called_once()
class TestGetRecordFailLoud: class TestGetRecordFailLoud:
@ -122,37 +145,22 @@ class TestGetRecordFailLoud:
def behavior(cursor): def behavior(cursor):
cursor.execute.return_value = None cursor.execute.return_value = None
cursor.fetchone.return_value = None cursor.fetchone.return_value = None
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
result = connector.getRecord(DummyTable, "missing-id") result = connector.getRecord(DummyTable, "missing-id")
assert result is None assert result is None
connection.rollback.assert_not_called() conn.rollback.assert_not_called()
conn.commit.assert_called_once()
def test_queryErrorRaisesDatabaseQueryError(self): def test_queryErrorRaisesDatabaseQueryError(self):
def behavior(cursor): def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedTable( cursor.execute.side_effect = psycopg2.errors.UndefinedTable(
'relation "DummyTable" does not exist' 'relation "DummyTable" does not exist'
) )
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo: with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecord(DummyTable, "any-id") connector.getRecord(DummyTable, "any-id")
assert excinfo.value.table == "DummyTable" assert excinfo.value.table == "DummyTable"
connection.rollback.assert_called_once() conn.rollback.assert_called_once()
class TestRollbackQuietly:
def test_rollsBackOnLiveConnection(self):
connection = MagicMock()
_rollbackQuietly(connection)
connection.rollback.assert_called_once()
def test_swallowsRollbackError(self):
"""Rollback failure must not mask the original query error."""
connection = MagicMock()
connection.rollback.side_effect = RuntimeError("rollback failed")
_rollbackQuietly(connection)
def test_noopOnNoneConnection(self):
_rollbackQuietly(None)

View file

@ -0,0 +1,304 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Concurrency tests for the PostgreSQL connection pool.
These tests pin the contract that the `c-work/2-build/2026-05-postgres-connection-pool.md`
refactor delivered:
* T1 50 threads × 100 calls in parallel produce 0 `OperationalError`s and
every call completes within reasonable time (p99 < 2 s).
* T2 Two threads `_loadRecord` + `_saveRecord` against the same connector
do not corrupt each other's cursors.
* T3 `statement_timeout` aborts a runaway `pg_sleep(60)` after ~30 s and
releases the connection back into the pool cleanly.
The tests need a real PostgreSQL server because the bug they guard against
only materialises with real psycopg2 sockets a mocked connection never
hangs in `recv()`. They read DB credentials from `APP_CONFIG` (which loads
`.env`) and are auto-skipped when the connection fails (no local Postgres,
wrong creds, etc.) so `pytest` keeps working in CI-only environments.
To run them locally:
pytest gateway/tests/unit/connectors/test_connectorDbPostgre_pool.py -v
They use a throwaway database name (`poweron_pool_test_<uuid>`) and drop it
in fixture teardown so they leave nothing behind.
"""
from __future__ import annotations
import time
import uuid
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import psycopg2
import psycopg2.errors
import pytest
from pydantic import Field
from modules.connectors.connectorDbPostgre import (
DatabaseConnector,
_PoolRegistry,
closeAllPools,
)
from modules.datamodels.datamodelBase import PowerOnModel
from modules.shared.configuration import APP_CONFIG
def _dbConfig():
"""Read DB connection params from APP_CONFIG (`.env`).
Returns ``None`` when host/user/password are not all present so the
test module can skip cleanly instead of blowing up at import time.
"""
host = APP_CONFIG.get("DB_HOST")
user = APP_CONFIG.get("DB_USER")
password = APP_CONFIG.get("DB_PASSWORD_SECRET")
port = APP_CONFIG.get("DB_PORT", 5432)
if not host or not user or password is None:
return None
return {"host": host, "user": user, "password": password, "port": int(port)}
def _canReachPostgres(cfg) -> bool:
"""Try a quick connect to the admin DB so we can skip on connection failures."""
try:
conn = psycopg2.connect(
host=cfg["host"], port=cfg["port"], database="postgres",
user=cfg["user"], password=cfg["password"], connect_timeout=2,
)
conn.close()
return True
except Exception: # noqa: BLE001
return False
_DB_CFG = _dbConfig()
pytestmark = pytest.mark.skipif(
_DB_CFG is None or not _canReachPostgres(_DB_CFG),
reason="No reachable PostgreSQL — skipping live-Postgres pool tests",
)
class PoolTestRow(PowerOnModel):
"""Tiny model used to exercise the pool — one ID + one payload field."""
payload: str = Field(default="", description="Test payload")
@pytest.fixture
def liveConnector():
"""Spin up a throwaway database, yield a `DatabaseConnector` against it,
drop the database afterwards.
The pool registry is wiped before and after each test so state from one
test cannot mask a bug in another.
"""
cfg = _DB_CFG
host = cfg["host"]
user = cfg["user"]
password = cfg["password"]
port = cfg["port"]
dbName = f"poweron_pool_test_{uuid.uuid4().hex[:8]}"
# Pre-clean: drop any orphan test DB with the same name (shouldn't happen
# because we use a unique uuid, but be defensive).
adminConn = psycopg2.connect(
host=host, port=port, database="postgres", user=user, password=password
)
adminConn.autocommit = True
try:
with adminConn.cursor() as cur:
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
finally:
adminConn.close()
closeAllPools()
connector = DatabaseConnector(
dbHost=host,
dbDatabase=dbName,
dbUser=user,
dbPassword=password,
dbPort=port,
)
# Seed exactly one row so every concurrent read has a stable target.
connector.recordCreate(PoolTestRow, {"id": "seed", "payload": "hello"})
yield connector
# Teardown: tear pools down, then drop the DB.
closeAllPools()
adminConn = psycopg2.connect(
host=host, port=port, database="postgres", user=user, password=password
)
adminConn.autocommit = True
try:
with adminConn.cursor() as cur:
cur.execute(
'SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s',
(dbName,),
)
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
finally:
adminConn.close()
class TestPoolConcurrency:
def _runWorkers(self, liveConnector, *, threadCount: int, callsPerThread: int):
"""Run N worker threads, each issuing M reads. Return (errors, latencies)."""
errors: list = []
latencies: list = []
lock = threading.Lock()
def worker():
for _ in range(callsPerThread):
t0 = time.perf_counter()
try:
rows = liveConnector.getRecordset(PoolTestRow)
assert any(r["id"] == "seed" for r in rows)
except Exception as e: # noqa: BLE001 — we want every failure mode
with lock:
errors.append(e)
finally:
with lock:
latencies.append(time.perf_counter() - t0)
with ThreadPoolExecutor(max_workers=threadCount) as ex:
futures = [ex.submit(worker) for _ in range(threadCount)]
for f in as_completed(futures):
f.result()
latencies.sort()
return errors, latencies
def test_50_threads_x_20_reads_no_errors(self, liveConnector):
"""T1a — STRESS: 50 threads × 20 reads each → 0 errors.
Pre-pool, this scenario produced either
`OperationalError: another command is already in progress` or a
deadlock in `recv()` because the threadpool shared one psycopg2
socket. With the pool plus `borrowConn`'s bounded wait, every
thread eventually gets a connection and completes even with 30
threads queued waiting at any moment (pool max=20).
"""
errors, _ = self._runWorkers(liveConnector, threadCount=50, callsPerThread=20)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
def test_20_threads_x_50_reads_latency_budget(self, liveConnector):
"""T1b — DESIGN CAPACITY: 20 threads × 50 reads, p99 < 5 s.
20 threads matches the pool's `max=20` so there is no queueing —
every borrow returns immediately. This pins a sanity-level per-call
latency budget; pre-pool it was unbounded (recv() never returned).
The 5 s ceiling is generous on purpose: `getRecordset` calls
`_ensureTableExists` which runs two `information_schema` queries
for column-additive migration, and under 20-way concurrency on a
single Postgres instance that produces a long tail. The hard
assertion is `not errors` the latency check just guarantees
nothing hangs indefinitely.
"""
errors, latencies = self._runWorkers(
liveConnector, threadCount=20, callsPerThread=50
)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
p99 = latencies[int(len(latencies) * 0.99)]
assert p99 < 5.0, f"p99 latency {p99:.2f}s exceeds 5s budget"
def test_interleaved_load_and_save_no_collision(self, liveConnector):
"""T2: parallel reads + writes on the same connector → no cursor mix-up.
Pre-pool the reader could observe a row in mid-write or vice versa
because both shared the same cursor. With one connection per borrow,
the database's own row-locking is the only contention, and we just
need to assert no exceptions.
"""
stopFlag = threading.Event()
errors: list = []
lock = threading.Lock()
def reader():
while not stopFlag.is_set():
try:
liveConnector.getRecord(PoolTestRow, "seed")
except Exception as e: # noqa: BLE001
with lock:
errors.append(("read", e))
def writer():
i = 0
while not stopFlag.is_set():
try:
liveConnector.recordModify(
PoolTestRow,
"seed",
{"id": "seed", "payload": f"v{i}"},
)
i += 1
except Exception as e: # noqa: BLE001
with lock:
errors.append(("write", e))
threads = [
threading.Thread(target=reader, daemon=True),
threading.Thread(target=reader, daemon=True),
threading.Thread(target=writer, daemon=True),
threading.Thread(target=writer, daemon=True),
]
for t in threads:
t.start()
time.sleep(2.0)
stopFlag.set()
for t in threads:
t.join(timeout=3.0)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
def test_statement_timeout_releases_connection(self, liveConnector):
"""T3: `pg_sleep` past statement_timeout → QueryCanceled, pool intact.
The bug we are guarding against: a runaway query with no timeout
hung `recv()` forever, the psycopg2 connection was poisoned, and the
whole backend became unresponsive once that connection was reused.
With `statement_timeout=30000` configured at pool construction the
query is cancelled by the server, the borrow context manager rolls
back, and the connection returns to the pool proven by the fact
that a follow-up call still succeeds quickly.
"""
# Use a short timeout to keep the test fast — override the pool's
# session statement_timeout for one borrow via SET LOCAL.
with liveConnector.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute("SET LOCAL statement_timeout = 500")
with pytest.raises(psycopg2.errors.QueryCanceled):
cursor.execute("SELECT pg_sleep(5)")
# Follow-up call must succeed quickly: connection is back in the pool.
t0 = time.perf_counter()
rows = liveConnector.getRecordset(PoolTestRow)
elapsed = time.perf_counter() - t0
assert any(r["id"] == "seed" for r in rows)
assert elapsed < 1.0, f"follow-up call took {elapsed:.2f}s — pool may be wedged"
class TestPoolRegistry:
def test_one_pool_per_database_identity(self, liveConnector):
"""Two connectors against the same (host, db, port) share one pool."""
cfg = _DB_CFG
pool1 = _PoolRegistry.getPool(
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
)
pool2 = _PoolRegistry.getPool(
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
)
assert pool1 is pool2
def test_close_all_clears_registry(self, liveConnector):
"""`closeAllPools()` empties the registry so the next call rebuilds."""
# Touch the pool first.
liveConnector.getRecordset(PoolTestRow)
assert _PoolRegistry._pools, "pool should exist after a real call"
closeAllPools()
assert _PoolRegistry._pools == {}, "registry should be empty after closeAllPools()"

View file

@ -68,6 +68,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass): def _ensureTableExists(self, modelClass):
return True return True
def borrowCursor(self):
"""Mimic `DatabaseConnector.borrowCursor()` context manager."""
from contextlib import contextmanager
from unittest.mock import MagicMock
@contextmanager
def _cm():
yield MagicMock()
return _cm()
def seed(self, modelClass, record: Dict[str, Any]): def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__ tableName = modelClass.__name__
self._tables.setdefault(tableName, {}) self._tables.setdefault(tableName, {})

View file

@ -69,6 +69,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass): def _ensureTableExists(self, modelClass):
return True return True
def borrowCursor(self):
"""Mimic `DatabaseConnector.borrowCursor()` context manager for the cascade test."""
from contextlib import contextmanager
from unittest.mock import MagicMock
@contextmanager
def _cm():
yield MagicMock()
return _cm()
def seed(self, modelClass, record: Dict[str, Any]): def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__ tableName = modelClass.__name__
self._tables.setdefault(tableName, {}) self._tables.setdefault(tableName, {})