2002 lines
84 KiB
Python
2002 lines
84 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
import contextvars
|
|
import re
|
|
import time
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
import psycopg2.pool
|
|
import logging
|
|
from contextlib import contextmanager
|
|
from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type
|
|
import uuid
|
|
from pydantic import BaseModel, Field
|
|
import threading
|
|
|
|
from modules.shared.timeUtils import getUtcTimestamp
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.datamodels.datamodelBase import PowerOnModel
|
|
from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions
|
|
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# No mapping needed - table name = Pydantic model name exactly
|
|
|
|
|
|
class DatabaseQueryError(RuntimeError):
|
|
"""Raised by DB read methods when the underlying SQL query failed.
|
|
|
|
Empty result sets do NOT raise this — they return ``[]`` / ``None`` /
|
|
``{"items": [], "totalItems": 0, "totalPages": 0}`` as before. This
|
|
exception is reserved for **real** failures: psycopg2 ProgrammingError,
|
|
DataError, OperationalError, IntegrityError, plus any unexpected
|
|
Python error raised inside a query path.
|
|
|
|
Read methods used to silently swallow such errors and return empty
|
|
collections, which made every caller incapable of distinguishing
|
|
"no rows" from "broken query / type adapter / dropped column / lost
|
|
connection". That hid concrete bugs (e.g. dict passed where Postgres
|
|
expected a UUID string) behind misleading downstream "no record found"
|
|
errors.
|
|
"""
|
|
|
|
def __init__(self, table: str, message: str, original: BaseException = None):
|
|
super().__init__(f"{table}: {message}")
|
|
self.table = table
|
|
self.original = original
|
|
|
|
|
|
class SystemTable(PowerOnModel):
|
|
"""Data model for system table entries"""
|
|
|
|
table_name: str = Field(
|
|
description="Name of the table",
|
|
json_schema_extra={
|
|
"frontend_type": "text",
|
|
"frontend_readonly": True,
|
|
"frontend_required": True,
|
|
}
|
|
)
|
|
initial_id: Optional[str] = Field(
|
|
default=None,
|
|
description="Initial ID for the table",
|
|
json_schema_extra={
|
|
"frontend_type": "text",
|
|
"frontend_readonly": True,
|
|
"frontend_required": False,
|
|
}
|
|
)
|
|
|
|
|
|
def _isVectorType(sqlType: str) -> bool:
|
|
"""Check if a SQL type string represents a pgvector column."""
|
|
return sqlType.upper().startswith("VECTOR")
|
|
|
|
|
|
def _isJsonbType(fieldType) -> bool:
|
|
"""Check if a type should be stored as JSONB in PostgreSQL."""
|
|
# Direct dict or list
|
|
if fieldType == dict or fieldType == list:
|
|
return True
|
|
|
|
# Generic List[X] or Dict[X, Y]
|
|
origin = get_origin(fieldType)
|
|
if origin in (dict, list):
|
|
return True
|
|
|
|
# Direct Pydantic BaseModel subclass
|
|
if isinstance(fieldType, type) and issubclass(fieldType, BaseModel):
|
|
return True
|
|
|
|
# Optional[X] - check the inner type
|
|
if origin is Union:
|
|
args = get_args(fieldType)
|
|
for arg in args:
|
|
if arg is type(None):
|
|
continue
|
|
# Recursively check the inner type
|
|
if _isJsonbType(arg):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def getModelFields(model_class) -> Dict[str, str]:
|
|
"""Get all fields from Pydantic model and map to SQL types.
|
|
|
|
Supports explicit db_type override via json_schema_extra={"db_type": "vector(1536)"}.
|
|
This enables pgvector columns without special-casing field names.
|
|
"""
|
|
model_fields = model_class.model_fields
|
|
|
|
fields = {}
|
|
for field_name, field_info in model_fields.items():
|
|
field_type = field_info.annotation
|
|
|
|
# Explicit db_type override (e.g. vector columns)
|
|
extra = field_info.json_schema_extra
|
|
if extra and isinstance(extra, dict) and "db_type" in extra:
|
|
fields[field_name] = extra["db_type"]
|
|
continue
|
|
|
|
# Unwrap Optional[X] → X (handles both typing.Union and types.UnionType)
|
|
origin = get_origin(field_type)
|
|
if origin is Union:
|
|
args = [a for a in get_args(field_type) if a is not type(None)]
|
|
if len(args) == 1:
|
|
field_type = args[0]
|
|
elif hasattr(field_type, '__args__') and type(None) in getattr(field_type, '__args__', ()):
|
|
args = [a for a in field_type.__args__ if a is not type(None)]
|
|
if len(args) == 1:
|
|
field_type = args[0]
|
|
|
|
if _isJsonbType(field_type):
|
|
fields[field_name] = "JSONB"
|
|
elif field_type is bool:
|
|
fields[field_name] = "BOOLEAN"
|
|
elif field_type is int:
|
|
fields[field_name] = "INTEGER"
|
|
elif field_type is float:
|
|
fields[field_name] = "DOUBLE PRECISION"
|
|
elif field_type in (str, type(None)):
|
|
fields[field_name] = "TEXT"
|
|
else:
|
|
fields[field_name] = "TEXT"
|
|
|
|
return fields
|
|
|
|
|
|
def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: str = "") -> None:
|
|
"""Parse record fields in-place: numeric typing, vector parsing, JSONB deserialization."""
|
|
import json as _json
|
|
|
|
for fieldName, fieldType in fields.items():
|
|
if fieldName not in record:
|
|
continue
|
|
value = record[fieldName]
|
|
|
|
if fieldType in ("DOUBLE PRECISION", "INTEGER") and value is not None:
|
|
try:
|
|
record[fieldName] = float(value) if fieldType == "DOUBLE PRECISION" else int(value)
|
|
except (ValueError, TypeError):
|
|
logger.warning(f"Could not convert {fieldName} to {fieldType} ({context}): {value}")
|
|
|
|
elif _isVectorType(fieldType) and value is not None:
|
|
if isinstance(value, str):
|
|
try:
|
|
record[fieldName] = [float(v) for v in value.strip("[]").split(",")]
|
|
except (ValueError, TypeError):
|
|
logger.warning(f"Could not parse vector field {fieldName} ({context})")
|
|
elif isinstance(value, list):
|
|
pass # already a list
|
|
|
|
elif fieldType == "BOOLEAN":
|
|
record[fieldName] = bool(value) if value is not None else None
|
|
|
|
elif fieldType == "JSONB" and value is not None:
|
|
try:
|
|
if isinstance(value, str):
|
|
record[fieldName] = _json.loads(value)
|
|
elif not isinstance(value, (dict, list)):
|
|
record[fieldName] = _json.loads(str(value))
|
|
except (_json.JSONDecodeError, TypeError, ValueError):
|
|
logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})")
|
|
|
|
|
|
def _stripNulBytesFromStr(value: Any) -> Any:
|
|
"""psycopg2 rejects bound parameters whose Python str contains NUL (0x00).
|
|
|
|
Some extracted files (e.g. SQL dumps, mixed binary treated as text) still
|
|
carry those bytes; PostgreSQL TEXT could store them via other paths, but
|
|
the client protocol path used here cannot.
|
|
"""
|
|
if isinstance(value, str) and "\x00" in value:
|
|
return value.replace("\x00", "")
|
|
return value
|
|
|
|
|
|
def _quotePgIdent(name: str) -> str:
|
|
return '"' + str(name).replace('"', '""') + '"'
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Connection pool registry
|
|
# ---------------------------------------------------------------------------
|
|
# psycopg2 connections are NOT thread-safe; sharing one connection across the
|
|
# FastAPI threadpool (sync `def` routes) or across multiple async tasks that
|
|
# happen to be active simultaneously results in either
|
|
# `OperationalError: another command is already in progress` or — far worse —
|
|
# an unbounded hang in `recv()` on the connection socket, because two cursors
|
|
# wait for one server response.
|
|
#
|
|
# `_PoolRegistry` keeps one `ThreadedConnectionPool` per database identity
|
|
# (`host`, `db`, `port`). Every DB call borrows a connection from the pool,
|
|
# runs its query, and returns the connection. The pool itself guarantees that
|
|
# no two callers share the same psycopg2 connection at the same time.
|
|
#
|
|
# `statement_timeout=30000` (30 s) is a safety net: a runaway query is aborted
|
|
# instead of hanging forever and poisoning the connection. `connect_timeout=10`
|
|
# prevents indefinite blocking when the Postgres host is unreachable.
|
|
|
|
_DEFAULT_POOL_MIN = 2
|
|
_DEFAULT_POOL_MAX = 20
|
|
_STATEMENT_TIMEOUT_MS = 30000
|
|
_CONNECT_TIMEOUT_S = 10
|
|
# psycopg2.pool.ThreadedConnectionPool.getconn() does NOT block when the pool
|
|
# is exhausted — it raises `psycopg2.pool.PoolError` immediately. That makes
|
|
# bursty workloads (50 concurrent route calls against a max=20 pool) fail
|
|
# spuriously. `borrowConn()` therefore retries with a small backoff up to
|
|
# `_BORROW_WAIT_TIMEOUT_S` seconds before giving up.
|
|
_BORROW_WAIT_TIMEOUT_S = 30.0
|
|
_BORROW_WAIT_BACKOFF_S = 0.05
|
|
_shuttingDown = False
|
|
|
|
|
|
def _resolvePoolMax() -> int:
|
|
"""Pool size is configurable via `DB_POOL_MAX_CONN` (default 20)."""
|
|
try:
|
|
return max(2, int(APP_CONFIG.get("DB_POOL_MAX_CONN") or _DEFAULT_POOL_MAX))
|
|
except (ValueError, TypeError):
|
|
return _DEFAULT_POOL_MAX
|
|
|
|
|
|
class _PoolRegistry:
|
|
"""Process-wide registry of `ThreadedConnectionPool` instances.
|
|
|
|
Keyed by `(host, database, port)` so that two `DatabaseConnector` instances
|
|
pointing at the same physical database share one pool. Lazy-initialised and
|
|
thread-safe.
|
|
"""
|
|
|
|
_pools: Dict[tuple, psycopg2.pool.ThreadedConnectionPool] = {}
|
|
_lock = threading.Lock()
|
|
|
|
@classmethod
|
|
def getPool(
|
|
cls,
|
|
*,
|
|
dbHost: str,
|
|
dbDatabase: str,
|
|
dbUser: str,
|
|
dbPassword: str,
|
|
dbPort: int,
|
|
) -> psycopg2.pool.ThreadedConnectionPool:
|
|
port = int(dbPort) if dbPort is not None else 5432
|
|
key = (dbHost, dbDatabase, port)
|
|
# Fast path: pool exists.
|
|
pool = cls._pools.get(key)
|
|
if pool is not None:
|
|
return pool
|
|
# Slow path: create exactly one pool per key, even under contention.
|
|
with cls._lock:
|
|
pool = cls._pools.get(key)
|
|
if pool is not None:
|
|
return pool
|
|
poolMax = _resolvePoolMax()
|
|
options = f"-c statement_timeout={_STATEMENT_TIMEOUT_MS}"
|
|
try:
|
|
pool = psycopg2.pool.ThreadedConnectionPool(
|
|
_DEFAULT_POOL_MIN,
|
|
poolMax,
|
|
host=dbHost,
|
|
port=port,
|
|
database=dbDatabase,
|
|
user=dbUser,
|
|
password=dbPassword,
|
|
client_encoding="utf8",
|
|
cursor_factory=psycopg2.extras.RealDictCursor,
|
|
connect_timeout=_CONNECT_TIMEOUT_S,
|
|
options=options,
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to create connection pool for db=%s host=%s: %s",
|
|
dbDatabase, dbHost, e,
|
|
)
|
|
raise
|
|
cls._pools[key] = pool
|
|
logger.debug(
|
|
"Created connection pool for db=%s host=%s port=%s (min=%d max=%d, stmt_timeout=%dms)",
|
|
dbDatabase, dbHost, port, _DEFAULT_POOL_MIN, poolMax, _STATEMENT_TIMEOUT_MS,
|
|
)
|
|
return pool
|
|
|
|
@classmethod
|
|
def closeAll(cls) -> None:
|
|
"""Close every pool. Call once during FastAPI shutdown."""
|
|
with cls._lock:
|
|
for key, pool in list(cls._pools.items()):
|
|
try:
|
|
pool.closeall()
|
|
except Exception as e:
|
|
logger.warning("Error closing pool %s: %s", key, e)
|
|
cls._pools.clear()
|
|
logger.info("All database connection pools closed")
|
|
|
|
|
|
def closeAllPools() -> None:
|
|
"""Public entry point for FastAPI lifespan shutdown hook.
|
|
|
|
Sets the shutdown flag first so that any in-flight ``_acquireConn`` loops
|
|
abort immediately instead of polling for up to 30 s.
|
|
"""
|
|
global _shuttingDown
|
|
_shuttingDown = True
|
|
_PoolRegistry.closeAll()
|
|
|
|
|
|
class _ConnectionShim:
|
|
"""Backward-compatibility stand-in for the old `DatabaseConnector.connection`.
|
|
|
|
Old callers used patterns like::
|
|
|
|
db.connection.commit()
|
|
if db.connection and not db.connection.closed: ...
|
|
|
|
These are now no-ops because the pool owns the connection lifecycle.
|
|
Direct cursor access through `self.connection.cursor()` is intentionally
|
|
blocked with a clear error so silent breakage is impossible — every such
|
|
call site must migrate to `db.borrowCursor()`.
|
|
"""
|
|
|
|
closed = False
|
|
|
|
def __bool__(self) -> bool:
|
|
return True
|
|
|
|
def commit(self) -> None:
|
|
return
|
|
|
|
def rollback(self) -> None:
|
|
return
|
|
|
|
def close(self) -> None:
|
|
return
|
|
|
|
def cursor(self, *args, **kwargs):
|
|
raise RuntimeError(
|
|
"DatabaseConnector.connection.cursor() is no longer supported. "
|
|
"Use `db.borrowCursor()` (or `db.borrowConn()` for multi-statement "
|
|
"transactions) so the connection is borrowed from and returned to "
|
|
"the pool correctly."
|
|
)
|
|
|
|
|
|
_CONNECTION_SHIM = _ConnectionShim()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Connector cache (lightweight wrappers — actual connections live in the pool)
|
|
# ---------------------------------------------------------------------------
|
|
# Multiple call sites (`routeI18n`, `aiAuditLogger`, `mainBackgroundJobService`,
|
|
# interfaces) ask for a connector via `getCachedConnector(...)` and expect to
|
|
# get back the same object on subsequent calls. Now that real connection
|
|
# multiplexing happens at the pool layer, the cache returns lightweight
|
|
# `DatabaseConnector` wrappers — they hold no connection themselves, only the
|
|
# DSN params and a reference to the shared pool.
|
|
_MAX_CACHED_CONNECTORS = 32
|
|
_connector_cache: Dict[tuple, "DatabaseConnector"] = {}
|
|
_connector_cache_order: List[tuple] = [] # FIFO order for eviction
|
|
_connector_cache_lock = threading.Lock()
|
|
_current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
|
|
"db_connector_user_id", default=None
|
|
)
|
|
|
|
|
|
def getCachedConnector(
|
|
dbHost: str,
|
|
dbDatabase: str,
|
|
dbUser: str = None,
|
|
dbPassword: str = None,
|
|
dbPort: int = None,
|
|
userId: str = None,
|
|
) -> "DatabaseConnector":
|
|
"""Return a cached `DatabaseConnector` wrapper for `(host, database, port)`.
|
|
|
|
Two layers of caching, both intentional:
|
|
|
|
1. **Pool layer** (`_PoolRegistry`) — owns the actual psycopg2 connections.
|
|
Per `(host, db, port)` exactly one `ThreadedConnectionPool`, shared
|
|
across every wrapper that points at the same physical database. This
|
|
is what actually saves Postgres connection slots.
|
|
|
|
2. **Wrapper layer** (this function) — caches the `DatabaseConnector`
|
|
Python object. Each wrapper triggers an `initDbSystem()` on first
|
|
instantiation (CREATE DATABASE if missing, CREATE TABLE _system,
|
|
pool warm-up). Caching the wrapper avoids paying that bootstrap cost
|
|
on every request and stops the log from filling with "PostgreSQL
|
|
database system initialized" lines.
|
|
|
|
`userId` is request-scoped via the `_current_user_id` contextvar so two
|
|
concurrent requests sharing the same cached wrapper still produce
|
|
correct `sysCreatedBy` / `sysModifiedBy` audit fields.
|
|
"""
|
|
port = int(dbPort) if dbPort is not None else 5432
|
|
key = (dbHost, dbDatabase, port)
|
|
with _connector_cache_lock:
|
|
if key not in _connector_cache:
|
|
# FIFO eviction. Connectors are now lightweight (no per-instance
|
|
# connection), so eviction is purely a memory bookkeeping concern;
|
|
# the underlying pool stays alive in `_PoolRegistry`.
|
|
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
|
|
oldest_key = _connector_cache_order.pop(0)
|
|
_connector_cache.pop(oldest_key, None)
|
|
_connector_cache[key] = DatabaseConnector(
|
|
dbHost=dbHost,
|
|
dbDatabase=dbDatabase,
|
|
dbUser=dbUser,
|
|
dbPassword=dbPassword,
|
|
dbPort=dbPort,
|
|
userId=userId,
|
|
)
|
|
_connector_cache[key]._isCachedShared = True
|
|
_connector_cache_order.append(key)
|
|
conn = _connector_cache[key]
|
|
# Set request-scoped userId via contextvar (avoids mutating shared connector)
|
|
if userId is not None:
|
|
_current_user_id.set(userId)
|
|
return conn
|
|
|
|
|
|
class DatabaseConnector:
|
|
"""
|
|
A connector for PostgreSQL-based data storage.
|
|
Provides generic database operations without user/mandate filtering.
|
|
Uses PostgreSQL with JSONB columns for flexible data storage.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dbHost: str,
|
|
dbDatabase: str,
|
|
dbUser: str = None,
|
|
dbPassword: str = None,
|
|
dbPort: int = None,
|
|
userId: str = None,
|
|
):
|
|
# Store the input parameters
|
|
self.dbHost = dbHost
|
|
self.dbDatabase = dbDatabase
|
|
self.dbUser = dbUser
|
|
self.dbPassword = dbPassword
|
|
self.dbPort = dbPort
|
|
|
|
# Set userId (default to empty string if None)
|
|
self.userId = userId if userId is not None else ""
|
|
|
|
# No per-instance connection any more — real connections live in the
|
|
# shared `_PoolRegistry` pool. `_isCachedShared` is retained because
|
|
# `close(forceClose=False)` callers (interface __del__) still ask.
|
|
self._isCachedShared = False
|
|
|
|
# pgvector extension state (cached per connector instance — cheap)
|
|
self._vectorExtensionEnabled = False
|
|
|
|
# System table bootstrap: create database, system table, ensure metadata.
|
|
self._systemTableName = "_system"
|
|
self.initDbSystem()
|
|
self._initializeSystemTable()
|
|
|
|
def initDbSystem(self):
|
|
"""Bootstrap the physical database and the `_system` metadata table.
|
|
|
|
Uses a short-lived autocommit connection (NOT the pool) because
|
|
`CREATE DATABASE` cannot run inside a transaction block. Also warms up
|
|
the pool by acquiring it for this `(host, db, port)` once.
|
|
"""
|
|
try:
|
|
self._create_database_if_not_exists()
|
|
self._create_tables()
|
|
|
|
# Warm the pool so the first request doesn't pay for socket setup.
|
|
_PoolRegistry.getPool(
|
|
dbHost=self.dbHost,
|
|
dbDatabase=self.dbDatabase,
|
|
dbUser=self.dbUser,
|
|
dbPassword=self.dbPassword,
|
|
dbPort=self.dbPort,
|
|
)
|
|
|
|
logger.debug(
|
|
"PostgreSQL database system initialized (db=%s, host=%s, port=%s)",
|
|
self.dbDatabase, self.dbHost, self.dbPort,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
|
|
raise
|
|
|
|
@property
|
|
def connection(self) -> "_ConnectionShim":
|
|
"""Backward-compat shim — see `_ConnectionShim` docstring."""
|
|
return _CONNECTION_SHIM
|
|
|
|
def _ensure_connection(self) -> None:
|
|
"""No-op for backward compatibility.
|
|
|
|
Previously this method tested-or-reconnected the per-instance socket.
|
|
With pooling, the `ThreadedConnectionPool` re-establishes broken
|
|
connections lazily on the next `getconn()`. Kept as a no-op so that
|
|
legacy call sites continue to compile.
|
|
"""
|
|
return
|
|
|
|
@contextmanager
|
|
def borrowCursor(self):
|
|
"""Borrow a cursor for one short SQL block.
|
|
|
|
Convenience wrapper around `borrowConn()` for the common pattern:
|
|
|
|
with db.borrowCursor() as cursor:
|
|
cursor.execute(...)
|
|
rows = cursor.fetchall()
|
|
|
|
Replaces the legacy `with db.connection.cursor() as cursor:` pattern.
|
|
Commit/rollback and pool return are handled automatically — callers
|
|
must NOT call `.commit()`/`.rollback()` on the cursor themselves.
|
|
"""
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
yield cursor
|
|
|
|
@contextmanager
|
|
def borrowConn(self):
|
|
"""Borrow a connection from the pool for the duration of the block.
|
|
|
|
Pool-exhaustion semantics: `ThreadedConnectionPool.getconn()` raises
|
|
`psycopg2.pool.PoolError` immediately when the pool is at its `maxconn`
|
|
limit — it does NOT block. We wrap that with a bounded busy-wait so
|
|
bursty workloads (e.g. 50 concurrent route calls against a max=20
|
|
pool) queue up instead of failing. If no connection becomes available
|
|
within `_BORROW_WAIT_TIMEOUT_S`, the `PoolError` is propagated so a
|
|
genuinely deadlocked pool surfaces as a real error.
|
|
|
|
On normal exit the current transaction is committed (idempotent for
|
|
read-only queries — leaves the connection in a clean state for the
|
|
next borrower). On exception the transaction is rolled back and the
|
|
exception propagates. The connection is **always** returned to the
|
|
pool, even when commit/rollback fails — otherwise a single bad query
|
|
would leak a slot and eventually exhaust the pool.
|
|
"""
|
|
pool = _PoolRegistry.getPool(
|
|
dbHost=self.dbHost,
|
|
dbDatabase=self.dbDatabase,
|
|
dbUser=self.dbUser,
|
|
dbPassword=self.dbPassword,
|
|
dbPort=self.dbPort,
|
|
)
|
|
conn = self._acquireConn(pool)
|
|
try:
|
|
yield conn
|
|
except Exception:
|
|
try:
|
|
conn.rollback()
|
|
except Exception:
|
|
pass
|
|
raise
|
|
else:
|
|
# Best-effort commit so the connection goes back to the pool with
|
|
# no in-flight transaction. Failure here is non-fatal — the pool
|
|
# will re-establish the socket on the next `getconn()` if needed.
|
|
try:
|
|
conn.commit()
|
|
except Exception:
|
|
try:
|
|
conn.rollback()
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
try:
|
|
pool.putconn(conn)
|
|
except Exception as e:
|
|
logger.warning("Failed to return connection to pool: %s", e)
|
|
|
|
@staticmethod
|
|
def _acquireConn(pool: psycopg2.pool.ThreadedConnectionPool):
|
|
"""Get a connection from the pool, waiting up to `_BORROW_WAIT_TIMEOUT_S`.
|
|
|
|
psycopg2's pool throws on exhaustion instead of queueing — this helper
|
|
polls with a short backoff so callers see queue semantics.
|
|
Aborts immediately when the application is shutting down.
|
|
"""
|
|
if _shuttingDown:
|
|
raise psycopg2.pool.PoolError("Application is shutting down")
|
|
deadline = time.monotonic() + _BORROW_WAIT_TIMEOUT_S
|
|
attempt = 0
|
|
while True:
|
|
try:
|
|
return pool.getconn()
|
|
except psycopg2.pool.PoolError as e:
|
|
attempt += 1
|
|
if _shuttingDown:
|
|
raise psycopg2.pool.PoolError("Application is shutting down")
|
|
if time.monotonic() >= deadline:
|
|
logger.error(
|
|
"Connection pool exhausted after %.1fs wait (%d retries)",
|
|
_BORROW_WAIT_TIMEOUT_S, attempt,
|
|
)
|
|
raise
|
|
time.sleep(_BORROW_WAIT_BACKOFF_S)
|
|
|
|
def _create_database_if_not_exists(self):
|
|
"""Create the database if it doesn't exist.
|
|
|
|
Uses an autocommit connection on the `postgres` admin DB because
|
|
`CREATE DATABASE` cannot run inside a transaction block — so this
|
|
path intentionally does NOT use the pool.
|
|
"""
|
|
try:
|
|
conn = psycopg2.connect(
|
|
host=self.dbHost,
|
|
port=self.dbPort,
|
|
database="postgres",
|
|
user=self.dbUser,
|
|
password=self.dbPassword,
|
|
client_encoding="utf8",
|
|
connect_timeout=_CONNECT_TIMEOUT_S,
|
|
)
|
|
conn.autocommit = True
|
|
try:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
"SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
|
|
)
|
|
exists = cursor.fetchone()
|
|
if not exists:
|
|
quoted_db_name = f'"{self.dbDatabase}"'
|
|
cursor.execute(f"CREATE DATABASE {quoted_db_name}")
|
|
logger.info(f"Created database: {self.dbDatabase}")
|
|
finally:
|
|
conn.close()
|
|
|
|
except Exception as e:
|
|
logger.error(f"FATAL ERROR: Cannot create database: {e}")
|
|
logger.error("Database connection failed - application cannot start")
|
|
raise RuntimeError(
|
|
f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}"
|
|
)
|
|
|
|
def _create_tables(self):
|
|
"""Create the `_system` table.
|
|
|
|
Uses a short-lived autocommit connection (not the pool) — runs exactly
|
|
once at connector creation.
|
|
"""
|
|
try:
|
|
conn = psycopg2.connect(
|
|
host=self.dbHost,
|
|
port=self.dbPort,
|
|
database=self.dbDatabase,
|
|
user=self.dbUser,
|
|
password=self.dbPassword,
|
|
client_encoding="utf8",
|
|
connect_timeout=_CONNECT_TIMEOUT_S,
|
|
)
|
|
conn.autocommit = True
|
|
try:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS _system (
|
|
id SERIAL PRIMARY KEY,
|
|
table_name VARCHAR(255) UNIQUE NOT NULL,
|
|
initial_id VARCHAR(255) NOT NULL,
|
|
"sysCreatedAt" DOUBLE PRECISION,
|
|
"sysCreatedBy" VARCHAR(255),
|
|
"sysModifiedAt" DOUBLE PRECISION,
|
|
"sysModifiedBy" VARCHAR(255)
|
|
)
|
|
""")
|
|
finally:
|
|
conn.close()
|
|
|
|
except Exception as e:
|
|
logger.error(f"FATAL ERROR: Cannot create system table: {e}")
|
|
logger.error(
|
|
"Database system table creation failed - application cannot start"
|
|
)
|
|
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
|
|
|
|
def _initializeSystemTable(self):
|
|
"""Initializes the system table if it doesn't exist yet."""
|
|
try:
|
|
self._ensureTableExists(SystemTable)
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute('SELECT COUNT(*) FROM "_system"')
|
|
cursor.fetchone() # noqa: just verifies table is readable
|
|
except Exception as e:
|
|
logger.error(f"Error initializing system table: {e}")
|
|
raise
|
|
|
|
def _loadSystemTable(self) -> Dict[str, str]:
|
|
"""Loads the system table with the initial IDs."""
|
|
try:
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
|
|
rows = cursor.fetchall()
|
|
return {row["table_name"]: row["initial_id"] for row in rows}
|
|
except Exception as e:
|
|
logger.error(f"Error loading system table: {e}")
|
|
return {}
|
|
|
|
def _saveSystemTable(self, data: Dict[str, str]) -> bool:
|
|
"""Saves the system table with the initial IDs."""
|
|
try:
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute('DELETE FROM "_system"')
|
|
for table_name, initial_id in data.items():
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt")
|
|
VALUES (%s, %s, %s)
|
|
""",
|
|
(table_name, initial_id, getUtcTimestamp()),
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error saving system table: {e}")
|
|
return False
|
|
|
|
def _ensureSystemTableExists(self) -> bool:
|
|
"""Ensures the system table exists, creates it if it doesn't."""
|
|
try:
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
"SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
|
|
(self._systemTableName,),
|
|
)
|
|
exists = cursor.fetchone()["count"] > 0
|
|
|
|
if not exists:
|
|
cursor.execute(f"""
|
|
CREATE TABLE "{self._systemTableName}" (
|
|
"table_name" VARCHAR(255) PRIMARY KEY,
|
|
"initial_id" VARCHAR(255),
|
|
"sysCreatedAt" DOUBLE PRECISION,
|
|
"sysCreatedBy" VARCHAR(255),
|
|
"sysModifiedAt" DOUBLE PRECISION,
|
|
"sysModifiedBy" VARCHAR(255)
|
|
)
|
|
""")
|
|
logger.info("System table created successfully")
|
|
else:
|
|
cursor.execute(
|
|
"""
|
|
SELECT column_name FROM information_schema.columns
|
|
WHERE table_name = %s AND table_schema = 'public'
|
|
""",
|
|
(self._systemTableName,),
|
|
)
|
|
existing_columns = [row["column_name"] for row in cursor.fetchall()]
|
|
|
|
for sys_col, sys_sql in [
|
|
("sysCreatedAt", "DOUBLE PRECISION"),
|
|
("sysCreatedBy", "VARCHAR(255)"),
|
|
("sysModifiedAt", "DOUBLE PRECISION"),
|
|
("sysModifiedBy", "VARCHAR(255)"),
|
|
]:
|
|
if sys_col not in existing_columns:
|
|
cursor.execute(
|
|
f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error ensuring system table exists: {e}")
|
|
return False
|
|
|
|
def _ensureTableExists(self, model_class: type) -> bool:
|
|
"""Ensures a table exists, creates it if it doesn't."""
|
|
table = model_class.__name__
|
|
|
|
if table == "SystemTable":
|
|
# Handle system table specially - it uses _system as the actual table name
|
|
return self._ensureSystemTableExists()
|
|
|
|
try:
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
"""
|
|
SELECT COUNT(*) FROM information_schema.tables
|
|
WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
|
|
""",
|
|
(table,),
|
|
)
|
|
exists = cursor.fetchone()["count"] > 0
|
|
|
|
if not exists:
|
|
self._create_table_from_model(cursor, table, model_class)
|
|
logger.info(
|
|
f"Created table '{table}' with columns from Pydantic model"
|
|
)
|
|
else:
|
|
# Table exists: ensure all columns from model are present (simple additive migration)
|
|
try:
|
|
cursor.execute(
|
|
"""
|
|
SELECT column_name, data_type
|
|
FROM information_schema.columns
|
|
WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
|
|
""",
|
|
(table,),
|
|
)
|
|
existing_column_rows = cursor.fetchall()
|
|
existing_columns = {
|
|
row["column_name"] for row in existing_column_rows
|
|
}
|
|
existing_column_types = {
|
|
row["column_name"]: (row["data_type"] or "").lower()
|
|
for row in existing_column_rows
|
|
}
|
|
|
|
model_fields = getModelFields(model_class)
|
|
desired_columns = set(["id"]) | set(model_fields.keys())
|
|
|
|
for col in sorted(desired_columns - existing_columns):
|
|
if col in ["id"]:
|
|
continue
|
|
sql_type = model_fields.get(col)
|
|
if not sql_type:
|
|
sql_type = "TEXT"
|
|
try:
|
|
cursor.execute(
|
|
f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}'
|
|
)
|
|
logger.info(
|
|
f"Added missing column '{col}' ({sql_type}) to '{table}'"
|
|
)
|
|
except Exception as add_err:
|
|
logger.warning(
|
|
f"Could not add column '{col}' to '{table}': {add_err}"
|
|
)
|
|
|
|
# Column type migrations for existing tables.
|
|
# TEXT→DOUBLE PRECISION handles three value shapes:
|
|
# 1. NULL / empty string → NULL
|
|
# 2. ISO date(time) like "2025-01-22" or "2025-01-22T10:00:00+00" → epoch via EXTRACT
|
|
# 3. Plain numeric string like "3.14" → direct cast
|
|
_TEXT_TO_DOUBLE = (
|
|
'DOUBLE PRECISION USING CASE'
|
|
' WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL'
|
|
' WHEN "{col}" ~ \'^\\d{4}-\\d{2}-\\d{2}\''
|
|
' THEN EXTRACT(EPOCH FROM "{col}"::timestamptz)'
|
|
' ELSE NULLIF("{col}", \'\')::double precision'
|
|
' END'
|
|
)
|
|
_SAFE_TYPE_CHANGES = {
|
|
("jsonb", "TEXT"): "TEXT USING \"{col}\"::text",
|
|
("text", "DOUBLE PRECISION"): _TEXT_TO_DOUBLE,
|
|
("text", "INTEGER"): "INTEGER USING NULLIF(\"{col}\", '')::integer",
|
|
("timestamp without time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}" AT TIME ZONE \'UTC\')',
|
|
("timestamp with time zone", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}")',
|
|
("date", "DOUBLE PRECISION"): 'DOUBLE PRECISION USING EXTRACT(EPOCH FROM "{col}"::timestamp AT TIME ZONE \'UTC\')',
|
|
}
|
|
for col in sorted(desired_columns & existing_columns):
|
|
if col == "id":
|
|
continue
|
|
desired_sql = (model_fields.get(col) or "").upper()
|
|
currentType = existing_column_types.get(col, "")
|
|
migration = _SAFE_TYPE_CHANGES.get((currentType, desired_sql))
|
|
if not migration and desired_sql.startswith("VECTOR") and currentType == "text":
|
|
migration = f'{desired_sql} USING CASE WHEN "{col}" IS NULL OR "{col}" = \'\' THEN NULL ELSE "{col}"::vector END'
|
|
if migration:
|
|
castExpr = migration.replace("{col}", col)
|
|
try:
|
|
cursor.execute('SAVEPOINT col_migrate')
|
|
cursor.execute(
|
|
f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE {castExpr}'
|
|
)
|
|
cursor.execute('RELEASE SAVEPOINT col_migrate')
|
|
logger.info(
|
|
f"Migrated column '{col}' from {currentType} to {desired_sql} on '{table}'"
|
|
)
|
|
except Exception as alter_err:
|
|
cursor.execute('ROLLBACK TO SAVEPOINT col_migrate')
|
|
logger.warning(
|
|
f"Could not migrate column '{col}' on '{table}': {alter_err}"
|
|
)
|
|
except Exception as ensure_err:
|
|
logger.warning(
|
|
f"Could not ensure columns for existing table '{table}': {ensure_err}"
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error ensuring table {table} exists: {e}")
|
|
return False
|
|
|
|
def _ensureVectorExtension(self) -> bool:
|
|
"""Enable pgvector extension if not already enabled. Called lazily on first vector table."""
|
|
if self._vectorExtensionEnabled:
|
|
return True
|
|
try:
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
|
self._vectorExtensionEnabled = True
|
|
logger.info("pgvector extension enabled")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to enable pgvector extension: {e}")
|
|
return False
|
|
|
|
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
|
|
"""Create table with columns matching Pydantic model fields."""
|
|
fields = getModelFields(model_class)
|
|
|
|
# Enable pgvector if any field uses vector type
|
|
if any(_isVectorType(sqlType) for sqlType in fields.values()):
|
|
self._ensureVectorExtension()
|
|
|
|
# Build column definitions with quoted identifiers to preserve exact case
|
|
columns = ['"id" VARCHAR(255) PRIMARY KEY']
|
|
for field_name, sql_type in fields.items():
|
|
if field_name != "id": # Skip id, already defined
|
|
columns.append(f'"{field_name}" {sql_type}')
|
|
|
|
# Create table
|
|
sql = f'CREATE TABLE IF NOT EXISTS "{table}" ({", ".join(columns)})'
|
|
cursor.execute(sql)
|
|
|
|
# Create indexes for foreign keys
|
|
for field_name in fields:
|
|
if field_name.endswith("Id") and field_name != "id":
|
|
cursor.execute(
|
|
f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")'
|
|
)
|
|
|
|
def _save_record(
|
|
self,
|
|
cursor,
|
|
table: str,
|
|
recordId: str,
|
|
record: Dict[str, Any],
|
|
model_class: type,
|
|
) -> None:
|
|
"""Save record to normalized table with explicit columns."""
|
|
# Get columns from Pydantic model instead of database schema
|
|
fields = getModelFields(model_class)
|
|
columns = ["id"] + [field for field in fields.keys() if field != "id"]
|
|
|
|
if not columns:
|
|
logger.error(f"No columns found for table {table}")
|
|
return
|
|
|
|
# Filter record data to only include columns that exist in the table
|
|
filtered_record = {k: v for k, v in record.items() if k in columns}
|
|
|
|
# Ensure id is set
|
|
filtered_record["id"] = recordId
|
|
|
|
# Prepare values in the correct order
|
|
values = []
|
|
for col in columns:
|
|
value = filtered_record.get(col)
|
|
|
|
# Handle timestamp fields - store as Unix timestamps (floats) for consistency
|
|
if col in ["sysCreatedAt", "sysModifiedAt"] and value is not None:
|
|
if isinstance(value, str):
|
|
# Try to parse string as timestamp
|
|
try:
|
|
value = float(value)
|
|
except:
|
|
pass # Keep as string if parsing fails
|
|
|
|
# Convert enum values to their string representation
|
|
elif hasattr(value, "value"):
|
|
value = value.value
|
|
|
|
# Handle vector fields (pgvector) - convert List[float] to string
|
|
elif col in fields and _isVectorType(fields[col]) and value is not None:
|
|
if isinstance(value, list):
|
|
value = f"[{','.join(str(v) for v in value)}]"
|
|
|
|
# Handle JSONB fields - ensure proper JSON format for PostgreSQL
|
|
elif col in fields and fields[col] == "JSONB" and value is not None:
|
|
import json
|
|
|
|
if isinstance(value, (dict, list)):
|
|
value = json.dumps(value)
|
|
elif isinstance(value, str):
|
|
try:
|
|
json.loads(value)
|
|
except (json.JSONDecodeError, TypeError):
|
|
value = json.dumps(value)
|
|
elif hasattr(value, 'model_dump'):
|
|
value = json.dumps(value.model_dump())
|
|
else:
|
|
value = json.dumps(value)
|
|
|
|
values.append(_stripNulBytesFromStr(value))
|
|
|
|
# Build INSERT/UPDATE with quoted identifiers
|
|
col_names = ", ".join([f'"{col}"' for col in columns])
|
|
placeholders = ", ".join(["%s"] * len(columns))
|
|
updates = ", ".join(
|
|
[
|
|
f'"{col}" = EXCLUDED."{col}"'
|
|
for col in columns[1:]
|
|
if col not in ["sysCreatedAt", "sysCreatedBy"]
|
|
]
|
|
)
|
|
|
|
sql = f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders}) ON CONFLICT ("id") DO UPDATE SET {updates}'
|
|
|
|
cursor.execute(sql, values)
|
|
|
|
def _loadRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
|
|
"""Loads a single record from the normalized table."""
|
|
table = model_class.__name__
|
|
|
|
try:
|
|
if not self._ensureTableExists(model_class):
|
|
return None
|
|
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
|
|
row = cursor.fetchone()
|
|
if not row:
|
|
return None
|
|
record = dict(row)
|
|
|
|
fields = getModelFields(model_class)
|
|
parseRecordFields(record, fields, f"record {recordId}")
|
|
return record
|
|
except Exception as e:
|
|
logger.error(f"Error loading record {recordId} from table {table}: {e}")
|
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
|
|
|
def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
|
|
"""Load one row by primary key (routes / services; wraps _loadRecord)."""
|
|
return self._loadRecord(model_class, str(recordId))
|
|
|
|
def _saveRecord(
|
|
self, model_class: type, recordId: str, record: Dict[str, Any]
|
|
) -> bool:
|
|
"""Saves a single record to the table."""
|
|
table = model_class.__name__
|
|
|
|
try:
|
|
if not self._ensureTableExists(model_class):
|
|
return False
|
|
|
|
recordId = str(recordId)
|
|
if "id" in record and str(record["id"]) != recordId:
|
|
raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}")
|
|
|
|
# Add metadata - use contextvar for request-scoped userId when sharing connector
|
|
effective_user_id = _current_user_id.get()
|
|
if effective_user_id is None:
|
|
effective_user_id = self.userId
|
|
currentTime = getUtcTimestamp()
|
|
# Set sysCreatedAt/sysCreatedBy on first persist; always refresh modified fields.
|
|
# Treat None and 0 as unset (empty / bad defaults); model_dump often has sysCreatedAt=None.
|
|
createdTs = record.get("sysCreatedAt")
|
|
if createdTs is None or createdTs == 0 or createdTs == 0.0:
|
|
record["sysCreatedAt"] = currentTime
|
|
# Do not wipe caller-provided sysCreatedBy (e.g. FileItem from createFile with
|
|
# real user). ContextVar can be "system" for the DB pool while the business
|
|
# user is set on the record from model_dump().
|
|
if effective_user_id and not record.get("sysCreatedBy"):
|
|
record["sysCreatedBy"] = effective_user_id
|
|
elif not record.get("sysCreatedBy"):
|
|
if effective_user_id:
|
|
record["sysCreatedBy"] = effective_user_id
|
|
record["sysModifiedAt"] = currentTime
|
|
if effective_user_id:
|
|
record["sysModifiedBy"] = effective_user_id
|
|
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
self._save_record(cursor, table, recordId, record, model_class)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error saving record {recordId} to table {table}: {e}")
|
|
return False
|
|
|
|
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
|
|
"""Loads all records from a normalized table."""
|
|
table = model_class.__name__
|
|
|
|
if table == self._systemTableName:
|
|
return self._loadSystemTable()
|
|
|
|
try:
|
|
if not self._ensureTableExists(model_class):
|
|
return []
|
|
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
|
|
records = [dict(row) for row in cursor.fetchall()]
|
|
|
|
fields = getModelFields(model_class)
|
|
modelFields = model_class.model_fields
|
|
for record in records:
|
|
parseRecordFields(record, fields, f"table {table}")
|
|
for fieldName, fieldType in fields.items():
|
|
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
|
fieldInfo = modelFields.get(fieldName)
|
|
if fieldInfo:
|
|
fieldAnnotation = fieldInfo.annotation
|
|
if (fieldAnnotation == list or
|
|
(hasattr(fieldAnnotation, "__origin__") and
|
|
fieldAnnotation.__origin__ is list)):
|
|
record[fieldName] = []
|
|
elif (fieldAnnotation == dict or
|
|
(hasattr(fieldAnnotation, "__origin__") and
|
|
fieldAnnotation.__origin__ is dict)):
|
|
record[fieldName] = {}
|
|
|
|
return records
|
|
except Exception as e:
|
|
logger.error(f"Error loading table {table}: {e}")
|
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
|
|
|
def _registerInitialId(self, table: str, initialId: str) -> bool:
|
|
"""Registers the initial ID for a table."""
|
|
try:
|
|
systemData = self._loadSystemTable()
|
|
|
|
if table not in systemData:
|
|
systemData[table] = initialId
|
|
success = self._saveSystemTable(systemData)
|
|
if success:
|
|
logger.info(f"Initial ID {initialId} for table {table} registered")
|
|
return success
|
|
else:
|
|
# Table already has an initial ID registered
|
|
logger.debug(f"Table {table} already has initial ID {systemData[table]}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error registering the initial ID for table {table}: {e}")
|
|
return False
|
|
|
|
def _removeInitialId(self, table: str) -> bool:
|
|
"""Removes the initial ID for a table from the system table."""
|
|
try:
|
|
systemData = self._loadSystemTable()
|
|
|
|
if table in systemData:
|
|
del systemData[table]
|
|
success = self._saveSystemTable(systemData)
|
|
if success:
|
|
logger.info(
|
|
f"Initial ID for table {table} removed from system table"
|
|
)
|
|
return success
|
|
return True # If not present, this is not an error
|
|
except Exception as e:
|
|
logger.error(f"Error removing initial ID for table {table}: {e}")
|
|
return False
|
|
|
|
def buildRbacWhereClause(
|
|
self,
|
|
permissions: UserPermissions,
|
|
currentUser: User,
|
|
table: str,
|
|
mandateId: Optional[str] = None,
|
|
featureInstanceId: Optional[str] = None,
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Delegate to interfaceRbac.buildRbacWhereClause (tests and call sites use connector as entry)."""
|
|
from modules.interfaces.interfaceRbac import buildRbacWhereClause as _buildRbacWhereClause
|
|
|
|
return _buildRbacWhereClause(
|
|
permissions,
|
|
currentUser,
|
|
table,
|
|
self,
|
|
mandateId=mandateId,
|
|
featureInstanceId=featureInstanceId,
|
|
)
|
|
|
|
def updateContext(self, userId: str) -> None:
|
|
"""Updates the context of the database connector.
|
|
Sets both instance userId and contextvar for request-scoped use when connector is shared.
|
|
"""
|
|
if userId is None:
|
|
raise ValueError("userId must be provided")
|
|
self.userId = userId
|
|
_current_user_id.set(userId)
|
|
|
|
# Public API
|
|
|
|
def getTables(self) -> List[str]:
|
|
"""Returns a list of all available tables."""
|
|
tables: List[str] = []
|
|
try:
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
ORDER BY table_name
|
|
""")
|
|
rows = cursor.fetchall()
|
|
tables = [row["table_name"] for row in rows]
|
|
except Exception as e:
|
|
logger.error(f"Error reading the database {self.dbDatabase}: {e}")
|
|
return tables
|
|
|
|
def getFields(self, model_class: type) -> List[str]:
|
|
"""Returns a list of all fields in a table."""
|
|
data = self._loadTable(model_class)
|
|
|
|
if not data:
|
|
return []
|
|
|
|
fields = list(data[0].keys()) if data else []
|
|
|
|
return fields
|
|
|
|
def getSchema(
|
|
self, model_class: type, language: str = None
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
"""Returns a schema object for a table with data types and labels."""
|
|
data = self._loadTable(model_class)
|
|
|
|
schema = {}
|
|
|
|
if not data:
|
|
return schema
|
|
|
|
firstRecord = data[0]
|
|
|
|
for field, value in firstRecord.items():
|
|
dataType = type(value).__name__
|
|
label = field
|
|
|
|
schema[field] = {"type": dataType, "label": label}
|
|
|
|
return schema
|
|
|
|
def getRecordset(
|
|
self,
|
|
model_class: type,
|
|
fieldFilter: List[str] = None,
|
|
recordFilter: Dict[str, Any] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Returns a list of records from a table, filtered by criteria."""
|
|
table = model_class.__name__
|
|
|
|
try:
|
|
if not self._ensureTableExists(model_class):
|
|
return []
|
|
|
|
# Build WHERE clause from recordFilter
|
|
where_conditions = []
|
|
where_values = []
|
|
|
|
if recordFilter:
|
|
for field, value in recordFilter.items():
|
|
if value is None:
|
|
where_conditions.append(f'"{field}" IS NULL')
|
|
elif field == "isTemplate" and value is False:
|
|
# NULL must count as non-template (legacy rows omit the flag)
|
|
where_conditions.append(f'"{field}" IS NOT TRUE')
|
|
elif isinstance(value, list):
|
|
where_conditions.append(f'"{field}" = ANY(%s)')
|
|
where_values.append(value)
|
|
else:
|
|
where_conditions.append(f'"{field}" = %s')
|
|
where_values.append(value)
|
|
|
|
if where_conditions:
|
|
where_clause = " WHERE " + " AND ".join(where_conditions)
|
|
else:
|
|
where_clause = ""
|
|
|
|
query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"'
|
|
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(query, where_values)
|
|
records = [dict(row) for row in cursor.fetchall()]
|
|
|
|
fields = getModelFields(model_class)
|
|
modelFields = model_class.model_fields
|
|
for record in records:
|
|
parseRecordFields(record, fields, f"table {table}")
|
|
for fieldName, fieldType in fields.items():
|
|
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
|
fieldInfo = modelFields.get(fieldName)
|
|
if fieldInfo:
|
|
fieldAnnotation = fieldInfo.annotation
|
|
if (fieldAnnotation == list or
|
|
(hasattr(fieldAnnotation, "__origin__") and
|
|
fieldAnnotation.__origin__ is list)):
|
|
record[fieldName] = []
|
|
elif (fieldAnnotation == dict or
|
|
(hasattr(fieldAnnotation, "__origin__") and
|
|
fieldAnnotation.__origin__ is dict)):
|
|
record[fieldName] = {}
|
|
|
|
if fieldFilter and isinstance(fieldFilter, list):
|
|
result = []
|
|
for record in records:
|
|
filteredRecord = {}
|
|
for field in fieldFilter:
|
|
if field in record:
|
|
filteredRecord[field] = record[field]
|
|
result.append(filteredRecord)
|
|
return result
|
|
|
|
return records
|
|
except Exception as e:
|
|
logger.error(f"Error loading records from table {table}: {e}")
|
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
|
|
|
def _buildPaginationClauses(
|
|
self,
|
|
model_class: type,
|
|
pagination,
|
|
recordFilter: Dict[str, Any] = None,
|
|
):
|
|
"""
|
|
Translate PaginationParams + recordFilter into SQL clauses.
|
|
Returns (where_clause, order_clause, limit_clause, values, count_values).
|
|
"""
|
|
fields = getModelFields(model_class)
|
|
validColumns = set(fields.keys())
|
|
|
|
where_parts: List[str] = []
|
|
values: List[Any] = []
|
|
|
|
if recordFilter:
|
|
for field, value in recordFilter.items():
|
|
if value is None:
|
|
where_parts.append(f'"{field}" IS NULL')
|
|
elif field == "isTemplate" and value is False:
|
|
where_parts.append(f'"{field}" IS NOT TRUE')
|
|
elif isinstance(value, list):
|
|
where_parts.append(f'"{field}" = ANY(%s)')
|
|
values.append(value)
|
|
else:
|
|
where_parts.append(f'"{field}" = %s')
|
|
values.append(value)
|
|
|
|
if pagination and pagination.filters:
|
|
for key, val in pagination.filters.items():
|
|
if key == "search" and isinstance(val, str) and val.strip():
|
|
term = f"%{val.strip()}%"
|
|
textCols = [c for c, t in fields.items() if t == "TEXT"]
|
|
if textCols:
|
|
orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols]
|
|
where_parts.append(f"({' OR '.join(orParts)})")
|
|
values.extend([term] * len(textCols))
|
|
continue
|
|
if key not in validColumns:
|
|
logger.debug(f"_buildPaginationClauses: key '{key}' NOT in validColumns {list(validColumns)[:10]}")
|
|
continue
|
|
colType = fields.get(key, "TEXT")
|
|
logger.debug(f"_buildPaginationClauses: filter key='{key}' val={val!r} type(val)={type(val).__name__} colType={colType}")
|
|
if val is None:
|
|
where_parts.append(f'("{key}" IS NULL OR "{key}" = \'\')')
|
|
continue
|
|
if isinstance(val, dict):
|
|
op = val.get("operator", "equals")
|
|
v = val.get("value", "")
|
|
if op in ("equals", "eq"):
|
|
if colType == "BOOLEAN":
|
|
where_parts.append(f'COALESCE("{key}", FALSE) = %s')
|
|
values.append(str(v).lower() == "true")
|
|
else:
|
|
where_parts.append(f'"{key}"::TEXT = %s')
|
|
values.append(str(v))
|
|
elif op == "contains":
|
|
where_parts.append(f'"{key}"::TEXT ILIKE %s')
|
|
values.append(f"%{v}%")
|
|
elif op == "startsWith":
|
|
where_parts.append(f'"{key}"::TEXT ILIKE %s')
|
|
values.append(f"{v}%")
|
|
elif op == "endsWith":
|
|
where_parts.append(f'"{key}"::TEXT ILIKE %s')
|
|
values.append(f"%{v}")
|
|
elif op in ("gt", "gte", "lt", "lte"):
|
|
sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op]
|
|
if colType in ("INTEGER", "DOUBLE PRECISION"):
|
|
try:
|
|
where_parts.append(f'"{key}"::double precision {sqlOp} %s')
|
|
values.append(float(v))
|
|
except (ValueError, TypeError):
|
|
continue
|
|
else:
|
|
where_parts.append(f'"{key}"::TEXT {sqlOp} %s')
|
|
values.append(str(v))
|
|
elif op == "between":
|
|
fromVal = v.get("from", "") if isinstance(v, dict) else ""
|
|
toVal = v.get("to", "") if isinstance(v, dict) else ""
|
|
if not fromVal and not toVal:
|
|
continue
|
|
colType = fields.get(key, "TEXT")
|
|
isNumericCol = colType in ("INTEGER", "DOUBLE PRECISION")
|
|
isDateVal = bool(fromVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(fromVal))) or \
|
|
bool(toVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(toVal)))
|
|
if isNumericCol and isDateVal:
|
|
from datetime import datetime as _dt, timezone as _tz
|
|
if fromVal and toVal:
|
|
fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp()
|
|
toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp()
|
|
where_parts.append(f'"{key}" >= %s AND "{key}" <= %s')
|
|
values.extend([fromTs, toTs])
|
|
elif fromVal:
|
|
fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp()
|
|
where_parts.append(f'"{key}" >= %s')
|
|
values.append(fromTs)
|
|
else:
|
|
toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp()
|
|
where_parts.append(f'"{key}" <= %s')
|
|
values.append(toTs)
|
|
elif isNumericCol:
|
|
try:
|
|
if fromVal and toVal:
|
|
where_parts.append(
|
|
f'"{key}"::double precision >= %s AND "{key}"::double precision <= %s'
|
|
)
|
|
values.extend([float(fromVal), float(toVal)])
|
|
elif fromVal:
|
|
where_parts.append(f'"{key}"::double precision >= %s')
|
|
values.append(float(fromVal))
|
|
elif toVal:
|
|
where_parts.append(f'"{key}"::double precision <= %s')
|
|
values.append(float(toVal))
|
|
except (ValueError, TypeError):
|
|
continue
|
|
else:
|
|
if fromVal and toVal:
|
|
where_parts.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s')
|
|
values.extend([str(fromVal), str(toVal)])
|
|
elif fromVal:
|
|
where_parts.append(f'"{key}"::TEXT >= %s')
|
|
values.append(str(fromVal))
|
|
elif toVal:
|
|
where_parts.append(f'"{key}"::TEXT <= %s')
|
|
values.append(str(toVal))
|
|
else:
|
|
if colType == "BOOLEAN":
|
|
where_parts.append(f'COALESCE("{key}", FALSE) = %s')
|
|
values.append(str(val).lower() == "true")
|
|
else:
|
|
where_parts.append(f'"{key}"::TEXT ILIKE %s')
|
|
values.append(str(val))
|
|
|
|
where_clause = " WHERE " + " AND ".join(where_parts) if where_parts else ""
|
|
count_values = list(values)
|
|
|
|
orderParts: List[str] = []
|
|
if pagination and pagination.sort:
|
|
for sf in pagination.sort:
|
|
sfField = sf.get("field") if isinstance(sf, dict) else getattr(sf, "field", None)
|
|
sfDir = sf.get("direction", "asc") if isinstance(sf, dict) else getattr(sf, "direction", "asc")
|
|
if sfField and sfField in validColumns:
|
|
direction = "DESC" if str(sfDir).lower() == "desc" else "ASC"
|
|
colType = fields.get(sfField, "TEXT")
|
|
if colType == "BOOLEAN":
|
|
orderParts.append(f'COALESCE("{sfField}", FALSE) {direction}')
|
|
else:
|
|
orderParts.append(f'"{sfField}" {direction} NULLS LAST')
|
|
if not orderParts:
|
|
orderParts.append('"id"')
|
|
order_clause = " ORDER BY " + ", ".join(orderParts)
|
|
|
|
limit_clause = ""
|
|
if pagination:
|
|
offset = (pagination.page - 1) * pagination.pageSize
|
|
limit_clause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
|
|
|
|
return where_clause, order_clause, limit_clause, values, count_values
|
|
|
|
def getRecordsetPaginated(
|
|
self,
|
|
model_class: type,
|
|
pagination=None,
|
|
recordFilter: Dict[str, Any] = None,
|
|
fieldFilter: List[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Returns paginated records with filtering + sorting at the SQL level.
|
|
Returns { "items": [...], "totalItems": int, "totalPages": int }.
|
|
If pagination is None, returns all records (no LIMIT/OFFSET).
|
|
"""
|
|
from modules.datamodels.datamodelPagination import PaginationParams
|
|
import math
|
|
|
|
table = model_class.__name__
|
|
|
|
try:
|
|
if not self._ensureTableExists(model_class):
|
|
return {"items": [], "totalItems": 0, "totalPages": 0}
|
|
|
|
where_clause, order_clause, limit_clause, values, count_values = \
|
|
self._buildPaginationClauses(model_class, pagination, recordFilter)
|
|
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
|
|
dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
|
|
cursor.execute(countSql, count_values)
|
|
totalItems = cursor.fetchone()["count"]
|
|
|
|
cursor.execute(dataSql, values)
|
|
records = [dict(row) for row in cursor.fetchall()]
|
|
|
|
fields = getModelFields(model_class)
|
|
modelFields = model_class.model_fields
|
|
for record in records:
|
|
parseRecordFields(record, fields, f"table {table}")
|
|
for fieldName, fieldType in fields.items():
|
|
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
|
fieldInfo = modelFields.get(fieldName)
|
|
if fieldInfo:
|
|
fieldAnnotation = fieldInfo.annotation
|
|
if (fieldAnnotation == list or
|
|
(hasattr(fieldAnnotation, "__origin__") and
|
|
fieldAnnotation.__origin__ is list)):
|
|
record[fieldName] = []
|
|
elif (fieldAnnotation == dict or
|
|
(hasattr(fieldAnnotation, "__origin__") and
|
|
fieldAnnotation.__origin__ is dict)):
|
|
record[fieldName] = {}
|
|
|
|
if fieldFilter and isinstance(fieldFilter, list):
|
|
records = [{f: r[f] for f in fieldFilter if f in r} for r in records]
|
|
|
|
from modules.routes.routeHelpers import enrichRowsWithFkLabels
|
|
enrichRowsWithFkLabels(records, model_class)
|
|
|
|
pageSize = pagination.pageSize if pagination else max(totalItems, 1)
|
|
totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0
|
|
|
|
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
|
|
except Exception as e:
|
|
logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
|
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
|
|
|
def getDistinctColumnValues(
|
|
self,
|
|
model_class: type,
|
|
column: str,
|
|
pagination=None,
|
|
recordFilter: Dict[str, Any] = None,
|
|
includeEmpty: bool = True,
|
|
) -> List[Optional[str]]:
|
|
"""Return sorted distinct values for a column using SQL DISTINCT.
|
|
|
|
When ``includeEmpty`` is True (default), NULL and empty-string rows are
|
|
represented as a single ``None`` entry at the end of the list — this
|
|
allows the frontend to offer a "(Leer)" filter option.
|
|
|
|
Applies cross-filtering (all filters except the requested column).
|
|
"""
|
|
table = model_class.__name__
|
|
fields = getModelFields(model_class)
|
|
|
|
if column not in fields:
|
|
return []
|
|
|
|
try:
|
|
if not self._ensureTableExists(model_class):
|
|
return []
|
|
|
|
if pagination:
|
|
import copy
|
|
pagination = copy.deepcopy(pagination)
|
|
if pagination.filters and column in pagination.filters:
|
|
pagination.filters.pop(column, None)
|
|
pagination.sort = []
|
|
|
|
where_clause, _, _, values, _ = \
|
|
self._buildPaginationClauses(model_class, pagination, recordFilter)
|
|
|
|
nonNullCond = f'"{column}" IS NOT NULL AND "{column}"::TEXT != \'\''
|
|
if where_clause:
|
|
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} AND {nonNullCond} ORDER BY val'
|
|
else:
|
|
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val'
|
|
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(sql, values)
|
|
result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()]
|
|
|
|
if includeEmpty:
|
|
emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\''
|
|
if where_clause:
|
|
emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1'
|
|
else:
|
|
emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1'
|
|
cursor.execute(emptySql, values)
|
|
if cursor.fetchone():
|
|
result.append(None)
|
|
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
|
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
|
|
|
def recordCreate(
|
|
self, model_class: type, record: Union[Dict[str, Any], BaseModel]
|
|
) -> Dict[str, Any]:
|
|
"""Creates a new record in a table based on Pydantic model class."""
|
|
# If record is a Pydantic model, convert to dict
|
|
if isinstance(record, BaseModel):
|
|
record = record.model_dump()
|
|
elif isinstance(record, dict):
|
|
record = record.copy()
|
|
else:
|
|
raise ValueError("Record must be a Pydantic model or dictionary")
|
|
|
|
# Ensure record has an ID
|
|
if "id" not in record:
|
|
record["id"] = str(uuid.uuid4())
|
|
|
|
# Save record
|
|
success = self._saveRecord(model_class, record["id"], record)
|
|
if not success:
|
|
table = model_class.__name__
|
|
raise ValueError(f"Failed to save record {record['id']} to table {table}")
|
|
|
|
# Check if this is the first record in the table and register as initial ID
|
|
table = model_class.__name__
|
|
existingInitialId = self.getInitialId(model_class)
|
|
if existingInitialId is None:
|
|
# This is the first record, register it as the initial ID
|
|
self._registerInitialId(table, record["id"])
|
|
logger.info(f"Registered initial ID {record['id']} for table {table}")
|
|
|
|
return record
|
|
|
|
def recordModify(
|
|
self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel]
|
|
) -> Dict[str, Any]:
|
|
"""Modifies an existing record in a table based on Pydantic model class."""
|
|
# Load existing record
|
|
existingRecord = self._loadRecord(model_class, recordId)
|
|
if not existingRecord:
|
|
table = model_class.__name__
|
|
raise ValueError(f"Record {recordId} not found in table {table}")
|
|
|
|
# If record is a Pydantic model, convert to dict
|
|
if isinstance(record, BaseModel):
|
|
record = record.model_dump()
|
|
elif isinstance(record, dict):
|
|
record = record.copy()
|
|
else:
|
|
raise ValueError("Record must be a Pydantic model or dictionary")
|
|
|
|
# CRITICAL: Ensure we never modify the ID
|
|
if "id" in record and str(record["id"]) != recordId:
|
|
logger.error(
|
|
f"Attempted to modify record ID from {recordId} to {record['id']}"
|
|
)
|
|
raise ValueError(
|
|
"Cannot modify record ID - it must match the provided recordId"
|
|
)
|
|
|
|
# Update existing record with new data
|
|
existingRecord.update(record)
|
|
|
|
# Save updated record
|
|
saved = self._saveRecord(model_class, recordId, existingRecord)
|
|
if not saved:
|
|
table = model_class.__name__
|
|
raise ValueError(f"Failed to save record {recordId} to table {table}")
|
|
return existingRecord
|
|
|
|
def recordDelete(self, model_class: type, recordId: str) -> bool:
|
|
"""Deletes a record from the table based on Pydantic model class."""
|
|
table = model_class.__name__
|
|
|
|
try:
|
|
if not self._ensureTableExists(model_class):
|
|
return False
|
|
|
|
# `getInitialId` opens its own borrow; do it BEFORE we acquire a
|
|
# connection ourselves so we don't pin two slots concurrently.
|
|
initialId = self.getInitialId(model_class)
|
|
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(
|
|
f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
|
|
)
|
|
if not cursor.fetchone():
|
|
return False
|
|
|
|
if initialId is not None and initialId == recordId:
|
|
# `_removeInitialId` borrows its own conn — done outside
|
|
# this block on purpose to avoid nested borrows.
|
|
pass
|
|
cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
|
|
|
|
if initialId is not None and initialId == recordId:
|
|
self._removeInitialId(table)
|
|
logger.info(
|
|
f"Initial ID {recordId} for table {table} has been removed from the system table"
|
|
)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
|
|
return False
|
|
|
|
def recordCreateBulk(
|
|
self, model_class: type, records: List[Union[Dict[str, Any], BaseModel]]
|
|
) -> int:
|
|
"""Bulk-insert many records in a single transaction.
|
|
|
|
Use this instead of calling recordCreate() in a tight loop when importing
|
|
large datasets (>100 rows). Performance gain is roughly two orders of
|
|
magnitude because:
|
|
- one network round-trip via execute_values() instead of N
|
|
- one COMMIT instead of N
|
|
- initial ID is registered once for the whole batch instead of every row
|
|
|
|
Returns the number of rows successfully inserted. Caller is responsible
|
|
for catching exceptions; on any error the transaction is rolled back so
|
|
the table stays consistent (all-or-nothing).
|
|
"""
|
|
if not records:
|
|
return 0
|
|
|
|
table = model_class.__name__
|
|
if not self._ensureTableExists(model_class):
|
|
raise ValueError(f"Table {table} does not exist")
|
|
|
|
fields = getModelFields(model_class)
|
|
columns = ["id"] + [f for f in fields.keys() if f != "id"]
|
|
modelFields = model_class.model_fields
|
|
|
|
effectiveUserId = _current_user_id.get()
|
|
if effectiveUserId is None:
|
|
effectiveUserId = self.userId
|
|
currentTime = getUtcTimestamp()
|
|
|
|
normalised: List[Dict[str, Any]] = []
|
|
for raw in records:
|
|
if isinstance(raw, BaseModel):
|
|
rec = raw.model_dump()
|
|
elif isinstance(raw, dict):
|
|
rec = raw.copy()
|
|
else:
|
|
raise ValueError("Bulk record must be a Pydantic model or dictionary")
|
|
if "id" not in rec or not rec["id"]:
|
|
rec["id"] = str(uuid.uuid4())
|
|
createdTs = rec.get("sysCreatedAt")
|
|
if createdTs is None or createdTs == 0 or createdTs == 0.0:
|
|
rec["sysCreatedAt"] = currentTime
|
|
if effectiveUserId and not rec.get("sysCreatedBy"):
|
|
rec["sysCreatedBy"] = effectiveUserId
|
|
elif not rec.get("sysCreatedBy") and effectiveUserId:
|
|
rec["sysCreatedBy"] = effectiveUserId
|
|
rec["sysModifiedAt"] = currentTime
|
|
if effectiveUserId:
|
|
rec["sysModifiedBy"] = effectiveUserId
|
|
normalised.append(rec)
|
|
|
|
rows = [self._coerceRowForInsert(rec, columns, fields, modelFields) for rec in normalised]
|
|
|
|
col_names = ", ".join([f'"{c}"' for c in columns])
|
|
updates = ", ".join(
|
|
[f'"{c}" = EXCLUDED."{c}"' for c in columns[1:]
|
|
if c not in ("sysCreatedAt", "sysCreatedBy")]
|
|
)
|
|
sql = (
|
|
f'INSERT INTO "{table}" ({col_names}) VALUES %s '
|
|
f'ON CONFLICT ("id") DO UPDATE SET {updates}'
|
|
)
|
|
|
|
try:
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
psycopg2.extras.execute_values(cursor, sql, rows, page_size=500)
|
|
except Exception as e:
|
|
logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}")
|
|
raise
|
|
|
|
if self.getInitialId(model_class) is None and normalised:
|
|
self._registerInitialId(table, normalised[0]["id"])
|
|
logger.info(f"Registered initial ID {normalised[0]['id']} for table {table}")
|
|
|
|
return len(rows)
|
|
|
|
def _coerceRowForInsert(
|
|
self,
|
|
record: Dict[str, Any],
|
|
columns: List[str],
|
|
fields: Dict[str, str],
|
|
modelFields: Dict[str, Any],
|
|
) -> tuple:
|
|
"""Convert one record dict to a positional tuple matching `columns`.
|
|
|
|
Mirrors the per-column coercion logic in `_save_record` so that bulk and
|
|
single inserts produce identical on-disk values (timestamps as floats,
|
|
enums as strings, vectors as pgvector text, JSONB as JSON strings).
|
|
"""
|
|
import json as _json
|
|
out = []
|
|
for col in columns:
|
|
value = record.get(col)
|
|
if col in ("sysCreatedAt", "sysModifiedAt") and value is not None:
|
|
if isinstance(value, str):
|
|
try:
|
|
value = float(value)
|
|
except Exception:
|
|
pass
|
|
elif hasattr(value, "value"):
|
|
value = value.value
|
|
elif col in fields and _isVectorType(fields[col]) and value is not None:
|
|
if isinstance(value, list):
|
|
value = f"[{','.join(str(v) for v in value)}]"
|
|
elif col in fields and fields[col] == "JSONB" and value is not None:
|
|
if isinstance(value, (dict, list)):
|
|
value = _json.dumps(value)
|
|
elif isinstance(value, str):
|
|
try:
|
|
_json.loads(value)
|
|
except (ValueError, TypeError):
|
|
value = _json.dumps(value)
|
|
elif hasattr(value, "model_dump"):
|
|
value = _json.dumps(value.model_dump())
|
|
else:
|
|
value = _json.dumps(value)
|
|
out.append(value)
|
|
return tuple(out)
|
|
|
|
def recordDeleteWhere(
|
|
self, model_class: type, recordFilter: Dict[str, Any]
|
|
) -> int:
|
|
"""Delete all records matching a simple equality filter, in one statement.
|
|
|
|
Replaces the N+1 pattern `for r in getRecordset(...): recordDelete(r.id)`.
|
|
Returns the number of rows actually deleted. If the table holds the
|
|
initial ID and that row gets deleted, the initial ID registration is
|
|
cleared so the next insert can re-register a fresh one.
|
|
"""
|
|
if not recordFilter:
|
|
raise ValueError("recordDeleteWhere requires a non-empty recordFilter (refusing to truncate)")
|
|
|
|
table = model_class.__name__
|
|
if not self._ensureTableExists(model_class):
|
|
return 0
|
|
|
|
fields = getModelFields(model_class)
|
|
clauses: List[str] = []
|
|
params: List[Any] = []
|
|
for key, val in recordFilter.items():
|
|
if key not in fields and key != "id":
|
|
raise ValueError(f"recordDeleteWhere: unknown column {table}.{key}")
|
|
clauses.append(f'"{key}" = %s')
|
|
params.append(val)
|
|
whereSql = " AND ".join(clauses)
|
|
|
|
initialId = self.getInitialId(model_class)
|
|
try:
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
if initialId is not None:
|
|
cursor.execute(
|
|
f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql,
|
|
[initialId, *params],
|
|
)
|
|
initialIsAffected = cursor.fetchone() is not None
|
|
else:
|
|
initialIsAffected = False
|
|
|
|
cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params)
|
|
deleted = cursor.rowcount or 0
|
|
except Exception as e:
|
|
logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}")
|
|
raise
|
|
|
|
if deleted and initialIsAffected:
|
|
self._removeInitialId(table)
|
|
logger.info(f"Initial ID for table {table} cleared (bulk-delete removed it)")
|
|
if deleted:
|
|
logger.info(f"recordDeleteWhere: deleted {deleted} rows from {table} where {recordFilter}")
|
|
return deleted
|
|
|
|
def getInitialId(self, model_class: type) -> Optional[str]:
|
|
"""Returns the initial ID for a table."""
|
|
table = model_class.__name__
|
|
systemData = self._loadSystemTable()
|
|
initialId = systemData.get(table)
|
|
return initialId
|
|
|
|
def semanticSearch(
|
|
self,
|
|
modelClass: type,
|
|
vectorColumn: str,
|
|
queryVector: List[float],
|
|
limit: int = 10,
|
|
recordFilter: Dict[str, Any] = None,
|
|
minScore: float = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Semantic search using pgvector cosine distance.
|
|
|
|
Args:
|
|
modelClass: Pydantic model class for the table.
|
|
vectorColumn: Name of the vector column to search.
|
|
queryVector: Query vector as List[float].
|
|
limit: Maximum number of results.
|
|
recordFilter: Additional WHERE filters (field: value).
|
|
minScore: Minimum cosine similarity (0.0 - 1.0).
|
|
|
|
Returns:
|
|
List of records with an added '_score' field (cosine similarity),
|
|
sorted by similarity descending.
|
|
"""
|
|
table = modelClass.__name__
|
|
|
|
try:
|
|
if not self._ensureTableExists(modelClass):
|
|
return []
|
|
|
|
vectorStr = f"[{','.join(str(v) for v in queryVector)}]"
|
|
|
|
whereConditions = []
|
|
whereValues = []
|
|
|
|
if recordFilter:
|
|
for field, value in recordFilter.items():
|
|
if value is None:
|
|
whereConditions.append(f'"{field}" IS NULL')
|
|
elif isinstance(value, (list, tuple)):
|
|
if not value:
|
|
whereConditions.append("1 = 0")
|
|
else:
|
|
whereConditions.append(f'"{field}" = ANY(%s)')
|
|
whereValues.append(list(value))
|
|
else:
|
|
whereConditions.append(f'"{field}" = %s')
|
|
whereValues.append(value)
|
|
|
|
if minScore is not None:
|
|
whereConditions.append(
|
|
f'1 - ("{vectorColumn}" <=> %s::vector) >= %s'
|
|
)
|
|
whereValues.extend([vectorStr, minScore])
|
|
|
|
whereClause = ""
|
|
if whereConditions:
|
|
whereClause = " WHERE " + " AND ".join(whereConditions)
|
|
|
|
query = (
|
|
f'SELECT *, 1 - ("{vectorColumn}" <=> %s::vector) AS "_score" '
|
|
f'FROM "{table}"{whereClause} '
|
|
f'ORDER BY "{vectorColumn}" <=> %s::vector '
|
|
f'LIMIT %s'
|
|
)
|
|
params = [vectorStr] + whereValues + [vectorStr, limit]
|
|
|
|
with self.borrowConn() as conn:
|
|
with conn.cursor() as cursor:
|
|
cursor.execute(query, params)
|
|
records = [dict(row) for row in cursor.fetchall()]
|
|
|
|
fields = getModelFields(modelClass)
|
|
for record in records:
|
|
parseRecordFields(record, fields, f"semanticSearch {table}")
|
|
return records
|
|
except Exception as e:
|
|
logger.error(f"Error in semantic search on {table}: {e}")
|
|
raise DatabaseQueryError(table, str(e), original=e) from e
|
|
|
|
def close(self, forceClose: bool = False):
|
|
"""No-op for backward compatibility.
|
|
|
|
Connections are now owned by the `_PoolRegistry` pool and live for the
|
|
process lifetime. Pool shutdown happens centrally via `closeAllPools()`
|
|
from the FastAPI lifespan hook — never from a connector instance.
|
|
Interface `__del__` paths used to call `close()` to release a per-
|
|
connector socket; with pooling there is nothing to close here.
|
|
"""
|
|
return
|
|
|
|
def __del__(self):
|
|
"""Cleanup hook (intentionally no-op — see `close`)."""
|
|
return
|