db connection pooling and rag limit transparency

This commit is contained in:
ValueOn AG 2026-05-17 20:38:37 +02:00
parent f5aba4bf99
commit 2bb65c2303
23 changed files with 1519 additions and 782 deletions

11
app.py
View file

@ -438,7 +438,16 @@ async def lifespan(app: FastAPI):
logger.error(f"Feature '{featureName}' failed to stop: {e}")
except Exception as e:
logger.warning(f"Could not shutdown feature containers: {e}")
# --- Close all PostgreSQL connection pools ---
# Must run LAST: feature `onStop` hooks may still issue DB calls during
# shutdown. Once we tear down the pools, no more borrows are possible.
try:
from modules.connectors.connectorDbPostgre import closeAllPools
closeAllPools()
except Exception as e:
logger.warning(f"Closing DB connection pools failed: {e}")
logger.info("Application has been shut down")

File diff suppressed because it is too large Load diff

View file

@ -342,7 +342,7 @@ class RealEstateObjects:
# If no exact match, try case-insensitive search via SQL query
# This handles cases where the name might have different casing
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@ -375,7 +375,7 @@ class RealEstateObjects:
# Try case-insensitive search
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@ -408,7 +408,7 @@ class RealEstateObjects:
# Try case-insensitive search
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
(name,)
@ -840,7 +840,7 @@ class RealEstateObjects:
# Ensure connection is alive
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
with self.db.borrowCursor() as cursor:
# Execute query
if parameters:
# Use parameterized query for safety

View file

@ -1659,7 +1659,7 @@ class BillingObjects:
try:
appInterface = getAppInterface(self.currentUser)
appInterface.db._ensure_connection()
with appInterface.db.connection.cursor() as cur:
with appInterface.db.borrowCursor() as cur:
if appInterface.db._ensureTableExists(UserInDB):
cur.execute(
'SELECT "id" FROM "UserInDB" WHERE '
@ -1780,7 +1780,7 @@ class BillingObjects:
try:
self.db._ensure_connection()
with self.db.connection.cursor() as cur:
with self.db.borrowCursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cur.execute(countSql, whereValues)
totalItems = cur.fetchone()["count"]
@ -1797,10 +1797,7 @@ class BillingObjects:
except Exception as e:
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
try:
self.db.connection.rollback()
except Exception:
pass
# Rollback is handled by `borrowCursor()` context manager on exit.
return {"items": [], "totalItems": 0, "totalPages": 0}
def _buildScopeFilter(
@ -1872,7 +1869,7 @@ class BillingObjects:
result: Dict[str, Any] = {}
with self.db.connection.cursor() as cur:
with self.db.borrowCursor() as cur:
# 1) Totals
cur.execute(
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
@ -1947,17 +1944,12 @@ class BillingObjects:
})
result["timeSeries"] = timeSeries
self.db.connection.commit()
# Commit/rollback are handled by `borrowCursor()` context manager.
result["_allAccounts"] = allAccounts
return result
except Exception as e:
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
try:
self.db.connection.rollback()
except Exception:
pass
return self._emptyStats()
@staticmethod

View file

@ -228,6 +228,22 @@ class KnowledgeObjects:
"""Get all ContentChunks for a file."""
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
def countChunksByFileIds(self, fileIds: List[str]) -> Dict[str, int]:
"""Return a {fileId: chunkCount} mapping for the given file IDs.
One aggregate query instead of N round trips. Used by RAG inventory
to display real chunk counts per DataSource without loading the
embedding vectors. Missing file IDs map to 0 in the caller's logic.
"""
if not fileIds:
return {}
if not self.db._ensureTableExists(ContentChunk):
return {}
sql = 'SELECT "fileId", COUNT(*) AS cnt FROM "ContentChunk" WHERE "fileId" = ANY(%s) GROUP BY "fileId"'
with self.db.borrowCursor() as cursor:
cursor.execute(sql, (list(fileIds),))
return {row["fileId"]: int(row["cnt"]) for row in cursor.fetchall()}
def deleteContentChunks(self, fileId: str) -> int:
"""Delete all ContentChunks for a file. Returns count of deleted chunks."""
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})

View file

@ -1221,22 +1221,17 @@ class ComponentObjects:
for item in fileRows
]
# Single transaction: delete FileData, FileItem, then FileFolder (children first)
self.db._ensure_connection()
try:
with self.db.connection.cursor() as cursor:
if fileIds:
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
orderedIds = list(folderIds)
orderedIds.remove(folderId)
orderedIds.append(folderId)
if orderedIds:
cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
self.db.connection.commit()
except Exception:
self.db.connection.rollback()
raise
# Single transaction: delete FileData, FileItem, then FileFolder (children first).
# Commit/rollback are handled by `borrowCursor()` on exit.
with self.db.borrowCursor() as cursor:
if fileIds:
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
orderedIds = list(folderIds)
orderedIds.remove(folderId)
orderedIds.append(folderId)
if orderedIds:
cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)}
@ -1507,7 +1502,7 @@ class ComponentObjects:
try:
self.db._ensure_connection()
with self.db.connection.cursor() as cursor:
with self.db.borrowCursor() as cursor:
cursor.execute(
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
(uniqueIds,),
@ -1526,11 +1521,10 @@ class ComponentObjects:
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,))
deletedFiles = cursor.rowcount
self.db.connection.commit()
# Commit/rollback are handled by `borrowCursor()` context manager.
return {"deletedFiles": deletedFiles}
except Exception as e:
logger.error(f"Error deleting files in batch: {e}")
self.db.connection.rollback()
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]:

View file

