fix db import
This commit is contained in:
parent
8c2e9d2183
commit
a2c5360364
2 changed files with 145 additions and 75 deletions
|
|
@ -35,7 +35,6 @@ from modules.system.databaseMigration import (
|
|||
_importDatabases,
|
||||
_importSingleDb,
|
||||
_importSingleDbFromFiles,
|
||||
_prepareImport,
|
||||
_streamSplitToFiles,
|
||||
_streamValidate,
|
||||
_validateImportPayload,
|
||||
|
|
@ -532,39 +531,6 @@ def getMigrationExportStream(
|
|||
)
|
||||
|
||||
|
||||
def _processUploadedFile(filePath: str, tmpDir: str, token: str) -> dict:
|
||||
"""Streaming validate + split: never loads the full JSON into RAM.
|
||||
|
||||
Pass 1 (``_streamValidate``): extract meta, count tables/records,
|
||||
detect system objects, build ID remap -- constant memory.
|
||||
|
||||
Pass 2 (``_streamSplitToFiles``): iterate rows again, apply remap,
|
||||
write per-table JSONL temp files -- one row in RAM at a time.
|
||||
"""
|
||||
import os
|
||||
|
||||
result = _streamValidate(filePath)
|
||||
|
||||
if not result.get("valid"):
|
||||
try:
|
||||
os.remove(filePath)
|
||||
except OSError:
|
||||
pass
|
||||
return {"result": result, "dbFiles": {}}
|
||||
|
||||
remap = result.get("remap", {})
|
||||
protectedIds = result.get("protectedIds", [])
|
||||
|
||||
dbFiles = _streamSplitToFiles(filePath, tmpDir, token, remap)
|
||||
|
||||
try:
|
||||
os.remove(filePath)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return {"result": result, "dbFiles": dbFiles, "protectedIds": protectedIds}
|
||||
|
||||
|
||||
@router.post("/migration/upload-import")
|
||||
@limiter.limit("5/minute")
|
||||
async def postMigrationUploadImport(
|
||||
|
|
@ -572,10 +538,9 @@ async def postMigrationUploadImport(
|
|||
file: UploadFile = File(...),
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
) -> Dict[str, Any]:
|
||||
"""Upload a backup file to disk (chunked), validate, remap IDs,
|
||||
split into per-DB temp files so the full payload doesn't stay in RAM.
|
||||
"""Upload a backup file to disk (chunked). Returns a token that the
|
||||
frontend passes to ``/process-import-stream`` for streaming validation.
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
|
|
@ -602,50 +567,120 @@ async def postMigrationUploadImport(
|
|||
os.remove(filePath)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Upload failed: {e}") from e
|
||||
|
||||
logger.info("SysAdmin migration upload-import: %s bytes on disk (%.1f MB)",
|
||||
totalBytes, totalBytes / 1024 / 1024)
|
||||
fileSizeMb = round(totalBytes / (1024 * 1024), 1)
|
||||
logger.info("SysAdmin migration upload-import: %s bytes on disk (%.1f MB)", totalBytes, fileSizeMb)
|
||||
|
||||
try:
|
||||
processed = await asyncio.to_thread(_processUploadedFile, filePath, tmpDir, token)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
if os.path.exists(filePath):
|
||||
os.remove(filePath)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON file: {e}") from e
|
||||
except Exception as e:
|
||||
logger.exception("Processing uploaded import file failed: %s", e)
|
||||
if os.path.exists(filePath):
|
||||
os.remove(filePath)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Processing failed: {e}") from e
|
||||
_pendingProcessing[token] = {"filePath": filePath, "tmpDir": tmpDir}
|
||||
|
||||
result = processed["result"]
|
||||
dbFiles = processed.get("dbFiles", {})
|
||||
|
||||
if not result.get("valid"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"message": "Payload validation failed", "warnings": result.get("warnings", [])},
|
||||
)
|
||||
|
||||
logger.info("SysAdmin migration upload-import: split into %d per-DB files, payload freed",
|
||||
len(dbFiles))
|
||||
|
||||
_pendingImports[token] = {
|
||||
"dbFiles": dbFiles,
|
||||
"protectedIds": processed.get("protectedIds", []),
|
||||
}
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"valid": result.get("valid", False),
|
||||
"databases": result.get("databases", []),
|
||||
"warnings": result.get("warnings", []),
|
||||
"systemObjectsFound": result.get("systemObjectsFound", []),
|
||||
}
|
||||
return {"token": token, "fileSizeMb": fileSizeMb}
|
||||
|
||||
|
||||
_pendingProcessing: Dict[str, dict] = {}
|
||||
_pendingImports: Dict[str, dict] = {}
|
||||
|
||||
|
||||
@router.get("/migration/process-import-stream")
|
||||
@limiter.limit("5/minute")
|
||||
def getProcessImportStream(
|
||||
request: Request,
|
||||
token: str,
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
):
|
||||
"""Stream validation + split progress as newline-delimited JSON.
|
||||
|
||||
Each line is a JSON object:
|
||||
- ``{"phase":"validate","db":"...","table":"...","rows":N}``
|
||||
- ``{"phase":"split","db":"...","table":"...","rows":N}``
|
||||
- ``{"phase":"done","result":{valid, databases, warnings, ...}}``
|
||||
- ``{"phase":"error","detail":"..."}``
|
||||
"""
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
|
||||
pending = _pendingProcessing.pop(token, None)
|
||||
if not pending:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid or expired processing token.")
|
||||
|
||||
filePath = pending["filePath"]
|
||||
tmpDir = pending["tmpDir"]
|
||||
|
||||
if not os.path.exists(filePath):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Upload file not found.")
|
||||
|
||||
q: queue.Queue = queue.Queue()
|
||||
|
||||
def _progressCb(info: dict):
|
||||
q.put(info)
|
||||
|
||||
def _worker():
|
||||
try:
|
||||
result = _streamValidate(filePath, progressCb=_progressCb)
|
||||
|
||||
if not result.get("valid"):
|
||||
try:
|
||||
os.remove(filePath)
|
||||
except OSError:
|
||||
pass
|
||||
q.put({"phase": "done", "result": {
|
||||
"valid": False,
|
||||
"databases": [],
|
||||
"warnings": result.get("warnings", []),
|
||||
"systemObjectsFound": result.get("systemObjectsFound", []),
|
||||
}})
|
||||
q.put(None)
|
||||
return
|
||||
|
||||
remap = result.get("remap", {})
|
||||
protectedIds = result.get("protectedIds", [])
|
||||
|
||||
dbFiles = _streamSplitToFiles(filePath, tmpDir, token, remap, progressCb=_progressCb)
|
||||
|
||||
try:
|
||||
os.remove(filePath)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
_pendingImports[token] = {
|
||||
"dbFiles": dbFiles,
|
||||
"protectedIds": protectedIds,
|
||||
}
|
||||
|
||||
q.put({"phase": "done", "result": {
|
||||
"token": token,
|
||||
"valid": True,
|
||||
"databases": result.get("databases", []),
|
||||
"warnings": result.get("warnings", []),
|
||||
"systemObjectsFound": result.get("systemObjectsFound", []),
|
||||
}})
|
||||
except Exception as e:
|
||||
logger.exception("Processing import stream failed: %s", e)
|
||||
try:
|
||||
os.remove(filePath)
|
||||
except OSError:
|
||||
pass
|
||||
q.put({"phase": "error", "detail": str(e)})
|
||||
finally:
|
||||
q.put(None)
|
||||
|
||||
def _generate():
|
||||
thread = threading.Thread(target=_worker, daemon=True)
|
||||
thread.start()
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
break
|
||||
yield json.dumps(item, ensure_ascii=False) + "\n"
|
||||
thread.join(timeout=5)
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="text/x-ndjson",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/migration/import-single")
|
||||
@limiter.limit("60/minute")
|
||||
def postMigrationImportSingle(
|
||||
|
|
|
|||
|
|
@ -1007,12 +1007,15 @@ def _iterStreamRows(filePath: str):
|
|||
stack.pop()
|
||||
|
||||
|
||||
def _streamValidate(filePath: str) -> dict:
|
||||
def _streamValidate(filePath: str, progressCb=None) -> dict:
|
||||
"""Stream-validate an import file without loading it into RAM.
|
||||
|
||||
Extracts ``meta``, counts databases/tables/records, detects system objects,
|
||||
and builds the ID remap -- all with constant memory usage.
|
||||
|
||||
``progressCb``, when provided, is called on every database/table transition
|
||||
with ``{"phase": "validate", "db": str, "table": str, "rows": int}``.
|
||||
|
||||
Returns the same shape as ``_prepareImport`` plus ``remap``.
|
||||
"""
|
||||
import ijson
|
||||
|
|
@ -1030,6 +1033,8 @@ def _streamValidate(filePath: str) -> dict:
|
|||
registeredDbs = getRegisteredDatabases()
|
||||
dbTableCounts: Dict[str, Dict[str, int]] = {}
|
||||
systemObjectsFound: List[dict] = []
|
||||
prevDb: Optional[str] = None
|
||||
prevTable: Optional[str] = None
|
||||
|
||||
for dbName, tableName, row in _iterStreamRows(filePath):
|
||||
if dbName not in dbTableCounts:
|
||||
|
|
@ -1038,6 +1043,13 @@ def _streamValidate(filePath: str) -> dict:
|
|||
dbTableCounts[dbName][tableName] = 0
|
||||
dbTableCounts[dbName][tableName] += 1
|
||||
|
||||
if progressCb and (dbName != prevDb or tableName != prevTable):
|
||||
if prevDb and prevTable:
|
||||
progressCb({"phase": "validate", "db": prevDb, "table": prevTable,
|
||||
"rows": dbTableCounts[prevDb][prevTable]})
|
||||
prevDb = dbName
|
||||
prevTable = tableName
|
||||
|
||||
if dbName == "poweron_app":
|
||||
if (tableName == "Mandate"
|
||||
and row.get("name") == "root"
|
||||
|
|
@ -1062,6 +1074,10 @@ def _streamValidate(filePath: str) -> dict:
|
|||
"payloadId": row.get("id"),
|
||||
})
|
||||
|
||||
if progressCb and prevDb and prevTable:
|
||||
progressCb({"phase": "validate", "db": prevDb, "table": prevTable,
|
||||
"rows": dbTableCounts[prevDb][prevTable]})
|
||||
|
||||
if not dbTableCounts:
|
||||
warnings.append("Fehlende oder ungueltige 'databases'-Sektion")
|
||||
|
||||
|
|
@ -1124,10 +1140,14 @@ def _streamSplitToFiles(
|
|||
tmpDir: str,
|
||||
token: str,
|
||||
remap: Dict[str, str],
|
||||
progressCb=None,
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""Stream through the export file a second time, applying ID remap on
|
||||
each row and writing per-table JSONL temp files.
|
||||
|
||||
``progressCb``, when provided, is called on every database/table transition
|
||||
with ``{"phase": "split", "db": str, "table": str, "rows": int}``.
|
||||
|
||||
Returns ``{dbName: {tableName: filePath}}``.
|
||||
"""
|
||||
import os
|
||||
|
|
@ -1135,6 +1155,9 @@ def _streamSplitToFiles(
|
|||
remapSet = set(remap.keys()) if remap else set()
|
||||
dbFiles: Dict[str, Dict[str, str]] = {}
|
||||
writers: Dict[Tuple[str, str], Any] = {}
|
||||
tableCounts: Dict[Tuple[str, str], int] = {}
|
||||
prevDb: Optional[str] = None
|
||||
prevTable: Optional[str] = None
|
||||
|
||||
try:
|
||||
for dbName, tableName, row in _iterStreamRows(filePath):
|
||||
|
|
@ -1143,6 +1166,12 @@ def _streamSplitToFiles(
|
|||
|
||||
key = (dbName, tableName)
|
||||
if key not in writers:
|
||||
if progressCb and prevDb and prevTable:
|
||||
progressCb({"phase": "split", "db": prevDb, "table": prevTable,
|
||||
"rows": tableCounts.get((prevDb, prevTable), 0)})
|
||||
prevDb = dbName
|
||||
prevTable = tableName
|
||||
|
||||
tblPath = os.path.join(
|
||||
tmpDir,
|
||||
f"poweron_import_{token}_{dbName}__{tableName}.jsonl",
|
||||
|
|
@ -1151,13 +1180,19 @@ def _streamSplitToFiles(
|
|||
if dbName not in dbFiles:
|
||||
dbFiles[dbName] = {}
|
||||
dbFiles[dbName][tableName] = tblPath
|
||||
tableCounts[key] = 0
|
||||
|
||||
writers[key].write(json.dumps(row, ensure_ascii=False, default=str))
|
||||
writers[key].write("\n")
|
||||
tableCounts[key] += 1
|
||||
finally:
|
||||
for fh in writers.values():
|
||||
fh.close()
|
||||
|
||||
if progressCb and prevDb and prevTable:
|
||||
progressCb({"phase": "split", "db": prevDb, "table": prevTable,
|
||||
"rows": tableCounts.get((prevDb, prevTable), 0)})
|
||||
|
||||
return dbFiles
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue