# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Database migration utilities — backup (export) and restore (import) for all registered PowerOn databases. System objects (root mandate, admin user, event user) are protected: they are never deleted or overwritten during import. Their IDs in the backup payload are remapped to the IDs of the corresponding live objects so that all FK references stay consistent. All functions are intended for SysAdmin use only (access control in the route layer). """ import logging from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Set, Tuple import psycopg2 import psycopg2.extras from modules.shared.configuration import APP_CONFIG from modules.shared.dbRegistry import getRegisteredDatabases from modules.system.databaseHealth import _getConnection, _jsonSafe logger = logging.getLogger(__name__) _EXPORT_FORMAT_VERSION = "1.0" _SYSTEM_TABLE = "_system" # --------------------------------------------------------------------------- # Instance label # --------------------------------------------------------------------------- def _getInstanceLabel() -> str: """Return the instance type from APP_ENV_TYPE (e.g. 'dev', 'int', 'prod').""" return APP_CONFIG.get("APP_ENV_TYPE", "unknown") # --------------------------------------------------------------------------- # Database list # --------------------------------------------------------------------------- def _getAvailableDatabases() -> List[dict]: """Return registered databases with table/row counts for the UI.""" registeredDbs = getRegisteredDatabases() results: List[dict] = [] for dbName in sorted(registeredDbs): if dbName == "poweron_test": continue entry: dict = {"name": dbName, "tableCount": 0, "recordCount": 0} try: conn = _getConnection(dbName) try: with conn.cursor() as cur: cur.execute(""" SELECT relname, n_live_tup FROM pg_stat_user_tables WHERE schemaname = 'public' AND relname NOT LIKE '\\_%%' """) for row in cur.fetchall(): entry["tableCount"] += 1 entry["recordCount"] += int(row["n_live_tup"]) finally: conn.close() except Exception as e: logger.warning("Could not stat database %s: %s", dbName, e) results.append(entry) return results # --------------------------------------------------------------------------- # Export # --------------------------------------------------------------------------- def _exportDatabases(databases: List[str]) -> dict: """Export selected databases as a JSON-serialisable dict. Returns ``{meta: {...}, databases: {dbName: {tables: {tbl: [rows]}, summary: {...}}}}`` """ registeredDbs = getRegisteredDatabases() if not databases: raise ValueError("No databases selected for export.") exportData: dict = { "meta": { "exportedAt": datetime.now(timezone.utc).isoformat(), "version": _EXPORT_FORMAT_VERSION, "databaseCount": 0, "totalTables": 0, "totalRecords": 0, }, "databases": {}, } for dbName in databases: if dbName not in registeredDbs: logger.warning("Export: skipping unregistered database %s", dbName) continue try: dbPayload = _exportSingleDb(dbName) exportData["databases"][dbName] = dbPayload exportData["meta"]["databaseCount"] += 1 exportData["meta"]["totalTables"] += dbPayload["tableCount"] exportData["meta"]["totalRecords"] += dbPayload["totalRecords"] except Exception as e: logger.error("Export failed for database %s: %s", dbName, e) return exportData def _exportSingleDb(dbName: str) -> dict: conn = _getConnection(dbName) try: tables = _listTables(conn) dbPayload: dict = {"tables": {}, "summary": {}, "tableCount": 0, "totalRecords": 0} for tbl in tables: rows = _readTableRows(conn, tbl) dbPayload["tables"][tbl] = rows dbPayload["summary"][tbl] = {"recordCount": len(rows)} dbPayload["tableCount"] += 1 dbPayload["totalRecords"] += len(rows) return dbPayload finally: conn.close() def _listTables(conn) -> List[str]: with conn.cursor() as cur: cur.execute(""" SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' AND table_name != %s ORDER BY table_name """, (_SYSTEM_TABLE,)) return [row["table_name"] for row in cur.fetchall()] def _readTableRows(conn, tableName: str) -> List[dict]: with conn.cursor() as cur: cur.execute(f'SELECT * FROM "{tableName}"') return [{k: _jsonSafe(v) for k, v in dict(row).items()} for row in cur.fetchall()] # --------------------------------------------------------------------------- # Validate # --------------------------------------------------------------------------- def _validateImportPayload(payload: dict) -> dict: """Validate an import payload without writing anything. Returns ``{valid, summary, warnings, systemObjectsFound}``. """ warnings: List[str] = [] summary: List[dict] = [] meta = payload.get("meta") if not meta or not isinstance(meta, dict): return {"valid": False, "summary": [], "warnings": ["Fehlende oder ungueltige 'meta'-Sektion"], "systemObjectsFound": []} version = meta.get("version", "") if version != _EXPORT_FORMAT_VERSION: warnings.append(f"Unbekannte Format-Version: {version} (erwartet: {_EXPORT_FORMAT_VERSION})") databases = payload.get("databases") if not databases or not isinstance(databases, dict): return {"valid": False, "summary": [], "warnings": ["Fehlende oder ungueltige 'databases'-Sektion"], "systemObjectsFound": []} registeredDbs = getRegisteredDatabases() for dbName, dbData in databases.items(): tables = dbData.get("tables", {}) tableCount = len(tables) recordCount = sum(len(rows) for rows in tables.values() if isinstance(rows, list)) registered = dbName in registeredDbs if not registered: warnings.append(f"Datenbank '{dbName}' ist nicht registriert und wird uebersprungen") summary.append({ "database": dbName, "tableCount": tableCount, "recordCount": recordCount, "registered": registered, }) systemObjectsFound = _detectSystemObjectsInPayload(payload) valid = any(s["registered"] for s in summary) return { "valid": valid, "summary": summary, "warnings": warnings, "systemObjectsFound": systemObjectsFound, } def _detectSystemObjectsInPayload(payload: dict) -> List[dict]: """Find system objects (root mandate, admin user, event user) in a payload.""" found: List[dict] = [] appData = payload.get("databases", {}).get("poweron_app", {}).get("tables", {}) for row in appData.get("Mandate", []): if row.get("name") == "root" and row.get("isSystem") is True: found.append({"type": "mandate", "label": "Root Mandate", "payloadId": row.get("id")}) for row in appData.get("UserInDB", []): if row.get("username") == "admin": found.append({"type": "user", "label": "Admin User", "payloadId": row.get("id")}) elif row.get("username") == "event": found.append({"type": "user", "label": "Event User", "payloadId": row.get("id")}) return found # --------------------------------------------------------------------------- # System-object ID remapping # --------------------------------------------------------------------------- def _loadLiveSystemObjectIds() -> Dict[str, str]: """Load the IDs of the 3 protected system objects from the live DB. Returns a dict like ``{"rootMandate": "", "adminUser": "", "eventUser": ""}``. """ registeredDbs = getRegisteredDatabases() if "poweron_app" not in registeredDbs: return {} result: Dict[str, str] = {} conn = _getConnection("poweron_app") try: with conn.cursor() as cur: cur.execute("""SELECT id FROM "Mandate" WHERE "name" = 'root' AND "isSystem" = true LIMIT 1""") row = cur.fetchone() if row: result["rootMandate"] = str(row["id"]) cur.execute("""SELECT id FROM "UserInDB" WHERE "username" = 'admin' LIMIT 1""") row = cur.fetchone() if row: result["adminUser"] = str(row["id"]) cur.execute("""SELECT id FROM "UserInDB" WHERE "username" = 'event' LIMIT 1""") row = cur.fetchone() if row: result["eventUser"] = str(row["id"]) finally: conn.close() return result def _buildIdRemapFromPayload(payload: dict, liveIds: Dict[str, str]) -> Dict[str, str]: """Build an ``{oldId: newId}`` mapping for system objects. Compares IDs found in the payload with the live system-object IDs. Only entries where the IDs actually differ are included. """ remap: Dict[str, str] = {} appTables = payload.get("databases", {}).get("poweron_app", {}).get("tables", {}) for row in appTables.get("Mandate", []): if row.get("name") == "root" and row.get("isSystem") is True: oldId = str(row.get("id", "")) newId = liveIds.get("rootMandate", "") if oldId and newId and oldId != newId: remap[oldId] = newId for row in appTables.get("UserInDB", []): username = row.get("username") oldId = str(row.get("id", "")) if username == "admin": newId = liveIds.get("adminUser", "") elif username == "event": newId = liveIds.get("eventUser", "") else: continue if oldId and newId and oldId != newId: remap[oldId] = newId return remap def _remapSystemObjectIds(payload: dict, remap: Dict[str, str]) -> dict: """Walk the entire payload and replace every value that matches an old system-object ID.""" if not remap: return payload remapSet = set(remap.keys()) databases = payload.get("databases", {}) for dbName, dbData in databases.items(): tables = dbData.get("tables", {}) for tableName, rows in tables.items(): if not isinstance(rows, list): continue for row in rows: _remapRowValues(row, remap, remapSet) return payload def _remapRowValues(row: dict, remap: Dict[str, str], remapSet: Set[str]) -> None: """In-place replace string values in a row dict that match a remap key.""" for key, val in row.items(): if isinstance(val, str) and val in remapSet: row[key] = remap[val] elif isinstance(val, dict): _remapRowValues(val, remap, remapSet) elif isinstance(val, list): for i, item in enumerate(val): if isinstance(item, str) and item in remapSet: val[i] = remap[item] elif isinstance(item, dict): _remapRowValues(item, remap, remapSet) # --------------------------------------------------------------------------- # Import # --------------------------------------------------------------------------- _PROTECTED_ROWS: Dict[str, List[dict]] = { "Mandate": [{"name": "root", "isSystem": True}], "UserInDB": [{"username": "admin"}, {"username": "event"}], } def _isProtectedRow(tableName: str, row: dict) -> bool: """Return True if a row represents a protected system object.""" patterns = _PROTECTED_ROWS.get(tableName, []) for pattern in patterns: if all(row.get(k) == v for k, v in pattern.items()): return True return False def _importDatabases(payload: dict, mode: str) -> dict: """Import databases from a validated payload. ``mode`` is ``"replace"`` (clear + insert) or ``"merge"`` (insert missing only). """ if mode not in ("replace", "merge"): raise ValueError(f"Invalid import mode: {mode}") registeredDbs = getRegisteredDatabases() liveIds = _loadLiveSystemObjectIds() remap = _buildIdRemapFromPayload(payload, liveIds) if remap: logger.info("System-object ID remap: %s", remap) _remapSystemObjectIds(payload, remap) protectedIdSet = set(liveIds.values()) imported: Dict[str, dict] = {} warnings: List[str] = [] databases = payload.get("databases", {}) for dbName, dbData in databases.items(): if dbName not in registeredDbs: warnings.append(f"Datenbank '{dbName}' uebersprungen (nicht registriert)") continue tables = dbData.get("tables", {}) dbResult: Dict[str, int] = {} conn = _getConnection(dbName) try: conn.autocommit = False existingTables = set(_listTables(conn)) for tableName, rows in tables.items(): if not isinstance(rows, list): continue if tableName not in existingTables: warnings.append(f"Tabelle '{dbName}.{tableName}' existiert nicht, uebersprungen") continue physicalCols = _getPhysicalColumns(conn, tableName) if not physicalCols: continue filteredRows = [] for row in rows: if _isProtectedRow(tableName, row): continue if row.get("id") and str(row["id"]) in protectedIdSet: continue filteredRows.append(row) if mode == "replace": _deleteNonProtected(conn, tableName, protectedIdSet) insertedCount = _insertRows(conn, tableName, filteredRows, physicalCols, mode) dbResult[tableName] = insertedCount conn.commit() except Exception as e: conn.rollback() logger.error("Import failed for database %s: %s", dbName, e) warnings.append(f"Import fuer '{dbName}' fehlgeschlagen: {e}") continue finally: conn.close() imported[dbName] = dbResult totalRecords = sum(sum(v.values()) for v in imported.values()) return { "success": True, "imported": imported, "totalRecords": totalRecords, "warnings": warnings, } def _getPhysicalColumns(conn, tableName: str) -> List[str]: with conn.cursor() as cur: cur.execute(""" SELECT column_name FROM information_schema.columns WHERE table_schema = 'public' AND table_name = %s ORDER BY ordinal_position """, (tableName,)) return [row["column_name"] for row in cur.fetchall()] def _deleteNonProtected(conn, tableName: str, protectedIds: Set[str]) -> int: """Delete all rows except protected system objects.""" if not protectedIds: with conn.cursor() as cur: cur.execute(f'DELETE FROM "{tableName}"') return cur.rowcount idList = list(protectedIds) with conn.cursor() as cur: cur.execute( f'DELETE FROM "{tableName}" WHERE "id"::text != ALL(%(ids)s)', {"ids": idList}, ) return cur.rowcount def _insertRows( conn, tableName: str, rows: List[dict], physicalCols: List[str], mode: str, ) -> int: """Insert rows into a table. In merge mode, skip rows whose id already exists.""" if not rows: return 0 physicalColSet = set(physicalCols) inserted = 0 for row in rows: cols = [c for c in row.keys() if c in physicalColSet] if not cols: continue values = [_pgSafe(row[c]) for c in cols] colNames = ", ".join(f'"{c}"' for c in cols) placeholders = ", ".join(["%s"] * len(cols)) if mode == "merge": sql = f'INSERT INTO "{tableName}" ({colNames}) VALUES ({placeholders}) ON CONFLICT ("id") DO NOTHING' else: sql = f'INSERT INTO "{tableName}" ({colNames}) VALUES ({placeholders})' try: with conn.cursor() as cur: cur.execute(sql, values) inserted += cur.rowcount except Exception as e: logger.warning("Insert failed for %s row: %s", tableName, e) conn.rollback() conn.autocommit = False return inserted def _pgSafe(v: Any) -> Any: """Convert Python values to psycopg2-compatible types.""" import json as _json if v is None or isinstance(v, (str, int, float, bool)): return v if isinstance(v, (dict, list)): return _json.dumps(v) return str(v) # --------------------------------------------------------------------------- # Prepare import (validate + remap, return context for per-DB import) # --------------------------------------------------------------------------- def _prepareImport(payload: dict) -> dict: """Validate, remap system-object IDs, and return the prepared payload together with metadata the frontend needs to drive per-DB import. Returns ``{valid, warnings, systemObjectsFound, databases, protectedIds, remappedPayload}``. """ validation = _validateImportPayload(payload) if not validation.get("valid"): return { "valid": False, "warnings": validation.get("warnings", []), "systemObjectsFound": validation.get("systemObjectsFound", []), "databases": [], "protectedIds": [], } liveIds = _loadLiveSystemObjectIds() remap = _buildIdRemapFromPayload(payload, liveIds) if remap: logger.info("System-object ID remap: %s", remap) _remapSystemObjectIds(payload, remap) protectedIdSet = set(liveIds.values()) registeredDbs = getRegisteredDatabases() dbList = [] for dbName, dbData in payload.get("databases", {}).items(): if dbName not in registeredDbs: continue tables = dbData.get("tables", {}) recordCount = sum(len(rows) for rows in tables.values() if isinstance(rows, list)) dbList.append({ "database": dbName, "tableCount": len(tables), "recordCount": recordCount, }) return { "valid": True, "warnings": validation.get("warnings", []), "systemObjectsFound": validation.get("systemObjectsFound", []), "databases": dbList, "protectedIds": list(protectedIdSet), } def _ensureDatabaseExists(dbName: str) -> bool: """Create the PostgreSQL database if it does not yet exist. Connects to the ``postgres`` admin database using the same credentials as the target DB. Returns True if the database was created, False if it already existed. """ registeredDbs = getRegisteredDatabases() configPrefix = registeredDbs.get(dbName) if configPrefix is None: return False hostKey = f"{configPrefix}_HOST" if configPrefix != "DB" else "DB_HOST" portKey = f"{configPrefix}_PORT" if configPrefix != "DB" else "DB_PORT" userKey = f"{configPrefix}_USER" if configPrefix != "DB" else "DB_USER" passwordKey = f"{configPrefix}_PASSWORD_SECRET" if configPrefix != "DB" else "DB_PASSWORD_SECRET" adminConn = psycopg2.connect( host=APP_CONFIG.get(hostKey, "localhost"), port=int(APP_CONFIG.get(portKey, 5432)), database="postgres", user=APP_CONFIG.get(userKey), password=APP_CONFIG.get(passwordKey), client_encoding="utf8", ) try: adminConn.autocommit = True with adminConn.cursor() as cur: cur.execute("SELECT 1 FROM pg_database WHERE datname = %s", (dbName,)) if cur.fetchone(): return False cur.execute(f'CREATE DATABASE "{dbName}"') logger.info("Created missing database: %s", dbName) return True finally: adminConn.close() def _createTableFromExport(conn, tableName: str, rows: List[dict]) -> None: """Create a table based on the column structure found in the export data. Uses TEXT for all columns since we don't have the original DDL. The ``id`` column gets a PRIMARY KEY constraint. """ allKeys: List[str] = [] seen: set = set() for row in rows: for k in row.keys(): if k not in seen: allKeys.append(k) seen.add(k) if not allKeys: return colDefs = [] for col in allKeys: if col == "id": colDefs.append(f'"{col}" TEXT PRIMARY KEY') else: colDefs.append(f'"{col}" TEXT') ddl = f'CREATE TABLE IF NOT EXISTS "{tableName}" ({", ".join(colDefs)})' with conn.cursor() as cur: cur.execute(ddl) logger.info("Created table %s with %d columns", tableName, len(allKeys)) def _importSingleDb(payload: dict, dbName: str, mode: str, protectedIds: List[str]) -> dict: """Import a single database from the (already remapped) payload. Returns ``{database, tables: {tableName: insertedCount}, recordCount, warnings}``. """ if mode not in ("replace", "merge"): raise ValueError(f"Invalid import mode: {mode}") registeredDbs = getRegisteredDatabases() if dbName not in registeredDbs: return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Datenbank '{dbName}' nicht registriert"]} dbData = payload.get("databases", {}).get(dbName) if not dbData: return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Keine Daten fuer '{dbName}' im Payload"]} try: dbCreated = _ensureDatabaseExists(dbName) except Exception as e: logger.error("Failed to ensure database %s exists: %s", dbName, e) return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Datenbank '{dbName}' konnte nicht erstellt werden: {e}"]} protectedIdSet = set(protectedIds) tables = dbData.get("tables", {}) warnings: List[str] = [] dbResult: Dict[str, int] = {} if dbCreated: warnings.append(f"Datenbank '{dbName}' wurde neu erstellt") conn = _getConnection(dbName) try: conn.autocommit = False existingTables = set(_listTables(conn)) for tableName, rows in tables.items(): if not isinstance(rows, list): continue if tableName not in existingTables: _createTableFromExport(conn, tableName, rows) conn.commit() conn.autocommit = False physicalCols = _getPhysicalColumns(conn, tableName) if not physicalCols: continue filteredRows = [] for row in rows: if _isProtectedRow(tableName, row): continue if row.get("id") and str(row["id"]) in protectedIdSet: continue filteredRows.append(row) if mode == "replace": _deleteNonProtected(conn, tableName, protectedIdSet) insertedCount = _insertRows(conn, tableName, filteredRows, physicalCols, mode) dbResult[tableName] = insertedCount conn.commit() except Exception as e: conn.rollback() logger.error("Import failed for database %s: %s", dbName, e) return {"database": dbName, "tables": {}, "recordCount": 0, "warnings": [f"Import fuer '{dbName}' fehlgeschlagen: {e}"]} finally: conn.close() recordCount = sum(dbResult.values()) return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings}