db connection pooling and rag limit transparency

This commit is contained in:
ValueOn AG 2026-05-17 20:38:37 +02:00
parent f5aba4bf99
commit 2bb65c2303
23 changed files with 1519 additions and 782 deletions

9
app.py
View file

@ -439,6 +439,15 @@ async def lifespan(app: FastAPI):
except Exception as e: except Exception as e:
logger.warning(f"Could not shutdown feature containers: {e}") logger.warning(f"Could not shutdown feature containers: {e}")
# --- Close all PostgreSQL connection pools ---
# Must run LAST: feature `onStop` hooks may still issue DB calls during
# shutdown. Once we tear down the pools, no more borrows are possible.
try:
from modules.connectors.connectorDbPostgre import closeAllPools
closeAllPools()
except Exception as e:
logger.warning(f"Closing DB connection pools failed: {e}")
logger.info("Application has been shut down") logger.info("Application has been shut down")

File diff suppressed because it is too large Load diff

View file

@ -342,7 +342,7 @@ class RealEstateObjects:
# If no exact match, try case-insensitive search via SQL query # If no exact match, try case-insensitive search via SQL query
# This handles cases where the name might have different casing # This handles cases where the name might have different casing
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cursor: with self.db.borrowCursor() as cursor:
cursor.execute( cursor.execute(
'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1', 'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,) (name,)
@ -375,7 +375,7 @@ class RealEstateObjects:
# Try case-insensitive search # Try case-insensitive search
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cursor: with self.db.borrowCursor() as cursor:
cursor.execute( cursor.execute(
'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1', 'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,) (name,)
@ -408,7 +408,7 @@ class RealEstateObjects:
# Try case-insensitive search # Try case-insensitive search
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cursor: with self.db.borrowCursor() as cursor:
cursor.execute( cursor.execute(
'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1', 'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,) (name,)
@ -840,7 +840,7 @@ class RealEstateObjects:
# Ensure connection is alive # Ensure connection is alive
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cursor: with self.db.borrowCursor() as cursor:
# Execute query # Execute query
if parameters: if parameters:
# Use parameterized query for safety # Use parameterized query for safety

View file

@ -1659,7 +1659,7 @@ class BillingObjects:
try: try:
appInterface = getAppInterface(self.currentUser) appInterface = getAppInterface(self.currentUser)
appInterface.db._ensure_connection() appInterface.db._ensure_connection()
with appInterface.db.connection.cursor() as cur: with appInterface.db.borrowCursor() as cur:
if appInterface.db._ensureTableExists(UserInDB): if appInterface.db._ensureTableExists(UserInDB):
cur.execute( cur.execute(
'SELECT "id" FROM "UserInDB" WHERE ' 'SELECT "id" FROM "UserInDB" WHERE '
@ -1780,7 +1780,7 @@ class BillingObjects:
try: try:
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cur: with self.db.borrowCursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}' countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cur.execute(countSql, whereValues) cur.execute(countSql, whereValues)
totalItems = cur.fetchone()["count"] totalItems = cur.fetchone()["count"]
@ -1797,10 +1797,7 @@ class BillingObjects:
except Exception as e: except Exception as e:
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True) logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
try: # Rollback is handled by `borrowCursor()` context manager on exit.
self.db.connection.rollback()
except Exception:
pass
return {"items": [], "totalItems": 0, "totalPages": 0} return {"items": [], "totalItems": 0, "totalPages": 0}
def _buildScopeFilter( def _buildScopeFilter(
@ -1872,7 +1869,7 @@ class BillingObjects:
result: Dict[str, Any] = {} result: Dict[str, Any] = {}
with self.db.connection.cursor() as cur: with self.db.borrowCursor() as cur:
# 1) Totals # 1) Totals
cur.execute( cur.execute(
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}', f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
@ -1947,17 +1944,12 @@ class BillingObjects:
}) })
result["timeSeries"] = timeSeries result["timeSeries"] = timeSeries
self.db.connection.commit() # Commit/rollback are handled by `borrowCursor()` context manager.
result["_allAccounts"] = allAccounts result["_allAccounts"] = allAccounts
return result return result
except Exception as e: except Exception as e:
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True) logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
try:
self.db.connection.rollback()
except Exception:
pass
return self._emptyStats() return self._emptyStats()
@staticmethod @staticmethod

View file

@ -228,6 +228,22 @@ class KnowledgeObjects:
"""Get all ContentChunks for a file.""" """Get all ContentChunks for a file."""
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId}) return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
def countChunksByFileIds(self, fileIds: List[str]) -> Dict[str, int]:
"""Return a {fileId: chunkCount} mapping for the given file IDs.
One aggregate query instead of N round trips. Used by RAG inventory
to display real chunk counts per DataSource without loading the
embedding vectors. Missing file IDs map to 0 in the caller's logic.
"""
if not fileIds:
return {}
if not self.db._ensureTableExists(ContentChunk):
return {}
sql = 'SELECT "fileId", COUNT(*) AS cnt FROM "ContentChunk" WHERE "fileId" = ANY(%s) GROUP BY "fileId"'
with self.db.borrowCursor() as cursor:
cursor.execute(sql, (list(fileIds),))
return {row["fileId"]: int(row["cnt"]) for row in cursor.fetchall()}
def deleteContentChunks(self, fileId: str) -> int: def deleteContentChunks(self, fileId: str) -> int:
"""Delete all ContentChunks for a file. Returns count of deleted chunks.""" """Delete all ContentChunks for a file. Returns count of deleted chunks."""
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId}) chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})

View file

@ -1221,22 +1221,17 @@ class ComponentObjects:
for item in fileRows for item in fileRows
] ]
# Single transaction: delete FileData, FileItem, then FileFolder (children first) # Single transaction: delete FileData, FileItem, then FileFolder (children first).
self.db._ensure_connection() # Commit/rollback are handled by `borrowCursor()` on exit.
try: with self.db.borrowCursor() as cursor:
with self.db.connection.cursor() as cursor: if fileIds:
if fileIds: cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,)) cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,)) orderedIds = list(folderIds)
orderedIds = list(folderIds) orderedIds.remove(folderId)
orderedIds.remove(folderId) orderedIds.append(folderId)
orderedIds.append(folderId) if orderedIds:
if orderedIds: cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
self.db.connection.commit()
except Exception:
self.db.connection.rollback()
raise
return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)} return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)}
@ -1507,7 +1502,7 @@ class ComponentObjects:
try: try:
self.db._ensure_connection() self.db._ensure_connection()
with self.db.connection.cursor() as cursor: with self.db.borrowCursor() as cursor:
cursor.execute( cursor.execute(
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)', 'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
(uniqueIds,), (uniqueIds,),
@ -1526,11 +1521,10 @@ class ComponentObjects:
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,)) cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,))
deletedFiles = cursor.rowcount deletedFiles = cursor.rowcount
self.db.connection.commit() # Commit/rollback are handled by `borrowCursor()` context manager.
return {"deletedFiles": deletedFiles} return {"deletedFiles": deletedFiles}
except Exception as e: except Exception as e:
logger.error(f"Error deleting files in batch: {e}") logger.error(f"Error deleting files in batch: {e}")
self.db.connection.rollback()
raise FileDeletionError(f"Error deleting files in batch: {str(e)}") raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]: def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]:

View file

