1327 lines
47 KiB
Python
1327 lines
47 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Database migration utilities — backup (export) and restore (import) for all
|
|
registered PowerOn databases.
|
|
|
|
System objects (root mandate, admin user, event user) are protected: they are
|
|
never deleted or overwritten during import. Their IDs in the backup payload
|
|
are remapped to the IDs of the corresponding live objects so that all FK
|
|
references stay consistent.
|
|
|
|
All functions are intended for SysAdmin use only (access control in the route layer).
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.shared.dbRegistry import getRegisteredDatabases
|
|
from modules.shared.fkRegistry import getFkRelationships
|
|
from modules.datamodels.datamodelBase import MODEL_REGISTRY
|
|
from modules.system.databaseHealth import _getConnection, _jsonSafe
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_EXPORT_FORMAT_VERSION = "1.0"
|
|
_SYSTEM_TABLE = "_system"
|
|
|
|
_EXCLUDED_TABLES: Dict[str, Set[str]] = {
|
|
"poweron_app": {"Token", "AuthEvent"},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Instance label
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _getInstanceLabel() -> str:
|
|
"""Return the instance type from APP_ENV_TYPE (e.g. 'dev', 'int', 'prod')."""
|
|
return APP_CONFIG.get("APP_ENV_TYPE", "unknown")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Database list
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _getAvailableDatabases() -> List[dict]:
|
|
"""Return registered databases with table/row counts for the UI."""
|
|
registeredDbs = getRegisteredDatabases()
|
|
results: List[dict] = []
|
|
for dbName in sorted(registeredDbs):
|
|
if dbName == "poweron_test":
|
|
continue
|
|
entry: dict = {"name": dbName, "tableCount": 0, "recordCount": 0}
|
|
try:
|
|
conn = _getConnection(dbName)
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT relname, n_live_tup
|
|
FROM pg_stat_user_tables
|
|
WHERE schemaname = 'public'
|
|
AND relname NOT LIKE '\\_%%'
|
|
""")
|
|
for row in cur.fetchall():
|
|
entry["tableCount"] += 1
|
|
entry["recordCount"] += int(row["n_live_tup"])
|
|
finally:
|
|
conn.close()
|
|
except Exception as e:
|
|
logger.warning("Could not stat database %s: %s", dbName, e)
|
|
results.append(entry)
|
|
return results
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Export
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _exportDatabases(databases: List[str]) -> dict:
|
|
"""Export selected databases as a JSON-serialisable dict.
|
|
|
|
Returns ``{meta: {...}, databases: {dbName: {tables: {tbl: [rows]}, summary: {...}}}}``
|
|
"""
|
|
registeredDbs = getRegisteredDatabases()
|
|
|
|
if not databases:
|
|
raise ValueError("No databases selected for export.")
|
|
|
|
exportData: dict = {
|
|
"meta": {
|
|
"exportedAt": datetime.now(timezone.utc).isoformat(),
|
|
"version": _EXPORT_FORMAT_VERSION,
|
|
"databaseCount": 0,
|
|
"totalTables": 0,
|
|
"totalRecords": 0,
|
|
},
|
|
"databases": {},
|
|
}
|
|
|
|
for dbName in databases:
|
|
if dbName not in registeredDbs:
|
|
logger.warning("Export: skipping unregistered database %s", dbName)
|
|
continue
|
|
try:
|
|
dbPayload = _exportSingleDb(dbName)
|
|
exportData["databases"][dbName] = dbPayload
|
|
exportData["meta"]["databaseCount"] += 1
|
|
exportData["meta"]["totalTables"] += dbPayload["tableCount"]
|
|
exportData["meta"]["totalRecords"] += dbPayload["totalRecords"]
|
|
except Exception as e:
|
|
logger.error("Export failed for database %s: %s", dbName, e)
|
|
|
|
return exportData
|
|
|
|
|
|
def _getModelTablesForDb(dbName: str, physicalTables: List[str]) -> List[str]:
|
|
"""Return only those physical tables that have a matching Pydantic model
|
|
registered in MODEL_REGISTRY.
|
|
|
|
Tables without a Pydantic class (legacy / orphan tables) are excluded
|
|
from export so the backup contains only model-backed data.
|
|
|
|
Note: the same model can exist in multiple databases (shared-table
|
|
pattern), so we only check membership in MODEL_REGISTRY, not the
|
|
DB mapping.
|
|
"""
|
|
return sorted(
|
|
t for t in physicalTables
|
|
if t in MODEL_REGISTRY
|
|
)
|
|
|
|
|
|
def _exportSingleDb(dbName: str) -> dict:
|
|
conn = _getConnection(dbName)
|
|
excluded = _EXCLUDED_TABLES.get(dbName, set())
|
|
try:
|
|
allTables = _listTables(conn)
|
|
modelTables = _getModelTablesForDb(dbName, allTables)
|
|
skippedLegacy = set(allTables) - set(modelTables) - excluded - {_SYSTEM_TABLE}
|
|
if skippedLegacy:
|
|
logger.info("Export %s: skipping %d legacy tables without model: %s",
|
|
dbName, len(skippedLegacy), sorted(skippedLegacy))
|
|
|
|
dbPayload: dict = {"tables": {}, "summary": {}, "tableCount": 0, "totalRecords": 0}
|
|
for tbl in modelTables:
|
|
if tbl in excluded:
|
|
logger.info("Export: skipping excluded table %s.%s", dbName, tbl)
|
|
continue
|
|
rows = _readTableRows(conn, tbl)
|
|
dbPayload["tables"][tbl] = rows
|
|
dbPayload["summary"][tbl] = {"recordCount": len(rows)}
|
|
dbPayload["tableCount"] += 1
|
|
dbPayload["totalRecords"] += len(rows)
|
|
return dbPayload
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
def _listTables(conn) -> List[str]:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT table_name
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
AND table_type = 'BASE TABLE'
|
|
AND table_name != %s
|
|
ORDER BY table_name
|
|
""", (_SYSTEM_TABLE,))
|
|
return [row["table_name"] for row in cur.fetchall()]
|
|
|
|
|
|
def _readTableRows(conn, tableName: str) -> List[dict]:
|
|
with conn.cursor() as cur:
|
|
cur.execute(f'SELECT * FROM "{tableName}"')
|
|
return [{k: _jsonSafe(v) for k, v in dict(row).items()} for row in cur.fetchall()]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Streaming Export (memory-safe, handles arbitrarily large databases)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def streamExportGenerator(databases: List[str], instanceLabel: str = ""):
|
|
"""Yield JSON fragments for a streaming database export.
|
|
|
|
Writes valid JSON incrementally (row-by-row, table-by-table) so that
|
|
neither the backend RAM nor the browser JS heap is saturated — works
|
|
for databases of any size.
|
|
|
|
The output format is identical to the non-streaming _exportDatabases():
|
|
{"meta": {...}, "databases": {"dbName": {"tables": {"tbl": [rows]}, ...}}}
|
|
"""
|
|
import json
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
validDbs = [db for db in databases if db in registeredDbs]
|
|
|
|
totalDbs = 0
|
|
totalTables = 0
|
|
totalRecords = 0
|
|
|
|
yield '{"meta":'
|
|
metaPlaceholder = json.dumps({
|
|
"exportedAt": datetime.now(timezone.utc).isoformat(),
|
|
"version": _EXPORT_FORMAT_VERSION,
|
|
"instanceLabel": instanceLabel,
|
|
"databaseCount": "<<PLACEHOLDER>>",
|
|
}, ensure_ascii=False)
|
|
yield metaPlaceholder
|
|
yield ',"databases":{'
|
|
|
|
firstDb = True
|
|
for dbName in validDbs:
|
|
excluded = _EXCLUDED_TABLES.get(dbName, set())
|
|
conn = None
|
|
try:
|
|
conn = _getConnection(dbName)
|
|
allTables = _listTables(conn)
|
|
modelTables = _getModelTablesForDb(dbName, allTables)
|
|
|
|
if not firstDb:
|
|
yield ','
|
|
firstDb = False
|
|
|
|
yield json.dumps(dbName, ensure_ascii=False)
|
|
yield ':{"tables":{'
|
|
|
|
firstTable = True
|
|
dbTableCount = 0
|
|
dbRecordCount = 0
|
|
|
|
for tbl in modelTables:
|
|
if tbl in excluded:
|
|
continue
|
|
|
|
if not firstTable:
|
|
yield ','
|
|
firstTable = False
|
|
|
|
yield json.dumps(tbl, ensure_ascii=False)
|
|
yield ':['
|
|
|
|
with conn.cursor(name=f"export_{dbName}_{tbl}") as cur:
|
|
cur.itersize = 2000
|
|
cur.execute(f'SELECT * FROM "{tbl}"')
|
|
|
|
firstRow = True
|
|
rowCount = 0
|
|
for row in cur:
|
|
if not firstRow:
|
|
yield ','
|
|
firstRow = False
|
|
safeRow = {k: _jsonSafe(v) for k, v in dict(row).items()}
|
|
yield json.dumps(safeRow, ensure_ascii=False, default=str)
|
|
rowCount += 1
|
|
|
|
yield ']'
|
|
dbTableCount += 1
|
|
dbRecordCount += rowCount
|
|
|
|
yield '},"summary":{'
|
|
firstSummaryTable = True
|
|
for tbl in modelTables:
|
|
if tbl in excluded:
|
|
continue
|
|
if not firstSummaryTable:
|
|
yield ','
|
|
firstSummaryTable = False
|
|
yield json.dumps(tbl, ensure_ascii=False)
|
|
yield ':{"recordCount":0}'
|
|
yield '}'
|
|
yield f',"tableCount":{dbTableCount},"totalRecords":{dbRecordCount}'
|
|
yield '}'
|
|
|
|
totalDbs += 1
|
|
totalTables += dbTableCount
|
|
totalRecords += dbRecordCount
|
|
logger.info("Stream export: %s done (%d tables, %d records)", dbName, dbTableCount, dbRecordCount)
|
|
|
|
except Exception as e:
|
|
logger.error("Stream export failed for %s: %s", dbName, e)
|
|
if not firstDb or not firstTable:
|
|
pass
|
|
finally:
|
|
if conn:
|
|
try:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
yield '}}'
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Validate
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _validateImportPayload(payload: dict) -> dict:
|
|
"""Validate an import payload without writing anything.
|
|
|
|
Returns ``{valid, summary, warnings, systemObjectsFound}``.
|
|
"""
|
|
warnings: List[str] = []
|
|
summary: List[dict] = []
|
|
|
|
meta = payload.get("meta")
|
|
if not meta or not isinstance(meta, dict):
|
|
return {"valid": False, "summary": [], "warnings": ["Fehlende oder ungueltige 'meta'-Sektion"], "systemObjectsFound": []}
|
|
|
|
version = meta.get("version", "")
|
|
if version != _EXPORT_FORMAT_VERSION:
|
|
warnings.append(f"Unbekannte Format-Version: {version} (erwartet: {_EXPORT_FORMAT_VERSION})")
|
|
|
|
databases = payload.get("databases")
|
|
if not databases or not isinstance(databases, dict):
|
|
return {"valid": False, "summary": [], "warnings": ["Fehlende oder ungueltige 'databases'-Sektion"], "systemObjectsFound": []}
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
|
|
for dbName, dbData in databases.items():
|
|
tables = dbData.get("tables", {})
|
|
tableCount = len(tables)
|
|
recordCount = sum(len(rows) for rows in tables.values() if isinstance(rows, list))
|
|
registered = dbName in registeredDbs
|
|
if not registered:
|
|
warnings.append(f"Datenbank '{dbName}' ist nicht registriert und wird uebersprungen")
|
|
summary.append({
|
|
"database": dbName,
|
|
"tableCount": tableCount,
|
|
"recordCount": recordCount,
|
|
"registered": registered,
|
|
})
|
|
|
|
systemObjectsFound = _detectSystemObjectsInPayload(payload)
|
|
|
|
valid = any(s["registered"] for s in summary)
|
|
return {
|
|
"valid": valid,
|
|
"summary": summary,
|
|
"warnings": warnings,
|
|
"systemObjectsFound": systemObjectsFound,
|
|
}
|
|
|
|
|
|
def _detectSystemObjectsInPayload(payload: dict) -> List[dict]:
|
|
"""Find system objects (root mandate, admin user, event user) in a payload."""
|
|
found: List[dict] = []
|
|
appData = payload.get("databases", {}).get("poweron_app", {}).get("tables", {})
|
|
|
|
for row in appData.get("Mandate", []):
|
|
if row.get("name") == "root" and row.get("isSystem") is True:
|
|
found.append({"type": "mandate", "label": "Root Mandate", "payloadId": row.get("id")})
|
|
|
|
for row in appData.get("UserInDB", []):
|
|
if row.get("username") == "admin":
|
|
found.append({"type": "user", "label": "Admin User", "payloadId": row.get("id")})
|
|
elif row.get("username") == "event":
|
|
found.append({"type": "user", "label": "Event User", "payloadId": row.get("id")})
|
|
|
|
return found
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# System-object ID remapping
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _loadLiveSystemObjectIds() -> Dict[str, str]:
|
|
"""Load the IDs of the 3 protected system objects from the live DB.
|
|
|
|
Returns a dict like ``{"rootMandate": "<uuid>", "adminUser": "<uuid>", "eventUser": "<uuid>"}``.
|
|
"""
|
|
registeredDbs = getRegisteredDatabases()
|
|
if "poweron_app" not in registeredDbs:
|
|
return {}
|
|
|
|
result: Dict[str, str] = {}
|
|
conn = _getConnection("poweron_app")
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""SELECT id FROM "Mandate" WHERE "name" = 'root' AND "isSystem" = true LIMIT 1""")
|
|
row = cur.fetchone()
|
|
if row:
|
|
result["rootMandate"] = str(row["id"])
|
|
|
|
cur.execute("""SELECT id FROM "UserInDB" WHERE "username" = 'admin' LIMIT 1""")
|
|
row = cur.fetchone()
|
|
if row:
|
|
result["adminUser"] = str(row["id"])
|
|
|
|
cur.execute("""SELECT id FROM "UserInDB" WHERE "username" = 'event' LIMIT 1""")
|
|
row = cur.fetchone()
|
|
if row:
|
|
result["eventUser"] = str(row["id"])
|
|
finally:
|
|
conn.close()
|
|
|
|
return result
|
|
|
|
|
|
def _buildIdRemapFromPayload(payload: dict, liveIds: Dict[str, str]) -> Dict[str, str]:
|
|
"""Build an ``{oldId: newId}`` mapping for system objects.
|
|
|
|
Compares IDs found in the payload with the live system-object IDs.
|
|
Only entries where the IDs actually differ are included.
|
|
"""
|
|
remap: Dict[str, str] = {}
|
|
appTables = payload.get("databases", {}).get("poweron_app", {}).get("tables", {})
|
|
|
|
for row in appTables.get("Mandate", []):
|
|
if row.get("name") == "root" and row.get("isSystem") is True:
|
|
oldId = str(row.get("id", ""))
|
|
newId = liveIds.get("rootMandate", "")
|
|
if oldId and newId and oldId != newId:
|
|
remap[oldId] = newId
|
|
|
|
for row in appTables.get("UserInDB", []):
|
|
username = row.get("username")
|
|
oldId = str(row.get("id", ""))
|
|
if username == "admin":
|
|
newId = liveIds.get("adminUser", "")
|
|
elif username == "event":
|
|
newId = liveIds.get("eventUser", "")
|
|
else:
|
|
continue
|
|
if oldId and newId and oldId != newId:
|
|
remap[oldId] = newId
|
|
|
|
return remap
|
|
|
|
|
|
def _remapSystemObjectIds(payload: dict, remap: Dict[str, str]) -> dict:
|
|
"""Walk the entire payload and replace every value that matches an old system-object ID."""
|
|
if not remap:
|
|
return payload
|
|
|
|
remapSet = set(remap.keys())
|
|
|
|
databases = payload.get("databases", {})
|
|
for dbName, dbData in databases.items():
|
|
tables = dbData.get("tables", {})
|
|
for tableName, rows in tables.items():
|
|
if not isinstance(rows, list):
|
|
continue
|
|
for row in rows:
|
|
_remapRowValues(row, remap, remapSet)
|
|
|
|
return payload
|
|
|
|
|
|
def _remapDbTables(tables: dict, remap: Dict[str, str]) -> None:
|
|
"""In-place remap system-object IDs in a single DB's tables dict."""
|
|
if not remap:
|
|
return
|
|
remapSet = set(remap.keys())
|
|
for tableName, rows in tables.items():
|
|
if not isinstance(rows, list):
|
|
continue
|
|
for row in rows:
|
|
_remapRowValues(row, remap, remapSet)
|
|
|
|
|
|
def _remapRowValues(row: dict, remap: Dict[str, str], remapSet: Set[str]) -> None:
|
|
"""In-place replace string values in a row dict that match a remap key."""
|
|
for key, val in row.items():
|
|
if isinstance(val, str) and val in remapSet:
|
|
row[key] = remap[val]
|
|
elif isinstance(val, dict):
|
|
_remapRowValues(val, remap, remapSet)
|
|
elif isinstance(val, list):
|
|
for i, item in enumerate(val):
|
|
if isinstance(item, str) and item in remapSet:
|
|
val[i] = remap[item]
|
|
elif isinstance(item, dict):
|
|
_remapRowValues(item, remap, remapSet)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Import
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_PROTECTED_ROWS: Dict[str, List[dict]] = {
|
|
"Mandate": [{"name": "root", "isSystem": True}],
|
|
"UserInDB": [{"username": "admin"}, {"username": "event"}],
|
|
}
|
|
|
|
|
|
def _isProtectedRow(tableName: str, row: dict) -> bool:
|
|
"""Return True if a row represents a protected system object."""
|
|
patterns = _PROTECTED_ROWS.get(tableName, [])
|
|
for pattern in patterns:
|
|
if all(row.get(k) == v for k, v in pattern.items()):
|
|
return True
|
|
return False
|
|
|
|
|
|
def _importDatabases(payload: dict, mode: str) -> dict:
|
|
"""Import databases from a validated payload.
|
|
|
|
``mode`` is ``"replace"`` (clear + insert) or ``"merge"`` (insert missing only).
|
|
"""
|
|
if mode not in ("replace", "merge"):
|
|
raise ValueError(f"Invalid import mode: {mode}")
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
|
|
liveIds = _loadLiveSystemObjectIds()
|
|
remap = _buildIdRemapFromPayload(payload, liveIds)
|
|
if remap:
|
|
logger.info("System-object ID remap: %s", remap)
|
|
_remapSystemObjectIds(payload, remap)
|
|
|
|
protectedIdSet = set(liveIds.values())
|
|
|
|
imported: Dict[str, dict] = {}
|
|
warnings: List[str] = []
|
|
databases = payload.get("databases", {})
|
|
|
|
for dbName, dbData in databases.items():
|
|
if dbName not in registeredDbs:
|
|
warnings.append(f"Datenbank '{dbName}' uebersprungen (nicht registriert)")
|
|
continue
|
|
|
|
tables = dbData.get("tables", {})
|
|
dbResult: Dict[str, int] = {}
|
|
|
|
conn = _getConnection(dbName)
|
|
try:
|
|
conn.autocommit = False
|
|
existingTables = set(_listTables(conn))
|
|
|
|
for tableName, rows in tables.items():
|
|
if not isinstance(rows, list):
|
|
continue
|
|
if tableName not in existingTables:
|
|
warnings.append(f"Tabelle '{dbName}.{tableName}' existiert nicht, uebersprungen")
|
|
continue
|
|
|
|
physicalCols = _getPhysicalColumns(conn, tableName)
|
|
if not physicalCols:
|
|
continue
|
|
|
|
filteredRows = []
|
|
for row in rows:
|
|
if _isProtectedRow(tableName, row):
|
|
continue
|
|
if row.get("id") and str(row["id"]) in protectedIdSet:
|
|
continue
|
|
filteredRows.append(row)
|
|
|
|
if mode == "replace":
|
|
_deleteNonProtected(conn, tableName, protectedIdSet)
|
|
|
|
insertedCount = _insertRows(conn, tableName, filteredRows, physicalCols, mode)
|
|
dbResult[tableName] = insertedCount
|
|
|
|
conn.commit()
|
|
except Exception as e:
|
|
conn.rollback()
|
|
logger.error("Import failed for database %s: %s", dbName, e)
|
|
warnings.append(f"Import fuer '{dbName}' fehlgeschlagen: {e}")
|
|
continue
|
|
finally:
|
|
conn.close()
|
|
|
|
imported[dbName] = dbResult
|
|
|
|
totalRecords = sum(sum(v.values()) for v in imported.values())
|
|
return {
|
|
"success": True,
|
|
"imported": imported,
|
|
"totalRecords": totalRecords,
|
|
"warnings": warnings,
|
|
}
|
|
|
|
|
|
def _getPhysicalColumns(conn, tableName: str) -> List[str]:
|
|
with conn.cursor() as cur:
|
|
cur.execute("""
|
|
SELECT column_name
|
|
FROM information_schema.columns
|
|
WHERE table_schema = 'public' AND table_name = %s
|
|
ORDER BY ordinal_position
|
|
""", (tableName,))
|
|
return [row["column_name"] for row in cur.fetchall()]
|
|
|
|
|
|
def _deleteNonProtected(conn, tableName: str, protectedIds: Set[str]) -> int:
|
|
"""Delete all rows except protected system objects."""
|
|
if not protectedIds:
|
|
with conn.cursor() as cur:
|
|
cur.execute(f'DELETE FROM "{tableName}"')
|
|
return cur.rowcount
|
|
|
|
idList = list(protectedIds)
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
f'DELETE FROM "{tableName}" WHERE "id"::text != ALL(%(ids)s)',
|
|
{"ids": idList},
|
|
)
|
|
return cur.rowcount
|
|
|
|
|
|
def _insertRows(
|
|
conn,
|
|
tableName: str,
|
|
rows: List[dict],
|
|
physicalCols: List[str],
|
|
mode: str,
|
|
) -> int:
|
|
"""Insert rows into a table. In merge mode, skip rows whose id already exists."""
|
|
if not rows:
|
|
return 0
|
|
|
|
physicalColSet = set(physicalCols)
|
|
inserted = 0
|
|
|
|
for row in rows:
|
|
cols = [c for c in row.keys() if c in physicalColSet]
|
|
if not cols:
|
|
continue
|
|
|
|
values = [_pgSafe(row[c]) for c in cols]
|
|
colNames = ", ".join(f'"{c}"' for c in cols)
|
|
placeholders = ", ".join(["%s"] * len(cols))
|
|
|
|
if mode == "merge":
|
|
sql = f'INSERT INTO "{tableName}" ({colNames}) VALUES ({placeholders}) ON CONFLICT ("id") DO NOTHING'
|
|
else:
|
|
sql = f'INSERT INTO "{tableName}" ({colNames}) VALUES ({placeholders})'
|
|
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute("SAVEPOINT row_sp")
|
|
cur.execute(sql, values)
|
|
inserted += cur.rowcount
|
|
cur.execute("RELEASE SAVEPOINT row_sp")
|
|
except Exception as e:
|
|
logger.warning("Insert failed for %s row: %s", tableName, e)
|
|
with conn.cursor() as cur:
|
|
cur.execute("ROLLBACK TO SAVEPOINT row_sp")
|
|
|
|
return inserted
|
|
|
|
|
|
def _pgSafe(v: Any) -> Any:
|
|
"""Convert Python values to psycopg2-compatible types."""
|
|
import json as _json
|
|
|
|
if v is None or isinstance(v, (str, int, float, bool)):
|
|
return v
|
|
if isinstance(v, (dict, list)):
|
|
return _json.dumps(v)
|
|
return str(v)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prepare import (validate + remap, return context for per-DB import)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _prepareImport(payload: dict) -> dict:
|
|
"""Validate, remap system-object IDs, and return the prepared payload
|
|
together with metadata the frontend needs to drive per-DB import.
|
|
|
|
Returns ``{valid, warnings, systemObjectsFound, databases, protectedIds, remappedPayload}``.
|
|
"""
|
|
validation = _validateImportPayload(payload)
|
|
if not validation.get("valid"):
|
|
return {
|
|
"valid": False,
|
|
"warnings": validation.get("warnings", []),
|
|
"systemObjectsFound": validation.get("systemObjectsFound", []),
|
|
"databases": [],
|
|
"protectedIds": [],
|
|
}
|
|
|
|
liveIds = _loadLiveSystemObjectIds()
|
|
remap = _buildIdRemapFromPayload(payload, liveIds)
|
|
if remap:
|
|
logger.info("System-object ID remap: %s", remap)
|
|
_remapSystemObjectIds(payload, remap)
|
|
|
|
protectedIdSet = set(liveIds.values())
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
dbList = []
|
|
for dbName, dbData in payload.get("databases", {}).items():
|
|
if dbName not in registeredDbs:
|
|
continue
|
|
tables = dbData.get("tables", {})
|
|
recordCount = sum(len(rows) for rows in tables.values() if isinstance(rows, list))
|
|
dbList.append({
|
|
"database": dbName,
|
|
"tableCount": len(tables),
|
|
"recordCount": recordCount,
|
|
})
|
|
|
|
return {
|
|
"valid": True,
|
|
"warnings": validation.get("warnings", []),
|
|
"systemObjectsFound": validation.get("systemObjectsFound", []),
|
|
"databases": dbList,
|
|
"protectedIds": list(protectedIdSet),
|
|
}
|
|
|
|
|
|
def _ensureDatabaseExists(dbName: str) -> bool:
|
|
"""Create the PostgreSQL database if it does not yet exist.
|
|
|
|
Connects to the ``postgres`` admin database using the same credentials
|
|
as the target DB. Returns True if the database was created, False if
|
|
it already existed.
|
|
"""
|
|
registeredDbs = getRegisteredDatabases()
|
|
configPrefix = registeredDbs.get(dbName)
|
|
if configPrefix is None:
|
|
return False
|
|
|
|
hostKey = f"{configPrefix}_HOST" if configPrefix != "DB" else "DB_HOST"
|
|
portKey = f"{configPrefix}_PORT" if configPrefix != "DB" else "DB_PORT"
|
|
userKey = f"{configPrefix}_USER" if configPrefix != "DB" else "DB_USER"
|
|
passwordKey = f"{configPrefix}_PASSWORD_SECRET" if configPrefix != "DB" else "DB_PASSWORD_SECRET"
|
|
|
|
adminConn = psycopg2.connect(
|
|
host=APP_CONFIG.get(hostKey, "localhost"),
|
|
port=int(APP_CONFIG.get(portKey, 5432)),
|
|
database="postgres",
|
|
user=APP_CONFIG.get(userKey),
|
|
password=APP_CONFIG.get(passwordKey),
|
|
client_encoding="utf8",
|
|
)
|
|
try:
|
|
adminConn.autocommit = True
|
|
with adminConn.cursor() as cur:
|
|
cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (dbName,))
|
|
if cur.fetchone():
|
|
return False
|
|
cur.execute(f'CREATE DATABASE "{dbName}"')
|
|
logger.info("Created missing database: %s", dbName)
|
|
return True
|
|
finally:
|
|
adminConn.close()
|
|
|
|
|
|
def _createTableFromExport(conn, tableName: str, rows: List[dict]) -> None:
|
|
"""Create a table based on the column structure found in the export data.
|
|
|
|
Uses TEXT for all columns since we don't have the original DDL.
|
|
The ``id`` column gets a PRIMARY KEY constraint.
|
|
"""
|
|
allKeys: List[str] = []
|
|
seen: set = set()
|
|
for row in rows:
|
|
for k in row.keys():
|
|
if k not in seen:
|
|
allKeys.append(k)
|
|
seen.add(k)
|
|
|
|
if not allKeys:
|
|
return
|
|
|
|
colDefs = []
|
|
for col in allKeys:
|
|
if col == "id":
|
|
colDefs.append(f'"{col}" TEXT PRIMARY KEY')
|
|
else:
|
|
colDefs.append(f'"{col}" TEXT')
|
|
|
|
ddl = f'CREATE TABLE IF NOT EXISTS "{tableName}" ({", ".join(colDefs)})'
|
|
with conn.cursor() as cur:
|
|
cur.execute(ddl)
|
|
logger.info("Created table %s with %d columns", tableName, len(allKeys))
|
|
|
|
|
|
def _getTableImportOrder(conn, tableNames: List[str], dbName: str = "") -> List[str]:
|
|
"""Sort tables by FK dependencies (parents first) using topological sort.
|
|
|
|
Uses Pydantic ``fk_target`` metadata from ``fkRegistry`` as the single
|
|
source of truth (works for ALL databases, not just those with SQL FKs).
|
|
Only *intra-DB* dependencies are considered; cross-DB FKs (e.g. to
|
|
``poweron_app.Mandate``) are handled by importing databases in order.
|
|
"""
|
|
tableSet = set(tableNames)
|
|
allRels = getFkRelationships()
|
|
|
|
deps: Dict[str, Set[str]] = {t: set() for t in tableNames}
|
|
for rel in allRels:
|
|
if rel.sourceDb != dbName or rel.targetDb != dbName:
|
|
continue
|
|
child = rel.sourceTable
|
|
parent = rel.targetTable
|
|
if child in tableSet and parent in tableSet and child != parent:
|
|
deps[child].add(parent)
|
|
|
|
inDegree = {t: len(deps[t]) for t in tableNames}
|
|
queue = sorted(t for t in tableNames if inDegree[t] == 0)
|
|
ordered: List[str] = []
|
|
|
|
while queue:
|
|
node = queue.pop(0)
|
|
ordered.append(node)
|
|
for t in tableNames:
|
|
if node in deps[t]:
|
|
deps[t].discard(node)
|
|
inDegree[t] -= 1
|
|
if inDegree[t] == 0:
|
|
queue.append(t)
|
|
queue.sort()
|
|
|
|
remaining = [t for t in tableNames if t not in set(ordered)]
|
|
if remaining:
|
|
logger.warning("FK cycle detected, appending without order guarantee: %s", remaining)
|
|
ordered.extend(sorted(remaining))
|
|
|
|
return ordered
|
|
|
|
|
|
def _importSingleDb(payload: dict, dbName: str, mode: str, protectedIds: List[str]) -> dict:
|
|
"""Import a single database from the (already remapped) payload.
|
|
|
|
Tables are sorted by FK dependencies: parent tables are inserted first,
|
|
child tables are deleted first (reverse order) in replace mode.
|
|
|
|
Returns ``{database, tables: {tableName: insertedCount}, recordCount, warnings}``.
|
|
"""
|
|
if mode not in ("replace", "merge"):
|
|
raise ValueError(f"Invalid import mode: {mode}")
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
if dbName not in registeredDbs:
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Datenbank '{dbName}' nicht registriert"]}
|
|
|
|
dbData = payload.get("databases", {}).get(dbName)
|
|
if not dbData:
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Keine Daten fuer '{dbName}' im Payload"]}
|
|
|
|
try:
|
|
dbCreated = _ensureDatabaseExists(dbName)
|
|
except Exception as e:
|
|
logger.error("Failed to ensure database %s exists: %s", dbName, e)
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Datenbank '{dbName}' konnte nicht erstellt werden: {e}"]}
|
|
|
|
protectedIdSet = set(protectedIds)
|
|
tables = dbData.get("tables", {})
|
|
warnings: List[str] = []
|
|
dbResult: Dict[str, int] = {}
|
|
excluded = _EXCLUDED_TABLES.get(dbName, set())
|
|
|
|
if dbCreated:
|
|
warnings.append(f"Datenbank '{dbName}' wurde neu erstellt")
|
|
|
|
conn = _getConnection(dbName)
|
|
try:
|
|
existingTables = set(_listTables(conn))
|
|
conn.rollback()
|
|
|
|
# Ensure all import tables exist (create missing ones from export schema)
|
|
conn.autocommit = True
|
|
for tableName, rows in tables.items():
|
|
if tableName in excluded or not isinstance(rows, list) or not rows:
|
|
continue
|
|
if tableName not in existingTables:
|
|
_createTableFromExport(conn, tableName, rows)
|
|
existingTables.add(tableName)
|
|
logger.info("Pre-created missing table %s.%s", dbName, tableName)
|
|
|
|
# Build importable table list and sort by FK dependencies
|
|
importable = [t for t in tables
|
|
if t not in excluded
|
|
and isinstance(tables.get(t), list)
|
|
and t in existingTables]
|
|
importOrder = _getTableImportOrder(conn, importable, dbName)
|
|
|
|
logger.info("Import order for %s: %s", dbName, importOrder)
|
|
|
|
for tableName in tables:
|
|
if tableName in excluded and isinstance(tables.get(tableName), list):
|
|
warnings.append(f"Table '{dbName}.{tableName}' excluded (security/transient)")
|
|
|
|
# Phase 1 (replace only): DELETE children first (reverse topological order)
|
|
if mode == "replace":
|
|
conn.autocommit = False
|
|
for tableName in reversed(importOrder):
|
|
try:
|
|
_deleteNonProtected(conn, tableName, protectedIdSet)
|
|
conn.commit()
|
|
except Exception as e:
|
|
conn.rollback()
|
|
warnings.append(f"DELETE from {dbName}.{tableName} failed: {e}")
|
|
logger.warning("DELETE from %s.%s failed: %s", dbName, tableName, e)
|
|
|
|
# Phase 2: INSERT parents first (topological order)
|
|
conn.autocommit = False
|
|
for tableName in importOrder:
|
|
try:
|
|
rows = tables[tableName]
|
|
physicalCols = _getPhysicalColumns(conn, tableName)
|
|
if not physicalCols:
|
|
conn.rollback()
|
|
continue
|
|
|
|
filteredRows = []
|
|
for row in rows:
|
|
if _isProtectedRow(tableName, row):
|
|
continue
|
|
if row.get("id") and str(row["id"]) in protectedIdSet:
|
|
continue
|
|
filteredRows.append(row)
|
|
|
|
insertedCount = _insertRows(conn, tableName, filteredRows, physicalCols, mode)
|
|
conn.commit()
|
|
dbResult[tableName] = insertedCount
|
|
except Exception as e:
|
|
conn.rollback()
|
|
warnings.append(f"INSERT into {dbName}.{tableName} failed: {e}")
|
|
logger.warning("INSERT into %s.%s failed: %s", dbName, tableName, e)
|
|
except Exception as e:
|
|
logger.error("Import failed for database %s: %s", dbName, e)
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Import fuer '{dbName}' fehlgeschlagen: {e}"]}
|
|
finally:
|
|
conn.close()
|
|
|
|
recordCount = sum(dbResult.values())
|
|
return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Streaming Import (memory-safe, ijson-based)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _iterStreamRows(filePath: str):
|
|
"""Yield ``(dbName, tableName, rowDict)`` one row at a time from an export
|
|
JSON file using ijson streaming parser. Never holds more than one row in RAM.
|
|
"""
|
|
import ijson
|
|
|
|
with open(filePath, "rb") as f:
|
|
currentDb: Optional[str] = None
|
|
currentTable: Optional[str] = None
|
|
rowPrefix: Optional[str] = None
|
|
inRow = False
|
|
stack: List[Tuple[Any, Optional[str]]] = []
|
|
root: dict = {}
|
|
|
|
for prefix, event, value in ijson.parse(f, use_float=True):
|
|
if not inRow:
|
|
if prefix == "databases" and event == "map_key":
|
|
currentDb = value
|
|
currentTable = None
|
|
rowPrefix = None
|
|
elif (currentDb
|
|
and prefix == f"databases.{currentDb}.tables"
|
|
and event == "map_key"):
|
|
currentTable = value
|
|
rowPrefix = f"databases.{currentDb}.tables.{currentTable}.item"
|
|
elif rowPrefix and prefix == rowPrefix and event == "start_map":
|
|
inRow = True
|
|
root = {}
|
|
stack = [(root, None)]
|
|
continue
|
|
|
|
container, pendingKey = stack[-1]
|
|
|
|
if event == "map_key":
|
|
stack[-1] = (container, value)
|
|
|
|
elif event in ("string", "number", "boolean", "null"):
|
|
if isinstance(container, dict):
|
|
container[pendingKey] = value
|
|
stack[-1] = (container, None)
|
|
else:
|
|
container.append(value)
|
|
|
|
elif event == "start_map":
|
|
nested: dict = {}
|
|
if isinstance(container, dict):
|
|
container[pendingKey] = nested
|
|
stack[-1] = (container, None)
|
|
else:
|
|
container.append(nested)
|
|
stack.append((nested, None))
|
|
|
|
elif event == "end_map":
|
|
stack.pop()
|
|
if not stack:
|
|
yield (currentDb, currentTable, root)
|
|
inRow = False
|
|
|
|
elif event == "start_array":
|
|
nested_list: list = []
|
|
if isinstance(container, dict):
|
|
container[pendingKey] = nested_list
|
|
stack[-1] = (container, None)
|
|
else:
|
|
container.append(nested_list)
|
|
stack.append((nested_list, None))
|
|
|
|
elif event == "end_array":
|
|
stack.pop()
|
|
|
|
|
|
def _streamValidate(filePath: str, progressCb=None) -> dict:
|
|
"""Stream-validate an import file without loading it into RAM.
|
|
|
|
Extracts ``meta``, counts databases/tables/records, detects system objects,
|
|
and builds the ID remap -- all with constant memory usage.
|
|
|
|
``progressCb``, when provided, is called on every database/table transition
|
|
with ``{"phase": "validate", "db": str, "table": str, "rows": int}``.
|
|
|
|
Returns the same shape as ``_prepareImport`` plus ``remap``.
|
|
"""
|
|
import ijson
|
|
|
|
meta = None
|
|
with open(filePath, "rb") as f:
|
|
for m in ijson.items(f, "meta"):
|
|
meta = m
|
|
break
|
|
|
|
warnings: List[str] = []
|
|
if not meta:
|
|
warnings.append("Fehlende oder ungueltige 'meta'-Sektion")
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
dbTableCounts: Dict[str, Dict[str, int]] = {}
|
|
systemObjectsFound: List[dict] = []
|
|
prevDb: Optional[str] = None
|
|
prevTable: Optional[str] = None
|
|
|
|
for dbName, tableName, row in _iterStreamRows(filePath):
|
|
if dbName not in dbTableCounts:
|
|
dbTableCounts[dbName] = {}
|
|
if tableName not in dbTableCounts[dbName]:
|
|
dbTableCounts[dbName][tableName] = 0
|
|
dbTableCounts[dbName][tableName] += 1
|
|
|
|
if progressCb and (dbName != prevDb or tableName != prevTable):
|
|
if prevDb and prevTable:
|
|
progressCb({"phase": "validate", "db": prevDb, "table": prevTable,
|
|
"rows": dbTableCounts[prevDb][prevTable]})
|
|
prevDb = dbName
|
|
prevTable = tableName
|
|
|
|
if dbName == "poweron_app":
|
|
if (tableName == "Mandate"
|
|
and row.get("name") == "root"
|
|
and row.get("isSystem") is True):
|
|
systemObjectsFound.append({
|
|
"type": "mandate",
|
|
"label": "Root Mandate",
|
|
"payloadId": row.get("id"),
|
|
})
|
|
elif tableName == "UserInDB":
|
|
uname = row.get("username")
|
|
if uname == "admin":
|
|
systemObjectsFound.append({
|
|
"type": "user",
|
|
"label": "Admin User",
|
|
"payloadId": row.get("id"),
|
|
})
|
|
elif uname == "event":
|
|
systemObjectsFound.append({
|
|
"type": "user",
|
|
"label": "Event User",
|
|
"payloadId": row.get("id"),
|
|
})
|
|
|
|
if progressCb and prevDb and prevTable:
|
|
progressCb({"phase": "validate", "db": prevDb, "table": prevTable,
|
|
"rows": dbTableCounts[prevDb][prevTable]})
|
|
|
|
if not dbTableCounts:
|
|
warnings.append("Fehlende oder ungueltige 'databases'-Sektion")
|
|
|
|
summary: List[dict] = []
|
|
for dbName, tables in dbTableCounts.items():
|
|
registered = dbName in registeredDbs
|
|
if not registered:
|
|
warnings.append(f"Datenbank '{dbName}' ist nicht registriert und wird uebersprungen")
|
|
summary.append({
|
|
"database": dbName,
|
|
"tableCount": len(tables),
|
|
"recordCount": sum(tables.values()),
|
|
"registered": registered,
|
|
})
|
|
|
|
liveIds = _loadLiveSystemObjectIds()
|
|
|
|
remap: Dict[str, str] = {}
|
|
for obj in systemObjectsFound:
|
|
oldId = str(obj.get("payloadId", ""))
|
|
if obj["type"] == "mandate":
|
|
newId = liveIds.get("rootMandate", "")
|
|
elif obj["label"] == "Admin User":
|
|
newId = liveIds.get("adminUser", "")
|
|
elif obj["label"] == "Event User":
|
|
newId = liveIds.get("eventUser", "")
|
|
else:
|
|
continue
|
|
if oldId and newId and oldId != newId:
|
|
remap[oldId] = newId
|
|
|
|
if remap:
|
|
logger.info("System-object ID remap: %s", remap)
|
|
|
|
protectedIdSet = set(liveIds.values())
|
|
valid = any(s["registered"] for s in summary)
|
|
|
|
dbList: List[dict] = []
|
|
for s in summary:
|
|
if s["registered"]:
|
|
dbList.append({
|
|
"database": s["database"],
|
|
"tableCount": s["tableCount"],
|
|
"recordCount": s["recordCount"],
|
|
})
|
|
|
|
return {
|
|
"valid": valid,
|
|
"summary": summary,
|
|
"warnings": warnings,
|
|
"systemObjectsFound": systemObjectsFound,
|
|
"databases": dbList,
|
|
"protectedIds": list(protectedIdSet),
|
|
"remap": remap,
|
|
}
|
|
|
|
|
|
def _streamSplitToFiles(
|
|
filePath: str,
|
|
tmpDir: str,
|
|
token: str,
|
|
remap: Dict[str, str],
|
|
progressCb=None,
|
|
) -> Dict[str, Dict[str, str]]:
|
|
"""Stream through the export file a second time, applying ID remap on
|
|
each row and writing per-table JSONL temp files.
|
|
|
|
``progressCb``, when provided, is called on every database/table transition
|
|
with ``{"phase": "split", "db": str, "table": str, "rows": int}``.
|
|
|
|
Returns ``{dbName: {tableName: filePath}}``.
|
|
"""
|
|
import os
|
|
|
|
remapSet = set(remap.keys()) if remap else set()
|
|
dbFiles: Dict[str, Dict[str, str]] = {}
|
|
writers: Dict[Tuple[str, str], Any] = {}
|
|
tableCounts: Dict[Tuple[str, str], int] = {}
|
|
prevDb: Optional[str] = None
|
|
prevTable: Optional[str] = None
|
|
|
|
try:
|
|
for dbName, tableName, row in _iterStreamRows(filePath):
|
|
if remap:
|
|
_remapRowValues(row, remap, remapSet)
|
|
|
|
key = (dbName, tableName)
|
|
if key not in writers:
|
|
if progressCb and prevDb and prevTable:
|
|
progressCb({"phase": "split", "db": prevDb, "table": prevTable,
|
|
"rows": tableCounts.get((prevDb, prevTable), 0)})
|
|
prevDb = dbName
|
|
prevTable = tableName
|
|
|
|
tblPath = os.path.join(
|
|
tmpDir,
|
|
f"poweron_import_{token}_{dbName}__{tableName}.jsonl",
|
|
)
|
|
writers[key] = open(tblPath, "w", encoding="utf-8")
|
|
if dbName not in dbFiles:
|
|
dbFiles[dbName] = {}
|
|
dbFiles[dbName][tableName] = tblPath
|
|
tableCounts[key] = 0
|
|
|
|
writers[key].write(json.dumps(row, ensure_ascii=False, default=str))
|
|
writers[key].write("\n")
|
|
tableCounts[key] += 1
|
|
finally:
|
|
for fh in writers.values():
|
|
fh.close()
|
|
|
|
if progressCb and prevDb and prevTable:
|
|
progressCb({"phase": "split", "db": prevDb, "table": prevTable,
|
|
"rows": tableCounts.get((prevDb, prevTable), 0)})
|
|
|
|
return dbFiles
|
|
|
|
|
|
def _importSingleDbFromFiles(
|
|
tableFiles: Dict[str, str],
|
|
dbName: str,
|
|
mode: str,
|
|
protectedIds: List[str],
|
|
) -> dict:
|
|
"""Import a single database from per-table JSONL files.
|
|
|
|
Each file contains one JSON object per line (rows).
|
|
Tables are sorted by FK dependencies before import.
|
|
|
|
Returns ``{database, tables, recordCount, warnings}``.
|
|
"""
|
|
if mode not in ("replace", "merge"):
|
|
raise ValueError(f"Invalid import mode: {mode}")
|
|
|
|
registeredDbs = getRegisteredDatabases()
|
|
if dbName not in registeredDbs:
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Datenbank '{dbName}' nicht registriert"]}
|
|
|
|
if not tableFiles:
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Keine Daten fuer '{dbName}'"]}
|
|
|
|
try:
|
|
dbCreated = _ensureDatabaseExists(dbName)
|
|
except Exception as e:
|
|
logger.error("Failed to ensure database %s exists: %s", dbName, e)
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Datenbank '{dbName}' konnte nicht erstellt werden: {e}"]}
|
|
|
|
protectedIdSet = set(protectedIds)
|
|
warnings: List[str] = []
|
|
dbResult: Dict[str, int] = {}
|
|
excluded = _EXCLUDED_TABLES.get(dbName, set())
|
|
|
|
if dbCreated:
|
|
warnings.append(f"Datenbank '{dbName}' wurde neu erstellt")
|
|
|
|
conn = _getConnection(dbName)
|
|
try:
|
|
existingTables = set(_listTables(conn))
|
|
conn.rollback()
|
|
|
|
conn.autocommit = True
|
|
for tableName, tblPath in tableFiles.items():
|
|
if tableName in excluded:
|
|
continue
|
|
if tableName not in existingTables:
|
|
sampleRows: List[dict] = []
|
|
with open(tblPath, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
sampleRows.append(json.loads(line))
|
|
if len(sampleRows) >= 10:
|
|
break
|
|
if sampleRows:
|
|
_createTableFromExport(conn, tableName, sampleRows)
|
|
existingTables.add(tableName)
|
|
logger.info("Pre-created missing table %s.%s", dbName, tableName)
|
|
|
|
importable = [t for t in tableFiles if t not in excluded and t in existingTables]
|
|
importOrder = _getTableImportOrder(conn, importable, dbName)
|
|
|
|
logger.info("Import order for %s: %s", dbName, importOrder)
|
|
|
|
for tableName in tableFiles:
|
|
if tableName in excluded:
|
|
warnings.append(f"Table '{dbName}.{tableName}' excluded (security/transient)")
|
|
|
|
if mode == "replace":
|
|
conn.autocommit = False
|
|
for tableName in reversed(importOrder):
|
|
try:
|
|
_deleteNonProtected(conn, tableName, protectedIdSet)
|
|
conn.commit()
|
|
except Exception as e:
|
|
conn.rollback()
|
|
warnings.append(f"DELETE from {dbName}.{tableName} failed: {e}")
|
|
logger.warning("DELETE from %s.%s failed: %s", dbName, tableName, e)
|
|
|
|
conn.autocommit = False
|
|
batchSize = 100
|
|
for tableName in importOrder:
|
|
tblPath = tableFiles.get(tableName)
|
|
if not tblPath:
|
|
continue
|
|
try:
|
|
physicalCols = _getPhysicalColumns(conn, tableName)
|
|
if not physicalCols:
|
|
conn.rollback()
|
|
continue
|
|
|
|
insertedCount = 0
|
|
batch: List[dict] = []
|
|
with open(tblPath, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
row = json.loads(line)
|
|
if _isProtectedRow(tableName, row):
|
|
continue
|
|
if row.get("id") and str(row["id"]) in protectedIdSet:
|
|
continue
|
|
batch.append(row)
|
|
if len(batch) >= batchSize:
|
|
insertedCount += _insertRows(conn, tableName, batch, physicalCols, mode)
|
|
batch = []
|
|
if batch:
|
|
insertedCount += _insertRows(conn, tableName, batch, physicalCols, mode)
|
|
|
|
conn.commit()
|
|
dbResult[tableName] = insertedCount
|
|
except Exception as e:
|
|
conn.rollback()
|
|
warnings.append(f"INSERT into {dbName}.{tableName} failed: {e}")
|
|
logger.warning("INSERT into %s.%s failed: %s", dbName, tableName, e)
|
|
except Exception as e:
|
|
logger.error("Import failed for database %s: %s", dbName, e)
|
|
return {"database": dbName, "tables": {}, "recordCount": 0,
|
|
"warnings": [f"Import fuer '{dbName}' fehlgeschlagen: {e}"]}
|
|
finally:
|
|
conn.close()
|
|
|
|
recordCount = sum(dbResult.values())
|
|
return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings}
|