db connection pooling and rag limit transparency
This commit is contained in:
parent
f5aba4bf99
commit
2bb65c2303
23 changed files with 1519 additions and 782 deletions
11
app.py
11
app.py
|
|
@ -438,7 +438,16 @@ async def lifespan(app: FastAPI):
|
|||
logger.error(f"Feature '{featureName}' failed to stop: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not shutdown feature containers: {e}")
|
||||
|
||||
|
||||
# --- Close all PostgreSQL connection pools ---
|
||||
# Must run LAST: feature `onStop` hooks may still issue DB calls during
|
||||
# shutdown. Once we tear down the pools, no more borrows are possible.
|
||||
try:
|
||||
from modules.connectors.connectorDbPostgre import closeAllPools
|
||||
closeAllPools()
|
||||
except Exception as e:
|
||||
logger.warning(f"Closing DB connection pools failed: {e}")
|
||||
|
||||
logger.info("Application has been shut down")
|
||||
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -342,7 +342,7 @@ class RealEstateObjects:
|
|||
# If no exact match, try case-insensitive search via SQL query
|
||||
# This handles cases where the name might have different casing
|
||||
self.db._ensure_connection()
|
||||
with self.db.connection.cursor() as cursor:
|
||||
with self.db.borrowCursor() as cursor:
|
||||
cursor.execute(
|
||||
'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
|
||||
(name,)
|
||||
|
|
@ -375,7 +375,7 @@ class RealEstateObjects:
|
|||
|
||||
# Try case-insensitive search
|
||||
self.db._ensure_connection()
|
||||
with self.db.connection.cursor() as cursor:
|
||||
with self.db.borrowCursor() as cursor:
|
||||
cursor.execute(
|
||||
'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
|
||||
(name,)
|
||||
|
|
@ -408,7 +408,7 @@ class RealEstateObjects:
|
|||
|
||||
# Try case-insensitive search
|
||||
self.db._ensure_connection()
|
||||
with self.db.connection.cursor() as cursor:
|
||||
with self.db.borrowCursor() as cursor:
|
||||
cursor.execute(
|
||||
'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
|
||||
(name,)
|
||||
|
|
@ -840,7 +840,7 @@ class RealEstateObjects:
|
|||
# Ensure connection is alive
|
||||
self.db._ensure_connection()
|
||||
|
||||
with self.db.connection.cursor() as cursor:
|
||||
with self.db.borrowCursor() as cursor:
|
||||
# Execute query
|
||||
if parameters:
|
||||
# Use parameterized query for safety
|
||||
|
|
|
|||
|
|
@ -1659,7 +1659,7 @@ class BillingObjects:
|
|||
try:
|
||||
appInterface = getAppInterface(self.currentUser)
|
||||
appInterface.db._ensure_connection()
|
||||
with appInterface.db.connection.cursor() as cur:
|
||||
with appInterface.db.borrowCursor() as cur:
|
||||
if appInterface.db._ensureTableExists(UserInDB):
|
||||
cur.execute(
|
||||
'SELECT "id" FROM "UserInDB" WHERE '
|
||||
|
|
@ -1780,7 +1780,7 @@ class BillingObjects:
|
|||
|
||||
try:
|
||||
self.db._ensure_connection()
|
||||
with self.db.connection.cursor() as cur:
|
||||
with self.db.borrowCursor() as cur:
|
||||
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
|
||||
cur.execute(countSql, whereValues)
|
||||
totalItems = cur.fetchone()["count"]
|
||||
|
|
@ -1797,10 +1797,7 @@ class BillingObjects:
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
|
||||
try:
|
||||
self.db.connection.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
# Rollback is handled by `borrowCursor()` context manager on exit.
|
||||
return {"items": [], "totalItems": 0, "totalPages": 0}
|
||||
|
||||
def _buildScopeFilter(
|
||||
|
|
@ -1872,7 +1869,7 @@ class BillingObjects:
|
|||
|
||||
result: Dict[str, Any] = {}
|
||||
|
||||
with self.db.connection.cursor() as cur:
|
||||
with self.db.borrowCursor() as cur:
|
||||
# 1) Totals
|
||||
cur.execute(
|
||||
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
|
||||
|
|
@ -1947,17 +1944,12 @@ class BillingObjects:
|
|||
})
|
||||
result["timeSeries"] = timeSeries
|
||||
|
||||
self.db.connection.commit()
|
||||
|
||||
# Commit/rollback are handled by `borrowCursor()` context manager.
|
||||
result["_allAccounts"] = allAccounts
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
|
||||
try:
|
||||
self.db.connection.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
return self._emptyStats()
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -228,6 +228,22 @@ class KnowledgeObjects:
|
|||
"""Get all ContentChunks for a file."""
|
||||
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
||||
|
||||
def countChunksByFileIds(self, fileIds: List[str]) -> Dict[str, int]:
|
||||
"""Return a {fileId: chunkCount} mapping for the given file IDs.
|
||||
|
||||
One aggregate query instead of N round trips. Used by RAG inventory
|
||||
to display real chunk counts per DataSource without loading the
|
||||
embedding vectors. Missing file IDs map to 0 in the caller's logic.
|
||||
"""
|
||||
if not fileIds:
|
||||
return {}
|
||||
if not self.db._ensureTableExists(ContentChunk):
|
||||
return {}
|
||||
sql = 'SELECT "fileId", COUNT(*) AS cnt FROM "ContentChunk" WHERE "fileId" = ANY(%s) GROUP BY "fileId"'
|
||||
with self.db.borrowCursor() as cursor:
|
||||
cursor.execute(sql, (list(fileIds),))
|
||||
return {row["fileId"]: int(row["cnt"]) for row in cursor.fetchall()}
|
||||
|
||||
def deleteContentChunks(self, fileId: str) -> int:
|
||||
"""Delete all ContentChunks for a file. Returns count of deleted chunks."""
|
||||
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
||||
|
|
|
|||
|
|
@ -1221,22 +1221,17 @@ class ComponentObjects:
|
|||
for item in fileRows
|
||||
]
|
||||
|
||||
# Single transaction: delete FileData, FileItem, then FileFolder (children first)
|
||||
self.db._ensure_connection()
|
||||
try:
|
||||
with self.db.connection.cursor() as cursor:
|
||||
if fileIds:
|
||||
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
|
||||
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
|
||||
orderedIds = list(folderIds)
|
||||
orderedIds.remove(folderId)
|
||||
orderedIds.append(folderId)
|
||||
if orderedIds:
|
||||
cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
|
||||
self.db.connection.commit()
|
||||
except Exception:
|
||||
self.db.connection.rollback()
|
||||
raise
|
||||
# Single transaction: delete FileData, FileItem, then FileFolder (children first).
|
||||
# Commit/rollback are handled by `borrowCursor()` on exit.
|
||||
with self.db.borrowCursor() as cursor:
|
||||
if fileIds:
|
||||
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
|
||||
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
|
||||
orderedIds = list(folderIds)
|
||||
orderedIds.remove(folderId)
|
||||
orderedIds.append(folderId)
|
||||
if orderedIds:
|
||||
cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
|
||||
|
||||
return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)}
|
||||
|
||||
|
|
@ -1507,7 +1502,7 @@ class ComponentObjects:
|
|||
|
||||
try:
|
||||
self.db._ensure_connection()
|
||||
with self.db.connection.cursor() as cursor:
|
||||
with self.db.borrowCursor() as cursor:
|
||||
cursor.execute(
|
||||
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
|
||||
(uniqueIds,),
|
||||
|
|
@ -1526,11 +1521,10 @@ class ComponentObjects:
|
|||
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,))
|
||||
deletedFiles = cursor.rowcount
|
||||
|
||||
self.db.connection.commit()
|
||||
# Commit/rollback are handled by `borrowCursor()` context manager.
|
||||
return {"deletedFiles": deletedFiles}
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting files in batch: {e}")
|
||||
self.db.connection.rollback()
|
||||
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
|
||||
|
||||
def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]:
|
||||
|
|
|
|||
|
|
@ -374,7 +374,7 @@ def getRecordsetWithRBAC(
|
|||
|
||||
query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
|
||||
|
||||
with connector.connection.cursor() as cursor:
|
||||
with connector.borrowCursor() as cursor:
|
||||
cursor.execute(query, whereValues)
|
||||
records = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
|
|
@ -561,7 +561,7 @@ def getRecordsetPaginatedWithRBAC(
|
|||
offset = (pagination.page - 1) * pagination.pageSize
|
||||
limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
|
||||
|
||||
with connector.connection.cursor() as cursor:
|
||||
with connector.borrowCursor() as cursor:
|
||||
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
|
||||
cursor.execute(countSql, countValues)
|
||||
totalItems = cursor.fetchone()["count"]
|
||||
|
|
@ -709,7 +709,7 @@ def getDistinctColumnValuesWithRBAC(
|
|||
|
||||
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val'
|
||||
|
||||
with connector.connection.cursor() as cursor:
|
||||
with connector.borrowCursor() as cursor:
|
||||
cursor.execute(sql, whereValues)
|
||||
result = [row["val"] for row in cursor.fetchall()]
|
||||
|
||||
|
|
@ -719,7 +719,7 @@ def getDistinctColumnValuesWithRBAC(
|
|||
emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1'
|
||||
else:
|
||||
emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1'
|
||||
with connector.connection.cursor() as cursor:
|
||||
with connector.borrowCursor() as cursor:
|
||||
cursor.execute(emptySql, whereValues)
|
||||
if cursor.fetchone():
|
||||
result.append(None)
|
||||
|
|
@ -967,7 +967,7 @@ def buildRbacWhereClause(
|
|||
# Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate
|
||||
if table == "UserInDB":
|
||||
try:
|
||||
with connector.connection.cursor() as cursor:
|
||||
with connector.borrowCursor() as cursor:
|
||||
# Get all user IDs that are members of the current mandate
|
||||
cursor.execute(
|
||||
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
|
||||
|
|
@ -994,7 +994,7 @@ def buildRbacWhereClause(
|
|||
# For UserConnection: Filter via UserMandate junction table
|
||||
elif table == "UserConnection":
|
||||
try:
|
||||
with connector.connection.cursor() as cursor:
|
||||
with connector.borrowCursor() as cursor:
|
||||
# Get all user IDs that are members of the current mandate
|
||||
cursor.execute(
|
||||
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ def handleIdsMode(
|
|||
|
||||
sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"'
|
||||
|
||||
with db.connection.cursor() as cursor:
|
||||
with db.borrowCursor() as cursor:
|
||||
cursor.execute(sql, values)
|
||||
return JSONResponse(content=[row["val"] for row in cursor.fetchall()])
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -25,6 +25,18 @@ router = APIRouter(
|
|||
|
||||
|
||||
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
|
||||
"""Build per-connection RAG inventory rows.
|
||||
|
||||
Each DataSource row exposes BOTH numbers because they mean different things:
|
||||
* `fileCount` — distinct files indexed (== `FileContentIndex` rows)
|
||||
* `chunkCount` — embedding-sized text fragments (== `ContentChunk` rows,
|
||||
max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval
|
||||
actually hits)
|
||||
|
||||
A single PDF typically yields 1 file × 5–100 chunks; legacy UI labelled
|
||||
`len(FileContentIndex)` as "chunks" which was off by 1–2 orders of
|
||||
magnitude and misleading.
|
||||
"""
|
||||
from modules.datamodels.datamodelDataSource import DataSource
|
||||
from modules.datamodels.datamodelKnowledge import FileContentIndex
|
||||
|
||||
|
|
@ -34,19 +46,35 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
|||
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
|
||||
|
||||
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
|
||||
connChunkTotal = len(connIndexRows)
|
||||
connFileTotal = len(connIndexRows)
|
||||
|
||||
# Map fileId → real chunk count via 1 aggregate query (cheap even for
|
||||
# connections with thousands of files; we never load the vector body).
|
||||
fileIds = [
|
||||
(idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", ""))
|
||||
for idx in connIndexRows
|
||||
]
|
||||
fileIds = [fid for fid in fileIds if fid]
|
||||
chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
|
||||
connChunkTotal = sum(chunkCountByFile.values())
|
||||
|
||||
filesByDs: Dict[str, int] = {}
|
||||
chunksByDs: Dict[str, int] = {}
|
||||
unassigned = 0
|
||||
unassignedFiles = 0
|
||||
unassignedChunks = 0
|
||||
for idx in connIndexRows:
|
||||
fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")
|
||||
chunkCnt = chunkCountByFile.get(fileId, 0)
|
||||
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
|
||||
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
|
||||
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
|
||||
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
|
||||
if dsIdRef:
|
||||
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
|
||||
filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1
|
||||
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt
|
||||
else:
|
||||
unassigned += 1
|
||||
unassignedFiles += 1
|
||||
unassignedChunks += chunkCnt
|
||||
|
||||
seen: Dict[str, bool] = {}
|
||||
dsItems = []
|
||||
|
|
@ -64,14 +92,19 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
|||
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
|
||||
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
|
||||
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
|
||||
"fileCount": filesByDs.get(dsId, 0),
|
||||
"chunkCount": chunksByDs.get(dsId, 0),
|
||||
})
|
||||
|
||||
if unassigned > 0 and len(dsItems) > 0:
|
||||
perDs = unassigned // len(dsItems)
|
||||
remainder = unassigned % len(dsItems)
|
||||
# Spread orphan files (provenance lost) evenly so totals match.
|
||||
if unassignedFiles > 0 and len(dsItems) > 0:
|
||||
perFile = unassignedFiles // len(dsItems)
|
||||
remFile = unassignedFiles % len(dsItems)
|
||||
perChunk = unassignedChunks // len(dsItems)
|
||||
remChunk = unassignedChunks % len(dsItems)
|
||||
for i, item in enumerate(dsItems):
|
||||
item["chunkCount"] += perDs + (1 if i < remainder else 0)
|
||||
item["fileCount"] += perFile + (1 if i < remFile else 0)
|
||||
item["chunkCount"] += perChunk + (1 if i < remChunk else 0)
|
||||
|
||||
# Pull a wider window than the previous 5 so the "last successful
|
||||
# sync" is found even if a connection has many recent jobs queued.
|
||||
|
|
@ -102,6 +135,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
|||
"skippedPolicy": result.get("skippedPolicy", 0),
|
||||
"failed": result.get("failed", 0),
|
||||
"durationMs": result.get("durationMs", 0),
|
||||
# Surface limit-stop reason so the UI can warn the user
|
||||
# that the index is provably incomplete (and which budget
|
||||
# to raise). None means the walker finished naturally.
|
||||
"stoppedAtLimit": result.get("stoppedAtLimit"),
|
||||
"limits": result.get("limits") or {},
|
||||
"bytesProcessed": result.get("bytesProcessed", 0),
|
||||
}
|
||||
if lastError and lastSuccess:
|
||||
break
|
||||
|
|
@ -113,6 +152,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
|||
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
|
||||
"preferences": getattr(conn, "knowledgePreferences", None) or {},
|
||||
"dataSources": dsItems,
|
||||
"totalFiles": connFileTotal,
|
||||
"totalChunks": connChunkTotal,
|
||||
"runningJobs": runningJobs,
|
||||
"lastError": lastError,
|
||||
|
|
@ -139,8 +179,9 @@ def _getInventoryMe(
|
|||
|
||||
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
|
||||
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
||||
totalFiles = sum(c.get("totalFiles", 0) for c in items)
|
||||
|
||||
return {"connections": items, "totals": {"chunks": totalChunks}}
|
||||
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
|
||||
except Exception as e:
|
||||
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -170,9 +211,10 @@ def _getInventoryMandate(
|
|||
|
||||
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
|
||||
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
||||
totalFiles = sum(c.get("totalFiles", 0) for c in items)
|
||||
totalBytes = aggregateMandateRagTotalBytes(mandateId)
|
||||
|
||||
return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}}
|
||||
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
@ -202,8 +244,9 @@ def _getInventoryPlatform(
|
|||
|
||||
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
|
||||
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
||||
totalFiles = sum(c.get("totalFiles", 0) for c in items)
|
||||
|
||||
return {"connections": items, "totals": {"chunks": totalChunks}}
|
||||
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ WHERE "workflowId" = ANY(%s)
|
|||
GROUP BY "workflowId"
|
||||
"""
|
||||
out: dict = {}
|
||||
with db.connection.cursor() as cursor:
|
||||
with db.borrowCursor() as cursor:
|
||||
cursor.execute(sql, (workflowIds,))
|
||||
for row in cursor.fetchall():
|
||||
r = dict(row)
|
||||
|
|
@ -480,7 +480,7 @@ def _getWorkflowsJoinedPaginated(
|
|||
dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}"
|
||||
|
||||
db._ensure_connection()
|
||||
with db.connection.cursor() as cursor:
|
||||
with db.borrowCursor() as cursor:
|
||||
cursor.execute(countSql, countValues)
|
||||
totalItems = int(cursor.fetchone()["cnt"])
|
||||
|
||||
|
|
|
|||
|
|
@ -25,15 +25,14 @@ _CACHE_TTL_SECONDS = 300
|
|||
|
||||
|
||||
def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str):
|
||||
"""Reuse a pooled DB connector for the given feature database."""
|
||||
"""Reuse a pooled DB connector for the given feature database.
|
||||
|
||||
The underlying psycopg2 connections live in the central pool
|
||||
(`_PoolRegistry`) and are recreated on demand if they go stale; we just
|
||||
need to keep the lightweight connector wrapper around.
|
||||
"""
|
||||
if featureDbName in _featureDbConnPool:
|
||||
conn = _featureDbConnPool[featureDbName]
|
||||
try:
|
||||
if conn.connection and not conn.connection.closed:
|
||||
return conn
|
||||
except Exception as e:
|
||||
logger.warning(f"Feature DB connection check failed for {featureDbName}: {e}")
|
||||
_featureDbConnPool.pop(featureDbName, None)
|
||||
return _featureDbConnPool[featureDbName]
|
||||
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
|
|
|||
|
|
@ -68,6 +68,9 @@ class ClickupBootstrapResult:
|
|||
workspaces: int = 0
|
||||
lists: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
# First budget exhausted: "maxTasks" | "maxWorkspaces" | "maxListsPerWorkspace" | None.
|
||||
# Drives the same UI banner as the file-walker bootstraps.
|
||||
stoppedAtLimit: Optional[str] = None
|
||||
|
||||
|
||||
def _syntheticTaskId(connectionId: str, taskId: str) -> str:
|
||||
|
|
@ -225,6 +228,7 @@ async def bootstrapClickup(
|
|||
cancelled = False
|
||||
for ds in dataSources:
|
||||
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
|
||||
_recordLimitStop(result, "maxTasks", "dataSource", limits)
|
||||
break
|
||||
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
|
||||
cancelled = True
|
||||
|
|
@ -243,8 +247,11 @@ async def bootstrapClickup(
|
|||
clickupScope=limits.clickupScope,
|
||||
)
|
||||
|
||||
if len(teams) > dsLimits.maxWorkspaces:
|
||||
_recordLimitStop(result, "maxWorkspaces", "teams", dsLimits, hard=False)
|
||||
for team in teams[:dsLimits.maxWorkspaces]:
|
||||
if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks:
|
||||
_recordLimitStop(result, "maxTasks", f"team={team.get('id','')}", dsLimits)
|
||||
break
|
||||
teamId = str(team.get("id", "") or "")
|
||||
if not teamId:
|
||||
|
|
@ -351,6 +358,7 @@ async def _walkTeam(
|
|||
|
||||
for lst in listsCollected:
|
||||
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
|
||||
_recordLimitStop(result, "maxTasks", f"team={teamId}", limits)
|
||||
return
|
||||
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
|
||||
return
|
||||
|
|
@ -407,6 +415,7 @@ async def _walkList(
|
|||
|
||||
for task in tasks:
|
||||
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
|
||||
_recordLimitStop(result, "maxTasks", f"list={listId}", limits)
|
||||
return
|
||||
if not _isRecent(task.get("date_updated"), limits.maxAgeDays):
|
||||
result.skippedPolicy += 1
|
||||
|
|
@ -529,13 +538,37 @@ async def _ingestTask(
|
|||
)
|
||||
|
||||
|
||||
def _recordLimitStop(
|
||||
result: ClickupBootstrapResult,
|
||||
limitName: str,
|
||||
where: str,
|
||||
limits: ClickupBootstrapLimits,
|
||||
*,
|
||||
hard: bool = True,
|
||||
) -> None:
|
||||
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
|
||||
if hard or result.stoppedAtLimit is None:
|
||||
result.stoppedAtLimit = limitName
|
||||
budgetMap = {
|
||||
"maxTasks": limits.maxTasks,
|
||||
"maxWorkspaces": limits.maxWorkspaces,
|
||||
"maxListsPerWorkspace": limits.maxListsPerWorkspace,
|
||||
}
|
||||
logger.warning(
|
||||
"clickup walker hit %s=%s at %s — partial index (indexed=%d, skippedDup=%d).",
|
||||
limitName, budgetMap.get(limitName), where,
|
||||
result.indexed, result.skippedDuplicate,
|
||||
)
|
||||
|
||||
|
||||
def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]:
|
||||
durationMs = int((time.time() - startMs) * 1000)
|
||||
logger.info(
|
||||
"ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d",
|
||||
"ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d stoppedAtLimit=%s",
|
||||
connectionId,
|
||||
result.indexed, result.skippedDuplicate, result.skippedPolicy,
|
||||
result.failed, result.workspaces, result.lists, durationMs,
|
||||
result.stoppedAtLimit or "none",
|
||||
extra={
|
||||
"event": "ingestion.connection.bootstrap.done",
|
||||
"part": "clickup",
|
||||
|
|
@ -547,6 +580,7 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
|
|||
"workspaces": result.workspaces,
|
||||
"lists": result.lists,
|
||||
"durationMs": durationMs,
|
||||
"stoppedAtLimit": result.stoppedAtLimit,
|
||||
},
|
||||
)
|
||||
return {
|
||||
|
|
@ -559,4 +593,11 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
|
|||
"lists": result.lists,
|
||||
"durationMs": durationMs,
|
||||
"errors": result.errors[:20],
|
||||
"stoppedAtLimit": result.stoppedAtLimit,
|
||||
"limits": {
|
||||
"maxTasks": MAX_TASKS_DEFAULT,
|
||||
"maxWorkspaces": MAX_WORKSPACES_DEFAULT,
|
||||
"maxListsPerWorkspace": MAX_LISTS_PER_WORKSPACE_DEFAULT,
|
||||
"maxAgeDays": MAX_AGE_DAYS_DEFAULT,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,6 +61,8 @@ class GdriveBootstrapResult:
|
|||
failed: int = 0
|
||||
bytesProcessed: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
|
||||
stoppedAtLimit: Optional[str] = None
|
||||
|
||||
|
||||
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
|
||||
|
|
@ -265,8 +267,10 @@ async def _walkFolder(
|
|||
|
||||
for entry in entries:
|
||||
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
||||
_recordLimitStop(result, "maxItems", folderPath, limits)
|
||||
return
|
||||
if result.bytesProcessed >= limits.maxBytes:
|
||||
_recordLimitStop(result, "maxBytes", folderPath, limits)
|
||||
return
|
||||
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
|
||||
return
|
||||
|
|
@ -276,6 +280,9 @@ async def _walkFolder(
|
|||
mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType")
|
||||
|
||||
if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME:
|
||||
if depth + 1 > limits.maxDepth:
|
||||
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
|
||||
continue
|
||||
await _walkFolder(
|
||||
adapter=adapter,
|
||||
knowledgeService=knowledgeService,
|
||||
|
|
@ -298,6 +305,7 @@ async def _walkFolder(
|
|||
continue
|
||||
size = int(getattr(entry, "size", 0) or 0)
|
||||
if size and size > limits.maxFileSize:
|
||||
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
|
||||
result.skippedPolicy += 1
|
||||
continue
|
||||
modifiedTime = metadata.get("modifiedTime")
|
||||
|
|
@ -470,13 +478,38 @@ async def _ingestOne(
|
|||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
def _recordLimitStop(
|
||||
result: GdriveBootstrapResult,
|
||||
limitName: str,
|
||||
where: str,
|
||||
limits: GdriveBootstrapLimits,
|
||||
*,
|
||||
hard: bool = True,
|
||||
) -> None:
|
||||
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
|
||||
if hard or result.stoppedAtLimit is None:
|
||||
result.stoppedAtLimit = limitName
|
||||
budgetMap = {
|
||||
"maxItems": limits.maxItems,
|
||||
"maxBytes": limits.maxBytes,
|
||||
"maxDepth": limits.maxDepth,
|
||||
"maxFileSize": limits.maxFileSize,
|
||||
}
|
||||
logger.warning(
|
||||
"gdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
|
||||
limitName, budgetMap.get(limitName), where,
|
||||
result.indexed, result.bytesProcessed,
|
||||
)
|
||||
|
||||
|
||||
def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
|
||||
durationMs = int((time.time() - startMs) * 1000)
|
||||
logger.info(
|
||||
"ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d",
|
||||
"ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d stoppedAtLimit=%s",
|
||||
connectionId,
|
||||
result.indexed, result.skippedDuplicate, result.skippedPolicy,
|
||||
result.failed, result.bytesProcessed, durationMs,
|
||||
result.stoppedAtLimit or "none",
|
||||
extra={
|
||||
"event": "ingestion.connection.bootstrap.done",
|
||||
"part": "gdrive",
|
||||
|
|
@ -487,6 +520,7 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
|
|||
"failed": result.failed,
|
||||
"bytes": result.bytesProcessed,
|
||||
"durationMs": durationMs,
|
||||
"stoppedAtLimit": result.stoppedAtLimit,
|
||||
},
|
||||
)
|
||||
return {
|
||||
|
|
@ -498,4 +532,11 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
|
|||
"bytesProcessed": result.bytesProcessed,
|
||||
"durationMs": durationMs,
|
||||
"errors": result.errors[:20],
|
||||
"stoppedAtLimit": result.stoppedAtLimit,
|
||||
"limits": {
|
||||
"maxItems": MAX_ITEMS_DEFAULT,
|
||||
"maxBytes": MAX_BYTES_DEFAULT,
|
||||
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
|
||||
"maxDepth": MAX_DEPTH_DEFAULT,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,6 +53,8 @@ class KdriveBootstrapResult:
|
|||
failed: int = 0
|
||||
bytesProcessed: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
|
||||
stoppedAtLimit: Optional[str] = None
|
||||
|
||||
|
||||
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
|
||||
|
|
@ -232,14 +234,19 @@ async def _walkFolder(
|
|||
|
||||
for entry in entries:
|
||||
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
||||
_recordLimitStop(result, "maxItems", folderPath, limits)
|
||||
return
|
||||
if result.bytesProcessed >= limits.maxBytes:
|
||||
_recordLimitStop(result, "maxBytes", folderPath, limits)
|
||||
return
|
||||
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
|
||||
return
|
||||
|
||||
entryPath = getattr(entry, "path", "") or ""
|
||||
if getattr(entry, "isFolder", False):
|
||||
if depth + 1 > limits.maxDepth:
|
||||
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
|
||||
continue
|
||||
await _walkFolder(
|
||||
adapter=adapter,
|
||||
knowledgeService=knowledgeService,
|
||||
|
|
@ -262,6 +269,7 @@ async def _walkFolder(
|
|||
continue
|
||||
size = int(getattr(entry, "size", 0) or 0)
|
||||
if size and size > limits.maxFileSize:
|
||||
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
|
||||
result.skippedPolicy += 1
|
||||
continue
|
||||
|
||||
|
|
@ -415,17 +423,42 @@ async def _ingestOne(
|
|||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
def _recordLimitStop(
|
||||
result: KdriveBootstrapResult,
|
||||
limitName: str,
|
||||
where: str,
|
||||
limits: KdriveBootstrapLimits,
|
||||
*,
|
||||
hard: bool = True,
|
||||
) -> None:
|
||||
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
|
||||
if hard or result.stoppedAtLimit is None:
|
||||
result.stoppedAtLimit = limitName
|
||||
budgetMap = {
|
||||
"maxItems": limits.maxItems,
|
||||
"maxBytes": limits.maxBytes,
|
||||
"maxDepth": limits.maxDepth,
|
||||
"maxFileSize": limits.maxFileSize,
|
||||
}
|
||||
logger.warning(
|
||||
"kdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
|
||||
limitName, budgetMap.get(limitName), where,
|
||||
result.indexed, result.bytesProcessed,
|
||||
)
|
||||
|
||||
|
||||
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
|
||||
durationMs = int((time.time() - startMs) * 1000)
|
||||
logger.info(
|
||||
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
|
||||
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
|
||||
connectionId,
|
||||
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
|
||||
durationMs,
|
||||
durationMs, result.stoppedAtLimit or "none",
|
||||
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
|
||||
"connectionId": connectionId, "indexed": result.indexed,
|
||||
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
|
||||
"failed": result.failed, "durationMs": durationMs},
|
||||
"failed": result.failed, "durationMs": durationMs,
|
||||
"stoppedAtLimit": result.stoppedAtLimit},
|
||||
)
|
||||
return {
|
||||
"connectionId": result.connectionId,
|
||||
|
|
@ -436,4 +469,11 @@ def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: f
|
|||
"bytesProcessed": result.bytesProcessed,
|
||||
"durationMs": durationMs,
|
||||
"errors": result.errors[:20],
|
||||
"stoppedAtLimit": result.stoppedAtLimit,
|
||||
"limits": {
|
||||
"maxItems": MAX_ITEMS_DEFAULT,
|
||||
"maxBytes": MAX_BYTES_DEFAULT,
|
||||
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
|
||||
"maxDepth": MAX_DEPTH_DEFAULT,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,6 +59,10 @@ class SharepointBootstrapResult:
|
|||
failed: int = 0
|
||||
bytesProcessed: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
# First budget that hit zero; None means the walk completed naturally.
|
||||
# Surfaces in the bootstrap result so the RAG inventory UI can warn the
|
||||
# user that the corpus is incomplete and tell them which knob to turn.
|
||||
stoppedAtLimit: Optional[str] = None # "maxItems" | "maxBytes" | "maxDepth" | "maxFileSize" | None
|
||||
|
||||
|
||||
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
|
||||
|
|
@ -259,14 +263,22 @@ async def _walkFolder(
|
|||
|
||||
for entry in entries:
|
||||
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
||||
_recordLimitStop(result, "maxItems", folderPath, limits)
|
||||
return
|
||||
if result.bytesProcessed >= limits.maxBytes:
|
||||
_recordLimitStop(result, "maxBytes", folderPath, limits)
|
||||
return
|
||||
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
|
||||
return
|
||||
|
||||
entryPath = getattr(entry, "path", "") or ""
|
||||
if getattr(entry, "isFolder", False):
|
||||
if depth + 1 > limits.maxDepth:
|
||||
# We stop descending here but keep walking siblings.
|
||||
# Record once per bootstrap so the UI shows "maxDepth" even
|
||||
# if other budgets aren't exhausted yet.
|
||||
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
|
||||
continue
|
||||
await _walkFolder(
|
||||
adapter=adapter,
|
||||
knowledgeService=knowledgeService,
|
||||
|
|
@ -289,6 +301,7 @@ async def _walkFolder(
|
|||
continue
|
||||
size = int(getattr(entry, "size", 0) or 0)
|
||||
if size and size > limits.maxFileSize:
|
||||
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
|
||||
result.skippedPolicy += 1
|
||||
continue
|
||||
|
||||
|
|
@ -443,13 +456,44 @@ async def _ingestOne(
|
|||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
def _recordLimitStop(
|
||||
result: SharepointBootstrapResult,
|
||||
limitName: str,
|
||||
where: str,
|
||||
limits: SharepointBootstrapLimits,
|
||||
*,
|
||||
hard: bool = True,
|
||||
) -> None:
|
||||
"""Mark the FIRST limit that bit. Soft hits (per-file maxFileSize, per-folder
|
||||
maxDepth) only record when no hard limit has yet stopped the run, so the UI
|
||||
surfaces the most important reason.
|
||||
|
||||
Hard limits (maxItems / maxBytes) ALWAYS overwrite a previously recorded
|
||||
soft limit — once a hard cap is hit, the corpus is provably incomplete.
|
||||
"""
|
||||
if hard or result.stoppedAtLimit is None:
|
||||
result.stoppedAtLimit = limitName
|
||||
budgetMap = {
|
||||
"maxItems": limits.maxItems,
|
||||
"maxBytes": limits.maxBytes,
|
||||
"maxDepth": limits.maxDepth,
|
||||
"maxFileSize": limits.maxFileSize,
|
||||
}
|
||||
logger.warning(
|
||||
"sharepoint walker hit %s=%s at %s — partial index "
|
||||
"(indexed=%d, bytesProcessed=%d). Raise the limit or split the data source.",
|
||||
limitName, budgetMap.get(limitName), where,
|
||||
result.indexed, result.bytesProcessed,
|
||||
)
|
||||
|
||||
|
||||
def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]:
|
||||
durationMs = int((time.time() - startMs) * 1000)
|
||||
logger.info(
|
||||
"ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
|
||||
"ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
|
||||
connectionId,
|
||||
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
|
||||
durationMs,
|
||||
durationMs, result.stoppedAtLimit or "none",
|
||||
extra={
|
||||
"event": "ingestion.connection.bootstrap.done",
|
||||
"part": "sharepoint",
|
||||
|
|
@ -459,6 +503,7 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
|
|||
"skippedPolicy": result.skippedPolicy,
|
||||
"failed": result.failed,
|
||||
"durationMs": durationMs,
|
||||
"stoppedAtLimit": result.stoppedAtLimit,
|
||||
},
|
||||
)
|
||||
return {
|
||||
|
|
@ -470,4 +515,11 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
|
|||
"bytesProcessed": result.bytesProcessed,
|
||||
"durationMs": durationMs,
|
||||
"errors": result.errors[:20],
|
||||
"stoppedAtLimit": result.stoppedAtLimit,
|
||||
"limits": {
|
||||
"maxItems": MAX_ITEMS_DEFAULT,
|
||||
"maxBytes": MAX_BYTES_DEFAULT,
|
||||
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
|
||||
"maxDepth": MAX_DEPTH_DEFAULT,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ import logging
|
|||
import json
|
||||
import base64
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
import threading
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
|
|
@ -286,6 +287,16 @@ def handleSecretJson(value: str, userId: str = "system", keyName: str = "unknown
|
|||
# Structure: {user_id: {key_name: [timestamps]}}
|
||||
_decryption_attempts = {}
|
||||
|
||||
# Process-wide plaintext cache for decrypted secrets.
|
||||
# Key: the encrypted ciphertext (which already includes env prefix).
|
||||
# Value: (expiresAtMonotonic, plaintext).
|
||||
# TTL is short enough that key rotation propagates quickly, long enough that
|
||||
# hot DB-init paths (every API call building a connector) don't blow the
|
||||
# decryption rate limit. 60s is a deliberate compromise.
|
||||
_DECRYPTION_CACHE_TTL_S = 60.0
|
||||
_decryption_cache: Dict[str, Tuple[float, str]] = {}
|
||||
_decryption_cache_lock = threading.Lock()
|
||||
|
||||
def _getMasterKey(envType: str = None) -> bytes:
|
||||
"""
|
||||
Get the master key for the specified environment.
|
||||
|
|
@ -486,25 +497,43 @@ def encryptValue(value: str, envType: str = None, userId: str = "system", keyNam
|
|||
def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "unknown") -> str:
|
||||
"""
|
||||
Decrypt a value using the master key for the current environment.
|
||||
|
||||
|
||||
A short-lived plaintext cache (TTL `_DECRYPTION_CACHE_TTL_S`) is consulted
|
||||
first. The 10/sec rate-limit on cache misses still protects against
|
||||
brute-force attacks; cache HITS bypass it because they are not actual
|
||||
cryptographic operations — they just return the result of an earlier
|
||||
successful decrypt. Without this cache, hot paths like
|
||||
`mainBackgroundJobService._getDb()` (called per RAG inventory poll AND
|
||||
per walker DB call) trigger the rate limit and surface as
|
||||
"Decryption rate limit exceeded for user 'system' key 'DB_PASSWORD_SECRET'"
|
||||
ERRORs in the RAG inventory UI route.
|
||||
|
||||
Args:
|
||||
encryptedValue: The encrypted value with prefix
|
||||
userId: The user ID making the request (default: "system")
|
||||
keyName: The name of the key being decrypted (default: "unknown")
|
||||
|
||||
|
||||
Returns:
|
||||
str: The decrypted plain text value
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If decryption fails
|
||||
"""
|
||||
if not _isEncryptedValue(encryptedValue):
|
||||
return encryptedValue # Return as-is if not encrypted
|
||||
|
||||
# Check rate limiting (10 per second per user per key)
|
||||
|
||||
# Cache lookup BEFORE the rate-limit check: a cache hit is not a new
|
||||
# cryptographic operation and must not be throttled.
|
||||
now = time.monotonic()
|
||||
with _decryption_cache_lock:
|
||||
cached = _decryption_cache.get(encryptedValue)
|
||||
if cached is not None and cached[0] > now:
|
||||
return cached[1]
|
||||
|
||||
# Cache miss → real decrypt → apply rate limit.
|
||||
if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10):
|
||||
raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)")
|
||||
|
||||
|
||||
try:
|
||||
# Extract environment type from prefix
|
||||
if encryptedValue.startswith('DEV_ENC:'):
|
||||
|
|
@ -536,7 +565,7 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
|
|||
encryptedBytes = base64.urlsafe_b64decode(encryptedPart.encode('utf-8'))
|
||||
decryptedBytes = fernet.decrypt(encryptedBytes)
|
||||
decryptedValue = decryptedBytes.decode('utf-8')
|
||||
|
||||
|
||||
# Log audit event for decryption
|
||||
try:
|
||||
from modules.shared.auditLogger import audit_logger
|
||||
|
|
@ -549,11 +578,25 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
|
|||
except Exception:
|
||||
# Don't fail if audit logging fails
|
||||
pass
|
||||
|
||||
|
||||
# Populate cache so subsequent reads of the same ciphertext don't
|
||||
# re-decrypt (and don't consume rate-limit budget).
|
||||
with _decryption_cache_lock:
|
||||
_decryption_cache[encryptedValue] = (
|
||||
time.monotonic() + _DECRYPTION_CACHE_TTL_S,
|
||||
decryptedValue,
|
||||
)
|
||||
|
||||
return decryptedValue
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Decryption failed: {e}")
|
||||
|
||||
|
||||
def clearDecryptionCache() -> None:
|
||||
"""Drop all cached plaintext secrets. Call after key rotation or in tests."""
|
||||
with _decryption_cache_lock:
|
||||
_decryption_cache.clear()
|
||||
|
||||
# Create the global APP_CONFIG instance
|
||||
APP_CONFIG = Configuration()
|
||||
|
|
@ -33,20 +33,35 @@ def _ensureUamTablesMatchModels(dbConnector) -> None:
|
|||
logger.debug(f"_ensureUamTablesMatchModels: {e}")
|
||||
|
||||
|
||||
def _getConnection(dbConnector):
|
||||
"""Get a connection from the DatabaseConnector.
|
||||
|
||||
Ensures the connection is alive and returns it.
|
||||
Commits any pending transaction first to avoid blocking.
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _borrowDbConn(dbConnector):
|
||||
"""Borrow a pooled connection from the DatabaseConnector.
|
||||
|
||||
Index/trigger/FK creation traditionally ran with `conn.autocommit = True`
|
||||
so each CREATE statement is its own transaction (DDL on a managed
|
||||
connection blocks waiting for COMMIT). This helper preserves that
|
||||
behaviour on top of the pool: borrow a connection, flip it to autocommit,
|
||||
yield it, and restore the previous state before returning it to the pool.
|
||||
"""
|
||||
dbConnector._ensure_connection()
|
||||
conn = dbConnector.connection
|
||||
# Commit any pending transaction to avoid blocking
|
||||
try:
|
||||
conn.commit()
|
||||
except Exception:
|
||||
pass # Ignore if nothing to commit
|
||||
return conn
|
||||
with dbConnector.borrowConn() as conn:
|
||||
try:
|
||||
previousAutocommit = conn.autocommit
|
||||
except Exception:
|
||||
previousAutocommit = False
|
||||
try:
|
||||
conn.autocommit = True
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not set autocommit on borrowed connection: {e}")
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
try:
|
||||
conn.autocommit = previousAutocommit
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -174,73 +189,42 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
|
|||
}
|
||||
|
||||
try:
|
||||
# Get a connection from the connector
|
||||
conn = _getConnection(dbConnector)
|
||||
|
||||
# Save and set autocommit state
|
||||
try:
|
||||
originalAutocommit = conn.autocommit
|
||||
except Exception:
|
||||
originalAutocommit = False
|
||||
|
||||
try:
|
||||
conn.autocommit = True
|
||||
except Exception as autoErr:
|
||||
logger.debug(f"Could not set autocommit: {autoErr}")
|
||||
|
||||
try:
|
||||
_ensureUamTablesMatchModels(dbConnector)
|
||||
except Exception as preIdxErr:
|
||||
logger.debug(f"Pre-index table ensure: {preIdxErr}")
|
||||
|
||||
try:
|
||||
|
||||
with _borrowDbConn(dbConnector) as conn:
|
||||
with conn.cursor() as cursor:
|
||||
# Apply indexes
|
||||
results["indexesCreated"] = _applyIndexes(cursor, tables)
|
||||
|
||||
# Apply foreign keys
|
||||
results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables)
|
||||
|
||||
# Apply immutable triggers
|
||||
results["triggersCreated"] = _applyImmutableTriggers(cursor, tables)
|
||||
|
||||
logger.info(
|
||||
f"Multi-tenant optimizations applied: "
|
||||
f"{results['indexesCreated']} indexes, "
|
||||
f"{results['triggersCreated']} triggers, "
|
||||
f"{results['foreignKeysCreated']} foreign keys"
|
||||
)
|
||||
finally:
|
||||
# Restore original autocommit state
|
||||
try:
|
||||
conn.autocommit = originalAutocommit
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Multi-tenant optimizations applied: "
|
||||
f"{results['indexesCreated']} indexes, "
|
||||
f"{results['triggersCreated']} triggers, "
|
||||
f"{results['foreignKeysCreated']} foreign keys"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}")
|
||||
results["errors"].append(str(e))
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int:
|
||||
"""Apply only indexes (lighter operation, safe for frequent calls)."""
|
||||
try:
|
||||
conn = _getConnection(dbConnector)
|
||||
originalAutocommit = conn.autocommit
|
||||
conn.autocommit = True
|
||||
|
||||
try:
|
||||
_ensureUamTablesMatchModels(dbConnector)
|
||||
except Exception as preIdxErr:
|
||||
logger.debug(f"Pre-index table ensure: {preIdxErr}")
|
||||
|
||||
try:
|
||||
|
||||
with _borrowDbConn(dbConnector) as conn:
|
||||
with conn.cursor() as cursor:
|
||||
return _applyIndexes(cursor, tables)
|
||||
finally:
|
||||
conn.autocommit = originalAutocommit
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying indexes: {e}")
|
||||
return 0
|
||||
|
|
@ -514,8 +498,7 @@ def getOptimizationStatus(dbConnector) -> dict:
|
|||
}
|
||||
|
||||
try:
|
||||
conn = _getConnection(dbConnector)
|
||||
with conn.cursor() as cursor:
|
||||
with _borrowDbConn(dbConnector) as conn, conn.cursor() as cursor:
|
||||
# Check regular indexes
|
||||
for tableName, indexName, _ in _INDEXES:
|
||||
if _tableExists(cursor, tableName):
|
||||
|
|
|
|||
|
|
@ -60,11 +60,9 @@ def _getTableColumns(dbConnector, tableName: str) -> List[str]:
|
|||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
cursor = dbConnector.connection.cursor()
|
||||
cursor.execute(query, (tableName,))
|
||||
columns = [row[0] for row in cursor.fetchall()]
|
||||
cursor.close()
|
||||
|
||||
with dbConnector.borrowCursor() as cursor:
|
||||
cursor.execute(query, (tableName,))
|
||||
columns = [row[0] for row in cursor.fetchall()]
|
||||
return columns
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting columns for table {tableName}: {e}")
|
||||
|
|
@ -92,29 +90,26 @@ def _getAllTables(dbConnector) -> List[str]:
|
|||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
cursor = dbConnector.connection.cursor()
|
||||
cursor.execute(query)
|
||||
allTables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
# Get foreign key relationships to determine dependency order
|
||||
fkQuery = """
|
||||
SELECT
|
||||
tc.table_name,
|
||||
ccu.table_name AS foreign_table_name
|
||||
FROM information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_schema = 'public'
|
||||
"""
|
||||
|
||||
cursor.execute(fkQuery)
|
||||
foreignKeys = cursor.fetchall()
|
||||
cursor.close()
|
||||
with dbConnector.borrowCursor() as cursor:
|
||||
cursor.execute(query)
|
||||
allTables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
fkQuery = """
|
||||
SELECT
|
||||
tc.table_name,
|
||||
ccu.table_name AS foreign_table_name
|
||||
FROM information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_schema = 'public'
|
||||
"""
|
||||
cursor.execute(fkQuery)
|
||||
foreignKeys = cursor.fetchall()
|
||||
|
||||
# Build dependency graph (child -> parent mapping)
|
||||
dependencies = {}
|
||||
|
|
@ -154,10 +149,9 @@ def _getAllTables(dbConnector) -> List[str]:
|
|||
# Fallback: return simple list without ordering
|
||||
try:
|
||||
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
|
||||
cursor = dbConnector.connection.cursor()
|
||||
cursor.execute(query)
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
cursor.close()
|
||||
with dbConnector.borrowCursor() as cursor:
|
||||
cursor.execute(query)
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
return [t for t in tables if t not in PROTECTED_TABLES]
|
||||
except Exception:
|
||||
return []
|
||||
|
|
@ -184,11 +178,9 @@ def _getPrimaryKeyColumns(dbConnector, tableName: str) -> List[str]:
|
|||
AND i.indisprimary
|
||||
"""
|
||||
|
||||
cursor = dbConnector.connection.cursor()
|
||||
cursor.execute(query, (tableName,))
|
||||
pkColumns = [row[0] for row in cursor.fetchall()]
|
||||
cursor.close()
|
||||
|
||||
with dbConnector.borrowCursor() as cursor:
|
||||
cursor.execute(query, (tableName,))
|
||||
pkColumns = [row[0] for row in cursor.fetchall()]
|
||||
return pkColumns
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get primary key for {tableName}: {e}")
|
||||
|
|
@ -229,21 +221,15 @@ def _findUserReferencesInTable(
|
|||
return {}
|
||||
|
||||
references = {}
|
||||
cursor = dbConnector.connection.cursor()
|
||||
|
||||
for userColumn in userColumns:
|
||||
# Build SELECT for primary key columns
|
||||
pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
|
||||
query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
|
||||
|
||||
cursor.execute(query, (userId,))
|
||||
recordKeys = cursor.fetchall()
|
||||
|
||||
if recordKeys:
|
||||
references[userColumn] = recordKeys
|
||||
logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
|
||||
|
||||
cursor.close()
|
||||
with dbConnector.borrowCursor() as cursor:
|
||||
for userColumn in userColumns:
|
||||
pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
|
||||
query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
|
||||
cursor.execute(query, (userId,))
|
||||
recordKeys = cursor.fetchall()
|
||||
if recordKeys:
|
||||
references[userColumn] = recordKeys
|
||||
logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
|
||||
return references
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -277,42 +263,35 @@ def _anonymizeRecords(
|
|||
return 0
|
||||
|
||||
try:
|
||||
cursor = dbConnector.connection.cursor()
|
||||
# Resolve column metadata once outside the borrow block (it borrows its
|
||||
# own connection internally).
|
||||
columns = _getTableColumns(dbConnector, tableName)
|
||||
hasModifiedAt = "sysModifiedAt" in columns
|
||||
|
||||
count = 0
|
||||
|
||||
for recordKey in recordKeys:
|
||||
# Build WHERE clause for primary key
|
||||
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
|
||||
|
||||
# Check if table has sysModifiedAt column
|
||||
columns = _getTableColumns(dbConnector, tableName)
|
||||
hasModifiedAt = "sysModifiedAt" in columns
|
||||
|
||||
if hasModifiedAt:
|
||||
query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
|
||||
params = [anonymousValue, getUtcTimestamp()]
|
||||
else:
|
||||
query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
|
||||
params = [anonymousValue]
|
||||
|
||||
# Add primary key values to params
|
||||
if isinstance(recordKey, tuple):
|
||||
params.extend(recordKey)
|
||||
else:
|
||||
params.append(recordKey)
|
||||
|
||||
cursor.execute(query, params)
|
||||
count += cursor.rowcount
|
||||
|
||||
dbConnector.connection.commit()
|
||||
cursor.close()
|
||||
|
||||
with dbConnector.borrowCursor() as cursor:
|
||||
for recordKey in recordKeys:
|
||||
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
|
||||
if hasModifiedAt:
|
||||
query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
|
||||
params = [anonymousValue, getUtcTimestamp()]
|
||||
else:
|
||||
query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
|
||||
params = [anonymousValue]
|
||||
|
||||
if isinstance(recordKey, tuple):
|
||||
params.extend(recordKey)
|
||||
else:
|
||||
params.append(recordKey)
|
||||
|
||||
cursor.execute(query, params)
|
||||
count += cursor.rowcount
|
||||
|
||||
logger.info(f"Anonymized {count} records in {tableName}.{columnName}")
|
||||
return count
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}")
|
||||
dbConnector.connection.rollback()
|
||||
return 0
|
||||
|
||||
|
||||
|
|
@ -338,32 +317,23 @@ def _deleteRecords(
|
|||
return 0
|
||||
|
||||
try:
|
||||
cursor = dbConnector.connection.cursor()
|
||||
count = 0
|
||||
|
||||
for recordKey in recordKeys:
|
||||
# Build WHERE clause for primary key
|
||||
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
|
||||
query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
|
||||
|
||||
# Prepare params
|
||||
if isinstance(recordKey, tuple):
|
||||
params = list(recordKey)
|
||||
else:
|
||||
params = [recordKey]
|
||||
|
||||
cursor.execute(query, params)
|
||||
count += cursor.rowcount
|
||||
|
||||
dbConnector.connection.commit()
|
||||
cursor.close()
|
||||
|
||||
with dbConnector.borrowCursor() as cursor:
|
||||
for recordKey in recordKeys:
|
||||
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
|
||||
query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
|
||||
if isinstance(recordKey, tuple):
|
||||
params = list(recordKey)
|
||||
else:
|
||||
params = [recordKey]
|
||||
cursor.execute(query, params)
|
||||
count += cursor.rowcount
|
||||
|
||||
logger.info(f"Deleted {count} records from {tableName}")
|
||||
return count
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting records from {tableName}: {e}")
|
||||
dbConnector.connection.rollback()
|
||||
return 0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ if not c or not c.connection:
|
|||
print("STAGE0: DB_CONNECTION=none (check config.ini / .env)")
|
||||
raise SystemExit(2)
|
||||
|
||||
cur = c.connection.cursor()
|
||||
cur = c.borrowCursor()
|
||||
|
||||
|
||||
def _scalar(cur):
|
||||
|
|
|
|||
|
|
@ -12,11 +12,16 @@ broken query into "no rows found". That hid bugs like:
|
|||
|
||||
These tests pin the new contract: empty result sets still return ``[]`` /
|
||||
``None`` (normal), but any exception inside the query path propagates as
|
||||
``DatabaseQueryError`` with the table name attached. The transaction is
|
||||
rolled back so the connection is usable for subsequent queries.
|
||||
``DatabaseQueryError`` with the table name attached.
|
||||
|
||||
Since the 2026-05-17 pool refactor (`c-work/2-build/2026-05-postgres-connection-pool.md`)
|
||||
the connector borrows a connection from `_PoolRegistry` on every call via the
|
||||
`borrowConn()` context manager. The tests mock that context manager so the
|
||||
fast-fail contract is exercised without requiring a live Postgres server.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
|
@ -25,7 +30,6 @@ import psycopg2.errors
|
|||
from modules.connectors.connectorDbPostgre import (
|
||||
DatabaseConnector,
|
||||
DatabaseQueryError,
|
||||
_rollbackQuietly,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -39,26 +43,44 @@ class DummyTable:
|
|||
|
||||
|
||||
def _makeConnector(cursorBehavior):
|
||||
"""Build a ``DatabaseConnector`` skeleton with mocked connection/cursor.
|
||||
"""Build a ``DatabaseConnector`` skeleton with a mocked pool borrow.
|
||||
|
||||
``cursorBehavior`` is a callable invoked with the cursor mock so the test
|
||||
can configure ``execute``/``fetchall``/``fetchone`` per scenario.
|
||||
|
||||
Returns ``(connector, conn, cursor)``:
|
||||
* ``conn`` exposes ``commit`` / ``rollback`` MagicMocks so tests can
|
||||
assert that the borrow lifecycle did the right thing.
|
||||
* ``cursor`` is the per-test cursor mock.
|
||||
"""
|
||||
connector = DatabaseConnector.__new__(DatabaseConnector)
|
||||
|
||||
cursor = MagicMock()
|
||||
cursorBehavior(cursor)
|
||||
|
||||
cursorContext = MagicMock()
|
||||
cursorContext.__enter__ = MagicMock(return_value=cursor)
|
||||
cursorContext.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
connection = MagicMock()
|
||||
connection.cursor.return_value = cursorContext
|
||||
connector.connection = connection
|
||||
conn = MagicMock()
|
||||
conn.cursor.return_value = cursorContext
|
||||
|
||||
@contextmanager
|
||||
def fakeBorrow():
|
||||
try:
|
||||
yield conn
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
else:
|
||||
conn.commit()
|
||||
|
||||
connector.borrowConn = fakeBorrow
|
||||
|
||||
connector._ensureTableExists = MagicMock(return_value=True)
|
||||
connector._systemTableName = "_system"
|
||||
|
||||
cursorBehavior(cursor)
|
||||
return connector, connection, cursor
|
||||
return connector, conn, cursor
|
||||
|
||||
|
||||
class TestGetRecordsetFailLoud:
|
||||
|
|
@ -67,11 +89,12 @@ class TestGetRecordsetFailLoud:
|
|||
def behavior(cursor):
|
||||
cursor.execute.return_value = None
|
||||
cursor.fetchall.return_value = []
|
||||
connector, connection, _ = _makeConnector(behavior)
|
||||
connector, conn, _ = _makeConnector(behavior)
|
||||
|
||||
result = connector.getRecordset(DummyTable)
|
||||
assert result == []
|
||||
connection.rollback.assert_not_called()
|
||||
conn.rollback.assert_not_called()
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
def test_dictAdaptErrorRaisesDatabaseQueryError(self):
|
||||
"""Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise."""
|
||||
|
|
@ -79,7 +102,7 @@ class TestGetRecordsetFailLoud:
|
|||
cursor.execute.side_effect = psycopg2.ProgrammingError(
|
||||
"can't adapt type 'dict'"
|
||||
)
|
||||
connector, connection, _ = _makeConnector(behavior)
|
||||
connector, conn, _ = _makeConnector(behavior)
|
||||
|
||||
with pytest.raises(DatabaseQueryError) as excinfo:
|
||||
connector.getRecordset(
|
||||
|
|
@ -90,30 +113,30 @@ class TestGetRecordsetFailLoud:
|
|||
assert excinfo.value.table == "DummyTable"
|
||||
assert "can't adapt type 'dict'" in str(excinfo.value)
|
||||
assert isinstance(excinfo.value.original, psycopg2.ProgrammingError)
|
||||
connection.rollback.assert_called_once()
|
||||
conn.rollback.assert_called_once()
|
||||
|
||||
def test_missingColumnRaisesDatabaseQueryError(self):
|
||||
def behavior(cursor):
|
||||
cursor.execute.side_effect = psycopg2.errors.UndefinedColumn(
|
||||
'column "wat" does not exist'
|
||||
)
|
||||
connector, connection, _ = _makeConnector(behavior)
|
||||
connector, conn, _ = _makeConnector(behavior)
|
||||
|
||||
with pytest.raises(DatabaseQueryError) as excinfo:
|
||||
connector.getRecordset(DummyTable, recordFilter={"wat": "x"})
|
||||
|
||||
assert "wat" in str(excinfo.value)
|
||||
connection.rollback.assert_called_once()
|
||||
conn.rollback.assert_called_once()
|
||||
|
||||
def test_operationalErrorRaisesDatabaseQueryError(self):
|
||||
"""Connection lost mid-query is also a real failure that must propagate."""
|
||||
def behavior(cursor):
|
||||
cursor.execute.side_effect = psycopg2.OperationalError("connection lost")
|
||||
connector, connection, _ = _makeConnector(behavior)
|
||||
connector, conn, _ = _makeConnector(behavior)
|
||||
|
||||
with pytest.raises(DatabaseQueryError):
|
||||
connector.getRecordset(DummyTable)
|
||||
connection.rollback.assert_called_once()
|
||||
conn.rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestGetRecordFailLoud:
|
||||
|
|
@ -122,37 +145,22 @@ class TestGetRecordFailLoud:
|
|||
def behavior(cursor):
|
||||
cursor.execute.return_value = None
|
||||
cursor.fetchone.return_value = None
|
||||
connector, connection, _ = _makeConnector(behavior)
|
||||
connector, conn, _ = _makeConnector(behavior)
|
||||
|
||||
result = connector.getRecord(DummyTable, "missing-id")
|
||||
assert result is None
|
||||
connection.rollback.assert_not_called()
|
||||
conn.rollback.assert_not_called()
|
||||
conn.commit.assert_called_once()
|
||||
|
||||
def test_queryErrorRaisesDatabaseQueryError(self):
|
||||
def behavior(cursor):
|
||||
cursor.execute.side_effect = psycopg2.errors.UndefinedTable(
|
||||
'relation "DummyTable" does not exist'
|
||||
)
|
||||
connector, connection, _ = _makeConnector(behavior)
|
||||
connector, conn, _ = _makeConnector(behavior)
|
||||
|
||||
with pytest.raises(DatabaseQueryError) as excinfo:
|
||||
connector.getRecord(DummyTable, "any-id")
|
||||
|
||||
assert excinfo.value.table == "DummyTable"
|
||||
connection.rollback.assert_called_once()
|
||||
|
||||
|
||||
class TestRollbackQuietly:
|
||||
def test_rollsBackOnLiveConnection(self):
|
||||
connection = MagicMock()
|
||||
_rollbackQuietly(connection)
|
||||
connection.rollback.assert_called_once()
|
||||
|
||||
def test_swallowsRollbackError(self):
|
||||
"""Rollback failure must not mask the original query error."""
|
||||
connection = MagicMock()
|
||||
connection.rollback.side_effect = RuntimeError("rollback failed")
|
||||
_rollbackQuietly(connection)
|
||||
|
||||
def test_noopOnNoneConnection(self):
|
||||
_rollbackQuietly(None)
|
||||
conn.rollback.assert_called_once()
|
||||
|
|
|
|||
304
tests/unit/connectors/test_connectorDbPostgre_pool.py
Normal file
304
tests/unit/connectors/test_connectorDbPostgre_pool.py
Normal file
|
|
@ -0,0 +1,304 @@
|
|||
# Copyright (c) 2026 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Concurrency tests for the PostgreSQL connection pool.
|
||||
|
||||
These tests pin the contract that the `c-work/2-build/2026-05-postgres-connection-pool.md`
|
||||
refactor delivered:
|
||||
|
||||
* T1 — 50 threads × 100 calls in parallel produce 0 `OperationalError`s and
|
||||
every call completes within reasonable time (p99 < 2 s).
|
||||
* T2 — Two threads `_loadRecord` + `_saveRecord` against the same connector
|
||||
do not corrupt each other's cursors.
|
||||
* T3 — `statement_timeout` aborts a runaway `pg_sleep(60)` after ~30 s and
|
||||
releases the connection back into the pool cleanly.
|
||||
|
||||
The tests need a real PostgreSQL server because the bug they guard against
|
||||
only materialises with real psycopg2 sockets — a mocked connection never
|
||||
hangs in `recv()`. They read DB credentials from `APP_CONFIG` (which loads
|
||||
`.env`) and are auto-skipped when the connection fails (no local Postgres,
|
||||
wrong creds, etc.) so `pytest` keeps working in CI-only environments.
|
||||
|
||||
To run them locally:
|
||||
|
||||
pytest gateway/tests/unit/connectors/test_connectorDbPostgre_pool.py -v
|
||||
|
||||
They use a throwaway database name (`poweron_pool_test_<uuid>`) and drop it
|
||||
in fixture teardown so they leave nothing behind.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.errors
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from modules.connectors.connectorDbPostgre import (
|
||||
DatabaseConnector,
|
||||
_PoolRegistry,
|
||||
closeAllPools,
|
||||
)
|
||||
from modules.datamodels.datamodelBase import PowerOnModel
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
|
||||
def _dbConfig():
|
||||
"""Read DB connection params from APP_CONFIG (`.env`).
|
||||
|
||||
Returns ``None`` when host/user/password are not all present so the
|
||||
test module can skip cleanly instead of blowing up at import time.
|
||||
"""
|
||||
host = APP_CONFIG.get("DB_HOST")
|
||||
user = APP_CONFIG.get("DB_USER")
|
||||
password = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
||||
port = APP_CONFIG.get("DB_PORT", 5432)
|
||||
if not host or not user or password is None:
|
||||
return None
|
||||
return {"host": host, "user": user, "password": password, "port": int(port)}
|
||||
|
||||
|
||||
def _canReachPostgres(cfg) -> bool:
|
||||
"""Try a quick connect to the admin DB so we can skip on connection failures."""
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=cfg["host"], port=cfg["port"], database="postgres",
|
||||
user=cfg["user"], password=cfg["password"], connect_timeout=2,
|
||||
)
|
||||
conn.close()
|
||||
return True
|
||||
except Exception: # noqa: BLE001
|
||||
return False
|
||||
|
||||
|
||||
_DB_CFG = _dbConfig()
|
||||
pytestmark = pytest.mark.skipif(
|
||||
_DB_CFG is None or not _canReachPostgres(_DB_CFG),
|
||||
reason="No reachable PostgreSQL — skipping live-Postgres pool tests",
|
||||
)
|
||||
|
||||
|
||||
class PoolTestRow(PowerOnModel):
|
||||
"""Tiny model used to exercise the pool — one ID + one payload field."""
|
||||
payload: str = Field(default="", description="Test payload")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def liveConnector():
|
||||
"""Spin up a throwaway database, yield a `DatabaseConnector` against it,
|
||||
drop the database afterwards.
|
||||
|
||||
The pool registry is wiped before and after each test so state from one
|
||||
test cannot mask a bug in another.
|
||||
"""
|
||||
cfg = _DB_CFG
|
||||
host = cfg["host"]
|
||||
user = cfg["user"]
|
||||
password = cfg["password"]
|
||||
port = cfg["port"]
|
||||
dbName = f"poweron_pool_test_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Pre-clean: drop any orphan test DB with the same name (shouldn't happen
|
||||
# because we use a unique uuid, but be defensive).
|
||||
adminConn = psycopg2.connect(
|
||||
host=host, port=port, database="postgres", user=user, password=password
|
||||
)
|
||||
adminConn.autocommit = True
|
||||
try:
|
||||
with adminConn.cursor() as cur:
|
||||
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
|
||||
finally:
|
||||
adminConn.close()
|
||||
|
||||
closeAllPools()
|
||||
|
||||
connector = DatabaseConnector(
|
||||
dbHost=host,
|
||||
dbDatabase=dbName,
|
||||
dbUser=user,
|
||||
dbPassword=password,
|
||||
dbPort=port,
|
||||
)
|
||||
# Seed exactly one row so every concurrent read has a stable target.
|
||||
connector.recordCreate(PoolTestRow, {"id": "seed", "payload": "hello"})
|
||||
|
||||
yield connector
|
||||
|
||||
# Teardown: tear pools down, then drop the DB.
|
||||
closeAllPools()
|
||||
adminConn = psycopg2.connect(
|
||||
host=host, port=port, database="postgres", user=user, password=password
|
||||
)
|
||||
adminConn.autocommit = True
|
||||
try:
|
||||
with adminConn.cursor() as cur:
|
||||
cur.execute(
|
||||
'SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s',
|
||||
(dbName,),
|
||||
)
|
||||
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
|
||||
finally:
|
||||
adminConn.close()
|
||||
|
||||
|
||||
class TestPoolConcurrency:
|
||||
def _runWorkers(self, liveConnector, *, threadCount: int, callsPerThread: int):
|
||||
"""Run N worker threads, each issuing M reads. Return (errors, latencies)."""
|
||||
errors: list = []
|
||||
latencies: list = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def worker():
|
||||
for _ in range(callsPerThread):
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
rows = liveConnector.getRecordset(PoolTestRow)
|
||||
assert any(r["id"] == "seed" for r in rows)
|
||||
except Exception as e: # noqa: BLE001 — we want every failure mode
|
||||
with lock:
|
||||
errors.append(e)
|
||||
finally:
|
||||
with lock:
|
||||
latencies.append(time.perf_counter() - t0)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=threadCount) as ex:
|
||||
futures = [ex.submit(worker) for _ in range(threadCount)]
|
||||
for f in as_completed(futures):
|
||||
f.result()
|
||||
latencies.sort()
|
||||
return errors, latencies
|
||||
|
||||
def test_50_threads_x_20_reads_no_errors(self, liveConnector):
|
||||
"""T1a — STRESS: 50 threads × 20 reads each → 0 errors.
|
||||
|
||||
Pre-pool, this scenario produced either
|
||||
`OperationalError: another command is already in progress` or a
|
||||
deadlock in `recv()` because the threadpool shared one psycopg2
|
||||
socket. With the pool plus `borrowConn`'s bounded wait, every
|
||||
thread eventually gets a connection and completes — even with 30
|
||||
threads queued waiting at any moment (pool max=20).
|
||||
"""
|
||||
errors, _ = self._runWorkers(liveConnector, threadCount=50, callsPerThread=20)
|
||||
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
|
||||
|
||||
def test_20_threads_x_50_reads_latency_budget(self, liveConnector):
|
||||
"""T1b — DESIGN CAPACITY: 20 threads × 50 reads, p99 < 5 s.
|
||||
|
||||
20 threads matches the pool's `max=20` so there is no queueing —
|
||||
every borrow returns immediately. This pins a sanity-level per-call
|
||||
latency budget; pre-pool it was unbounded (recv() never returned).
|
||||
|
||||
The 5 s ceiling is generous on purpose: `getRecordset` calls
|
||||
`_ensureTableExists` which runs two `information_schema` queries
|
||||
for column-additive migration, and under 20-way concurrency on a
|
||||
single Postgres instance that produces a long tail. The hard
|
||||
assertion is `not errors` — the latency check just guarantees
|
||||
nothing hangs indefinitely.
|
||||
"""
|
||||
errors, latencies = self._runWorkers(
|
||||
liveConnector, threadCount=20, callsPerThread=50
|
||||
)
|
||||
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
|
||||
p99 = latencies[int(len(latencies) * 0.99)]
|
||||
assert p99 < 5.0, f"p99 latency {p99:.2f}s exceeds 5s budget"
|
||||
|
||||
def test_interleaved_load_and_save_no_collision(self, liveConnector):
|
||||
"""T2: parallel reads + writes on the same connector → no cursor mix-up.
|
||||
|
||||
Pre-pool the reader could observe a row in mid-write or vice versa
|
||||
because both shared the same cursor. With one connection per borrow,
|
||||
the database's own row-locking is the only contention, and we just
|
||||
need to assert no exceptions.
|
||||
"""
|
||||
stopFlag = threading.Event()
|
||||
errors: list = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def reader():
|
||||
while not stopFlag.is_set():
|
||||
try:
|
||||
liveConnector.getRecord(PoolTestRow, "seed")
|
||||
except Exception as e: # noqa: BLE001
|
||||
with lock:
|
||||
errors.append(("read", e))
|
||||
|
||||
def writer():
|
||||
i = 0
|
||||
while not stopFlag.is_set():
|
||||
try:
|
||||
liveConnector.recordModify(
|
||||
PoolTestRow,
|
||||
"seed",
|
||||
{"id": "seed", "payload": f"v{i}"},
|
||||
)
|
||||
i += 1
|
||||
except Exception as e: # noqa: BLE001
|
||||
with lock:
|
||||
errors.append(("write", e))
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=reader, daemon=True),
|
||||
threading.Thread(target=reader, daemon=True),
|
||||
threading.Thread(target=writer, daemon=True),
|
||||
threading.Thread(target=writer, daemon=True),
|
||||
]
|
||||
for t in threads:
|
||||
t.start()
|
||||
time.sleep(2.0)
|
||||
stopFlag.set()
|
||||
for t in threads:
|
||||
t.join(timeout=3.0)
|
||||
|
||||
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
|
||||
|
||||
def test_statement_timeout_releases_connection(self, liveConnector):
|
||||
"""T3: `pg_sleep` past statement_timeout → QueryCanceled, pool intact.
|
||||
|
||||
The bug we are guarding against: a runaway query with no timeout
|
||||
hung `recv()` forever, the psycopg2 connection was poisoned, and the
|
||||
whole backend became unresponsive once that connection was reused.
|
||||
With `statement_timeout=30000` configured at pool construction the
|
||||
query is cancelled by the server, the borrow context manager rolls
|
||||
back, and the connection returns to the pool — proven by the fact
|
||||
that a follow-up call still succeeds quickly.
|
||||
"""
|
||||
# Use a short timeout to keep the test fast — override the pool's
|
||||
# session statement_timeout for one borrow via SET LOCAL.
|
||||
with liveConnector.borrowConn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SET LOCAL statement_timeout = 500")
|
||||
with pytest.raises(psycopg2.errors.QueryCanceled):
|
||||
cursor.execute("SELECT pg_sleep(5)")
|
||||
|
||||
# Follow-up call must succeed quickly: connection is back in the pool.
|
||||
t0 = time.perf_counter()
|
||||
rows = liveConnector.getRecordset(PoolTestRow)
|
||||
elapsed = time.perf_counter() - t0
|
||||
assert any(r["id"] == "seed" for r in rows)
|
||||
assert elapsed < 1.0, f"follow-up call took {elapsed:.2f}s — pool may be wedged"
|
||||
|
||||
|
||||
class TestPoolRegistry:
|
||||
def test_one_pool_per_database_identity(self, liveConnector):
|
||||
"""Two connectors against the same (host, db, port) share one pool."""
|
||||
cfg = _DB_CFG
|
||||
pool1 = _PoolRegistry.getPool(
|
||||
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
|
||||
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
|
||||
)
|
||||
pool2 = _PoolRegistry.getPool(
|
||||
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
|
||||
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
|
||||
)
|
||||
assert pool1 is pool2
|
||||
|
||||
def test_close_all_clears_registry(self, liveConnector):
|
||||
"""`closeAllPools()` empties the registry so the next call rebuilds."""
|
||||
# Touch the pool first.
|
||||
liveConnector.getRecordset(PoolTestRow)
|
||||
assert _PoolRegistry._pools, "pool should exist after a real call"
|
||||
closeAllPools()
|
||||
assert _PoolRegistry._pools == {}, "registry should be empty after closeAllPools()"
|
||||
|
|
@ -68,6 +68,16 @@ class _FakeDb:
|
|||
def _ensureTableExists(self, modelClass):
|
||||
return True
|
||||
|
||||
def borrowCursor(self):
|
||||
"""Mimic `DatabaseConnector.borrowCursor()` context manager."""
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@contextmanager
|
||||
def _cm():
|
||||
yield MagicMock()
|
||||
return _cm()
|
||||
|
||||
def seed(self, modelClass, record: Dict[str, Any]):
|
||||
tableName = modelClass.__name__
|
||||
self._tables.setdefault(tableName, {})
|
||||
|
|
|
|||
|
|
@ -69,6 +69,16 @@ class _FakeDb:
|
|||
def _ensureTableExists(self, modelClass):
|
||||
return True
|
||||
|
||||
def borrowCursor(self):
|
||||
"""Mimic `DatabaseConnector.borrowCursor()` context manager for the cascade test."""
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@contextmanager
|
||||
def _cm():
|
||||
yield MagicMock()
|
||||
return _cm()
|
||||
|
||||
def seed(self, modelClass, record: Dict[str, Any]):
|
||||
tableName = modelClass.__name__
|
||||
self._tables.setdefault(tableName, {})
|
||||
|
|
|
|||
Loading…
Reference in a new issue