209 lines
9 KiB
Python
209 lines
9 KiB
Python
"""
|
|
Intelligent Token-Aware Merger for optimizing AI calls based on LLM token limits.
|
|
"""
|
|
from typing import List, Dict, Any, Tuple
|
|
import logging
|
|
from modules.datamodels.datamodelExtraction import ContentPart
|
|
from .subUtils import makeId
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class IntelligentTokenAwareMerger:
|
|
"""
|
|
Intelligent merger that groups chunks based on LLM token limits to minimize AI calls.
|
|
|
|
Strategy:
|
|
1. Calculate token count for each chunk
|
|
2. Group chunks to maximize token usage without exceeding limits
|
|
3. Preserve document structure and semantic boundaries
|
|
4. Minimize total number of AI calls
|
|
"""
|
|
|
|
def __init__(self, model_capabilities: Dict[str, Any]):
|
|
self.max_tokens = model_capabilities.get("maxTokens", 4000)
|
|
self.safety_margin = model_capabilities.get("safetyMargin", 0.1)
|
|
self.effective_max_tokens = int(self.max_tokens * (1 - self.safety_margin))
|
|
self.chars_per_token = model_capabilities.get("charsPerToken", 4) # Rough estimation
|
|
|
|
def merge_chunks_intelligently(self, chunks: List[ContentPart], prompt: str = "") -> List[ContentPart]:
|
|
"""
|
|
Merge chunks intelligently based on token limits.
|
|
|
|
Args:
|
|
chunks: List of ContentPart chunks to merge
|
|
prompt: AI prompt to account for in token calculation
|
|
|
|
Returns:
|
|
List of optimally merged ContentPart objects
|
|
"""
|
|
if not chunks:
|
|
return chunks
|
|
|
|
logger.info(f"🧠 Intelligent merging: {len(chunks)} chunks, max_tokens={self.effective_max_tokens}")
|
|
|
|
# Calculate tokens for prompt
|
|
prompt_tokens = self._estimate_tokens(prompt)
|
|
available_tokens = self.effective_max_tokens - prompt_tokens
|
|
|
|
logger.info(f"📊 Prompt tokens: {prompt_tokens}, Available for content: {available_tokens}")
|
|
|
|
# Group chunks by document and type for semantic coherence
|
|
grouped_chunks = self._group_chunks_by_document_and_type(chunks)
|
|
|
|
merged_parts = []
|
|
|
|
for group_key, group_chunks in grouped_chunks.items():
|
|
logger.info(f"📁 Processing group: {group_key} ({len(group_chunks)} chunks)")
|
|
|
|
# Merge chunks within this group optimally
|
|
group_merged = self._merge_group_optimally(group_chunks, available_tokens)
|
|
merged_parts.extend(group_merged)
|
|
|
|
logger.info(f"✅ Intelligent merging complete: {len(chunks)} → {len(merged_parts)} parts")
|
|
return merged_parts
|
|
|
|
def _group_chunks_by_document_and_type(self, chunks: List[ContentPart]) -> Dict[str, List[ContentPart]]:
|
|
"""Group chunks by document and type for semantic coherence."""
|
|
groups = {}
|
|
|
|
for chunk in chunks:
|
|
# Create group key: document_id + type_group
|
|
doc_id = chunk.metadata.get("documentId", "unknown")
|
|
type_group = chunk.typeGroup
|
|
group_key = f"{doc_id}_{type_group}"
|
|
|
|
if group_key not in groups:
|
|
groups[group_key] = []
|
|
groups[group_key].append(chunk)
|
|
|
|
return groups
|
|
|
|
def _merge_group_optimally(self, chunks: List[ContentPart], available_tokens: int) -> List[ContentPart]:
|
|
"""Merge chunks within a group optimally to minimize AI calls."""
|
|
if not chunks:
|
|
return []
|
|
|
|
# Sort chunks by size (smallest first for better packing)
|
|
sorted_chunks = sorted(chunks, key=lambda c: self._estimate_tokens(c.data))
|
|
|
|
merged_parts = []
|
|
current_group = []
|
|
current_tokens = 0
|
|
|
|
for chunk in sorted_chunks:
|
|
chunk_tokens = self._estimate_tokens(chunk.data)
|
|
|
|
# Special case: If single chunk is already at max size, process it alone
|
|
if chunk_tokens >= available_tokens * 0.9: # 90% of available tokens
|
|
# Finalize current group if it exists
|
|
if current_group:
|
|
merged_part = self._create_merged_part(current_group, current_tokens)
|
|
merged_parts.append(merged_part)
|
|
current_group = []
|
|
current_tokens = 0
|
|
|
|
# Process large chunk individually
|
|
merged_parts.append(chunk)
|
|
logger.debug(f"🔍 Large chunk processed individually: {chunk_tokens} tokens")
|
|
continue
|
|
|
|
# If adding this chunk would exceed limit, finalize current group
|
|
if current_tokens + chunk_tokens > available_tokens and current_group:
|
|
merged_part = self._create_merged_part(current_group, current_tokens)
|
|
merged_parts.append(merged_part)
|
|
current_group = [chunk]
|
|
current_tokens = chunk_tokens
|
|
else:
|
|
current_group.append(chunk)
|
|
current_tokens += chunk_tokens
|
|
|
|
# Finalize remaining group
|
|
if current_group:
|
|
merged_part = self._create_merged_part(current_group, current_tokens)
|
|
merged_parts.append(merged_part)
|
|
|
|
logger.info(f"📦 Group merged: {len(chunks)} → {len(merged_parts)} parts")
|
|
return merged_parts
|
|
|
|
def _create_merged_part(self, chunks: List[ContentPart], total_tokens: int) -> ContentPart:
|
|
"""Create a merged ContentPart from multiple chunks."""
|
|
if len(chunks) == 1:
|
|
return chunks[0] # No need to merge single chunk
|
|
|
|
# Combine data with semantic separators
|
|
combined_data = self._combine_chunk_data(chunks)
|
|
|
|
# Use metadata from first chunk as base
|
|
base_chunk = chunks[0]
|
|
merged_metadata = base_chunk.metadata.copy()
|
|
merged_metadata.update({
|
|
"merged": True,
|
|
"originalChunkCount": len(chunks),
|
|
"totalTokens": total_tokens,
|
|
"originalChunkIds": [c.id for c in chunks],
|
|
"size": len(combined_data.encode('utf-8'))
|
|
})
|
|
|
|
merged_part = ContentPart(
|
|
id=makeId(),
|
|
parentId=base_chunk.parentId,
|
|
label=f"merged_{len(chunks)}_chunks",
|
|
typeGroup=base_chunk.typeGroup,
|
|
mimeType=base_chunk.mimeType,
|
|
data=combined_data,
|
|
metadata=merged_metadata
|
|
)
|
|
|
|
logger.debug(f"🔗 Created merged part: {len(chunks)} chunks, {total_tokens} tokens")
|
|
return merged_part
|
|
|
|
def _combine_chunk_data(self, chunks: List[ContentPart]) -> str:
|
|
"""Combine chunk data with appropriate separators."""
|
|
if not chunks:
|
|
return ""
|
|
|
|
# Use different separators based on content type
|
|
if chunks[0].typeGroup == "text":
|
|
separator = "\n\n---\n\n" # Clear text separation
|
|
elif chunks[0].typeGroup == "table":
|
|
separator = "\n\n[TABLE BREAK]\n\n" # Table separation
|
|
else:
|
|
separator = "\n\n---\n\n" # Default separation
|
|
|
|
return separator.join([chunk.data for chunk in chunks])
|
|
|
|
def _estimate_tokens(self, text: str) -> int:
|
|
"""Estimate token count for text."""
|
|
if not text:
|
|
return 0
|
|
return len(text) // self.chars_per_token
|
|
|
|
def calculate_optimization_stats(self, original_chunks: List[ContentPart], merged_parts: List[ContentPart]) -> Dict[str, Any]:
|
|
"""Calculate optimization statistics with detailed analysis."""
|
|
original_calls = len(original_chunks)
|
|
optimized_calls = len(merged_parts)
|
|
reduction_percent = ((original_calls - optimized_calls) / original_calls * 100) if original_calls > 0 else 0
|
|
|
|
# Analyze chunk sizes
|
|
large_chunks = [c for c in original_chunks if self._estimate_tokens(c.data) >= self.effective_max_tokens * 0.9]
|
|
small_chunks = [c for c in original_chunks if self._estimate_tokens(c.data) < self.effective_max_tokens * 0.9]
|
|
|
|
# Calculate theoretical maximum optimization (if all small chunks could be merged)
|
|
theoretical_min_calls = len(large_chunks) + max(1, len(small_chunks) // 3) # Assume 3 small chunks per call
|
|
theoretical_reduction = ((original_calls - theoretical_min_calls) / original_calls * 100) if original_calls > 0 else 0
|
|
|
|
return {
|
|
"original_ai_calls": original_calls,
|
|
"optimized_ai_calls": optimized_calls,
|
|
"reduction_percent": round(reduction_percent, 1),
|
|
"cost_savings": f"{reduction_percent:.1f}%",
|
|
"efficiency_gain": f"{original_calls / optimized_calls:.1f}x" if optimized_calls > 0 else "∞",
|
|
"analysis": {
|
|
"large_chunks": len(large_chunks),
|
|
"small_chunks": len(small_chunks),
|
|
"theoretical_min_calls": theoretical_min_calls,
|
|
"theoretical_reduction": round(theoretical_reduction, 1),
|
|
"optimization_potential": "high" if reduction_percent > 50 else "moderate" if reduction_percent > 20 else "low"
|
|
}
|
|
}
|