db connection pooling and rag limit transparency
This commit is contained in:
parent
f5aba4bf99
commit
2bb65c2303
23 changed files with 1519 additions and 782 deletions
9
app.py
9
app.py
|
|
@ -439,6 +439,15 @@ async def lifespan(app: FastAPI):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not shutdown feature containers: {e}")
|
logger.warning(f"Could not shutdown feature containers: {e}")
|
||||||
|
|
||||||
|
# --- Close all PostgreSQL connection pools ---
|
||||||
|
# Must run LAST: feature `onStop` hooks may still issue DB calls during
|
||||||
|
# shutdown. Once we tear down the pools, no more borrows are possible.
|
||||||
|
try:
|
||||||
|
from modules.connectors.connectorDbPostgre import closeAllPools
|
||||||
|
closeAllPools()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Closing DB connection pools failed: {e}")
|
||||||
|
|
||||||
logger.info("Application has been shut down")
|
logger.info("Application has been shut down")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,12 @@
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
import contextvars
|
import contextvars
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
import psycopg2
|
import psycopg2
|
||||||
import psycopg2.extras
|
import psycopg2.extras
|
||||||
|
import psycopg2.pool
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type
|
from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type
|
||||||
import uuid
|
import uuid
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
@ -44,24 +47,6 @@ class DatabaseQueryError(RuntimeError):
|
||||||
self.original = original
|
self.original = original
|
||||||
|
|
||||||
|
|
||||||
def _rollbackQuietly(connection) -> None:
|
|
||||||
"""Restore the connection state after a failed query.
|
|
||||||
|
|
||||||
Postgres puts the connection in an error state after any failed
|
|
||||||
statement; subsequent queries on the same connection raise
|
|
||||||
``InFailedSqlTransaction`` until we rollback. We swallow rollback
|
|
||||||
errors because the original query error is what the caller should
|
|
||||||
see — a secondary rollback failure typically means the connection
|
|
||||||
is gone and will be reopened on the next ``_ensure_connection``.
|
|
||||||
"""
|
|
||||||
if connection is None:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
connection.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SystemTable(PowerOnModel):
|
class SystemTable(PowerOnModel):
|
||||||
"""Data model for system table entries"""
|
"""Data model for system table entries"""
|
||||||
|
|
||||||
|
|
@ -203,9 +188,174 @@ def _quotePgIdent(name: str) -> str:
|
||||||
return '"' + str(name).replace('"', '""') + '"'
|
return '"' + str(name).replace('"', '""') + '"'
|
||||||
|
|
||||||
|
|
||||||
# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
|
# ---------------------------------------------------------------------------
|
||||||
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
|
# Connection pool registry
|
||||||
# contextvars to avoid races when concurrent requests share the same connector.
|
# ---------------------------------------------------------------------------
|
||||||
|
# psycopg2 connections are NOT thread-safe; sharing one connection across the
|
||||||
|
# FastAPI threadpool (sync `def` routes) or across multiple async tasks that
|
||||||
|
# happen to be active simultaneously results in either
|
||||||
|
# `OperationalError: another command is already in progress` or — far worse —
|
||||||
|
# an unbounded hang in `recv()` on the connection socket, because two cursors
|
||||||
|
# wait for one server response.
|
||||||
|
#
|
||||||
|
# `_PoolRegistry` keeps one `ThreadedConnectionPool` per database identity
|
||||||
|
# (`host`, `db`, `port`). Every DB call borrows a connection from the pool,
|
||||||
|
# runs its query, and returns the connection. The pool itself guarantees that
|
||||||
|
# no two callers share the same psycopg2 connection at the same time.
|
||||||
|
#
|
||||||
|
# `statement_timeout=30000` (30 s) is a safety net: a runaway query is aborted
|
||||||
|
# instead of hanging forever and poisoning the connection. `connect_timeout=10`
|
||||||
|
# prevents indefinite blocking when the Postgres host is unreachable.
|
||||||
|
|
||||||
|
_DEFAULT_POOL_MIN = 2
|
||||||
|
_DEFAULT_POOL_MAX = 20
|
||||||
|
_STATEMENT_TIMEOUT_MS = 30000
|
||||||
|
_CONNECT_TIMEOUT_S = 10
|
||||||
|
# psycopg2.pool.ThreadedConnectionPool.getconn() does NOT block when the pool
|
||||||
|
# is exhausted — it raises `psycopg2.pool.PoolError` immediately. That makes
|
||||||
|
# bursty workloads (50 concurrent route calls against a max=20 pool) fail
|
||||||
|
# spuriously. `borrowConn()` therefore retries with a small backoff up to
|
||||||
|
# `_BORROW_WAIT_TIMEOUT_S` seconds before giving up.
|
||||||
|
_BORROW_WAIT_TIMEOUT_S = 30.0
|
||||||
|
_BORROW_WAIT_BACKOFF_S = 0.05
|
||||||
|
|
||||||
|
|
||||||
|
def _resolvePoolMax() -> int:
|
||||||
|
"""Pool size is configurable via `DB_POOL_MAX_CONN` (default 20)."""
|
||||||
|
try:
|
||||||
|
return max(2, int(APP_CONFIG.get("DB_POOL_MAX_CONN") or _DEFAULT_POOL_MAX))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return _DEFAULT_POOL_MAX
|
||||||
|
|
||||||
|
|
||||||
|
class _PoolRegistry:
|
||||||
|
"""Process-wide registry of `ThreadedConnectionPool` instances.
|
||||||
|
|
||||||
|
Keyed by `(host, database, port)` so that two `DatabaseConnector` instances
|
||||||
|
pointing at the same physical database share one pool. Lazy-initialised and
|
||||||
|
thread-safe.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_pools: Dict[tuple, psycopg2.pool.ThreadedConnectionPool] = {}
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getPool(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
dbHost: str,
|
||||||
|
dbDatabase: str,
|
||||||
|
dbUser: str,
|
||||||
|
dbPassword: str,
|
||||||
|
dbPort: int,
|
||||||
|
) -> psycopg2.pool.ThreadedConnectionPool:
|
||||||
|
port = int(dbPort) if dbPort is not None else 5432
|
||||||
|
key = (dbHost, dbDatabase, port)
|
||||||
|
# Fast path: pool exists.
|
||||||
|
pool = cls._pools.get(key)
|
||||||
|
if pool is not None:
|
||||||
|
return pool
|
||||||
|
# Slow path: create exactly one pool per key, even under contention.
|
||||||
|
with cls._lock:
|
||||||
|
pool = cls._pools.get(key)
|
||||||
|
if pool is not None:
|
||||||
|
return pool
|
||||||
|
poolMax = _resolvePoolMax()
|
||||||
|
options = f"-c statement_timeout={_STATEMENT_TIMEOUT_MS}"
|
||||||
|
try:
|
||||||
|
pool = psycopg2.pool.ThreadedConnectionPool(
|
||||||
|
_DEFAULT_POOL_MIN,
|
||||||
|
poolMax,
|
||||||
|
host=dbHost,
|
||||||
|
port=port,
|
||||||
|
database=dbDatabase,
|
||||||
|
user=dbUser,
|
||||||
|
password=dbPassword,
|
||||||
|
client_encoding="utf8",
|
||||||
|
cursor_factory=psycopg2.extras.RealDictCursor,
|
||||||
|
connect_timeout=_CONNECT_TIMEOUT_S,
|
||||||
|
options=options,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to create connection pool for db=%s host=%s: %s",
|
||||||
|
dbDatabase, dbHost, e,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
cls._pools[key] = pool
|
||||||
|
logger.debug(
|
||||||
|
"Created connection pool for db=%s host=%s port=%s (min=%d max=%d, stmt_timeout=%dms)",
|
||||||
|
dbDatabase, dbHost, port, _DEFAULT_POOL_MIN, poolMax, _STATEMENT_TIMEOUT_MS,
|
||||||
|
)
|
||||||
|
return pool
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def closeAll(cls) -> None:
|
||||||
|
"""Close every pool. Call once during FastAPI shutdown."""
|
||||||
|
with cls._lock:
|
||||||
|
for key, pool in list(cls._pools.items()):
|
||||||
|
try:
|
||||||
|
pool.closeall()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Error closing pool %s: %s", key, e)
|
||||||
|
cls._pools.clear()
|
||||||
|
logger.info("All database connection pools closed")
|
||||||
|
|
||||||
|
|
||||||
|
def closeAllPools() -> None:
|
||||||
|
"""Public entry point for FastAPI lifespan shutdown hook."""
|
||||||
|
_PoolRegistry.closeAll()
|
||||||
|
|
||||||
|
|
||||||
|
class _ConnectionShim:
|
||||||
|
"""Backward-compatibility stand-in for the old `DatabaseConnector.connection`.
|
||||||
|
|
||||||
|
Old callers used patterns like::
|
||||||
|
|
||||||
|
db.connection.commit()
|
||||||
|
if db.connection and not db.connection.closed: ...
|
||||||
|
|
||||||
|
These are now no-ops because the pool owns the connection lifecycle.
|
||||||
|
Direct cursor access through `self.connection.cursor()` is intentionally
|
||||||
|
blocked with a clear error so silent breakage is impossible — every such
|
||||||
|
call site must migrate to `db.borrowCursor()`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
closed = False
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def commit(self) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def rollback(self) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def cursor(self, *args, **kwargs):
|
||||||
|
raise RuntimeError(
|
||||||
|
"DatabaseConnector.connection.cursor() is no longer supported. "
|
||||||
|
"Use `db.borrowCursor()` (or `db.borrowConn()` for multi-statement "
|
||||||
|
"transactions) so the connection is borrowed from and returned to "
|
||||||
|
"the pool correctly."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_CONNECTION_SHIM = _ConnectionShim()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Connector cache (lightweight wrappers — actual connections live in the pool)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Multiple call sites (`routeI18n`, `aiAuditLogger`, `mainBackgroundJobService`,
|
||||||
|
# interfaces) ask for a connector via `getCachedConnector(...)` and expect to
|
||||||
|
# get back the same object on subsequent calls. Now that real connection
|
||||||
|
# multiplexing happens at the pool layer, the cache returns lightweight
|
||||||
|
# `DatabaseConnector` wrappers — they hold no connection themselves, only the
|
||||||
|
# DSN params and a reference to the shared pool.
|
||||||
_MAX_CACHED_CONNECTORS = 32
|
_MAX_CACHED_CONNECTORS = 32
|
||||||
_connector_cache: Dict[tuple, "DatabaseConnector"] = {}
|
_connector_cache: Dict[tuple, "DatabaseConnector"] = {}
|
||||||
_connector_cache_order: List[tuple] = [] # FIFO order for eviction
|
_connector_cache_order: List[tuple] = [] # FIFO order for eviction
|
||||||
|
|
@ -223,22 +373,36 @@ def getCachedConnector(
|
||||||
dbPort: int = None,
|
dbPort: int = None,
|
||||||
userId: str = None,
|
userId: str = None,
|
||||||
) -> "DatabaseConnector":
|
) -> "DatabaseConnector":
|
||||||
"""Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits.
|
"""Return a cached `DatabaseConnector` wrapper for `(host, database, port)`.
|
||||||
Uses contextvars for userId so concurrent requests sharing the same connector get correct sysCreatedBy/sysModifiedBy.
|
|
||||||
|
Two layers of caching, both intentional:
|
||||||
|
|
||||||
|
1. **Pool layer** (`_PoolRegistry`) — owns the actual psycopg2 connections.
|
||||||
|
Per `(host, db, port)` exactly one `ThreadedConnectionPool`, shared
|
||||||
|
across every wrapper that points at the same physical database. This
|
||||||
|
is what actually saves Postgres connection slots.
|
||||||
|
|
||||||
|
2. **Wrapper layer** (this function) — caches the `DatabaseConnector`
|
||||||
|
Python object. Each wrapper triggers an `initDbSystem()` on first
|
||||||
|
instantiation (CREATE DATABASE if missing, CREATE TABLE _system,
|
||||||
|
pool warm-up). Caching the wrapper avoids paying that bootstrap cost
|
||||||
|
on every request and stops the log from filling with "PostgreSQL
|
||||||
|
database system initialized" lines.
|
||||||
|
|
||||||
|
`userId` is request-scoped via the `_current_user_id` contextvar so two
|
||||||
|
concurrent requests sharing the same cached wrapper still produce
|
||||||
|
correct `sysCreatedBy` / `sysModifiedBy` audit fields.
|
||||||
"""
|
"""
|
||||||
port = int(dbPort) if dbPort is not None else 5432
|
port = int(dbPort) if dbPort is not None else 5432
|
||||||
key = (dbHost, dbDatabase, port)
|
key = (dbHost, dbDatabase, port)
|
||||||
with _connector_cache_lock:
|
with _connector_cache_lock:
|
||||||
if key not in _connector_cache:
|
if key not in _connector_cache:
|
||||||
# Evict oldest if at capacity
|
# FIFO eviction. Connectors are now lightweight (no per-instance
|
||||||
|
# connection), so eviction is purely a memory bookkeeping concern;
|
||||||
|
# the underlying pool stays alive in `_PoolRegistry`.
|
||||||
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
|
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
|
||||||
oldest_key = _connector_cache_order.pop(0)
|
oldest_key = _connector_cache_order.pop(0)
|
||||||
if oldest_key in _connector_cache:
|
_connector_cache.pop(oldest_key, None)
|
||||||
try:
|
|
||||||
_connector_cache[oldest_key].close(forceClose=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error closing evicted connector: {e}")
|
|
||||||
del _connector_cache[oldest_key]
|
|
||||||
_connector_cache[key] = DatabaseConnector(
|
_connector_cache[key] = DatabaseConnector(
|
||||||
dbHost=dbHost,
|
dbHost=dbHost,
|
||||||
dbDatabase=dbDatabase,
|
dbDatabase=dbDatabase,
|
||||||
|
|
@ -282,34 +446,38 @@ class DatabaseConnector:
|
||||||
# Set userId (default to empty string if None)
|
# Set userId (default to empty string if None)
|
||||||
self.userId = userId if userId is not None else ""
|
self.userId = userId if userId is not None else ""
|
||||||
|
|
||||||
# Initialize database system first (creates database if needed)
|
# No per-instance connection any more — real connections live in the
|
||||||
self.connection = None
|
# shared `_PoolRegistry` pool. `_isCachedShared` is retained because
|
||||||
|
# `close(forceClose=False)` callers (interface __del__) still ask.
|
||||||
self._isCachedShared = False
|
self._isCachedShared = False
|
||||||
self.initDbSystem()
|
|
||||||
|
|
||||||
# No caching needed with proper database - PostgreSQL handles performance
|
# pgvector extension state (cached per connector instance — cheap)
|
||||||
|
|
||||||
# Thread safety
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
# pgvector extension state
|
|
||||||
self._vectorExtensionEnabled = False
|
self._vectorExtensionEnabled = False
|
||||||
|
|
||||||
# Initialize system table
|
# System table bootstrap: create database, system table, ensure metadata.
|
||||||
self._systemTableName = "_system"
|
self._systemTableName = "_system"
|
||||||
|
self.initDbSystem()
|
||||||
self._initializeSystemTable()
|
self._initializeSystemTable()
|
||||||
|
|
||||||
def initDbSystem(self):
|
def initDbSystem(self):
|
||||||
"""Initialize the database system - creates database and tables."""
|
"""Bootstrap the physical database and the `_system` metadata table.
|
||||||
try:
|
|
||||||
# Create database if it doesn't exist
|
|
||||||
self._create_database_if_not_exists()
|
|
||||||
|
|
||||||
# Create tables
|
Uses a short-lived autocommit connection (NOT the pool) because
|
||||||
|
`CREATE DATABASE` cannot run inside a transaction block. Also warms up
|
||||||
|
the pool by acquiring it for this `(host, db, port)` once.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._create_database_if_not_exists()
|
||||||
self._create_tables()
|
self._create_tables()
|
||||||
|
|
||||||
# Establish connection to the database
|
# Warm the pool so the first request doesn't pay for socket setup.
|
||||||
self._connect()
|
_PoolRegistry.getPool(
|
||||||
|
dbHost=self.dbHost,
|
||||||
|
dbDatabase=self.dbDatabase,
|
||||||
|
dbUser=self.dbUser,
|
||||||
|
dbPassword=self.dbPassword,
|
||||||
|
dbPort=self.dbPort,
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"PostgreSQL database system initialized (db=%s, host=%s, port=%s)",
|
"PostgreSQL database system initialized (db=%s, host=%s, port=%s)",
|
||||||
|
|
@ -319,10 +487,121 @@ class DatabaseConnector:
|
||||||
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
|
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _create_database_if_not_exists(self):
|
@property
|
||||||
"""Create the database if it doesn't exist."""
|
def connection(self) -> "_ConnectionShim":
|
||||||
|
"""Backward-compat shim — see `_ConnectionShim` docstring."""
|
||||||
|
return _CONNECTION_SHIM
|
||||||
|
|
||||||
|
def _ensure_connection(self) -> None:
|
||||||
|
"""No-op for backward compatibility.
|
||||||
|
|
||||||
|
Previously this method tested-or-reconnected the per-instance socket.
|
||||||
|
With pooling, the `ThreadedConnectionPool` re-establishes broken
|
||||||
|
connections lazily on the next `getconn()`. Kept as a no-op so that
|
||||||
|
legacy call sites continue to compile.
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def borrowCursor(self):
|
||||||
|
"""Borrow a cursor for one short SQL block.
|
||||||
|
|
||||||
|
Convenience wrapper around `borrowConn()` for the common pattern:
|
||||||
|
|
||||||
|
with db.borrowCursor() as cursor:
|
||||||
|
cursor.execute(...)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
|
||||||
|
Replaces the legacy `with db.connection.cursor() as cursor:` pattern.
|
||||||
|
Commit/rollback and pool return are handled automatically — callers
|
||||||
|
must NOT call `.commit()`/`.rollback()` on the cursor themselves.
|
||||||
|
"""
|
||||||
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
yield cursor
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def borrowConn(self):
|
||||||
|
"""Borrow a connection from the pool for the duration of the block.
|
||||||
|
|
||||||
|
Pool-exhaustion semantics: `ThreadedConnectionPool.getconn()` raises
|
||||||
|
`psycopg2.pool.PoolError` immediately when the pool is at its `maxconn`
|
||||||
|
limit — it does NOT block. We wrap that with a bounded busy-wait so
|
||||||
|
bursty workloads (e.g. 50 concurrent route calls against a max=20
|
||||||
|
pool) queue up instead of failing. If no connection becomes available
|
||||||
|
within `_BORROW_WAIT_TIMEOUT_S`, the `PoolError` is propagated so a
|
||||||
|
genuinely deadlocked pool surfaces as a real error.
|
||||||
|
|
||||||
|
On normal exit the current transaction is committed (idempotent for
|
||||||
|
read-only queries — leaves the connection in a clean state for the
|
||||||
|
next borrower). On exception the transaction is rolled back and the
|
||||||
|
exception propagates. The connection is **always** returned to the
|
||||||
|
pool, even when commit/rollback fails — otherwise a single bad query
|
||||||
|
would leak a slot and eventually exhaust the pool.
|
||||||
|
"""
|
||||||
|
pool = _PoolRegistry.getPool(
|
||||||
|
dbHost=self.dbHost,
|
||||||
|
dbDatabase=self.dbDatabase,
|
||||||
|
dbUser=self.dbUser,
|
||||||
|
dbPassword=self.dbPassword,
|
||||||
|
dbPort=self.dbPort,
|
||||||
|
)
|
||||||
|
conn = self._acquireConn(pool)
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
conn.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# Best-effort commit so the connection goes back to the pool with
|
||||||
|
# no in-flight transaction. Failure here is non-fatal — the pool
|
||||||
|
# will re-establish the socket on the next `getconn()` if needed.
|
||||||
|
try:
|
||||||
|
conn.commit()
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
conn.rollback()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
pool.putconn(conn)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to return connection to pool: %s", e)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _acquireConn(pool: psycopg2.pool.ThreadedConnectionPool):
|
||||||
|
"""Get a connection from the pool, waiting up to `_BORROW_WAIT_TIMEOUT_S`.
|
||||||
|
|
||||||
|
psycopg2's pool throws on exhaustion instead of queueing — this helper
|
||||||
|
polls with a short backoff so callers see queue semantics.
|
||||||
|
"""
|
||||||
|
deadline = time.monotonic() + _BORROW_WAIT_TIMEOUT_S
|
||||||
|
attempt = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
return pool.getconn()
|
||||||
|
except psycopg2.pool.PoolError as e:
|
||||||
|
attempt += 1
|
||||||
|
if time.monotonic() >= deadline:
|
||||||
|
logger.error(
|
||||||
|
"Connection pool exhausted after %.1fs wait (%d retries)",
|
||||||
|
_BORROW_WAIT_TIMEOUT_S, attempt,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
time.sleep(_BORROW_WAIT_BACKOFF_S)
|
||||||
|
|
||||||
|
def _create_database_if_not_exists(self):
|
||||||
|
"""Create the database if it doesn't exist.
|
||||||
|
|
||||||
|
Uses an autocommit connection on the `postgres` admin DB because
|
||||||
|
`CREATE DATABASE` cannot run inside a transaction block — so this
|
||||||
|
path intentionally does NOT use the pool.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Use the configured user for database creation
|
|
||||||
conn = psycopg2.connect(
|
conn = psycopg2.connect(
|
||||||
host=self.dbHost,
|
host=self.dbHost,
|
||||||
port=self.dbPort,
|
port=self.dbPort,
|
||||||
|
|
@ -330,22 +609,20 @@ class DatabaseConnector:
|
||||||
user=self.dbUser,
|
user=self.dbUser,
|
||||||
password=self.dbPassword,
|
password=self.dbPassword,
|
||||||
client_encoding="utf8",
|
client_encoding="utf8",
|
||||||
|
connect_timeout=_CONNECT_TIMEOUT_S,
|
||||||
)
|
)
|
||||||
conn.autocommit = True
|
conn.autocommit = True
|
||||||
|
try:
|
||||||
with conn.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
# Check if database exists
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
|
"SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
|
||||||
)
|
)
|
||||||
exists = cursor.fetchone()
|
exists = cursor.fetchone()
|
||||||
|
|
||||||
if not exists:
|
if not exists:
|
||||||
# Create database with proper quoting for names with hyphens
|
|
||||||
quoted_db_name = f'"{self.dbDatabase}"'
|
quoted_db_name = f'"{self.dbDatabase}"'
|
||||||
cursor.execute(f"CREATE DATABASE {quoted_db_name}")
|
cursor.execute(f"CREATE DATABASE {quoted_db_name}")
|
||||||
logger.info(f"Created database: {self.dbDatabase}")
|
logger.info(f"Created database: {self.dbDatabase}")
|
||||||
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -356,9 +633,12 @@ class DatabaseConnector:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_tables(self):
|
def _create_tables(self):
|
||||||
"""Create only the system table - application tables are created by interfaces."""
|
"""Create the `_system` table.
|
||||||
|
|
||||||
|
Uses a short-lived autocommit connection (not the pool) — runs exactly
|
||||||
|
once at connector creation.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Use the configured user for table creation
|
|
||||||
conn = psycopg2.connect(
|
conn = psycopg2.connect(
|
||||||
host=self.dbHost,
|
host=self.dbHost,
|
||||||
port=self.dbPort,
|
port=self.dbPort,
|
||||||
|
|
@ -366,11 +646,11 @@ class DatabaseConnector:
|
||||||
user=self.dbUser,
|
user=self.dbUser,
|
||||||
password=self.dbPassword,
|
password=self.dbPassword,
|
||||||
client_encoding="utf8",
|
client_encoding="utf8",
|
||||||
|
connect_timeout=_CONNECT_TIMEOUT_S,
|
||||||
)
|
)
|
||||||
conn.autocommit = True
|
conn.autocommit = True
|
||||||
|
try:
|
||||||
with conn.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
# Create only the system table
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS _system (
|
CREATE TABLE IF NOT EXISTS _system (
|
||||||
id SERIAL PRIMARY KEY,
|
id SERIAL PRIMARY KEY,
|
||||||
|
|
@ -382,6 +662,7 @@ class DatabaseConnector:
|
||||||
"sysModifiedBy" VARCHAR(255)
|
"sysModifiedBy" VARCHAR(255)
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -391,67 +672,26 @@ class DatabaseConnector:
|
||||||
)
|
)
|
||||||
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
|
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
|
||||||
|
|
||||||
def _connect(self):
|
|
||||||
"""Establish connection to PostgreSQL database."""
|
|
||||||
try:
|
|
||||||
# Use configured user for main connection with proper parameter handling
|
|
||||||
self.connection = psycopg2.connect(
|
|
||||||
host=self.dbHost,
|
|
||||||
port=self.dbPort,
|
|
||||||
database=self.dbDatabase,
|
|
||||||
user=self.dbUser,
|
|
||||||
password=self.dbPassword,
|
|
||||||
client_encoding="utf8",
|
|
||||||
cursor_factory=psycopg2.extras.RealDictCursor,
|
|
||||||
)
|
|
||||||
self.connection.autocommit = False # Use transactions
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to connect to PostgreSQL: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _ensure_connection(self):
|
|
||||||
"""Ensure database connection is alive, reconnect if necessary."""
|
|
||||||
try:
|
|
||||||
if self.connection is None or self.connection.closed:
|
|
||||||
self._connect()
|
|
||||||
else:
|
|
||||||
# Test connection with a simple query
|
|
||||||
with self.connection.cursor() as cursor:
|
|
||||||
cursor.execute("SELECT 1")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Connection lost, reconnecting: {e}")
|
|
||||||
self._connect()
|
|
||||||
|
|
||||||
def _initializeSystemTable(self):
|
def _initializeSystemTable(self):
|
||||||
"""Initializes the system table if it doesn't exist yet."""
|
"""Initializes the system table if it doesn't exist yet."""
|
||||||
try:
|
try:
|
||||||
# First ensure the system table exists
|
|
||||||
self._ensureTableExists(SystemTable)
|
self._ensureTableExists(SystemTable)
|
||||||
|
with self.borrowConn() as conn:
|
||||||
with self.connection.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
# Check if system table has any data
|
|
||||||
cursor.execute('SELECT COUNT(*) FROM "_system"')
|
cursor.execute('SELECT COUNT(*) FROM "_system"')
|
||||||
row = cursor.fetchone()
|
cursor.fetchone() # noqa: just verifies table is readable
|
||||||
count = row["count"] if row else 0
|
|
||||||
|
|
||||||
self.connection.commit()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error initializing system table: {e}")
|
logger.error(f"Error initializing system table: {e}")
|
||||||
self.connection.rollback()
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _loadSystemTable(self) -> Dict[str, str]:
|
def _loadSystemTable(self) -> Dict[str, str]:
|
||||||
"""Loads the system table with the initial IDs."""
|
"""Loads the system table with the initial IDs."""
|
||||||
try:
|
try:
|
||||||
with self.connection.cursor() as cursor:
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
|
cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
|
||||||
rows = cursor.fetchall()
|
rows = cursor.fetchall()
|
||||||
|
return {row["table_name"]: row["initial_id"] for row in rows}
|
||||||
system_data = {}
|
|
||||||
for row in rows:
|
|
||||||
system_data[row["table_name"]] = row["initial_id"]
|
|
||||||
|
|
||||||
return system_data
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading system table: {e}")
|
logger.error(f"Error loading system table: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
@ -459,11 +699,9 @@ class DatabaseConnector:
|
||||||
def _saveSystemTable(self, data: Dict[str, str]) -> bool:
|
def _saveSystemTable(self, data: Dict[str, str]) -> bool:
|
||||||
"""Saves the system table with the initial IDs."""
|
"""Saves the system table with the initial IDs."""
|
||||||
try:
|
try:
|
||||||
with self.connection.cursor() as cursor:
|
with self.borrowConn() as conn:
|
||||||
# Clear existing data
|
with conn.cursor() as cursor:
|
||||||
cursor.execute('DELETE FROM "_system"')
|
cursor.execute('DELETE FROM "_system"')
|
||||||
|
|
||||||
# Insert new data
|
|
||||||
for table_name, initial_id in data.items():
|
for table_name, initial_id in data.items():
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
|
|
@ -472,21 +710,16 @@ class DatabaseConnector:
|
||||||
""",
|
""",
|
||||||
(table_name, initial_id, getUtcTimestamp()),
|
(table_name, initial_id, getUtcTimestamp()),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.connection.commit()
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error saving system table: {e}")
|
logger.error(f"Error saving system table: {e}")
|
||||||
self.connection.rollback()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _ensureSystemTableExists(self) -> bool:
|
def _ensureSystemTableExists(self) -> bool:
|
||||||
"""Ensures the system table exists, creates it if it doesn't."""
|
"""Ensures the system table exists, creates it if it doesn't."""
|
||||||
try:
|
try:
|
||||||
self._ensure_connection()
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
with self.connection.cursor() as cursor:
|
|
||||||
# Check if system table exists
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
|
"SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
|
||||||
(self._systemTableName,),
|
(self._systemTableName,),
|
||||||
|
|
@ -494,7 +727,6 @@ class DatabaseConnector:
|
||||||
exists = cursor.fetchone()["count"] > 0
|
exists = cursor.fetchone()["count"] > 0
|
||||||
|
|
||||||
if not exists:
|
if not exists:
|
||||||
# Create system table
|
|
||||||
cursor.execute(f"""
|
cursor.execute(f"""
|
||||||
CREATE TABLE "{self._systemTableName}" (
|
CREATE TABLE "{self._systemTableName}" (
|
||||||
"table_name" VARCHAR(255) PRIMARY KEY,
|
"table_name" VARCHAR(255) PRIMARY KEY,
|
||||||
|
|
@ -507,7 +739,6 @@ class DatabaseConnector:
|
||||||
""")
|
""")
|
||||||
logger.info("System table created successfully")
|
logger.info("System table created successfully")
|
||||||
else:
|
else:
|
||||||
# Check if we need to add missing columns to existing table
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
SELECT column_name FROM information_schema.columns
|
SELECT column_name FROM information_schema.columns
|
||||||
|
|
@ -527,7 +758,6 @@ class DatabaseConnector:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
|
f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
|
||||||
)
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error ensuring system table exists: {e}")
|
logger.error(f"Error ensuring system table exists: {e}")
|
||||||
|
|
@ -542,10 +772,8 @@ class DatabaseConnector:
|
||||||
return self._ensureSystemTableExists()
|
return self._ensureSystemTableExists()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._ensure_connection()
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
with self.connection.cursor() as cursor:
|
|
||||||
# Check if table exists by querying information_schema with case-insensitive search
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"""
|
"""
|
||||||
SELECT COUNT(*) FROM information_schema.tables
|
SELECT COUNT(*) FROM information_schema.tables
|
||||||
|
|
@ -556,7 +784,6 @@ class DatabaseConnector:
|
||||||
exists = cursor.fetchone()["count"] > 0
|
exists = cursor.fetchone()["count"] > 0
|
||||||
|
|
||||||
if not exists:
|
if not exists:
|
||||||
# Create table from Pydantic model
|
|
||||||
self._create_table_from_model(cursor, table, model_class)
|
self._create_table_from_model(cursor, table, model_class)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Created table '{table}' with columns from Pydantic model"
|
f"Created table '{table}' with columns from Pydantic model"
|
||||||
|
|
@ -581,15 +808,12 @@ class DatabaseConnector:
|
||||||
for row in existing_column_rows
|
for row in existing_column_rows
|
||||||
}
|
}
|
||||||
|
|
||||||
# Desired columns based on model
|
|
||||||
model_fields = getModelFields(model_class)
|
model_fields = getModelFields(model_class)
|
||||||
desired_columns = set(["id"]) | set(model_fields.keys())
|
desired_columns = set(["id"]) | set(model_fields.keys())
|
||||||
|
|
||||||
# Add missing columns
|
|
||||||
for col in sorted(desired_columns - existing_columns):
|
for col in sorted(desired_columns - existing_columns):
|
||||||
# Determine SQL type
|
|
||||||
if col in ["id"]:
|
if col in ["id"]:
|
||||||
continue # primary key exists already
|
continue
|
||||||
sql_type = model_fields.get(col)
|
sql_type = model_fields.get(col)
|
||||||
if not sql_type:
|
if not sql_type:
|
||||||
sql_type = "TEXT"
|
sql_type = "TEXT"
|
||||||
|
|
@ -652,13 +876,9 @@ class DatabaseConnector:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Could not ensure columns for existing table '{table}': {ensure_err}"
|
f"Could not ensure columns for existing table '{table}': {ensure_err}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.connection.commit()
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error ensuring table {table} exists: {e}")
|
logger.error(f"Error ensuring table {table} exists: {e}")
|
||||||
if hasattr(self, "connection") and self.connection:
|
|
||||||
self.connection.rollback()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _ensureVectorExtension(self) -> bool:
|
def _ensureVectorExtension(self) -> bool:
|
||||||
|
|
@ -666,17 +886,14 @@ class DatabaseConnector:
|
||||||
if self._vectorExtensionEnabled:
|
if self._vectorExtensionEnabled:
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
self._ensure_connection()
|
with self.borrowConn() as conn:
|
||||||
with self.connection.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
self.connection.commit()
|
|
||||||
self._vectorExtensionEnabled = True
|
self._vectorExtensionEnabled = True
|
||||||
logger.info("pgvector extension enabled")
|
logger.info("pgvector extension enabled")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to enable pgvector extension: {e}")
|
logger.error(f"Failed to enable pgvector extension: {e}")
|
||||||
if hasattr(self, "connection") and self.connection:
|
|
||||||
self.connection.rollback()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
|
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
|
||||||
|
|
@ -791,22 +1008,19 @@ class DatabaseConnector:
|
||||||
if not self._ensureTableExists(model_class):
|
if not self._ensureTableExists(model_class):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with self.connection.cursor() as cursor:
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
|
cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
if not row:
|
if not row:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Convert row to dict and handle JSONB fields
|
|
||||||
record = dict(row)
|
record = dict(row)
|
||||||
|
|
||||||
fields = getModelFields(model_class)
|
fields = getModelFields(model_class)
|
||||||
|
|
||||||
parseRecordFields(record, fields, f"record {recordId}")
|
parseRecordFields(record, fields, f"record {recordId}")
|
||||||
|
|
||||||
return record
|
return record
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading record {recordId} from table {table}: {e}")
|
logger.error(f"Error loading record {recordId} from table {table}: {e}")
|
||||||
_rollbackQuietly(getattr(self, "connection", None))
|
|
||||||
raise DatabaseQueryError(table, str(e), original=e) from e
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
||||||
|
|
||||||
def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
|
def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
|
||||||
|
|
@ -849,14 +1063,12 @@ class DatabaseConnector:
|
||||||
if effective_user_id:
|
if effective_user_id:
|
||||||
record["sysModifiedBy"] = effective_user_id
|
record["sysModifiedBy"] = effective_user_id
|
||||||
|
|
||||||
with self.connection.cursor() as cursor:
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
self._save_record(cursor, table, recordId, record, model_class)
|
self._save_record(cursor, table, recordId, record, model_class)
|
||||||
|
|
||||||
self.connection.commit()
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error saving record {recordId} to table {table}: {e}")
|
logger.error(f"Error saving record {recordId} to table {table}: {e}")
|
||||||
self.connection.rollback()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
|
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
|
||||||
|
|
@ -870,7 +1082,8 @@ class DatabaseConnector:
|
||||||
if not self._ensureTableExists(model_class):
|
if not self._ensureTableExists(model_class):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
with self.connection.cursor() as cursor:
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
|
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
|
||||||
records = [dict(row) for row in cursor.fetchall()]
|
records = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
|
@ -878,7 +1091,6 @@ class DatabaseConnector:
|
||||||
modelFields = model_class.model_fields
|
modelFields = model_class.model_fields
|
||||||
for record in records:
|
for record in records:
|
||||||
parseRecordFields(record, fields, f"table {table}")
|
parseRecordFields(record, fields, f"table {table}")
|
||||||
# Set type-aware defaults for NULL JSONB fields
|
|
||||||
for fieldName, fieldType in fields.items():
|
for fieldName, fieldType in fields.items():
|
||||||
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
||||||
fieldInfo = modelFields.get(fieldName)
|
fieldInfo = modelFields.get(fieldName)
|
||||||
|
|
@ -896,7 +1108,6 @@ class DatabaseConnector:
|
||||||
return records
|
return records
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading table {table}: {e}")
|
logger.error(f"Error loading table {table}: {e}")
|
||||||
_rollbackQuietly(getattr(self, "connection", None))
|
|
||||||
raise DatabaseQueryError(table, str(e), original=e) from e
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
||||||
|
|
||||||
def _registerInitialId(self, table: str, initialId: str) -> bool:
|
def _registerInitialId(self, table: str, initialId: str) -> bool:
|
||||||
|
|
@ -969,17 +1180,10 @@ class DatabaseConnector:
|
||||||
|
|
||||||
def getTables(self) -> List[str]:
|
def getTables(self) -> List[str]:
|
||||||
"""Returns a list of all available tables."""
|
"""Returns a list of all available tables."""
|
||||||
tables = []
|
tables: List[str] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Ensure connection is alive
|
with self.borrowConn() as conn:
|
||||||
self._ensure_connection()
|
with conn.cursor() as cursor:
|
||||||
|
|
||||||
if not self.connection or self.connection.closed:
|
|
||||||
logger.error("Database connection is not available")
|
|
||||||
return tables
|
|
||||||
|
|
||||||
with self.connection.cursor() as cursor:
|
|
||||||
cursor.execute("""
|
cursor.execute("""
|
||||||
SELECT table_name
|
SELECT table_name
|
||||||
FROM information_schema.tables
|
FROM information_schema.tables
|
||||||
|
|
@ -990,7 +1194,6 @@ class DatabaseConnector:
|
||||||
tables = [row["table_name"] for row in rows]
|
tables = [row["table_name"] for row in rows]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error reading the database {self.dbDatabase}: {e}")
|
logger.error(f"Error reading the database {self.dbDatabase}: {e}")
|
||||||
|
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
def getFields(self, model_class: type) -> List[str]:
|
def getFields(self, model_class: type) -> List[str]:
|
||||||
|
|
@ -1060,7 +1263,8 @@ class DatabaseConnector:
|
||||||
|
|
||||||
query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"'
|
query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"'
|
||||||
|
|
||||||
with self.connection.cursor() as cursor:
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
cursor.execute(query, where_values)
|
cursor.execute(query, where_values)
|
||||||
records = [dict(row) for row in cursor.fetchall()]
|
records = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
|
@ -1082,7 +1286,6 @@ class DatabaseConnector:
|
||||||
fieldAnnotation.__origin__ is dict)):
|
fieldAnnotation.__origin__ is dict)):
|
||||||
record[fieldName] = {}
|
record[fieldName] = {}
|
||||||
|
|
||||||
# If fieldFilter is available, reduce the fields
|
|
||||||
if fieldFilter and isinstance(fieldFilter, list):
|
if fieldFilter and isinstance(fieldFilter, list):
|
||||||
result = []
|
result = []
|
||||||
for record in records:
|
for record in records:
|
||||||
|
|
@ -1096,7 +1299,6 @@ class DatabaseConnector:
|
||||||
return records
|
return records
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading records from table {table}: {e}")
|
logger.error(f"Error loading records from table {table}: {e}")
|
||||||
_rollbackQuietly(getattr(self, "connection", None))
|
|
||||||
raise DatabaseQueryError(table, str(e), original=e) from e
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
||||||
|
|
||||||
def _buildPaginationClauses(
|
def _buildPaginationClauses(
|
||||||
|
|
@ -1281,7 +1483,8 @@ class DatabaseConnector:
|
||||||
where_clause, order_clause, limit_clause, values, count_values = \
|
where_clause, order_clause, limit_clause, values, count_values = \
|
||||||
self._buildPaginationClauses(model_class, pagination, recordFilter)
|
self._buildPaginationClauses(model_class, pagination, recordFilter)
|
||||||
|
|
||||||
with self.connection.cursor() as cursor:
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
|
countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
|
||||||
dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
|
dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
|
||||||
cursor.execute(countSql, count_values)
|
cursor.execute(countSql, count_values)
|
||||||
|
|
@ -1320,7 +1523,6 @@ class DatabaseConnector:
|
||||||
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
|
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
|
logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
|
||||||
_rollbackQuietly(getattr(self, "connection", None))
|
|
||||||
raise DatabaseQueryError(table, str(e), original=e) from e
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
||||||
|
|
||||||
def getDistinctColumnValues(
|
def getDistinctColumnValues(
|
||||||
|
|
@ -1365,7 +1567,8 @@ class DatabaseConnector:
|
||||||
else:
|
else:
|
||||||
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val'
|
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val'
|
||||||
|
|
||||||
with self.connection.cursor() as cursor:
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
cursor.execute(sql, values)
|
cursor.execute(sql, values)
|
||||||
result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()]
|
result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
|
@ -1375,7 +1578,6 @@ class DatabaseConnector:
|
||||||
emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1'
|
emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1'
|
||||||
else:
|
else:
|
||||||
emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1'
|
emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1'
|
||||||
with self.connection.cursor() as cursor:
|
|
||||||
cursor.execute(emptySql, values)
|
cursor.execute(emptySql, values)
|
||||||
if cursor.fetchone():
|
if cursor.fetchone():
|
||||||
result.append(None)
|
result.append(None)
|
||||||
|
|
@ -1383,7 +1585,6 @@ class DatabaseConnector:
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
|
logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
|
||||||
_rollbackQuietly(getattr(self, "connection", None))
|
|
||||||
raise DatabaseQueryError(table, str(e), original=e) from e
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
||||||
|
|
||||||
def recordCreate(
|
def recordCreate(
|
||||||
|
|
@ -1463,33 +1664,33 @@ class DatabaseConnector:
|
||||||
if not self._ensureTableExists(model_class):
|
if not self._ensureTableExists(model_class):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
with self.connection.cursor() as cursor:
|
# `getInitialId` opens its own borrow; do it BEFORE we acquire a
|
||||||
# Check if record exists
|
# connection ourselves so we don't pin two slots concurrently.
|
||||||
|
initialId = self.getInitialId(model_class)
|
||||||
|
|
||||||
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
|
f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
|
||||||
)
|
)
|
||||||
if not cursor.fetchone():
|
if not cursor.fetchone():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if it's an initial record
|
if initialId is not None and initialId == recordId:
|
||||||
initialId = self.getInitialId(model_class)
|
# `_removeInitialId` borrows its own conn — done outside
|
||||||
|
# this block on purpose to avoid nested borrows.
|
||||||
|
pass
|
||||||
|
cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
|
||||||
|
|
||||||
if initialId is not None and initialId == recordId:
|
if initialId is not None and initialId == recordId:
|
||||||
self._removeInitialId(table)
|
self._removeInitialId(table)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Initial ID {recordId} for table {table} has been removed from the system table"
|
f"Initial ID {recordId} for table {table} has been removed from the system table"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Delete the record
|
|
||||||
cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
|
|
||||||
|
|
||||||
# No cache to update - database handles consistency
|
|
||||||
|
|
||||||
self.connection.commit()
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
|
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
|
||||||
self.connection.rollback()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def recordCreateBulk(
|
def recordCreateBulk(
|
||||||
|
|
@ -1559,16 +1760,11 @@ class DatabaseConnector:
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._ensure_connection()
|
with self.borrowConn() as conn:
|
||||||
with self.connection.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
psycopg2.extras.execute_values(cursor, sql, rows, page_size=500)
|
psycopg2.extras.execute_values(cursor, sql, rows, page_size=500)
|
||||||
self.connection.commit()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}")
|
logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}")
|
||||||
try:
|
|
||||||
self.connection.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if self.getInitialId(model_class) is None and normalised:
|
if self.getInitialId(model_class) is None and normalised:
|
||||||
|
|
@ -1649,8 +1845,8 @@ class DatabaseConnector:
|
||||||
|
|
||||||
initialId = self.getInitialId(model_class)
|
initialId = self.getInitialId(model_class)
|
||||||
try:
|
try:
|
||||||
self._ensure_connection()
|
with self.borrowConn() as conn:
|
||||||
with self.connection.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
if initialId is not None:
|
if initialId is not None:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql,
|
f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql,
|
||||||
|
|
@ -1662,13 +1858,8 @@ class DatabaseConnector:
|
||||||
|
|
||||||
cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params)
|
cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params)
|
||||||
deleted = cursor.rowcount or 0
|
deleted = cursor.rowcount or 0
|
||||||
self.connection.commit()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}")
|
logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}")
|
||||||
try:
|
|
||||||
self.connection.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if deleted and initialIsAffected:
|
if deleted and initialIsAffected:
|
||||||
|
|
@ -1751,39 +1942,30 @@ class DatabaseConnector:
|
||||||
)
|
)
|
||||||
params = [vectorStr] + whereValues + [vectorStr, limit]
|
params = [vectorStr] + whereValues + [vectorStr, limit]
|
||||||
|
|
||||||
with self.connection.cursor() as cursor:
|
with self.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
records = [dict(row) for row in cursor.fetchall()]
|
records = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
fields = getModelFields(modelClass)
|
fields = getModelFields(modelClass)
|
||||||
for record in records:
|
for record in records:
|
||||||
parseRecordFields(record, fields, f"semanticSearch {table}")
|
parseRecordFields(record, fields, f"semanticSearch {table}")
|
||||||
|
|
||||||
return records
|
return records
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in semantic search on {table}: {e}")
|
logger.error(f"Error in semantic search on {table}: {e}")
|
||||||
_rollbackQuietly(getattr(self, "connection", None))
|
|
||||||
raise DatabaseQueryError(table, str(e), original=e) from e
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
||||||
|
|
||||||
def close(self, forceClose: bool = False):
|
def close(self, forceClose: bool = False):
|
||||||
"""Close the database connection.
|
"""No-op for backward compatibility.
|
||||||
|
|
||||||
Shared cached connectors are intentionally kept open unless forceClose=True.
|
Connections are now owned by the `_PoolRegistry` pool and live for the
|
||||||
This prevents accidental shutdown from interface __del__ methods while
|
process lifetime. Pool shutdown happens centrally via `closeAllPools()`
|
||||||
other requests are still using the same cached connector instance.
|
from the FastAPI lifespan hook — never from a connector instance.
|
||||||
|
Interface `__del__` paths used to call `close()` to release a per-
|
||||||
|
connector socket; with pooling there is nothing to close here.
|
||||||
"""
|
"""
|
||||||
if self._isCachedShared and not forceClose:
|
|
||||||
return
|
return
|
||||||
if (
|
|
||||||
hasattr(self, "connection")
|
|
||||||
and self.connection
|
|
||||||
and not self.connection.closed
|
|
||||||
):
|
|
||||||
self.connection.close()
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
"""Cleanup method to close connection."""
|
"""Cleanup hook (intentionally no-op — see `close`)."""
|
||||||
try:
|
return
|
||||||
self.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
|
||||||
|
|
@ -342,7 +342,7 @@ class RealEstateObjects:
|
||||||
# If no exact match, try case-insensitive search via SQL query
|
# If no exact match, try case-insensitive search via SQL query
|
||||||
# This handles cases where the name might have different casing
|
# This handles cases where the name might have different casing
|
||||||
self.db._ensure_connection()
|
self.db._ensure_connection()
|
||||||
with self.db.connection.cursor() as cursor:
|
with self.db.borrowCursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
|
'SELECT "id" FROM "Gemeinde" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
|
||||||
(name,)
|
(name,)
|
||||||
|
|
@ -375,7 +375,7 @@ class RealEstateObjects:
|
||||||
|
|
||||||
# Try case-insensitive search
|
# Try case-insensitive search
|
||||||
self.db._ensure_connection()
|
self.db._ensure_connection()
|
||||||
with self.db.connection.cursor() as cursor:
|
with self.db.borrowCursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
|
'SELECT "id" FROM "Kanton" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
|
||||||
(name,)
|
(name,)
|
||||||
|
|
@ -408,7 +408,7 @@ class RealEstateObjects:
|
||||||
|
|
||||||
# Try case-insensitive search
|
# Try case-insensitive search
|
||||||
self.db._ensure_connection()
|
self.db._ensure_connection()
|
||||||
with self.db.connection.cursor() as cursor:
|
with self.db.borrowCursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
|
'SELECT "id" FROM "Land" WHERE LOWER("label") = LOWER(%s) LIMIT 1',
|
||||||
(name,)
|
(name,)
|
||||||
|
|
@ -840,7 +840,7 @@ class RealEstateObjects:
|
||||||
# Ensure connection is alive
|
# Ensure connection is alive
|
||||||
self.db._ensure_connection()
|
self.db._ensure_connection()
|
||||||
|
|
||||||
with self.db.connection.cursor() as cursor:
|
with self.db.borrowCursor() as cursor:
|
||||||
# Execute query
|
# Execute query
|
||||||
if parameters:
|
if parameters:
|
||||||
# Use parameterized query for safety
|
# Use parameterized query for safety
|
||||||
|
|
|
||||||
|
|
@ -1659,7 +1659,7 @@ class BillingObjects:
|
||||||
try:
|
try:
|
||||||
appInterface = getAppInterface(self.currentUser)
|
appInterface = getAppInterface(self.currentUser)
|
||||||
appInterface.db._ensure_connection()
|
appInterface.db._ensure_connection()
|
||||||
with appInterface.db.connection.cursor() as cur:
|
with appInterface.db.borrowCursor() as cur:
|
||||||
if appInterface.db._ensureTableExists(UserInDB):
|
if appInterface.db._ensureTableExists(UserInDB):
|
||||||
cur.execute(
|
cur.execute(
|
||||||
'SELECT "id" FROM "UserInDB" WHERE '
|
'SELECT "id" FROM "UserInDB" WHERE '
|
||||||
|
|
@ -1780,7 +1780,7 @@ class BillingObjects:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.db._ensure_connection()
|
self.db._ensure_connection()
|
||||||
with self.db.connection.cursor() as cur:
|
with self.db.borrowCursor() as cur:
|
||||||
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
|
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
|
||||||
cur.execute(countSql, whereValues)
|
cur.execute(countSql, whereValues)
|
||||||
totalItems = cur.fetchone()["count"]
|
totalItems = cur.fetchone()["count"]
|
||||||
|
|
@ -1797,10 +1797,7 @@ class BillingObjects:
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
|
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
|
||||||
try:
|
# Rollback is handled by `borrowCursor()` context manager on exit.
|
||||||
self.db.connection.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return {"items": [], "totalItems": 0, "totalPages": 0}
|
return {"items": [], "totalItems": 0, "totalPages": 0}
|
||||||
|
|
||||||
def _buildScopeFilter(
|
def _buildScopeFilter(
|
||||||
|
|
@ -1872,7 +1869,7 @@ class BillingObjects:
|
||||||
|
|
||||||
result: Dict[str, Any] = {}
|
result: Dict[str, Any] = {}
|
||||||
|
|
||||||
with self.db.connection.cursor() as cur:
|
with self.db.borrowCursor() as cur:
|
||||||
# 1) Totals
|
# 1) Totals
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
|
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
|
||||||
|
|
@ -1947,17 +1944,12 @@ class BillingObjects:
|
||||||
})
|
})
|
||||||
result["timeSeries"] = timeSeries
|
result["timeSeries"] = timeSeries
|
||||||
|
|
||||||
self.db.connection.commit()
|
# Commit/rollback are handled by `borrowCursor()` context manager.
|
||||||
|
|
||||||
result["_allAccounts"] = allAccounts
|
result["_allAccounts"] = allAccounts
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
|
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
|
||||||
try:
|
|
||||||
self.db.connection.rollback()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return self._emptyStats()
|
return self._emptyStats()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -228,6 +228,22 @@ class KnowledgeObjects:
|
||||||
"""Get all ContentChunks for a file."""
|
"""Get all ContentChunks for a file."""
|
||||||
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
return self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
||||||
|
|
||||||
|
def countChunksByFileIds(self, fileIds: List[str]) -> Dict[str, int]:
|
||||||
|
"""Return a {fileId: chunkCount} mapping for the given file IDs.
|
||||||
|
|
||||||
|
One aggregate query instead of N round trips. Used by RAG inventory
|
||||||
|
to display real chunk counts per DataSource without loading the
|
||||||
|
embedding vectors. Missing file IDs map to 0 in the caller's logic.
|
||||||
|
"""
|
||||||
|
if not fileIds:
|
||||||
|
return {}
|
||||||
|
if not self.db._ensureTableExists(ContentChunk):
|
||||||
|
return {}
|
||||||
|
sql = 'SELECT "fileId", COUNT(*) AS cnt FROM "ContentChunk" WHERE "fileId" = ANY(%s) GROUP BY "fileId"'
|
||||||
|
with self.db.borrowCursor() as cursor:
|
||||||
|
cursor.execute(sql, (list(fileIds),))
|
||||||
|
return {row["fileId"]: int(row["cnt"]) for row in cursor.fetchall()}
|
||||||
|
|
||||||
def deleteContentChunks(self, fileId: str) -> int:
|
def deleteContentChunks(self, fileId: str) -> int:
|
||||||
"""Delete all ContentChunks for a file. Returns count of deleted chunks."""
|
"""Delete all ContentChunks for a file. Returns count of deleted chunks."""
|
||||||
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
chunks = self.db.getRecordset(ContentChunk, recordFilter={"fileId": fileId})
|
||||||
|
|
|
||||||
|
|
@ -1221,10 +1221,9 @@ class ComponentObjects:
|
||||||
for item in fileRows
|
for item in fileRows
|
||||||
]
|
]
|
||||||
|
|
||||||
# Single transaction: delete FileData, FileItem, then FileFolder (children first)
|
# Single transaction: delete FileData, FileItem, then FileFolder (children first).
|
||||||
self.db._ensure_connection()
|
# Commit/rollback are handled by `borrowCursor()` on exit.
|
||||||
try:
|
with self.db.borrowCursor() as cursor:
|
||||||
with self.db.connection.cursor() as cursor:
|
|
||||||
if fileIds:
|
if fileIds:
|
||||||
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
|
cursor.execute('DELETE FROM "FileData" WHERE "id" = ANY(%s)', (fileIds,))
|
||||||
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
|
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (fileIds,))
|
||||||
|
|
@ -1233,10 +1232,6 @@ class ComponentObjects:
|
||||||
orderedIds.append(folderId)
|
orderedIds.append(folderId)
|
||||||
if orderedIds:
|
if orderedIds:
|
||||||
cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
|
cursor.execute('DELETE FROM "FileFolder" WHERE "id" = ANY(%s)', (orderedIds,))
|
||||||
self.db.connection.commit()
|
|
||||||
except Exception:
|
|
||||||
self.db.connection.rollback()
|
|
||||||
raise
|
|
||||||
|
|
||||||
return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)}
|
return {"deletedFolders": len(folderIds), "deletedFiles": len(fileIds)}
|
||||||
|
|
||||||
|
|
@ -1507,7 +1502,7 @@ class ComponentObjects:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.db._ensure_connection()
|
self.db._ensure_connection()
|
||||||
with self.db.connection.cursor() as cursor:
|
with self.db.borrowCursor() as cursor:
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
|
'SELECT "id", "sysCreatedBy" FROM "FileItem" WHERE "id" = ANY(%s)',
|
||||||
(uniqueIds,),
|
(uniqueIds,),
|
||||||
|
|
@ -1526,11 +1521,10 @@ class ComponentObjects:
|
||||||
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,))
|
cursor.execute('DELETE FROM "FileItem" WHERE "id" = ANY(%s)', (accessibleIds,))
|
||||||
deletedFiles = cursor.rowcount
|
deletedFiles = cursor.rowcount
|
||||||
|
|
||||||
self.db.connection.commit()
|
# Commit/rollback are handled by `borrowCursor()` context manager.
|
||||||
return {"deletedFiles": deletedFiles}
|
return {"deletedFiles": deletedFiles}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting files in batch: {e}")
|
logger.error(f"Error deleting files in batch: {e}")
|
||||||
self.db.connection.rollback()
|
|
||||||
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
|
raise FileDeletionError(f"Error deleting files in batch: {str(e)}")
|
||||||
|
|
||||||
def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]:
|
def _ensureFeatureInstanceGroup(self, featureInstanceId: str, contextKey: str = "files/list") -> Optional[str]:
|
||||||
|
|
|
||||||
|
|
@ -374,7 +374,7 @@ def getRecordsetWithRBAC(
|
||||||
|
|
||||||
query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
|
query = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
|
||||||
|
|
||||||
with connector.connection.cursor() as cursor:
|
with connector.borrowCursor() as cursor:
|
||||||
cursor.execute(query, whereValues)
|
cursor.execute(query, whereValues)
|
||||||
records = [dict(row) for row in cursor.fetchall()]
|
records = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
|
@ -561,7 +561,7 @@ def getRecordsetPaginatedWithRBAC(
|
||||||
offset = (pagination.page - 1) * pagination.pageSize
|
offset = (pagination.page - 1) * pagination.pageSize
|
||||||
limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
|
limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
|
||||||
|
|
||||||
with connector.connection.cursor() as cursor:
|
with connector.borrowCursor() as cursor:
|
||||||
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
|
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
|
||||||
cursor.execute(countSql, countValues)
|
cursor.execute(countSql, countValues)
|
||||||
totalItems = cursor.fetchone()["count"]
|
totalItems = cursor.fetchone()["count"]
|
||||||
|
|
@ -709,7 +709,7 @@ def getDistinctColumnValuesWithRBAC(
|
||||||
|
|
||||||
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val'
|
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{nonNullWhere} ORDER BY val'
|
||||||
|
|
||||||
with connector.connection.cursor() as cursor:
|
with connector.borrowCursor() as cursor:
|
||||||
cursor.execute(sql, whereValues)
|
cursor.execute(sql, whereValues)
|
||||||
result = [row["val"] for row in cursor.fetchall()]
|
result = [row["val"] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
|
@ -719,7 +719,7 @@ def getDistinctColumnValuesWithRBAC(
|
||||||
emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1'
|
emptySql = f'SELECT 1 FROM "{table}"{whereClause} AND {emptyCond} LIMIT 1'
|
||||||
else:
|
else:
|
||||||
emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1'
|
emptySql = f'SELECT 1 FROM "{table}" WHERE {emptyCond} LIMIT 1'
|
||||||
with connector.connection.cursor() as cursor:
|
with connector.borrowCursor() as cursor:
|
||||||
cursor.execute(emptySql, whereValues)
|
cursor.execute(emptySql, whereValues)
|
||||||
if cursor.fetchone():
|
if cursor.fetchone():
|
||||||
result.append(None)
|
result.append(None)
|
||||||
|
|
@ -967,7 +967,7 @@ def buildRbacWhereClause(
|
||||||
# Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate
|
# Multi-Tenant Design: Users do NOT have mandateId - they are linked via UserMandate
|
||||||
if table == "UserInDB":
|
if table == "UserInDB":
|
||||||
try:
|
try:
|
||||||
with connector.connection.cursor() as cursor:
|
with connector.borrowCursor() as cursor:
|
||||||
# Get all user IDs that are members of the current mandate
|
# Get all user IDs that are members of the current mandate
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
|
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
|
||||||
|
|
@ -994,7 +994,7 @@ def buildRbacWhereClause(
|
||||||
# For UserConnection: Filter via UserMandate junction table
|
# For UserConnection: Filter via UserMandate junction table
|
||||||
elif table == "UserConnection":
|
elif table == "UserConnection":
|
||||||
try:
|
try:
|
||||||
with connector.connection.cursor() as cursor:
|
with connector.borrowCursor() as cursor:
|
||||||
# Get all user IDs that are members of the current mandate
|
# Get all user IDs that are members of the current mandate
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
|
'SELECT "userId" FROM "UserMandate" WHERE "mandateId" = %s AND "enabled" = true',
|
||||||
|
|
|
||||||
|
|
@ -305,7 +305,7 @@ def handleIdsMode(
|
||||||
|
|
||||||
sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"'
|
sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"'
|
||||||
|
|
||||||
with db.connection.cursor() as cursor:
|
with db.borrowCursor() as cursor:
|
||||||
cursor.execute(sql, values)
|
cursor.execute(sql, values)
|
||||||
return JSONResponse(content=[row["val"] for row in cursor.fetchall()])
|
return JSONResponse(content=[row["val"] for row in cursor.fetchall()])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,18 @@ router = APIRouter(
|
||||||
|
|
||||||
|
|
||||||
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
|
def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> List[Dict[str, Any]]:
|
||||||
|
"""Build per-connection RAG inventory rows.
|
||||||
|
|
||||||
|
Each DataSource row exposes BOTH numbers because they mean different things:
|
||||||
|
* `fileCount` — distinct files indexed (== `FileContentIndex` rows)
|
||||||
|
* `chunkCount` — embedding-sized text fragments (== `ContentChunk` rows,
|
||||||
|
max `DEFAULT_CHUNK_TOKENS` tokens each, what the vector retrieval
|
||||||
|
actually hits)
|
||||||
|
|
||||||
|
A single PDF typically yields 1 file × 5–100 chunks; legacy UI labelled
|
||||||
|
`len(FileContentIndex)` as "chunks" which was off by 1–2 orders of
|
||||||
|
magnitude and misleading.
|
||||||
|
"""
|
||||||
from modules.datamodels.datamodelDataSource import DataSource
|
from modules.datamodels.datamodelDataSource import DataSource
|
||||||
from modules.datamodels.datamodelKnowledge import FileContentIndex
|
from modules.datamodels.datamodelKnowledge import FileContentIndex
|
||||||
|
|
||||||
|
|
@ -34,19 +46,35 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
||||||
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
|
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
|
||||||
|
|
||||||
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
|
connIndexRows = knowledgeIf.db.getRecordset(FileContentIndex, recordFilter={"connectionId": connectionId})
|
||||||
connChunkTotal = len(connIndexRows)
|
connFileTotal = len(connIndexRows)
|
||||||
|
|
||||||
|
# Map fileId → real chunk count via 1 aggregate query (cheap even for
|
||||||
|
# connections with thousands of files; we never load the vector body).
|
||||||
|
fileIds = [
|
||||||
|
(idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", ""))
|
||||||
|
for idx in connIndexRows
|
||||||
|
]
|
||||||
|
fileIds = [fid for fid in fileIds if fid]
|
||||||
|
chunkCountByFile = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
|
||||||
|
connChunkTotal = sum(chunkCountByFile.values())
|
||||||
|
|
||||||
|
filesByDs: Dict[str, int] = {}
|
||||||
chunksByDs: Dict[str, int] = {}
|
chunksByDs: Dict[str, int] = {}
|
||||||
unassigned = 0
|
unassignedFiles = 0
|
||||||
|
unassignedChunks = 0
|
||||||
for idx in connIndexRows:
|
for idx in connIndexRows:
|
||||||
|
fileId = idx.get("id") if isinstance(idx, dict) else getattr(idx, "id", "")
|
||||||
|
chunkCnt = chunkCountByFile.get(fileId, 0)
|
||||||
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
|
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
|
||||||
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
|
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
|
||||||
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
|
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
|
||||||
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
|
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
|
||||||
if dsIdRef:
|
if dsIdRef:
|
||||||
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
|
filesByDs[dsIdRef] = filesByDs.get(dsIdRef, 0) + 1
|
||||||
|
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + chunkCnt
|
||||||
else:
|
else:
|
||||||
unassigned += 1
|
unassignedFiles += 1
|
||||||
|
unassignedChunks += chunkCnt
|
||||||
|
|
||||||
seen: Dict[str, bool] = {}
|
seen: Dict[str, bool] = {}
|
||||||
dsItems = []
|
dsItems = []
|
||||||
|
|
@ -64,14 +92,19 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
||||||
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
|
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
|
||||||
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
|
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
|
||||||
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
|
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
|
||||||
|
"fileCount": filesByDs.get(dsId, 0),
|
||||||
"chunkCount": chunksByDs.get(dsId, 0),
|
"chunkCount": chunksByDs.get(dsId, 0),
|
||||||
})
|
})
|
||||||
|
|
||||||
if unassigned > 0 and len(dsItems) > 0:
|
# Spread orphan files (provenance lost) evenly so totals match.
|
||||||
perDs = unassigned // len(dsItems)
|
if unassignedFiles > 0 and len(dsItems) > 0:
|
||||||
remainder = unassigned % len(dsItems)
|
perFile = unassignedFiles // len(dsItems)
|
||||||
|
remFile = unassignedFiles % len(dsItems)
|
||||||
|
perChunk = unassignedChunks // len(dsItems)
|
||||||
|
remChunk = unassignedChunks % len(dsItems)
|
||||||
for i, item in enumerate(dsItems):
|
for i, item in enumerate(dsItems):
|
||||||
item["chunkCount"] += perDs + (1 if i < remainder else 0)
|
item["fileCount"] += perFile + (1 if i < remFile else 0)
|
||||||
|
item["chunkCount"] += perChunk + (1 if i < remChunk else 0)
|
||||||
|
|
||||||
# Pull a wider window than the previous 5 so the "last successful
|
# Pull a wider window than the previous 5 so the "last successful
|
||||||
# sync" is found even if a connection has many recent jobs queued.
|
# sync" is found even if a connection has many recent jobs queued.
|
||||||
|
|
@ -102,6 +135,12 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
||||||
"skippedPolicy": result.get("skippedPolicy", 0),
|
"skippedPolicy": result.get("skippedPolicy", 0),
|
||||||
"failed": result.get("failed", 0),
|
"failed": result.get("failed", 0),
|
||||||
"durationMs": result.get("durationMs", 0),
|
"durationMs": result.get("durationMs", 0),
|
||||||
|
# Surface limit-stop reason so the UI can warn the user
|
||||||
|
# that the index is provably incomplete (and which budget
|
||||||
|
# to raise). None means the walker finished naturally.
|
||||||
|
"stoppedAtLimit": result.get("stoppedAtLimit"),
|
||||||
|
"limits": result.get("limits") or {},
|
||||||
|
"bytesProcessed": result.get("bytesProcessed", 0),
|
||||||
}
|
}
|
||||||
if lastError and lastSuccess:
|
if lastError and lastSuccess:
|
||||||
break
|
break
|
||||||
|
|
@ -113,6 +152,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
||||||
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
|
"knowledgeIngestionEnabled": getattr(conn, "knowledgeIngestionEnabled", False),
|
||||||
"preferences": getattr(conn, "knowledgePreferences", None) or {},
|
"preferences": getattr(conn, "knowledgePreferences", None) or {},
|
||||||
"dataSources": dsItems,
|
"dataSources": dsItems,
|
||||||
|
"totalFiles": connFileTotal,
|
||||||
"totalChunks": connChunkTotal,
|
"totalChunks": connChunkTotal,
|
||||||
"runningJobs": runningJobs,
|
"runningJobs": runningJobs,
|
||||||
"lastError": lastError,
|
"lastError": lastError,
|
||||||
|
|
@ -139,8 +179,9 @@ def _getInventoryMe(
|
||||||
|
|
||||||
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
|
items = _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService)
|
||||||
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
||||||
|
totalFiles = sum(c.get("totalFiles", 0) for c in items)
|
||||||
|
|
||||||
return {"connections": items, "totals": {"chunks": totalChunks}}
|
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
|
logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
@ -170,9 +211,10 @@ def _getInventoryMandate(
|
||||||
|
|
||||||
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
|
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
|
||||||
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
||||||
|
totalFiles = sum(c.get("totalFiles", 0) for c in items)
|
||||||
totalBytes = aggregateMandateRagTotalBytes(mandateId)
|
totalBytes = aggregateMandateRagTotalBytes(mandateId)
|
||||||
|
|
||||||
return {"connections": items, "totals": {"chunks": totalChunks, "bytes": totalBytes}}
|
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}}
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -202,8 +244,9 @@ def _getInventoryPlatform(
|
||||||
|
|
||||||
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
|
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService)
|
||||||
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
totalChunks = sum(c.get("totalChunks", 0) for c in items)
|
||||||
|
totalFiles = sum(c.get("totalFiles", 0) for c in items)
|
||||||
|
|
||||||
return {"connections": items, "totals": {"chunks": totalChunks}}
|
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}}
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -227,7 +227,7 @@ WHERE "workflowId" = ANY(%s)
|
||||||
GROUP BY "workflowId"
|
GROUP BY "workflowId"
|
||||||
"""
|
"""
|
||||||
out: dict = {}
|
out: dict = {}
|
||||||
with db.connection.cursor() as cursor:
|
with db.borrowCursor() as cursor:
|
||||||
cursor.execute(sql, (workflowIds,))
|
cursor.execute(sql, (workflowIds,))
|
||||||
for row in cursor.fetchall():
|
for row in cursor.fetchall():
|
||||||
r = dict(row)
|
r = dict(row)
|
||||||
|
|
@ -480,7 +480,7 @@ def _getWorkflowsJoinedPaginated(
|
||||||
dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}"
|
dataSql = f"SELECT w.*, rs.\"lastStartedAt\", rs.\"runCount\", rs.\"activeRunId\" FROM {fromSql}{whereClause}{orderClause}{limitClause}"
|
||||||
|
|
||||||
db._ensure_connection()
|
db._ensure_connection()
|
||||||
with db.connection.cursor() as cursor:
|
with db.borrowCursor() as cursor:
|
||||||
cursor.execute(countSql, countValues)
|
cursor.execute(countSql, countValues)
|
||||||
totalItems = int(cursor.fetchone()["cnt"])
|
totalItems = int(cursor.fetchone()["cnt"])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,15 +25,14 @@ _CACHE_TTL_SECONDS = 300
|
||||||
|
|
||||||
|
|
||||||
def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str):
|
def _getOrCreateFeatureDbConnector(featureDbName: str, userId: str):
|
||||||
"""Reuse a pooled DB connector for the given feature database."""
|
"""Reuse a pooled DB connector for the given feature database.
|
||||||
|
|
||||||
|
The underlying psycopg2 connections live in the central pool
|
||||||
|
(`_PoolRegistry`) and are recreated on demand if they go stale; we just
|
||||||
|
need to keep the lightweight connector wrapper around.
|
||||||
|
"""
|
||||||
if featureDbName in _featureDbConnPool:
|
if featureDbName in _featureDbConnPool:
|
||||||
conn = _featureDbConnPool[featureDbName]
|
return _featureDbConnPool[featureDbName]
|
||||||
try:
|
|
||||||
if conn.connection and not conn.connection.closed:
|
|
||||||
return conn
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Feature DB connection check failed for {featureDbName}: {e}")
|
|
||||||
_featureDbConnPool.pop(featureDbName, None)
|
|
||||||
|
|
||||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||||
from modules.shared.configuration import APP_CONFIG
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,9 @@ class ClickupBootstrapResult:
|
||||||
workspaces: int = 0
|
workspaces: int = 0
|
||||||
lists: int = 0
|
lists: int = 0
|
||||||
errors: List[str] = field(default_factory=list)
|
errors: List[str] = field(default_factory=list)
|
||||||
|
# First budget exhausted: "maxTasks" | "maxWorkspaces" | "maxListsPerWorkspace" | None.
|
||||||
|
# Drives the same UI banner as the file-walker bootstraps.
|
||||||
|
stoppedAtLimit: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def _syntheticTaskId(connectionId: str, taskId: str) -> str:
|
def _syntheticTaskId(connectionId: str, taskId: str) -> str:
|
||||||
|
|
@ -225,6 +228,7 @@ async def bootstrapClickup(
|
||||||
cancelled = False
|
cancelled = False
|
||||||
for ds in dataSources:
|
for ds in dataSources:
|
||||||
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
|
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
|
||||||
|
_recordLimitStop(result, "maxTasks", "dataSource", limits)
|
||||||
break
|
break
|
||||||
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
|
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
|
||||||
cancelled = True
|
cancelled = True
|
||||||
|
|
@ -243,8 +247,11 @@ async def bootstrapClickup(
|
||||||
clickupScope=limits.clickupScope,
|
clickupScope=limits.clickupScope,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if len(teams) > dsLimits.maxWorkspaces:
|
||||||
|
_recordLimitStop(result, "maxWorkspaces", "teams", dsLimits, hard=False)
|
||||||
for team in teams[:dsLimits.maxWorkspaces]:
|
for team in teams[:dsLimits.maxWorkspaces]:
|
||||||
if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks:
|
if result.indexed + result.skippedDuplicate >= dsLimits.maxTasks:
|
||||||
|
_recordLimitStop(result, "maxTasks", f"team={team.get('id','')}", dsLimits)
|
||||||
break
|
break
|
||||||
teamId = str(team.get("id", "") or "")
|
teamId = str(team.get("id", "") or "")
|
||||||
if not teamId:
|
if not teamId:
|
||||||
|
|
@ -351,6 +358,7 @@ async def _walkTeam(
|
||||||
|
|
||||||
for lst in listsCollected:
|
for lst in listsCollected:
|
||||||
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
|
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
|
||||||
|
_recordLimitStop(result, "maxTasks", f"team={teamId}", limits)
|
||||||
return
|
return
|
||||||
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
|
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
|
||||||
return
|
return
|
||||||
|
|
@ -407,6 +415,7 @@ async def _walkList(
|
||||||
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
|
if result.indexed + result.skippedDuplicate >= limits.maxTasks:
|
||||||
|
_recordLimitStop(result, "maxTasks", f"list={listId}", limits)
|
||||||
return
|
return
|
||||||
if not _isRecent(task.get("date_updated"), limits.maxAgeDays):
|
if not _isRecent(task.get("date_updated"), limits.maxAgeDays):
|
||||||
result.skippedPolicy += 1
|
result.skippedPolicy += 1
|
||||||
|
|
@ -529,13 +538,37 @@ async def _ingestTask(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _recordLimitStop(
|
||||||
|
result: ClickupBootstrapResult,
|
||||||
|
limitName: str,
|
||||||
|
where: str,
|
||||||
|
limits: ClickupBootstrapLimits,
|
||||||
|
*,
|
||||||
|
hard: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
|
||||||
|
if hard or result.stoppedAtLimit is None:
|
||||||
|
result.stoppedAtLimit = limitName
|
||||||
|
budgetMap = {
|
||||||
|
"maxTasks": limits.maxTasks,
|
||||||
|
"maxWorkspaces": limits.maxWorkspaces,
|
||||||
|
"maxListsPerWorkspace": limits.maxListsPerWorkspace,
|
||||||
|
}
|
||||||
|
logger.warning(
|
||||||
|
"clickup walker hit %s=%s at %s — partial index (indexed=%d, skippedDup=%d).",
|
||||||
|
limitName, budgetMap.get(limitName), where,
|
||||||
|
result.indexed, result.skippedDuplicate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]:
|
def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs: float) -> Dict[str, Any]:
|
||||||
durationMs = int((time.time() - startMs) * 1000)
|
durationMs = int((time.time() - startMs) * 1000)
|
||||||
logger.info(
|
logger.info(
|
||||||
"ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d",
|
"ingestion.connection.bootstrap.done part=clickup connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d workspaces=%d lists=%d durationMs=%d stoppedAtLimit=%s",
|
||||||
connectionId,
|
connectionId,
|
||||||
result.indexed, result.skippedDuplicate, result.skippedPolicy,
|
result.indexed, result.skippedDuplicate, result.skippedPolicy,
|
||||||
result.failed, result.workspaces, result.lists, durationMs,
|
result.failed, result.workspaces, result.lists, durationMs,
|
||||||
|
result.stoppedAtLimit or "none",
|
||||||
extra={
|
extra={
|
||||||
"event": "ingestion.connection.bootstrap.done",
|
"event": "ingestion.connection.bootstrap.done",
|
||||||
"part": "clickup",
|
"part": "clickup",
|
||||||
|
|
@ -547,6 +580,7 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
|
||||||
"workspaces": result.workspaces,
|
"workspaces": result.workspaces,
|
||||||
"lists": result.lists,
|
"lists": result.lists,
|
||||||
"durationMs": durationMs,
|
"durationMs": durationMs,
|
||||||
|
"stoppedAtLimit": result.stoppedAtLimit,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
|
|
@ -559,4 +593,11 @@ def _finalizeResult(connectionId: str, result: ClickupBootstrapResult, startMs:
|
||||||
"lists": result.lists,
|
"lists": result.lists,
|
||||||
"durationMs": durationMs,
|
"durationMs": durationMs,
|
||||||
"errors": result.errors[:20],
|
"errors": result.errors[:20],
|
||||||
|
"stoppedAtLimit": result.stoppedAtLimit,
|
||||||
|
"limits": {
|
||||||
|
"maxTasks": MAX_TASKS_DEFAULT,
|
||||||
|
"maxWorkspaces": MAX_WORKSPACES_DEFAULT,
|
||||||
|
"maxListsPerWorkspace": MAX_LISTS_PER_WORKSPACE_DEFAULT,
|
||||||
|
"maxAgeDays": MAX_AGE_DAYS_DEFAULT,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,8 @@ class GdriveBootstrapResult:
|
||||||
failed: int = 0
|
failed: int = 0
|
||||||
bytesProcessed: int = 0
|
bytesProcessed: int = 0
|
||||||
errors: List[str] = field(default_factory=list)
|
errors: List[str] = field(default_factory=list)
|
||||||
|
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
|
||||||
|
stoppedAtLimit: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
|
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
|
||||||
|
|
@ -265,8 +267,10 @@ async def _walkFolder(
|
||||||
|
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
||||||
|
_recordLimitStop(result, "maxItems", folderPath, limits)
|
||||||
return
|
return
|
||||||
if result.bytesProcessed >= limits.maxBytes:
|
if result.bytesProcessed >= limits.maxBytes:
|
||||||
|
_recordLimitStop(result, "maxBytes", folderPath, limits)
|
||||||
return
|
return
|
||||||
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
|
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
|
||||||
return
|
return
|
||||||
|
|
@ -276,6 +280,9 @@ async def _walkFolder(
|
||||||
mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType")
|
mimeType = getattr(entry, "mimeType", None) or metadata.get("mimeType")
|
||||||
|
|
||||||
if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME:
|
if getattr(entry, "isFolder", False) or mimeType == FOLDER_MIME:
|
||||||
|
if depth + 1 > limits.maxDepth:
|
||||||
|
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
|
||||||
|
continue
|
||||||
await _walkFolder(
|
await _walkFolder(
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
knowledgeService=knowledgeService,
|
knowledgeService=knowledgeService,
|
||||||
|
|
@ -298,6 +305,7 @@ async def _walkFolder(
|
||||||
continue
|
continue
|
||||||
size = int(getattr(entry, "size", 0) or 0)
|
size = int(getattr(entry, "size", 0) or 0)
|
||||||
if size and size > limits.maxFileSize:
|
if size and size > limits.maxFileSize:
|
||||||
|
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
|
||||||
result.skippedPolicy += 1
|
result.skippedPolicy += 1
|
||||||
continue
|
continue
|
||||||
modifiedTime = metadata.get("modifiedTime")
|
modifiedTime = metadata.get("modifiedTime")
|
||||||
|
|
@ -470,13 +478,38 @@ async def _ingestOne(
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _recordLimitStop(
|
||||||
|
result: GdriveBootstrapResult,
|
||||||
|
limitName: str,
|
||||||
|
where: str,
|
||||||
|
limits: GdriveBootstrapLimits,
|
||||||
|
*,
|
||||||
|
hard: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
|
||||||
|
if hard or result.stoppedAtLimit is None:
|
||||||
|
result.stoppedAtLimit = limitName
|
||||||
|
budgetMap = {
|
||||||
|
"maxItems": limits.maxItems,
|
||||||
|
"maxBytes": limits.maxBytes,
|
||||||
|
"maxDepth": limits.maxDepth,
|
||||||
|
"maxFileSize": limits.maxFileSize,
|
||||||
|
}
|
||||||
|
logger.warning(
|
||||||
|
"gdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
|
||||||
|
limitName, budgetMap.get(limitName), where,
|
||||||
|
result.indexed, result.bytesProcessed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
|
def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
|
||||||
durationMs = int((time.time() - startMs) * 1000)
|
durationMs = int((time.time() - startMs) * 1000)
|
||||||
logger.info(
|
logger.info(
|
||||||
"ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d",
|
"ingestion.connection.bootstrap.done part=gdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d bytes=%d durationMs=%d stoppedAtLimit=%s",
|
||||||
connectionId,
|
connectionId,
|
||||||
result.indexed, result.skippedDuplicate, result.skippedPolicy,
|
result.indexed, result.skippedDuplicate, result.skippedPolicy,
|
||||||
result.failed, result.bytesProcessed, durationMs,
|
result.failed, result.bytesProcessed, durationMs,
|
||||||
|
result.stoppedAtLimit or "none",
|
||||||
extra={
|
extra={
|
||||||
"event": "ingestion.connection.bootstrap.done",
|
"event": "ingestion.connection.bootstrap.done",
|
||||||
"part": "gdrive",
|
"part": "gdrive",
|
||||||
|
|
@ -487,6 +520,7 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
|
||||||
"failed": result.failed,
|
"failed": result.failed,
|
||||||
"bytes": result.bytesProcessed,
|
"bytes": result.bytesProcessed,
|
||||||
"durationMs": durationMs,
|
"durationMs": durationMs,
|
||||||
|
"stoppedAtLimit": result.stoppedAtLimit,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
|
|
@ -498,4 +532,11 @@ def _finalizeResult(connectionId: str, result: GdriveBootstrapResult, startMs: f
|
||||||
"bytesProcessed": result.bytesProcessed,
|
"bytesProcessed": result.bytesProcessed,
|
||||||
"durationMs": durationMs,
|
"durationMs": durationMs,
|
||||||
"errors": result.errors[:20],
|
"errors": result.errors[:20],
|
||||||
|
"stoppedAtLimit": result.stoppedAtLimit,
|
||||||
|
"limits": {
|
||||||
|
"maxItems": MAX_ITEMS_DEFAULT,
|
||||||
|
"maxBytes": MAX_BYTES_DEFAULT,
|
||||||
|
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
|
||||||
|
"maxDepth": MAX_DEPTH_DEFAULT,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,8 @@ class KdriveBootstrapResult:
|
||||||
failed: int = 0
|
failed: int = 0
|
||||||
bytesProcessed: int = 0
|
bytesProcessed: int = 0
|
||||||
errors: List[str] = field(default_factory=list)
|
errors: List[str] = field(default_factory=list)
|
||||||
|
# See SharepointBootstrapResult.stoppedAtLimit — same semantics.
|
||||||
|
stoppedAtLimit: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
|
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
|
||||||
|
|
@ -232,14 +234,19 @@ async def _walkFolder(
|
||||||
|
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
||||||
|
_recordLimitStop(result, "maxItems", folderPath, limits)
|
||||||
return
|
return
|
||||||
if result.bytesProcessed >= limits.maxBytes:
|
if result.bytesProcessed >= limits.maxBytes:
|
||||||
|
_recordLimitStop(result, "maxBytes", folderPath, limits)
|
||||||
return
|
return
|
||||||
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
|
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
|
||||||
return
|
return
|
||||||
|
|
||||||
entryPath = getattr(entry, "path", "") or ""
|
entryPath = getattr(entry, "path", "") or ""
|
||||||
if getattr(entry, "isFolder", False):
|
if getattr(entry, "isFolder", False):
|
||||||
|
if depth + 1 > limits.maxDepth:
|
||||||
|
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
|
||||||
|
continue
|
||||||
await _walkFolder(
|
await _walkFolder(
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
knowledgeService=knowledgeService,
|
knowledgeService=knowledgeService,
|
||||||
|
|
@ -262,6 +269,7 @@ async def _walkFolder(
|
||||||
continue
|
continue
|
||||||
size = int(getattr(entry, "size", 0) or 0)
|
size = int(getattr(entry, "size", 0) or 0)
|
||||||
if size and size > limits.maxFileSize:
|
if size and size > limits.maxFileSize:
|
||||||
|
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
|
||||||
result.skippedPolicy += 1
|
result.skippedPolicy += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -415,17 +423,42 @@ async def _ingestOne(
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _recordLimitStop(
|
||||||
|
result: KdriveBootstrapResult,
|
||||||
|
limitName: str,
|
||||||
|
where: str,
|
||||||
|
limits: KdriveBootstrapLimits,
|
||||||
|
*,
|
||||||
|
hard: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""See subConnectorSyncSharepoint._recordLimitStop for semantics."""
|
||||||
|
if hard or result.stoppedAtLimit is None:
|
||||||
|
result.stoppedAtLimit = limitName
|
||||||
|
budgetMap = {
|
||||||
|
"maxItems": limits.maxItems,
|
||||||
|
"maxBytes": limits.maxBytes,
|
||||||
|
"maxDepth": limits.maxDepth,
|
||||||
|
"maxFileSize": limits.maxFileSize,
|
||||||
|
}
|
||||||
|
logger.warning(
|
||||||
|
"kdrive walker hit %s=%s at %s — partial index (indexed=%d, bytesProcessed=%d).",
|
||||||
|
limitName, budgetMap.get(limitName), where,
|
||||||
|
result.indexed, result.bytesProcessed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
|
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
|
||||||
durationMs = int((time.time() - startMs) * 1000)
|
durationMs = int((time.time() - startMs) * 1000)
|
||||||
logger.info(
|
logger.info(
|
||||||
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
|
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
|
||||||
connectionId,
|
connectionId,
|
||||||
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
|
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
|
||||||
durationMs,
|
durationMs, result.stoppedAtLimit or "none",
|
||||||
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
|
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
|
||||||
"connectionId": connectionId, "indexed": result.indexed,
|
"connectionId": connectionId, "indexed": result.indexed,
|
||||||
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
|
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
|
||||||
"failed": result.failed, "durationMs": durationMs},
|
"failed": result.failed, "durationMs": durationMs,
|
||||||
|
"stoppedAtLimit": result.stoppedAtLimit},
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"connectionId": result.connectionId,
|
"connectionId": result.connectionId,
|
||||||
|
|
@ -436,4 +469,11 @@ def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: f
|
||||||
"bytesProcessed": result.bytesProcessed,
|
"bytesProcessed": result.bytesProcessed,
|
||||||
"durationMs": durationMs,
|
"durationMs": durationMs,
|
||||||
"errors": result.errors[:20],
|
"errors": result.errors[:20],
|
||||||
|
"stoppedAtLimit": result.stoppedAtLimit,
|
||||||
|
"limits": {
|
||||||
|
"maxItems": MAX_ITEMS_DEFAULT,
|
||||||
|
"maxBytes": MAX_BYTES_DEFAULT,
|
||||||
|
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
|
||||||
|
"maxDepth": MAX_DEPTH_DEFAULT,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,10 @@ class SharepointBootstrapResult:
|
||||||
failed: int = 0
|
failed: int = 0
|
||||||
bytesProcessed: int = 0
|
bytesProcessed: int = 0
|
||||||
errors: List[str] = field(default_factory=list)
|
errors: List[str] = field(default_factory=list)
|
||||||
|
# First budget that hit zero; None means the walk completed naturally.
|
||||||
|
# Surfaces in the bootstrap result so the RAG inventory UI can warn the
|
||||||
|
# user that the corpus is incomplete and tell them which knob to turn.
|
||||||
|
stoppedAtLimit: Optional[str] = None # "maxItems" | "maxBytes" | "maxDepth" | "maxFileSize" | None
|
||||||
|
|
||||||
|
|
||||||
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
|
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
|
||||||
|
|
@ -259,14 +263,22 @@ async def _walkFolder(
|
||||||
|
|
||||||
for entry in entries:
|
for entry in entries:
|
||||||
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
||||||
|
_recordLimitStop(result, "maxItems", folderPath, limits)
|
||||||
return
|
return
|
||||||
if result.bytesProcessed >= limits.maxBytes:
|
if result.bytesProcessed >= limits.maxBytes:
|
||||||
|
_recordLimitStop(result, "maxBytes", folderPath, limits)
|
||||||
return
|
return
|
||||||
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
|
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
|
||||||
return
|
return
|
||||||
|
|
||||||
entryPath = getattr(entry, "path", "") or ""
|
entryPath = getattr(entry, "path", "") or ""
|
||||||
if getattr(entry, "isFolder", False):
|
if getattr(entry, "isFolder", False):
|
||||||
|
if depth + 1 > limits.maxDepth:
|
||||||
|
# We stop descending here but keep walking siblings.
|
||||||
|
# Record once per bootstrap so the UI shows "maxDepth" even
|
||||||
|
# if other budgets aren't exhausted yet.
|
||||||
|
_recordLimitStop(result, "maxDepth", entryPath, limits, hard=False)
|
||||||
|
continue
|
||||||
await _walkFolder(
|
await _walkFolder(
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
knowledgeService=knowledgeService,
|
knowledgeService=knowledgeService,
|
||||||
|
|
@ -289,6 +301,7 @@ async def _walkFolder(
|
||||||
continue
|
continue
|
||||||
size = int(getattr(entry, "size", 0) or 0)
|
size = int(getattr(entry, "size", 0) or 0)
|
||||||
if size and size > limits.maxFileSize:
|
if size and size > limits.maxFileSize:
|
||||||
|
_recordLimitStop(result, "maxFileSize", entryPath, limits, hard=False)
|
||||||
result.skippedPolicy += 1
|
result.skippedPolicy += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -443,13 +456,44 @@ async def _ingestOne(
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _recordLimitStop(
|
||||||
|
result: SharepointBootstrapResult,
|
||||||
|
limitName: str,
|
||||||
|
where: str,
|
||||||
|
limits: SharepointBootstrapLimits,
|
||||||
|
*,
|
||||||
|
hard: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Mark the FIRST limit that bit. Soft hits (per-file maxFileSize, per-folder
|
||||||
|
maxDepth) only record when no hard limit has yet stopped the run, so the UI
|
||||||
|
surfaces the most important reason.
|
||||||
|
|
||||||
|
Hard limits (maxItems / maxBytes) ALWAYS overwrite a previously recorded
|
||||||
|
soft limit — once a hard cap is hit, the corpus is provably incomplete.
|
||||||
|
"""
|
||||||
|
if hard or result.stoppedAtLimit is None:
|
||||||
|
result.stoppedAtLimit = limitName
|
||||||
|
budgetMap = {
|
||||||
|
"maxItems": limits.maxItems,
|
||||||
|
"maxBytes": limits.maxBytes,
|
||||||
|
"maxDepth": limits.maxDepth,
|
||||||
|
"maxFileSize": limits.maxFileSize,
|
||||||
|
}
|
||||||
|
logger.warning(
|
||||||
|
"sharepoint walker hit %s=%s at %s — partial index "
|
||||||
|
"(indexed=%d, bytesProcessed=%d). Raise the limit or split the data source.",
|
||||||
|
limitName, budgetMap.get(limitName), where,
|
||||||
|
result.indexed, result.bytesProcessed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]:
|
def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startMs: float) -> Dict[str, Any]:
|
||||||
durationMs = int((time.time() - startMs) * 1000)
|
durationMs = int((time.time() - startMs) * 1000)
|
||||||
logger.info(
|
logger.info(
|
||||||
"ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
|
"ingestion.connection.bootstrap.done part=sharepoint connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d stoppedAtLimit=%s",
|
||||||
connectionId,
|
connectionId,
|
||||||
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
|
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
|
||||||
durationMs,
|
durationMs, result.stoppedAtLimit or "none",
|
||||||
extra={
|
extra={
|
||||||
"event": "ingestion.connection.bootstrap.done",
|
"event": "ingestion.connection.bootstrap.done",
|
||||||
"part": "sharepoint",
|
"part": "sharepoint",
|
||||||
|
|
@ -459,6 +503,7 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
|
||||||
"skippedPolicy": result.skippedPolicy,
|
"skippedPolicy": result.skippedPolicy,
|
||||||
"failed": result.failed,
|
"failed": result.failed,
|
||||||
"durationMs": durationMs,
|
"durationMs": durationMs,
|
||||||
|
"stoppedAtLimit": result.stoppedAtLimit,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
|
|
@ -470,4 +515,11 @@ def _finalizeResult(connectionId: str, result: SharepointBootstrapResult, startM
|
||||||
"bytesProcessed": result.bytesProcessed,
|
"bytesProcessed": result.bytesProcessed,
|
||||||
"durationMs": durationMs,
|
"durationMs": durationMs,
|
||||||
"errors": result.errors[:20],
|
"errors": result.errors[:20],
|
||||||
|
"stoppedAtLimit": result.stoppedAtLimit,
|
||||||
|
"limits": {
|
||||||
|
"maxItems": MAX_ITEMS_DEFAULT,
|
||||||
|
"maxBytes": MAX_BYTES_DEFAULT,
|
||||||
|
"maxFileSize": MAX_FILE_SIZE_DEFAULT,
|
||||||
|
"maxDepth": MAX_DEPTH_DEFAULT,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,8 @@ import logging
|
||||||
import json
|
import json
|
||||||
import base64
|
import base64
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Optional
|
import threading
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from cryptography.hazmat.primitives import hashes
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
|
@ -286,6 +287,16 @@ def handleSecretJson(value: str, userId: str = "system", keyName: str = "unknown
|
||||||
# Structure: {user_id: {key_name: [timestamps]}}
|
# Structure: {user_id: {key_name: [timestamps]}}
|
||||||
_decryption_attempts = {}
|
_decryption_attempts = {}
|
||||||
|
|
||||||
|
# Process-wide plaintext cache for decrypted secrets.
|
||||||
|
# Key: the encrypted ciphertext (which already includes env prefix).
|
||||||
|
# Value: (expiresAtMonotonic, plaintext).
|
||||||
|
# TTL is short enough that key rotation propagates quickly, long enough that
|
||||||
|
# hot DB-init paths (every API call building a connector) don't blow the
|
||||||
|
# decryption rate limit. 60s is a deliberate compromise.
|
||||||
|
_DECRYPTION_CACHE_TTL_S = 60.0
|
||||||
|
_decryption_cache: Dict[str, Tuple[float, str]] = {}
|
||||||
|
_decryption_cache_lock = threading.Lock()
|
||||||
|
|
||||||
def _getMasterKey(envType: str = None) -> bytes:
|
def _getMasterKey(envType: str = None) -> bytes:
|
||||||
"""
|
"""
|
||||||
Get the master key for the specified environment.
|
Get the master key for the specified environment.
|
||||||
|
|
@ -487,6 +498,16 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
|
||||||
"""
|
"""
|
||||||
Decrypt a value using the master key for the current environment.
|
Decrypt a value using the master key for the current environment.
|
||||||
|
|
||||||
|
A short-lived plaintext cache (TTL `_DECRYPTION_CACHE_TTL_S`) is consulted
|
||||||
|
first. The 10/sec rate-limit on cache misses still protects against
|
||||||
|
brute-force attacks; cache HITS bypass it because they are not actual
|
||||||
|
cryptographic operations — they just return the result of an earlier
|
||||||
|
successful decrypt. Without this cache, hot paths like
|
||||||
|
`mainBackgroundJobService._getDb()` (called per RAG inventory poll AND
|
||||||
|
per walker DB call) trigger the rate limit and surface as
|
||||||
|
"Decryption rate limit exceeded for user 'system' key 'DB_PASSWORD_SECRET'"
|
||||||
|
ERRORs in the RAG inventory UI route.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encryptedValue: The encrypted value with prefix
|
encryptedValue: The encrypted value with prefix
|
||||||
userId: The user ID making the request (default: "system")
|
userId: The user ID making the request (default: "system")
|
||||||
|
|
@ -501,7 +522,15 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
|
||||||
if not _isEncryptedValue(encryptedValue):
|
if not _isEncryptedValue(encryptedValue):
|
||||||
return encryptedValue # Return as-is if not encrypted
|
return encryptedValue # Return as-is if not encrypted
|
||||||
|
|
||||||
# Check rate limiting (10 per second per user per key)
|
# Cache lookup BEFORE the rate-limit check: a cache hit is not a new
|
||||||
|
# cryptographic operation and must not be throttled.
|
||||||
|
now = time.monotonic()
|
||||||
|
with _decryption_cache_lock:
|
||||||
|
cached = _decryption_cache.get(encryptedValue)
|
||||||
|
if cached is not None and cached[0] > now:
|
||||||
|
return cached[1]
|
||||||
|
|
||||||
|
# Cache miss → real decrypt → apply rate limit.
|
||||||
if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10):
|
if not _checkDecryptionRateLimit(userId, keyName, maxPerSecond=10):
|
||||||
raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)")
|
raise ValueError(f"Decryption rate limit exceeded for user '{userId}' key '{keyName}' (10/sec)")
|
||||||
|
|
||||||
|
|
@ -550,10 +579,24 @@ def decryptValue(encryptedValue: str, userId: str = "system", keyName: str = "un
|
||||||
# Don't fail if audit logging fails
|
# Don't fail if audit logging fails
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Populate cache so subsequent reads of the same ciphertext don't
|
||||||
|
# re-decrypt (and don't consume rate-limit budget).
|
||||||
|
with _decryption_cache_lock:
|
||||||
|
_decryption_cache[encryptedValue] = (
|
||||||
|
time.monotonic() + _DECRYPTION_CACHE_TTL_S,
|
||||||
|
decryptedValue,
|
||||||
|
)
|
||||||
|
|
||||||
return decryptedValue
|
return decryptedValue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Decryption failed: {e}")
|
raise ValueError(f"Decryption failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def clearDecryptionCache() -> None:
|
||||||
|
"""Drop all cached plaintext secrets. Call after key rotation or in tests."""
|
||||||
|
with _decryption_cache_lock:
|
||||||
|
_decryption_cache.clear()
|
||||||
|
|
||||||
# Create the global APP_CONFIG instance
|
# Create the global APP_CONFIG instance
|
||||||
APP_CONFIG = Configuration()
|
APP_CONFIG = Configuration()
|
||||||
|
|
@ -33,20 +33,35 @@ def _ensureUamTablesMatchModels(dbConnector) -> None:
|
||||||
logger.debug(f"_ensureUamTablesMatchModels: {e}")
|
logger.debug(f"_ensureUamTablesMatchModels: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _getConnection(dbConnector):
|
from contextlib import contextmanager
|
||||||
"""Get a connection from the DatabaseConnector.
|
|
||||||
|
|
||||||
Ensures the connection is alive and returns it.
|
|
||||||
Commits any pending transaction first to avoid blocking.
|
@contextmanager
|
||||||
|
def _borrowDbConn(dbConnector):
|
||||||
|
"""Borrow a pooled connection from the DatabaseConnector.
|
||||||
|
|
||||||
|
Index/trigger/FK creation traditionally ran with `conn.autocommit = True`
|
||||||
|
so each CREATE statement is its own transaction (DDL on a managed
|
||||||
|
connection blocks waiting for COMMIT). This helper preserves that
|
||||||
|
behaviour on top of the pool: borrow a connection, flip it to autocommit,
|
||||||
|
yield it, and restore the previous state before returning it to the pool.
|
||||||
"""
|
"""
|
||||||
dbConnector._ensure_connection()
|
with dbConnector.borrowConn() as conn:
|
||||||
conn = dbConnector.connection
|
|
||||||
# Commit any pending transaction to avoid blocking
|
|
||||||
try:
|
try:
|
||||||
conn.commit()
|
previousAutocommit = conn.autocommit
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Ignore if nothing to commit
|
previousAutocommit = False
|
||||||
return conn
|
try:
|
||||||
|
conn.autocommit = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not set autocommit on borrowed connection: {e}")
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
conn.autocommit = previousAutocommit
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
@ -174,34 +189,15 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get a connection from the connector
|
|
||||||
conn = _getConnection(dbConnector)
|
|
||||||
|
|
||||||
# Save and set autocommit state
|
|
||||||
try:
|
|
||||||
originalAutocommit = conn.autocommit
|
|
||||||
except Exception:
|
|
||||||
originalAutocommit = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
conn.autocommit = True
|
|
||||||
except Exception as autoErr:
|
|
||||||
logger.debug(f"Could not set autocommit: {autoErr}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_ensureUamTablesMatchModels(dbConnector)
|
_ensureUamTablesMatchModels(dbConnector)
|
||||||
except Exception as preIdxErr:
|
except Exception as preIdxErr:
|
||||||
logger.debug(f"Pre-index table ensure: {preIdxErr}")
|
logger.debug(f"Pre-index table ensure: {preIdxErr}")
|
||||||
|
|
||||||
try:
|
with _borrowDbConn(dbConnector) as conn:
|
||||||
with conn.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
# Apply indexes
|
|
||||||
results["indexesCreated"] = _applyIndexes(cursor, tables)
|
results["indexesCreated"] = _applyIndexes(cursor, tables)
|
||||||
|
|
||||||
# Apply foreign keys
|
|
||||||
results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables)
|
results["foreignKeysCreated"] = _applyForeignKeys(cursor, tables)
|
||||||
|
|
||||||
# Apply immutable triggers
|
|
||||||
results["triggersCreated"] = _applyImmutableTriggers(cursor, tables)
|
results["triggersCreated"] = _applyImmutableTriggers(cursor, tables)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -210,12 +206,6 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
|
||||||
f"{results['triggersCreated']} triggers, "
|
f"{results['triggersCreated']} triggers, "
|
||||||
f"{results['foreignKeysCreated']} foreign keys"
|
f"{results['foreignKeysCreated']} foreign keys"
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
# Restore original autocommit state
|
|
||||||
try:
|
|
||||||
conn.autocommit = originalAutocommit
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}")
|
logger.error(f"Error applying multi-tenant optimizations: {type(e).__name__}: {e}")
|
||||||
|
|
@ -227,20 +217,14 @@ def applyMultiTenantOptimizations(dbConnector, tables: Optional[List[str]] = Non
|
||||||
def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int:
|
def applyIndexesOnly(dbConnector, tables: Optional[List[str]] = None) -> int:
|
||||||
"""Apply only indexes (lighter operation, safe for frequent calls)."""
|
"""Apply only indexes (lighter operation, safe for frequent calls)."""
|
||||||
try:
|
try:
|
||||||
conn = _getConnection(dbConnector)
|
|
||||||
originalAutocommit = conn.autocommit
|
|
||||||
conn.autocommit = True
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_ensureUamTablesMatchModels(dbConnector)
|
_ensureUamTablesMatchModels(dbConnector)
|
||||||
except Exception as preIdxErr:
|
except Exception as preIdxErr:
|
||||||
logger.debug(f"Pre-index table ensure: {preIdxErr}")
|
logger.debug(f"Pre-index table ensure: {preIdxErr}")
|
||||||
|
|
||||||
try:
|
with _borrowDbConn(dbConnector) as conn:
|
||||||
with conn.cursor() as cursor:
|
with conn.cursor() as cursor:
|
||||||
return _applyIndexes(cursor, tables)
|
return _applyIndexes(cursor, tables)
|
||||||
finally:
|
|
||||||
conn.autocommit = originalAutocommit
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error applying indexes: {e}")
|
logger.error(f"Error applying indexes: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|
@ -514,8 +498,7 @@ def getOptimizationStatus(dbConnector) -> dict:
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = _getConnection(dbConnector)
|
with _borrowDbConn(dbConnector) as conn, conn.cursor() as cursor:
|
||||||
with conn.cursor() as cursor:
|
|
||||||
# Check regular indexes
|
# Check regular indexes
|
||||||
for tableName, indexName, _ in _INDEXES:
|
for tableName, indexName, _ in _INDEXES:
|
||||||
if _tableExists(cursor, tableName):
|
if _tableExists(cursor, tableName):
|
||||||
|
|
|
||||||
|
|
@ -60,11 +60,9 @@ def _getTableColumns(dbConnector, tableName: str) -> List[str]:
|
||||||
ORDER BY ordinal_position
|
ORDER BY ordinal_position
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cursor = dbConnector.connection.cursor()
|
with dbConnector.borrowCursor() as cursor:
|
||||||
cursor.execute(query, (tableName,))
|
cursor.execute(query, (tableName,))
|
||||||
columns = [row[0] for row in cursor.fetchall()]
|
columns = [row[0] for row in cursor.fetchall()]
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
return columns
|
return columns
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting columns for table {tableName}: {e}")
|
logger.error(f"Error getting columns for table {tableName}: {e}")
|
||||||
|
|
@ -92,11 +90,10 @@ def _getAllTables(dbConnector) -> List[str]:
|
||||||
ORDER BY table_name
|
ORDER BY table_name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cursor = dbConnector.connection.cursor()
|
with dbConnector.borrowCursor() as cursor:
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
allTables = [row[0] for row in cursor.fetchall()]
|
allTables = [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
# Get foreign key relationships to determine dependency order
|
|
||||||
fkQuery = """
|
fkQuery = """
|
||||||
SELECT
|
SELECT
|
||||||
tc.table_name,
|
tc.table_name,
|
||||||
|
|
@ -111,10 +108,8 @@ def _getAllTables(dbConnector) -> List[str]:
|
||||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||||
AND tc.table_schema = 'public'
|
AND tc.table_schema = 'public'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cursor.execute(fkQuery)
|
cursor.execute(fkQuery)
|
||||||
foreignKeys = cursor.fetchall()
|
foreignKeys = cursor.fetchall()
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
# Build dependency graph (child -> parent mapping)
|
# Build dependency graph (child -> parent mapping)
|
||||||
dependencies = {}
|
dependencies = {}
|
||||||
|
|
@ -154,10 +149,9 @@ def _getAllTables(dbConnector) -> List[str]:
|
||||||
# Fallback: return simple list without ordering
|
# Fallback: return simple list without ordering
|
||||||
try:
|
try:
|
||||||
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
|
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
|
||||||
cursor = dbConnector.connection.cursor()
|
with dbConnector.borrowCursor() as cursor:
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
tables = [row[0] for row in cursor.fetchall()]
|
tables = [row[0] for row in cursor.fetchall()]
|
||||||
cursor.close()
|
|
||||||
return [t for t in tables if t not in PROTECTED_TABLES]
|
return [t for t in tables if t not in PROTECTED_TABLES]
|
||||||
except Exception:
|
except Exception:
|
||||||
return []
|
return []
|
||||||
|
|
@ -184,11 +178,9 @@ def _getPrimaryKeyColumns(dbConnector, tableName: str) -> List[str]:
|
||||||
AND i.indisprimary
|
AND i.indisprimary
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cursor = dbConnector.connection.cursor()
|
with dbConnector.borrowCursor() as cursor:
|
||||||
cursor.execute(query, (tableName,))
|
cursor.execute(query, (tableName,))
|
||||||
pkColumns = [row[0] for row in cursor.fetchall()]
|
pkColumns = [row[0] for row in cursor.fetchall()]
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
return pkColumns
|
return pkColumns
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Could not get primary key for {tableName}: {e}")
|
logger.debug(f"Could not get primary key for {tableName}: {e}")
|
||||||
|
|
@ -229,21 +221,15 @@ def _findUserReferencesInTable(
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
references = {}
|
references = {}
|
||||||
cursor = dbConnector.connection.cursor()
|
with dbConnector.borrowCursor() as cursor:
|
||||||
|
|
||||||
for userColumn in userColumns:
|
for userColumn in userColumns:
|
||||||
# Build SELECT for primary key columns
|
|
||||||
pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
|
pkSelect = ", ".join([f'"{pk}"' for pk in pkColumns])
|
||||||
query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
|
query = f'SELECT {pkSelect} FROM "{tableName}" WHERE "{userColumn}" = %s'
|
||||||
|
|
||||||
cursor.execute(query, (userId,))
|
cursor.execute(query, (userId,))
|
||||||
recordKeys = cursor.fetchall()
|
recordKeys = cursor.fetchall()
|
||||||
|
|
||||||
if recordKeys:
|
if recordKeys:
|
||||||
references[userColumn] = recordKeys
|
references[userColumn] = recordKeys
|
||||||
logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
|
logger.debug(f"Found {len(recordKeys)} records in {tableName}.{userColumn} for user {userId}")
|
||||||
|
|
||||||
cursor.close()
|
|
||||||
return references
|
return references
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -277,17 +263,15 @@ def _anonymizeRecords(
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cursor = dbConnector.connection.cursor()
|
# Resolve column metadata once outside the borrow block (it borrows its
|
||||||
count = 0
|
# own connection internally).
|
||||||
|
|
||||||
for recordKey in recordKeys:
|
|
||||||
# Build WHERE clause for primary key
|
|
||||||
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
|
|
||||||
|
|
||||||
# Check if table has sysModifiedAt column
|
|
||||||
columns = _getTableColumns(dbConnector, tableName)
|
columns = _getTableColumns(dbConnector, tableName)
|
||||||
hasModifiedAt = "sysModifiedAt" in columns
|
hasModifiedAt = "sysModifiedAt" in columns
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
with dbConnector.borrowCursor() as cursor:
|
||||||
|
for recordKey in recordKeys:
|
||||||
|
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
|
||||||
if hasModifiedAt:
|
if hasModifiedAt:
|
||||||
query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
|
query = f'UPDATE "{tableName}" SET "{columnName}" = %s, "sysModifiedAt" = %s WHERE {whereClause}'
|
||||||
params = [anonymousValue, getUtcTimestamp()]
|
params = [anonymousValue, getUtcTimestamp()]
|
||||||
|
|
@ -295,7 +279,6 @@ def _anonymizeRecords(
|
||||||
query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
|
query = f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE {whereClause}'
|
||||||
params = [anonymousValue]
|
params = [anonymousValue]
|
||||||
|
|
||||||
# Add primary key values to params
|
|
||||||
if isinstance(recordKey, tuple):
|
if isinstance(recordKey, tuple):
|
||||||
params.extend(recordKey)
|
params.extend(recordKey)
|
||||||
else:
|
else:
|
||||||
|
|
@ -304,15 +287,11 @@ def _anonymizeRecords(
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
count += cursor.rowcount
|
count += cursor.rowcount
|
||||||
|
|
||||||
dbConnector.connection.commit()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
logger.info(f"Anonymized {count} records in {tableName}.{columnName}")
|
logger.info(f"Anonymized {count} records in {tableName}.{columnName}")
|
||||||
return count
|
return count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}")
|
logger.error(f"Error anonymizing records in {tableName}.{columnName}: {e}")
|
||||||
dbConnector.connection.rollback()
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -338,32 +317,23 @@ def _deleteRecords(
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cursor = dbConnector.connection.cursor()
|
|
||||||
count = 0
|
count = 0
|
||||||
|
with dbConnector.borrowCursor() as cursor:
|
||||||
for recordKey in recordKeys:
|
for recordKey in recordKeys:
|
||||||
# Build WHERE clause for primary key
|
|
||||||
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
|
whereClause = " AND ".join([f'"{pk}" = %s' for pk in pkColumns])
|
||||||
query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
|
query = f'DELETE FROM "{tableName}" WHERE {whereClause}'
|
||||||
|
|
||||||
# Prepare params
|
|
||||||
if isinstance(recordKey, tuple):
|
if isinstance(recordKey, tuple):
|
||||||
params = list(recordKey)
|
params = list(recordKey)
|
||||||
else:
|
else:
|
||||||
params = [recordKey]
|
params = [recordKey]
|
||||||
|
|
||||||
cursor.execute(query, params)
|
cursor.execute(query, params)
|
||||||
count += cursor.rowcount
|
count += cursor.rowcount
|
||||||
|
|
||||||
dbConnector.connection.commit()
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
logger.info(f"Deleted {count} records from {tableName}")
|
logger.info(f"Deleted {count} records from {tableName}")
|
||||||
return count
|
return count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting records from {tableName}: {e}")
|
logger.error(f"Error deleting records from {tableName}: {e}")
|
||||||
dbConnector.connection.rollback()
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ if not c or not c.connection:
|
||||||
print("STAGE0: DB_CONNECTION=none (check config.ini / .env)")
|
print("STAGE0: DB_CONNECTION=none (check config.ini / .env)")
|
||||||
raise SystemExit(2)
|
raise SystemExit(2)
|
||||||
|
|
||||||
cur = c.connection.cursor()
|
cur = c.borrowCursor()
|
||||||
|
|
||||||
|
|
||||||
def _scalar(cur):
|
def _scalar(cur):
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,16 @@ broken query into "no rows found". That hid bugs like:
|
||||||
|
|
||||||
These tests pin the new contract: empty result sets still return ``[]`` /
|
These tests pin the new contract: empty result sets still return ``[]`` /
|
||||||
``None`` (normal), but any exception inside the query path propagates as
|
``None`` (normal), but any exception inside the query path propagates as
|
||||||
``DatabaseQueryError`` with the table name attached. The transaction is
|
``DatabaseQueryError`` with the table name attached.
|
||||||
rolled back so the connection is usable for subsequent queries.
|
|
||||||
|
Since the 2026-05-17 pool refactor (`c-work/2-build/2026-05-postgres-connection-pool.md`)
|
||||||
|
the connector borrows a connection from `_PoolRegistry` on every call via the
|
||||||
|
`borrowConn()` context manager. The tests mock that context manager so the
|
||||||
|
fast-fail contract is exercised without requiring a live Postgres server.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -25,7 +30,6 @@ import psycopg2.errors
|
||||||
from modules.connectors.connectorDbPostgre import (
|
from modules.connectors.connectorDbPostgre import (
|
||||||
DatabaseConnector,
|
DatabaseConnector,
|
||||||
DatabaseQueryError,
|
DatabaseQueryError,
|
||||||
_rollbackQuietly,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -39,26 +43,44 @@ class DummyTable:
|
||||||
|
|
||||||
|
|
||||||
def _makeConnector(cursorBehavior):
|
def _makeConnector(cursorBehavior):
|
||||||
"""Build a ``DatabaseConnector`` skeleton with mocked connection/cursor.
|
"""Build a ``DatabaseConnector`` skeleton with a mocked pool borrow.
|
||||||
|
|
||||||
``cursorBehavior`` is a callable invoked with the cursor mock so the test
|
``cursorBehavior`` is a callable invoked with the cursor mock so the test
|
||||||
can configure ``execute``/``fetchall``/``fetchone`` per scenario.
|
can configure ``execute``/``fetchall``/``fetchone`` per scenario.
|
||||||
|
|
||||||
|
Returns ``(connector, conn, cursor)``:
|
||||||
|
* ``conn`` exposes ``commit`` / ``rollback`` MagicMocks so tests can
|
||||||
|
assert that the borrow lifecycle did the right thing.
|
||||||
|
* ``cursor`` is the per-test cursor mock.
|
||||||
"""
|
"""
|
||||||
connector = DatabaseConnector.__new__(DatabaseConnector)
|
connector = DatabaseConnector.__new__(DatabaseConnector)
|
||||||
|
|
||||||
cursor = MagicMock()
|
cursor = MagicMock()
|
||||||
|
cursorBehavior(cursor)
|
||||||
|
|
||||||
cursorContext = MagicMock()
|
cursorContext = MagicMock()
|
||||||
cursorContext.__enter__ = MagicMock(return_value=cursor)
|
cursorContext.__enter__ = MagicMock(return_value=cursor)
|
||||||
cursorContext.__exit__ = MagicMock(return_value=False)
|
cursorContext.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
connection = MagicMock()
|
conn = MagicMock()
|
||||||
connection.cursor.return_value = cursorContext
|
conn.cursor.return_value = cursorContext
|
||||||
connector.connection = connection
|
|
||||||
|
@contextmanager
|
||||||
|
def fakeBorrow():
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
connector.borrowConn = fakeBorrow
|
||||||
|
|
||||||
connector._ensureTableExists = MagicMock(return_value=True)
|
connector._ensureTableExists = MagicMock(return_value=True)
|
||||||
connector._systemTableName = "_system"
|
connector._systemTableName = "_system"
|
||||||
|
|
||||||
cursorBehavior(cursor)
|
return connector, conn, cursor
|
||||||
return connector, connection, cursor
|
|
||||||
|
|
||||||
|
|
||||||
class TestGetRecordsetFailLoud:
|
class TestGetRecordsetFailLoud:
|
||||||
|
|
@ -67,11 +89,12 @@ class TestGetRecordsetFailLoud:
|
||||||
def behavior(cursor):
|
def behavior(cursor):
|
||||||
cursor.execute.return_value = None
|
cursor.execute.return_value = None
|
||||||
cursor.fetchall.return_value = []
|
cursor.fetchall.return_value = []
|
||||||
connector, connection, _ = _makeConnector(behavior)
|
connector, conn, _ = _makeConnector(behavior)
|
||||||
|
|
||||||
result = connector.getRecordset(DummyTable)
|
result = connector.getRecordset(DummyTable)
|
||||||
assert result == []
|
assert result == []
|
||||||
connection.rollback.assert_not_called()
|
conn.rollback.assert_not_called()
|
||||||
|
conn.commit.assert_called_once()
|
||||||
|
|
||||||
def test_dictAdaptErrorRaisesDatabaseQueryError(self):
|
def test_dictAdaptErrorRaisesDatabaseQueryError(self):
|
||||||
"""Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise."""
|
"""Reproduces the Trustee bug: passing a dict in WHERE → can't adapt → raise."""
|
||||||
|
|
@ -79,7 +102,7 @@ class TestGetRecordsetFailLoud:
|
||||||
cursor.execute.side_effect = psycopg2.ProgrammingError(
|
cursor.execute.side_effect = psycopg2.ProgrammingError(
|
||||||
"can't adapt type 'dict'"
|
"can't adapt type 'dict'"
|
||||||
)
|
)
|
||||||
connector, connection, _ = _makeConnector(behavior)
|
connector, conn, _ = _makeConnector(behavior)
|
||||||
|
|
||||||
with pytest.raises(DatabaseQueryError) as excinfo:
|
with pytest.raises(DatabaseQueryError) as excinfo:
|
||||||
connector.getRecordset(
|
connector.getRecordset(
|
||||||
|
|
@ -90,30 +113,30 @@ class TestGetRecordsetFailLoud:
|
||||||
assert excinfo.value.table == "DummyTable"
|
assert excinfo.value.table == "DummyTable"
|
||||||
assert "can't adapt type 'dict'" in str(excinfo.value)
|
assert "can't adapt type 'dict'" in str(excinfo.value)
|
||||||
assert isinstance(excinfo.value.original, psycopg2.ProgrammingError)
|
assert isinstance(excinfo.value.original, psycopg2.ProgrammingError)
|
||||||
connection.rollback.assert_called_once()
|
conn.rollback.assert_called_once()
|
||||||
|
|
||||||
def test_missingColumnRaisesDatabaseQueryError(self):
|
def test_missingColumnRaisesDatabaseQueryError(self):
|
||||||
def behavior(cursor):
|
def behavior(cursor):
|
||||||
cursor.execute.side_effect = psycopg2.errors.UndefinedColumn(
|
cursor.execute.side_effect = psycopg2.errors.UndefinedColumn(
|
||||||
'column "wat" does not exist'
|
'column "wat" does not exist'
|
||||||
)
|
)
|
||||||
connector, connection, _ = _makeConnector(behavior)
|
connector, conn, _ = _makeConnector(behavior)
|
||||||
|
|
||||||
with pytest.raises(DatabaseQueryError) as excinfo:
|
with pytest.raises(DatabaseQueryError) as excinfo:
|
||||||
connector.getRecordset(DummyTable, recordFilter={"wat": "x"})
|
connector.getRecordset(DummyTable, recordFilter={"wat": "x"})
|
||||||
|
|
||||||
assert "wat" in str(excinfo.value)
|
assert "wat" in str(excinfo.value)
|
||||||
connection.rollback.assert_called_once()
|
conn.rollback.assert_called_once()
|
||||||
|
|
||||||
def test_operationalErrorRaisesDatabaseQueryError(self):
|
def test_operationalErrorRaisesDatabaseQueryError(self):
|
||||||
"""Connection lost mid-query is also a real failure that must propagate."""
|
"""Connection lost mid-query is also a real failure that must propagate."""
|
||||||
def behavior(cursor):
|
def behavior(cursor):
|
||||||
cursor.execute.side_effect = psycopg2.OperationalError("connection lost")
|
cursor.execute.side_effect = psycopg2.OperationalError("connection lost")
|
||||||
connector, connection, _ = _makeConnector(behavior)
|
connector, conn, _ = _makeConnector(behavior)
|
||||||
|
|
||||||
with pytest.raises(DatabaseQueryError):
|
with pytest.raises(DatabaseQueryError):
|
||||||
connector.getRecordset(DummyTable)
|
connector.getRecordset(DummyTable)
|
||||||
connection.rollback.assert_called_once()
|
conn.rollback.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class TestGetRecordFailLoud:
|
class TestGetRecordFailLoud:
|
||||||
|
|
@ -122,37 +145,22 @@ class TestGetRecordFailLoud:
|
||||||
def behavior(cursor):
|
def behavior(cursor):
|
||||||
cursor.execute.return_value = None
|
cursor.execute.return_value = None
|
||||||
cursor.fetchone.return_value = None
|
cursor.fetchone.return_value = None
|
||||||
connector, connection, _ = _makeConnector(behavior)
|
connector, conn, _ = _makeConnector(behavior)
|
||||||
|
|
||||||
result = connector.getRecord(DummyTable, "missing-id")
|
result = connector.getRecord(DummyTable, "missing-id")
|
||||||
assert result is None
|
assert result is None
|
||||||
connection.rollback.assert_not_called()
|
conn.rollback.assert_not_called()
|
||||||
|
conn.commit.assert_called_once()
|
||||||
|
|
||||||
def test_queryErrorRaisesDatabaseQueryError(self):
|
def test_queryErrorRaisesDatabaseQueryError(self):
|
||||||
def behavior(cursor):
|
def behavior(cursor):
|
||||||
cursor.execute.side_effect = psycopg2.errors.UndefinedTable(
|
cursor.execute.side_effect = psycopg2.errors.UndefinedTable(
|
||||||
'relation "DummyTable" does not exist'
|
'relation "DummyTable" does not exist'
|
||||||
)
|
)
|
||||||
connector, connection, _ = _makeConnector(behavior)
|
connector, conn, _ = _makeConnector(behavior)
|
||||||
|
|
||||||
with pytest.raises(DatabaseQueryError) as excinfo:
|
with pytest.raises(DatabaseQueryError) as excinfo:
|
||||||
connector.getRecord(DummyTable, "any-id")
|
connector.getRecord(DummyTable, "any-id")
|
||||||
|
|
||||||
assert excinfo.value.table == "DummyTable"
|
assert excinfo.value.table == "DummyTable"
|
||||||
connection.rollback.assert_called_once()
|
conn.rollback.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
class TestRollbackQuietly:
|
|
||||||
def test_rollsBackOnLiveConnection(self):
|
|
||||||
connection = MagicMock()
|
|
||||||
_rollbackQuietly(connection)
|
|
||||||
connection.rollback.assert_called_once()
|
|
||||||
|
|
||||||
def test_swallowsRollbackError(self):
|
|
||||||
"""Rollback failure must not mask the original query error."""
|
|
||||||
connection = MagicMock()
|
|
||||||
connection.rollback.side_effect = RuntimeError("rollback failed")
|
|
||||||
_rollbackQuietly(connection)
|
|
||||||
|
|
||||||
def test_noopOnNoneConnection(self):
|
|
||||||
_rollbackQuietly(None)
|
|
||||||
|
|
|
||||||
304
tests/unit/connectors/test_connectorDbPostgre_pool.py
Normal file
304
tests/unit/connectors/test_connectorDbPostgre_pool.py
Normal file
|
|
@ -0,0 +1,304 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Concurrency tests for the PostgreSQL connection pool.
|
||||||
|
|
||||||
|
These tests pin the contract that the `c-work/2-build/2026-05-postgres-connection-pool.md`
|
||||||
|
refactor delivered:
|
||||||
|
|
||||||
|
* T1 — 50 threads × 100 calls in parallel produce 0 `OperationalError`s and
|
||||||
|
every call completes within reasonable time (p99 < 2 s).
|
||||||
|
* T2 — Two threads `_loadRecord` + `_saveRecord` against the same connector
|
||||||
|
do not corrupt each other's cursors.
|
||||||
|
* T3 — `statement_timeout` aborts a runaway `pg_sleep(60)` after ~30 s and
|
||||||
|
releases the connection back into the pool cleanly.
|
||||||
|
|
||||||
|
The tests need a real PostgreSQL server because the bug they guard against
|
||||||
|
only materialises with real psycopg2 sockets — a mocked connection never
|
||||||
|
hangs in `recv()`. They read DB credentials from `APP_CONFIG` (which loads
|
||||||
|
`.env`) and are auto-skipped when the connection fails (no local Postgres,
|
||||||
|
wrong creds, etc.) so `pytest` keeps working in CI-only environments.
|
||||||
|
|
||||||
|
To run them locally:
|
||||||
|
|
||||||
|
pytest gateway/tests/unit/connectors/test_connectorDbPostgre_pool.py -v
|
||||||
|
|
||||||
|
They use a throwaway database name (`poweron_pool_test_<uuid>`) and drop it
|
||||||
|
in fixture teardown so they leave nothing behind.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import threading
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
import psycopg2.errors
|
||||||
|
import pytest
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from modules.connectors.connectorDbPostgre import (
|
||||||
|
DatabaseConnector,
|
||||||
|
_PoolRegistry,
|
||||||
|
closeAllPools,
|
||||||
|
)
|
||||||
|
from modules.datamodels.datamodelBase import PowerOnModel
|
||||||
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def _dbConfig():
|
||||||
|
"""Read DB connection params from APP_CONFIG (`.env`).
|
||||||
|
|
||||||
|
Returns ``None`` when host/user/password are not all present so the
|
||||||
|
test module can skip cleanly instead of blowing up at import time.
|
||||||
|
"""
|
||||||
|
host = APP_CONFIG.get("DB_HOST")
|
||||||
|
user = APP_CONFIG.get("DB_USER")
|
||||||
|
password = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
||||||
|
port = APP_CONFIG.get("DB_PORT", 5432)
|
||||||
|
if not host or not user or password is None:
|
||||||
|
return None
|
||||||
|
return {"host": host, "user": user, "password": password, "port": int(port)}
|
||||||
|
|
||||||
|
|
||||||
|
def _canReachPostgres(cfg) -> bool:
|
||||||
|
"""Try a quick connect to the admin DB so we can skip on connection failures."""
|
||||||
|
try:
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host=cfg["host"], port=cfg["port"], database="postgres",
|
||||||
|
user=cfg["user"], password=cfg["password"], connect_timeout=2,
|
||||||
|
)
|
||||||
|
conn.close()
|
||||||
|
return True
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
_DB_CFG = _dbConfig()
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
_DB_CFG is None or not _canReachPostgres(_DB_CFG),
|
||||||
|
reason="No reachable PostgreSQL — skipping live-Postgres pool tests",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolTestRow(PowerOnModel):
|
||||||
|
"""Tiny model used to exercise the pool — one ID + one payload field."""
|
||||||
|
payload: str = Field(default="", description="Test payload")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def liveConnector():
|
||||||
|
"""Spin up a throwaway database, yield a `DatabaseConnector` against it,
|
||||||
|
drop the database afterwards.
|
||||||
|
|
||||||
|
The pool registry is wiped before and after each test so state from one
|
||||||
|
test cannot mask a bug in another.
|
||||||
|
"""
|
||||||
|
cfg = _DB_CFG
|
||||||
|
host = cfg["host"]
|
||||||
|
user = cfg["user"]
|
||||||
|
password = cfg["password"]
|
||||||
|
port = cfg["port"]
|
||||||
|
dbName = f"poweron_pool_test_{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# Pre-clean: drop any orphan test DB with the same name (shouldn't happen
|
||||||
|
# because we use a unique uuid, but be defensive).
|
||||||
|
adminConn = psycopg2.connect(
|
||||||
|
host=host, port=port, database="postgres", user=user, password=password
|
||||||
|
)
|
||||||
|
adminConn.autocommit = True
|
||||||
|
try:
|
||||||
|
with adminConn.cursor() as cur:
|
||||||
|
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
|
||||||
|
finally:
|
||||||
|
adminConn.close()
|
||||||
|
|
||||||
|
closeAllPools()
|
||||||
|
|
||||||
|
connector = DatabaseConnector(
|
||||||
|
dbHost=host,
|
||||||
|
dbDatabase=dbName,
|
||||||
|
dbUser=user,
|
||||||
|
dbPassword=password,
|
||||||
|
dbPort=port,
|
||||||
|
)
|
||||||
|
# Seed exactly one row so every concurrent read has a stable target.
|
||||||
|
connector.recordCreate(PoolTestRow, {"id": "seed", "payload": "hello"})
|
||||||
|
|
||||||
|
yield connector
|
||||||
|
|
||||||
|
# Teardown: tear pools down, then drop the DB.
|
||||||
|
closeAllPools()
|
||||||
|
adminConn = psycopg2.connect(
|
||||||
|
host=host, port=port, database="postgres", user=user, password=password
|
||||||
|
)
|
||||||
|
adminConn.autocommit = True
|
||||||
|
try:
|
||||||
|
with adminConn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
'SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = %s',
|
||||||
|
(dbName,),
|
||||||
|
)
|
||||||
|
cur.execute(f'DROP DATABASE IF EXISTS "{dbName}"')
|
||||||
|
finally:
|
||||||
|
adminConn.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPoolConcurrency:
|
||||||
|
def _runWorkers(self, liveConnector, *, threadCount: int, callsPerThread: int):
|
||||||
|
"""Run N worker threads, each issuing M reads. Return (errors, latencies)."""
|
||||||
|
errors: list = []
|
||||||
|
latencies: list = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
for _ in range(callsPerThread):
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
try:
|
||||||
|
rows = liveConnector.getRecordset(PoolTestRow)
|
||||||
|
assert any(r["id"] == "seed" for r in rows)
|
||||||
|
except Exception as e: # noqa: BLE001 — we want every failure mode
|
||||||
|
with lock:
|
||||||
|
errors.append(e)
|
||||||
|
finally:
|
||||||
|
with lock:
|
||||||
|
latencies.append(time.perf_counter() - t0)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=threadCount) as ex:
|
||||||
|
futures = [ex.submit(worker) for _ in range(threadCount)]
|
||||||
|
for f in as_completed(futures):
|
||||||
|
f.result()
|
||||||
|
latencies.sort()
|
||||||
|
return errors, latencies
|
||||||
|
|
||||||
|
def test_50_threads_x_20_reads_no_errors(self, liveConnector):
|
||||||
|
"""T1a — STRESS: 50 threads × 20 reads each → 0 errors.
|
||||||
|
|
||||||
|
Pre-pool, this scenario produced either
|
||||||
|
`OperationalError: another command is already in progress` or a
|
||||||
|
deadlock in `recv()` because the threadpool shared one psycopg2
|
||||||
|
socket. With the pool plus `borrowConn`'s bounded wait, every
|
||||||
|
thread eventually gets a connection and completes — even with 30
|
||||||
|
threads queued waiting at any moment (pool max=20).
|
||||||
|
"""
|
||||||
|
errors, _ = self._runWorkers(liveConnector, threadCount=50, callsPerThread=20)
|
||||||
|
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
|
||||||
|
|
||||||
|
def test_20_threads_x_50_reads_latency_budget(self, liveConnector):
|
||||||
|
"""T1b — DESIGN CAPACITY: 20 threads × 50 reads, p99 < 5 s.
|
||||||
|
|
||||||
|
20 threads matches the pool's `max=20` so there is no queueing —
|
||||||
|
every borrow returns immediately. This pins a sanity-level per-call
|
||||||
|
latency budget; pre-pool it was unbounded (recv() never returned).
|
||||||
|
|
||||||
|
The 5 s ceiling is generous on purpose: `getRecordset` calls
|
||||||
|
`_ensureTableExists` which runs two `information_schema` queries
|
||||||
|
for column-additive migration, and under 20-way concurrency on a
|
||||||
|
single Postgres instance that produces a long tail. The hard
|
||||||
|
assertion is `not errors` — the latency check just guarantees
|
||||||
|
nothing hangs indefinitely.
|
||||||
|
"""
|
||||||
|
errors, latencies = self._runWorkers(
|
||||||
|
liveConnector, threadCount=20, callsPerThread=50
|
||||||
|
)
|
||||||
|
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
|
||||||
|
p99 = latencies[int(len(latencies) * 0.99)]
|
||||||
|
assert p99 < 5.0, f"p99 latency {p99:.2f}s exceeds 5s budget"
|
||||||
|
|
||||||
|
def test_interleaved_load_and_save_no_collision(self, liveConnector):
|
||||||
|
"""T2: parallel reads + writes on the same connector → no cursor mix-up.
|
||||||
|
|
||||||
|
Pre-pool the reader could observe a row in mid-write or vice versa
|
||||||
|
because both shared the same cursor. With one connection per borrow,
|
||||||
|
the database's own row-locking is the only contention, and we just
|
||||||
|
need to assert no exceptions.
|
||||||
|
"""
|
||||||
|
stopFlag = threading.Event()
|
||||||
|
errors: list = []
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def reader():
|
||||||
|
while not stopFlag.is_set():
|
||||||
|
try:
|
||||||
|
liveConnector.getRecord(PoolTestRow, "seed")
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
with lock:
|
||||||
|
errors.append(("read", e))
|
||||||
|
|
||||||
|
def writer():
|
||||||
|
i = 0
|
||||||
|
while not stopFlag.is_set():
|
||||||
|
try:
|
||||||
|
liveConnector.recordModify(
|
||||||
|
PoolTestRow,
|
||||||
|
"seed",
|
||||||
|
{"id": "seed", "payload": f"v{i}"},
|
||||||
|
)
|
||||||
|
i += 1
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
with lock:
|
||||||
|
errors.append(("write", e))
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=reader, daemon=True),
|
||||||
|
threading.Thread(target=reader, daemon=True),
|
||||||
|
threading.Thread(target=writer, daemon=True),
|
||||||
|
threading.Thread(target=writer, daemon=True),
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
time.sleep(2.0)
|
||||||
|
stopFlag.set()
|
||||||
|
for t in threads:
|
||||||
|
t.join(timeout=3.0)
|
||||||
|
|
||||||
|
assert not errors, f"got {len(errors)} errors; first: {errors[0]!r}"
|
||||||
|
|
||||||
|
def test_statement_timeout_releases_connection(self, liveConnector):
|
||||||
|
"""T3: `pg_sleep` past statement_timeout → QueryCanceled, pool intact.
|
||||||
|
|
||||||
|
The bug we are guarding against: a runaway query with no timeout
|
||||||
|
hung `recv()` forever, the psycopg2 connection was poisoned, and the
|
||||||
|
whole backend became unresponsive once that connection was reused.
|
||||||
|
With `statement_timeout=30000` configured at pool construction the
|
||||||
|
query is cancelled by the server, the borrow context manager rolls
|
||||||
|
back, and the connection returns to the pool — proven by the fact
|
||||||
|
that a follow-up call still succeeds quickly.
|
||||||
|
"""
|
||||||
|
# Use a short timeout to keep the test fast — override the pool's
|
||||||
|
# session statement_timeout for one borrow via SET LOCAL.
|
||||||
|
with liveConnector.borrowConn() as conn:
|
||||||
|
with conn.cursor() as cursor:
|
||||||
|
cursor.execute("SET LOCAL statement_timeout = 500")
|
||||||
|
with pytest.raises(psycopg2.errors.QueryCanceled):
|
||||||
|
cursor.execute("SELECT pg_sleep(5)")
|
||||||
|
|
||||||
|
# Follow-up call must succeed quickly: connection is back in the pool.
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
rows = liveConnector.getRecordset(PoolTestRow)
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
assert any(r["id"] == "seed" for r in rows)
|
||||||
|
assert elapsed < 1.0, f"follow-up call took {elapsed:.2f}s — pool may be wedged"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPoolRegistry:
|
||||||
|
def test_one_pool_per_database_identity(self, liveConnector):
|
||||||
|
"""Two connectors against the same (host, db, port) share one pool."""
|
||||||
|
cfg = _DB_CFG
|
||||||
|
pool1 = _PoolRegistry.getPool(
|
||||||
|
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
|
||||||
|
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
|
||||||
|
)
|
||||||
|
pool2 = _PoolRegistry.getPool(
|
||||||
|
dbHost=cfg["host"], dbDatabase=liveConnector.dbDatabase,
|
||||||
|
dbUser=cfg["user"], dbPassword=cfg["password"], dbPort=cfg["port"],
|
||||||
|
)
|
||||||
|
assert pool1 is pool2
|
||||||
|
|
||||||
|
def test_close_all_clears_registry(self, liveConnector):
|
||||||
|
"""`closeAllPools()` empties the registry so the next call rebuilds."""
|
||||||
|
# Touch the pool first.
|
||||||
|
liveConnector.getRecordset(PoolTestRow)
|
||||||
|
assert _PoolRegistry._pools, "pool should exist after a real call"
|
||||||
|
closeAllPools()
|
||||||
|
assert _PoolRegistry._pools == {}, "registry should be empty after closeAllPools()"
|
||||||
|
|
@ -68,6 +68,16 @@ class _FakeDb:
|
||||||
def _ensureTableExists(self, modelClass):
|
def _ensureTableExists(self, modelClass):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def borrowCursor(self):
|
||||||
|
"""Mimic `DatabaseConnector.borrowCursor()` context manager."""
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _cm():
|
||||||
|
yield MagicMock()
|
||||||
|
return _cm()
|
||||||
|
|
||||||
def seed(self, modelClass, record: Dict[str, Any]):
|
def seed(self, modelClass, record: Dict[str, Any]):
|
||||||
tableName = modelClass.__name__
|
tableName = modelClass.__name__
|
||||||
self._tables.setdefault(tableName, {})
|
self._tables.setdefault(tableName, {})
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,16 @@ class _FakeDb:
|
||||||
def _ensureTableExists(self, modelClass):
|
def _ensureTableExists(self, modelClass):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def borrowCursor(self):
|
||||||
|
"""Mimic `DatabaseConnector.borrowCursor()` context manager for the cascade test."""
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _cm():
|
||||||
|
yield MagicMock()
|
||||||
|
return _cm()
|
||||||
|
|
||||||
def seed(self, modelClass, record: Dict[str, Any]):
|
def seed(self, modelClass, record: Dict[str, Any]):
|
||||||
tableName = modelClass.__name__
|
tableName = modelClass.__name__
|
||||||
self._tables.setdefault(tableName, {})
|
self._tables.setdefault(tableName, {})
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue