""" Intelligent Token-Aware Merger for optimizing AI calls based on LLM token limits. """ from typing import List, Dict, Any import logging from modules.datamodels.datamodelExtraction import ContentPart from .subUtils import makeId logger = logging.getLogger(__name__) class IntelligentTokenAwareMerger: """ Intelligent merger that groups chunks based on LLM token limits to minimize AI calls. Strategy: 1. Calculate token count for each chunk 2. Group chunks to maximize token usage without exceeding limits 3. Preserve document structure and semantic boundaries 4. Minimize total number of AI calls """ def __init__(self, modelCapabilities: Dict[str, Any]): self.maxTokens = modelCapabilities.get("maxTokens", 4000) self.safetyMargin = modelCapabilities.get("safetyMargin", 0.1) self.effectiveMaxTokens = int(self.maxTokens * (1 - self.safetyMargin)) self.charsPerToken = modelCapabilities.get("charsPerToken", 4) # Rough estimation def mergeChunksIntelligently(self, chunks: List[ContentPart], prompt: str = "") -> List[ContentPart]: """ Merge chunks intelligently based on token limits. Args: chunks: List of ContentPart chunks to merge prompt: AI prompt to account for in token calculation Returns: List of optimally merged ContentPart objects """ if not chunks: return chunks logger.info(f"🧠 Intelligent merging: {len(chunks)} chunks, maxTokens={self.effectiveMaxTokens}") # Calculate tokens for prompt promptTokens = self._estimateTokens(prompt) availableTokens = self.effectiveMaxTokens - promptTokens logger.info(f"📊 Prompt tokens: {promptTokens}, Available for content: {availableTokens}") # Group chunks by document and type for semantic coherence groupedChunks = self._groupChunksByDocumentAndType(chunks) mergedParts = [] for groupKey, groupChunks in groupedChunks.items(): logger.info(f"📁 Processing group: {groupKey} ({len(groupChunks)} chunks)") # Merge chunks within this group optimally groupMerged = self._mergeGroupOptimally(groupChunks, availableTokens) mergedParts.extend(groupMerged) logger.info(f"✅ Intelligent merging complete: {len(chunks)} → {len(mergedParts)} parts") return mergedParts def _groupChunksByDocumentAndType(self, chunks: List[ContentPart]) -> Dict[str, List[ContentPart]]: """Group chunks by document and type for semantic coherence.""" groups = {} for chunk in chunks: # Create group key: document_id + type_group docId = chunk.metadata.get("documentId", "unknown") typeGroup = chunk.typeGroup groupKey = f"{docId}_{typeGroup}" if groupKey not in groups: groups[groupKey] = [] groups[groupKey].append(chunk) return groups def _mergeGroupOptimally(self, chunks: List[ContentPart], availableTokens: int) -> List[ContentPart]: """Merge chunks within a group optimally to minimize AI calls.""" if not chunks: return [] # Sort chunks by size (smallest first for better packing) sortedChunks = sorted(chunks, key=lambda c: self._estimateTokens(c.data)) mergedParts = [] currentGroup = [] currentTokens = 0 for chunk in sortedChunks: chunkTokens = self._estimateTokens(chunk.data) # Special case: If single chunk is already at max size, process it alone if chunkTokens >= availableTokens * 0.9: # 90% of available tokens # Finalize current group if it exists if currentGroup: mergedPart = self._createMergedPart(currentGroup, currentTokens) mergedParts.append(mergedPart) currentGroup = [] currentTokens = 0 # Process large chunk individually mergedParts.append(chunk) logger.debug(f"🔍 Large chunk processed individually: {chunkTokens} tokens") continue # If adding this chunk would exceed limit, finalize current group if currentTokens + chunkTokens > availableTokens and currentGroup: mergedPart = self._createMergedPart(currentGroup, currentTokens) mergedParts.append(mergedPart) currentGroup = [chunk] currentTokens = chunkTokens else: currentGroup.append(chunk) currentTokens += chunkTokens # Finalize remaining group if currentGroup: mergedPart = self._createMergedPart(currentGroup, currentTokens) mergedParts.append(mergedPart) logger.info(f"📦 Group merged: {len(chunks)} → {len(mergedParts)} parts") return mergedParts def _createMergedPart(self, chunks: List[ContentPart], totalTokens: int) -> ContentPart: """Create a merged ContentPart from multiple chunks.""" if len(chunks) == 1: return chunks[0] # No need to merge single chunk # Combine data with semantic separators combinedData = self._combineChunkData(chunks) # Use metadata from first chunk as base baseChunk = chunks[0] mergedMetadata = baseChunk.metadata.copy() mergedMetadata.update({ "merged": True, "originalChunkCount": len(chunks), "totalTokens": totalTokens, "originalChunkIds": [c.id for c in chunks], "size": len(combinedData.encode('utf-8')) }) mergedPart = ContentPart( id=makeId(), parentId=baseChunk.parentId, label=f"merged_{len(chunks)}_chunks", typeGroup=baseChunk.typeGroup, mimeType=baseChunk.mimeType, data=combinedData, metadata=mergedMetadata ) logger.debug(f"🔗 Created merged part: {len(chunks)} chunks, {totalTokens} tokens") return mergedPart def _combineChunkData(self, chunks: List[ContentPart]) -> str: """Combine chunk data with appropriate separators.""" if not chunks: return "" # Use different separators based on content type if chunks[0].typeGroup == "text": separator = "\n\n---\n\n" # Clear text separation elif chunks[0].typeGroup == "table": separator = "\n\n[TABLE BREAK]\n\n" # Table separation else: separator = "\n\n---\n\n" # Default separation return separator.join([chunk.data for chunk in chunks]) def _estimateTokens(self, text: str) -> int: """Estimate token count for text.""" if not text: return 0 return len(text) // self.charsPerToken def calculateOptimizationStats(self, originalChunks: List[ContentPart], mergedParts: List[ContentPart]) -> Dict[str, Any]: """Calculate optimization statistics with detailed analysis.""" originalCalls = len(originalChunks) optimizedCalls = len(mergedParts) reductionPercent = ((originalCalls - optimizedCalls) / originalCalls * 100) if originalCalls > 0 else 0 # Analyze chunk sizes largeChunks = [c for c in originalChunks if self._estimateTokens(c.data) >= self.effectiveMaxTokens * 0.9] smallChunks = [c for c in originalChunks if self._estimateTokens(c.data) < self.effectiveMaxTokens * 0.9] # Calculate theoretical maximum optimization (if all small chunks could be merged) theoreticalMinCalls = len(largeChunks) + max(1, len(smallChunks) // 3) # Assume 3 small chunks per call theoreticalReduction = ((originalCalls - theoreticalMinCalls) / originalCalls * 100) if originalCalls > 0 else 0 return { "original_ai_calls": originalCalls, "optimized_ai_calls": optimizedCalls, "reduction_percent": round(reductionPercent, 1), "cost_savings": f"{reductionPercent:.1f}%", "efficiency_gain": f"{originalCalls / optimizedCalls:.1f}x" if optimizedCalls > 0 else "∞", "analysis": { "large_chunks": len(largeChunks), "small_chunks": len(smallChunks), "theoretical_min_calls": theoreticalMinCalls, "theoretical_reduction": round(theoreticalReduction, 1), "optimization_potential": "high" if reductionPercent > 50 else "moderate" if reductionPercent > 20 else "low" } }