db import streaming
This commit is contained in:
parent
3e2c07a776
commit
8c2e9d2183
3 changed files with 420 additions and 53 deletions
|
|
@ -34,7 +34,10 @@ from modules.system.databaseMigration import (
|
|||
_getInstanceLabel,
|
||||
_importDatabases,
|
||||
_importSingleDb,
|
||||
_importSingleDbFromFiles,
|
||||
_prepareImport,
|
||||
_streamSplitToFiles,
|
||||
_streamValidate,
|
||||
_validateImportPayload,
|
||||
streamExportGenerator,
|
||||
)
|
||||
|
|
@ -530,43 +533,35 @@ def getMigrationExportStream(
|
|||
|
||||
|
||||
def _processUploadedFile(filePath: str, tmpDir: str, token: str) -> dict:
|
||||
"""Parse JSON, validate, remap, split into per-DB files.
|
||||
"""Streaming validate + split: never loads the full JSON into RAM.
|
||||
|
||||
Runs in a thread pool to avoid blocking the asyncio event loop
|
||||
during the CPU-heavy json.load() of large (500+ MB) files.
|
||||
Pass 1 (``_streamValidate``): extract meta, count tables/records,
|
||||
detect system objects, build ID remap -- constant memory.
|
||||
|
||||
Pass 2 (``_streamSplitToFiles``): iterate rows again, apply remap,
|
||||
write per-table JSONL temp files -- one row in RAM at a time.
|
||||
"""
|
||||
import gc
|
||||
import os
|
||||
|
||||
with open(filePath, "r", encoding="utf-8") as f:
|
||||
payload = json.load(f)
|
||||
result = _streamValidate(filePath)
|
||||
|
||||
if not result.get("valid"):
|
||||
try:
|
||||
os.remove(filePath)
|
||||
except OSError:
|
||||
pass
|
||||
return {"result": result, "dbFiles": {}}
|
||||
|
||||
remap = result.get("remap", {})
|
||||
protectedIds = result.get("protectedIds", [])
|
||||
|
||||
dbFiles = _streamSplitToFiles(filePath, tmpDir, token, remap)
|
||||
|
||||
try:
|
||||
os.remove(filePath)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
result = _prepareImport(payload)
|
||||
|
||||
if not result.get("valid"):
|
||||
del payload
|
||||
gc.collect()
|
||||
return {"result": result, "dbFiles": {}}
|
||||
|
||||
protectedIds = result.get("protectedIds", [])
|
||||
|
||||
dbFiles = {}
|
||||
databases = payload.get("databases", {})
|
||||
for dbName, dbData in databases.items():
|
||||
dbPath = os.path.join(tmpDir, f"poweron_import_{token}_{dbName}.json")
|
||||
with open(dbPath, "w", encoding="utf-8") as dbF:
|
||||
json.dump(dbData, dbF, ensure_ascii=False, default=str)
|
||||
dbFiles[dbName] = dbPath
|
||||
|
||||
del payload
|
||||
del databases
|
||||
gc.collect()
|
||||
|
||||
return {"result": result, "dbFiles": dbFiles, "protectedIds": protectedIds}
|
||||
|
||||
|
||||
|
|
@ -617,6 +612,7 @@ async def postMigrationUploadImport(
|
|||
os.remove(filePath)
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid JSON file: {e}") from e
|
||||
except Exception as e:
|
||||
logger.exception("Processing uploaded import file failed: %s", e)
|
||||
if os.path.exists(filePath):
|
||||
os.remove(filePath)
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Processing failed: {e}") from e
|
||||
|
|
@ -657,7 +653,10 @@ def postMigrationImportSingle(
|
|||
body: dict,
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
) -> Dict[str, Any]:
|
||||
"""Import a single database from a previously uploaded + prepared payload.
|
||||
"""Import a single database from previously uploaded + prepared payload.
|
||||
|
||||
Supports both the new per-table JSONL format (``{tableName: filePath}``)
|
||||
and the legacy single-JSON-per-DB format (plain file path string).
|
||||
|
||||
Body: ``{token, database, mode}``
|
||||
"""
|
||||
|
|
@ -674,29 +673,26 @@ def postMigrationImportSingle(
|
|||
if not pending:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid or expired import token.")
|
||||
|
||||
dbFiles = pending.get("dbFiles", {})
|
||||
dbFilePath = dbFiles.get(database)
|
||||
if not dbFilePath or not os.path.exists(dbFilePath):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No data for database '{database}'.",
|
||||
)
|
||||
dbEntry = pending.get("dbFiles", {}).get(database)
|
||||
if not dbEntry:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"No data for database '{database}'.")
|
||||
|
||||
logger.info("SysAdmin migration import-single: user=%s db=%s mode=%s", currentUser.username, database, mode)
|
||||
|
||||
try:
|
||||
with open(dbFilePath, "r", encoding="utf-8") as f:
|
||||
dbData = json.load(f)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to read import data for '{database}': {e}",
|
||||
) from e
|
||||
|
||||
payload = {"databases": {database: dbData}}
|
||||
|
||||
try:
|
||||
result = _importSingleDb(payload, database, mode, pending["protectedIds"])
|
||||
if isinstance(dbEntry, dict):
|
||||
result = _importSingleDbFromFiles(dbEntry, database, mode, pending["protectedIds"])
|
||||
else:
|
||||
dbFilePath = dbEntry
|
||||
if not os.path.exists(dbFilePath):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No data for database '{database}'.")
|
||||
with open(dbFilePath, "r", encoding="utf-8") as f:
|
||||
dbData = json.load(f)
|
||||
payload = {"databases": {database: dbData}}
|
||||
result = _importSingleDb(payload, database, mode, pending["protectedIds"])
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Import-single failed for %s: %s", database, e)
|
||||
raise HTTPException(
|
||||
|
|
@ -714,15 +710,22 @@ def postMigrationImportDone(
|
|||
body: dict,
|
||||
currentUser: User = Depends(requireSysAdmin),
|
||||
) -> Dict[str, Any]:
|
||||
"""Clean up the per-DB temp files."""
|
||||
"""Clean up the per-DB / per-table temp files."""
|
||||
import os
|
||||
|
||||
token = body.get("token", "")
|
||||
pending = _pendingImports.pop(token, None)
|
||||
if pending:
|
||||
for dbPath in pending.get("dbFiles", {}).values():
|
||||
try:
|
||||
os.remove(dbPath)
|
||||
except OSError:
|
||||
pass
|
||||
for dbEntry in pending.get("dbFiles", {}).values():
|
||||
if isinstance(dbEntry, str):
|
||||
try:
|
||||
os.remove(dbEntry)
|
||||
except OSError:
|
||||
pass
|
||||
elif isinstance(dbEntry, dict):
|
||||
for tblPath in dbEntry.values():
|
||||
try:
|
||||
os.remove(tblPath)
|
||||
except OSError:
|
||||
pass
|
||||
return {"ok": True}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ references stay consistent.
|
|||
All functions are intended for SysAdmin use only (access control in the route layer).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
|
@ -929,3 +930,363 @@ def _importSingleDb(payload: dict, dbName: str, mode: str, protectedIds: List[st
|
|||
|
||||
recordCount = sum(dbResult.values())
|
||||
return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming Import (memory-safe, ijson-based)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _iterStreamRows(filePath: str):
|
||||
"""Yield ``(dbName, tableName, rowDict)`` one row at a time from an export
|
||||
JSON file using ijson streaming parser. Never holds more than one row in RAM.
|
||||
"""
|
||||
import ijson
|
||||
|
||||
with open(filePath, "rb") as f:
|
||||
currentDb: Optional[str] = None
|
||||
currentTable: Optional[str] = None
|
||||
rowPrefix: Optional[str] = None
|
||||
inRow = False
|
||||
stack: List[Tuple[Any, Optional[str]]] = []
|
||||
root: dict = {}
|
||||
|
||||
for prefix, event, value in ijson.parse(f, use_float=True):
|
||||
if not inRow:
|
||||
if prefix == "databases" and event == "map_key":
|
||||
currentDb = value
|
||||
currentTable = None
|
||||
rowPrefix = None
|
||||
elif (currentDb
|
||||
and prefix == f"databases.{currentDb}.tables"
|
||||
and event == "map_key"):
|
||||
currentTable = value
|
||||
rowPrefix = f"databases.{currentDb}.tables.{currentTable}.item"
|
||||
elif rowPrefix and prefix == rowPrefix and event == "start_map":
|
||||
inRow = True
|
||||
root = {}
|
||||
stack = [(root, None)]
|
||||
continue
|
||||
|
||||
container, pendingKey = stack[-1]
|
||||
|
||||
if event == "map_key":
|
||||
stack[-1] = (container, value)
|
||||
|
||||
elif event in ("string", "number", "boolean", "null"):
|
||||
if isinstance(container, dict):
|
||||
container[pendingKey] = value
|
||||
stack[-1] = (container, None)
|
||||
else:
|
||||
container.append(value)
|
||||
|
||||
elif event == "start_map":
|
||||
nested: dict = {}
|
||||
if isinstance(container, dict):
|
||||
container[pendingKey] = nested
|
||||
stack[-1] = (container, None)
|
||||
else:
|
||||
container.append(nested)
|
||||
stack.append((nested, None))
|
||||
|
||||
elif event == "end_map":
|
||||
stack.pop()
|
||||
if not stack:
|
||||
yield (currentDb, currentTable, root)
|
||||
inRow = False
|
||||
|
||||
elif event == "start_array":
|
||||
nested_list: list = []
|
||||
if isinstance(container, dict):
|
||||
container[pendingKey] = nested_list
|
||||
stack[-1] = (container, None)
|
||||
else:
|
||||
container.append(nested_list)
|
||||
stack.append((nested_list, None))
|
||||
|
||||
elif event == "end_array":
|
||||
stack.pop()
|
||||
|
||||
|
||||
def _streamValidate(filePath: str) -> dict:
|
||||
"""Stream-validate an import file without loading it into RAM.
|
||||
|
||||
Extracts ``meta``, counts databases/tables/records, detects system objects,
|
||||
and builds the ID remap -- all with constant memory usage.
|
||||
|
||||
Returns the same shape as ``_prepareImport`` plus ``remap``.
|
||||
"""
|
||||
import ijson
|
||||
|
||||
meta = None
|
||||
with open(filePath, "rb") as f:
|
||||
for m in ijson.items(f, "meta"):
|
||||
meta = m
|
||||
break
|
||||
|
||||
warnings: List[str] = []
|
||||
if not meta:
|
||||
warnings.append("Fehlende oder ungueltige 'meta'-Sektion")
|
||||
|
||||
registeredDbs = getRegisteredDatabases()
|
||||
dbTableCounts: Dict[str, Dict[str, int]] = {}
|
||||
systemObjectsFound: List[dict] = []
|
||||
|
||||
for dbName, tableName, row in _iterStreamRows(filePath):
|
||||
if dbName not in dbTableCounts:
|
||||
dbTableCounts[dbName] = {}
|
||||
if tableName not in dbTableCounts[dbName]:
|
||||
dbTableCounts[dbName][tableName] = 0
|
||||
dbTableCounts[dbName][tableName] += 1
|
||||
|
||||
if dbName == "poweron_app":
|
||||
if (tableName == "Mandate"
|
||||
and row.get("name") == "root"
|
||||
and row.get("isSystem") is True):
|
||||
systemObjectsFound.append({
|
||||
"type": "mandate",
|
||||
"label": "Root Mandate",
|
||||
"payloadId": row.get("id"),
|
||||
})
|
||||
elif tableName == "UserInDB":
|
||||
uname = row.get("username")
|
||||
if uname == "admin":
|
||||
systemObjectsFound.append({
|
||||
"type": "user",
|
||||
"label": "Admin User",
|
||||
"payloadId": row.get("id"),
|
||||
})
|
||||
elif uname == "event":
|
||||
systemObjectsFound.append({
|
||||
"type": "user",
|
||||
"label": "Event User",
|
||||
"payloadId": row.get("id"),
|
||||
})
|
||||
|
||||
if not dbTableCounts:
|
||||
warnings.append("Fehlende oder ungueltige 'databases'-Sektion")
|
||||
|
||||
summary: List[dict] = []
|
||||
for dbName, tables in dbTableCounts.items():
|
||||
registered = dbName in registeredDbs
|
||||
if not registered:
|
||||
warnings.append(f"Datenbank '{dbName}' ist nicht registriert und wird uebersprungen")
|
||||
summary.append({
|
||||
"database": dbName,
|
||||
"tableCount": len(tables),
|
||||
"recordCount": sum(tables.values()),
|
||||
"registered": registered,
|
||||
})
|
||||
|
||||
liveIds = _loadLiveSystemObjectIds()
|
||||
|
||||
remap: Dict[str, str] = {}
|
||||
for obj in systemObjectsFound:
|
||||
oldId = str(obj.get("payloadId", ""))
|
||||
if obj["type"] == "mandate":
|
||||
newId = liveIds.get("rootMandate", "")
|
||||
elif obj["label"] == "Admin User":
|
||||
newId = liveIds.get("adminUser", "")
|
||||
elif obj["label"] == "Event User":
|
||||
newId = liveIds.get("eventUser", "")
|
||||
else:
|
||||
continue
|
||||
if oldId and newId and oldId != newId:
|
||||
remap[oldId] = newId
|
||||
|
||||
if remap:
|
||||
logger.info("System-object ID remap: %s", remap)
|
||||
|
||||
protectedIdSet = set(liveIds.values())
|
||||
valid = any(s["registered"] for s in summary)
|
||||
|
||||
dbList: List[dict] = []
|
||||
for s in summary:
|
||||
if s["registered"]:
|
||||
dbList.append({
|
||||
"database": s["database"],
|
||||
"tableCount": s["tableCount"],
|
||||
"recordCount": s["recordCount"],
|
||||
})
|
||||
|
||||
return {
|
||||
"valid": valid,
|
||||
"summary": summary,
|
||||
"warnings": warnings,
|
||||
"systemObjectsFound": systemObjectsFound,
|
||||
"databases": dbList,
|
||||
"protectedIds": list(protectedIdSet),
|
||||
"remap": remap,
|
||||
}
|
||||
|
||||
|
||||
def _streamSplitToFiles(
|
||||
filePath: str,
|
||||
tmpDir: str,
|
||||
token: str,
|
||||
remap: Dict[str, str],
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""Stream through the export file a second time, applying ID remap on
|
||||
each row and writing per-table JSONL temp files.
|
||||
|
||||
Returns ``{dbName: {tableName: filePath}}``.
|
||||
"""
|
||||
import os
|
||||
|
||||
remapSet = set(remap.keys()) if remap else set()
|
||||
dbFiles: Dict[str, Dict[str, str]] = {}
|
||||
writers: Dict[Tuple[str, str], Any] = {}
|
||||
|
||||
try:
|
||||
for dbName, tableName, row in _iterStreamRows(filePath):
|
||||
if remap:
|
||||
_remapRowValues(row, remap, remapSet)
|
||||
|
||||
key = (dbName, tableName)
|
||||
if key not in writers:
|
||||
tblPath = os.path.join(
|
||||
tmpDir,
|
||||
f"poweron_import_{token}_{dbName}__{tableName}.jsonl",
|
||||
)
|
||||
writers[key] = open(tblPath, "w", encoding="utf-8")
|
||||
if dbName not in dbFiles:
|
||||
dbFiles[dbName] = {}
|
||||
dbFiles[dbName][tableName] = tblPath
|
||||
|
||||
writers[key].write(json.dumps(row, ensure_ascii=False, default=str))
|
||||
writers[key].write("\n")
|
||||
finally:
|
||||
for fh in writers.values():
|
||||
fh.close()
|
||||
|
||||
return dbFiles
|
||||
|
||||
|
||||
def _importSingleDbFromFiles(
|
||||
tableFiles: Dict[str, str],
|
||||
dbName: str,
|
||||
mode: str,
|
||||
protectedIds: List[str],
|
||||
) -> dict:
|
||||
"""Import a single database from per-table JSONL files.
|
||||
|
||||
Each file contains one JSON object per line (rows).
|
||||
Tables are sorted by FK dependencies before import.
|
||||
|
||||
Returns ``{database, tables, recordCount, warnings}``.
|
||||
"""
|
||||
if mode not in ("replace", "merge"):
|
||||
raise ValueError(f"Invalid import mode: {mode}")
|
||||
|
||||
registeredDbs = getRegisteredDatabases()
|
||||
if dbName not in registeredDbs:
|
||||
return {"database": dbName, "tables": {}, "recordCount": 0,
|
||||
"warnings": [f"Datenbank '{dbName}' nicht registriert"]}
|
||||
|
||||
if not tableFiles:
|
||||
return {"database": dbName, "tables": {}, "recordCount": 0,
|
||||
"warnings": [f"Keine Daten fuer '{dbName}'"]}
|
||||
|
||||
try:
|
||||
dbCreated = _ensureDatabaseExists(dbName)
|
||||
except Exception as e:
|
||||
logger.error("Failed to ensure database %s exists: %s", dbName, e)
|
||||
return {"database": dbName, "tables": {}, "recordCount": 0,
|
||||
"warnings": [f"Datenbank '{dbName}' konnte nicht erstellt werden: {e}"]}
|
||||
|
||||
protectedIdSet = set(protectedIds)
|
||||
warnings: List[str] = []
|
||||
dbResult: Dict[str, int] = {}
|
||||
excluded = _EXCLUDED_TABLES.get(dbName, set())
|
||||
|
||||
if dbCreated:
|
||||
warnings.append(f"Datenbank '{dbName}' wurde neu erstellt")
|
||||
|
||||
conn = _getConnection(dbName)
|
||||
try:
|
||||
existingTables = set(_listTables(conn))
|
||||
conn.rollback()
|
||||
|
||||
conn.autocommit = True
|
||||
for tableName, tblPath in tableFiles.items():
|
||||
if tableName in excluded:
|
||||
continue
|
||||
if tableName not in existingTables:
|
||||
sampleRows: List[dict] = []
|
||||
with open(tblPath, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
sampleRows.append(json.loads(line))
|
||||
if len(sampleRows) >= 10:
|
||||
break
|
||||
if sampleRows:
|
||||
_createTableFromExport(conn, tableName, sampleRows)
|
||||
existingTables.add(tableName)
|
||||
logger.info("Pre-created missing table %s.%s", dbName, tableName)
|
||||
|
||||
importable = [t for t in tableFiles if t not in excluded and t in existingTables]
|
||||
importOrder = _getTableImportOrder(conn, importable, dbName)
|
||||
|
||||
logger.info("Import order for %s: %s", dbName, importOrder)
|
||||
|
||||
for tableName in tableFiles:
|
||||
if tableName in excluded:
|
||||
warnings.append(f"Table '{dbName}.{tableName}' excluded (security/transient)")
|
||||
|
||||
if mode == "replace":
|
||||
conn.autocommit = False
|
||||
for tableName in reversed(importOrder):
|
||||
try:
|
||||
_deleteNonProtected(conn, tableName, protectedIdSet)
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
warnings.append(f"DELETE from {dbName}.{tableName} failed: {e}")
|
||||
logger.warning("DELETE from %s.%s failed: %s", dbName, tableName, e)
|
||||
|
||||
conn.autocommit = False
|
||||
batchSize = 100
|
||||
for tableName in importOrder:
|
||||
tblPath = tableFiles.get(tableName)
|
||||
if not tblPath:
|
||||
continue
|
||||
try:
|
||||
physicalCols = _getPhysicalColumns(conn, tableName)
|
||||
if not physicalCols:
|
||||
conn.rollback()
|
||||
continue
|
||||
|
||||
insertedCount = 0
|
||||
batch: List[dict] = []
|
||||
with open(tblPath, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
row = json.loads(line)
|
||||
if _isProtectedRow(tableName, row):
|
||||
continue
|
||||
if row.get("id") and str(row["id"]) in protectedIdSet:
|
||||
continue
|
||||
batch.append(row)
|
||||
if len(batch) >= batchSize:
|
||||
insertedCount += _insertRows(conn, tableName, batch, physicalCols, mode)
|
||||
batch = []
|
||||
if batch:
|
||||
insertedCount += _insertRows(conn, tableName, batch, physicalCols, mode)
|
||||
|
||||
conn.commit()
|
||||
dbResult[tableName] = insertedCount
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
warnings.append(f"INSERT into {dbName}.{tableName} failed: {e}")
|
||||
logger.warning("INSERT into %s.%s failed: %s", dbName, tableName, e)
|
||||
except Exception as e:
|
||||
logger.error("Import failed for database %s: %s", dbName, e)
|
||||
return {"database": dbName, "tables": {}, "recordCount": 0,
|
||||
"warnings": [f"Import fuer '{dbName}' fehlgeschlagen: {e}"]}
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
recordCount = sum(dbResult.values())
|
||||
return {"database": dbName, "tables": dbResult, "recordCount": recordCount, "warnings": warnings}
|
||||
|
|
|
|||
|
|
@ -83,6 +83,9 @@ pytest-asyncio>=0.21.0
|
|||
## Configuration Validation
|
||||
jsonschema>=4.0.0 # Required for chatbot workflow config validation
|
||||
|
||||
## Streaming JSON parser (memory-safe import of large DB exports)
|
||||
ijson>=3.2.0
|
||||
|
||||
## For Scheduling / Repeated Tasks
|
||||
APScheduler==3.11.0
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue