diff --git a/modules/routes/routeAdminDatabaseHealth.py b/modules/routes/routeAdminDatabaseHealth.py index 61b56bac..02343e83 100644 --- a/modules/routes/routeAdminDatabaseHealth.py +++ b/modules/routes/routeAdminDatabaseHealth.py @@ -35,7 +35,6 @@ from modules.system.databaseMigration import ( _importDatabases, _importSingleDb, _importSingleDbFromFiles, - _prepareImport, _streamSplitToFiles, _streamValidate, _validateImportPayload, @@ -532,39 +531,6 @@ def getMigrationExportStream( ) -def _processUploadedFile(filePath: str, tmpDir: str, token: str) -> dict: - """Streaming validate + split: never loads the full JSON into RAM. - - Pass 1 (``_streamValidate``): extract meta, count tables/records, - detect system objects, build ID remap -- constant memory. - - Pass 2 (``_streamSplitToFiles``): iterate rows again, apply remap, - write per-table JSONL temp files -- one row in RAM at a time. - """ - import os - - result = _streamValidate(filePath) - - if not result.get("valid"): - try: - os.remove(filePath) - except OSError: - pass - return {"result": result, "dbFiles": {}} - - remap = result.get("remap", {}) - protectedIds = result.get("protectedIds", []) - - dbFiles = _streamSplitToFiles(filePath, tmpDir, token, remap) - - try: - os.remove(filePath) - except OSError: - pass - - return {"result": result, "dbFiles": dbFiles, "protectedIds": protectedIds} - - @router.post("/migration/upload-import") @limiter.limit("5/minute") async def postMigrationUploadImport( @@ -572,10 +538,9 @@ async def postMigrationUploadImport( file: UploadFile = File(...), currentUser: User = Depends(requireSysAdmin), ) -> Dict[str, Any]: - """Upload a backup file to disk (chunked), validate, remap IDs, - split into per-DB temp files so the full payload doesn't stay in RAM. + """Upload a backup file to disk (chunked). Returns a token that the + frontend passes to ``/process-import-stream`` for streaming validation. """ - import asyncio import os import tempfile import uuid @@ -602,50 +567,120 @@ async def postMigrationUploadImport( os.remove(filePath) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Upload failed: {e}") from e - logger.info("SysAdmin migration upload-import: %s bytes on disk (%.1f MB)", - totalBytes, totalBytes / 1024 / 1024) + fileSizeMb = round(totalBytes / (1024 * 1024), 1) + logger.info("SysAdmin migration upload-import: %s bytes on disk (%.1f MB)", totalBytes, fileSizeMb) - try: - processed = await asyncio.to_thread(_processUploadedFile, filePath, tmpDir, token) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - if os.path.exists(filePath): - os.remove(filePath) - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON file: {e}") from e - except Exception as e: - logger.exception("Processing uploaded import file failed: %s", e) - if os.path.exists(filePath): - os.remove(filePath) - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Processing failed: {e}") from e + _pendingProcessing[token] = {"filePath": filePath, "tmpDir": tmpDir} - result = processed["result"] - dbFiles = processed.get("dbFiles", {}) - - if not result.get("valid"): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"message": "Payload validation failed", "warnings": result.get("warnings", [])}, - ) - - logger.info("SysAdmin migration upload-import: split into %d per-DB files, payload freed", - len(dbFiles)) - - _pendingImports[token] = { - "dbFiles": dbFiles, - "protectedIds": processed.get("protectedIds", []), - } - - return { - "token": token, - "valid": result.get("valid", False), - "databases": result.get("databases", []), - "warnings": result.get("warnings", []), - "systemObjectsFound": result.get("systemObjectsFound", []), - } + return {"token": token, "fileSizeMb": fileSizeMb} +_pendingProcessing: Dict[str, dict] = {} _pendingImports: Dict[str, dict] = {} +@router.get("/migration/process-import-stream") +@limiter.limit("5/minute") +def getProcessImportStream( + request: Request, + token: str, + currentUser: User = Depends(requireSysAdmin), +): + """Stream validation + split progress as newline-delimited JSON. + + Each line is a JSON object: + - ``{"phase":"validate","db":"...","table":"...","rows":N}`` + - ``{"phase":"split","db":"...","table":"...","rows":N}`` + - ``{"phase":"done","result":{valid, databases, warnings, ...}}`` + - ``{"phase":"error","detail":"..."}`` + """ + import os + import queue + import threading + + pending = _pendingProcessing.pop(token, None) + if not pending: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired processing token.") + + filePath = pending["filePath"] + tmpDir = pending["tmpDir"] + + if not os.path.exists(filePath): + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail="Upload file not found.") + + q: queue.Queue = queue.Queue() + + def _progressCb(info: dict): + q.put(info) + + def _worker(): + try: + result = _streamValidate(filePath, progressCb=_progressCb) + + if not result.get("valid"): + try: + os.remove(filePath) + except OSError: + pass + q.put({"phase": "done", "result": { + "valid": False, + "databases": [], + "warnings": result.get("warnings", []), + "systemObjectsFound": result.get("systemObjectsFound", []), + }}) + q.put(None) + return + + remap = result.get("remap", {}) + protectedIds = result.get("protectedIds", []) + + dbFiles = _streamSplitToFiles(filePath, tmpDir, token, remap, progressCb=_progressCb) + + try: + os.remove(filePath) + except OSError: + pass + + _pendingImports[token] = { + "dbFiles": dbFiles, + "protectedIds": protectedIds, + } + + q.put({"phase": "done", "result": { + "token": token, + "valid": True, + "databases": result.get("databases", []), + "warnings": result.get("warnings", []), + "systemObjectsFound": result.get("systemObjectsFound", []), + }}) + except Exception as e: + logger.exception("Processing import stream failed: %s", e) + try: + os.remove(filePath) + except OSError: + pass + q.put({"phase": "error", "detail": str(e)}) + finally: + q.put(None) + + def _generate(): + thread = threading.Thread(target=_worker, daemon=True) + thread.start() + while True: + item = q.get() + if item is None: + break + yield json.dumps(item, ensure_ascii=False) + "\n" + thread.join(timeout=5) + + return StreamingResponse( + _generate(), + media_type="text/x-ndjson", + ) + + @router.post("/migration/import-single") @limiter.limit("60/minute") def postMigrationImportSingle( diff --git a/modules/system/databaseMigration.py b/modules/system/databaseMigration.py index da54d4a7..2607eb6b 100644 --- a/modules/system/databaseMigration.py +++ b/modules/system/databaseMigration.py @@ -1007,12 +1007,15 @@ def _iterStreamRows(filePath: str): stack.pop() -def _streamValidate(filePath: str) -> dict: +def _streamValidate(filePath: str, progressCb=None) -> dict: """Stream-validate an import file without loading it into RAM. Extracts ``meta``, counts databases/tables/records, detects system objects, and builds the ID remap -- all with constant memory usage. + ``progressCb``, when provided, is called on every database/table transition + with ``{"phase": "validate", "db": str, "table": str, "rows": int}``. + Returns the same shape as ``_prepareImport`` plus ``remap``. """ import ijson @@ -1030,6 +1033,8 @@ def _streamValidate(filePath: str) -> dict: registeredDbs = getRegisteredDatabases() dbTableCounts: Dict[str, Dict[str, int]] = {} systemObjectsFound: List[dict] = [] + prevDb: Optional[str] = None + prevTable: Optional[str] = None for dbName, tableName, row in _iterStreamRows(filePath): if dbName not in dbTableCounts: @@ -1038,6 +1043,13 @@ def _streamValidate(filePath: str) -> dict: dbTableCounts[dbName][tableName] = 0 dbTableCounts[dbName][tableName] += 1 + if progressCb and (dbName != prevDb or tableName != prevTable): + if prevDb and prevTable: + progressCb({"phase": "validate", "db": prevDb, "table": prevTable, + "rows": dbTableCounts[prevDb][prevTable]}) + prevDb = dbName + prevTable = tableName + if dbName == "poweron_app": if (tableName == "Mandate" and row.get("name") == "root" @@ -1062,6 +1074,10 @@ def _streamValidate(filePath: str) -> dict: "payloadId": row.get("id"), }) + if progressCb and prevDb and prevTable: + progressCb({"phase": "validate", "db": prevDb, "table": prevTable, + "rows": dbTableCounts[prevDb][prevTable]}) + if not dbTableCounts: warnings.append("Fehlende oder ungueltige 'databases'-Sektion") @@ -1124,10 +1140,14 @@ def _streamSplitToFiles( tmpDir: str, token: str, remap: Dict[str, str], + progressCb=None, ) -> Dict[str, Dict[str, str]]: """Stream through the export file a second time, applying ID remap on each row and writing per-table JSONL temp files. + ``progressCb``, when provided, is called on every database/table transition + with ``{"phase": "split", "db": str, "table": str, "rows": int}``. + Returns ``{dbName: {tableName: filePath}}``. """ import os @@ -1135,6 +1155,9 @@ def _streamSplitToFiles( remapSet = set(remap.keys()) if remap else set() dbFiles: Dict[str, Dict[str, str]] = {} writers: Dict[Tuple[str, str], Any] = {} + tableCounts: Dict[Tuple[str, str], int] = {} + prevDb: Optional[str] = None + prevTable: Optional[str] = None try: for dbName, tableName, row in _iterStreamRows(filePath): @@ -1143,6 +1166,12 @@ def _streamSplitToFiles( key = (dbName, tableName) if key not in writers: + if progressCb and prevDb and prevTable: + progressCb({"phase": "split", "db": prevDb, "table": prevTable, + "rows": tableCounts.get((prevDb, prevTable), 0)}) + prevDb = dbName + prevTable = tableName + tblPath = os.path.join( tmpDir, f"poweron_import_{token}_{dbName}__{tableName}.jsonl", @@ -1151,13 +1180,19 @@ def _streamSplitToFiles( if dbName not in dbFiles: dbFiles[dbName] = {} dbFiles[dbName][tableName] = tblPath + tableCounts[key] = 0 writers[key].write(json.dumps(row, ensure_ascii=False, default=str)) writers[key].write("\n") + tableCounts[key] += 1 finally: for fh in writers.values(): fh.close() + if progressCb and prevDb and prevTable: + progressCb({"phase": "split", "db": prevDb, "table": prevTable, + "rows": tableCounts.get((prevDb, prevTable), 0)}) + return dbFiles