@ -374,7 +374,7 @@ def getRecordsetWithRBAC(
query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}' query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
cursor.execute(query, whereValues) cursor.execute(query, whereValues)
records = [dict(row) for row in cursor.fetchall()] records = [dict(row) for row in cursor.fetchall()]
@ -561,7 +561,7 @@ def getRecordsetPaginatedWithRBAC(
offset = (pagination.page - 1) * pagination.pageSize offset = (pagination.page - 1) * pagination.pageSize
limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}" limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}' countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cursor.execute(countSql, countValues) cursor.execute(countSql, countValues)
totalItems = cursor.fetchone()["count"] totalItems = cursor.fetchone()["count"]
@ -709,7 +709,7 @@ def getDistinctColumnValuesWithRBAC(
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val' sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val'
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
cursor.execute(sql, whereValues) cursor.execute(sql, whereValues)
result = [row["val"] for row in cursor.fetchall()] result = [row["val"] for row in cursor.fetchall()]
@ -719,7 +719,7 @@ def getDistinctColumnValuesWithRBAC(
emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1' emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1'
else: else:
emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1' emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1'
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
cursor.execute(emptySql, whereValues) cursor.execute(emptySql, whereValues)
if cursor.fetchone(): if cursor.fetchone():
result.append(None) result.append(None)
@ -967,7 +967,7 @@ def buildRbacWhereClause(
# Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate # Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate
if table == "UserInDB": if table == "UserInDB":
try: try:
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate # Get all user IDs that are members of the current mandate
cursor.execute( cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true', 'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
@ -994,7 +994,7 @@ def buildRbacWhereClause(
# For UserConnection: Filter via UserMandate junction table # For UserConnection: Filter via UserMandate junction table
elif table == "UserConnection": elif table == "UserConnection":
try: try:
with connector.connection.cursor() as cursor: with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate # Get all user IDs that are members of the current mandate
cursor.execute( cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true', 'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',

View file

@ -305,7 +305,7 @@ def handleIdsMode(
sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"' sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"'
with db.connection.cursor() as cursor: with db.borrowCursor() as cursor:
cursor.execute(sql, values) cursor.execute(sql, values)
return JSONResponse(content=[row["val"] for row in cursor.fetchall()]) return JSONResponse(content=[row["val"] for row in cursor.fetchall()])
except Exception as e: except Exception as e:

View file

@ -25,6 +25,18 @@ router = APIRouter(
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]: def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
"""Build per-connection RAG inventory rows.
Each DataSource row exposes BOTH numbers because they mean different things:
* `fileCount` distinct files indexed (== `FileContentIndex` rows)
* `chunkCount` embedding-sized text fragments (== `ContentChunk` rows,
max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval
actually hits)
A single PDF typically yields 1 file × 5100 chunks; legacy UI labelled
`len(FileContentIndex)` as "chunks" which was off by 12 orders of
magnitude and misleading.
"""
from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelDataSource import DataSource
from modules.datamodels.datamodelKnowledge import FileContentIndex from modules.datamodels.datamodelKnowledge import FileContentIndex
@ -34,19 +46,35 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId}) connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
connChunkTotal = len(connIndexRows) connFileTotal = len(connIndexRows)
# Map fileId → real chunk count via 1 aggregate query (cheap even for
# connections with thousands of files; we never load the vector body).
fileIds = [
(idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", ""))
for idx in connIndexRows
]
fileIds = [fid for fid in fileIds if fid]
chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
connChunkTotal = sum(chunkCountByFile.values())
filesByDs: Dict[str, int] = {}
chunksByDs: Dict[str, int] = {} chunksByDs: Dict[str, int] = {}
unassigned = 0 unassignedFiles = 0
unassignedChunks = 0
for idx in connIndexRows: for idx in connIndexRows:
fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")
chunkCnt = chunkCountByFile.get(fileId, 0)
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {} struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {} ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {} prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else "" dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
if dsIdRef: if dsIdRef:
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1 filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt
else: else:
unassigned += 1 unassignedFiles += 1
unassignedChunks += chunkCnt
seen: Dict[str, bool] = {} seen: Dict[str, bool] = {}
dsItems = [] dsItems = []
@ -64,14 +92,19 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False), "ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False), "neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None), "lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
"fileCount": filesByDs.get(dsId, 0),
"chunkCount": chunksByDs.get(dsId, 0), "chunkCount": chunksByDs.get(dsId, 0),
}) })
if unassigned > 0 and len(dsItems) > 0: # Spread orphan files (provenance lost) evenly so totals match.
perDs = unassigned // len(dsItems) if unassignedFiles > 0 and len(dsItems) > 0:
remainder = unassigned % len(dsItems) perFile = unassignedFiles // len(dsItems)
remFile = unassignedFiles % len(dsItems)
perChunk = unassignedChunks // len(dsItems)
remChunk = unassignedChunks % len(dsItems)
for i, item in enumerate(dsItems): for i, item in enumerate(dsItems):
item["chunkCount"] += perDs + (1 if i < remainder else 0) item["fileCount"] += perFile + (1 if i < remFile else 0)
item["chunkCount"] += perChunk + (1 if i < remChunk else 0)
# Pull a wider window than the previous 5 so the "last successful # Pull a wider window than the previous 5 so the "last successful
# sync" is found even if a connection has many recent jobs queued. # sync" is found even if a connection has many recent jobs queued.
@ -102,6 +135,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"skippedPolicy": result.get("skippedPolicy", 0), "skippedPolicy": result.get("skippedPolicy", 0),
"failed": result.get("failed", 0), "failed": result.get("failed", 0),
"durationMs": result.get("durationMs", 0), "durationMs": result.get("durationMs", 0),
# Surface limit-stop reason so the UI can warn the user
# that the index is provably incomplete (and which budget
# to raise). None means the walker finished naturally.
"stoppedAtLimit": result.get("stoppedAtLimit"),
"limits": result.get("limits") or {},
"bytesProcessed": result.get("bytesProcessed", 0),
} }
if lastError and lastSuccess: if lastError and lastSuccess:
break break
@ -113,6 +152,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False), "knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
"preferences": getattr(conn, "knowledgePreferences", None) or {}, "preferences": getattr(conn, "knowledgePreferences", None) or {},
"dataSources": dsItems, "dataSources": dsItems,
"totalFiles": connFileTotal,
"totalChunks": connChunkTotal, "totalChunks": connChunkTotal,
"runningJobs": runningJobs, "runningJobs": runningJobs,
"lastError": lastError, "lastError": lastError,
@ -139,8 +179,9 @@ def _getInventoryMe(
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items) totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"chunks": totalChunks}} return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except Exception as e: except Exception as e:
logger.error("Error in RAG inventory /me: %s", e, exc_info=True) logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -170,9 +211,10 @@ def _getInventoryMandate(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items) totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
totalBytes = aggregateMandateRagTotalBytes(mandateId) totalBytes = aggregateMandateRagTotalBytes(mandateId)
return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}} return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -202,8 +244,9 @@ def _getInventoryPlatform(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items) totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"chunks": totalChunks}} return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:

View file

@ -227,7 +227,7 @@ WHERE "workflowId" = ANY(%s)
GROUP BY "workflowId" GROUP BY "workflowId"
""" """
out: dict = {} out: dict = {}
with db.connection.cursor() as cursor: with db.borrowCursor() as cursor:
cursor.execute(sql, (workflowIds,)) cursor.execute(sql, (workflowIds,))
for row in cursor.fetchall(): for row in cursor.fetchall():
r = dict(row) r = dict(row)
@ -480,7 +480,7 @@ def _getWorkflowsJoinedPaginated(
dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}" dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}"
db._ensure_connection() db._ensure_connection()
with db.connection.cursor() as cursor: with db.borrowCursor() as cursor:
cursor.execute(countSql, countValues) cursor.execute(countSql, countValues)
totalItems = int(cursor.fetchone()["cnt"]) totalItems = int(cursor.fetchone()["cnt"])

