gateway/modules/connectors/connectorDbPostgre.py
2026-04-26 08:57:49 +02:00

1695 lines
69 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
import contextvars
import re
import psycopg2
import psycopg2.extras
import logging
from typing import List, Dict, Any, Optional, Union, get_origin, get_args, Type
import uuid
from pydantic import BaseModel, Field
import threading
from modules.shared.timeUtils import getUtcTimestamp
from modules.shared.configuration import APP_CONFIG
from modules.datamodels.datamodelBase import PowerOnModel
from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
logger = logging.getLogger(__name__)
# No mapping needed - table name = Pydantic model name exactly
class SystemTable(PowerOnModel):
"""Data model for system table entries"""
table_name: str = Field(
description="Name of the table",
json_schema_extra={
"frontend_type": "text",
"frontend_readonly": True,
"frontend_required": True,
}
)
initial_id: Optional[str] = Field(
default=None,
description="Initial ID for the table",
json_schema_extra={
"frontend_type": "text",
"frontend_readonly": True,
"frontend_required": False,
}
)
def _isVectorType(sqlType: str) -> bool:
"""Check if a SQL type string represents a pgvector column."""
return sqlType.upper().startswith("VECTOR")
def _isJsonbType(fieldType) -> bool:
"""Check if a type should be stored as JSONB in PostgreSQL."""
# Direct dict or list
if fieldType == dict or fieldType == list:
return True
# Generic List[X] or Dict[X, Y]
origin = get_origin(fieldType)
if origin in (dict, list):
return True
# Direct Pydantic BaseModel subclass
if isinstance(fieldType, type) and issubclass(fieldType, BaseModel):
return True
# Optional[X] - check the inner type
if origin is Union:
args = get_args(fieldType)
for arg in args:
if arg is type(None):
continue
# Recursively check the inner type
if _isJsonbType(arg):
return True
return False
def getModelFields(model_class) -> Dict[str, str]:
"""Get all fields from Pydantic model and map to SQL types.
Supports explicit db_type override via json_schema_extra={"db_type": "vector(1536)"}.
This enables pgvector columns without special-casing field names.
"""
model_fields = model_class.model_fields
fields = {}
for field_name, field_info in model_fields.items():
field_type = field_info.annotation
# Explicit db_type override (e.g. vector columns)
extra = field_info.json_schema_extra
if extra and isinstance(extra, dict) and "db_type" in extra:
fields[field_name] = extra["db_type"]
continue
# Unwrap Optional[X] → X (handles both typing.Union and types.UnionType)
origin = get_origin(field_type)
if origin is Union:
args = [a for a in get_args(field_type) if a is not type(None)]
if len(args) == 1:
field_type = args[0]
elif hasattr(field_type, '__args__') and type(None) in getattr(field_type, '__args__', ()):
args = [a for a in field_type.__args__ if a is not type(None)]
if len(args) == 1:
field_type = args[0]
if _isJsonbType(field_type):
fields[field_name] = "JSONB"
elif field_type is bool:
fields[field_name] = "BOOLEAN"
elif field_type is int:
fields[field_name] = "INTEGER"
elif field_type is float:
fields[field_name] = "DOUBLE PRECISION"
elif field_type in (str, type(None)):
fields[field_name] = "TEXT"
else:
fields[field_name] = "TEXT"
return fields
def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: str = "") -> None:
"""Parse record fields in-place: numeric typing, vector parsing, JSONB deserialization."""
import json as _json
for fieldName, fieldType in fields.items():
if fieldName not in record:
continue
value = record[fieldName]
if fieldType in ("DOUBLE PRECISION", "INTEGER") and value is not None:
try:
record[fieldName] = float(value) if fieldType == "DOUBLE PRECISION" else int(value)
except (ValueError, TypeError):
logger.warning(f"Could not convert {fieldName} to {fieldType} ({context}): {value}")
elif _isVectorType(fieldType) and value is not None:
if isinstance(value, str):
try:
record[fieldName] = [float(v) for v in value.strip("[]").split(",")]
except (ValueError, TypeError):
logger.warning(f"Could not parse vector field {fieldName} ({context})")
elif isinstance(value, list):
pass # already a list
elif fieldType == "BOOLEAN":
record[fieldName] = bool(value) if value is not None else False
elif fieldType == "JSONB" and value is not None:
try:
if isinstance(value, str):
record[fieldName] = _json.loads(value)
elif not isinstance(value, (dict, list)):
record[fieldName] = _json.loads(str(value))
except (_json.JSONDecodeError, TypeError, ValueError):
logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})")
def _quotePgIdent(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"'
# Cache connectors by (host, database, port) to avoid duplicate inits for same database.
# Thread safety: _connector_cache_lock protects cache access. userId is request-scoped via
# contextvars to avoid races when concurrent requests share the same connector.
_MAX_CACHED_CONNECTORS = 32
_connector_cache: Dict[tuple, "DatabaseConnector"] = {}
_connector_cache_order: List[tuple] = [] # FIFO order for eviction
_connector_cache_lock = threading.Lock()
_current_user_id: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"db_connector_user_id", default=None
)
def getCachedConnector(
dbHost: str,
dbDatabase: str,
dbUser: str = None,
dbPassword: str = None,
dbPort: int = None,
userId: str = None,
) -> "DatabaseConnector":
"""Return cached DatabaseConnector for same (host, database, port) to avoid duplicate PostgreSQL inits.
Uses contextvars for userId so concurrent requests sharing the same connector get correct sysCreatedBy/sysModifiedBy.
"""
port = int(dbPort) if dbPort is not None else 5432
key = (dbHost, dbDatabase, port)
with _connector_cache_lock:
if key not in _connector_cache:
# Evict oldest if at capacity
while len(_connector_cache) >= _MAX_CACHED_CONNECTORS and _connector_cache_order:
oldest_key = _connector_cache_order.pop(0)
if oldest_key in _connector_cache:
try:
_connector_cache[oldest_key].close(forceClose=True)
except Exception as e:
logger.warning(f"Error closing evicted connector: {e}")
del _connector_cache[oldest_key]
_connector_cache[key] = DatabaseConnector(
dbHost=dbHost,
dbDatabase=dbDatabase,
dbUser=dbUser,
dbPassword=dbPassword,
dbPort=dbPort,
userId=userId,
)
_connector_cache[key]._isCachedShared = True
_connector_cache_order.append(key)
conn = _connector_cache[key]
# Set request-scoped userId via contextvar (avoids mutating shared connector)
if userId is not None:
_current_user_id.set(userId)
return conn
class DatabaseConnector:
"""
A connector for PostgreSQL-based data storage.
Provides generic database operations without user/mandate filtering.
Uses PostgreSQL with JSONB columns for flexible data storage.
"""
def __init__(
self,
dbHost: str,
dbDatabase: str,
dbUser: str = None,
dbPassword: str = None,
dbPort: int = None,
userId: str = None,
):
# Store the input parameters
self.dbHost = dbHost
self.dbDatabase = dbDatabase
self.dbUser = dbUser
self.dbPassword = dbPassword
self.dbPort = dbPort
# Set userId (default to empty string if None)
self.userId = userId if userId is not None else ""
# Initialize database system first (creates database if needed)
self.connection = None
self._isCachedShared = False
self.initDbSystem()
# No caching needed with proper database - PostgreSQL handles performance
# Thread safety
self._lock = threading.Lock()
# pgvector extension state
self._vectorExtensionEnabled = False
# Initialize system table
self._systemTableName = "_system"
self._initializeSystemTable()
def initDbSystem(self):
"""Initialize the database system - creates database and tables."""
try:
# Create database if it doesn't exist
self._create_database_if_not_exists()
# Create tables
self._create_tables()
# Establish connection to the database
self._connect()
logger.info("PostgreSQL database system initialized successfully")
except Exception as e:
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
raise
def _create_database_if_not_exists(self):
"""Create the database if it doesn't exist."""
try:
# Use the configured user for database creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database="postgres",
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
)
conn.autocommit = True
with conn.cursor() as cursor:
# Check if database exists
cursor.execute(
"SELECT 1 FROM pg_database WHERE datname = %s", (self.dbDatabase,)
)
exists = cursor.fetchone()
if not exists:
# Create database with proper quoting for names with hyphens
quoted_db_name = f'"{self.dbDatabase}"'
cursor.execute(f"CREATE DATABASE {quoted_db_name}")
logger.info(f"Created database: {self.dbDatabase}")
conn.close()
except Exception as e:
logger.error(f"FATAL ERROR: Cannot create database: {e}")
logger.error("Database connection failed - application cannot start")
raise RuntimeError(
f"FATAL ERROR: Cannot create database '{self.dbDatabase}': {e}"
)
def _create_tables(self):
"""Create only the system table - application tables are created by interfaces."""
try:
# Use the configured user for table creation
conn = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database=self.dbDatabase,
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
)
conn.autocommit = True
with conn.cursor() as cursor:
# Create only the system table
cursor.execute("""
CREATE TABLE IF NOT EXISTS _system (
id SERIAL PRIMARY KEY,
table_name VARCHAR(255) UNIQUE NOT NULL,
initial_id VARCHAR(255) NOT NULL,
"sysCreatedAt" DOUBLE PRECISION,
"sysCreatedBy" VARCHAR(255),
"sysModifiedAt" DOUBLE PRECISION,
"sysModifiedBy" VARCHAR(255)
)
""")
conn.close()
except Exception as e:
logger.error(f"FATAL ERROR: Cannot create system table: {e}")
logger.error(
"Database system table creation failed - application cannot start"
)
raise RuntimeError(f"FATAL ERROR: Cannot create system table: {e}")
def _connect(self):
"""Establish connection to PostgreSQL database."""
try:
# Use configured user for main connection with proper parameter handling
self.connection = psycopg2.connect(
host=self.dbHost,
port=self.dbPort,
database=self.dbDatabase,
user=self.dbUser,
password=self.dbPassword,
client_encoding="utf8",
cursor_factory=psycopg2.extras.RealDictCursor,
)
self.connection.autocommit = False # Use transactions
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
raise
def _ensure_connection(self):
"""Ensure database connection is alive, reconnect if necessary."""
try:
if self.connection is None or self.connection.closed:
self._connect()
else:
# Test connection with a simple query
with self.connection.cursor() as cursor:
cursor.execute("SELECT 1")
except Exception as e:
logger.warning(f"Connection lost, reconnecting: {e}")
self._connect()
def _initializeSystemTable(self):
"""Initializes the system table if it doesn't exist yet."""
try:
# First ensure the system table exists
self._ensureTableExists(SystemTable)
with self.connection.cursor() as cursor:
# Check if system table has any data
cursor.execute('SELECT COUNT(*) FROM "_system"')
row = cursor.fetchone()
count = row["count"] if row else 0
self.connection.commit()
except Exception as e:
logger.error(f"Error initializing system table: {e}")
self.connection.rollback()
raise
def _loadSystemTable(self) -> Dict[str, str]:
"""Loads the system table with the initial IDs."""
try:
with self.connection.cursor() as cursor:
cursor.execute('SELECT "table_name", "initial_id" FROM "_system"')
rows = cursor.fetchall()
system_data = {}
for row in rows:
system_data[row["table_name"]] = row["initial_id"]
return system_data
except Exception as e:
logger.error(f"Error loading system table: {e}")
return {}
def _saveSystemTable(self, data: Dict[str, str]) -> bool:
"""Saves the system table with the initial IDs."""
try:
with self.connection.cursor() as cursor:
# Clear existing data
cursor.execute('DELETE FROM "_system"')
# Insert new data
for table_name, initial_id in data.items():
cursor.execute(
"""
INSERT INTO "_system" ("table_name", "initial_id", "sysModifiedAt")
VALUES (%s, %s, %s)
""",
(table_name, initial_id, getUtcTimestamp()),
)
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error saving system table: {e}")
self.connection.rollback()
return False
def _ensureSystemTableExists(self) -> bool:
"""Ensures the system table exists, creates it if it doesn't."""
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
# Check if system table exists
cursor.execute(
"SELECT COUNT(*) FROM pg_stat_user_tables WHERE relname = %s",
(self._systemTableName,),
)
exists = cursor.fetchone()["count"] > 0
if not exists:
# Create system table
cursor.execute(f"""
CREATE TABLE "{self._systemTableName}" (
"table_name" VARCHAR(255) PRIMARY KEY,
"initial_id" VARCHAR(255),
"sysCreatedAt" DOUBLE PRECISION,
"sysCreatedBy" VARCHAR(255),
"sysModifiedAt" DOUBLE PRECISION,
"sysModifiedBy" VARCHAR(255)
)
""")
logger.info("System table created successfully")
else:
# Check if we need to add missing columns to existing table
cursor.execute(
"""
SELECT column_name FROM information_schema.columns
WHERE table_name = %s AND table_schema = 'public'
""",
(self._systemTableName,),
)
existing_columns = [row["column_name"] for row in cursor.fetchall()]
for sys_col, sys_sql in [
("sysCreatedAt", "DOUBLE PRECISION"),
("sysCreatedBy", "VARCHAR(255)"),
("sysModifiedAt", "DOUBLE PRECISION"),
("sysModifiedBy", "VARCHAR(255)"),
]:
if sys_col not in existing_columns:
cursor.execute(
f'ALTER TABLE "{self._systemTableName}" ADD COLUMN "{sys_col}" {sys_sql}'
)
return True
except Exception as e:
logger.error(f"Error ensuring system table exists: {e}")
return False
def _ensureTableExists(self, model_class: type) -> bool:
"""Ensures a table exists, creates it if it doesn't."""
table = model_class.__name__
if table == "SystemTable":
# Handle system table specially - it uses _system as the actual table name
return self._ensureSystemTableExists()
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
# Check if table exists by querying information_schema with case-insensitive search
cursor.execute(
"""
SELECT COUNT(*) FROM information_schema.tables
WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
""",
(table,),
)
exists = cursor.fetchone()["count"] > 0
if not exists:
# Create table from Pydantic model
self._create_table_from_model(cursor, table, model_class)
logger.info(
f"Created table '{table}' with columns from Pydantic model"
)
else:
# Table exists: ensure all columns from model are present (simple additive migration)
try:
cursor.execute(
"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE LOWER(table_name) = LOWER(%s) AND table_schema = 'public'
""",
(table,),
)
existing_column_rows = cursor.fetchall()
existing_columns = {
row["column_name"] for row in existing_column_rows
}
existing_column_types = {
row["column_name"]: (row["data_type"] or "").lower()
for row in existing_column_rows
}
# Desired columns based on model
model_fields = getModelFields(model_class)
desired_columns = set(["id"]) | set(model_fields.keys())
# Add missing columns
for col in sorted(desired_columns - existing_columns):
# Determine SQL type
if col in ["id"]:
continue # primary key exists already
sql_type = model_fields.get(col)
if not sql_type:
sql_type = "TEXT"
try:
cursor.execute(
f'ALTER TABLE "{table}" ADD COLUMN "{col}" {sql_type}'
)
logger.info(
f"Added missing column '{col}' ({sql_type}) to '{table}'"
)
except Exception as add_err:
logger.warning(
f"Could not add column '{col}' to '{table}': {add_err}"
)
# Targeted type-downgrade: if a model field has been
# changed from a structured type (JSONB) to a plain
# TEXT field, alter the column so writes don't fail.
# JSONB -> TEXT is a safe, lossless cast (JSONB is
# rendered as its JSON-text representation; the
# corresponding Pydantic ``@field_validator`` is
# responsible for re-decoding legacy data on read).
for col in sorted(desired_columns & existing_columns):
if col == "id":
continue
desired_sql = (model_fields.get(col) or "").upper()
currentType = existing_column_types.get(col, "")
if desired_sql == "TEXT" and currentType == "jsonb":
try:
cursor.execute(
f'ALTER TABLE "{table}" ALTER COLUMN "{col}" TYPE TEXT USING "{col}"::text'
)
logger.info(
f"Downgraded column '{col}' from JSONB to TEXT on '{table}'"
)
except Exception as alter_err:
logger.warning(
f"Could not downgrade column '{col}' on '{table}': {alter_err}"
)
except Exception as ensure_err:
logger.warning(
f"Could not ensure columns for existing table '{table}': {ensure_err}"
)
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error ensuring table {table} exists: {e}")
if hasattr(self, "connection") and self.connection:
self.connection.rollback()
return False
def _ensureVectorExtension(self) -> bool:
"""Enable pgvector extension if not already enabled. Called lazily on first vector table."""
if self._vectorExtensionEnabled:
return True
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS vector")
self.connection.commit()
self._vectorExtensionEnabled = True
logger.info("pgvector extension enabled")
return True
except Exception as e:
logger.error(f"Failed to enable pgvector extension: {e}")
if hasattr(self, "connection") and self.connection:
self.connection.rollback()
return False
def _create_table_from_model(self, cursor, table: str, model_class: type) -> None:
"""Create table with columns matching Pydantic model fields."""
fields = getModelFields(model_class)
# Enable pgvector if any field uses vector type
if any(_isVectorType(sqlType) for sqlType in fields.values()):
self._ensureVectorExtension()
# Build column definitions with quoted identifiers to preserve exact case
columns = ['"id" VARCHAR(255) PRIMARY KEY']
for field_name, sql_type in fields.items():
if field_name != "id": # Skip id, already defined
columns.append(f'"{field_name}" {sql_type}')
# Create table
sql = f'CREATE TABLE IF NOT EXISTS "{table}" ({", ".join(columns)})'
cursor.execute(sql)
# Create indexes for foreign keys
for field_name in fields:
if field_name.endswith("Id") and field_name != "id":
cursor.execute(
f'CREATE INDEX IF NOT EXISTS "idx_{table}_{field_name}" ON "{table}" ("{field_name}")'
)
def _save_record(
self,
cursor,
table: str,
recordId: str,
record: Dict[str, Any],
model_class: type,
) -> None:
"""Save record to normalized table with explicit columns."""
# Get columns from Pydantic model instead of database schema
fields = getModelFields(model_class)
columns = ["id"] + [field for field in fields.keys() if field != "id"]
if not columns:
logger.error(f"No columns found for table {table}")
return
# Filter record data to only include columns that exist in the table
filtered_record = {k: v for k, v in record.items() if k in columns}
# Ensure id is set
filtered_record["id"] = recordId
# Prepare values in the correct order
values = []
for col in columns:
value = filtered_record.get(col)
# Handle timestamp fields - store as Unix timestamps (floats) for consistency
if col in ["sysCreatedAt", "sysModifiedAt"] and value is not None:
if isinstance(value, str):
# Try to parse string as timestamp
try:
value = float(value)
except:
pass # Keep as string if parsing fails
# Convert enum values to their string representation
elif hasattr(value, "value"):
value = value.value
# Handle vector fields (pgvector) - convert List[float] to string
elif col in fields and _isVectorType(fields[col]) and value is not None:
if isinstance(value, list):
value = f"[{','.join(str(v) for v in value)}]"
# Handle JSONB fields - ensure proper JSON format for PostgreSQL
elif col in fields and fields[col] == "JSONB" and value is not None:
import json
if isinstance(value, (dict, list)):
value = json.dumps(value)
elif isinstance(value, str):
try:
json.loads(value)
except (json.JSONDecodeError, TypeError):
value = json.dumps(value)
elif hasattr(value, 'model_dump'):
value = json.dumps(value.model_dump())
else:
value = json.dumps(value)
values.append(value)
# Build INSERT/UPDATE with quoted identifiers
col_names = ", ".join([f'"{col}"' for col in columns])
placeholders = ", ".join(["%s"] * len(columns))
updates = ", ".join(
[
f'"{col}" = EXCLUDED."{col}"'
for col in columns[1:]
if col not in ["sysCreatedAt", "sysCreatedBy"]
]
)
sql = f'INSERT INTO "{table}" ({col_names}) VALUES ({placeholders}) ON CONFLICT ("id") DO UPDATE SET {updates}'
cursor.execute(sql, values)
def _loadRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
"""Loads a single record from the normalized table."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return None
with self.connection.cursor() as cursor:
cursor.execute(f'SELECT * FROM "{table}" WHERE "id" = %s', (recordId,))
row = cursor.fetchone()
if not row:
return None
# Convert row to dict and handle JSONB fields
record = dict(row)
fields = getModelFields(model_class)
parseRecordFields(record, fields, f"record {recordId}")
return record
except Exception as e:
logger.error(f"Error loading record {recordId} from table {table}: {e}")
return None
def getRecord(self, model_class: type, recordId: str) -> Optional[Dict[str, Any]]:
"""Load one row by primary key (routes / services; wraps _loadRecord)."""
return self._loadRecord(model_class, str(recordId))
def _saveRecord(
self, model_class: type, recordId: str, record: Dict[str, Any]
) -> bool:
"""Saves a single record to the table."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return False
recordId = str(recordId)
if "id" in record and str(record["id"]) != recordId:
raise ValueError(f"Record ID mismatch: {recordId} != {record['id']}")
# Add metadata - use contextvar for request-scoped userId when sharing connector
effective_user_id = _current_user_id.get()
if effective_user_id is None:
effective_user_id = self.userId
currentTime = getUtcTimestamp()
# Set sysCreatedAt/sysCreatedBy on first persist; always refresh modified fields.
# Treat None and 0 as unset (empty / bad defaults); model_dump often has sysCreatedAt=None.
createdTs = record.get("sysCreatedAt")
if createdTs is None or createdTs == 0 or createdTs == 0.0:
record["sysCreatedAt"] = currentTime
if effective_user_id:
record["sysCreatedBy"] = effective_user_id
elif not record.get("sysCreatedBy"):
if effective_user_id:
record["sysCreatedBy"] = effective_user_id
record["sysModifiedAt"] = currentTime
if effective_user_id:
record["sysModifiedBy"] = effective_user_id
with self.connection.cursor() as cursor:
self._save_record(cursor, table, recordId, record, model_class)
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error saving record {recordId} to table {table}: {e}")
self.connection.rollback()
return False
def _loadTable(self, model_class: type) -> List[Dict[str, Any]]:
"""Loads all records from a normalized table."""
table = model_class.__name__
if table == self._systemTableName:
return self._loadSystemTable()
try:
if not self._ensureTableExists(model_class):
return []
with self.connection.cursor() as cursor:
cursor.execute(f'SELECT * FROM "{table}" ORDER BY "id"')
records = [dict(row) for row in cursor.fetchall()]
fields = getModelFields(model_class)
modelFields = model_class.model_fields
for record in records:
parseRecordFields(record, fields, f"table {table}")
# Set type-aware defaults for NULL JSONB fields
for fieldName, fieldType in fields.items():
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
fieldInfo = modelFields.get(fieldName)
if fieldInfo:
fieldAnnotation = fieldInfo.annotation
if (fieldAnnotation == list or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is list)):
record[fieldName] = []
elif (fieldAnnotation == dict or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is dict)):
record[fieldName] = {}
return records
except Exception as e:
logger.error(f"Error loading table {table}: {e}")
return []
def _registerInitialId(self, table: str, initialId: str) -> bool:
"""Registers the initial ID for a table."""
try:
systemData = self._loadSystemTable()
if table not in systemData:
systemData[table] = initialId
success = self._saveSystemTable(systemData)
if success:
logger.info(f"Initial ID {initialId} for table {table} registered")
return success
else:
# Table already has an initial ID registered
logger.debug(f"Table {table} already has initial ID {systemData[table]}")
return True
except Exception as e:
logger.error(f"Error registering the initial ID for table {table}: {e}")
return False
def _removeInitialId(self, table: str) -> bool:
"""Removes the initial ID for a table from the system table."""
try:
systemData = self._loadSystemTable()
if table in systemData:
del systemData[table]
success = self._saveSystemTable(systemData)
if success:
logger.info(
f"Initial ID for table {table} removed from system table"
)
return success
return True # If not present, this is not an error
except Exception as e:
logger.error(f"Error removing initial ID for table {table}: {e}")
return False
def buildRbacWhereClause(
self,
permissions: UserPermissions,
currentUser: User,
table: str,
mandateId: Optional[str] = None,
featureInstanceId: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
"""Delegate to interfaceRbac.buildRbacWhereClause (tests and call sites use connector as entry)."""
from modules.interfaces.interfaceRbac import buildRbacWhereClause as _buildRbacWhereClause
return _buildRbacWhereClause(
permissions,
currentUser,
table,
self,
mandateId=mandateId,
featureInstanceId=featureInstanceId,
)
def updateContext(self, userId: str) -> None:
"""Updates the context of the database connector.
Sets both instance userId and contextvar for request-scoped use when connector is shared.
"""
if userId is None:
raise ValueError("userId must be provided")
self.userId = userId
_current_user_id.set(userId)
# Public API
def getTables(self) -> List[str]:
"""Returns a list of all available tables."""
tables = []
try:
# Ensure connection is alive
self._ensure_connection()
if not self.connection or self.connection.closed:
logger.error("Database connection is not available")
return tables
with self.connection.cursor() as cursor:
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
ORDER BY table_name
""")
rows = cursor.fetchall()
tables = [row["table_name"] for row in rows]
except Exception as e:
logger.error(f"Error reading the database {self.dbDatabase}: {e}")
return tables
def getFields(self, model_class: type) -> List[str]:
"""Returns a list of all fields in a table."""
data = self._loadTable(model_class)
if not data:
return []
fields = list(data[0].keys()) if data else []
return fields
def getSchema(
self, model_class: type, language: str = None
) -> Dict[str, Dict[str, Any]]:
"""Returns a schema object for a table with data types and labels."""
data = self._loadTable(model_class)
schema = {}
if not data:
return schema
firstRecord = data[0]
for field, value in firstRecord.items():
dataType = type(value).__name__
label = field
schema[field] = {"type": dataType, "label": label}
return schema
def getRecordset(
self,
model_class: type,
fieldFilter: List[str] = None,
recordFilter: Dict[str, Any] = None,
) -> List[Dict[str, Any]]:
"""Returns a list of records from a table, filtered by criteria."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return []
# Build WHERE clause from recordFilter
where_conditions = []
where_values = []
if recordFilter:
for field, value in recordFilter.items():
if value is None:
where_conditions.append(f'"{field}" IS NULL')
elif isinstance(value, list):
where_conditions.append(f'"{field}" = ANY(%s)')
where_values.append(value)
else:
where_conditions.append(f'"{field}" = %s')
where_values.append(value)
if where_conditions:
where_clause = " WHERE " + " AND ".join(where_conditions)
else:
where_clause = ""
query = f'SELECT * FROM "{table}"{where_clause} ORDER BY "id"'
with self.connection.cursor() as cursor:
cursor.execute(query, where_values)
records = [dict(row) for row in cursor.fetchall()]
fields = getModelFields(model_class)
modelFields = model_class.model_fields
for record in records:
parseRecordFields(record, fields, f"table {table}")
for fieldName, fieldType in fields.items():
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
fieldInfo = modelFields.get(fieldName)
if fieldInfo:
fieldAnnotation = fieldInfo.annotation
if (fieldAnnotation == list or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is list)):
record[fieldName] = []
elif (fieldAnnotation == dict or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is dict)):
record[fieldName] = {}
# If fieldFilter is available, reduce the fields
if fieldFilter and isinstance(fieldFilter, list):
result = []
for record in records:
filteredRecord = {}
for field in fieldFilter:
if field in record:
filteredRecord[field] = record[field]
result.append(filteredRecord)
return result
return records
except Exception as e:
logger.error(f"Error loading records from table {table}: {e}")
return []
def _buildPaginationClauses(
self,
model_class: type,
pagination,
recordFilter: Dict[str, Any] = None,
):
"""
Translate PaginationParams + recordFilter into SQL clauses.
Returns (where_clause, order_clause, limit_clause, values, count_values).
"""
fields = getModelFields(model_class)
validColumns = set(fields.keys())
where_parts: List[str] = []
values: List[Any] = []
if recordFilter:
for field, value in recordFilter.items():
if value is None:
where_parts.append(f'"{field}" IS NULL')
elif isinstance(value, list):
where_parts.append(f'"{field}" = ANY(%s)')
values.append(value)
else:
where_parts.append(f'"{field}" = %s')
values.append(value)
if pagination and pagination.filters:
for key, val in pagination.filters.items():
if key == "search" and isinstance(val, str) and val.strip():
term = f"%{val.strip()}%"
textCols = [c for c, t in fields.items() if t == "TEXT"]
if textCols:
orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols]
where_parts.append(f"({' OR '.join(orParts)})")
values.extend([term] * len(textCols))
continue
if key not in validColumns:
logger.debug(f"_buildPaginationClauses: key '{key}' NOT in validColumns {list(validColumns)[:10]}")
continue
colType = fields.get(key, "TEXT")
logger.debug(f"_buildPaginationClauses: filter key='{key}' val={val!r} type(val)={type(val).__name__} colType={colType}")
if val is None:
where_parts.append(f'("{key}" IS NULL OR "{key}" = \'\')')
continue
if isinstance(val, dict):
op = val.get("operator", "equals")
v = val.get("value", "")
if op in ("equals", "eq"):
if colType == "BOOLEAN":
where_parts.append(f'COALESCE("{key}", FALSE) = %s')
values.append(str(v).lower() == "true")
else:
where_parts.append(f'"{key}"::TEXT = %s')
values.append(str(v))
elif op == "contains":
where_parts.append(f'"{key}"::TEXT ILIKE %s')
values.append(f"%{v}%")
elif op == "startsWith":
where_parts.append(f'"{key}"::TEXT ILIKE %s')
values.append(f"{v}%")
elif op == "endsWith":
where_parts.append(f'"{key}"::TEXT ILIKE %s')
values.append(f"%{v}")
elif op in ("gt", "gte", "lt", "lte"):
sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op]
where_parts.append(f'"{key}"::TEXT {sqlOp} %s')
values.append(str(v))
elif op == "between":
fromVal = v.get("from", "") if isinstance(v, dict) else ""
toVal = v.get("to", "") if isinstance(v, dict) else ""
if not fromVal and not toVal:
continue
colType = fields.get(key, "TEXT")
isNumericCol = colType in ("INTEGER", "DOUBLE PRECISION")
isDateVal = bool(fromVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(fromVal))) or \
bool(toVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(toVal)))
if isNumericCol and isDateVal:
from datetime import datetime as _dt, timezone as _tz
if fromVal and toVal:
fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp()
toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp()
where_parts.append(f'"{key}" >= %s AND "{key}" <= %s')
values.extend([fromTs, toTs])
elif fromVal:
fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp()
where_parts.append(f'"{key}" >= %s')
values.append(fromTs)
else:
toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp()
where_parts.append(f'"{key}" <= %s')
values.append(toTs)
else:
if fromVal and toVal:
where_parts.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s')
values.extend([str(fromVal), str(toVal)])
elif fromVal:
where_parts.append(f'"{key}"::TEXT >= %s')
values.append(str(fromVal))
elif toVal:
where_parts.append(f'"{key}"::TEXT <= %s')
values.append(str(toVal))
else:
if colType == "BOOLEAN":
where_parts.append(f'COALESCE("{key}", FALSE) = %s')
values.append(str(val).lower() == "true")
else:
where_parts.append(f'"{key}"::TEXT ILIKE %s')
values.append(str(val))
where_clause = " WHERE " + " AND ".join(where_parts) if where_parts else ""
count_values = list(values)
orderParts: List[str] = []
if pagination and pagination.sort:
for sf in pagination.sort:
sfField = sf.get("field") if isinstance(sf, dict) else getattr(sf, "field", None)
sfDir = sf.get("direction", "asc") if isinstance(sf, dict) else getattr(sf, "direction", "asc")
if sfField and sfField in validColumns:
direction = "DESC" if str(sfDir).lower() == "desc" else "ASC"
colType = fields.get(sfField, "TEXT")
if colType == "BOOLEAN":
orderParts.append(f'COALESCE("{sfField}", FALSE) {direction}')
else:
orderParts.append(f'"{sfField}" {direction} NULLS LAST')
if not orderParts:
orderParts.append('"id"')
order_clause = " ORDER BY " + ", ".join(orderParts)
limit_clause = ""
if pagination:
offset = (pagination.page - 1) * pagination.pageSize
limit_clause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
return where_clause, order_clause, limit_clause, values, count_values
def getRecordsetPaginated(
self,
model_class: type,
pagination=None,
recordFilter: Dict[str, Any] = None,
fieldFilter: List[str] = None,
) -> Dict[str, Any]:
"""
Returns paginated records with filtering + sorting at the SQL level.
Returns { "items": [...], "totalItems": int, "totalPages": int }.
If pagination is None, returns all records (no LIMIT/OFFSET).
"""
from modules.datamodels.datamodelPagination import PaginationParams
import math
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return {"items": [], "totalItems": 0, "totalPages": 0}
where_clause, order_clause, limit_clause, values, count_values = \
self._buildPaginationClauses(model_class, pagination, recordFilter)
with self.connection.cursor() as cursor:
countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
cursor.execute(countSql, count_values)
totalItems = cursor.fetchone()["count"]
cursor.execute(dataSql, values)
records = [dict(row) for row in cursor.fetchall()]
fields = getModelFields(model_class)
modelFields = model_class.model_fields
for record in records:
parseRecordFields(record, fields, f"table {table}")
for fieldName, fieldType in fields.items():
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
fieldInfo = modelFields.get(fieldName)
if fieldInfo:
fieldAnnotation = fieldInfo.annotation
if (fieldAnnotation == list or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is list)):
record[fieldName] = []
elif (fieldAnnotation == dict or
(hasattr(fieldAnnotation, "__origin__") and
fieldAnnotation.__origin__ is dict)):
record[fieldName] = {}
if fieldFilter and isinstance(fieldFilter, list):
records = [{f: r[f] for f in fieldFilter if f in r} for r in records]
from modules.routes.routeHelpers import enrichRowsWithFkLabels
enrichRowsWithFkLabels(records, model_class)
pageSize = pagination.pageSize if pagination else max(totalItems, 1)
totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
except Exception as e:
logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
return {"items": [], "totalItems": 0, "totalPages": 0}
def getDistinctColumnValues(
self,
model_class: type,
column: str,
pagination=None,
recordFilter: Dict[str, Any] = None,
includeEmpty: bool = True,
) -> List[Optional[str]]:
"""Return sorted distinct values for a column using SQL DISTINCT.
When ``includeEmpty`` is True (default), NULL and empty-string rows are
represented as a single ``None`` entry at the end of the list — this
allows the frontend to offer a "(Leer)" filter option.
Applies cross-filtering (all filters except the requested column).
"""
table = model_class.__name__
fields = getModelFields(model_class)
if column not in fields:
return []
try:
if not self._ensureTableExists(model_class):
return []
if pagination:
import copy
pagination = copy.deepcopy(pagination)
if pagination.filters and column in pagination.filters:
pagination.filters.pop(column, None)
pagination.sort = []
where_clause, _, _, values, _ = \
self._buildPaginationClauses(model_class, pagination, recordFilter)
nonNullCond = f'"{column}" IS NOT NULL AND "{column}"::TEXT != \'\''
if where_clause:
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} AND {nonNullCond} ORDER BY val'
else:
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}" WHERE {nonNullCond} ORDER BY val'
with self.connection.cursor() as cursor:
cursor.execute(sql, values)
result: List[Optional[str]] = [row["val"] for row in cursor.fetchall()]
if includeEmpty:
emptyCond = f'"{column}" IS NULL OR "{column}"::TEXT = \'\''
if where_clause:
emptySql = f'SELECT 1 FROM "{table}"{where_clause} AND ({emptyCond}) LIMIT 1'
else:
emptySql = f'SELECT 1 FROM "{table}" WHERE ({emptyCond}) LIMIT 1'
with self.connection.cursor() as cursor:
cursor.execute(emptySql, values)
if cursor.fetchone():
result.append(None)
return result
except Exception as e:
logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
return []
def recordCreate(
self, model_class: type, record: Union[Dict[str, Any], BaseModel]
) -> Dict[str, Any]:
"""Creates a new record in a table based on Pydantic model class."""
# If record is a Pydantic model, convert to dict
if isinstance(record, BaseModel):
record = record.model_dump()
elif isinstance(record, dict):
record = record.copy()
else:
raise ValueError("Record must be a Pydantic model or dictionary")
# Ensure record has an ID
if "id" not in record:
record["id"] = str(uuid.uuid4())
# Save record
success = self._saveRecord(model_class, record["id"], record)
if not success:
table = model_class.__name__
raise ValueError(f"Failed to save record {record['id']} to table {table}")
# Check if this is the first record in the table and register as initial ID
table = model_class.__name__
existingInitialId = self.getInitialId(model_class)
if existingInitialId is None:
# This is the first record, register it as the initial ID
self._registerInitialId(table, record["id"])
logger.info(f"Registered initial ID {record['id']} for table {table}")
return record
def recordModify(
self, model_class: type, recordId: str, record: Union[Dict[str, Any], BaseModel]
) -> Dict[str, Any]:
"""Modifies an existing record in a table based on Pydantic model class."""
# Load existing record
existingRecord = self._loadRecord(model_class, recordId)
if not existingRecord:
table = model_class.__name__
raise ValueError(f"Record {recordId} not found in table {table}")
# If record is a Pydantic model, convert to dict
if isinstance(record, BaseModel):
record = record.model_dump()
elif isinstance(record, dict):
record = record.copy()
else:
raise ValueError("Record must be a Pydantic model or dictionary")
# CRITICAL: Ensure we never modify the ID
if "id" in record and str(record["id"]) != recordId:
logger.error(
f"Attempted to modify record ID from {recordId} to {record['id']}"
)
raise ValueError(
"Cannot modify record ID - it must match the provided recordId"
)
# Update existing record with new data
existingRecord.update(record)
# Save updated record
saved = self._saveRecord(model_class, recordId, existingRecord)
if not saved:
table = model_class.__name__
raise ValueError(f"Failed to save record {recordId} to table {table}")
return existingRecord
def recordDelete(self, model_class: type, recordId: str) -> bool:
"""Deletes a record from the table based on Pydantic model class."""
table = model_class.__name__
try:
if not self._ensureTableExists(model_class):
return False
with self.connection.cursor() as cursor:
# Check if record exists
cursor.execute(
f'SELECT "id" FROM "{table}" WHERE "id" = %s', (recordId,)
)
if not cursor.fetchone():
return False
# Check if it's an initial record
initialId = self.getInitialId(model_class)
if initialId is not None and initialId == recordId:
self._removeInitialId(table)
logger.info(
f"Initial ID {recordId} for table {table} has been removed from the system table"
)
# Delete the record
cursor.execute(f'DELETE FROM "{table}" WHERE "id" = %s', (recordId,))
# No cache to update - database handles consistency
self.connection.commit()
return True
except Exception as e:
logger.error(f"Error deleting record {recordId} from table {table}: {e}")
self.connection.rollback()
return False
def recordCreateBulk(
self, model_class: type, records: List[Union[Dict[str, Any], BaseModel]]
) -> int:
"""Bulk-insert many records in a single transaction.
Use this instead of calling recordCreate() in a tight loop when importing
large datasets (>100 rows). Performance gain is roughly two orders of
magnitude because:
- one network round-trip via execute_values() instead of N
- one COMMIT instead of N
- initial ID is registered once for the whole batch instead of every row
Returns the number of rows successfully inserted. Caller is responsible
for catching exceptions; on any error the transaction is rolled back so
the table stays consistent (all-or-nothing).
"""
if not records:
return 0
table = model_class.__name__
if not self._ensureTableExists(model_class):
raise ValueError(f"Table {table} does not exist")
fields = getModelFields(model_class)
columns = ["id"] + [f for f in fields.keys() if f != "id"]
modelFields = model_class.model_fields
effectiveUserId = _current_user_id.get()
if effectiveUserId is None:
effectiveUserId = self.userId
currentTime = getUtcTimestamp()
normalised: List[Dict[str, Any]] = []
for raw in records:
if isinstance(raw, BaseModel):
rec = raw.model_dump()
elif isinstance(raw, dict):
rec = raw.copy()
else:
raise ValueError("Bulk record must be a Pydantic model or dictionary")
if "id" not in rec or not rec["id"]:
rec["id"] = str(uuid.uuid4())
createdTs = rec.get("sysCreatedAt")
if createdTs is None or createdTs == 0 or createdTs == 0.0:
rec["sysCreatedAt"] = currentTime
if effectiveUserId:
rec["sysCreatedBy"] = effectiveUserId
elif not rec.get("sysCreatedBy") and effectiveUserId:
rec["sysCreatedBy"] = effectiveUserId
rec["sysModifiedAt"] = currentTime
if effectiveUserId:
rec["sysModifiedBy"] = effectiveUserId
normalised.append(rec)
rows = [self._coerceRowForInsert(rec, columns, fields, modelFields) for rec in normalised]
col_names = ", ".join([f'"{c}"' for c in columns])
updates = ", ".join(
[f'"{c}" = EXCLUDED."{c}"' for c in columns[1:]
if c not in ("sysCreatedAt", "sysCreatedBy")]
)
sql = (
f'INSERT INTO "{table}" ({col_names}) VALUES %s '
f'ON CONFLICT ("id") DO UPDATE SET {updates}'
)
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
psycopg2.extras.execute_values(cursor, sql, rows, page_size=500)
self.connection.commit()
except Exception as e:
logger.error(f"Bulk insert into {table} failed (n={len(rows)}): {e}")
try:
self.connection.rollback()
except Exception:
pass
raise
if self.getInitialId(model_class) is None and normalised:
self._registerInitialId(table, normalised[0]["id"])
logger.info(f"Registered initial ID {normalised[0]['id']} for table {table}")
return len(rows)
def _coerceRowForInsert(
self,
record: Dict[str, Any],
columns: List[str],
fields: Dict[str, str],
modelFields: Dict[str, Any],
) -> tuple:
"""Convert one record dict to a positional tuple matching `columns`.
Mirrors the per-column coercion logic in `_save_record` so that bulk and
single inserts produce identical on-disk values (timestamps as floats,
enums as strings, vectors as pgvector text, JSONB as JSON strings).
"""
import json as _json
out = []
for col in columns:
value = record.get(col)
if col in ("sysCreatedAt", "sysModifiedAt") and value is not None:
if isinstance(value, str):
try:
value = float(value)
except Exception:
pass
elif hasattr(value, "value"):
value = value.value
elif col in fields and _isVectorType(fields[col]) and value is not None:
if isinstance(value, list):
value = f"[{','.join(str(v) for v in value)}]"
elif col in fields and fields[col] == "JSONB" and value is not None:
if isinstance(value, (dict, list)):
value = _json.dumps(value)
elif isinstance(value, str):
try:
_json.loads(value)
except (ValueError, TypeError):
value = _json.dumps(value)
elif hasattr(value, "model_dump"):
value = _json.dumps(value.model_dump())
else:
value = _json.dumps(value)
out.append(value)
return tuple(out)
def recordDeleteWhere(
self, model_class: type, recordFilter: Dict[str, Any]
) -> int:
"""Delete all records matching a simple equality filter, in one statement.
Replaces the N+1 pattern `for r in getRecordset(...): recordDelete(r.id)`.
Returns the number of rows actually deleted. If the table holds the
initial ID and that row gets deleted, the initial ID registration is
cleared so the next insert can re-register a fresh one.
"""
if not recordFilter:
raise ValueError("recordDeleteWhere requires a non-empty recordFilter (refusing to truncate)")
table = model_class.__name__
if not self._ensureTableExists(model_class):
return 0
fields = getModelFields(model_class)
clauses: List[str] = []
params: List[Any] = []
for key, val in recordFilter.items():
if key not in fields and key != "id":
raise ValueError(f"recordDeleteWhere: unknown column {table}.{key}")
clauses.append(f'"{key}" = %s')
params.append(val)
whereSql = " AND ".join(clauses)
initialId = self.getInitialId(model_class)
try:
self._ensure_connection()
with self.connection.cursor() as cursor:
if initialId is not None:
cursor.execute(
f'SELECT 1 FROM "{table}" WHERE "id" = %s AND ' + whereSql,
[initialId, *params],
)
initialIsAffected = cursor.fetchone() is not None
else:
initialIsAffected = False
cursor.execute(f'DELETE FROM "{table}" WHERE ' + whereSql, params)
deleted = cursor.rowcount or 0
self.connection.commit()
except Exception as e:
logger.error(f"Bulk delete from {table} failed (filter={recordFilter}): {e}")
try:
self.connection.rollback()
except Exception:
pass
raise
if deleted and initialIsAffected:
self._removeInitialId(table)
logger.info(f"Initial ID for table {table} cleared (bulk-delete removed it)")
if deleted:
logger.info(f"recordDeleteWhere: deleted {deleted} rows from {table} where {recordFilter}")
return deleted
def getInitialId(self, model_class: type) -> Optional[str]:
"""Returns the initial ID for a table."""
table = model_class.__name__
systemData = self._loadSystemTable()
initialId = systemData.get(table)
return initialId
def semanticSearch(
self,
modelClass: type,
vectorColumn: str,
queryVector: List[float],
limit: int = 10,
recordFilter: Dict[str, Any] = None,
minScore: float = None,
) -> List[Dict[str, Any]]:
"""Semantic search using pgvector cosine distance.
Args:
modelClass: Pydantic model class for the table.
vectorColumn: Name of the vector column to search.
queryVector: Query vector as List[float].
limit: Maximum number of results.
recordFilter: Additional WHERE filters (field: value).
minScore: Minimum cosine similarity (0.0 - 1.0).
Returns:
List of records with an added '_score' field (cosine similarity),
sorted by similarity descending.
"""
table = modelClass.__name__
try:
if not self._ensureTableExists(modelClass):
return []
vectorStr = f"[{','.join(str(v) for v in queryVector)}]"
whereConditions = []
whereValues = []
if recordFilter:
for field, value in recordFilter.items():
if value is None:
whereConditions.append(f'"{field}" IS NULL')
elif isinstance(value, (list, tuple)):
if not value:
whereConditions.append("1 = 0")
else:
whereConditions.append(f'"{field}" = ANY(%s)')
whereValues.append(list(value))
else:
whereConditions.append(f'"{field}" = %s')
whereValues.append(value)
if minScore is not None:
whereConditions.append(
f'1 - ("{vectorColumn}" <=> %s::vector) >= %s'
)
whereValues.extend([vectorStr, minScore])
whereClause = ""
if whereConditions:
whereClause = " WHERE " + " AND ".join(whereConditions)
query = (
f'SELECT *, 1 - ("{vectorColumn}" <=> %s::vector) AS "_score" '
f'FROM "{table}"{whereClause} '
f'ORDER BY "{vectorColumn}" <=> %s::vector '
f'LIMIT %s'
)
params = [vectorStr] + whereValues + [vectorStr, limit]
with self.connection.cursor() as cursor:
cursor.execute(query, params)
records = [dict(row) for row in cursor.fetchall()]
fields = getModelFields(modelClass)
for record in records:
parseRecordFields(record, fields, f"semanticSearch {table}")
return records
except Exception as e:
logger.error(f"Error in semantic search on {table}: {e}")
return []
def close(self, forceClose: bool = False):
"""Close the database connection.
Shared cached connectors are intentionally kept open unless forceClose=True.
This prevents accidental shutdown from interface __del__ methods while
other requests are still using the same cached connector instance.
"""
if self._isCachedShared and not forceClose:
return
if (
hasattr(self, "connection")
and self.connection
and not self.connection.closed
):
self.connection.close()
def __del__(self):
"""Cleanup method to close connection."""
try:
self.close()
except Exception:
pass