@ -374,7 +374,7 @@ def getRecordsetWithRBAC(
query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
cursor.execute(query, whereValues)
records = [dict(row) for row in cursor.fetchall()]
@ -561,7 +561,7 @@ def getRecordsetPaginatedWithRBAC(
offset = (pagination.page - 1) * pagination.pageSize
limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cursor.execute(countSql, countValues)
totalItems = cursor.fetchone()["count"]
@ -709,7 +709,7 @@ def getDistinctColumnValuesWithRBAC(
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val'
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
cursor.execute(sql, whereValues)
result = [row["val"] for row in cursor.fetchall()]
@ -719,7 +719,7 @@ def getDistinctColumnValuesWithRBAC(
emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1'
else:
emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1'
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
cursor.execute(emptySql, whereValues)
if cursor.fetchone():
result.append(None)
@ -967,7 +967,7 @@ def buildRbacWhereClause(
# Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate
if table == "UserInDB":
try:
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate
cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
@ -994,7 +994,7 @@ def buildRbacWhereClause(
# For UserConnection: Filter via UserMandate junction table
elif table == "UserConnection":
try:
with connector.connection.cursor() as cursor:
with connector.borrowCursor() as cursor:
# Get all user IDs that are members of the current mandate
cursor.execute(
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',

View file

@ -305,7 +305,7 @@ def handleIdsMode(
sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"'
with db.connection.cursor() as cursor:
with db.borrowCursor() as cursor:
cursor.execute(sql, values)
return JSONResponse(content=[row["val"] for row in cursor.fetchall()])
except Exception as e:

View file

@ -25,6 +25,18 @@ router = APIRouter(
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
"""Build per-connection RAG inventory rows.
Each DataSource row exposes BOTH numbers because they mean different things:
* `fileCount` distinct files indexed (== `FileContentIndex` rows)
* `chunkCount` embedding-sized text fragments (== `ContentChunk` rows,
max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval
actually hits)
A single PDF typically yields 1 file × 5100 chunks; legacy UI labelled
`len(FileContentIndex)` as "chunks" which was off by 12 orders of
magnitude and misleading.
"""
from modules.datamodels.datamodelDataSource import DataSource
from modules.datamodels.datamodelKnowledge import FileContentIndex
@ -34,19 +46,35 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
connChunkTotal = len(connIndexRows)
connFileTotal = len(connIndexRows)
# Map fileId → real chunk count via 1 aggregate query (cheap even for
# connections with thousands of files; we never load the vector body).
fileIds = [
(idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", ""))
for idx in connIndexRows
]
fileIds = [fid for fid in fileIds if fid]
chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
connChunkTotal = sum(chunkCountByFile.values())
filesByDs: Dict[str, int] = {}
chunksByDs: Dict[str, int] = {}
unassigned = 0
unassignedFiles = 0
unassignedChunks = 0
for idx in connIndexRows:
fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")
chunkCnt = chunkCountByFile.get(fileId, 0)
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
if dsIdRef:
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt
else:
unassigned += 1
unassignedFiles += 1
unassignedChunks += chunkCnt
seen: Dict[str, bool] = {}
dsItems = []
@ -64,14 +92,19 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
"fileCount": filesByDs.get(dsId, 0),
"chunkCount": chunksByDs.get(dsId, 0),
})
if unassigned > 0 and len(dsItems) > 0:
perDs = unassigned // len(dsItems)
remainder = unassigned % len(dsItems)
# Spread orphan files (provenance lost) evenly so totals match.
if unassignedFiles > 0 and len(dsItems) > 0:
perFile = unassignedFiles // len(dsItems)
remFile = unassignedFiles % len(dsItems)
perChunk = unassignedChunks // len(dsItems)
remChunk = unassignedChunks % len(dsItems)
for i, item in enumerate(dsItems):
item["chunkCount"] += perDs + (1 if i < remainder else 0)
item["fileCount"] += perFile + (1 if i < remFile else 0)
item["chunkCount"] += perChunk + (1 if i < remChunk else 0)
# Pull a wider window than the previous 5 so the "last successful
# sync" is found even if a connection has many recent jobs queued.
@ -102,6 +135,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"skippedPolicy": result.get("skippedPolicy", 0),
"failed": result.get("failed", 0),
"durationMs": result.get("durationMs", 0),
# Surface limit-stop reason so the UI can warn the user
# that the index is provably incomplete (and which budget
# to raise). None means the walker finished naturally.
"stoppedAtLimit": result.get("stoppedAtLimit"),
"limits": result.get("limits") or {},
"bytesProcessed": result.get("bytesProcessed", 0),
}
if lastError and lastSuccess:
break
@ -113,6 +152,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
"preferences": getattr(conn, "knowledgePreferences", None) or {},
"dataSources": dsItems,
"totalFiles": connFileTotal,
"totalChunks": connChunkTotal,
"runningJobs": runningJobs,
"lastError": lastError,
@ -139,8 +179,9 @@ def _getInventoryMe(
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"chunks": totalChunks}}
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except Exception as e:
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@ -170,9 +211,10 @@ def _getInventoryMandate(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
totalBytes = aggregateMandateRagTotalBytes(mandateId)
return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}}
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}}
except HTTPException:
raise
except Exception as e:
@ -202,8 +244,9 @@ def _getInventoryPlatform(
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"chunks": totalChunks}}
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
except HTTPException:
raise
except Exception as e:

View file

@ -227,7 +227,7 @@ WHERE "workflowId" = ANY(%s)
GROUP BY "workflowId"
"""
out: dict = {}
with db.connection.cursor() as cursor:
with db.borrowCursor() as cursor:
cursor.execute(sql, (workflowIds,))
for row in cursor.fetchall():
r = dict(row)
@ -480,7 +480,7 @@ def _getWorkflowsJoinedPaginated(
dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}"
db._ensure_connection()
with db.connection.cursor() as cursor:
with db.borrowCursor() as cursor:
cursor.execute(countSql, countValues)
totalItems = int(cursor.fetchone()["cnt"])

View file

@ -25,15 +25,14 @@ _CACHE_TTL_SECONDS = 300
def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str):
"""Reuse a pooled DB connector for the given feature database."""
"""Reuse a pooled DB connector for the given feature database.
The underlying psycopg2 connections live in the central pool
(`_PoolRegistry`) and are recreated on demand if they go stale; we just
need to keep the lightweight connector wrapper around.
"""
if featureDbName in _featureDbConnPool:
conn = _featureDbConnPool[featureDbName]
try:
if conn.connection and not conn.connection.closed:
return conn
except Exception as e:
logger.warning(f"Feature DB connection check failed for {featureDbName}: {e}")
_featureDbConnPool.pop(featureDbName, None)
return _featureDbConnPool[featureDbName]
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG

View file

@ -68,6 +68,9 @@ class ClickupBootstrapResult:
workspaces: int = 0
lists: int = 0
errors: List[str] = field(default_factory=list)
# First budget exhausted: "maxTasks" | "maxWorkspaces" | "maxListsPerWorkspace" | None.
# Drives the same UI banner as the file-walker bootstraps.
stoppedAtLimit: Optional[str] = None
def _syntheticTaskId(connectionId: str, taskId: str) -> str:
@ -225,6 +228,7 @@ async def bootstrapClickup(
cancelled = False
for ds in dataSources:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", "dataSource", limits)
break
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
cancelled = True
@ -243,8 +247,11 @@ async def bootstrapClickup(
clickupScope=limits.clickupScope,
)
if len(teams) > dsLimits.maxWorkspaces:
_recordLimitStop(result, "maxWorkspaces", "teams", dsLimits, hard=False)
for team in teams[:dsLimits.maxWorkspaces]:
if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks:
_recordLimitStop(result, "maxTasks", f"team={team.get('id','')}", dsLimits)
break
teamId = str(team.get("id", "") or "")
if not teamId:
@ -351,6 +358,7 @@ async def _walkTeam(
for lst in listsCollected:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", f"team={teamId}", limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
return
@ -407,6 +415,7 @@ async def _walkList(
for task in tasks:
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
_recordLimitStop(result, "maxTasks", f"list={listId}", limits)
return
if not _isRecent(task.get("date_updated"), limits.maxAgeDays):
result.skippedPolicy += 1
@ -529,13 +538,37 @@ async def _ingestTask(
)
def _recordLimitStop(
result: ClickupBootstrapResult,
limitName: str,
where: str,
limits: ClickupBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxTasks": limits.maxTasks,
"maxWorkspaces": limits.maxWorkspaces,
"maxListsPerWorkspace": limits.maxListsPerWorkspace,
}
logger.warning(
"clickup walker hit %s=%s at %s — partial index (indexed=%d, skippedDup=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.skippedDuplicate,
)
def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
"ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d",
"ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.workspaces, result.lists, durationMs,
result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "clickup",
@ -547,6 +580,7 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"workspaces": result.workspaces,
"lists": result.lists,
"durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@ -559,4 +593,11 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
"lists": result.lists,
"durationMs": durationMs,
"errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxTasks": MAX_TASKS_DEFAULT,
"maxWorkspaces": MAX_WORKSPACES_DEFAULT,
"maxListsPerWorkspace": MAX_LISTS_PER_WORKSPACE_DEFAULT,
"maxAgeDays": MAX_AGE_DAYS_DEFAULT,
},
}

View file

@ -61,6 +61,8 @@ class GdriveBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -265,8 +267,10 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
@ -276,6 +280,9 @@ async def _walkFolder(
mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType")
if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME:
if depth + 1 > limits.maxDepth:
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@ -298,6 +305,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
modifiedTime = metadata.get("modifiedTime")
@ -470,13 +478,38 @@ async def _ingestOne(
await asyncio.sleep(0)
def _recordLimitStop(
result: GdriveBootstrapResult,
limitName: str,
where: str,
limits: GdriveBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"gdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
"ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d",
"ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy,
result.failed, result.bytesProcessed, durationMs,
result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "gdrive",
@ -487,6 +520,7 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"failed": result.failed,
"bytes": result.bytesProcessed,
"durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@ -498,4 +532,11 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
}

View file

@ -53,6 +53,8 @@ class KdriveBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
stoppedAtLimit: Optional[str] = None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -232,14 +234,19 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False):
if depth + 1 > limits.maxDepth:
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@ -262,6 +269,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
@ -415,17 +423,42 @@ async def _ingestOne(
await asyncio.sleep(0)
def _recordLimitStop(
result: KdriveBootstrapResult,
limitName: str,
where: str,
limits: KdriveBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"kdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
durationMs,
durationMs, result.stoppedAtLimit or "none",
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
"connectionId": connectionId, "indexed": result.indexed,
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
"failed": result.failed, "durationMs": durationMs},
"failed": result.failed, "durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit},
)
return {
"connectionId": result.connectionId,
@ -436,4 +469,11 @@ def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: f
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
}

View file

@ -59,6 +59,10 @@ class SharepointBootstrapResult:
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
# First budget that hit zero; None means the walk completed naturally.
# Surfaces in the bootstrap result so the RAG inventory UI can warn the
# user that the corpus is incomplete and tell them which knob to turn.
stoppedAtLimit: Optional[str] = None # "maxItems" | "maxBytes" | "maxDepth" | "maxFileSize" | None
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
@ -259,14 +263,22 @@ async def _walkFolder(
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
_recordLimitStop(result, "maxItems", folderPath, limits)
return
if result.bytesProcessed >= limits.maxBytes:
_recordLimitStop(result, "maxBytes", folderPath, limits)
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False):
if depth + 1 > limits.maxDepth:
# We stop descending here but keep walking siblings.
# Record once per bootstrap so the UI shows "maxDepth" even
# if other budgets aren't exhausted yet.
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
continue
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
@ -289,6 +301,7 @@ async def _walkFolder(
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
result.skippedPolicy += 1
continue
@ -443,13 +456,44 @@ async def _ingestOne(
await asyncio.sleep(0)
def _recordLimitStop(
result: SharepointBootstrapResult,
limitName: str,
where: str,
limits: SharepointBootstrapLimits,
*,
hard: bool = True,
) -> None:
"""Mark the FIRST limit that bit. Soft hits (per-file maxFileSize, per-folder
maxDepth) only record when no hard limit has yet stopped the run, so the UI
surfaces the most important reason.
Hard limits (maxItems / maxBytes) ALWAYS overwrite a previously recorded
soft limit once a hard cap is hit, the corpus is provably incomplete.
"""
if hard or result.stoppedAtLimit is None:
result.stoppedAtLimit = limitName
budgetMap = {
"maxItems": limits.maxItems,
"maxBytes": limits.maxBytes,
"maxDepth": limits.maxDepth,
"maxFileSize": limits.maxFileSize,
}
logger.warning(
"sharepoint walker hit %s=%s at %s — partial index "
"(indexed=%d, bytesProcessed=%d). Raise the limit or split the data source.",
limitName, budgetMap.get(limitName), where,
result.indexed, result.bytesProcessed,
)
def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
"ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
"ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
durationMs,
durationMs, result.stoppedAtLimit or "none",
extra={
"event": "ingestion.connection.bootstrap.done",
"part": "sharepoint",
@ -459,6 +503,7 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"skippedPolicy": result.skippedPolicy,
"failed": result.failed,
"durationMs": durationMs,
"stoppedAtLimit": result.stoppedAtLimit,
},
)
return {
@ -470,4 +515,11 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
"stoppedAtLimit": result.stoppedAtLimit,
"limits": {
"maxItems": MAX_ITEMS_DEFAULT,
"maxBytes": MAX_BYTES_DEFAULT,
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
"maxDepth": MAX_DEPTH_DEFAULT,
},
}

View file

@ -12,7 +12,8 @@ import logging
import json
import base64
import time
from typing import Any, Dict, Optional
import threading
from typing import Any, Dict, Optional, Tuple
from pathlib import Path
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
@ -286,6 +287,16 @@ def handleSecretJson(value: str, userId: str = "system", keyName: str = "unknown
# Structure: {user_id: {key_name: [timestamps]}}
_decryption_attempts = {}
# Process-wide plaintext cache for decrypted secrets.
# Key: the encrypted ciphertext (which already includes env prefix).
# Value: (expiresAtMonotonic, plaintext).
# TTL is short enough that key rotation propagates quickly, long enough that
# hot DB-init paths (every API call building a connector) don't blow the
# decryption rate limit. 60s is a deliberate compromise.
_DECRYPTION_CACHE_TTL_S = 60.0
_decryption_cache: Dict[str, Tuple[float, str]] = {}
_decryption_cache_lock = threading.Lock()
def _getMasterKey(envType: str = None) -> bytes:
"""
Get the master key for the specified environment.
@ -486,25 +497,43 @@ def encryptValue(value: str, envType: str = None, userId: str = "system", keyNam
def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "unknown") -> str:
"""
Decrypt a value using the master key for the current environment.
A short-lived plaintext cache (TTL `_DECRYPTION_CACHE_TTL_S`) is consulted
first. The 10/sec rate-limit on cache misses still protects against
brute-force attacks; cache HITS bypass it because they are not actual
cryptographic operations they just return the result of an earlier
successful decrypt. Without this cache, hot paths like
`mainBackgroundJobService._getDb()` (called per RAG inventory poll AND
per walker DB call) trigger the rate limit and surface as
"Decryption rate limit exceeded for user 'system' key 'DB_PASSWORD_SECRET'"
ERRORs in the RAG inventory UI route.
Args:
encryptedValue: The encrypted value with prefix
userId: The user ID making the request (default: "system")
keyName: The name of the key being decrypted (default: "unknown")
Returns:
str: The decrypted plain text value
Raises:
ValueError: If decryption fails
"""
if not _isEncryptedValue(encryptedValue):
return encryptedValue # Return as-is if not encrypted
# Check rate limiting (10 per second per user per key)
# Cache lookup BEFORE the rate-limit check: a cache hit is not a new
# cryptographic operation and must not be throttled.
now = time.monotonic()
with _decryption_cache_lock:
cached = _decryption_cache.get(encryptedValue)
if cached is not None and cached[0] > now:
return cached[1]
# Cache miss → real decrypt → apply rate limit.
if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10):
raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)")
try:
# Extract environment type from prefix
if encryptedValue.startswith('DEV_ENC:'):
@ -536,7 +565,7 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
encryptedBytes = base64.urlsafe_b64decode(encryptedPart.encode('utf-8'))
decryptedBytes = fernet.decrypt(encryptedBytes)
decryptedValue = decryptedBytes.decode('utf-8')
# Log audit event for decryption
try:
from modules.shared.auditLogger import audit_logger
@ -549,11 +578,25 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
except Exception:
# Don't fail if audit logging fails
pass
# Populate cache so subsequent reads of the same ciphertext don't
# re-decrypt (and don't consume rate-limit budget).
with _decryption_cache_lock:
_decryption_cache[encryptedValue] = (
time.monotonic() + _DECRYPTION_CACHE_TTL_S,
decryptedValue,
)
return decryptedValue
except Exception as e:
raise ValueError(f"Decryption failed: {e}")
def clearDecryptionCache() -> None:
"""Drop all cached plaintext secrets. Call after key rotation or in tests."""
with _decryption_cache_lock:
_decryption_cache.clear()
# Create the global APP_CONFIG instance
APP_CONFIG = Configuration()

