790 lines
28 KiB
Python
790 lines
28 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Database migration utilities — backup (export) and restore (import) for all
|
|
registered PowerOn databases.
|
|
|
|
System objects (root mandate, admin user, event user) are protected: they are
|
|
never deleted or overwritten during import. Their IDs in the backup payload
|
|
are remapped to the IDs of the corresponding live objects so that all FK
|
|
references stay consistent.
|
|
|
|
All functions are intended for SysAdmin use only (access control in the route layer).
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.shared.dbRegistry import getRegisteredDatabases
|
|
from modules.system.databaseHealth import _getConnection, _jsonSafe
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_EXPORT_FORMAT_VERSION = "1.0"
|
|
_SYSTEM_TABLE = "_system"
|
|
|
|
_EXCLUDED_TABLES: Dict[str, Set[str]] = {
|
|
"poweron_app": {"Token", "AuthEvent"},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Instance label
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _getInstanceLabel() -> str:
|
|
"""Return the instance type from APP_ENV_TYPE (e.g. 'dev', 'int', 'prod')."""
|
|
return APP_CONFIG.get("APP_ENV_TYPE", "unknown")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Database list
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _getAvailableDatabases() -> List[dict]:
|
|
"""Return registered databases with table/row counts for the UI."""
|
|
registeredDbs = getRegisteredDatabases()
|
|
results: List[dict] = []
|
|
for dbName in sorted(registeredDbs):
|
|
if dbName == "poweron_test":
|
|
continue
|
|
entry: dict = {"name": dbName, "tableCount": 0, "recordCount": 0}
|
|
try:
|
|
conn = _getConnection(dbName)
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT relname, n_live_tup
|
|
FROM pg_stat_user_tables
|
|
WHERE schemaname = 'public'
|
|
AND relname NOT LIKE '\\_%%'
|
|
""")
|
|
for row in cur.fetchall():
|
|
entry["tableCount"] += 1
|
|
entry["recordCount"] += int(row["n_live_tup"])
|
|
finally:
|
|
conn.close()
|
|
except Exception as e:
|
|
logger.warning("Could not stat database %s: %s", dbName, e)
|
|
results.append(entry)
|
|
return results
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Export
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _exportDatabases(databases: List[str]) -> dict:
|
|
"""Export selected databases as a JSON-serialisable dict.
|
|
|
|
Returns ``{meta: {...}, databases: {dbName: {tables: {tbl: [rows]}, summary: {...}}}}``
|
|
"""
|
|
registeredDbs = getRegisteredDatabases()
|
|
|
|
if not databases:
|
|
raise ValueError("No databases selected for export.")
|
|
|
|
exportData: dict = {
|
|
"meta": {
|
|
"exportedAt": datetime.now(timezone.utc).isoformat(),
|
|
"version": _EXPORT_FORMAT_VERSION,
|
|
"databaseCount": 0,
|
|
"totalTables": 0,
|
|
"totalRecords": 0,
|
|
},
|
|
"databases": {},
|
|
}
|
|
|
|
for dbName in databases:
|
|
if dbName not in registeredDbs:
|
|
logger.warning("Export: skipping unregistered database %s", dbName)
|
|
continue
|
|
try:
|
|
dbPayload = _exportSingleDb(dbName)
|
|
exportData["databases"][dbName] = dbPayload
|
|
exportData["meta"]["databaseCount"] += 1
|
|
exportData["meta"]["totalTables"] += dbPayload["tableCount"]
|
|
exportData["meta"]["totalRecords"] += dbPayload["totalRecords"]
|
|
except Exception as e:
|
|
logger.error("Export failed for database %s: %s", dbName, e)
|
|
|
|
return exportData
|
|
|
|
|
|
def _exportSingleDb(dbName: str) -> dict:
|
|
conn = _getConnection(dbName)
|
|
excluded = _EXCLUDED_TABLES.get(dbName, set())
|
|
try:
|
|
tables = _listTables(conn)
|
|
dbPayload: dict = {"tables": {}, "summary": {}, "tableCount": 0, "totalRecords": 0}
|
|
for tbl in tables:
|
|
if tbl in excluded:
|
|
logger.info("Export: skipping excluded table %s.%s", dbName, tbl)
|
|
continue
|
|
rows = _readTableRows(conn, tbl)
|
|
dbPayload["tables"][tbl] = rows
|
|
dbPayload["summary"][tbl] = {"recordCount": len(rows)}
|
|
dbPayload["tableCount"] += 1
|
|
dbPayload["totalRecords"] += len(rows)
|
|
return dbPayload
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def _listTables(conn) -> List[str]:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_type = 'BASE TABLE'
|
|
AND table_name != %s
|
|
ORDER BY table_name
|
|
""", (_SYSTEM_TABLE,))
|
|
return [row["table_name"] for row in cur.fetchall()]
|
|
|
|
|
|
def _readTableRows(conn, tableName: str) -> List[dict]:
|
|
with conn.cursor() as cur:
|
|
cur.execute(f'SELECT * FROM "{tableName}"')
|
|
return [{k: _jsonSafe(v) for k, v in dict(row).items()} for row in cur.fetchall()]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Validate
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _validateImportPayload(payload: dict) -> dict:
|
|
"""Validate an import payload without writing anything.
|
|
|
|
Returns ``{valid, summary, warnings, systemObjectsFound}``.
|
|
"""
|
|
warnings: List[str] = []
|
|
summary: List[dict] = []
|
|
|
|
meta = payload.get("meta")
|
|
if not meta or not isinstance(meta, dict):
|
|
return {"valid": False, "summary": [], "warnings": ["Fehlende oder ungueltige 'meta'-Sektion"], "systemObjectsFound": []}
|
|
|
|
version = meta.get("version", "")
|
|
if version != _EXPORT_FORMAT_VERSION:
|
|
warnings.append(f"Unbekannte Format-Version: {version} (erwartet: {_EXPORT_FORMAT_VERSION})")
|
|
|
|
databases = payload.get("databases")
|
|
if not databases or not isinstance(databases, dict):
|
|
return {"valid": False, "summary": [], "warnings": ["Fehlende oder ungueltige 'databases'-Sektion"], "systemObjectsFound": []}
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
|
|
for dbName, dbData in databases.items():
|
|
tables = dbData.get("tables", {})
|
|
tableCount = len(tables)
|
|
recordCount = sum(len(rows) for rows in tables.values() if isinstance(rows, list))
|
|
registered = dbName in registeredDbs
|
|
if not registered:
|
|
warnings.append(f"Datenbank '{dbName}' ist nicht registriert und wird uebersprungen")
|
|
summary.append({
|
|
"database": dbName,
|
|
"tableCount": tableCount,
|
|
"recordCount": recordCount,
|
|
"registered": registered,
|
|
})
|
|
|
|
systemObjectsFound = _detectSystemObjectsInPayload(payload)
|
|
|
|
valid = any(s["registered"] for s in summary)
|
|
return {
|
|
"valid": valid,
|
|
"summary": summary,
|
|
"warnings": warnings,
|
|
"systemObjectsFound": systemObjectsFound,
|
|
}
|
|
|
|
|
|
def _detectSystemObjectsInPayload(payload: dict) -> List[dict]:
|
|
"""Find system objects (root mandate, admin user, event user) in a payload."""
|
|
found: List[dict] = []
|
|
appData = payload.get("databases", {}).get("poweron_app", {}).get("tables", {})
|
|
|
|
for row in appData.get("Mandate", []):
|
|
if row.get("name") == "root" and row.get("isSystem") is True:
|
|
found.append({"type": "mandate", "label": "Root Mandate", "payloadId": row.get("id")})
|
|
|
|
for row in appData.get("UserInDB", []):
|
|
if row.get("username") == "admin":
|
|
found.append({"type": "user", "label": "Admin User", "payloadId": row.get("id")})
|
|
elif row.get("username") == "event":
|
|
found.append({"type": "user", "label": "Event User", "payloadId": row.get("id")})
|
|
|
|
return found
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# System-object ID remapping
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _loadLiveSystemObjectIds() -> Dict[str, str]:
|
|
"""Load the IDs of the 3 protected system objects from the live DB.
|
|
|
|
Returns a dict like ``{"rootMandate": "<uuid>", "adminUser": "<uuid>", "eventUser": "<uuid>"}``.
|
|
"""
|
|
registeredDbs = getRegisteredDatabases()
|
|
if "poweron_app" not in registeredDbs:
|
|
return {}
|
|
|
|
result: Dict[str, str] = {}
|
|
conn = _getConnection("poweron_app")
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""SELECT id FROM "Mandate" WHERE "name" = 'root' AND "isSystem" = true LIMIT 1""")
|
|
row = cur.fetchone()
|
|
if row:
|
|
result["rootMandate"] = str(row["id"])
|
|
|
|
cur.execute("""SELECT id FROM "UserInDB" WHERE "username" = 'admin' LIMIT 1""")
|
|
row = cur.fetchone()
|
|
if row:
|
|
result["adminUser"] = str(row["id"])
|
|
|
|
cur.execute("""SELECT id FROM "UserInDB" WHERE "username" = 'event' LIMIT 1""")
|
|
row = cur.fetchone()
|
|
if row:
|
|
result["eventUser"] = str(row["id"])
|
|
finally:
|
|
conn.close()
|
|
|
|
return result
|
|
|
|
|
|
def _buildIdRemapFromPayload(payload: dict, liveIds: Dict[str, str]) -> Dict[str, str]:
|
|
"""Build an ``{oldId: newId}`` mapping for system objects.
|
|
|
|
Compares IDs found in the payload with the live system-object IDs.
|
|
Only entries where the IDs actually differ are included.
|
|
"""
|
|
remap: Dict[str, str] = {}
|
|
appTables = payload.get("databases", {}).get("poweron_app", {}).get("tables", {})
|
|
|
|
for row in appTables.get("Mandate", []):
|
|
if row.get("name") == "root" and row.get("isSystem") is True:
|
|
oldId = str(row.get("id", ""))
|
|
newId = liveIds.get("rootMandate", "")
|
|
if oldId and newId and oldId != newId:
|
|
remap[oldId] = newId
|
|
|
|
for row in appTables.get("UserInDB", []):
|
|
username = row.get("username")
|
|
oldId = str(row.get("id", ""))
|
|
if username == "admin":
|
|
newId = liveIds.get("adminUser", "")
|
|
elif username == "event":
|
|
newId = liveIds.get("eventUser", "")
|
|
else:
|
|
continue
|
|
if oldId and newId and oldId != newId:
|
|
remap[oldId] = newId
|
|
|
|
return remap
|
|
|
|
|
|
def _remapSystemObjectIds(payload: dict, remap: Dict[str, str]) -> dict:
|
|
"""Walk the entire payload and replace every value that matches an old system-object ID."""
|
|
if not remap:
|
|
return payload
|
|
|
|
remapSet = set(remap.keys())
|
|
|
|
databases = payload.get("databases", {})
|
|
for dbName, dbData in databases.items():
|
|
tables = dbData.get("tables", {})
|
|
for tableName, rows in tables.items():
|
|
if not isinstance(rows, list):
|
|
continue
|
|
for row in rows:
|
|
_remapRowValues(row, remap, remapSet)
|
|
|
|
return payload
|
|
|
|
|
|
def _remapDbTables(tables: dict, remap: Dict[str, str]) -> None:
|
|
"""In-place remap system-object IDs in a single DB's tables dict."""
|
|
if not remap:
|
|
return
|
|
remapSet = set(remap.keys())
|
|
for tableName, rows in tables.items():
|
|
if not isinstance(rows, list):
|
|
continue
|
|
for row in rows:
|
|
_remapRowValues(row, remap, remapSet)
|
|
|
|
|
|
def _remapRowValues(row: dict, remap: Dict[str, str], remapSet: Set[str]) -> None:
|
|
"""In-place replace string values in a row dict that match a remap key."""
|
|
for key, val in row.items():
|
|
if isinstance(val, str) and val in remapSet:
|
|
row[key] = remap[val]
|
|
elif isinstance(val, dict):
|
|
_remapRowValues(val, remap, remapSet)
|
|
elif isinstance(val, list):
|
|
for i, item in enumerate(val):
|
|
if isinstance(item, str) and item in remapSet:
|
|
val[i] = remap[item]
|
|
elif isinstance(item, dict):
|
|
_remapRowValues(item, remap, remapSet)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Import
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_PROTECTED_ROWS: Dict[str, List[dict]] = {
|
|
"Mandate": [{"name": "root", "isSystem": True}],
|
|
"UserInDB": [{"username": "admin"}, {"username": "event"}],
|
|
}
|
|
|
|
|
|
def _isProtectedRow(tableName: str, row: dict) -> bool:
|
|
"""Return True if a row represents a protected system object."""
|
|
patterns = _PROTECTED_ROWS.get(tableName, [])
|
|
for pattern in patterns:
|
|
if all(row.get(k) == v for k, v in pattern.items()):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _importDatabases(payload: dict, mode: str) -> dict:
|
|
"""Import databases from a validated payload.
|
|
|
|
``mode`` is ``"replace"`` (clear + insert) or ``"merge"`` (insert missing only).
|
|
"""
|
|
if mode not in ("replace", "merge"):
|
|
raise ValueError(f"Invalid import mode: {mode}")
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
|
|
liveIds = _loadLiveSystemObjectIds()
|
|
remap = _buildIdRemapFromPayload(payload, liveIds)
|
|
if remap:
|
|
logger.info("System-object ID remap: %s", remap)
|
|
_remapSystemObjectIds(payload, remap)
|
|
|
|
protectedIdSet = set(liveIds.values())
|
|
|
|
imported: Dict[str, dict] = {}
|
|
warnings: List[str] = []
|
|
databases = payload.get("databases", {})
|
|
|
|
for dbName, dbData in databases.items():
|
|
if dbName not in registeredDbs:
|
|
warnings.append(f"Datenbank '{dbName}' uebersprungen (nicht registriert)")
|
|
continue
|
|
|
|
tables = dbData.get("tables", {})
|
|
dbResult: Dict[str, int] = {}
|
|
|
|
conn = _getConnection(dbName)
|
|
try:
|
|
conn.autocommit = False
|
|
existingTables = set(_listTables(conn))
|
|
|
|
for tableName, rows in tables.items():
|
|
if not isinstance(rows, list):
|
|
continue
|
|
if tableName not in existingTables:
|
|
warnings.append(f"Tabelle '{dbName}.{tableName}' existiert nicht, uebersprungen")
|
|
continue
|
|
|
|
physicalCols = _getPhysicalColumns(conn, tableName)
|
|
if not physicalCols:
|
|
continue
|
|
|
|
filteredRows = []
|
|
for row in rows:
|
|
if _isProtectedRow(tableName, row):
|
|
continue
|
|
if row.get("id") and str(row["id"]) in protectedIdSet:
|
|
continue
|
|
filteredRows.append(row)
|
|
|
|
if mode == "replace":
|
|
_deleteNonProtected(conn, tableName, protectedIdSet)
|
|
|
|
insertedCount = _insertRows(conn, tableName, filteredRows, physicalCols, mode)
|
|
dbResult[tableName] = insertedCount
|
|
|
|
conn.commit()
|
|
except Exception as e:
|
|
conn.rollback()
|
|
logger.error("Import failed for database %s: %s", dbName, e)
|
|
warnings.append(f"Import fuer '{dbName}' fehlgeschlagen: {e}")
|
|
continue
|
|
finally:
|
|
conn.close()
|
|
|
|
imported[dbName] = dbResult
|
|
|
|
totalRecords = sum(sum(v.values()) for v in imported.values())
|
|
return {
|
|
"success": True,
|
|
"imported": imported,
|
|
"totalRecords": totalRecords,
|
|
"warnings": warnings,
|
|
}
|
|
|
|
|
|
def _getPhysicalColumns(conn, tableName: str) -> List[str]:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema = 'public' AND table_name = %s
|
|
ORDER BY ordinal_position
|
|
""", (tableName,))
|
|
return [row["column_name"] for row in cur.fetchall()]
|
|
|
|
|
|
def _deleteNonProtected(conn, tableName: str, protectedIds: Set[str]) -> int:
|
|
"""Delete all rows except protected system objects."""
|
|
if not protectedIds:
|
|
with conn.cursor() as cur:
|
|
cur.execute(f'DELETE FROM "{tableName}"')
|
|
return cur.rowcount
|
|
|
|
idList = list(protectedIds)
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
f'DELETE FROM "{tableName}" WHERE "id"::text != ALL(%(ids)s)',
|
|
{"ids": idList},
|
|
)
|
|
return cur.rowcount
|
|
|
|
|
|
def _insertRows(
|
|
conn,
|
|
tableName: str,
|
|
rows: List[dict],
|
|
physicalCols: List[str],
|
|
mode: str,
|
|
) -> int:
|
|
"""Insert rows into a table. In merge mode, skip rows whose id already exists."""
|
|
if not rows:
|
|
return 0
|
|
|
|
physicalColSet = set(physicalCols)
|
|
inserted = 0
|
|
|
|
for row in rows:
|
|
cols = [c for c in row.keys() if c in physicalColSet]
|
|
if not cols:
|
|
continue
|
|
|
|
values = [_pgSafe(row[c]) for c in cols]
|
|
colNames = ", ".join(f'"{c}"' for c in cols)
|
|
placeholders = ", ".join(["%s"] * len(cols))
|
|
|
|
if mode == "merge":
|
|
sql = f'INSERT INTO "{tableName}" ({colNames}) VALUES ({placeholders}) ON CONFLICT ("id") DO NOTHING'
|
|
else:
|
|
sql = f'INSERT INTO "{tableName}" ({colNames}) VALUES ({placeholders})'
|
|
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("SAVEPOINT row_sp")
|
|
cur.execute(sql, values)
|
|
inserted += cur.rowcount
|
|
cur.execute("RELEASE SAVEPOINT row_sp")
|
|
except Exception as e:
|
|
logger.warning("Insert failed for %s row: %s", tableName, e)
|
|
with conn.cursor() as cur:
|
|
cur.execute("ROLLBACK TO SAVEPOINT row_sp")
|
|
|
|
return inserted
|
|
|
|
|
|
def _pgSafe(v: Any) -> Any:
|
|
"""Convert Python values to psycopg2-compatible types."""
|
|
import json as _json
|
|
|
|
if v is None or isinstance(v, (str, int, float, bool)):
|
|
return v
|
|
if isinstance(v, (dict, list)):
|
|
return _json.dumps(v)
|
|
return str(v)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prepare import (validate + remap, return context for per-DB import)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _prepareImport(payload: dict) -> dict:
|
|
"""Validate, remap system-object IDs, and return the prepared payload
|
|
together with metadata the frontend needs to drive per-DB import.
|
|
|
|
Returns ``{valid, warnings, systemObjectsFound, databases, protectedIds, remappedPayload}``.
|
|
"""
|
|
validation = _validateImportPayload(payload)
|
|
if not validation.get("valid"):
|
|
return {
|
|
"valid": False,
|
|
"warnings": validation.get("warnings", []),
|
|
"systemObjectsFound": validation.get("systemObjectsFound", []),
|
|
"databases": [],
|
|
"protectedIds": [],
|
|
}
|
|
|
|
liveIds = _loadLiveSystemObjectIds()
|
|
remap = _buildIdRemapFromPayload(payload, liveIds)
|
|
if remap:
|
|
logger.info("System-object ID remap: %s", remap)
|
|
_remapSystemObjectIds(payload, remap)
|
|
|
|
protectedIdSet = set(liveIds.values())
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
dbList = []
|
|
for dbName, dbData in payload.get("databases", {}).items():
|
|
if dbName not in registeredDbs:
|
|
continue
|
|
tables = dbData.get("tables", {})
|
|
recordCount = sum(len(rows) for rows in tables.values() if isinstance(rows, list))
|
|
dbList.append({
|
|
"database": dbName,
|
|
"tableCount": len(tables),
|
|
"recordCount": recordCount,
|
|
})
|
|
|
|
return {
|
|
"valid": True,
|
|
"warnings": validation.get("warnings", []),
|
|
"systemObjectsFound": validation.get("systemObjectsFound", []),
|
|
"databases": dbList,
|
|
"protectedIds": list(protectedIdSet),
|
|
}
|
|
|
|
|
|
def _ensureDatabaseExists(dbName: str) -> bool:
|
|
"""Create the PostgreSQL database if it does not yet exist.
|
|
|
|
Connects to the ``postgres`` admin database using the same credentials
|
|
as the target DB. Returns True if the database was created, False if
|
|
it already existed.
|
|
"""
|
|
registeredDbs = getRegisteredDatabases()
|
|
configPrefix = registeredDbs.get(dbName)
|
|
if configPrefix is None:
|
|
return False
|
|
|
|
hostKey = f"{configPrefix}_HOST" if configPrefix != "DB" else "DB_HOST"
|
|
portKey = f"{configPrefix}_PORT" if configPrefix != "DB" else "DB_PORT"
|
|
userKey = f"{configPrefix}_USER" if configPrefix != "DB" else "DB_USER"
|
|
passwordKey = f"{configPrefix}_PASSWORD_SECRET" if configPrefix != "DB" else "DB_PASSWORD_SECRET"
|
|
|
|
adminConn = psycopg2.connect(
|
|
host=APP_CONFIG.get(hostKey, "localhost"),
|
|
port=int(APP_CONFIG.get(portKey, 5432)),
|
|
database="postgres",
|
|
user=APP_CONFIG.get(userKey),
|
|
password=APP_CONFIG.get(passwordKey),
|
|
client_encoding="utf8",
|
|
)
|
|
try:
|
|
adminConn.autocommit = True
|
|
with adminConn.cursor() as cur:
|
|
cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (dbName,))
|
|
if cur.fetchone():
|
|
return False
|
|
cur.execute(f'CREATE DATABASE "{dbName}"')
|
|
logger.info("Created missing database: %s", dbName)
|
|
return True
|
|
finally:
|
|
adminConn.close()
|
|
|
|
|
|
def _createTableFromExport(conn, tableName: str, rows: List[dict]) -> None:
|
|
"""Create a table based on the column structure found in the export data.
|
|
|
|
Uses TEXT for all columns since we don't have the original DDL.
|
|
The ``id`` column gets a PRIMARY KEY constraint.
|
|
"""
|
|
allKeys: List[str] = []
|
|
seen: set = set()
|
|
for row in rows:
|
|
for k in row.keys():
|
|
if k not in seen:
|
|
allKeys.append(k)
|
|
seen.add(k)
|
|
|
|
if not allKeys:
|
|
return
|
|
|
|
colDefs = []
|
|
for col in allKeys:
|
|
if col == "id":
|
|
colDefs.append(f'"{col}" TEXT PRIMARY KEY')
|
|
else:
|
|
colDefs.append(f'"{col}" TEXT')
|
|
|
|
ddl = f'CREATE TABLE IF NOT EXISTS "{tableName}" ({", ".join(colDefs)})'
|
|
with conn.cursor() as cur:
|
|
cur.execute(ddl)
|
|
logger.info("Created table %s with %d columns", tableName, len(allKeys))
|
|
|
|
|
|
def _getTableImportOrder(conn, tableNames: List[str]) -> List[str]:
|
|
"""Sort tables by FK dependencies (parents first) using topological sort.
|
|
|
|
Queries ``information_schema`` for FK relationships, builds a dependency
|
|
graph, and returns the tables in an order that satisfies referential
|
|
integrity: parent tables before child tables.
|
|
"""
|
|
tableSet = set(tableNames)
|
|
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT DISTINCT
|
|
tc.table_name AS child_table,
|
|
ccu.table_name AS parent_table
|
|
FROM information_schema.table_constraints tc
|
|
JOIN information_schema.constraint_column_usage ccu
|
|
ON ccu.constraint_name = tc.constraint_name
|
|
AND ccu.table_schema = tc.table_schema
|
|
WHERE tc.constraint_type = 'FOREIGN KEY'
|
|
AND tc.table_schema = 'public'
|
|
AND tc.table_name != ccu.table_name
|
|
""")
|
|
fks = cur.fetchall()
|
|
|
|
deps: Dict[str, Set[str]] = {t: set() for t in tableNames}
|
|
for fk in fks:
|
|
child = fk["child_table"]
|
|
parent = fk["parent_table"]
|
|
if child in tableSet and parent in tableSet:
|
|
deps[child].add(parent)
|
|
|
|
inDegree = {t: len(deps[t]) for t in tableNames}
|
|
queue = sorted(t for t in tableNames if inDegree[t] == 0)
|
|
ordered: List[str] = []
|
|
|
|
while queue:
|
|
node = queue.pop(0)
|
|
ordered.append(node)
|
|
for t in tableNames:
|
|
if node in deps[t]:
|
|
deps[t].discard(node)
|
|
inDegree[t] -= 1
|
|
if inDegree[t] == 0:
|
|
queue.append(t)
|
|
queue.sort()
|
|
|
|
remaining = [t for t in tableNames if t not in set(ordered)]
|
|
if remaining:
|
|
logger.warning("FK cycle detected, appending without order guarantee: %s", remaining)
|
|
ordered.extend(sorted(remaining))
|
|
|
|
return ordered
|
|
|
|
|
|
def _importSingleDb(payload: dict, dbName: str, mode: str, protectedIds: List[str]) -> dict:
|
|
"""Import a single database from the (already remapped) payload.
|
|
|
|
Tables are sorted by FK dependencies: parent tables are inserted first,
|
|
child tables are deleted first (reverse order) in replace mode.
|
|
|
|
Returns ``{database, tables: {tableName: insertedCount}, recordCount, warnings}``.
|
|
"""
|
|
if mode not in ("replace", "merge"):
|
|
raise ValueError(f"Invalid import mode: {mode}")
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
if dbName not in registeredDbs:
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Datenbank '{dbName}' nicht registriert"]}
|
|
|
|
dbData = payload.get("databases", {}).get(dbName)
|
|
if not dbData:
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Keine Daten fuer '{dbName}' im Payload"]}
|
|
|
|
try:
|
|
dbCreated = _ensureDatabaseExists(dbName)
|
|
except Exception as e:
|
|
logger.error("Failed to ensure database %s exists: %s", dbName, e)
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Datenbank '{dbName}' konnte nicht erstellt werden: {e}"]}
|
|
|
|
protectedIdSet = set(protectedIds)
|
|
tables = dbData.get("tables", {})
|
|
warnings: List[str] = []
|
|
dbResult: Dict[str, int] = {}
|
|
excluded = _EXCLUDED_TABLES.get(dbName, set())
|
|
|
|
if dbCreated:
|
|
warnings.append(f"Datenbank '{dbName}' wurde neu erstellt")
|
|
|
|
conn = _getConnection(dbName)
|
|
try:
|
|
conn.autocommit = False
|
|
existingTables = set(_listTables(conn))
|
|
|
|
# Pre-create missing tables so FK ordering can discover them
|
|
for tableName, rows in tables.items():
|
|
if tableName in excluded or not isinstance(rows, list):
|
|
continue
|
|
if tableName not in existingTables:
|
|
_createTableFromExport(conn, tableName, rows)
|
|
conn.commit()
|
|
conn.autocommit = False
|
|
existingTables.add(tableName)
|
|
|
|
# Build importable table list and sort by FK dependencies
|
|
importable = [t for t in tables
|
|
if t not in excluded
|
|
and isinstance(tables.get(t), list)
|
|
and t in existingTables]
|
|
importOrder = _getTableImportOrder(conn, importable)
|
|
|
|
logger.info("Import order for %s: %s", dbName, importOrder)
|
|
|
|
for tableName in tables:
|
|
if tableName in excluded and isinstance(tables.get(tableName), list):
|
|
warnings.append(f"Table '{dbName}.{tableName}' excluded (security/transient)")
|
|
|
|
# Phase 1 (replace only): DELETE children first (reverse topological order)
|
|
if mode == "replace":
|
|
for tableName in reversed(importOrder):
|
|
_deleteNonProtected(conn, tableName, protectedIdSet)
|
|
|
|
# Phase 2: INSERT parents first (topological order)
|
|
for tableName in importOrder:
|
|
rows = tables[tableName]
|
|
physicalCols = _getPhysicalColumns(conn, tableName)
|
|
if not physicalCols:
|
|
continue
|
|
|
|
filteredRows = []
|
|
for row in rows:
|
|
if _isProtectedRow(tableName, row):
|
|
continue
|
|
if row.get("id") and str(row["id"]) in protectedIdSet:
|
|
continue
|
|
filteredRows.append(row)
|
|
|
|
insertedCount = _insertRows(conn, tableName, filteredRows, physicalCols, mode)
|
|
dbResult[tableName] = insertedCount
|
|
|
|
conn.commit()
|
|
except Exception as e:
|
|
conn.rollback()
|
|
logger.error("Import failed for database %s: %s", dbName, e)
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Import fuer '{dbName}' fehlgeschlagen: {e}"]}
|
|
finally:
|
|
conn.close()
|
|
|
|
recordCount = sum(dbResult.values())
|
|
return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings}
|