gateway/modules/services/serviceExtraction/intelligent_merger.py
2025-10-11 18:30:26 +02:00

209 lines
9 KiB
Python

"""
Intelligent Token-Aware Merger for optimizing AI calls based on LLM token limits.
"""
from typing import List, Dict, Any, Tuple
import logging
from modules.datamodels.datamodelExtraction import ContentPart
from .subUtils import makeId
logger = logging.getLogger(__name__)
class IntelligentTokenAwareMerger:
"""
Intelligent merger that groups chunks based on LLM token limits to minimize AI calls.
Strategy:
1. Calculate token count for each chunk
2. Group chunks to maximize token usage without exceeding limits
3. Preserve document structure and semantic boundaries
4. Minimize total number of AI calls
"""
def __init__(self, model_capabilities: Dict[str, Any]):
self.max_tokens = model_capabilities.get("maxTokens", 4000)
self.safety_margin = model_capabilities.get("safetyMargin", 0.1)
self.effective_max_tokens = int(self.max_tokens * (1 - self.safety_margin))
self.chars_per_token = model_capabilities.get("charsPerToken", 4) # Rough estimation
def merge_chunks_intelligently(self, chunks: List[ContentPart], prompt: str = "") -> List[ContentPart]:
"""
Merge chunks intelligently based on token limits.
Args:
chunks: List of ContentPart chunks to merge
prompt: AI prompt to account for in token calculation
Returns:
List of optimally merged ContentPart objects
"""
if not chunks:
return chunks
logger.info(f"🧠 Intelligent merging: {len(chunks)} chunks, max_tokens={self.effective_max_tokens}")
# Calculate tokens for prompt
prompt_tokens = self._estimate_tokens(prompt)
available_tokens = self.effective_max_tokens - prompt_tokens
logger.info(f"📊 Prompt tokens: {prompt_tokens}, Available for content: {available_tokens}")
# Group chunks by document and type for semantic coherence
grouped_chunks = self._group_chunks_by_document_and_type(chunks)
merged_parts = []
for group_key, group_chunks in grouped_chunks.items():
logger.info(f"📁 Processing group: {group_key} ({len(group_chunks)} chunks)")
# Merge chunks within this group optimally
group_merged = self._merge_group_optimally(group_chunks, available_tokens)
merged_parts.extend(group_merged)
logger.info(f"✅ Intelligent merging complete: {len(chunks)}{len(merged_parts)} parts")
return merged_parts
def _group_chunks_by_document_and_type(self, chunks: List[ContentPart]) -> Dict[str, List[ContentPart]]:
"""Group chunks by document and type for semantic coherence."""
groups = {}
for chunk in chunks:
# Create group key: document_id + type_group
doc_id = chunk.metadata.get("documentId", "unknown")
type_group = chunk.typeGroup
group_key = f"{doc_id}_{type_group}"
if group_key not in groups:
groups[group_key] = []
groups[group_key].append(chunk)
return groups
def _merge_group_optimally(self, chunks: List[ContentPart], available_tokens: int) -> List[ContentPart]:
"""Merge chunks within a group optimally to minimize AI calls."""
if not chunks:
return []
# Sort chunks by size (smallest first for better packing)
sorted_chunks = sorted(chunks, key=lambda c: self._estimate_tokens(c.data))
merged_parts = []
current_group = []
current_tokens = 0
for chunk in sorted_chunks:
chunk_tokens = self._estimate_tokens(chunk.data)
# Special case: If single chunk is already at max size, process it alone
if chunk_tokens >= available_tokens * 0.9: # 90% of available tokens
# Finalize current group if it exists
if current_group:
merged_part = self._create_merged_part(current_group, current_tokens)
merged_parts.append(merged_part)
current_group = []
current_tokens = 0
# Process large chunk individually
merged_parts.append(chunk)
logger.debug(f"🔍 Large chunk processed individually: {chunk_tokens} tokens")
continue
# If adding this chunk would exceed limit, finalize current group
if current_tokens + chunk_tokens > available_tokens and current_group:
merged_part = self._create_merged_part(current_group, current_tokens)
merged_parts.append(merged_part)
current_group = [chunk]
current_tokens = chunk_tokens
else:
current_group.append(chunk)
current_tokens += chunk_tokens
# Finalize remaining group
if current_group:
merged_part = self._create_merged_part(current_group, current_tokens)
merged_parts.append(merged_part)
logger.info(f"📦 Group merged: {len(chunks)}{len(merged_parts)} parts")
return merged_parts
def _create_merged_part(self, chunks: List[ContentPart], total_tokens: int) -> ContentPart:
"""Create a merged ContentPart from multiple chunks."""
if len(chunks) == 1:
return chunks[0] # No need to merge single chunk
# Combine data with semantic separators
combined_data = self._combine_chunk_data(chunks)
# Use metadata from first chunk as base
base_chunk = chunks[0]
merged_metadata = base_chunk.metadata.copy()
merged_metadata.update({
"merged": True,
"originalChunkCount": len(chunks),
"totalTokens": total_tokens,
"originalChunkIds": [c.id for c in chunks],
"size": len(combined_data.encode('utf-8'))
})
merged_part = ContentPart(
id=makeId(),
parentId=base_chunk.parentId,
label=f"merged_{len(chunks)}_chunks",
typeGroup=base_chunk.typeGroup,
mimeType=base_chunk.mimeType,
data=combined_data,
metadata=merged_metadata
)
logger.debug(f"🔗 Created merged part: {len(chunks)} chunks, {total_tokens} tokens")
return merged_part
def _combine_chunk_data(self, chunks: List[ContentPart]) -> str:
"""Combine chunk data with appropriate separators."""
if not chunks:
return ""
# Use different separators based on content type
if chunks[0].typeGroup == "text":
separator = "\n\n---\n\n" # Clear text separation
elif chunks[0].typeGroup == "table":
separator = "\n\n[TABLE BREAK]\n\n" # Table separation
else:
separator = "\n\n---\n\n" # Default separation
return separator.join([chunk.data for chunk in chunks])
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count for text."""
if not text:
return 0
return len(text) // self.chars_per_token
def calculate_optimization_stats(self, original_chunks: List[ContentPart], merged_parts: List[ContentPart]) -> Dict[str, Any]:
"""Calculate optimization statistics with detailed analysis."""
original_calls = len(original_chunks)
optimized_calls = len(merged_parts)
reduction_percent = ((original_calls - optimized_calls) / original_calls * 100) if original_calls > 0 else 0
# Analyze chunk sizes
large_chunks = [c for c in original_chunks if self._estimate_tokens(c.data) >= self.effective_max_tokens * 0.9]
small_chunks = [c for c in original_chunks if self._estimate_tokens(c.data) < self.effective_max_tokens * 0.9]
# Calculate theoretical maximum optimization (if all small chunks could be merged)
theoretical_min_calls = len(large_chunks) + max(1, len(small_chunks) // 3) # Assume 3 small chunks per call
theoretical_reduction = ((original_calls - theoretical_min_calls) / original_calls * 100) if original_calls > 0 else 0
return {
"original_ai_calls": original_calls,
"optimized_ai_calls": optimized_calls,
"reduction_percent": round(reduction_percent, 1),
"cost_savings": f"{reduction_percent:.1f}%",
"efficiency_gain": f"{original_calls / optimized_calls:.1f}x" if optimized_calls > 0 else "",
"analysis": {
"large_chunks": len(large_chunks),
"small_chunks": len(small_chunks),
"theoretical_min_calls": theoretical_min_calls,
"theoretical_reduction": round(theoretical_reduction, 1),
"optimization_potential": "high" if reduction_percent > 50 else "moderate" if reduction_percent > 20 else "low"
}
}