View file

@ -25,15 +25,14 @@ _CACHE_TTL_SECONDS = 300
def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str): def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str):
"""Reuse a pooled DB connector for the given feature database.""" """Reuse a pooled DB connector for the given feature database.
The underlying psycopg2 connections live in the central pool
(`_PoolRegistry`) and are recreated on demand if they go stale; we just
need to keep the lightweight connector wrapper around.
"""
if featureDbName in _featureDbConnPool: if featureDbName in _featureDbConnPool:
conn = _featureDbConnPool[featureDbName] return _featureDbConnPool[featureDbName]
try:
if conn.connection and not conn.connection.closed:
return conn
except Exception as e:
logger.warning(f"Feature DB connection check failed for {featureDbName}: {e}")
_featureDbConnPool.pop(featureDbName, None)
from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG from modules.shared.configuration import APP_CONFIG

View file

@ -68,6 +68,9 @@ class ClickupBootstrapResult:
workspaces: int = 0 workspaces: int = 0
lists: int = 0 lists: int = 0
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
# First budget exhausted: "maxTasks" | "maxWorkspaces" | "maxListsPerWorkspace" | None.
# Drives the same UI banner as the file-walker bootstraps.
stoppedAtLimit: Optional[str] = None
def _syntheticTaskId(connectionId: str, taskId: str) -> str: def _syntheticTaskId(connectionId: str, taskId: str) -> str:
@ -225,6 +228,7 @@ async def bootstrapClickup(
cancelled = False cancelled = False
for ds in dataSources: for ds in dataSources:
if result.indexed + result.skippedDuplicate >= limits.maxTasks: if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", "dataSource", limits)
break break
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
cancelled = True cancelled = True
@ -243,8 +247,11 @@ async def bootstrapClickup(
clickupScope=limits.clickupScope, clickupScope=limits.clickupScope,
) )
if len(teams) > dsLimits.maxWorkspaces:
_recordLimitStop(result, "maxWorkspaces", "teams", dsLimits, hard=False)
for team in teams[:dsLimits.maxWorkspaces]: for team in teams[:dsLimits.maxWorkspaces]:
if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks: if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks:
_recordLimitStop(result, "maxTasks", f"team={team.get('id','')}", dsLimits)
break break
teamId = str(team.get("id", "") or "") teamId = str(team.get("id", "") or "")
if not teamId: if not teamId:
@ -351,6 +358,7 @@ async def _walkTeam(
for lst in listsCollected: for lst in listsCollected:
if result.indexed + result.skippedDuplicate >= limits.maxTasks: if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", f"team={teamId}", limits)
return return
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
return return
@ -407,6 +415,7 @@ async def _walkList(
for task in tasks: for task in tasks:
if result.indexed + result.skippedDuplicate >= limits.maxTasks: if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", f"list={listId}", limits)
return return
if not _isRecent(task.get("date_updated"), limits.maxAgeDays): if not _isRecent(task.get("date_updated"), limits.maxAgeDays):
result.skippedPolicy += 1 result.skippedPolicy += 1
@ -529,13 +538,37 @@ async def _ingestTask(
) )
def _recordLimitStop(
result: ClickupBootstrapResult,
limitName: str,
where: str,
limits: ClickupBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxTasks": limits.maxTasks,
"maxWorkspaces": limits.maxWorkspaces,
"maxListsPerWorkspace": limits.maxListsPerWorkspace,
}
logger.warning(
"clickup walker hit %s=%s at %s — partial index (indexed=%d, skippedDup=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.skippedDuplicate,
)
def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]: def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000) durationMs = int((time.time() - startMs) * 1000)
logger.info( logger.info(
"ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d", "ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d stoppedAtLimit=%s",
connectionId, connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.workspaces, result.lists, durationMs, result.failed, result.workspaces, result.lists, durationMs,
result.stoppedAtLimit or "none",
extra={ extra={
"event": "ingestion.connection.bootstrap.done", "event": "ingestion.connection.bootstrap.done",
"part": "clickup", "part": "clickup",
@ -547,6 +580,7 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"workspaces": result.workspaces, "workspaces": result.workspaces,
"lists": result.lists, "lists": result.lists,
"durationMs": durationMs, "durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
}, },
) )
return { return {
@ -559,4 +593,11 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"lists": result.lists, "lists": result.lists,
"durationMs": durationMs, "durationMs": durationMs,
"errors": result.errors[:20], "errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxTasks": MAX_TASKS_DEFAULT,
"maxWorkspaces": MAX_WORKSPACES_DEFAULT,
"maxListsPerWorkspace": MAX_LISTS_PER_WORKSPACE_DEFAULT,
"maxAgeDays": MAX_AGE_DAYS_DEFAULT,
},
} }

View file

@ -61,6 +61,8 @@ class GdriveBootstrapResult:
failed: int = 0 failed: int = 0
bytesProcessed: int = 0 bytesProcessed: int = 0
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str: def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -265,8 +267,10 @@ async def _walkFolder(
for entry in entries: for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems: if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return return
if result.bytesProcessed >= limits.maxBytes: if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return return
@ -276,6 +280,9 @@ async def _walkFolder(
mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType") mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType")
if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME: if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME:
if depth + 1 > limits.maxDepth:
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder( await _walkFolder(
adapter=adapter, adapter=adapter,
knowledgeService=knowledgeService, knowledgeService=knowledgeService,
@ -298,6 +305,7 @@ async def _walkFolder(
continue continue
size = int(getattr(entry, "size", 0) or 0) size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize: if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1 result.skippedPolicy += 1
continue continue
modifiedTime = metadata.get("modifiedTime") modifiedTime = metadata.get("modifiedTime")
@ -470,13 +478,38 @@ async def _ingestOne(
await asyncio.sleep(0) await asyncio.sleep(0)
def _recordLimitStop(
result: GdriveBootstrapResult,
limitName: str,
where: str,
limits: GdriveBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"gdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]: def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000) durationMs = int((time.time() - startMs) * 1000)
logger.info( logger.info(
"ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d", "ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d stoppedAtLimit=%s",
connectionId, connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.bytesProcessed, durationMs, result.failed, result.bytesProcessed, durationMs,
result.stoppedAtLimit or "none",
extra={ extra={
"event": "ingestion.connection.bootstrap.done", "event": "ingestion.connection.bootstrap.done",
"part": "gdrive", "part": "gdrive",
@ -487,6 +520,7 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"failed": result.failed, "failed": result.failed,
"bytes": result.bytesProcessed, "bytes": result.bytesProcessed,
"durationMs": durationMs, "durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
}, },
) )
return { return {
@ -498,4 +532,11 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed, "bytesProcessed": result.bytesProcessed,
"durationMs": durationMs, "durationMs": durationMs,
"errors": result.errors[:20], "errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
} }

View file

@ -53,6 +53,8 @@ class KdriveBootstrapResult:
failed: int = 0 failed: int = 0
bytesProcessed: int = 0 bytesProcessed: int = 0
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str: def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -232,14 +234,19 @@ async def _walkFolder(
for entry in entries: for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems: if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return return
if result.bytesProcessed >= limits.maxBytes: if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return return
entryPath = getattr(entry, "path", "") or "" entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False): if getattr(entry, "isFolder", False):
if depth + 1 > limits.maxDepth:
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder( await _walkFolder(
adapter=adapter, adapter=adapter,
knowledgeService=knowledgeService, knowledgeService=knowledgeService,
@ -262,6 +269,7 @@ async def _walkFolder(
continue continue
size = int(getattr(entry, "size", 0) or 0) size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize: if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1 result.skippedPolicy += 1
continue continue
@ -415,17 +423,42 @@ async def _ingestOne(
await asyncio.sleep(0) await asyncio.sleep(0)
def _recordLimitStop(
result: KdriveBootstrapResult,
limitName: str,
where: str,
limits: KdriveBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"kdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]: def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000) durationMs = int((time.time() - startMs) * 1000)
logger.info( logger.info(
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d", "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId, connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
durationMs, durationMs, result.stoppedAtLimit or "none",
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive", extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
"connectionId": connectionId, "indexed": result.indexed, "connectionId": connectionId, "indexed": result.indexed,
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy, "skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
"failed": result.failed, "durationMs": durationMs}, "failed": result.failed, "durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit},
) )
return { return {
"connectionId": result.connectionId, "connectionId": result.connectionId,
@ -436,4 +469,11 @@ def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed, "bytesProcessed": result.bytesProcessed,
"durationMs": durationMs, "durationMs": durationMs,
"errors": result.errors[:20], "errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
} }

View file

@ -59,6 +59,10 @@ class SharepointBootstrapResult:
failed: int = 0 failed: int = 0
bytesProcessed: int = 0 bytesProcessed: int = 0
errors: List[str] = field(default_factory=list) errors: List[str] = field(default_factory=list)
# First budget that hit zero; None means the walk completed naturally.
# Surfaces in the bootstrap result so the RAG inventory UI can warn the
# user that the corpus is incomplete and tell them which knob to turn.
stoppedAtLimit: Optional[str] = None # "maxItems" | "maxBytes" | "maxDepth" | "maxFileSize" | None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str: def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -259,14 +263,22 @@ async def _walkFolder(
for entry in entries: for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems: if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return return
if result.bytesProcessed >= limits.maxBytes: if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return return
entryPath = getattr(entry, "path", "") or "" entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False): if getattr(entry, "isFolder", False):
if depth + 1 > limits.maxDepth:
# We stop descending here but keep walking siblings.
# Record once per bootstrap so the UI shows "maxDepth" even
# if other budgets aren't exhausted yet.
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder( await _walkFolder(
adapter=adapter, adapter=adapter,
knowledgeService=knowledgeService, knowledgeService=knowledgeService,
@ -289,6 +301,7 @@ async def _walkFolder(
continue continue
size = int(getattr(entry, "size", 0) or 0) size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize: if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1 result.skippedPolicy += 1
continue continue
@ -443,13 +456,44 @@ async def _ingestOne(
await asyncio.sleep(0) await asyncio.sleep(0)
def _recordLimitStop(
result: SharepointBootstrapResult,
limitName: str,
where: str,
limits: SharepointBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""Mark the FIRST limit that bit. Soft hits (per-file maxFileSize, per-folder
maxDepth) only record when no hard limit has yet stopped the run, so the UI
surfaces the most important reason.
Hard limits (maxItems / maxBytes) ALWAYS overwrite a previously recorded
soft limit once a hard cap is hit, the corpus is provably incomplete.
"""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"sharepoint walker hit %s=%s at %s — partial index "
"(indexed=%d, bytesProcessed=%d). Raise the limit or split the data source.",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]: def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000) durationMs = int((time.time() - startMs) * 1000)
logger.info( logger.info(
"ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d", "ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId, connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
durationMs, durationMs, result.stoppedAtLimit or "none",
extra={ extra={
"event": "ingestion.connection.bootstrap.done", "event": "ingestion.connection.bootstrap.done",
"part": "sharepoint", "part": "sharepoint",
@ -459,6 +503,7 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"skippedPolicy": result.skippedPolicy, "skippedPolicy": result.skippedPolicy,
"failed": result.failed, "failed": result.failed,
"durationMs": durationMs, "durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
}, },
) )
return { return {
@ -470,4 +515,11 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"bytesProcessed": result.bytesProcessed, "bytesProcessed": result.bytesProcessed,
"durationMs": durationMs, "durationMs": durationMs,
"errors": result.errors[:20], "errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
} }

View file

@ -12,7 +12,8 @@ import logging
import json import json
import base64 import base64
import time import time
from typing import Any, Dict, Optional import threading
from typing import Any, Dict, Optional, Tuple
from pathlib import Path from pathlib import Path
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
@ -286,6 +287,16 @@ def handleSecretJson(value: str, userId: str = "system", keyName: str = "unknown
# Structure: {user_id: {key_name: [timestamps]}} # Structure: {user_id: {key_name: [timestamps]}}
_decryption_attempts = {} _decryption_attempts = {}
# Process-wide plaintext cache for decrypted secrets.
# Key: the encrypted ciphertext (which already includes env prefix).
# Value: (expiresAtMonotonic, plaintext).
# TTL is short enough that key rotation propagates quickly, long enough that
# hot DB-init paths (every API call building a connector) don't blow the
# decryption rate limit. 60s is a deliberate compromise.
_DECRYPTION_CACHE_TTL_S = 60.0
_decryption_cache: Dict[str, Tuple[float, str]] = {}
_decryption_cache_lock = threading.Lock()
def _getMasterKey(envType: str = None) -> bytes: def _getMasterKey(envType: str = None) -> bytes:
""" """
Get the master key for the specified environment. Get the master key for the specified environment.
@ -487,6 +498,16 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
""" """
Decrypt a value using the master key for the current environment. Decrypt a value using the master key for the current environment.
A short-lived plaintext cache (TTL `_DECRYPTION_CACHE_TTL_S`) is consulted
first. The 10/sec rate-limit on cache misses still protects against
brute-force attacks; cache HITS bypass it because they are not actual
cryptographic operations they just return the result of an earlier
successful decrypt. Without this cache, hot paths like
`mainBackgroundJobService._getDb()` (called per RAG inventory poll AND
per walker DB call) trigger the rate limit and surface as
"Decryption rate limit exceeded for user 'system' key 'DB_PASSWORD_SECRET'"
ERRORs in the RAG inventory UI route.
Args: Args:
encryptedValue: The encrypted value with prefix encryptedValue: The encrypted value with prefix
userId: The user ID making the request (default: "system") userId: The user ID making the request (default: "system")
@ -501,7 +522,15 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
if not _isEncryptedValue(encryptedValue): if not _isEncryptedValue(encryptedValue):
return encryptedValue # Return as-is if not encrypted return encryptedValue # Return as-is if not encrypted
# Check rate limiting (10 per second per user per key) # Cache lookup BEFORE the rate-limit check: a cache hit is not a new
# cryptographic operation and must not be throttled.
now = time.monotonic()
with _decryption_cache_lock:
cached = _decryption_cache.get(encryptedValue)
if cached is not None and cached[0] > now:
return cached[1]
# Cache miss → real decrypt → apply rate limit.
if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10): if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10):
raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)") raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)")
@ -550,10 +579,24 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
# Don't fail if audit logging fails # Don't fail if audit logging fails
pass pass
# Populate cache so subsequent reads of the same ciphertext don't
# re-decrypt (and don't consume rate-limit budget).
with _decryption_cache_lock:
_decryption_cache[encryptedValue] = (
time.monotonic() + _DECRYPTION_CACHE_TTL_S,
decryptedValue,
)
return decryptedValue return decryptedValue
except Exception as e: except Exception as e:
raise ValueError(f"Decryption failed: {e}") raise ValueError(f"Decryption failed: {e}")
def clearDecryptionCache() -> None:
"""Drop all cached plaintext secrets. Call after key rotation or in tests."""
with _decryption_cache_lock:
_decryption_cache.clear()
# Create the global APP_CONFIG instance # Create the global APP_CONFIG instance
APP_CONFIG = Configuration() APP_CONFIG = Configuration()

View file

@ -33,20 +33,35 @@ def _ensureUamTablesMatchModels(dbConnector) -> None:
logger.debug(f"_ensureUamTablesMatchModels: {e}") logger.debug(f"_ensureUamTablesMatchModels: {e}")
def _getConnection(dbConnector): from contextlib import contextmanager
"""Get a connection from the DatabaseConnector.
Ensures the connection is alive and returns it.
Commits any pending transaction first to avoid blocking. @contextmanager
def _borrowDbConn(dbConnector):
"""Borrow a pooled connection from the DatabaseConnector.
Index/trigger/FK creation traditionally ran with `conn.autocommit = True`
so each CREATE statement is its own transaction (DDL on a managed
connection blocks waiting for COMMIT). This helper preserves that
behaviour on top of the pool: borrow a connection, flip it to autocommit,
yield it, and restore the previous state before returning it to the pool.
""" """
dbConnector._ensure_connection() with dbConnector.borrowConn() as conn:
conn = dbConnector.connection try:
# Commit any pending transaction to avoid blocking previousAutocommit = conn.autocommit
try: except Exception:
conn.commit() previousAutocommit = False
except Exception: try:
pass # Ignore if nothing to commit conn.autocommit = True
return conn except Exception as e:
logger.debug(f"Could not set autocommit on borrowed connection: {e}")
try:
yield conn
finally:
try:
conn.autocommit = previousAutocommit
except Exception:
pass
# ============================================================================= # =============================================================================
@ -174,48 +189,23 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
} }
try: try:
# Get a connection from the connector
conn = _getConnection(dbConnector)
# Save and set autocommit state
try:
originalAutocommit = conn.autocommit
except Exception:
originalAutocommit = False
try:
conn.autocommit = True
except Exception as autoErr:
logger.debug(f"Could not set autocommit: {autoErr}")
try: try:
_ensureUamTablesMatchModels(dbConnector) _ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr: except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}") logger.debug(f"Pre-index table ensure: {preIdxErr}")
try: with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
# Apply indexes
results["indexesCreated"] = _applyIndexes(cursor, tables) results["indexesCreated"] = _applyIndexes(cursor, tables)
# Apply foreign keys
results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables) results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables)
# Apply immutable triggers
results["triggersCreated"] = _applyImmutableTriggers(cursor, tables) results["triggersCreated"] = _applyImmutableTriggers(cursor, tables)
logger.info( logger.info(
f"Multi-tenant optimizations applied: " f"Multi-tenant optimizations applied: "
f"{results['indexesCreated']} indexes, " f"{results['indexesCreated']} indexes, "
f"{results['triggersCreated']} triggers, " f"{results['triggersCreated']} triggers, "
f"{results['foreignKeysCreated']} foreign keys" f"{results['foreignKeysCreated']} foreign keys"
) )
finally:
# Restore original autocommit state
try:
conn.autocommit = originalAutocommit
except Exception:
pass
except Exception as e: except Exception as e:
logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}") logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}")
@ -227,20 +217,14 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int: def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int:
"""Apply only indexes (lighter operation, safe for frequent calls).""" """Apply only indexes (lighter operation, safe for frequent calls)."""
try: try:
conn = _getConnection(dbConnector)
originalAutocommit = conn.autocommit
conn.autocommit = True
try: try:
_ensureUamTablesMatchModels(dbConnector) _ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr: except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}") logger.debug(f"Pre-index table ensure: {preIdxErr}")
try: with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor: with conn.cursor() as cursor:
return _applyIndexes(cursor, tables) return _applyIndexes(cursor, tables)
finally:
conn.autocommit = originalAutocommit
except Exception as e: except Exception as e:
logger.error(f"Error applying indexes: {e}") logger.error(f"Error applying indexes: {e}")
return 0 return 0
@ -514,8 +498,7 @@ def getOptimizationStatus(dbConnector) -> dict:
} }
try: try:
conn = _getConnection(dbConnector) with _borrowDbConn(dbConnector) as conn, conn.cursor() as cursor:
with conn.cursor() as cursor:
# Check regular indexes # Check regular indexes
for tableName, indexName, _ in _INDEXES: for tableName, indexName, _ in _INDEXES:
if _tableExists(cursor, tableName): if _tableExists(cursor, tableName):

View file

@ -60,11 +60,9 @@ def _getTableColumns(dbConnector, tableName: str) -> List[str]:
ORDER BY ordinal_position ORDER BY ordinal_position
""" """
cursor = dbConnector.connection.cursor() with dbConnector.borrowCursor() as cursor:
cursor.execute(query, (tableName,)) cursor.execute(query, (tableName,))
columns = [row[0] for row in cursor.fetchall()] columns = [row[0] for row in cursor.fetchall()]
cursor.close()
return columns return columns
except Exception as e: except Exception as e:
logger.error(f"Error getting columns for table {tableName}: {e}") logger.error(f"Error getting columns for table {tableName}: {e}")
@ -92,29 +90,26 @@ def _getAllTables(dbConnector) -> List[str]:
ORDER BY table_name ORDER BY table_name
""" """
cursor = dbConnector.connection.cursor() with dbConnector.borrowCursor() as cursor:
cursor.execute(query) cursor.execute(query)
allTables = [row[0] for row in cursor.fetchall()] allTables = [row[0] for row in cursor.fetchall()]
# Get foreign key relationships to determine dependency order fkQuery = """
fkQuery = """ SELECT
SELECT tc.table_name,
tc.table_name, ccu.table_name AS foreign_table_name
ccu.table_name AS foreign_table_name FROM information_schema.table_constraints AS tc
FROM information_schema.table_constraints AS tc JOIN information_schema.key_column_usage AS kcu
JOIN information_schema.key_column_usage AS kcu ON tc.constraint_name = kcu.constraint_name
ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema
AND tc.table_schema = kcu.table_schema JOIN information_schema.constraint_column_usage AS ccu
JOIN information_schema.constraint_column_usage AS ccu ON ccu.constraint_name = tc.constraint_name
ON ccu.constraint_name = tc.constraint_name AND ccu.table_schema = tc.table_schema
AND ccu.table_schema = tc.table_schema WHERE tc.constraint_type = 'FOREIGN KEY'
WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = 'public'
AND tc.table_schema = 'public' """
""" cursor.execute(fkQuery)
foreignKeys = cursor.fetchall()
cursor.execute(fkQuery)
foreignKeys = cursor.fetchall()
cursor.close()
# Build dependency graph (child -> parent mapping) # Build dependency graph (child -> parent mapping)
dependencies = {} dependencies = {}
@ -154,10 +149,9 @@ def _getAllTables(dbConnector) -> List[str]:
# Fallback: return simple list without ordering # Fallback: return simple list without ordering
try: try:
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'" query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
cursor = dbConnector.connection.cursor() with dbConnector.borrowCursor() as cursor:
cursor.execute(query) cursor.execute(query)
tables = [row[0] for row in cursor.fetchall()] tables = [row[0] for row in cursor.fetchall()]
cursor.close()
return [t for t in tables if t not in PROTECTED_TABLES] return [t for t in tables if t not in PROTECTED_TABLES]
except Exception: except Exception:
return [] return []
@ -184,11 +178,9 @@ def _getPrimaryKeyColumns(dbConnector, tableName: str) -> List[str]:
AND i.indisprimary AND i.indisprimary
""" """
cursor = dbConnector.connection.cursor() with dbConnector.borrowCursor() as cursor:
cursor.execute(query, (tableName,)) cursor.execute(query, (tableName,))
pkColumns = [row[0] for row in cursor.fetchall()] pkColumns = [row[0] for row in cursor.fetchall()]
cursor.close()
return pkColumns return pkColumns
except Exception as e: except Exception as e:
logger.debug(f"Could not get primary key for {tableName}: {e}") logger.debug(f"Could not get primary key for {tableName}: {e}")
@ -229,21 +221,15 @@ def _findUserReferencesInTable(
return {} return {}
references = {} references = {}
cursor = dbConnector.connection.cursor() with dbConnector.borrowCursor() as cursor:
for userColumn in userColumns:
for userColumn in userColumns: pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
# Build SELECT for primary key columns query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns]) cursor.execute(query, (userId,))
query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s' recordKeys = cursor.fetchall()
if recordKeys:
cursor.execute(query, (userId,)) references[userColumn] = recordKeys
recordKeys = cursor.fetchall() logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
if recordKeys:
references[userColumn] = recordKeys
logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
cursor.close()
return references return references
except Exception as e: except Exception as e:
@ -277,42 +263,35 @@ def _anonymizeRecords(
return 0 return 0
try: try:
cursor = dbConnector.connection.cursor() # Resolve column metadata once outside the borrow block (it borrows its
# own connection internally).
columns = _getTableColumns(dbConnector, tableName)
hasModifiedAt = "sysModifiedAt" in columns
count = 0 count = 0
with dbConnector.borrowCursor() as cursor:
for recordKey in recordKeys:
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
if hasModifiedAt:
query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
params = [anonymousValue, getUtcTimestamp()]
else:
query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
params = [anonymousValue]
for recordKey in recordKeys: if isinstance(recordKey, tuple):
# Build WHERE clause for primary key params.extend(recordKey)
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) else:
params.append(recordKey)
# Check if table has sysModifiedAt column cursor.execute(query, params)
columns = _getTableColumns(dbConnector, tableName) count += cursor.rowcount
hasModifiedAt = "sysModifiedAt" in columns
if hasModifiedAt:
query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
params = [anonymousValue, getUtcTimestamp()]
else:
query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
params = [anonymousValue]
# Add primary key values to params
if isinstance(recordKey, tuple):
params.extend(recordKey)
else:
params.append(recordKey)
cursor.execute(query, params)
count += cursor.rowcount
dbConnector.connection.commit()
cursor.close()
logger.info(f"Anonymized {count} records in {tableName}.{columnName}") logger.info(f"Anonymized {count} records in {tableName}.{columnName}")
return count return count
except Exception as e: except Exception as e:
logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}") logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}")
dbConnector.connection.rollback()
return 0 return 0
@ -338,32 +317,23 @@ def _deleteRecords(
return 0 return 0
try: try:
cursor = dbConnector.connection.cursor()
count = 0 count = 0
with dbConnector.borrowCursor() as cursor:
for recordKey in recordKeys: for recordKey in recordKeys:
# Build WHERE clause for primary key whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns]) query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
query = f'DELETE FROM "{tableName}" WHERE {whereClause}' if isinstance(recordKey, tuple):
params = list(recordKey)
# Prepare params else:
if isinstance(recordKey, tuple): params = [recordKey]
params = list(recordKey) cursor.execute(query, params)
else: count += cursor.rowcount
params = [recordKey]
cursor.execute(query, params)
count += cursor.rowcount
dbConnector.connection.commit()
cursor.close()
logger.info(f"Deleted {count} records from {tableName}") logger.info(f"Deleted {count} records from {tableName}")
return count return count
except Exception as e: except Exception as e:
logger.error(f"Error deleting records from {tableName}: {e}") logger.error(f"Error deleting records from {tableName}: {e}")
dbConnector.connection.rollback()
return 0 return 0

View file

@ -25,7 +25,7 @@ if not c or not c.connection:
print("STAGE0: DB_CONNECTION=none (check config.ini / .env)") print("STAGE0: DB_CONNECTION=none (check config.ini / .env)")
raise SystemExit(2) raise SystemExit(2)
cur = c.connection.cursor() cur = c.borrowCursor()
def _scalar(cur): def _scalar(cur):

View file

@ -12,11 +12,16 @@ broken query into "no rows found". That hid bugs like:
These tests pin the new contract: empty result sets still return ``[]`` / These tests pin the new contract: empty result sets still return ``[]`` /
``None`` (normal), but any exception inside the query path propagates as ``None`` (normal), but any exception inside the query path propagates as
``DatabaseQueryError`` with the table name attached. The transaction is ``DatabaseQueryError`` with the table name attached.
rolled back so the connection is usable for subsequent queries.
Since the 2026-05-17 pool refactor (`c-work/2-build/2026-05-postgres-connection-pool.md`)
the connector borrows a connection from `_PoolRegistry` on every call via the
`borrowConn()` context manager. The tests mock that context manager so the
fast-fail contract is exercised without requiring a live Postgres server.
""" """
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -25,7 +30,6 @@ import psycopg2.errors
from modules.connectors.connectorDbPostgre import ( from modules.connectors.connectorDbPostgre import (
DatabaseConnector, DatabaseConnector,
DatabaseQueryError, DatabaseQueryError,
_rollbackQuietly,
) )
@ -39,26 +43,44 @@ class DummyTable:
def _makeConnector(cursorBehavior): def _makeConnector(cursorBehavior):
"""Build a ``DatabaseConnector`` skeleton with mocked connection/cursor. """Build a ``DatabaseConnector`` skeleton with a mocked pool borrow.
``cursorBehavior`` is a callable invoked with the cursor mock so the test ``cursorBehavior`` is a callable invoked with the cursor mock so the test
can configure ``execute``/``fetchall``/``fetchone`` per scenario. can configure ``execute``/``fetchall``/``fetchone`` per scenario.
Returns ``(connector, conn, cursor)``:
* ``conn`` exposes ``commit`` / ``rollback`` MagicMocks so tests can
assert that the borrow lifecycle did the right thing.
* ``cursor`` is the per-test cursor mock.
""" """
connector = DatabaseConnector.__new__(DatabaseConnector) connector = DatabaseConnector.__new__(DatabaseConnector)
cursor = MagicMock() cursor = MagicMock()
cursorBehavior(cursor)
cursorContext = MagicMock() cursorContext = MagicMock()
cursorContext.__enter__ = MagicMock(return_value=cursor) cursorContext.__enter__ = MagicMock(return_value=cursor)
cursorContext.__exit__ = MagicMock(return_value=False) cursorContext.__exit__ = MagicMock(return_value=False)
connection = MagicMock() conn = MagicMock()
connection.cursor.return_value = cursorContext conn.cursor.return_value = cursorContext
connector.connection = connection
@contextmanager
def fakeBorrow():
try:
yield conn
except Exception:
conn.rollback()
raise
else:
conn.commit()
connector.borrowConn = fakeBorrow
connector._ensureTableExists = MagicMock(return_value=True) connector._ensureTableExists = MagicMock(return_value=True)
connector._systemTableName = "_system" connector._systemTableName = "_system"
cursorBehavior(cursor) return connector, conn, cursor
return connector, connection, cursor
class TestGetRecordsetFailLoud: class TestGetRecordsetFailLoud:
@ -67,11 +89,12 @@ class TestGetRecordsetFailLoud:
def behavior(cursor): def behavior(cursor):
cursor.execute.return_value = None cursor.execute.return_value = None
cursor.fetchall.return_value = [] cursor.fetchall.return_value = []
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
result = connector.getRecordset(DummyTable) result = connector.getRecordset(DummyTable)
assert result == [] assert result == []
connection.rollback.assert_not_called() conn.rollback.assert_not_called()
conn.commit.assert_called_once()
def test_dictAdaptErrorRaisesDatabaseQueryError(self): def test_dictAdaptErrorRaisesDatabaseQueryError(self):
"""Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise.""" """Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise."""
@ -79,7 +102,7 @@ class TestGetRecordsetFailLoud:
cursor.execute.side_effect = psycopg2.ProgrammingError( cursor.execute.side_effect = psycopg2.ProgrammingError(
"can't adapt type 'dict'" "can't adapt type 'dict'"
) )
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo: with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset( connector.getRecordset(
@ -90,30 +113,30 @@ class TestGetRecordsetFailLoud:
assert excinfo.value.table == "DummyTable" assert excinfo.value.table == "DummyTable"
assert "can't adapt type 'dict'" in str(excinfo.value) assert "can't adapt type 'dict'" in str(excinfo.value)
assert isinstance(excinfo.value.original, psycopg2.ProgrammingError) assert isinstance(excinfo.value.original, psycopg2.ProgrammingError)
connection.rollback.assert_called_once() conn.rollback.assert_called_once()
def test_missingColumnRaisesDatabaseQueryError(self): def test_missingColumnRaisesDatabaseQueryError(self):
def behavior(cursor): def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedColumn( cursor.execute.side_effect = psycopg2.errors.UndefinedColumn(
'column "wat" does not exist' 'column "wat" does not exist'
) )
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo: with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset(DummyTable, recordFilter={"wat": "x"}) connector.getRecordset(DummyTable, recordFilter={"wat": "x"})
assert "wat" in str(excinfo.value) assert "wat" in str(excinfo.value)
connection.rollback.assert_called_once() conn.rollback.assert_called_once()
def test_operationalErrorRaisesDatabaseQueryError(self): def test_operationalErrorRaisesDatabaseQueryError(self):
"""Connection lost mid-query is also a real failure that must propagate.""" """Connection lost mid-query is also a real failure that must propagate."""
def behavior(cursor): def behavior(cursor):
cursor.execute.side_effect = psycopg2.OperationalError("connection lost") cursor.execute.side_effect = psycopg2.OperationalError("connection lost")
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError): with pytest.raises(DatabaseQueryError):
connector.getRecordset(DummyTable) connector.getRecordset(DummyTable)
connection.rollback.assert_called_once() conn.rollback.assert_called_once()
class TestGetRecordFailLoud: class TestGetRecordFailLoud:
@ -122,37 +145,22 @@ class TestGetRecordFailLoud:
def behavior(cursor): def behavior(cursor):
cursor.execute.return_value = None cursor.execute.return_value = None
cursor.fetchone.return_value = None cursor.fetchone.return_value = None
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
result = connector.getRecord(DummyTable, "missing-id") result = connector.getRecord(DummyTable, "missing-id")
assert result is None assert result is None
connection.rollback.assert_not_called() conn.rollback.assert_not_called()
conn.commit.assert_called_once()
def test_queryErrorRaisesDatabaseQueryError(self): def test_queryErrorRaisesDatabaseQueryError(self):
def behavior(cursor): def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedTable( cursor.execute.side_effect = psycopg2.errors.UndefinedTable(
'relation "DummyTable" does not exist' 'relation "DummyTable" does not exist'
) )
connector, connection, _ = _makeConnector(behavior) connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo: with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecord(DummyTable, "any-id") connector.getRecord(DummyTable, "any-id")
assert excinfo.value.table == "DummyTable" assert excinfo.value.table == "DummyTable"
connection.rollback.assert_called_once() conn.rollback.assert_called_once()
class TestRollbackQuietly:
def test_rollsBackOnLiveConnection(self):
connection = MagicMock()
_rollbackQuietly(connection)
connection.rollback.assert_called_once()
def test_swallowsRollbackError(self):
"""Rollback failure must not mask the original query error."""
connection = MagicMock()
connection.rollback.side_effect = RuntimeError("rollback failed")
_rollbackQuietly(connection)
def test_noopOnNoneConnection(self):
_rollbackQuietly(None)

View file

@ -0,0 +1,304 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Concurrency tests for the PostgreSQL connection pool.
These tests pin the contract that the `c-work/2-build/2026-05-postgres-connection-pool.md`
refactor delivered:
* T1 50 threads × 100 calls in parallel produce 0 `OperationalError`s and
every call completes within reasonable time (p99 < 2 s).
* T2 Two threads `_loadRecord` + `_saveRecord` against the same connector
do not corrupt each other's cursors.
* T3 `statement_timeout` aborts a runaway `pg_sleep(60)` after ~30 s and
releases the connection back into the pool cleanly.
The tests need a real PostgreSQL server because the bug they guard against
only materialises with real psycopg2 sockets a mocked connection never
hangs in `recv()`. They read DB credentials from `APP_CONFIG` (which loads
`.env`) and are auto-skipped when the connection fails (no local Postgres,
wrong creds, etc.) so `pytest` keeps working in CI-only environments.
To run them locally:
pytest gateway/tests/unit/connectors/test_connectorDbPostgre_pool.py -v
They use a throwaway database name (`poweron_pool_test_<uuid>`) and drop it
in fixture teardown so they leave nothing behind.
"""
from __future__ import annotations
import time
import uuid
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import psycopg2
import psycopg2.errors
import pytest
from pydantic import Field
from modules.connectors.connectorDbPostgre import (
DatabaseConnector,
_PoolRegistry,
closeAllPools,
)
from modules.datamodels.datamodelBase import PowerOnModel
from modules.shared.configuration import APP_CONFIG
def _dbConfig():
"""Read DB connection params from APP_CONFIG (`.env`).
Returns ``None`` when host/user/password are not all present so the
test module can skip cleanly instead of blowing up at import time.
"""
host = APP_CONFIG.get("DB_HOST")
user = APP_CONFIG.get("DB_USER")
password = APP_CONFIG.get("DB_PASSWORD_SECRET")
port = APP_CONFIG.get("DB_PORT", 5432)
if not host or not user or password is None:
return None
return {"host": host, "user": user, "password": password, "port": int(port)}
def _canReachPostgres(cfg) -> bool:
"""Try a quick connect to the admin DB so we can skip on connection failures."""
try:
conn = psycopg2.connect(
host=cfg["host"], port=cfg["port"], database="postgres",
user=cfg["user"], password=cfg["password"], connect_timeout=2,
)
conn.close()
return True
except Exception: # noqa: BLE001
return False
_DB_CFG = _dbConfig()
pytestmark = pytest.mark.skipif(
_DB_CFG is None or not _canReachPostgres(_DB_CFG),
reason="No reachable PostgreSQL — skipping live-Postgres pool tests",
)
class PoolTestRow(PowerOnModel):
"""Tiny model used to exercise the pool — one ID + one payload field."""
payload: str = Field(default="", description="Test payload")
@pytest.fixture
def liveConnector():
"""Spin up a throwaway database, yield a `DatabaseConnector` against it,
drop the database afterwards.
The pool registry is wiped before and after each test so state from one
test cannot mask a bug in another.
"""
cfg = _DB_CFG
host = cfg["host"]
user = cfg["user"]
password = cfg["password"]
port = cfg["port"]
dbName = f"poweron_pool_test_{uuid.uuid4().hex[:8]}"
# Pre-clean: drop any orphan test DB with the same name (shouldn't happen
# because we use a unique uuid, but be defensive).
adminConn = psycopg2.connect(
host=host, port=port, database="postgres", user=user, password=password
)
adminConn.autocommit = True
try:
with adminConn.cursor() as cur:
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
finally:
adminConn.close()
closeAllPools()
connector = DatabaseConnector(
dbHost=host,
dbDatabase=dbName,
dbUser=user,
dbPassword=password,
dbPort=port,
)
# Seed exactly one row so every concurrent read has a stable target.
connector.recordCreate(PoolTestRow, {"id": "seed", "payload": "hello"})
yield connector
# Teardown: tear pools down, then drop the DB.
closeAllPools()
adminConn = psycopg2.connect(
host=host, port=port, database="postgres", user=user, password=password
)
adminConn.autocommit = True
try:
with adminConn.cursor() as cur:
cur.execute(
'SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s',
(dbName,),
)
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
finally:
adminConn.close()
class TestPoolConcurrency:
def _runWorkers(self, liveConnector, *, threadCount: int, callsPerThread: int):
"""Run N worker threads, each issuing M reads. Return (errors, latencies)."""
errors: list = []
latencies: list = []
lock = threading.Lock()
def worker():
for _ in range(callsPerThread):
t0 = time.perf_counter()
try:
rows = liveConnector.getRecordset(PoolTestRow)
assert any(r["id"] == "seed" for r in rows)
except Exception as e: # noqa: BLE001 — we want every failure mode
with lock:
errors.append(e)
finally:
with lock:
latencies.append(time.perf_counter() - t0)
with ThreadPoolExecutor(max_workers=threadCount) as ex:
futures = [ex.submit(worker) for _ in range(threadCount)]
for f in as_completed(futures):
f.result()
latencies.sort()
return errors, latencies
def test_50_threads_x_20_reads_no_errors(self, liveConnector):
"""T1a — STRESS: 50 threads × 20 reads each → 0 errors.
Pre-pool, this scenario produced either
`OperationalError: another command is already in progress` or a
deadlock in `recv()` because the threadpool shared one psycopg2
socket. With the pool plus `borrowConn`'s bounded wait, every
thread eventually gets a connection and completes even with 30
threads queued waiting at any moment (pool max=20).
"""
errors, _ = self._runWorkers(liveConnector, threadCount=50, callsPerThread=20)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
def test_20_threads_x_50_reads_latency_budget(self, liveConnector):
"""T1b — DESIGN CAPACITY: 20 threads × 50 reads, p99 < 5 s.
20 threads matches the pool's `max=20` so there is no queueing —
every borrow returns immediately. This pins a sanity-level per-call
latency budget; pre-pool it was unbounded (recv() never returned).
The 5 s ceiling is generous on purpose: `getRecordset` calls
`_ensureTableExists` which runs two `information_schema` queries
for column-additive migration, and under 20-way concurrency on a
single Postgres instance that produces a long tail. The hard
assertion is `not errors` the latency check just guarantees
nothing hangs indefinitely.
"""
errors, latencies = self._runWorkers(
liveConnector, threadCount=20, callsPerThread=50
)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
p99 = latencies[int(len(latencies) * 0.99)]
assert p99 < 5.0, f"p99 latency {p99:.2f}s exceeds 5s budget"
def test_interleaved_load_and_save_no_collision(self, liveConnector):
"""T2: parallel reads + writes on the same connector → no cursor mix-up.
Pre-pool the reader could observe a row in mid-write or vice versa
because both shared the same cursor. With one connection per borrow,
the database's own row-locking is the only contention, and we just
need to assert no exceptions.
"""
stopFlag = threading.Event()
errors: list = []
lock = threading.Lock()
def reader():
while not stopFlag.is_set():
try:
liveConnector.getRecord(PoolTestRow, "seed")
except Exception as e: # noqa: BLE001
with lock:
errors.append(("read", e))
def writer():
i = 0
while not stopFlag.is_set():
try:
liveConnector.recordModify(
PoolTestRow,
"seed",
{"id": "seed", "payload": f"v{i}"},
)
i += 1
except Exception as e: # noqa: BLE001
with lock:
errors.append(("write", e))
threads = [
threading.Thread(target=reader, daemon=True),
threading.Thread(target=reader, daemon=True),
threading.Thread(target=writer, daemon=True),
threading.Thread(target=writer, daemon=True),
]
for t in threads:
t.start()
time.sleep(2.0)
stopFlag.set()
for t in threads:
t.join(timeout=3.0)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
def test_statement_timeout_releases_connection(self, liveConnector):
"""T3: `pg_sleep` past statement_timeout → QueryCanceled, pool intact.
The bug we are guarding against: a runaway query with no timeout
hung `recv()` forever, the psycopg2 connection was poisoned, and the
whole backend became unresponsive once that connection was reused.
With `statement_timeout=30000` configured at pool construction the
query is cancelled by the server, the borrow context manager rolls
back, and the connection returns to the pool proven by the fact
that a follow-up call still succeeds quickly.
"""
# Use a short timeout to keep the test fast — override the pool's
# session statement_timeout for one borrow via SET LOCAL.
with liveConnector.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute("SET LOCAL statement_timeout = 500")
with pytest.raises(psycopg2.errors.QueryCanceled):
cursor.execute("SELECT pg_sleep(5)")
# Follow-up call must succeed quickly: connection is back in the pool.
t0 = time.perf_counter()
rows = liveConnector.getRecordset(PoolTestRow)
elapsed = time.perf_counter() - t0
assert any(r["id"] == "seed" for r in rows)
assert elapsed < 1.0, f"follow-up call took {elapsed:.2f}s — pool may be wedged"
class TestPoolRegistry:
def test_one_pool_per_database_identity(self, liveConnector):
"""Two connectors against the same (host, db, port) share one pool."""
cfg = _DB_CFG
pool1 = _PoolRegistry.getPool(
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
)
pool2 = _PoolRegistry.getPool(
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
)
assert pool1 is pool2
def test_close_all_clears_registry(self, liveConnector):
"""`closeAllPools()` empties the registry so the next call rebuilds."""
# Touch the pool first.
liveConnector.getRecordset(PoolTestRow)
assert _PoolRegistry._pools, "pool should exist after a real call"
closeAllPools()
assert _PoolRegistry._pools == {}, "registry should be empty after closeAllPools()"

View file

@ -68,6 +68,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass): def _ensureTableExists(self, modelClass):
return True return True
def borrowCursor(self):
"""Mimic `DatabaseConnector.borrowCursor()` context manager."""
from contextlib import contextmanager
from unittest.mock import MagicMock
@contextmanager
def _cm():
yield MagicMock()
return _cm()
def seed(self, modelClass, record: Dict[str, Any]): def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__ tableName = modelClass.__name__
self._tables.setdefault(tableName, {}) self._tables.setdefault(tableName, {})

View file

@ -69,6 +69,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass): def _ensureTableExists(self, modelClass):
return True return True
def borrowCursor(self):
"""Mimic `DatabaseConnector.borrowCursor()` context manager for the cascade test."""
from contextlib import contextmanager
from unittest.mock import MagicMock
@contextmanager
def _cm():
yield MagicMock()
return _cm()
def seed(self, modelClass, record: Dict[str, Any]): def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__ tableName = modelClass.__name__
self._tables.setdefault(tableName, {}) self._tables.setdefault(tableName, {})