platform-core/modules/system/databaseMigration.py
ValueOn AG 8c2e9d2183
All checks were successful
Deploy Plattform-Core / test (push) Successful in 47s
Deploy Plattform-Core / deploy (push) Successful in 5s
Deploy Plattform-Core (Int) / test (push) Successful in 59s
Deploy Plattform-Core (Int) / deploy (push) Successful in 10s
db import streaming
2026-05-28 10:52:10 +02:00

1292 lines
46 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Database migration utilities — backup (export) and restore (import) for all
registered PowerOn databases.
System objects (root mandate, admin user, event user) are protected: they are
never deleted or overwritten during import. Their IDs in the backup payload
are remapped to the IDs of the corresponding live objects so that all FK
references stay consistent.
All functions are intended for SysAdmin use only (access control in the route layer).
"""
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set, Tuple
import psycopg2
import psycopg2.extras
from modules.shared.configuration import APP_CONFIG
from modules.shared.dbRegistry import getRegisteredDatabases
from modules.shared.fkRegistry import getFkRelationships
from modules.datamodels.datamodelBase import MODEL_REGISTRY
from modules.system.databaseHealth import _getConnection, _jsonSafe
logger = logging.getLogger(__name__)
_EXPORT_FORMAT_VERSION = "1.0"
_SYSTEM_TABLE = "_system"
_EXCLUDED_TABLES: Dict[str, Set[str]] = {
"poweron_app": {"Token", "AuthEvent"},
}
# ---------------------------------------------------------------------------
# Instance label
# ---------------------------------------------------------------------------
def _getInstanceLabel() -> str:
"""Return the instance type from APP_ENV_TYPE (e.g. 'dev', 'int', 'prod')."""
return APP_CONFIG.get("APP_ENV_TYPE", "unknown")
# ---------------------------------------------------------------------------
# Database list
# ---------------------------------------------------------------------------
def _getAvailableDatabases() -> List[dict]:
"""Return registered databases with table/row counts for the UI."""
registeredDbs = getRegisteredDatabases()
results: List[dict] = []
for dbName in sorted(registeredDbs):
if dbName == "poweron_test":
continue
entry: dict = {"name": dbName, "tableCount": 0, "recordCount": 0}
try:
conn = _getConnection(dbName)
try:
with conn.cursor() as cur:
cur.execute("""
SELECT relname, n_live_tup
FROM pg_stat_user_tables
WHERE schemaname = 'public'
AND relname NOT LIKE '\\_%%'
""")
for row in cur.fetchall():
entry["tableCount"] += 1
entry["recordCount"] += int(row["n_live_tup"])
finally:
conn.close()
except Exception as e:
logger.warning("Could not stat database %s: %s", dbName, e)
results.append(entry)
return results
# ---------------------------------------------------------------------------
# Export
# ---------------------------------------------------------------------------
def _exportDatabases(databases: List[str]) -> dict:
"""Export selected databases as a JSON-serialisable dict.
Returns ``{meta: {...}, databases: {dbName: {tables: {tbl: [rows]}, summary: {...}}}}``
"""
registeredDbs = getRegisteredDatabases()
if not databases:
raise ValueError("No databases selected for export.")
exportData: dict = {
"meta": {
"exportedAt": datetime.now(timezone.utc).isoformat(),
"version": _EXPORT_FORMAT_VERSION,
"databaseCount": 0,
"totalTables": 0,
"totalRecords": 0,
},
"databases": {},
}
for dbName in databases:
if dbName not in registeredDbs:
logger.warning("Export: skipping unregistered database %s", dbName)
continue
try:
dbPayload = _exportSingleDb(dbName)
exportData["databases"][dbName] = dbPayload
exportData["meta"]["databaseCount"] += 1
exportData["meta"]["totalTables"] += dbPayload["tableCount"]
exportData["meta"]["totalRecords"] += dbPayload["totalRecords"]
except Exception as e:
logger.error("Export failed for database %s: %s", dbName, e)
return exportData
def _getModelTablesForDb(dbName: str, physicalTables: List[str]) -> List[str]:
"""Return only those physical tables that have a matching Pydantic model
registered in MODEL_REGISTRY.
Tables without a Pydantic class (legacy / orphan tables) are excluded
from export so the backup contains only model-backed data.
Note: the same model can exist in multiple databases (shared-table
pattern), so we only check membership in MODEL_REGISTRY, not the
DB mapping.
"""
return sorted(
t for t in physicalTables
if t in MODEL_REGISTRY
)
def _exportSingleDb(dbName: str) -> dict:
conn = _getConnection(dbName)
excluded = _EXCLUDED_TABLES.get(dbName, set())
try:
allTables = _listTables(conn)
modelTables = _getModelTablesForDb(dbName, allTables)
skippedLegacy = set(allTables) - set(modelTables) - excluded - {_SYSTEM_TABLE}
if skippedLegacy:
logger.info("Export %s: skipping %d legacy tables without model: %s",
dbName, len(skippedLegacy), sorted(skippedLegacy))
dbPayload: dict = {"tables": {}, "summary": {}, "tableCount": 0, "totalRecords": 0}
for tbl in modelTables:
if tbl in excluded:
logger.info("Export: skipping excluded table %s.%s", dbName, tbl)
continue
rows = _readTableRows(conn, tbl)
dbPayload["tables"][tbl] = rows
dbPayload["summary"][tbl] = {"recordCount": len(rows)}
dbPayload["tableCount"] += 1
dbPayload["totalRecords"] += len(rows)
return dbPayload
finally:
conn.close()
def _listTables(conn) -> List[str]:
with conn.cursor() as cur:
cur.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE'
AND table_name != %s
ORDER BY table_name
""", (_SYSTEM_TABLE,))
return [row["table_name"] for row in cur.fetchall()]
def _readTableRows(conn, tableName: str) -> List[dict]:
with conn.cursor() as cur:
cur.execute(f'SELECT * FROM "{tableName}"')
return [{k: _jsonSafe(v) for k, v in dict(row).items()} for row in cur.fetchall()]
# ---------------------------------------------------------------------------
# Streaming Export (memory-safe, handles arbitrarily large databases)
# ---------------------------------------------------------------------------
def streamExportGenerator(databases: List[str], instanceLabel: str = ""):
"""Yield JSON fragments for a streaming database export.
Writes valid JSON incrementally (row-by-row, table-by-table) so that
neither the backend RAM nor the browser JS heap is saturated — works
for databases of any size.
The output format is identical to the non-streaming _exportDatabases():
{"meta": {...}, "databases": {"dbName": {"tables": {"tbl": [rows]}, ...}}}
"""
import json
registeredDbs = getRegisteredDatabases()
validDbs = [db for db in databases if db in registeredDbs]
totalDbs = 0
totalTables = 0
totalRecords = 0
yield '{"meta":'
metaPlaceholder = json.dumps({
"exportedAt": datetime.now(timezone.utc).isoformat(),
"version": _EXPORT_FORMAT_VERSION,
"instanceLabel": instanceLabel,
"databaseCount": "<<PLACEHOLDER>>",
}, ensure_ascii=False)
yield metaPlaceholder
yield ',"databases":{'
firstDb = True
for dbName in validDbs:
excluded = _EXCLUDED_TABLES.get(dbName, set())
conn = None
try:
conn = _getConnection(dbName)
allTables = _listTables(conn)
modelTables = _getModelTablesForDb(dbName, allTables)
if not firstDb:
yield ','
firstDb = False
yield json.dumps(dbName, ensure_ascii=False)
yield ':{"tables":{'
firstTable = True
dbTableCount = 0
dbRecordCount = 0
for tbl in modelTables:
if tbl in excluded:
continue
if not firstTable:
yield ','
firstTable = False
yield json.dumps(tbl, ensure_ascii=False)
yield ':['
with conn.cursor(name=f"export_{dbName}_{tbl}") as cur:
cur.itersize = 2000
cur.execute(f'SELECT * FROM "{tbl}"')
firstRow = True
rowCount = 0
for row in cur:
if not firstRow:
yield ','
firstRow = False
safeRow = {k: _jsonSafe(v) for k, v in dict(row).items()}
yield json.dumps(safeRow, ensure_ascii=False, default=str)
rowCount += 1
yield ']'
dbTableCount += 1
dbRecordCount += rowCount
yield '},"summary":{'
firstSummaryTable = True
for tbl in modelTables:
if tbl in excluded:
continue
if not firstSummaryTable:
yield ','
firstSummaryTable = False
yield json.dumps(tbl, ensure_ascii=False)
yield ':{"recordCount":0}'
yield '}'
yield f',"tableCount":{dbTableCount},"totalRecords":{dbRecordCount}'
yield '}'
totalDbs += 1
totalTables += dbTableCount
totalRecords += dbRecordCount
logger.info("Stream export: %s done (%d tables, %d records)", dbName, dbTableCount, dbRecordCount)
except Exception as e:
logger.error("Stream export failed for %s: %s", dbName, e)
if not firstDb or not firstTable:
pass
finally:
if conn:
try:
conn.close()
except Exception:
pass
yield '}}'
# ---------------------------------------------------------------------------
# Validate
# ---------------------------------------------------------------------------
def _validateImportPayload(payload: dict) -> dict:
"""Validate an import payload without writing anything.
Returns ``{valid, summary, warnings, systemObjectsFound}``.
"""
warnings: List[str] = []
summary: List[dict] = []
meta = payload.get("meta")
if not meta or not isinstance(meta, dict):
return {"valid": False, "summary": [], "warnings": ["Fehlende oder ungueltige 'meta'-Sektion"], "systemObjectsFound": []}
version = meta.get("version", "")
if version != _EXPORT_FORMAT_VERSION:
warnings.append(f"Unbekannte Format-Version: {version} (erwartet: {_EXPORT_FORMAT_VERSION})")
databases = payload.get("databases")
if not databases or not isinstance(databases, dict):
return {"valid": False, "summary": [], "warnings": ["Fehlende oder ungueltige 'databases'-Sektion"], "systemObjectsFound": []}
registeredDbs = getRegisteredDatabases()
for dbName, dbData in databases.items():
tables = dbData.get("tables", {})
tableCount = len(tables)
recordCount = sum(len(rows) for rows in tables.values() if isinstance(rows, list))
registered = dbName in registeredDbs
if not registered:
warnings.append(f"Datenbank '{dbName}' ist nicht registriert und wird uebersprungen")
summary.append({
"database": dbName,
"tableCount": tableCount,
"recordCount": recordCount,
"registered": registered,
})
systemObjectsFound = _detectSystemObjectsInPayload(payload)
valid = any(s["registered"] for s in summary)
return {
"valid": valid,
"summary": summary,
"warnings": warnings,
"systemObjectsFound": systemObjectsFound,
}
def _detectSystemObjectsInPayload(payload: dict) -> List[dict]:
"""Find system objects (root mandate, admin user, event user) in a payload."""
found: List[dict] = []
appData = payload.get("databases", {}).get("poweron_app", {}).get("tables", {})
for row in appData.get("Mandate", []):
if row.get("name") == "root" and row.get("isSystem") is True:
found.append({"type": "mandate", "label": "Root Mandate", "payloadId": row.get("id")})
for row in appData.get("UserInDB", []):
if row.get("username") == "admin":
found.append({"type": "user", "label": "Admin User", "payloadId": row.get("id")})
elif row.get("username") == "event":
found.append({"type": "user", "label": "Event User", "payloadId": row.get("id")})
return found
# ---------------------------------------------------------------------------
# System-object ID remapping
# ---------------------------------------------------------------------------
def _loadLiveSystemObjectIds() -> Dict[str, str]:
"""Load the IDs of the 3 protected system objects from the live DB.
Returns a dict like ``{"rootMandate": "<uuid>", "adminUser": "<uuid>", "eventUser": "<uuid>"}``.
"""
registeredDbs = getRegisteredDatabases()
if "poweron_app" not in registeredDbs:
return {}
result: Dict[str, str] = {}
conn = _getConnection("poweron_app")
try:
with conn.cursor() as cur:
cur.execute("""SELECT id FROM "Mandate" WHERE "name" = 'root' AND "isSystem" = true LIMIT 1""")
row = cur.fetchone()
if row:
result["rootMandate"] = str(row["id"])
cur.execute("""SELECT id FROM "UserInDB" WHERE "username" = 'admin' LIMIT 1""")
row = cur.fetchone()
if row:
result["adminUser"] = str(row["id"])
cur.execute("""SELECT id FROM "UserInDB" WHERE "username" = 'event' LIMIT 1""")
row = cur.fetchone()
if row:
result["eventUser"] = str(row["id"])
finally:
conn.close()
return result
def _buildIdRemapFromPayload(payload: dict, liveIds: Dict[str, str]) -> Dict[str, str]:
"""Build an ``{oldId: newId}`` mapping for system objects.
Compares IDs found in the payload with the live system-object IDs.
Only entries where the IDs actually differ are included.
"""
remap: Dict[str, str] = {}
appTables = payload.get("databases", {}).get("poweron_app", {}).get("tables", {})
for row in appTables.get("Mandate", []):
if row.get("name") == "root" and row.get("isSystem") is True:
oldId = str(row.get("id", ""))
newId = liveIds.get("rootMandate", "")
if oldId and newId and oldId != newId:
remap[oldId] = newId
for row in appTables.get("UserInDB", []):
username = row.get("username")
oldId = str(row.get("id", ""))
if username == "admin":
newId = liveIds.get("adminUser", "")
elif username == "event":
newId = liveIds.get("eventUser", "")
else:
continue
if oldId and newId and oldId != newId:
remap[oldId] = newId
return remap
def _remapSystemObjectIds(payload: dict, remap: Dict[str, str]) -> dict:
"""Walk the entire payload and replace every value that matches an old system-object ID."""
if not remap:
return payload
remapSet = set(remap.keys())
databases = payload.get("databases", {})
for dbName, dbData in databases.items():
tables = dbData.get("tables", {})
for tableName, rows in tables.items():
if not isinstance(rows, list):
continue
for row in rows:
_remapRowValues(row, remap, remapSet)
return payload
def _remapDbTables(tables: dict, remap: Dict[str, str]) -> None:
"""In-place remap system-object IDs in a single DB's tables dict."""
if not remap:
return
remapSet = set(remap.keys())
for tableName, rows in tables.items():
if not isinstance(rows, list):
continue
for row in rows:
_remapRowValues(row, remap, remapSet)
def _remapRowValues(row: dict, remap: Dict[str, str], remapSet: Set[str]) -> None:
"""In-place replace string values in a row dict that match a remap key."""
for key, val in row.items():
if isinstance(val, str) and val in remapSet:
row[key] = remap[val]
elif isinstance(val, dict):
_remapRowValues(val, remap, remapSet)
elif isinstance(val, list):
for i, item in enumerate(val):
if isinstance(item, str) and item in remapSet:
val[i] = remap[item]
elif isinstance(item, dict):
_remapRowValues(item, remap, remapSet)
# ---------------------------------------------------------------------------
# Import
# ---------------------------------------------------------------------------
_PROTECTED_ROWS: Dict[str, List[dict]] = {
"Mandate": [{"name": "root", "isSystem": True}],
"UserInDB": [{"username": "admin"}, {"username": "event"}],
}
def _isProtectedRow(tableName: str, row: dict) -> bool:
"""Return True if a row represents a protected system object."""
patterns = _PROTECTED_ROWS.get(tableName, [])
for pattern in patterns:
if all(row.get(k) == v for k, v in pattern.items()):
return True
return False
def _importDatabases(payload: dict, mode: str) -> dict:
"""Import databases from a validated payload.
``mode`` is ``"replace"`` (clear + insert) or ``"merge"`` (insert missing only).
"""
if mode not in ("replace", "merge"):
raise ValueError(f"Invalid import mode: {mode}")
registeredDbs = getRegisteredDatabases()
liveIds = _loadLiveSystemObjectIds()
remap = _buildIdRemapFromPayload(payload, liveIds)
if remap:
logger.info("System-object ID remap: %s", remap)
_remapSystemObjectIds(payload, remap)
protectedIdSet = set(liveIds.values())
imported: Dict[str, dict] = {}
warnings: List[str] = []
databases = payload.get("databases", {})
for dbName, dbData in databases.items():
if dbName not in registeredDbs:
warnings.append(f"Datenbank '{dbName}' uebersprungen (nicht registriert)")
continue
tables = dbData.get("tables", {})
dbResult: Dict[str, int] = {}
conn = _getConnection(dbName)
try:
conn.autocommit = False
existingTables = set(_listTables(conn))
for tableName, rows in tables.items():
if not isinstance(rows, list):
continue
if tableName not in existingTables:
warnings.append(f"Tabelle '{dbName}.{tableName}' existiert nicht, uebersprungen")
continue
physicalCols = _getPhysicalColumns(conn, tableName)
if not physicalCols:
continue
filteredRows = []
for row in rows:
if _isProtectedRow(tableName, row):
continue
if row.get("id") and str(row["id"]) in protectedIdSet:
continue
filteredRows.append(row)
if mode == "replace":
_deleteNonProtected(conn, tableName, protectedIdSet)
insertedCount = _insertRows(conn, tableName, filteredRows, physicalCols, mode)
dbResult[tableName] = insertedCount
conn.commit()
except Exception as e:
conn.rollback()
logger.error("Import failed for database %s: %s", dbName, e)
warnings.append(f"Import fuer '{dbName}' fehlgeschlagen: {e}")
continue
finally:
conn.close()
imported[dbName] = dbResult
totalRecords = sum(sum(v.values()) for v in imported.values())
return {
"success": True,
"imported": imported,
"totalRecords": totalRecords,
"warnings": warnings,
}
def _getPhysicalColumns(conn, tableName: str) -> List[str]:
with conn.cursor() as cur:
cur.execute("""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = %s
ORDER BY ordinal_position
""", (tableName,))
return [row["column_name"] for row in cur.fetchall()]
def _deleteNonProtected(conn, tableName: str, protectedIds: Set[str]) -> int:
"""Delete all rows except protected system objects."""
if not protectedIds:
with conn.cursor() as cur:
cur.execute(f'DELETE FROM "{tableName}"')
return cur.rowcount
idList = list(protectedIds)
with conn.cursor() as cur:
cur.execute(
f'DELETE FROM "{tableName}" WHERE "id"::text != ALL(%(ids)s)',
{"ids": idList},
)
return cur.rowcount
def _insertRows(
conn,
tableName: str,
rows: List[dict],
physicalCols: List[str],
mode: str,
) -> int:
"""Insert rows into a table. In merge mode, skip rows whose id already exists."""
if not rows:
return 0
physicalColSet = set(physicalCols)
inserted = 0
for row in rows:
cols = [c for c in row.keys() if c in physicalColSet]
if not cols:
continue
values = [_pgSafe(row[c]) for c in cols]
colNames = ", ".join(f'"{c}"' for c in cols)
placeholders = ", ".join(["%s"] * len(cols))
if mode == "merge":
sql = f'INSERT INTO "{tableName}" ({colNames}) VALUES ({placeholders}) ON CONFLICT ("id") DO NOTHING'
else:
sql = f'INSERT INTO "{tableName}" ({colNames}) VALUES ({placeholders})'
try:
with conn.cursor() as cur:
cur.execute("SAVEPOINT row_sp")
cur.execute(sql, values)
inserted += cur.rowcount
cur.execute("RELEASE SAVEPOINT row_sp")
except Exception as e:
logger.warning("Insert failed for %s row: %s", tableName, e)
with conn.cursor() as cur:
cur.execute("ROLLBACK TO SAVEPOINT row_sp")
return inserted
def _pgSafe(v: Any) -> Any:
"""Convert Python values to psycopg2-compatible types."""
import json as _json
if v is None or isinstance(v, (str, int, float, bool)):
return v
if isinstance(v, (dict, list)):
return _json.dumps(v)
return str(v)
# ---------------------------------------------------------------------------
# Prepare import (validate + remap, return context for per-DB import)
# ---------------------------------------------------------------------------
def _prepareImport(payload: dict) -> dict:
"""Validate, remap system-object IDs, and return the prepared payload
together with metadata the frontend needs to drive per-DB import.
Returns ``{valid, warnings, systemObjectsFound, databases, protectedIds, remappedPayload}``.
"""
validation = _validateImportPayload(payload)
if not validation.get("valid"):
return {
"valid": False,
"warnings": validation.get("warnings", []),
"systemObjectsFound": validation.get("systemObjectsFound", []),
"databases": [],
"protectedIds": [],
}
liveIds = _loadLiveSystemObjectIds()
remap = _buildIdRemapFromPayload(payload, liveIds)
if remap:
logger.info("System-object ID remap: %s", remap)
_remapSystemObjectIds(payload, remap)
protectedIdSet = set(liveIds.values())
registeredDbs = getRegisteredDatabases()
dbList = []
for dbName, dbData in payload.get("databases", {}).items():
if dbName not in registeredDbs:
continue
tables = dbData.get("tables", {})
recordCount = sum(len(rows) for rows in tables.values() if isinstance(rows, list))
dbList.append({
"database": dbName,
"tableCount": len(tables),
"recordCount": recordCount,
})
return {
"valid": True,
"warnings": validation.get("warnings", []),
"systemObjectsFound": validation.get("systemObjectsFound", []),
"databases": dbList,
"protectedIds": list(protectedIdSet),
}
def _ensureDatabaseExists(dbName: str) -> bool:
"""Create the PostgreSQL database if it does not yet exist.
Connects to the ``postgres`` admin database using the same credentials
as the target DB. Returns True if the database was created, False if
it already existed.
"""
registeredDbs = getRegisteredDatabases()
configPrefix = registeredDbs.get(dbName)
if configPrefix is None:
return False
hostKey = f"{configPrefix}_HOST" if configPrefix != "DB" else "DB_HOST"
portKey = f"{configPrefix}_PORT" if configPrefix != "DB" else "DB_PORT"
userKey = f"{configPrefix}_USER" if configPrefix != "DB" else "DB_USER"
passwordKey = f"{configPrefix}_PASSWORD_SECRET" if configPrefix != "DB" else "DB_PASSWORD_SECRET"
adminConn = psycopg2.connect(
host=APP_CONFIG.get(hostKey, "localhost"),
port=int(APP_CONFIG.get(portKey, 5432)),
database="postgres",
user=APP_CONFIG.get(userKey),
password=APP_CONFIG.get(passwordKey),
client_encoding="utf8",
)
try:
adminConn.autocommit = True
with adminConn.cursor() as cur:
cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (dbName,))
if cur.fetchone():
return False
cur.execute(f'CREATE DATABASE "{dbName}"')
logger.info("Created missing database: %s", dbName)
return True
finally:
adminConn.close()
def _createTableFromExport(conn, tableName: str, rows: List[dict]) -> None:
"""Create a table based on the column structure found in the export data.
Uses TEXT for all columns since we don't have the original DDL.
The ``id`` column gets a PRIMARY KEY constraint.
"""
allKeys: List[str] = []
seen: set = set()
for row in rows:
for k in row.keys():
if k not in seen:
allKeys.append(k)
seen.add(k)
if not allKeys:
return
colDefs = []
for col in allKeys:
if col == "id":
colDefs.append(f'"{col}" TEXT PRIMARY KEY')
else:
colDefs.append(f'"{col}" TEXT')
ddl = f'CREATE TABLE IF NOT EXISTS "{tableName}" ({", ".join(colDefs)})'
with conn.cursor() as cur:
cur.execute(ddl)
logger.info("Created table %s with %d columns", tableName, len(allKeys))
def _getTableImportOrder(conn, tableNames: List[str], dbName: str = "") -> List[str]:
"""Sort tables by FK dependencies (parents first) using topological sort.
Uses Pydantic ``fk_target`` metadata from ``fkRegistry`` as the single
source of truth (works for ALL databases, not just those with SQL FKs).
Only *intra-DB* dependencies are considered; cross-DB FKs (e.g. to
``poweron_app.Mandate``) are handled by importing databases in order.
"""
tableSet = set(tableNames)
allRels = getFkRelationships()
deps: Dict[str, Set[str]] = {t: set() for t in tableNames}
for rel in allRels:
if rel.sourceDb != dbName or rel.targetDb != dbName:
continue
child = rel.sourceTable
parent = rel.targetTable
if child in tableSet and parent in tableSet and child != parent:
deps[child].add(parent)
inDegree = {t: len(deps[t]) for t in tableNames}
queue = sorted(t for t in tableNames if inDegree[t] == 0)
ordered: List[str] = []
while queue:
node = queue.pop(0)
ordered.append(node)
for t in tableNames:
if node in deps[t]:
deps[t].discard(node)
inDegree[t] -= 1
if inDegree[t] == 0:
queue.append(t)
queue.sort()
remaining = [t for t in tableNames if t not in set(ordered)]
if remaining:
logger.warning("FK cycle detected, appending without order guarantee: %s", remaining)
ordered.extend(sorted(remaining))
return ordered
def _importSingleDb(payload: dict, dbName: str, mode: str, protectedIds: List[str]) -> dict:
"""Import a single database from the (already remapped) payload.
Tables are sorted by FK dependencies: parent tables are inserted first,
child tables are deleted first (reverse order) in replace mode.
Returns ``{database, tables: {tableName: insertedCount}, recordCount, warnings}``.
"""
if mode not in ("replace", "merge"):
raise ValueError(f"Invalid import mode: {mode}")
registeredDbs = getRegisteredDatabases()
if dbName not in registeredDbs:
return {"database": dbName, "tables": {}, "recordCount": 0,
"warnings": [f"Datenbank '{dbName}' nicht registriert"]}
dbData = payload.get("databases", {}).get(dbName)
if not dbData:
return {"database": dbName, "tables": {}, "recordCount": 0,
"warnings": [f"Keine Daten fuer '{dbName}' im Payload"]}
try:
dbCreated = _ensureDatabaseExists(dbName)
except Exception as e:
logger.error("Failed to ensure database %s exists: %s", dbName, e)
return {"database": dbName, "tables": {}, "recordCount": 0,
"warnings": [f"Datenbank '{dbName}' konnte nicht erstellt werden: {e}"]}
protectedIdSet = set(protectedIds)
tables = dbData.get("tables", {})
warnings: List[str] = []
dbResult: Dict[str, int] = {}
excluded = _EXCLUDED_TABLES.get(dbName, set())
if dbCreated:
warnings.append(f"Datenbank '{dbName}' wurde neu erstellt")
conn = _getConnection(dbName)
try:
existingTables = set(_listTables(conn))
conn.rollback()
# Ensure all import tables exist (create missing ones from export schema)
conn.autocommit = True
for tableName, rows in tables.items():
if tableName in excluded or not isinstance(rows, list) or not rows:
continue
if tableName not in existingTables:
_createTableFromExport(conn, tableName, rows)
existingTables.add(tableName)
logger.info("Pre-created missing table %s.%s", dbName, tableName)
# Build importable table list and sort by FK dependencies
importable = [t for t in tables
if t not in excluded
and isinstance(tables.get(t), list)
and t in existingTables]
importOrder = _getTableImportOrder(conn, importable, dbName)
logger.info("Import order for %s: %s", dbName, importOrder)
for tableName in tables:
if tableName in excluded and isinstance(tables.get(tableName), list):
warnings.append(f"Table '{dbName}.{tableName}' excluded (security/transient)")
# Phase 1 (replace only): DELETE children first (reverse topological order)
if mode == "replace":
conn.autocommit = False
for tableName in reversed(importOrder):
try:
_deleteNonProtected(conn, tableName, protectedIdSet)
conn.commit()
except Exception as e:
conn.rollback()
warnings.append(f"DELETE from {dbName}.{tableName} failed: {e}")
logger.warning("DELETE from %s.%s failed: %s", dbName, tableName, e)
# Phase 2: INSERT parents first (topological order)
conn.autocommit = False
for tableName in importOrder:
try:
rows = tables[tableName]
physicalCols = _getPhysicalColumns(conn, tableName)
if not physicalCols:
conn.rollback()
continue
filteredRows = []
for row in rows:
if _isProtectedRow(tableName, row):
continue
if row.get("id") and str(row["id"]) in protectedIdSet:
continue
filteredRows.append(row)
insertedCount = _insertRows(conn, tableName, filteredRows, physicalCols, mode)
conn.commit()
dbResult[tableName] = insertedCount
except Exception as e:
conn.rollback()
warnings.append(f"INSERT into {dbName}.{tableName} failed: {e}")
logger.warning("INSERT into %s.%s failed: %s", dbName, tableName, e)
except Exception as e:
logger.error("Import failed for database %s: %s", dbName, e)
return {"database": dbName, "tables": {}, "recordCount": 0,
"warnings": [f"Import fuer '{dbName}' fehlgeschlagen: {e}"]}
finally:
conn.close()
recordCount = sum(dbResult.values())
return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings}
# ---------------------------------------------------------------------------
# Streaming Import (memory-safe, ijson-based)
# ---------------------------------------------------------------------------
def _iterStreamRows(filePath: str):
"""Yield ``(dbName, tableName, rowDict)`` one row at a time from an export
JSON file using ijson streaming parser. Never holds more than one row in RAM.
"""
import ijson
with open(filePath, "rb") as f:
currentDb: Optional[str] = None
currentTable: Optional[str] = None
rowPrefix: Optional[str] = None
inRow = False
stack: List[Tuple[Any, Optional[str]]] = []
root: dict = {}
for prefix, event, value in ijson.parse(f, use_float=True):
if not inRow:
if prefix == "databases" and event == "map_key":
currentDb = value
currentTable = None
rowPrefix = None
elif (currentDb
and prefix == f"databases.{currentDb}.tables"
and event == "map_key"):
currentTable = value
rowPrefix = f"databases.{currentDb}.tables.{currentTable}.item"
elif rowPrefix and prefix == rowPrefix and event == "start_map":
inRow = True
root = {}
stack = [(root, None)]
continue
container, pendingKey = stack[-1]
if event == "map_key":
stack[-1] = (container, value)
elif event in ("string", "number", "boolean", "null"):
if isinstance(container, dict):
container[pendingKey] = value
stack[-1] = (container, None)
else:
container.append(value)
elif event == "start_map":
nested: dict = {}
if isinstance(container, dict):
container[pendingKey] = nested
stack[-1] = (container, None)
else:
container.append(nested)
stack.append((nested, None))
elif event == "end_map":
stack.pop()
if not stack:
yield (currentDb, currentTable, root)
inRow = False
elif event == "start_array":
nested_list: list = []
if isinstance(container, dict):
container[pendingKey] = nested_list
stack[-1] = (container, None)
else:
container.append(nested_list)
stack.append((nested_list, None))
elif event == "end_array":
stack.pop()
def _streamValidate(filePath: str) -> dict:
"""Stream-validate an import file without loading it into RAM.
Extracts ``meta``, counts databases/tables/records, detects system objects,
and builds the ID remap -- all with constant memory usage.
Returns the same shape as ``_prepareImport`` plus ``remap``.
"""
import ijson
meta = None
with open(filePath, "rb") as f:
for m in ijson.items(f, "meta"):
meta = m
break
warnings: List[str] = []
if not meta:
warnings.append("Fehlende oder ungueltige 'meta'-Sektion")
registeredDbs = getRegisteredDatabases()
dbTableCounts: Dict[str, Dict[str, int]] = {}
systemObjectsFound: List[dict] = []
for dbName, tableName, row in _iterStreamRows(filePath):
if dbName not in dbTableCounts:
dbTableCounts[dbName] = {}
if tableName not in dbTableCounts[dbName]:
dbTableCounts[dbName][tableName] = 0
dbTableCounts[dbName][tableName] += 1
if dbName == "poweron_app":
if (tableName == "Mandate"
and row.get("name") == "root"
and row.get("isSystem") is True):
systemObjectsFound.append({
"type": "mandate",
"label": "Root Mandate",
"payloadId": row.get("id"),
})
elif tableName == "UserInDB":
uname = row.get("username")
if uname == "admin":
systemObjectsFound.append({
"type": "user",
"label": "Admin User",
"payloadId": row.get("id"),
})
elif uname == "event":
systemObjectsFound.append({
"type": "user",
"label": "Event User",
"payloadId": row.get("id"),
})
if not dbTableCounts:
warnings.append("Fehlende oder ungueltige 'databases'-Sektion")
summary: List[dict] = []
for dbName, tables in dbTableCounts.items():
registered = dbName in registeredDbs
if not registered:
warnings.append(f"Datenbank '{dbName}' ist nicht registriert und wird uebersprungen")
summary.append({
"database": dbName,
"tableCount": len(tables),
"recordCount": sum(tables.values()),
"registered": registered,
})
liveIds = _loadLiveSystemObjectIds()
remap: Dict[str, str] = {}
for obj in systemObjectsFound:
oldId = str(obj.get("payloadId", ""))
if obj["type"] == "mandate":
newId = liveIds.get("rootMandate", "")
elif obj["label"] == "Admin User":
newId = liveIds.get("adminUser", "")
elif obj["label"] == "Event User":
newId = liveIds.get("eventUser", "")
else:
continue
if oldId and newId and oldId != newId:
remap[oldId] = newId
if remap:
logger.info("System-object ID remap: %s", remap)
protectedIdSet = set(liveIds.values())
valid = any(s["registered"] for s in summary)
dbList: List[dict] = []
for s in summary:
if s["registered"]:
dbList.append({
"database": s["database"],
"tableCount": s["tableCount"],
"recordCount": s["recordCount"],
})
return {
"valid": valid,
"summary": summary,
"warnings": warnings,
"systemObjectsFound": systemObjectsFound,
"databases": dbList,
"protectedIds": list(protectedIdSet),
"remap": remap,
}
def _streamSplitToFiles(
filePath: str,
tmpDir: str,
token: str,
remap: Dict[str, str],
) -> Dict[str, Dict[str, str]]:
"""Stream through the export file a second time, applying ID remap on
each row and writing per-table JSONL temp files.
Returns ``{dbName: {tableName: filePath}}``.
"""
import os
remapSet = set(remap.keys()) if remap else set()
dbFiles: Dict[str, Dict[str, str]] = {}
writers: Dict[Tuple[str, str], Any] = {}
try:
for dbName, tableName, row in _iterStreamRows(filePath):
if remap:
_remapRowValues(row, remap, remapSet)
key = (dbName, tableName)
if key not in writers:
tblPath = os.path.join(
tmpDir,
f"poweron_import_{token}_{dbName}__{tableName}.jsonl",
)
writers[key] = open(tblPath, "w", encoding="utf-8")
if dbName not in dbFiles:
dbFiles[dbName] = {}
dbFiles[dbName][tableName] = tblPath
writers[key].write(json.dumps(row, ensure_ascii=False, default=str))
writers[key].write("\n")
finally:
for fh in writers.values():
fh.close()
return dbFiles
def _importSingleDbFromFiles(
tableFiles: Dict[str, str],
dbName: str,
mode: str,
protectedIds: List[str],
) -> dict:
"""Import a single database from per-table JSONL files.
Each file contains one JSON object per line (rows).
Tables are sorted by FK dependencies before import.
Returns ``{database, tables, recordCount, warnings}``.
"""
if mode not in ("replace", "merge"):
raise ValueError(f"Invalid import mode: {mode}")
registeredDbs = getRegisteredDatabases()
if dbName not in registeredDbs:
return {"database": dbName, "tables": {}, "recordCount": 0,
"warnings": [f"Datenbank '{dbName}' nicht registriert"]}
if not tableFiles:
return {"database": dbName, "tables": {}, "recordCount": 0,
"warnings": [f"Keine Daten fuer '{dbName}'"]}
try:
dbCreated = _ensureDatabaseExists(dbName)
except Exception as e:
logger.error("Failed to ensure database %s exists: %s", dbName, e)
return {"database": dbName, "tables": {}, "recordCount": 0,
"warnings": [f"Datenbank '{dbName}' konnte nicht erstellt werden: {e}"]}
protectedIdSet = set(protectedIds)
warnings: List[str] = []
dbResult: Dict[str, int] = {}
excluded = _EXCLUDED_TABLES.get(dbName, set())
if dbCreated:
warnings.append(f"Datenbank '{dbName}' wurde neu erstellt")
conn = _getConnection(dbName)
try:
existingTables = set(_listTables(conn))
conn.rollback()
conn.autocommit = True
for tableName, tblPath in tableFiles.items():
if tableName in excluded:
continue
if tableName not in existingTables:
sampleRows: List[dict] = []
with open(tblPath, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
sampleRows.append(json.loads(line))
if len(sampleRows) >= 10:
break
if sampleRows:
_createTableFromExport(conn, tableName, sampleRows)
existingTables.add(tableName)
logger.info("Pre-created missing table %s.%s", dbName, tableName)
importable = [t for t in tableFiles if t not in excluded and t in existingTables]
importOrder = _getTableImportOrder(conn, importable, dbName)
logger.info("Import order for %s: %s", dbName, importOrder)
for tableName in tableFiles:
if tableName in excluded:
warnings.append(f"Table '{dbName}.{tableName}' excluded (security/transient)")
if mode == "replace":
conn.autocommit = False
for tableName in reversed(importOrder):
try:
_deleteNonProtected(conn, tableName, protectedIdSet)
conn.commit()
except Exception as e:
conn.rollback()
warnings.append(f"DELETE from {dbName}.{tableName} failed: {e}")
logger.warning("DELETE from %s.%s failed: %s", dbName, tableName, e)
conn.autocommit = False
batchSize = 100
for tableName in importOrder:
tblPath = tableFiles.get(tableName)
if not tblPath:
continue
try:
physicalCols = _getPhysicalColumns(conn, tableName)
if not physicalCols:
conn.rollback()
continue
insertedCount = 0
batch: List[dict] = []
with open(tblPath, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
row = json.loads(line)
if _isProtectedRow(tableName, row):
continue
if row.get("id") and str(row["id"]) in protectedIdSet:
continue
batch.append(row)
if len(batch) >= batchSize:
insertedCount += _insertRows(conn, tableName, batch, physicalCols, mode)
batch = []
if batch:
insertedCount += _insertRows(conn, tableName, batch, physicalCols, mode)
conn.commit()
dbResult[tableName] = insertedCount
except Exception as e:
conn.rollback()
warnings.append(f"INSERT into {dbName}.{tableName} failed: {e}")
logger.warning("INSERT into %s.%s failed: %s", dbName, tableName, e)
except Exception as e:
logger.error("Import failed for database %s: %s", dbName, e)
return {"database": dbName, "tables": {}, "recordCount": 0,
"warnings": [f"Import fuer '{dbName}' fehlgeschlagen: {e}"]}
finally:
conn.close()
recordCount = sum(dbResult.values())
return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings}