209 lines
8.8 KiB
Python
209 lines
8.8 KiB
Python
"""
|
|
Intelligent Token-Aware Merger for optimizing AI calls based on LLM token limits.
|
|
"""
|
|
from typing import List, Dict, Any
|
|
import logging
|
|
from modules.datamodels.datamodelExtraction import ContentPart
|
|
from .subUtils import makeId
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class IntelligentTokenAwareMerger:
|
|
"""
|
|
Intelligent merger that groups chunks based on LLM token limits to minimize AI calls.
|
|
|
|
Strategy:
|
|
1. Calculate token count for each chunk
|
|
2. Group chunks to maximize token usage without exceeding limits
|
|
3. Preserve document structure and semantic boundaries
|
|
4. Minimize total number of AI calls
|
|
"""
|
|
|
|
def __init__(self, modelCapabilities: Dict[str, Any]):
|
|
self.maxTokens = modelCapabilities.get("maxTokens", 4000)
|
|
self.safetyMargin = modelCapabilities.get("safetyMargin", 0.1)
|
|
self.effectiveMaxTokens = int(self.maxTokens * (1 - self.safetyMargin))
|
|
self.charsPerToken = modelCapabilities.get("charsPerToken", 4) # Rough estimation
|
|
|
|
def mergeChunksIntelligently(self, chunks: List[ContentPart], prompt: str = "") -> List[ContentPart]:
|
|
"""
|
|
Merge chunks intelligently based on token limits.
|
|
|
|
Args:
|
|
chunks: List of ContentPart chunks to merge
|
|
prompt: AI prompt to account for in token calculation
|
|
|
|
Returns:
|
|
List of optimally merged ContentPart objects
|
|
"""
|
|
if not chunks:
|
|
return chunks
|
|
|
|
logger.info(f"🧠 Intelligent merging: {len(chunks)} chunks, maxTokens={self.effectiveMaxTokens}")
|
|
|
|
# Calculate tokens for prompt
|
|
promptTokens = self._estimateTokens(prompt)
|
|
availableTokens = self.effectiveMaxTokens - promptTokens
|
|
|
|
logger.info(f"📊 Prompt tokens: {promptTokens}, Available for content: {availableTokens}")
|
|
|
|
# Group chunks by document and type for semantic coherence
|
|
groupedChunks = self._groupChunksByDocumentAndType(chunks)
|
|
|
|
mergedParts = []
|
|
|
|
for groupKey, groupChunks in groupedChunks.items():
|
|
logger.info(f"📁 Processing group: {groupKey} ({len(groupChunks)} chunks)")
|
|
|
|
# Merge chunks within this group optimally
|
|
groupMerged = self._mergeGroupOptimally(groupChunks, availableTokens)
|
|
mergedParts.extend(groupMerged)
|
|
|
|
logger.info(f"✅ Intelligent merging complete: {len(chunks)} → {len(mergedParts)} parts")
|
|
return mergedParts
|
|
|
|
def _groupChunksByDocumentAndType(self, chunks: List[ContentPart]) -> Dict[str, List[ContentPart]]:
|
|
"""Group chunks by document and type for semantic coherence."""
|
|
groups = {}
|
|
|
|
for chunk in chunks:
|
|
# Create group key: document_id + type_group
|
|
docId = chunk.metadata.get("documentId", "unknown")
|
|
typeGroup = chunk.typeGroup
|
|
groupKey = f"{docId}_{typeGroup}"
|
|
|
|
if groupKey not in groups:
|
|
groups[groupKey] = []
|
|
groups[groupKey].append(chunk)
|
|
|
|
return groups
|
|
|
|
def _mergeGroupOptimally(self, chunks: List[ContentPart], availableTokens: int) -> List[ContentPart]:
|
|
"""Merge chunks within a group optimally to minimize AI calls."""
|
|
if not chunks:
|
|
return []
|
|
|
|
# Sort chunks by size (smallest first for better packing)
|
|
sortedChunks = sorted(chunks, key=lambda c: self._estimateTokens(c.data))
|
|
|
|
mergedParts = []
|
|
currentGroup = []
|
|
currentTokens = 0
|
|
|
|
for chunk in sortedChunks:
|
|
chunkTokens = self._estimateTokens(chunk.data)
|
|
|
|
# Special case: If single chunk is already at max size, process it alone
|
|
if chunkTokens >= availableTokens * 0.9: # 90% of available tokens
|
|
# Finalize current group if it exists
|
|
if currentGroup:
|
|
mergedPart = self._createMergedPart(currentGroup, currentTokens)
|
|
mergedParts.append(mergedPart)
|
|
currentGroup = []
|
|
currentTokens = 0
|
|
|
|
# Process large chunk individually
|
|
mergedParts.append(chunk)
|
|
logger.debug(f"🔍 Large chunk processed individually: {chunkTokens} tokens")
|
|
continue
|
|
|
|
# If adding this chunk would exceed limit, finalize current group
|
|
if currentTokens + chunkTokens > availableTokens and currentGroup:
|
|
mergedPart = self._createMergedPart(currentGroup, currentTokens)
|
|
mergedParts.append(mergedPart)
|
|
currentGroup = [chunk]
|
|
currentTokens = chunkTokens
|
|
else:
|
|
currentGroup.append(chunk)
|
|
currentTokens += chunkTokens
|
|
|
|
# Finalize remaining group
|
|
if currentGroup:
|
|
mergedPart = self._createMergedPart(currentGroup, currentTokens)
|
|
mergedParts.append(mergedPart)
|
|
|
|
logger.info(f"📦 Group merged: {len(chunks)} → {len(mergedParts)} parts")
|
|
return mergedParts
|
|
|
|
def _createMergedPart(self, chunks: List[ContentPart], totalTokens: int) -> ContentPart:
|
|
"""Create a merged ContentPart from multiple chunks."""
|
|
if len(chunks) == 1:
|
|
return chunks[0] # No need to merge single chunk
|
|
|
|
# Combine data with semantic separators
|
|
combinedData = self._combineChunkData(chunks)
|
|
|
|
# Use metadata from first chunk as base
|
|
baseChunk = chunks[0]
|
|
mergedMetadata = baseChunk.metadata.copy()
|
|
mergedMetadata.update({
|
|
"merged": True,
|
|
"originalChunkCount": len(chunks),
|
|
"totalTokens": totalTokens,
|
|
"originalChunkIds": [c.id for c in chunks],
|
|
"size": len(combinedData.encode('utf-8'))
|
|
})
|
|
|
|
mergedPart = ContentPart(
|
|
id=makeId(),
|
|
parentId=baseChunk.parentId,
|
|
label=f"merged_{len(chunks)}_chunks",
|
|
typeGroup=baseChunk.typeGroup,
|
|
mimeType=baseChunk.mimeType,
|
|
data=combinedData,
|
|
metadata=mergedMetadata
|
|
)
|
|
|
|
logger.debug(f"🔗 Created merged part: {len(chunks)} chunks, {totalTokens} tokens")
|
|
return mergedPart
|
|
|
|
def _combineChunkData(self, chunks: List[ContentPart]) -> str:
|
|
"""Combine chunk data with appropriate separators."""
|
|
if not chunks:
|
|
return ""
|
|
|
|
# Use different separators based on content type
|
|
if chunks[0].typeGroup == "text":
|
|
separator = "\n\n---\n\n" # Clear text separation
|
|
elif chunks[0].typeGroup == "table":
|
|
separator = "\n\n[TABLE BREAK]\n\n" # Table separation
|
|
else:
|
|
separator = "\n\n---\n\n" # Default separation
|
|
|
|
return separator.join([chunk.data for chunk in chunks])
|
|
|
|
def _estimateTokens(self, text: str) -> int:
|
|
"""Estimate token count for text."""
|
|
if not text:
|
|
return 0
|
|
return len(text) // self.charsPerToken
|
|
|
|
def calculateOptimizationStats(self, originalChunks: List[ContentPart], mergedParts: List[ContentPart]) -> Dict[str, Any]:
|
|
"""Calculate optimization statistics with detailed analysis."""
|
|
originalCalls = len(originalChunks)
|
|
optimizedCalls = len(mergedParts)
|
|
reductionPercent = ((originalCalls - optimizedCalls) / originalCalls * 100) if originalCalls > 0 else 0
|
|
|
|
# Analyze chunk sizes
|
|
largeChunks = [c for c in originalChunks if self._estimateTokens(c.data) >= self.effectiveMaxTokens * 0.9]
|
|
smallChunks = [c for c in originalChunks if self._estimateTokens(c.data) < self.effectiveMaxTokens * 0.9]
|
|
|
|
# Calculate theoretical maximum optimization (if all small chunks could be merged)
|
|
theoreticalMinCalls = len(largeChunks) + max(1, len(smallChunks) // 3) # Assume 3 small chunks per call
|
|
theoreticalReduction = ((originalCalls - theoreticalMinCalls) / originalCalls * 100) if originalCalls > 0 else 0
|
|
|
|
return {
|
|
"original_ai_calls": originalCalls,
|
|
"optimized_ai_calls": optimizedCalls,
|
|
"reduction_percent": round(reductionPercent, 1),
|
|
"cost_savings": f"{reductionPercent:.1f}%",
|
|
"efficiency_gain": f"{originalCalls / optimizedCalls:.1f}x" if optimizedCalls > 0 else "∞",
|
|
"analysis": {
|
|
"large_chunks": len(largeChunks),
|
|
"small_chunks": len(smallChunks),
|
|
"theoretical_min_calls": theoreticalMinCalls,
|
|
"theoretical_reduction": round(theoreticalReduction, 1),
|
|
"optimization_potential": "high" if reductionPercent > 50 else "moderate" if reductionPercent > 20 else "low"
|
|
}
|
|
}
|