# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Database migration utilities — backup (export) and restore (import) for all registered PowerOn databases. System objects (root mandate, admin user, event user) are protected: they are never deleted or overwritten during import. Their IDs in the backup payload are remapped to the IDs of the corresponding live objects so that all FK references stay consistent. All functions are intended for SysAdmin use only (access control in the route layer). """ import json import logging from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Set, Tuple import psycopg2 import psycopg2.extras from modules.shared.configuration import APP_CONFIG from modules.shared.dbRegistry import getRegisteredDatabases from modules.shared.fkRegistry import getFkRelationships from modules.datamodels.datamodelBase import MODEL_REGISTRY from modules.system.databaseHealth import _getConnection, _jsonSafe logger = logging.getLogger(__name__) _EXPORT_FORMAT_VERSION = "1.0" _SYSTEM_TABLE = "_system" _EXCLUDED_TABLES: Dict[str, Set[str]] = { "poweron_app": {"Token", "AuthEvent"}, } # --------------------------------------------------------------------------- # Instance label # --------------------------------------------------------------------------- def _getInstanceLabel() -> str: """Return the instance type from APP_ENV_TYPE (e.g. 'dev', 'int', 'prod').""" return APP_CONFIG.get("APP_ENV_TYPE", "unknown") # --------------------------------------------------------------------------- # Database list # --------------------------------------------------------------------------- def _getAvailableDatabases() -> List[dict]: """Return registered databases with table/row counts for the UI.""" registeredDbs = getRegisteredDatabases() results: List[dict] = [] for dbName in sorted(registeredDbs): if dbName == "poweron_test": continue entry: dict = {"name": dbName, "tableCount": 0, "recordCount": 0} try: conn = _getConnection(dbName) try: with conn.cursor() as cur: cur.execute(""" SELECT relname, n_live_tup FROM pg_stat_user_tables WHERE schemaname = 'public' AND relname NOT LIKE '\\_%%' """) for row in cur.fetchall(): entry["tableCount"] += 1 entry["recordCount"] += int(row["n_live_tup"]) finally: conn.close() except Exception as e: logger.warning("Could not stat database %s: %s", dbName, e) results.append(entry) return results # --------------------------------------------------------------------------- # Export # --------------------------------------------------------------------------- def _exportDatabases(databases: List[str]) -> dict: """Export selected databases as a JSON-serialisable dict. Returns ``{meta: {...}, databases: {dbName: {tables: {tbl: [rows]}, summary: {...}}}}`` """ registeredDbs = getRegisteredDatabases() if not databases: raise ValueError("No databases selected for export.") exportData: dict = { "meta": { "exportedAt": datetime.now(timezone.utc).isoformat(), "version": _EXPORT_FORMAT_VERSION, "databaseCount": 0, "totalTables": 0, "totalRecords": 0, }, "databases": {}, } for dbName in databases: if dbName not in registeredDbs: logger.warning("Export: skipping unregistered database %s", dbName) continue try: dbPayload = _exportSingleDb(dbName) exportData["databases"][dbName] = dbPayload exportData["meta"]["databaseCount"] += 1 exportData["meta"]["totalTables"] += dbPayload["tableCount"] exportData["meta"]["totalRecords"] += dbPayload["totalRecords"] except Exception as e: logger.error("Export failed for database %s: %s", dbName, e) return exportData def _getModelTablesForDb(dbName: str, physicalTables: List[str]) -> List[str]: """Return only those physical tables that have a matching Pydantic model registered in MODEL_REGISTRY. Tables without a Pydantic class (legacy / orphan tables) are excluded from export so the backup contains only model-backed data. Note: the same model can exist in multiple databases (shared-table pattern), so we only check membership in MODEL_REGISTRY, not the DB mapping. """ return sorted( t for t in physicalTables if t in MODEL_REGISTRY ) def _exportSingleDb(dbName: str) -> dict: conn = _getConnection(dbName) excluded = _EXCLUDED_TABLES.get(dbName, set()) try: allTables = _listTables(conn) modelTables = _getModelTablesForDb(dbName, allTables) skippedLegacy = set(allTables) - set(modelTables) - excluded - {_SYSTEM_TABLE} if skippedLegacy: logger.info("Export %s: skipping %d legacy tables without model: %s", dbName, len(skippedLegacy), sorted(skippedLegacy)) dbPayload: dict = {"tables": {}, "summary": {}, "tableCount": 0, "totalRecords": 0} for tbl in modelTables: if tbl in excluded: logger.info("Export: skipping excluded table %s.%s", dbName, tbl) continue rows = _readTableRows(conn, tbl) dbPayload["tables"][tbl] = rows dbPayload["summary"][tbl] = {"recordCount": len(rows)} dbPayload["tableCount"] += 1 dbPayload["totalRecords"] += len(rows) return dbPayload finally: conn.close() def _listTables(conn) -> List[str]: with conn.cursor() as cur: cur.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' AND table_name != %s ORDER BY table_name """, (_SYSTEM_TABLE,)) return [row["table_name"] for row in cur.fetchall()] def _readTableRows(conn, tableName: str) -> List[dict]: with conn.cursor() as cur: cur.execute(f'SELECT * FROM "{tableName}"') return [{k: _jsonSafe(v) for k, v in dict(row).items()} for row in cur.fetchall()] # --------------------------------------------------------------------------- # Streaming Export (memory-safe, handles arbitrarily large databases) # --------------------------------------------------------------------------- def streamExportGenerator(databases: List[str], instanceLabel: str = ""): """Yield JSON fragments for a streaming database export. Writes valid JSON incrementally (row-by-row, table-by-table) so that neither the backend RAM nor the browser JS heap is saturated — works for databases of any size. The output format is identical to the non-streaming _exportDatabases(): {"meta": {...}, "databases": {"dbName": {"tables": {"tbl": [rows]}, ...}}} """ import json registeredDbs = getRegisteredDatabases() validDbs = [db for db in databases if db in registeredDbs] totalDbs = 0 totalTables = 0 totalRecords = 0 yield '{"meta":' metaPlaceholder = json.dumps({ "exportedAt": datetime.now(timezone.utc).isoformat(), "version": _EXPORT_FORMAT_VERSION, "instanceLabel": instanceLabel, "databaseCount": "<>", }, ensure_ascii=False) yield metaPlaceholder yield ',"databases":{' firstDb = True for dbName in validDbs: excluded = _EXCLUDED_TABLES.get(dbName, set()) conn = None try: conn = _getConnection(dbName) allTables = _listTables(conn) modelTables = _getModelTablesForDb(dbName, allTables) if not firstDb: yield ',' firstDb = False yield json.dumps(dbName, ensure_ascii=False) yield ':{"tables":{' firstTable = True dbTableCount = 0 dbRecordCount = 0 for tbl in modelTables: if tbl in excluded: continue if not firstTable: yield ',' firstTable = False yield json.dumps(tbl, ensure_ascii=False) yield ':[' with conn.cursor(name=f"export_{dbName}_{tbl}") as cur: cur.itersize = 2000 cur.execute(f'SELECT * FROM "{tbl}"') firstRow = True rowCount = 0 for row in cur: if not firstRow: yield ',' firstRow = False safeRow = {k: _jsonSafe(v) for k, v in dict(row).items()} yield json.dumps(safeRow, ensure_ascii=False, default=str) rowCount += 1 yield ']' dbTableCount += 1 dbRecordCount += rowCount yield '},"summary":{' firstSummaryTable = True for tbl in modelTables: if tbl in excluded: continue if not firstSummaryTable: yield ',' firstSummaryTable = False yield json.dumps(tbl, ensure_ascii=False) yield ':{"recordCount":0}' yield '}' yield f',"tableCount":{dbTableCount},"totalRecords":{dbRecordCount}' yield '}' totalDbs += 1 totalTables += dbTableCount totalRecords += dbRecordCount logger.info("Stream export: %s done (%d tables, %d records)", dbName, dbTableCount, dbRecordCount) except Exception as e: logger.error("Stream export failed for %s: %s", dbName, e) if not firstDb or not firstTable: pass finally: if conn: try: conn.close() except Exception: pass yield '}}' # --------------------------------------------------------------------------- # Validate # --------------------------------------------------------------------------- def _validateImportPayload(payload: dict) -> dict: """Validate an import payload without writing anything. Returns ``{valid, summary, warnings, systemObjectsFound}``. """ warnings: List[str] = [] summary: List[dict] = [] meta = payload.get("meta") if not meta or not isinstance(meta, dict): return {"valid": False, "summary": [], "warnings": ["Fehlende oder ungueltige 'meta'-Sektion"], "systemObjectsFound": []} version = meta.get("version", "") if version != _EXPORT_FORMAT_VERSION: warnings.append(f"Unbekannte Format-Version: {version} (erwartet: {_EXPORT_FORMAT_VERSION})") databases = payload.get("databases") if not databases or not isinstance(databases, dict): return {"valid": False, "summary": [], "warnings": ["Fehlende oder ungueltige 'databases'-Sektion"], "systemObjectsFound": []} registeredDbs = getRegisteredDatabases() for dbName, dbData in databases.items(): tables = dbData.get("tables", {}) tableCount = len(tables) recordCount = sum(len(rows) for rows in tables.values() if isinstance(rows, list)) registered = dbName in registeredDbs if not registered: warnings.append(f"Datenbank '{dbName}' ist nicht registriert und wird uebersprungen") summary.append({ "database": dbName, "tableCount": tableCount, "recordCount": recordCount, "registered": registered, }) systemObjectsFound = _detectSystemObjectsInPayload(payload) valid = any(s["registered"] for s in summary) return { "valid": valid, "summary": summary, "warnings": warnings, "systemObjectsFound": systemObjectsFound, } def _detectSystemObjectsInPayload(payload: dict) -> List[dict]: """Find system objects (root mandate, admin user, event user) in a payload.""" found: List[dict] = [] appData = payload.get("databases", {}).get("poweron_app", {}).get("tables", {}) for row in appData.get("Mandate", []): if row.get("name") == "root" and row.get("isSystem") is True: found.append({"type": "mandate", "label": "Root Mandate", "payloadId": row.get("id")}) for row in appData.get("UserInDB", []): if row.get("username") == "admin": found.append({"type": "user", "label": "Admin User", "payloadId": row.get("id")}) elif row.get("username") == "event": found.append({"type": "user", "label": "Event User", "payloadId": row.get("id")}) return found # --------------------------------------------------------------------------- # System-object ID remapping # --------------------------------------------------------------------------- def _loadLiveSystemObjectIds() -> Dict[str, str]: """Load the IDs of the 3 protected system objects from the live DB. Returns a dict like ``{"rootMandate": "", "adminUser": "", "eventUser": ""}``. """ registeredDbs = getRegisteredDatabases() if "poweron_app" not in registeredDbs: return {} result: Dict[str, str] = {} conn = _getConnection("poweron_app") try: with conn.cursor() as cur: cur.execute("""SELECT id FROM "Mandate" WHERE "name" = 'root' AND "isSystem" = true LIMIT 1""") row = cur.fetchone() if row: result["rootMandate"] = str(row["id"]) cur.execute("""SELECT id FROM "UserInDB" WHERE "username" = 'admin' LIMIT 1""") row = cur.fetchone() if row: result["adminUser"] = str(row["id"]) cur.execute("""SELECT id FROM "UserInDB" WHERE "username" = 'event' LIMIT 1""") row = cur.fetchone() if row: result["eventUser"] = str(row["id"]) finally: conn.close() return result def _buildIdRemapFromPayload(payload: dict, liveIds: Dict[str, str]) -> Dict[str, str]: """Build an ``{oldId: newId}`` mapping for system objects. Compares IDs found in the payload with the live system-object IDs. Only entries where the IDs actually differ are included. """ remap: Dict[str, str] = {} appTables = payload.get("databases", {}).get("poweron_app", {}).get("tables", {}) for row in appTables.get("Mandate", []): if row.get("name") == "root" and row.get("isSystem") is True: oldId = str(row.get("id", "")) newId = liveIds.get("rootMandate", "") if oldId and newId and oldId != newId: remap[oldId] = newId for row in appTables.get("UserInDB", []): username = row.get("username") oldId = str(row.get("id", "")) if username == "admin": newId = liveIds.get("adminUser", "") elif username == "event": newId = liveIds.get("eventUser", "") else: continue if oldId and newId and oldId != newId: remap[oldId] = newId return remap def _remapSystemObjectIds(payload: dict, remap: Dict[str, str]) -> dict: """Walk the entire payload and replace every value that matches an old system-object ID.""" if not remap: return payload remapSet = set(remap.keys()) databases = payload.get("databases", {}) for dbName, dbData in databases.items(): tables = dbData.get("tables", {}) for tableName, rows in tables.items(): if not isinstance(rows, list): continue for row in rows: _remapRowValues(row, remap, remapSet) return payload def _remapDbTables(tables: dict, remap: Dict[str, str]) -> None: """In-place remap system-object IDs in a single DB's tables dict.""" if not remap: return remapSet = set(remap.keys()) for tableName, rows in tables.items(): if not isinstance(rows, list): continue for row in rows: _remapRowValues(row, remap, remapSet) def _remapRowValues(row: dict, remap: Dict[str, str], remapSet: Set[str]) -> None: """In-place replace string values in a row dict that match a remap key.""" for key, val in row.items(): if isinstance(val, str) and val in remapSet: row[key] = remap[val] elif isinstance(val, dict): _remapRowValues(val, remap, remapSet) elif isinstance(val, list): for i, item in enumerate(val): if isinstance(item, str) and item in remapSet: val[i] = remap[item] elif isinstance(item, dict): _remapRowValues(item, remap, remapSet) # --------------------------------------------------------------------------- # Import # --------------------------------------------------------------------------- _PROTECTED_ROWS: Dict[str, List[dict]] = { "Mandate": [{"name": "root", "isSystem": True}], "UserInDB": [{"username": "admin"}, {"username": "event"}], } def _isProtectedRow(tableName: str, row: dict) -> bool: """Return True if a row represents a protected system object.""" patterns = _PROTECTED_ROWS.get(tableName, []) for pattern in patterns: if all(row.get(k) == v for k, v in pattern.items()): return True return False def _importDatabases(payload: dict, mode: str) -> dict: """Import databases from a validated payload. ``mode`` is ``"replace"`` (clear + insert) or ``"merge"`` (insert missing only). """ if mode not in ("replace", "merge"): raise ValueError(f"Invalid import mode: {mode}") registeredDbs = getRegisteredDatabases() liveIds = _loadLiveSystemObjectIds() remap = _buildIdRemapFromPayload(payload, liveIds) if remap: logger.info("System-object ID remap: %s", remap) _remapSystemObjectIds(payload, remap) protectedIdSet = set(liveIds.values()) imported: Dict[str, dict] = {} warnings: List[str] = [] databases = payload.get("databases", {}) for dbName, dbData in databases.items(): if dbName not in registeredDbs: warnings.append(f"Datenbank '{dbName}' uebersprungen (nicht registriert)") continue tables = dbData.get("tables", {}) dbResult: Dict[str, int] = {} conn = _getConnection(dbName) try: conn.autocommit = False existingTables = set(_listTables(conn)) for tableName, rows in tables.items(): if not isinstance(rows, list): continue if tableName not in existingTables: warnings.append(f"Tabelle '{dbName}.{tableName}' existiert nicht, uebersprungen") continue physicalCols = _getPhysicalColumns(conn, tableName) if not physicalCols: continue filteredRows = [] for row in rows: if _isProtectedRow(tableName, row): continue if row.get("id") and str(row["id"]) in protectedIdSet: continue filteredRows.append(row) if mode == "replace": _deleteNonProtected(conn, tableName, protectedIdSet) insertedCount = _insertRows(conn, tableName, filteredRows, physicalCols, mode) dbResult[tableName] = insertedCount conn.commit() except Exception as e: conn.rollback() logger.error("Import failed for database %s: %s", dbName, e) warnings.append(f"Import fuer '{dbName}' fehlgeschlagen: {e}") continue finally: conn.close() imported[dbName] = dbResult totalRecords = sum(sum(v.values()) for v in imported.values()) return { "success": True, "imported": imported, "totalRecords": totalRecords, "warnings": warnings, } def _getPhysicalColumns(conn, tableName: str) -> List[str]: with conn.cursor() as cur: cur.execute(""" SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = %s ORDER BY ordinal_position """, (tableName,)) return [row["column_name"] for row in cur.fetchall()] def _deleteNonProtected(conn, tableName: str, protectedIds: Set[str]) -> int: """Delete all rows except protected system objects.""" if not protectedIds: with conn.cursor() as cur: cur.execute(f'DELETE FROM "{tableName}"') return cur.rowcount idList = list(protectedIds) with conn.cursor() as cur: cur.execute( f'DELETE FROM "{tableName}" WHERE "id"::text != ALL(%(ids)s)', {"ids": idList}, ) return cur.rowcount def _insertRows( conn, tableName: str, rows: List[dict], physicalCols: List[str], mode: str, ) -> int: """Insert rows into a table. In merge mode, skip rows whose id already exists.""" if not rows: return 0 physicalColSet = set(physicalCols) inserted = 0 for row in rows: cols = [c for c in row.keys() if c in physicalColSet] if not cols: continue values = [_pgSafe(row[c]) for c in cols] colNames = ", ".join(f'"{c}"' for c in cols) placeholders = ", ".join(["%s"] * len(cols)) if mode == "merge": sql = f'INSERT INTO "{tableName}" ({colNames}) VALUES ({placeholders}) ON CONFLICT ("id") DO NOTHING' else: sql = f'INSERT INTO "{tableName}" ({colNames}) VALUES ({placeholders})' try: with conn.cursor() as cur: cur.execute("SAVEPOINT row_sp") cur.execute(sql, values) inserted += cur.rowcount cur.execute("RELEASE SAVEPOINT row_sp") except Exception as e: logger.warning("Insert failed for %s row: %s", tableName, e) with conn.cursor() as cur: cur.execute("ROLLBACK TO SAVEPOINT row_sp") return inserted def _pgSafe(v: Any) -> Any: """Convert Python values to psycopg2-compatible types.""" import json as _json if v is None or isinstance(v, (str, int, float, bool)): return v if isinstance(v, (dict, list)): return _json.dumps(v) return str(v) # --------------------------------------------------------------------------- # Prepare import (validate + remap, return context for per-DB import) # --------------------------------------------------------------------------- def _prepareImport(payload: dict) -> dict: """Validate, remap system-object IDs, and return the prepared payload together with metadata the frontend needs to drive per-DB import. Returns ``{valid, warnings, systemObjectsFound, databases, protectedIds, remappedPayload}``. """ validation = _validateImportPayload(payload) if not validation.get("valid"): return { "valid": False, "warnings": validation.get("warnings", []), "systemObjectsFound": validation.get("systemObjectsFound", []), "databases": [], "protectedIds": [], } liveIds = _loadLiveSystemObjectIds() remap = _buildIdRemapFromPayload(payload, liveIds) if remap: logger.info("System-object ID remap: %s", remap) _remapSystemObjectIds(payload, remap) protectedIdSet = set(liveIds.values()) registeredDbs = getRegisteredDatabases() dbList = [] for dbName, dbData in payload.get("databases", {}).items(): if dbName not in registeredDbs: continue tables = dbData.get("tables", {}) recordCount = sum(len(rows) for rows in tables.values() if isinstance(rows, list)) dbList.append({ "database": dbName, "tableCount": len(tables), "recordCount": recordCount, }) return { "valid": True, "warnings": validation.get("warnings", []), "systemObjectsFound": validation.get("systemObjectsFound", []), "databases": dbList, "protectedIds": list(protectedIdSet), } def _ensureDatabaseExists(dbName: str) -> bool: """Create the PostgreSQL database if it does not yet exist. Connects to the ``postgres`` admin database using the same credentials as the target DB. Returns True if the database was created, False if it already existed. """ registeredDbs = getRegisteredDatabases() configPrefix = registeredDbs.get(dbName) if configPrefix is None: return False hostKey = f"{configPrefix}_HOST" if configPrefix != "DB" else "DB_HOST" portKey = f"{configPrefix}_PORT" if configPrefix != "DB" else "DB_PORT" userKey = f"{configPrefix}_USER" if configPrefix != "DB" else "DB_USER" passwordKey = f"{configPrefix}_PASSWORD_SECRET" if configPrefix != "DB" else "DB_PASSWORD_SECRET" adminConn = psycopg2.connect( host=APP_CONFIG.get(hostKey, "localhost"), port=int(APP_CONFIG.get(portKey, 5432)), database="postgres", user=APP_CONFIG.get(userKey), password=APP_CONFIG.get(passwordKey), client_encoding="utf8", ) try: adminConn.autocommit = True with adminConn.cursor() as cur: cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (dbName,)) if cur.fetchone(): return False cur.execute(f'CREATE DATABASE "{dbName}"') logger.info("Created missing database: %s", dbName) return True finally: adminConn.close() def _createTableFromExport(conn, tableName: str, rows: List[dict]) -> None: """Create a table based on the column structure found in the export data. Uses TEXT for all columns since we don't have the original DDL. The ``id`` column gets a PRIMARY KEY constraint. """ allKeys: List[str] = [] seen: set = set() for row in rows: for k in row.keys(): if k not in seen: allKeys.append(k) seen.add(k) if not allKeys: return colDefs = [] for col in allKeys: if col == "id": colDefs.append(f'"{col}" TEXT PRIMARY KEY') else: colDefs.append(f'"{col}" TEXT') ddl = f'CREATE TABLE IF NOT EXISTS "{tableName}" ({", ".join(colDefs)})' with conn.cursor() as cur: cur.execute(ddl) logger.info("Created table %s with %d columns", tableName, len(allKeys)) def _getTableImportOrder(conn, tableNames: List[str], dbName: str = "") -> List[str]: """Sort tables by FK dependencies (parents first) using topological sort. Uses Pydantic ``fk_target`` metadata from ``fkRegistry`` as the single source of truth (works for ALL databases, not just those with SQL FKs). Only *intra-DB* dependencies are considered; cross-DB FKs (e.g. to ``poweron_app.Mandate``) are handled by importing databases in order. """ tableSet = set(tableNames) allRels = getFkRelationships() deps: Dict[str, Set[str]] = {t: set() for t in tableNames} for rel in allRels: if rel.sourceDb != dbName or rel.targetDb != dbName: continue child = rel.sourceTable parent = rel.targetTable if child in tableSet and parent in tableSet and child != parent: deps[child].add(parent) inDegree = {t: len(deps[t]) for t in tableNames} queue = sorted(t for t in tableNames if inDegree[t] == 0) ordered: List[str] = [] while queue: node = queue.pop(0) ordered.append(node) for t in tableNames: if node in deps[t]: deps[t].discard(node) inDegree[t] -= 1 if inDegree[t] == 0: queue.append(t) queue.sort() remaining = [t for t in tableNames if t not in set(ordered)] if remaining: logger.warning("FK cycle detected, appending without order guarantee: %s", remaining) ordered.extend(sorted(remaining)) return ordered def _importSingleDb(payload: dict, dbName: str, mode: str, protectedIds: List[str]) -> dict: """Import a single database from the (already remapped) payload. Tables are sorted by FK dependencies: parent tables are inserted first, child tables are deleted first (reverse order) in replace mode. Returns ``{database, tables: {tableName: insertedCount}, recordCount, warnings}``. """ if mode not in ("replace", "merge"): raise ValueError(f"Invalid import mode: {mode}") registeredDbs = getRegisteredDatabases() if dbName not in registeredDbs: return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Datenbank '{dbName}' nicht registriert"]} dbData = payload.get("databases", {}).get(dbName) if not dbData: return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Keine Daten fuer '{dbName}' im Payload"]} try: dbCreated = _ensureDatabaseExists(dbName) except Exception as e: logger.error("Failed to ensure database %s exists: %s", dbName, e) return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Datenbank '{dbName}' konnte nicht erstellt werden: {e}"]} protectedIdSet = set(protectedIds) tables = dbData.get("tables", {}) warnings: List[str] = [] dbResult: Dict[str, int] = {} excluded = _EXCLUDED_TABLES.get(dbName, set()) if dbCreated: warnings.append(f"Datenbank '{dbName}' wurde neu erstellt") conn = _getConnection(dbName) try: existingTables = set(_listTables(conn)) conn.rollback() # Ensure all import tables exist (create missing ones from export schema) conn.autocommit = True for tableName, rows in tables.items(): if tableName in excluded or not isinstance(rows, list) or not rows: continue if tableName not in existingTables: _createTableFromExport(conn, tableName, rows) existingTables.add(tableName) logger.info("Pre-created missing table %s.%s", dbName, tableName) # Build importable table list and sort by FK dependencies importable = [t for t in tables if t not in excluded and isinstance(tables.get(t), list) and t in existingTables] importOrder = _getTableImportOrder(conn, importable, dbName) logger.info("Import order for %s: %s", dbName, importOrder) for tableName in tables: if tableName in excluded and isinstance(tables.get(tableName), list): warnings.append(f"Table '{dbName}.{tableName}' excluded (security/transient)") # Phase 1 (replace only): DELETE children first (reverse topological order) if mode == "replace": conn.autocommit = False for tableName in reversed(importOrder): try: _deleteNonProtected(conn, tableName, protectedIdSet) conn.commit() except Exception as e: conn.rollback() warnings.append(f"DELETE from {dbName}.{tableName} failed: {e}") logger.warning("DELETE from %s.%s failed: %s", dbName, tableName, e) # Phase 2: INSERT parents first (topological order) conn.autocommit = False for tableName in importOrder: try: rows = tables[tableName] physicalCols = _getPhysicalColumns(conn, tableName) if not physicalCols: conn.rollback() continue filteredRows = [] for row in rows: if _isProtectedRow(tableName, row): continue if row.get("id") and str(row["id"]) in protectedIdSet: continue filteredRows.append(row) insertedCount = _insertRows(conn, tableName, filteredRows, physicalCols, mode) conn.commit() dbResult[tableName] = insertedCount except Exception as e: conn.rollback() warnings.append(f"INSERT into {dbName}.{tableName} failed: {e}") logger.warning("INSERT into %s.%s failed: %s", dbName, tableName, e) except Exception as e: logger.error("Import failed for database %s: %s", dbName, e) return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Import fuer '{dbName}' fehlgeschlagen: {e}"]} finally: conn.close() recordCount = sum(dbResult.values()) return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings} # --------------------------------------------------------------------------- # Streaming Import (memory-safe, ijson-based) # --------------------------------------------------------------------------- def _iterStreamRows(filePath: str): """Yield ``(dbName, tableName, rowDict)`` one row at a time from an export JSON file using ijson streaming parser. Never holds more than one row in RAM. """ import ijson with open(filePath, "rb") as f: currentDb: Optional[str] = None currentTable: Optional[str] = None rowPrefix: Optional[str] = None inRow = False stack: List[Tuple[Any, Optional[str]]] = [] root: dict = {} for prefix, event, value in ijson.parse(f, use_float=True): if not inRow: if prefix == "databases" and event == "map_key": currentDb = value currentTable = None rowPrefix = None elif (currentDb and prefix == f"databases.{currentDb}.tables" and event == "map_key"): currentTable = value rowPrefix = f"databases.{currentDb}.tables.{currentTable}.item" elif rowPrefix and prefix == rowPrefix and event == "start_map": inRow = True root = {} stack = [(root, None)] continue container, pendingKey = stack[-1] if event == "map_key": stack[-1] = (container, value) elif event in ("string", "number", "boolean", "null"): if isinstance(container, dict): container[pendingKey] = value stack[-1] = (container, None) else: container.append(value) elif event == "start_map": nested: dict = {} if isinstance(container, dict): container[pendingKey] = nested stack[-1] = (container, None) else: container.append(nested) stack.append((nested, None)) elif event == "end_map": stack.pop() if not stack: yield (currentDb, currentTable, root) inRow = False elif event == "start_array": nested_list: list = [] if isinstance(container, dict): container[pendingKey] = nested_list stack[-1] = (container, None) else: container.append(nested_list) stack.append((nested_list, None)) elif event == "end_array": stack.pop() def _streamValidate(filePath: str, progressCb=None) -> dict: """Stream-validate an import file without loading it into RAM. Extracts ``meta``, counts databases/tables/records, detects system objects, and builds the ID remap -- all with constant memory usage. ``progressCb``, when provided, is called on every database/table transition with ``{"phase": "validate", "db": str, "table": str, "rows": int}``. Returns the same shape as ``_prepareImport`` plus ``remap``. """ import ijson meta = None with open(filePath, "rb") as f: for m in ijson.items(f, "meta"): meta = m break warnings: List[str] = [] if not meta: warnings.append("Fehlende oder ungueltige 'meta'-Sektion") registeredDbs = getRegisteredDatabases() dbTableCounts: Dict[str, Dict[str, int]] = {} systemObjectsFound: List[dict] = [] prevDb: Optional[str] = None prevTable: Optional[str] = None for dbName, tableName, row in _iterStreamRows(filePath): if dbName not in dbTableCounts: dbTableCounts[dbName] = {} if tableName not in dbTableCounts[dbName]: dbTableCounts[dbName][tableName] = 0 dbTableCounts[dbName][tableName] += 1 if progressCb and (dbName != prevDb or tableName != prevTable): if prevDb and prevTable: progressCb({"phase": "validate", "db": prevDb, "table": prevTable, "rows": dbTableCounts[prevDb][prevTable]}) prevDb = dbName prevTable = tableName if dbName == "poweron_app": if (tableName == "Mandate" and row.get("name") == "root" and row.get("isSystem") is True): systemObjectsFound.append({ "type": "mandate", "label": "Root Mandate", "payloadId": row.get("id"), }) elif tableName == "UserInDB": uname = row.get("username") if uname == "admin": systemObjectsFound.append({ "type": "user", "label": "Admin User", "payloadId": row.get("id"), }) elif uname == "event": systemObjectsFound.append({ "type": "user", "label": "Event User", "payloadId": row.get("id"), }) if progressCb and prevDb and prevTable: progressCb({"phase": "validate", "db": prevDb, "table": prevTable, "rows": dbTableCounts[prevDb][prevTable]}) if not dbTableCounts: warnings.append("Fehlende oder ungueltige 'databases'-Sektion") summary: List[dict] = [] for dbName, tables in dbTableCounts.items(): registered = dbName in registeredDbs if not registered: warnings.append(f"Datenbank '{dbName}' ist nicht registriert und wird uebersprungen") summary.append({ "database": dbName, "tableCount": len(tables), "recordCount": sum(tables.values()), "registered": registered, }) liveIds = _loadLiveSystemObjectIds() remap: Dict[str, str] = {} for obj in systemObjectsFound: oldId = str(obj.get("payloadId", "")) if obj["type"] == "mandate": newId = liveIds.get("rootMandate", "") elif obj["label"] == "Admin User": newId = liveIds.get("adminUser", "") elif obj["label"] == "Event User": newId = liveIds.get("eventUser", "") else: continue if oldId and newId and oldId != newId: remap[oldId] = newId if remap: logger.info("System-object ID remap: %s", remap) protectedIdSet = set(liveIds.values()) valid = any(s["registered"] for s in summary) dbList: List[dict] = [] for s in summary: if s["registered"]: dbList.append({ "database": s["database"], "tableCount": s["tableCount"], "recordCount": s["recordCount"], }) return { "valid": valid, "summary": summary, "warnings": warnings, "systemObjectsFound": systemObjectsFound, "databases": dbList, "protectedIds": list(protectedIdSet), "remap": remap, } def _streamSplitToFiles( filePath: str, tmpDir: str, token: str, remap: Dict[str, str], progressCb=None, ) -> Dict[str, Dict[str, str]]: """Stream through the export file a second time, applying ID remap on each row and writing per-table JSONL temp files. ``progressCb``, when provided, is called on every database/table transition with ``{"phase": "split", "db": str, "table": str, "rows": int}``. Returns ``{dbName: {tableName: filePath}}``. """ import os remapSet = set(remap.keys()) if remap else set() dbFiles: Dict[str, Dict[str, str]] = {} writers: Dict[Tuple[str, str], Any] = {} tableCounts: Dict[Tuple[str, str], int] = {} prevDb: Optional[str] = None prevTable: Optional[str] = None try: for dbName, tableName, row in _iterStreamRows(filePath): if remap: _remapRowValues(row, remap, remapSet) key = (dbName, tableName) if key not in writers: if progressCb and prevDb and prevTable: progressCb({"phase": "split", "db": prevDb, "table": prevTable, "rows": tableCounts.get((prevDb, prevTable), 0)}) prevDb = dbName prevTable = tableName tblPath = os.path.join( tmpDir, f"poweron_import_{token}_{dbName}__{tableName}.jsonl", ) writers[key] = open(tblPath, "w", encoding="utf-8") if dbName not in dbFiles: dbFiles[dbName] = {} dbFiles[dbName][tableName] = tblPath tableCounts[key] = 0 writers[key].write(json.dumps(row, ensure_ascii=False, default=str)) writers[key].write("\n") tableCounts[key] += 1 finally: for fh in writers.values(): fh.close() if progressCb and prevDb and prevTable: progressCb({"phase": "split", "db": prevDb, "table": prevTable, "rows": tableCounts.get((prevDb, prevTable), 0)}) return dbFiles def _importSingleDbFromFiles( tableFiles: Dict[str, str], dbName: str, mode: str, protectedIds: List[str], ) -> dict: """Import a single database from per-table JSONL files. Each file contains one JSON object per line (rows). Tables are sorted by FK dependencies before import. Returns ``{database, tables, recordCount, warnings}``. """ if mode not in ("replace", "merge"): raise ValueError(f"Invalid import mode: {mode}") registeredDbs = getRegisteredDatabases() if dbName not in registeredDbs: return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Datenbank '{dbName}' nicht registriert"]} if not tableFiles: return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Keine Daten fuer '{dbName}'"]} try: dbCreated = _ensureDatabaseExists(dbName) except Exception as e: logger.error("Failed to ensure database %s exists: %s", dbName, e) return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Datenbank '{dbName}' konnte nicht erstellt werden: {e}"]} protectedIdSet = set(protectedIds) warnings: List[str] = [] dbResult: Dict[str, int] = {} excluded = _EXCLUDED_TABLES.get(dbName, set()) if dbCreated: warnings.append(f"Datenbank '{dbName}' wurde neu erstellt") conn = _getConnection(dbName) try: existingTables = set(_listTables(conn)) conn.rollback() conn.autocommit = True for tableName, tblPath in tableFiles.items(): if tableName in excluded: continue if tableName not in existingTables: sampleRows: List[dict] = [] with open(tblPath, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: sampleRows.append(json.loads(line)) if len(sampleRows) >= 10: break if sampleRows: _createTableFromExport(conn, tableName, sampleRows) existingTables.add(tableName) logger.info("Pre-created missing table %s.%s", dbName, tableName) importable = [t for t in tableFiles if t not in excluded and t in existingTables] importOrder = _getTableImportOrder(conn, importable, dbName) logger.info("Import order for %s: %s", dbName, importOrder) for tableName in tableFiles: if tableName in excluded: warnings.append(f"Table '{dbName}.{tableName}' excluded (security/transient)") if mode == "replace": conn.autocommit = False for tableName in reversed(importOrder): try: _deleteNonProtected(conn, tableName, protectedIdSet) conn.commit() except Exception as e: conn.rollback() warnings.append(f"DELETE from {dbName}.{tableName} failed: {e}") logger.warning("DELETE from %s.%s failed: %s", dbName, tableName, e) conn.autocommit = False batchSize = 100 for tableName in importOrder: tblPath = tableFiles.get(tableName) if not tblPath: continue try: physicalCols = _getPhysicalColumns(conn, tableName) if not physicalCols: conn.rollback() continue insertedCount = 0 batch: List[dict] = [] with open(tblPath, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue row = json.loads(line) if _isProtectedRow(tableName, row): continue if row.get("id") and str(row["id"]) in protectedIdSet: continue batch.append(row) if len(batch) >= batchSize: insertedCount += _insertRows(conn, tableName, batch, physicalCols, mode) batch = [] if batch: insertedCount += _insertRows(conn, tableName, batch, physicalCols, mode) conn.commit() dbResult[tableName] = insertedCount except Exception as e: conn.rollback() warnings.append(f"INSERT into {dbName}.{tableName} failed: {e}") logger.warning("INSERT into %s.%s failed: %s", dbName, tableName, e) except Exception as e: logger.error("Import failed for database %s: %s", dbName, e) return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Import fuer '{dbName}' fehlgeschlagen: {e}"]} finally: conn.close() recordCount = sum(dbResult.values()) return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings}