382 lines
14 KiB
Python
382 lines
14 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Container extractor for ZIP, TAR, GZ, and 7Z archives.
|
|
|
|
Recursively unpacks containers and delegates each contained file to the
|
|
appropriate type-specific extractor via the ExtractorRegistry.
|
|
|
|
Safety limits:
|
|
- MAX_TOTAL_EXTRACTED_SIZE: 500 MB
|
|
- MAX_FILE_COUNT: 10000
|
|
- maxDepth: 5
|
|
- Symlinks blocked
|
|
"""
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
import io
|
|
import logging
|
|
import mimetypes
|
|
import zipfile
|
|
import tarfile
|
|
|
|
from ..subUtils import makeId
|
|
from modules.datamodels.datamodelExtraction import ContentPart
|
|
from modules.datamodels.datamodelContent import ContainerLimitError, ContentContextRef
|
|
from ..subRegistry import Extractor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MAX_TOTAL_EXTRACTED_SIZE = 500 * 1024 * 1024 # 500 MB
|
|
MAX_FILE_COUNT = 10000
|
|
MAX_DEPTH = 5
|
|
|
|
_CONTAINER_MIME_TYPES = [
|
|
"application/zip",
|
|
"application/x-zip-compressed",
|
|
"application/x-tar",
|
|
"application/gzip",
|
|
"application/x-gzip",
|
|
"application/x-7z-compressed",
|
|
]
|
|
_CONTAINER_EXTENSIONS = [".zip", ".tar", ".gz", ".tar.gz", ".tgz", ".7z"]
|
|
|
|
|
|
def _detectMimeType(fileName: str) -> str:
|
|
"""Detect MIME type from file name."""
|
|
guessed, _ = mimetypes.guess_type(fileName)
|
|
return guessed or "application/octet-stream"
|
|
|
|
|
|
def _isSymlink(info) -> bool:
|
|
"""Check if a tar member is a symlink."""
|
|
if hasattr(info, "issym") and callable(info.issym):
|
|
return info.issym() or info.islnk()
|
|
return False
|
|
|
|
|
|
class ContainerExtractor(Extractor):
|
|
"""Extractor for archive containers (ZIP, TAR, GZ, 7Z).
|
|
|
|
Recursively resolves nested containers and produces a flat list of
|
|
ContentPart entries -- one per contained file -- with containerPath metadata.
|
|
"""
|
|
|
|
def detect(self, fileName: str, mimeType: str, headBytes: bytes) -> bool:
|
|
if mimeType in _CONTAINER_MIME_TYPES:
|
|
return True
|
|
lower = (fileName or "").lower()
|
|
return any(lower.endswith(ext) for ext in _CONTAINER_EXTENSIONS)
|
|
|
|
def getSupportedExtensions(self) -> list[str]:
|
|
return list(_CONTAINER_EXTENSIONS)
|
|
|
|
def getSupportedMimeTypes(self) -> list[str]:
|
|
return list(_CONTAINER_MIME_TYPES)
|
|
|
|
def extract(self, fileBytes: bytes, context: Dict[str, Any]) -> List[ContentPart]:
|
|
"""Extract by recursively unpacking the container."""
|
|
fileName = context.get("fileName", "archive")
|
|
mimeType = context.get("mimeType", "application/octet-stream")
|
|
|
|
rootId = makeId()
|
|
parts: List[ContentPart] = [
|
|
ContentPart(
|
|
id=rootId,
|
|
parentId=None,
|
|
label=fileName,
|
|
typeGroup="container",
|
|
mimeType=mimeType,
|
|
data="",
|
|
metadata={"size": len(fileBytes), "containerType": "archive"},
|
|
)
|
|
]
|
|
|
|
if context.get("lazyContainer"):
|
|
lazy = _extractLazyListing(fileBytes, mimeType, fileName, rootId)
|
|
if lazy is not None:
|
|
parts.extend(lazy)
|
|
return parts
|
|
|
|
state = {"totalSize": 0, "fileCount": 0}
|
|
try:
|
|
childParts = _resolveContainerRecursive(
|
|
fileBytes, mimeType, fileName, rootId, "", 0, state
|
|
)
|
|
parts.extend(childParts)
|
|
except ContainerLimitError as e:
|
|
logger.warning(f"Container limit reached for {fileName}: {e}")
|
|
parts.append(ContentPart(
|
|
id=makeId(),
|
|
parentId=rootId,
|
|
label="limit_exceeded",
|
|
typeGroup="text",
|
|
mimeType="text/plain",
|
|
data=str(e),
|
|
metadata={"warning": "Container extraction limit exceeded"},
|
|
))
|
|
|
|
return parts
|
|
|
|
|
|
def _extractLazyListing(
|
|
fileBytes: bytes,
|
|
containerMime: str,
|
|
containerName: str,
|
|
parentId: str,
|
|
) -> Optional[List[ContentPart]]:
|
|
"""ZIP only: list member files with metadata (no nested extraction)."""
|
|
if containerMime not in ("application/zip", "application/x-zip-compressed") and not (containerName or "").lower().endswith(".zip"):
|
|
return None
|
|
out: List[ContentPart] = []
|
|
try:
|
|
with zipfile.ZipFile(io.BytesIO(fileBytes)) as zf:
|
|
for info in zf.infolist():
|
|
if info.is_dir():
|
|
continue
|
|
entryMime = _detectMimeType(info.filename)
|
|
out.append(
|
|
ContentPart(
|
|
id=makeId(),
|
|
parentId=parentId,
|
|
label=info.filename,
|
|
typeGroup="container",
|
|
mimeType=entryMime,
|
|
data="",
|
|
metadata={
|
|
"containerPath": info.filename,
|
|
"size": info.file_size,
|
|
"lazyReference": True,
|
|
},
|
|
)
|
|
)
|
|
except zipfile.BadZipFile:
|
|
return None
|
|
return out
|
|
|
|
|
|
def _resolveContainerRecursive(
|
|
containerBytes: bytes,
|
|
containerMime: str,
|
|
containerName: str,
|
|
parentId: str,
|
|
containerPath: str,
|
|
depth: int,
|
|
state: Dict[str, int],
|
|
) -> List[ContentPart]:
|
|
"""Recursively unpack containers. No AI calls."""
|
|
if depth > MAX_DEPTH:
|
|
raise ContainerLimitError(f"Max nesting depth {MAX_DEPTH} exceeded")
|
|
|
|
parts: List[ContentPart] = []
|
|
|
|
if containerMime in ("application/zip", "application/x-zip-compressed") or containerName.lower().endswith(".zip"):
|
|
parts.extend(_extractZip(containerBytes, parentId, containerPath, depth, state))
|
|
elif containerMime in ("application/x-tar",) or containerName.lower().endswith(".tar"):
|
|
parts.extend(_extractTar(containerBytes, parentId, containerPath, depth, state, compressed=False))
|
|
elif containerMime in ("application/gzip", "application/x-gzip") or containerName.lower().endswith((".gz", ".tgz", ".tar.gz")):
|
|
parts.extend(_extractTar(containerBytes, parentId, containerPath, depth, state, compressed=True))
|
|
elif containerName.lower().endswith(".7z"):
|
|
parts.extend(_extract7z(containerBytes, parentId, containerPath, depth, state))
|
|
else:
|
|
logger.warning(f"Unknown container format: {containerMime} ({containerName})")
|
|
|
|
return parts
|
|
|
|
|
|
def _addFilePart(
|
|
data: bytes,
|
|
fileName: str,
|
|
parentId: str,
|
|
containerPath: str,
|
|
state: Dict[str, int],
|
|
) -> List[ContentPart]:
|
|
"""Extract a file via its type-specific Extractor and return ContentParts."""
|
|
state["totalSize"] += len(data)
|
|
state["fileCount"] += 1
|
|
|
|
if state["totalSize"] > MAX_TOTAL_EXTRACTED_SIZE:
|
|
raise ContainerLimitError(f"Total extracted size exceeds {MAX_TOTAL_EXTRACTED_SIZE // (1024 * 1024)} MB")
|
|
if state["fileCount"] > MAX_FILE_COUNT:
|
|
raise ContainerLimitError(f"File count exceeds {MAX_FILE_COUNT}")
|
|
|
|
entryPath = f"{containerPath}/{fileName}" if containerPath else fileName
|
|
detectedMime = _detectMimeType(fileName)
|
|
|
|
from ..subRegistry import getExtractorRegistry
|
|
|
|
registry = getExtractorRegistry()
|
|
extractor = registry.resolve(detectedMime, fileName)
|
|
|
|
if extractor and not isinstance(extractor, ContainerExtractor):
|
|
try:
|
|
childParts = extractor.extract(data, {"fileName": fileName, "mimeType": detectedMime})
|
|
for part in childParts:
|
|
part.parentId = parentId
|
|
if not part.metadata:
|
|
part.metadata = {}
|
|
part.metadata["containerPath"] = entryPath
|
|
return childParts
|
|
except Exception as e:
|
|
logger.warning(f"Type-extractor failed for {fileName} in container: {e}")
|
|
|
|
import base64
|
|
encodedData = base64.b64encode(data).decode("utf-8") if data else ""
|
|
|
|
return [ContentPart(
|
|
id=makeId(),
|
|
parentId=parentId,
|
|
label=fileName,
|
|
typeGroup="binary",
|
|
mimeType=detectedMime,
|
|
data=encodedData,
|
|
metadata={
|
|
"size": len(data),
|
|
"containerPath": entryPath,
|
|
"contextRef": ContentContextRef(
|
|
containerPath=entryPath,
|
|
location="file",
|
|
).model_dump(),
|
|
},
|
|
)]
|
|
|
|
|
|
def _isNestedContainer(fileName: str, mimeType: str) -> bool:
|
|
lower = fileName.lower()
|
|
return any(lower.endswith(ext) for ext in _CONTAINER_EXTENSIONS) or mimeType in _CONTAINER_MIME_TYPES
|
|
|
|
|
|
def _extractZip(
|
|
data: bytes, parentId: str, containerPath: str, depth: int, state: Dict[str, int]
|
|
) -> List[ContentPart]:
|
|
parts: List[ContentPart] = []
|
|
try:
|
|
with zipfile.ZipFile(io.BytesIO(data)) as zf:
|
|
for info in zf.infolist():
|
|
if info.is_dir():
|
|
continue
|
|
if info.file_size == 0:
|
|
continue
|
|
|
|
entryPath = f"{containerPath}/{info.filename}" if containerPath else info.filename
|
|
entryMime = _detectMimeType(info.filename)
|
|
entryData = zf.read(info.filename)
|
|
|
|
if _isNestedContainer(info.filename, entryMime):
|
|
nestedId = makeId()
|
|
parts.append(ContentPart(
|
|
id=nestedId,
|
|
parentId=parentId,
|
|
label=info.filename,
|
|
typeGroup="container",
|
|
mimeType=entryMime,
|
|
data="",
|
|
metadata={"size": len(entryData), "containerPath": entryPath},
|
|
))
|
|
nested = _resolveContainerRecursive(
|
|
entryData, entryMime, info.filename, nestedId, entryPath, depth + 1, state
|
|
)
|
|
parts.extend(nested)
|
|
else:
|
|
parts.extend(_addFilePart(entryData, info.filename, parentId, containerPath, state))
|
|
except zipfile.BadZipFile as e:
|
|
logger.error(f"Invalid ZIP file: {e}")
|
|
parts.append(ContentPart(
|
|
id=makeId(), parentId=parentId, label="error",
|
|
typeGroup="text", mimeType="text/plain",
|
|
data=f"Invalid ZIP archive: {e}", metadata={"error": True},
|
|
))
|
|
return parts
|
|
|
|
|
|
def _extractTar(
|
|
data: bytes, parentId: str, containerPath: str, depth: int, state: Dict[str, int],
|
|
compressed: bool = False,
|
|
) -> List[ContentPart]:
|
|
parts: List[ContentPart] = []
|
|
mode = "r:gz" if compressed else "r:"
|
|
try:
|
|
with tarfile.open(fileobj=io.BytesIO(data), mode=mode) as tf:
|
|
for member in tf.getmembers():
|
|
if member.isdir():
|
|
continue
|
|
if _isSymlink(member):
|
|
logger.warning(f"Skipping symlink in TAR: {member.name}")
|
|
continue
|
|
if member.size == 0:
|
|
continue
|
|
|
|
entryPath = f"{containerPath}/{member.name}" if containerPath else member.name
|
|
entryMime = _detectMimeType(member.name)
|
|
fobj = tf.extractfile(member)
|
|
if fobj is None:
|
|
continue
|
|
entryData = fobj.read()
|
|
|
|
if _isNestedContainer(member.name, entryMime):
|
|
nestedId = makeId()
|
|
parts.append(ContentPart(
|
|
id=nestedId, parentId=parentId, label=member.name,
|
|
typeGroup="container", mimeType=entryMime, data="",
|
|
metadata={"size": len(entryData), "containerPath": entryPath},
|
|
))
|
|
nested = _resolveContainerRecursive(
|
|
entryData, entryMime, member.name, nestedId, entryPath, depth + 1, state
|
|
)
|
|
parts.extend(nested)
|
|
else:
|
|
parts.extend(_addFilePart(entryData, member.name, parentId, containerPath, state))
|
|
except tarfile.TarError as e:
|
|
logger.error(f"Invalid TAR file: {e}")
|
|
parts.append(ContentPart(
|
|
id=makeId(), parentId=parentId, label="error",
|
|
typeGroup="text", mimeType="text/plain",
|
|
data=f"Invalid TAR archive: {e}", metadata={"error": True},
|
|
))
|
|
return parts
|
|
|
|
|
|
def _extract7z(
|
|
data: bytes, parentId: str, containerPath: str, depth: int, state: Dict[str, int]
|
|
) -> List[ContentPart]:
|
|
"""Extract 7z archive. Requires py7zr (optional dependency)."""
|
|
parts: List[ContentPart] = []
|
|
try:
|
|
import py7zr
|
|
with py7zr.SevenZipFile(io.BytesIO(data), mode="r") as szf:
|
|
allFiles = szf.readall()
|
|
for fileName, bio in allFiles.items():
|
|
entryData = bio.read() if hasattr(bio, "read") else bytes(bio)
|
|
if not entryData:
|
|
continue
|
|
|
|
entryPath = f"{containerPath}/{fileName}" if containerPath else fileName
|
|
entryMime = _detectMimeType(fileName)
|
|
|
|
if _isNestedContainer(fileName, entryMime):
|
|
nestedId = makeId()
|
|
parts.append(ContentPart(
|
|
id=nestedId, parentId=parentId, label=fileName,
|
|
typeGroup="container", mimeType=entryMime, data="",
|
|
metadata={"size": len(entryData), "containerPath": entryPath},
|
|
))
|
|
nested = _resolveContainerRecursive(
|
|
entryData, entryMime, fileName, nestedId, entryPath, depth + 1, state
|
|
)
|
|
parts.extend(nested)
|
|
else:
|
|
parts.extend(_addFilePart(entryData, fileName, parentId, containerPath, state))
|
|
except ImportError:
|
|
logger.warning("py7zr not installed -- 7z files will be treated as binary")
|
|
parts.append(ContentPart(
|
|
id=makeId(), parentId=parentId, label="unsupported",
|
|
typeGroup="text", mimeType="text/plain",
|
|
data="7z extraction requires py7zr package", metadata={"warning": True},
|
|
))
|
|
except Exception as e:
|
|
logger.error(f"Invalid 7z file: {e}")
|
|
parts.append(ContentPart(
|
|
id=makeId(), parentId=parentId, label="error",
|
|
typeGroup="text", mimeType="text/plain",
|
|
data=f"Invalid 7z archive: {e}", metadata={"error": True},
|
|
))
|
|
return parts
|