From 8c2e9d21837d992ac9e34571b23768cdce555770 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Thu, 28 May 2026 10:52:10 +0200 Subject: [PATCH] db import streaming --- modules/routes/routeAdminDatabaseHealth.py | 109 ++++--- modules/system/databaseMigration.py | 361 +++++++++++++++++++++ requirements.txt | 3 + 3 files changed, 420 insertions(+), 53 deletions(-) diff --git a/modules/routes/routeAdminDatabaseHealth.py b/modules/routes/routeAdminDatabaseHealth.py index 15ab1c5a..61b56bac 100644 --- a/modules/routes/routeAdminDatabaseHealth.py +++ b/modules/routes/routeAdminDatabaseHealth.py @@ -34,7 +34,10 @@ from modules.system.databaseMigration import ( _getInstanceLabel, _importDatabases, _importSingleDb, + _importSingleDbFromFiles, _prepareImport, + _streamSplitToFiles, + _streamValidate, _validateImportPayload, streamExportGenerator, ) @@ -530,43 +533,35 @@ def getMigrationExportStream( def _processUploadedFile(filePath: str, tmpDir: str, token: str) -> dict: - """Parse JSON, validate, remap, split into per-DB files. + """Streaming validate + split: never loads the full JSON into RAM. - Runs in a thread pool to avoid blocking the asyncio event loop - during the CPU-heavy json.load() of large (500+ MB) files. + Pass 1 (``_streamValidate``): extract meta, count tables/records, + detect system objects, build ID remap -- constant memory. + + Pass 2 (``_streamSplitToFiles``): iterate rows again, apply remap, + write per-table JSONL temp files -- one row in RAM at a time. """ - import gc import os - with open(filePath, "r", encoding="utf-8") as f: - payload = json.load(f) + result = _streamValidate(filePath) + + if not result.get("valid"): + try: + os.remove(filePath) + except OSError: + pass + return {"result": result, "dbFiles": {}} + + remap = result.get("remap", {}) + protectedIds = result.get("protectedIds", []) + + dbFiles = _streamSplitToFiles(filePath, tmpDir, token, remap) try: os.remove(filePath) except OSError: pass - result = _prepareImport(payload) - - if not result.get("valid"): - del payload - gc.collect() - return {"result": result, "dbFiles": {}} - - protectedIds = result.get("protectedIds", []) - - dbFiles = {} - databases = payload.get("databases", {}) - for dbName, dbData in databases.items(): - dbPath = os.path.join(tmpDir, f"poweron_import_{token}_{dbName}.json") - with open(dbPath, "w", encoding="utf-8") as dbF: - json.dump(dbData, dbF, ensure_ascii=False, default=str) - dbFiles[dbName] = dbPath - - del payload - del databases - gc.collect() - return {"result": result, "dbFiles": dbFiles, "protectedIds": protectedIds} @@ -617,6 +612,7 @@ async def postMigrationUploadImport( os.remove(filePath) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON file: {e}") from e except Exception as e: + logger.exception("Processing uploaded import file failed: %s", e) if os.path.exists(filePath): os.remove(filePath) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Processing failed: {e}") from e @@ -657,7 +653,10 @@ def postMigrationImportSingle( body: dict, currentUser: User = Depends(requireSysAdmin), ) -> Dict[str, Any]: - """Import a single database from a previously uploaded + prepared payload. + """Import a single database from previously uploaded + prepared payload. + + Supports both the new per-table JSONL format (``{tableName: filePath}``) + and the legacy single-JSON-per-DB format (plain file path string). Body: ``{token, database, mode}`` """ @@ -674,29 +673,26 @@ def postMigrationImportSingle( if not pending: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired import token.") - dbFiles = pending.get("dbFiles", {}) - dbFilePath = dbFiles.get(database) - if not dbFilePath or not os.path.exists(dbFilePath): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"No data for database '{database}'.", - ) + dbEntry = pending.get("dbFiles", {}).get(database) + if not dbEntry: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"No data for database '{database}'.") logger.info("SysAdmin migration import-single: user=%s db=%s mode=%s", currentUser.username, database, mode) try: - with open(dbFilePath, "r", encoding="utf-8") as f: - dbData = json.load(f) - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Failed to read import data for '{database}': {e}", - ) from e - - payload = {"databases": {database: dbData}} - - try: - result = _importSingleDb(payload, database, mode, pending["protectedIds"]) + if isinstance(dbEntry, dict): + result = _importSingleDbFromFiles(dbEntry, database, mode, pending["protectedIds"]) + else: + dbFilePath = dbEntry + if not os.path.exists(dbFilePath): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"No data for database '{database}'.") + with open(dbFilePath, "r", encoding="utf-8") as f: + dbData = json.load(f) + payload = {"databases": {database: dbData}} + result = _importSingleDb(payload, database, mode, pending["protectedIds"]) + except HTTPException: + raise except Exception as e: logger.error("Import-single failed for %s: %s", database, e) raise HTTPException( @@ -714,15 +710,22 @@ def postMigrationImportDone( body: dict, currentUser: User = Depends(requireSysAdmin), ) -> Dict[str, Any]: - """Clean up the per-DB temp files.""" + """Clean up the per-DB / per-table temp files.""" import os token = body.get("token", "") pending = _pendingImports.pop(token, None) if pending: - for dbPath in pending.get("dbFiles", {}).values(): - try: - os.remove(dbPath) - except OSError: - pass + for dbEntry in pending.get("dbFiles", {}).values(): + if isinstance(dbEntry, str): + try: + os.remove(dbEntry) + except OSError: + pass + elif isinstance(dbEntry, dict): + for tblPath in dbEntry.values(): + try: + os.remove(tblPath) + except OSError: + pass return {"ok": True} diff --git a/modules/system/databaseMigration.py b/modules/system/databaseMigration.py index 8244ca4e..da54d4a7 100644 --- a/modules/system/databaseMigration.py +++ b/modules/system/databaseMigration.py @@ -12,6 +12,7 @@ references stay consistent. All functions are intended for SysAdmin use only (access control in the route layer). """ +import json import logging from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Set, Tuple @@ -929,3 +930,363 @@ def _importSingleDb(payload: dict, dbName: str, mode: str, protectedIds: List[st recordCount = sum(dbResult.values()) return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings} + + +# --------------------------------------------------------------------------- +# Streaming Import (memory-safe, ijson-based) +# --------------------------------------------------------------------------- + +def _iterStreamRows(filePath: str): + """Yield ``(dbName, tableName, rowDict)`` one row at a time from an export + JSON file using ijson streaming parser. Never holds more than one row in RAM. + """ + import ijson + + with open(filePath, "rb") as f: + currentDb: Optional[str] = None + currentTable: Optional[str] = None + rowPrefix: Optional[str] = None + inRow = False + stack: List[Tuple[Any, Optional[str]]] = [] + root: dict = {} + + for prefix, event, value in ijson.parse(f, use_float=True): + if not inRow: + if prefix == "databases" and event == "map_key": + currentDb = value + currentTable = None + rowPrefix = None + elif (currentDb + and prefix == f"databases.{currentDb}.tables" + and event == "map_key"): + currentTable = value + rowPrefix = f"databases.{currentDb}.tables.{currentTable}.item" + elif rowPrefix and prefix == rowPrefix and event == "start_map": + inRow = True + root = {} + stack = [(root, None)] + continue + + container, pendingKey = stack[-1] + + if event == "map_key": + stack[-1] = (container, value) + + elif event in ("string", "number", "boolean", "null"): + if isinstance(container, dict): + container[pendingKey] = value + stack[-1] = (container, None) + else: + container.append(value) + + elif event == "start_map": + nested: dict = {} + if isinstance(container, dict): + container[pendingKey] = nested + stack[-1] = (container, None) + else: + container.append(nested) + stack.append((nested, None)) + + elif event == "end_map": + stack.pop() + if not stack: + yield (currentDb, currentTable, root) + inRow = False + + elif event == "start_array": + nested_list: list = [] + if isinstance(container, dict): + container[pendingKey] = nested_list + stack[-1] = (container, None) + else: + container.append(nested_list) + stack.append((nested_list, None)) + + elif event == "end_array": + stack.pop() + + +def _streamValidate(filePath: str) -> dict: + """Stream-validate an import file without loading it into RAM. + + Extracts ``meta``, counts databases/tables/records, detects system objects, + and builds the ID remap -- all with constant memory usage. + + Returns the same shape as ``_prepareImport`` plus ``remap``. + """ + import ijson + + meta = None + with open(filePath, "rb") as f: + for m in ijson.items(f, "meta"): + meta = m + break + + warnings: List[str] = [] + if not meta: + warnings.append("Fehlende oder ungueltige 'meta'-Sektion") + + registeredDbs = getRegisteredDatabases() + dbTableCounts: Dict[str, Dict[str, int]] = {} + systemObjectsFound: List[dict] = [] + + for dbName, tableName, row in _iterStreamRows(filePath): + if dbName not in dbTableCounts: + dbTableCounts[dbName] = {} + if tableName not in dbTableCounts[dbName]: + dbTableCounts[dbName][tableName] = 0 + dbTableCounts[dbName][tableName] += 1 + + if dbName == "poweron_app": + if (tableName == "Mandate" + and row.get("name") == "root" + and row.get("isSystem") is True): + systemObjectsFound.append({ + "type": "mandate", + "label": "Root Mandate", + "payloadId": row.get("id"), + }) + elif tableName == "UserInDB": + uname = row.get("username") + if uname == "admin": + systemObjectsFound.append({ + "type": "user", + "label": "Admin User", + "payloadId": row.get("id"), + }) + elif uname == "event": + systemObjectsFound.append({ + "type": "user", + "label": "Event User", + "payloadId": row.get("id"), + }) + + if not dbTableCounts: + warnings.append("Fehlende oder ungueltige 'databases'-Sektion") + + summary: List[dict] = [] + for dbName, tables in dbTableCounts.items(): + registered = dbName in registeredDbs + if not registered: + warnings.append(f"Datenbank '{dbName}' ist nicht registriert und wird uebersprungen") + summary.append({ + "database": dbName, + "tableCount": len(tables), + "recordCount": sum(tables.values()), + "registered": registered, + }) + + liveIds = _loadLiveSystemObjectIds() + + remap: Dict[str, str] = {} + for obj in systemObjectsFound: + oldId = str(obj.get("payloadId", "")) + if obj["type"] == "mandate": + newId = liveIds.get("rootMandate", "") + elif obj["label"] == "Admin User": + newId = liveIds.get("adminUser", "") + elif obj["label"] == "Event User": + newId = liveIds.get("eventUser", "") + else: + continue + if oldId and newId and oldId != newId: + remap[oldId] = newId + + if remap: + logger.info("System-object ID remap: %s", remap) + + protectedIdSet = set(liveIds.values()) + valid = any(s["registered"] for s in summary) + + dbList: List[dict] = [] + for s in summary: + if s["registered"]: + dbList.append({ + "database": s["database"], + "tableCount": s["tableCount"], + "recordCount": s["recordCount"], + }) + + return { + "valid": valid, + "summary": summary, + "warnings": warnings, + "systemObjectsFound": systemObjectsFound, + "databases": dbList, + "protectedIds": list(protectedIdSet), + "remap": remap, + } + + +def _streamSplitToFiles( + filePath: str, + tmpDir: str, + token: str, + remap: Dict[str, str], +) -> Dict[str, Dict[str, str]]: + """Stream through the export file a second time, applying ID remap on + each row and writing per-table JSONL temp files. + + Returns ``{dbName: {tableName: filePath}}``. + """ + import os + + remapSet = set(remap.keys()) if remap else set() + dbFiles: Dict[str, Dict[str, str]] = {} + writers: Dict[Tuple[str, str], Any] = {} + + try: + for dbName, tableName, row in _iterStreamRows(filePath): + if remap: + _remapRowValues(row, remap, remapSet) + + key = (dbName, tableName) + if key not in writers: + tblPath = os.path.join( + tmpDir, + f"poweron_import_{token}_{dbName}__{tableName}.jsonl", + ) + writers[key] = open(tblPath, "w", encoding="utf-8") + if dbName not in dbFiles: + dbFiles[dbName] = {} + dbFiles[dbName][tableName] = tblPath + + writers[key].write(json.dumps(row, ensure_ascii=False, default=str)) + writers[key].write("\n") + finally: + for fh in writers.values(): + fh.close() + + return dbFiles + + +def _importSingleDbFromFiles( + tableFiles: Dict[str, str], + dbName: str, + mode: str, + protectedIds: List[str], +) -> dict: + """Import a single database from per-table JSONL files. + + Each file contains one JSON object per line (rows). + Tables are sorted by FK dependencies before import. + + Returns ``{database, tables, recordCount, warnings}``. + """ + if mode not in ("replace", "merge"): + raise ValueError(f"Invalid import mode: {mode}") + + registeredDbs = getRegisteredDatabases() + if dbName not in registeredDbs: + return {"database": dbName, "tables": {}, "recordCount": 0, + "warnings": [f"Datenbank '{dbName}' nicht registriert"]} + + if not tableFiles: + return {"database": dbName, "tables": {}, "recordCount": 0, + "warnings": [f"Keine Daten fuer '{dbName}'"]} + + try: + dbCreated = _ensureDatabaseExists(dbName) + except Exception as e: + logger.error("Failed to ensure database %s exists: %s", dbName, e) + return {"database": dbName, "tables": {}, "recordCount": 0, + "warnings": [f"Datenbank '{dbName}' konnte nicht erstellt werden: {e}"]} + + protectedIdSet = set(protectedIds) + warnings: List[str] = [] + dbResult: Dict[str, int] = {} + excluded = _EXCLUDED_TABLES.get(dbName, set()) + + if dbCreated: + warnings.append(f"Datenbank '{dbName}' wurde neu erstellt") + + conn = _getConnection(dbName) + try: + existingTables = set(_listTables(conn)) + conn.rollback() + + conn.autocommit = True + for tableName, tblPath in tableFiles.items(): + if tableName in excluded: + continue + if tableName not in existingTables: + sampleRows: List[dict] = [] + with open(tblPath, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + sampleRows.append(json.loads(line)) + if len(sampleRows) >= 10: + break + if sampleRows: + _createTableFromExport(conn, tableName, sampleRows) + existingTables.add(tableName) + logger.info("Pre-created missing table %s.%s", dbName, tableName) + + importable = [t for t in tableFiles if t not in excluded and t in existingTables] + importOrder = _getTableImportOrder(conn, importable, dbName) + + logger.info("Import order for %s: %s", dbName, importOrder) + + for tableName in tableFiles: + if tableName in excluded: + warnings.append(f"Table '{dbName}.{tableName}' excluded (security/transient)") + + if mode == "replace": + conn.autocommit = False + for tableName in reversed(importOrder): + try: + _deleteNonProtected(conn, tableName, protectedIdSet) + conn.commit() + except Exception as e: + conn.rollback() + warnings.append(f"DELETE from {dbName}.{tableName} failed: {e}") + logger.warning("DELETE from %s.%s failed: %s", dbName, tableName, e) + + conn.autocommit = False + batchSize = 100 + for tableName in importOrder: + tblPath = tableFiles.get(tableName) + if not tblPath: + continue + try: + physicalCols = _getPhysicalColumns(conn, tableName) + if not physicalCols: + conn.rollback() + continue + + insertedCount = 0 + batch: List[dict] = [] + with open(tblPath, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + row = json.loads(line) + if _isProtectedRow(tableName, row): + continue + if row.get("id") and str(row["id"]) in protectedIdSet: + continue + batch.append(row) + if len(batch) >= batchSize: + insertedCount += _insertRows(conn, tableName, batch, physicalCols, mode) + batch = [] + if batch: + insertedCount += _insertRows(conn, tableName, batch, physicalCols, mode) + + conn.commit() + dbResult[tableName] = insertedCount + except Exception as e: + conn.rollback() + warnings.append(f"INSERT into {dbName}.{tableName} failed: {e}") + logger.warning("INSERT into %s.%s failed: %s", dbName, tableName, e) + except Exception as e: + logger.error("Import failed for database %s: %s", dbName, e) + return {"database": dbName, "tables": {}, "recordCount": 0, + "warnings": [f"Import fuer '{dbName}' fehlgeschlagen: {e}"]} + finally: + conn.close() + + recordCount = sum(dbResult.values()) + return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings} diff --git a/requirements.txt b/requirements.txt index 9aafd048..0e330d51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -83,6 +83,9 @@ pytest-asyncio>=0.21.0 ## Configuration Validation jsonschema>=4.0.0 # Required for chatbot workflow config validation +## Streaming JSON parser (memory-safe import of large DB exports) +ijson>=3.2.0 + ## For Scheduling / Repeated Tasks APScheduler==3.11.0