View file

@ -33,20 +33,35 @@ def _ensureUamTablesMatchModels(dbConnector) -> None:
logger.debug(f"_ensureUamTablesMatchModels: {e}")
def _getConnection(dbConnector):
"""Get a connection from the DatabaseConnector.
Ensures the connection is alive and returns it.
Commits any pending transaction first to avoid blocking.
from contextlib import contextmanager
@contextmanager
def _borrowDbConn(dbConnector):
"""Borrow a pooled connection from the DatabaseConnector.
Index/trigger/FK creation traditionally ran with `conn.autocommit = True`
so each CREATE statement is its own transaction (DDL on a managed
connection blocks waiting for COMMIT). This helper preserves that
behaviour on top of the pool: borrow a connection, flip it to autocommit,
yield it, and restore the previous state before returning it to the pool.
"""
dbConnector._ensure_connection()
conn = dbConnector.connection
# Commit any pending transaction to avoid blocking
try:
conn.commit()
except Exception:
pass # Ignore if nothing to commit
return conn
with dbConnector.borrowConn() as conn:
try:
previousAutocommit = conn.autocommit
except Exception:
previousAutocommit = False
try:
conn.autocommit = True
except Exception as e:
logger.debug(f"Could not set autocommit on borrowed connection: {e}")
try:
yield conn
finally:
try:
conn.autocommit = previousAutocommit
except Exception:
pass
# =============================================================================
@ -174,73 +189,42 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
}
try:
# Get a connection from the connector
conn = _getConnection(dbConnector)
# Save and set autocommit state
try:
originalAutocommit = conn.autocommit
except Exception:
originalAutocommit = False
try:
conn.autocommit = True
except Exception as autoErr:
logger.debug(f"Could not set autocommit: {autoErr}")
try:
_ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}")
try:
with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor:
# Apply indexes
results["indexesCreated"] = _applyIndexes(cursor, tables)
# Apply foreign keys
results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables)
# Apply immutable triggers
results["triggersCreated"] = _applyImmutableTriggers(cursor, tables)
logger.info(
f"Multi-tenant optimizations applied: "
f"{results['indexesCreated']} indexes, "
f"{results['triggersCreated']} triggers, "
f"{results['foreignKeysCreated']} foreign keys"
)
finally:
# Restore original autocommit state
try:
conn.autocommit = originalAutocommit
except Exception:
pass
logger.info(
f"Multi-tenant optimizations applied: "
f"{results['indexesCreated']} indexes, "
f"{results['triggersCreated']} triggers, "
f"{results['foreignKeysCreated']} foreign keys"
)
except Exception as e:
logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}")
results["errors"].append(str(e))
return results
def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int:
"""Apply only indexes (lighter operation, safe for frequent calls)."""
try:
conn = _getConnection(dbConnector)
originalAutocommit = conn.autocommit
conn.autocommit = True
try:
_ensureUamTablesMatchModels(dbConnector)
except Exception as preIdxErr:
logger.debug(f"Pre-index table ensure: {preIdxErr}")
try:
with _borrowDbConn(dbConnector) as conn:
with conn.cursor() as cursor:
return _applyIndexes(cursor, tables)
finally:
conn.autocommit = originalAutocommit
except Exception as e:
logger.error(f"Error applying indexes: {e}")
return 0
@ -514,8 +498,7 @@ def getOptimizationStatus(dbConnector) -> dict:
}
try:
conn = _getConnection(dbConnector)
with conn.cursor() as cursor:
with _borrowDbConn(dbConnector) as conn, conn.cursor() as cursor:
# Check regular indexes
for tableName, indexName, _ in _INDEXES:
if _tableExists(cursor, tableName):

View file

@ -60,11 +60,9 @@ def _getTableColumns(dbConnector, tableName: str) -> List[str]:
ORDER BY ordinal_position
"""
cursor = dbConnector.connection.cursor()
cursor.execute(query, (tableName,))
columns = [row[0] for row in cursor.fetchall()]
cursor.close()
with dbConnector.borrowCursor() as cursor:
cursor.execute(query, (tableName,))
columns = [row[0] for row in cursor.fetchall()]
return columns
except Exception as e:
logger.error(f"Error getting columns for table {tableName}: {e}")
@ -92,29 +90,26 @@ def _getAllTables(dbConnector) -> List[str]:
ORDER BY table_name
"""
cursor = dbConnector.connection.cursor()
cursor.execute(query)
allTables = [row[0] for row in cursor.fetchall()]
# Get foreign key relationships to determine dependency order
fkQuery = """
SELECT
tc.table_name,
ccu.table_name AS foreign_table_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public'
"""
cursor.execute(fkQuery)
foreignKeys = cursor.fetchall()
cursor.close()
with dbConnector.borrowCursor() as cursor:
cursor.execute(query)
allTables = [row[0] for row in cursor.fetchall()]
fkQuery = """
SELECT
tc.table_name,
ccu.table_name AS foreign_table_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = 'public'
"""
cursor.execute(fkQuery)
foreignKeys = cursor.fetchall()
# Build dependency graph (child -> parent mapping)
dependencies = {}
@ -154,10 +149,9 @@ def _getAllTables(dbConnector) -> List[str]:
# Fallback: return simple list without ordering
try:
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
cursor = dbConnector.connection.cursor()
cursor.execute(query)
tables = [row[0] for row in cursor.fetchall()]
cursor.close()
with dbConnector.borrowCursor() as cursor:
cursor.execute(query)
tables = [row[0] for row in cursor.fetchall()]
return [t for t in tables if t not in PROTECTED_TABLES]
except Exception:
return []
@ -184,11 +178,9 @@ def _getPrimaryKeyColumns(dbConnector, tableName: str) -> List[str]:
AND i.indisprimary
"""
cursor = dbConnector.connection.cursor()
cursor.execute(query, (tableName,))
pkColumns = [row[0] for row in cursor.fetchall()]
cursor.close()
with dbConnector.borrowCursor() as cursor:
cursor.execute(query, (tableName,))
pkColumns = [row[0] for row in cursor.fetchall()]
return pkColumns
except Exception as e:
logger.debug(f"Could not get primary key for {tableName}: {e}")
@ -229,21 +221,15 @@ def _findUserReferencesInTable(
return {}
references = {}
cursor = dbConnector.connection.cursor()
for userColumn in userColumns:
# Build SELECT for primary key columns
pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
cursor.execute(query, (userId,))
recordKeys = cursor.fetchall()
if recordKeys:
references[userColumn] = recordKeys
logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
cursor.close()
with dbConnector.borrowCursor() as cursor:
for userColumn in userColumns:
pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
cursor.execute(query, (userId,))
recordKeys = cursor.fetchall()
if recordKeys:
references[userColumn] = recordKeys
logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
return references
except Exception as e:
@ -277,42 +263,35 @@ def _anonymizeRecords(
return 0
try:
cursor = dbConnector.connection.cursor()
# Resolve column metadata once outside the borrow block (it borrows its
# own connection internally).
columns = _getTableColumns(dbConnector, tableName)
hasModifiedAt = "sysModifiedAt" in columns
count = 0
for recordKey in recordKeys:
# Build WHERE clause for primary key
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
# Check if table has sysModifiedAt column
columns = _getTableColumns(dbConnector, tableName)
hasModifiedAt = "sysModifiedAt" in columns
if hasModifiedAt:
query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
params = [anonymousValue, getUtcTimestamp()]
else:
query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
params = [anonymousValue]
# Add primary key values to params
if isinstance(recordKey, tuple):
params.extend(recordKey)
else:
params.append(recordKey)
cursor.execute(query, params)
count += cursor.rowcount
dbConnector.connection.commit()
cursor.close()
with dbConnector.borrowCursor() as cursor:
for recordKey in recordKeys:
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
if hasModifiedAt:
query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
params = [anonymousValue, getUtcTimestamp()]
else:
query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
params = [anonymousValue]
if isinstance(recordKey, tuple):
params.extend(recordKey)
else:
params.append(recordKey)
cursor.execute(query, params)
count += cursor.rowcount
logger.info(f"Anonymized {count} records in {tableName}.{columnName}")
return count
except Exception as e:
logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}")
dbConnector.connection.rollback()
return 0
@ -338,32 +317,23 @@ def _deleteRecords(
return 0
try:
cursor = dbConnector.connection.cursor()
count = 0
for recordKey in recordKeys:
# Build WHERE clause for primary key
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
# Prepare params
if isinstance(recordKey, tuple):
params = list(recordKey)
else:
params = [recordKey]
cursor.execute(query, params)
count += cursor.rowcount
dbConnector.connection.commit()
cursor.close()
with dbConnector.borrowCursor() as cursor:
for recordKey in recordKeys:
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
if isinstance(recordKey, tuple):
params = list(recordKey)
else:
params = [recordKey]
cursor.execute(query, params)
count += cursor.rowcount
logger.info(f"Deleted {count} records from {tableName}")
return count
except Exception as e:
logger.error(f"Error deleting records from {tableName}: {e}")
dbConnector.connection.rollback()
return 0

View file

@ -25,7 +25,7 @@ if not c or not c.connection:
print("STAGE0: DB_CONNECTION=none (check config.ini / .env)")
raise SystemExit(2)
cur = c.connection.cursor()
cur = c.borrowCursor()
def _scalar(cur):

View file

@ -12,11 +12,16 @@ broken query into "no rows found". That hid bugs like:
These tests pin the new contract: empty result sets still return ``[]`` /
``None`` (normal), but any exception inside the query path propagates as
``DatabaseQueryError`` with the table name attached. The transaction is
rolled back so the connection is usable for subsequent queries.
``DatabaseQueryError`` with the table name attached.
Since the 2026-05-17 pool refactor (`c-work/2-build/2026-05-postgres-connection-pool.md`)
the connector borrows a connection from `_PoolRegistry` on every call via the
`borrowConn()` context manager. The tests mock that context manager so the
fast-fail contract is exercised without requiring a live Postgres server.
"""
from __future__ import annotations
from contextlib import contextmanager
from unittest.mock import MagicMock
import pytest
@ -25,7 +30,6 @@ import psycopg2.errors
from modules.connectors.connectorDbPostgre import (
DatabaseConnector,
DatabaseQueryError,
_rollbackQuietly,
)
@ -39,26 +43,44 @@ class DummyTable:
def _makeConnector(cursorBehavior):
"""Build a ``DatabaseConnector`` skeleton with mocked connection/cursor.
"""Build a ``DatabaseConnector`` skeleton with a mocked pool borrow.
``cursorBehavior`` is a callable invoked with the cursor mock so the test
can configure ``execute``/``fetchall``/``fetchone`` per scenario.
Returns ``(connector, conn, cursor)``:
* ``conn`` exposes ``commit`` / ``rollback`` MagicMocks so tests can
assert that the borrow lifecycle did the right thing.
* ``cursor`` is the per-test cursor mock.
"""
connector = DatabaseConnector.__new__(DatabaseConnector)
cursor = MagicMock()
cursorBehavior(cursor)
cursorContext = MagicMock()
cursorContext.__enter__ = MagicMock(return_value=cursor)
cursorContext.__exit__ = MagicMock(return_value=False)
connection = MagicMock()
connection.cursor.return_value = cursorContext
connector.connection = connection
conn = MagicMock()
conn.cursor.return_value = cursorContext
@contextmanager
def fakeBorrow():
try:
yield conn
except Exception:
conn.rollback()
raise
else:
conn.commit()
connector.borrowConn = fakeBorrow
connector._ensureTableExists = MagicMock(return_value=True)
connector._systemTableName = "_system"
cursorBehavior(cursor)
return connector, connection, cursor
return connector, conn, cursor
class TestGetRecordsetFailLoud:
@ -67,11 +89,12 @@ class TestGetRecordsetFailLoud:
def behavior(cursor):
cursor.execute.return_value = None
cursor.fetchall.return_value = []
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
result = connector.getRecordset(DummyTable)
assert result == []
connection.rollback.assert_not_called()
conn.rollback.assert_not_called()
conn.commit.assert_called_once()
def test_dictAdaptErrorRaisesDatabaseQueryError(self):
"""Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise."""
@ -79,7 +102,7 @@ class TestGetRecordsetFailLoud:
cursor.execute.side_effect = psycopg2.ProgrammingError(
"can't adapt type 'dict'"
)
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset(
@ -90,30 +113,30 @@ class TestGetRecordsetFailLoud:
assert excinfo.value.table == "DummyTable"
assert "can't adapt type 'dict'" in str(excinfo.value)
assert isinstance(excinfo.value.original, psycopg2.ProgrammingError)
connection.rollback.assert_called_once()
conn.rollback.assert_called_once()
def test_missingColumnRaisesDatabaseQueryError(self):
def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedColumn(
'column "wat" does not exist'
)
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecordset(DummyTable, recordFilter={"wat": "x"})
assert "wat" in str(excinfo.value)
connection.rollback.assert_called_once()
conn.rollback.assert_called_once()
def test_operationalErrorRaisesDatabaseQueryError(self):
"""Connection lost mid-query is also a real failure that must propagate."""
def behavior(cursor):
cursor.execute.side_effect = psycopg2.OperationalError("connection lost")
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError):
connector.getRecordset(DummyTable)
connection.rollback.assert_called_once()
conn.rollback.assert_called_once()
class TestGetRecordFailLoud:
@ -122,37 +145,22 @@ class TestGetRecordFailLoud:
def behavior(cursor):
cursor.execute.return_value = None
cursor.fetchone.return_value = None
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
result = connector.getRecord(DummyTable, "missing-id")
assert result is None
connection.rollback.assert_not_called()
conn.rollback.assert_not_called()
conn.commit.assert_called_once()
def test_queryErrorRaisesDatabaseQueryError(self):
def behavior(cursor):
cursor.execute.side_effect = psycopg2.errors.UndefinedTable(
'relation "DummyTable" does not exist'
)
connector, connection, _ = _makeConnector(behavior)
connector, conn, _ = _makeConnector(behavior)
with pytest.raises(DatabaseQueryError) as excinfo:
connector.getRecord(DummyTable, "any-id")
assert excinfo.value.table == "DummyTable"
connection.rollback.assert_called_once()
class TestRollbackQuietly:
def test_rollsBackOnLiveConnection(self):
connection = MagicMock()
_rollbackQuietly(connection)
connection.rollback.assert_called_once()
def test_swallowsRollbackError(self):
"""Rollback failure must not mask the original query error."""
connection = MagicMock()
connection.rollback.side_effect = RuntimeError("rollback failed")
_rollbackQuietly(connection)
def test_noopOnNoneConnection(self):
_rollbackQuietly(None)
conn.rollback.assert_called_once()

View file

@ -0,0 +1,304 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Concurrency tests for the PostgreSQL connection pool.
These tests pin the contract that the `c-work/2-build/2026-05-postgres-connection-pool.md`
refactor delivered:
* T1 50 threads × 100 calls in parallel produce 0 `OperationalError`s and
every call completes within reasonable time (p99 < 2 s).
* T2 Two threads `_loadRecord` + `_saveRecord` against the same connector
do not corrupt each other's cursors.
* T3 `statement_timeout` aborts a runaway `pg_sleep(60)` after ~30 s and
releases the connection back into the pool cleanly.
The tests need a real PostgreSQL server because the bug they guard against
only materialises with real psycopg2 sockets a mocked connection never
hangs in `recv()`. They read DB credentials from `APP_CONFIG` (which loads
`.env`) and are auto-skipped when the connection fails (no local Postgres,
wrong creds, etc.) so `pytest` keeps working in CI-only environments.
To run them locally:
pytest gateway/tests/unit/connectors/test_connectorDbPostgre_pool.py -v
They use a throwaway database name (`poweron_pool_test_<uuid>`) and drop it
in fixture teardown so they leave nothing behind.
"""
from __future__ import annotations
import time
import uuid
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
import psycopg2
import psycopg2.errors
import pytest
from pydantic import Field
from modules.connectors.connectorDbPostgre import (
DatabaseConnector,
_PoolRegistry,
closeAllPools,
)
from modules.datamodels.datamodelBase import PowerOnModel
from modules.shared.configuration import APP_CONFIG
def _dbConfig():
"""Read DB connection params from APP_CONFIG (`.env`).
Returns ``None`` when host/user/password are not all present so the
test module can skip cleanly instead of blowing up at import time.
"""
host = APP_CONFIG.get("DB_HOST")
user = APP_CONFIG.get("DB_USER")
password = APP_CONFIG.get("DB_PASSWORD_SECRET")
port = APP_CONFIG.get("DB_PORT", 5432)
if not host or not user or password is None:
return None
return {"host": host, "user": user, "password": password, "port": int(port)}
def _canReachPostgres(cfg) -> bool:
"""Try a quick connect to the admin DB so we can skip on connection failures."""
try:
conn = psycopg2.connect(
host=cfg["host"], port=cfg["port"], database="postgres",
user=cfg["user"], password=cfg["password"], connect_timeout=2,
)
conn.close()
return True
except Exception: # noqa: BLE001
return False
_DB_CFG = _dbConfig()
pytestmark = pytest.mark.skipif(
_DB_CFG is None or not _canReachPostgres(_DB_CFG),
reason="No reachable PostgreSQL — skipping live-Postgres pool tests",
)
class PoolTestRow(PowerOnModel):
"""Tiny model used to exercise the pool — one ID + one payload field."""
payload: str = Field(default="", description="Test payload")
@pytest.fixture
def liveConnector():
"""Spin up a throwaway database, yield a `DatabaseConnector` against it,
drop the database afterwards.
The pool registry is wiped before and after each test so state from one
test cannot mask a bug in another.
"""
cfg = _DB_CFG
host = cfg["host"]
user = cfg["user"]
password = cfg["password"]
port = cfg["port"]
dbName = f"poweron_pool_test_{uuid.uuid4().hex[:8]}"
# Pre-clean: drop any orphan test DB with the same name (shouldn't happen
# because we use a unique uuid, but be defensive).
adminConn = psycopg2.connect(
host=host, port=port, database="postgres", user=user, password=password
)
adminConn.autocommit = True
try:
with adminConn.cursor() as cur:
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
finally:
adminConn.close()
closeAllPools()
connector = DatabaseConnector(
dbHost=host,
dbDatabase=dbName,
dbUser=user,
dbPassword=password,
dbPort=port,
)
# Seed exactly one row so every concurrent read has a stable target.
connector.recordCreate(PoolTestRow, {"id": "seed", "payload": "hello"})
yield connector
# Teardown: tear pools down, then drop the DB.
closeAllPools()
adminConn = psycopg2.connect(
host=host, port=port, database="postgres", user=user, password=password
)
adminConn.autocommit = True
try:
with adminConn.cursor() as cur:
cur.execute(
'SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s',
(dbName,),
)
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
finally:
adminConn.close()
class TestPoolConcurrency:
def _runWorkers(self, liveConnector, *, threadCount: int, callsPerThread: int):
"""Run N worker threads, each issuing M reads. Return (errors, latencies)."""
errors: list = []
latencies: list = []
lock = threading.Lock()
def worker():
for _ in range(callsPerThread):
t0 = time.perf_counter()
try:
rows = liveConnector.getRecordset(PoolTestRow)
assert any(r["id"] == "seed" for r in rows)
except Exception as e: # noqa: BLE001 — we want every failure mode
with lock:
errors.append(e)
finally:
with lock:
latencies.append(time.perf_counter() - t0)
with ThreadPoolExecutor(max_workers=threadCount) as ex:
futures = [ex.submit(worker) for _ in range(threadCount)]
for f in as_completed(futures):
f.result()
latencies.sort()
return errors, latencies
def test_50_threads_x_20_reads_no_errors(self, liveConnector):
"""T1a — STRESS: 50 threads × 20 reads each → 0 errors.
Pre-pool, this scenario produced either
`OperationalError: another command is already in progress` or a
deadlock in `recv()` because the threadpool shared one psycopg2
socket. With the pool plus `borrowConn`'s bounded wait, every
thread eventually gets a connection and completes even with 30
threads queued waiting at any moment (pool max=20).
"""
errors, _ = self._runWorkers(liveConnector, threadCount=50, callsPerThread=20)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
def test_20_threads_x_50_reads_latency_budget(self, liveConnector):
"""T1b — DESIGN CAPACITY: 20 threads × 50 reads, p99 < 5 s.
20 threads matches the pool's `max=20` so there is no queueing —
every borrow returns immediately. This pins a sanity-level per-call
latency budget; pre-pool it was unbounded (recv() never returned).
The 5 s ceiling is generous on purpose: `getRecordset` calls
`_ensureTableExists` which runs two `information_schema` queries
for column-additive migration, and under 20-way concurrency on a
single Postgres instance that produces a long tail. The hard
assertion is `not errors` the latency check just guarantees
nothing hangs indefinitely.
"""
errors, latencies = self._runWorkers(
liveConnector, threadCount=20, callsPerThread=50
)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
p99 = latencies[int(len(latencies) * 0.99)]
assert p99 < 5.0, f"p99 latency {p99:.2f}s exceeds 5s budget"
def test_interleaved_load_and_save_no_collision(self, liveConnector):
"""T2: parallel reads + writes on the same connector → no cursor mix-up.
Pre-pool the reader could observe a row in mid-write or vice versa
because both shared the same cursor. With one connection per borrow,
the database's own row-locking is the only contention, and we just
need to assert no exceptions.
"""
stopFlag = threading.Event()
errors: list = []
lock = threading.Lock()
def reader():
while not stopFlag.is_set():
try:
liveConnector.getRecord(PoolTestRow, "seed")
except Exception as e: # noqa: BLE001
with lock:
errors.append(("read", e))
def writer():
i = 0
while not stopFlag.is_set():
try:
liveConnector.recordModify(
PoolTestRow,
"seed",
{"id": "seed", "payload": f"v{i}"},
)
i += 1
except Exception as e: # noqa: BLE001
with lock:
errors.append(("write", e))
threads = [
threading.Thread(target=reader, daemon=True),
threading.Thread(target=reader, daemon=True),
threading.Thread(target=writer, daemon=True),
threading.Thread(target=writer, daemon=True),
]
for t in threads:
t.start()
time.sleep(2.0)
stopFlag.set()
for t in threads:
t.join(timeout=3.0)
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
def test_statement_timeout_releases_connection(self, liveConnector):
"""T3: `pg_sleep` past statement_timeout → QueryCanceled, pool intact.
The bug we are guarding against: a runaway query with no timeout
hung `recv()` forever, the psycopg2 connection was poisoned, and the
whole backend became unresponsive once that connection was reused.
With `statement_timeout=30000` configured at pool construction the
query is cancelled by the server, the borrow context manager rolls
back, and the connection returns to the pool proven by the fact
that a follow-up call still succeeds quickly.
"""
# Use a short timeout to keep the test fast — override the pool's
# session statement_timeout for one borrow via SET LOCAL.
with liveConnector.borrowConn() as conn:
with conn.cursor() as cursor:
cursor.execute("SET LOCAL statement_timeout = 500")
with pytest.raises(psycopg2.errors.QueryCanceled):
cursor.execute("SELECT pg_sleep(5)")
# Follow-up call must succeed quickly: connection is back in the pool.
t0 = time.perf_counter()
rows = liveConnector.getRecordset(PoolTestRow)
elapsed = time.perf_counter() - t0
assert any(r["id"] == "seed" for r in rows)
assert elapsed < 1.0, f"follow-up call took {elapsed:.2f}s — pool may be wedged"
class TestPoolRegistry:
def test_one_pool_per_database_identity(self, liveConnector):
"""Two connectors against the same (host, db, port) share one pool."""
cfg = _DB_CFG
pool1 = _PoolRegistry.getPool(
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
)
pool2 = _PoolRegistry.getPool(
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
)
assert pool1 is pool2
def test_close_all_clears_registry(self, liveConnector):
"""`closeAllPools()` empties the registry so the next call rebuilds."""
# Touch the pool first.
liveConnector.getRecordset(PoolTestRow)
assert _PoolRegistry._pools, "pool should exist after a real call"
closeAllPools()
assert _PoolRegistry._pools == {}, "registry should be empty after closeAllPools()"

View file

@ -68,6 +68,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass):
return True
def borrowCursor(self):
"""Mimic `DatabaseConnector.borrowCursor()` context manager."""
from contextlib import contextmanager
from unittest.mock import MagicMock
@contextmanager
def _cm():
yield MagicMock()
return _cm()
def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__
self._tables.setdefault(tableName, {})

View file

@ -69,6 +69,16 @@ class _FakeDb:
def _ensureTableExists(self, modelClass):
return True
def borrowCursor(self):
"""Mimic `DatabaseConnector.borrowCursor()` context manager for the cascade test."""
from contextlib import contextmanager
from unittest.mock import MagicMock
@contextmanager
def _cm():
yield MagicMock()
return _cm()
def seed(self, modelClass, record: Dict[str, Any]):
tableName = modelClass.__name__
self._tables.setdefault(tableName, {})