136 lines
4.7 KiB
Python
136 lines
4.7 KiB
Python
from typing import Any, Dict, List
|
|
from modules.datamodels.datamodelExtraction import ContentPart
|
|
from ..subUtils import makeId
|
|
|
|
|
|
class TextMerger:
|
|
def merge(self, parts: List[ContentPart], strategy: Dict[str, Any]) -> List[ContentPart]:
|
|
"""
|
|
Merge text parts based on strategy.
|
|
Strategy options:
|
|
- groupBy: "parentId" (default), "documentId", "none"
|
|
- orderBy: "label", "pageIndex", "sheetIndex", "none"
|
|
- maxSize: maximum size per merged part
|
|
"""
|
|
if not parts:
|
|
return parts
|
|
|
|
groupBy = strategy.get("groupBy", "parentId")
|
|
orderBy = strategy.get("orderBy", "label")
|
|
maxSize = strategy.get("maxSize", 0)
|
|
|
|
# Group parts
|
|
groups = self._groupParts(parts, groupBy)
|
|
|
|
merged: List[ContentPart] = []
|
|
for groupKey, groupParts in groups.items():
|
|
# Sort within group
|
|
sortedParts = self._sortParts(groupParts, orderBy)
|
|
|
|
# Merge respecting maxSize
|
|
if maxSize > 0:
|
|
merged.extend(self._mergeWithSizeLimit(sortedParts, maxSize))
|
|
else:
|
|
merged.extend(self._mergeGroup(sortedParts, groupKey))
|
|
|
|
return merged
|
|
|
|
def _groupParts(self, parts: List[ContentPart], groupBy: str) -> Dict[str, List[ContentPart]]:
|
|
groups: Dict[str, List[ContentPart]] = {}
|
|
|
|
for part in parts:
|
|
if part.typeGroup != "text":
|
|
# Non-text parts go in their own group
|
|
key = f"nontext_{part.id}"
|
|
if key not in groups:
|
|
groups[key] = []
|
|
groups[key].append(part)
|
|
continue
|
|
|
|
if groupBy == "parentId":
|
|
key = part.parentId or "root"
|
|
elif groupBy == "documentId":
|
|
key = part.metadata.get("documentId", "unknown")
|
|
else: # "none"
|
|
key = "all"
|
|
|
|
if key not in groups:
|
|
groups[key] = []
|
|
groups[key].append(part)
|
|
|
|
return groups
|
|
|
|
def _sortParts(self, parts: List[ContentPart], orderBy: str) -> List[ContentPart]:
|
|
if orderBy == "pageIndex":
|
|
return sorted(parts, key=lambda p: p.metadata.get("pageIndex", 0))
|
|
elif orderBy == "sheetIndex":
|
|
return sorted(parts, key=lambda p: p.metadata.get("sheetIndex", 0))
|
|
elif orderBy == "label":
|
|
return sorted(parts, key=lambda p: p.label)
|
|
else: # "none"
|
|
return parts
|
|
|
|
def _mergeGroup(self, parts: List[ContentPart], groupKey: str) -> List[ContentPart]:
|
|
if not parts:
|
|
return []
|
|
if len(parts) == 1:
|
|
return parts
|
|
|
|
# Merge all text parts in group
|
|
textParts = [p for p in parts if p.typeGroup == "text"]
|
|
nonTextParts = [p for p in parts if p.typeGroup != "text"]
|
|
|
|
if not textParts:
|
|
return nonTextParts
|
|
|
|
# Combine text data
|
|
combinedData = "\n".join([p.data for p in textParts])
|
|
totalSize = sum(p.metadata.get("size", 0) for p in textParts)
|
|
|
|
mergedPart = ContentPart(
|
|
id=makeId(),
|
|
parentId=textParts[0].parentId,
|
|
label=f"merged_{groupKey}",
|
|
typeGroup="text",
|
|
mimeType="text/plain",
|
|
data=combinedData,
|
|
metadata={
|
|
"size": totalSize,
|
|
"merged": len(textParts),
|
|
"originalParts": [p.id for p in textParts]
|
|
}
|
|
)
|
|
|
|
return [mergedPart] + nonTextParts
|
|
|
|
def _mergeWithSizeLimit(self, parts: List[ContentPart], maxSize: int) -> List[ContentPart]:
|
|
if not parts:
|
|
return []
|
|
|
|
textParts = [p for p in parts if p.typeGroup == "text"]
|
|
nonTextParts = [p for p in parts if p.typeGroup != "text"]
|
|
|
|
if not textParts:
|
|
return nonTextParts
|
|
|
|
merged: List[ContentPart] = []
|
|
currentGroup: List[ContentPart] = []
|
|
currentSize = 0
|
|
|
|
for part in textParts:
|
|
partSize = part.metadata.get("size", 0)
|
|
|
|
if currentSize + partSize > maxSize and currentGroup:
|
|
# Flush current group
|
|
merged.extend(self._mergeGroup(currentGroup, f"chunk_{len(merged)}"))
|
|
currentGroup = [part]
|
|
currentSize = partSize
|
|
else:
|
|
currentGroup.append(part)
|
|
currentSize += partSize
|
|
|
|
# Flush remaining group
|
|
if currentGroup:
|
|
merged.extend(self._mergeGroup(currentGroup, f"chunk_{len(merged)}"))
|
|
|
|
return merged + nonTextParts
|