from typing import Any, Dict, List from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy from ..subUtils import makeId class TextMerger: def merge(self, parts: List[ContentPart], strategy: MergeStrategy) -> List[ContentPart]: """ Merge text parts based on strategy. Strategy options: - groupBy: "parentId" (default), "documentId", "none" - orderBy: "label", "pageIndex", "sheetIndex", "none" - maxSize: maximum size per merged part """ if not parts: return parts groupBy = strategy.groupBy orderBy = strategy.orderBy maxSize = strategy.maxSize or 0 # Group parts groups = self._groupParts(parts, groupBy) merged: List[ContentPart] = [] for groupKey, groupParts in groups.items(): # Sort within group sortedParts = self._sortParts(groupParts, orderBy) # Merge respecting maxSize if maxSize > 0: merged.extend(self._mergeWithSizeLimit(sortedParts, maxSize)) else: merged.extend(self._mergeGroup(sortedParts, groupKey)) return merged def _groupParts(self, parts: List[ContentPart], groupBy: str) -> Dict[str, List[ContentPart]]: groups: Dict[str, List[ContentPart]] = {} for part in parts: if part.typeGroup != "text": # Non-text parts go in their own group key = f"nontext_{part.id}" if key not in groups: groups[key] = [] groups[key].append(part) continue if groupBy == "parentId": key = part.parentId or "root" elif groupBy == "documentId": key = part.metadata.get("documentId", "unknown") else: # "none" key = "all" if key not in groups: groups[key] = [] groups[key].append(part) return groups def _sortParts(self, parts: List[ContentPart], orderBy: str) -> List[ContentPart]: if orderBy == "pageIndex": return sorted(parts, key=lambda p: p.metadata.get("pageIndex", 0)) elif orderBy == "sheetIndex": return sorted(parts, key=lambda p: p.metadata.get("sheetIndex", 0)) elif orderBy == "label": return sorted(parts, key=lambda p: p.label) else: # "none" return parts def _mergeGroup(self, parts: List[ContentPart], groupKey: str) -> List[ContentPart]: if not parts: return [] if len(parts) == 1: return parts # Merge all text parts in group textParts = [p for p in parts if p.typeGroup == "text"] nonTextParts = [p for p in parts if p.typeGroup != "text"] if not textParts: return nonTextParts # Combine text data combinedData = "\n".join([p.data for p in textParts]) totalSize = sum(p.metadata.get("size", 0) for p in textParts) mergedPart = ContentPart( id=makeId(), parentId=textParts[0].parentId, label=f"merged_{groupKey}", typeGroup="text", mimeType="text/plain", data=combinedData, metadata={ "size": totalSize, "merged": len(textParts), "originalParts": [p.id for p in textParts] } ) return [mergedPart] + nonTextParts def _mergeWithSizeLimit(self, parts: List[ContentPart], maxSize: int) -> List[ContentPart]: if not parts: return [] textParts = [p for p in parts if p.typeGroup == "text"] nonTextParts = [p for p in parts if p.typeGroup != "text"] if not textParts: return nonTextParts merged: List[ContentPart] = [] currentGroup: List[ContentPart] = [] currentSize = 0 for part in textParts: partSize = part.metadata.get("size", 0) if currentSize + partSize > maxSize and currentGroup: # Flush current group merged.extend(self._mergeGroup(currentGroup, f"chunk_{len(merged)}")) currentGroup = [part] currentSize = partSize else: currentGroup.append(part) currentSize += partSize # Flush remaining group if currentGroup: merged.extend(self._mergeGroup(currentGroup, f"chunk_{len(merged)}")) return merged + nonTextParts