gateway/modules/services/serviceAi/subAiCallLooping.py

627 lines
37 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
AI Call Looping Module
Handles AI calls with looping and repair logic, including:
- Looping with JSON repair and continuation
- KPI definition and tracking
- Progress tracking and iteration management
"""
import json
import logging
from typing import Dict, Any, List, Optional, Callable
from modules.datamodels.datamodelAi import (
AiCallRequest, AiCallOptions
)
from modules.datamodels.datamodelExtraction import ContentPart
from modules.shared.jsonUtils import buildContinuationContext, extractJsonString, tryParseJson
from modules.services.serviceAi.subJsonResponseHandling import JsonResponseHandler
from modules.services.serviceAi.subLoopingUseCases import LoopingUseCaseRegistry
from modules.workflows.processing.shared.stateTools import checkWorkflowStopped
logger = logging.getLogger(__name__)
class AiCallLooper:
"""Handles AI calls with looping and repair logic."""
def __init__(self, services, aiService, responseParser):
"""Initialize AiCallLooper with service center, AI service, and response parser access."""
self.services = services
self.aiService = aiService
self.responseParser = responseParser
self.useCaseRegistry = LoopingUseCaseRegistry() # Initialize use case registry
async def callAiWithLooping(
self,
prompt: str,
options: AiCallOptions,
debugPrefix: str = "ai_call",
promptBuilder: Optional[Callable] = None,
promptArgs: Optional[Dict[str, Any]] = None,
operationId: Optional[str] = None,
userPrompt: Optional[str] = None,
contentParts: Optional[List[ContentPart]] = None, # ARCHITECTURE: Support ContentParts for large content
useCaseId: str = None # REQUIRED: Explicit use case ID - no auto-detection, no fallback
) -> str:
"""
Shared core function for AI calls with repair-based looping system.
Automatically repairs broken JSON and continues generation seamlessly.
Args:
prompt: The prompt to send to AI
options: AI call configuration options
debugPrefix: Prefix for debug file names
promptBuilder: Optional function to rebuild prompts for continuation
promptArgs: Optional arguments for prompt builder
operationId: Optional operation ID for progress tracking
userPrompt: Optional user prompt for KPI definition
contentParts: Optional content parts for first iteration
useCaseId: REQUIRED: Explicit use case ID - no auto-detection, no fallback
Returns:
Complete AI response after all iterations
"""
# REQUIRED: useCaseId must be provided - no auto-detection, no fallback
if not useCaseId:
errorMsg = (
"useCaseId is REQUIRED for callAiWithLooping. "
"No auto-detection - must explicitly specify use case ID. "
f"Available use cases: {list(self.useCaseRegistry.useCases.keys())}"
)
logger.error(errorMsg)
raise ValueError(errorMsg)
# Validate use case exists
useCase = self.useCaseRegistry.get(useCaseId)
if not useCase:
errorMsg = (
f"Use case '{useCaseId}' not found in registry. "
f"Available use cases: {list(self.useCaseRegistry.useCases.keys())}"
)
logger.error(errorMsg)
raise ValueError(errorMsg)
maxIterations = 50 # Prevent infinite loops
iteration = 0
allSections = [] # Accumulate all sections across iterations
lastRawResponse = None # Store last raw JSON response for continuation
accumulatedDirectJson = [] # Accumulate JSON strings for direct return use cases (chapter_structure, code_structure)
# Get parent operation ID for iteration operations (parentId should be operationId, not log entry ID)
parentOperationId = operationId # Use the parent's operationId directly
while iteration < maxIterations:
iteration += 1
# Create separate operation for each iteration with parent reference
iterationOperationId = None
if operationId:
iterationOperationId = f"{operationId}_iter_{iteration}"
self.services.chat.progressLogStart(
iterationOperationId,
"AI Call",
f"Iteration {iteration}",
"",
parentOperationId=parentOperationId
)
# Build iteration prompt
# CRITICAL: Build continuation prompt if we have sections OR if we have a previous response (even if broken)
# This ensures continuation prompts are built even when JSON is so broken that no sections can be extracted
if (len(allSections) > 0 or lastRawResponse) and promptBuilder and promptArgs:
# Extract templateStructure and basePrompt from promptArgs (REQUIRED)
templateStructure = promptArgs.get("templateStructure")
if not templateStructure:
raise ValueError(
f"templateStructure is REQUIRED in promptArgs for use case '{useCaseId}'. "
"Prompt creation functions must return (prompt, templateStructure) tuple."
)
basePrompt = promptArgs.get("basePrompt")
if not basePrompt:
# Fallback: use prompt parameter (should be the same)
basePrompt = prompt
logger.warning(
f"basePrompt not found in promptArgs for use case '{useCaseId}', "
"using prompt parameter instead. This may indicate a bug."
)
# This is a continuation - build continuation context with raw JSON and rebuild prompt
continuationContext = buildContinuationContext(
allSections, lastRawResponse, useCaseId, templateStructure
)
if not lastRawResponse:
logger.warning(f"Iteration {iteration}: No previous response available for continuation!")
# Unified prompt builder call: Continuation builders only need continuationContext, templateStructure, and basePrompt
# All initial context (section, userPrompt, etc.) is already in basePrompt, so promptArgs is not needed
# Extract templateStructure and basePrompt from promptArgs (they're explicit parameters)
iterationPrompt = await promptBuilder(
continuationContext=continuationContext,
templateStructure=templateStructure,
basePrompt=basePrompt
)
else:
# First iteration - use original prompt
iterationPrompt = prompt
# Make AI call
try:
checkWorkflowStopped(self.services)
if iterationOperationId:
self.services.chat.progressLogUpdate(iterationOperationId, 0.3, "Calling AI model")
# ARCHITECTURE: Pass ContentParts directly to AiCallRequest
# This allows model-aware chunking to handle large content properly
# ContentParts are only passed in first iteration (continuations don't need them)
request = AiCallRequest(
prompt=iterationPrompt,
context="",
options=options,
contentParts=contentParts if iteration == 1 else None # Only pass ContentParts in first iteration
)
# Write the ACTUAL prompt sent to AI
# For section content generation: write prompt for first iteration and continuation iterations
# For document generation: write prompt for each iteration
isSectionContent = "_section_" in debugPrefix
if iteration == 1:
self.services.utils.writeDebugFile(iterationPrompt, f"{debugPrefix}_prompt")
elif isSectionContent:
# Save continuation prompts for section_content debugging
self.services.utils.writeDebugFile(iterationPrompt, f"{debugPrefix}_prompt_iteration_{iteration}")
else:
# Document generation - save all iteration prompts
self.services.utils.writeDebugFile(iterationPrompt, f"{debugPrefix}_prompt_iteration_{iteration}")
response = await self.aiService.callAi(request)
result = response.content
# Track bytes for progress reporting
bytesReceived = len(result.encode('utf-8')) if result else 0
totalBytesSoFar = sum(len(section.get('content', '').encode('utf-8')) if isinstance(section.get('content'), str) else 0 for section in allSections) + bytesReceived
# Update progress after AI call with byte information
if iterationOperationId:
# Format bytes for display (kB or MB)
if totalBytesSoFar < 1024:
bytesDisplay = f"{totalBytesSoFar}B"
elif totalBytesSoFar < 1024 * 1024:
bytesDisplay = f"{totalBytesSoFar / 1024:.1f}kB"
else:
bytesDisplay = f"{totalBytesSoFar / (1024 * 1024):.1f}MB"
self.services.chat.progressLogUpdate(iterationOperationId, 0.6, f"AI response received ({bytesDisplay})")
# Write raw AI response to debug file
# For section content generation: write response for first iteration and continuation iterations
# For document generation: write response for each iteration
if iteration == 1:
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response")
elif isSectionContent:
# Save continuation responses for section_content debugging
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}")
else:
# Document generation - save all iteration responses
self.services.utils.writeDebugFile(result, f"{debugPrefix}_response_iteration_{iteration}")
# Emit stats for this iteration (only if workflow exists and has id)
if self.services.workflow and hasattr(self.services.workflow, 'id') and self.services.workflow.id:
try:
self.services.chat.storeWorkflowStat(
self.services.workflow,
response,
f"ai.call.{debugPrefix}.iteration_{iteration}"
)
except Exception as statError:
# Don't break the main loop if stat storage fails
logger.warning(f"Failed to store workflow stat: {str(statError)}")
# Check for error response using generic error detection (errorCount > 0 or modelName == "error")
if hasattr(response, 'errorCount') and response.errorCount > 0:
errorMsg = f"Iteration {iteration}: Error response detected (errorCount={response.errorCount}), stopping loop: {result[:200] if result else 'empty'}"
logger.error(errorMsg)
break
if hasattr(response, 'modelName') and response.modelName == "error":
errorMsg = f"Iteration {iteration}: Error response detected (modelName=error), stopping loop: {result[:200] if result else 'empty'}"
logger.error(errorMsg)
break
if not result or not result.strip():
logger.warning(f"Iteration {iteration}: Empty response, stopping")
break
# Check if this is a text response (not document generation)
# Text responses don't need JSON parsing - return immediately after first successful response
isTextResponse = (promptBuilder is None and promptArgs is None) or debugPrefix == "text"
if isTextResponse:
# For text responses, return the text immediately - no JSON parsing needed
logger.info(f"Iteration {iteration}: Text response received, returning immediately")
if iterationOperationId:
self.services.chat.progressLogFinish(iterationOperationId, True)
return result
# Store raw response for continuation (even if broken)
lastRawResponse = result
# Parse JSON for use case handling
parsedJsonForUseCase = None
extractedJsonForUseCase = None
try:
extractedJsonForUseCase = extractJsonString(result)
parsedJson, parseError, _ = tryParseJson(extractedJsonForUseCase)
if parseError is None and parsedJson:
parsedJsonForUseCase = parsedJson
except Exception:
pass
# Handle use cases that return JSON directly (no section extraction needed)
# Check if use case supports direct return (all registered use cases do)
if useCase and not useCase.requiresExtraction:
# For all direct return use cases, check completeness and support looping
if True: # All registered use cases support looping
# If parsing failed (e.g., invalid JSON with comments or truncated JSON), continue looping to get valid JSON
if not parsedJsonForUseCase:
logger.info(f"Iteration {iteration}: Use case '{useCaseId}' - JSON parsing failed (likely incomplete/truncated), continuing iteration to complete")
# Accumulate response for merging in next iteration
accumulatedDirectJson.append(result)
# Continue to next iteration - continuation prompt builder will handle the rest
if iterationOperationId:
self.services.chat.progressLogUpdate(iterationOperationId, 0.7, "JSON incomplete, requesting continuation")
self.services.chat.progressLogFinish(iterationOperationId, True)
continue
# If we successfully parsed JSON, check completeness using parsed structure
# CRITICAL: If parsing succeeded, trust the parsed JSON structure check
# The string-based check can have false positives on valid JSON (e.g., due to normalization issues)
# Only use string-based check when parsing fails (already handled above)
isComplete = JsonResponseHandler.isJsonComplete(parsedJsonForUseCase)
# If parsed check says complete, trust it - don't override with string check
# String check is only reliable when parsing fails (truncated JSON that gets closed for parsing)
# For successfully parsed JSON, the structure check is definitive
if not isComplete:
logger.warning(f"Iteration {iteration}: Use case '{useCaseId}' - JSON is incomplete, continuing for continuation")
# Accumulate response for merging in next iteration
accumulatedDirectJson.append(result)
# Continue to next iteration - continuation prompt builder will handle the rest
if iterationOperationId:
self.services.chat.progressLogUpdate(iterationOperationId, 0.7, "JSON incomplete, requesting continuation")
self.services.chat.progressLogFinish(iterationOperationId, True)
continue
else:
# JSON is complete - merge accumulated responses if any
if accumulatedDirectJson:
logger.info(f"Iteration {iteration}: Merging {len(accumulatedDirectJson) + 1} accumulated responses")
# Use generic data-based merging for all use cases
try:
# Strategy: Merge strings first for incomplete JSON, then parse and merge parsed objects
# This ensures incomplete JSON from part 1 is preserved
allJsonStrings = accumulatedDirectJson + [result]
# Step 1: Merge all JSON strings using existing overlap detection
mergedJsonString = allJsonStrings[0] if allJsonStrings else ""
hasOverlap = True # Track if any overlap was found
for jsonStr in allJsonStrings[1:]:
mergedJsonString, hasOverlapInMerge = JsonResponseHandler.mergeJsonStringsWithOverlap(mergedJsonString, jsonStr)
# If no overlap found in any merge, stop iterations
if not hasOverlapInMerge:
hasOverlap = False
logger.info(f"Iteration {iteration}: No overlap found during merge - stopping iterations and closing JSON")
break
# If no overlap was found, mark as complete and use closed JSON
if not hasOverlap:
isComplete = True
# JSON is already closed by mergeJsonStringsWithOverlap when no overlap
# Use the merged (closed) JSON string directly
result = mergedJsonString
# CRITICAL: Update lastRawResponse with merged result for next iteration
lastRawResponse = mergedJsonString
# Try to parse it to get parsedJsonForUseCase
try:
extracted = extractJsonString(mergedJsonString)
parsed, parseErr, _ = tryParseJson(extracted)
if parseErr is None and parsed:
# Use callback to normalize JSON structure
normalized = self._normalizeJsonStructure(parsed, useCase)
parsedJsonForUseCase = normalized
result = json.dumps(normalized, indent=2, ensure_ascii=False)
# CRITICAL: Update lastRawResponse with final result
lastRawResponse = result
else:
# Parsing failed - try to repair JSON
from modules.shared.jsonUtils import repairBrokenJson
logger.warning(
f"Iteration {iteration}: JSON parse failed after no-overlap merge, "
f"attempting repair: {str(parseErr) if parseErr else 'Unknown error'}"
)
repairedJson = repairBrokenJson(extracted)
if repairedJson and isinstance(repairedJson, dict):
# repairBrokenJson returns a dict directly - use it
normalized = self._normalizeJsonStructure(repairedJson, useCase)
parsedJsonForUseCase = normalized
result = json.dumps(normalized, indent=2, ensure_ascii=False)
# CRITICAL: Update lastRawResponse with final result
lastRawResponse = result
logger.info(f"Iteration {iteration}: Successfully repaired JSON after no-overlap merge")
except Exception as e:
# Last resort: try repair on the original merged string
logger.warning(
f"Iteration {iteration}: Exception during no-overlap JSON processing, "
f"attempting repair: {str(e)}"
)
try:
from modules.shared.jsonUtils import repairBrokenJson
repairedJson = repairBrokenJson(mergedJsonString)
if repairedJson and isinstance(repairedJson, dict):
normalized = self._normalizeJsonStructure(repairedJson, useCase)
parsedJsonForUseCase = normalized
result = json.dumps(normalized, indent=2, ensure_ascii=False)
logger.info(f"Iteration {iteration}: Successfully repaired JSON after exception")
else:
logger.error(f"Iteration {iteration}: JSON repair failed, using string result as-is")
except Exception as repairError:
logger.error(
f"Iteration {iteration}: JSON repair also failed: {str(repairError)}, "
"using string result as-is"
)
else:
# Overlap found - continue with normal processing
# Step 2: Try to parse the merged string
extracted = extractJsonString(mergedJsonString)
parsed, parseErr, _ = tryParseJson(extracted)
if parseErr is None and parsed:
# Parsing succeeded - normalize and use (via callback)
normalized = self._normalizeJsonStructure(parsed, useCase)
parsedJsonForUseCase = normalized
result = json.dumps(normalized, indent=2, ensure_ascii=False)
# CRITICAL: Update lastRawResponse with merged result
lastRawResponse = result
else:
# Parsing failed - try to extract partial data using Deep-Structure-Merging
# This fallback works for all use cases: parse what we can from each part
allParsed = []
for jsonStr in allJsonStrings:
extracted = extractJsonString(jsonStr)
parsed, parseErr, _ = tryParseJson(extracted)
if parseErr is None and parsed:
# Use callback to normalize JSON structure
normalized = self._normalizeJsonStructure(parsed, useCase)
allParsed.append(normalized)
if allParsed:
# Use mergeDeepStructures for intelligent merging across all use cases
if len(allParsed) > 1:
mergedJsonObj = allParsed[0]
for nextObj in allParsed[1:]:
mergedJsonObj = JsonResponseHandler.mergeDeepStructures(
mergedJsonObj, nextObj, iteration, f"{useCaseId}.merge"
)
else:
mergedJsonObj = allParsed[0]
parsedJsonForUseCase = mergedJsonObj
result = json.dumps(mergedJsonObj, indent=2, ensure_ascii=False)
# CRITICAL: Update lastRawResponse with merged result
lastRawResponse = result
else:
# All parsing failed - use string merge result
result = mergedJsonString
# CRITICAL: Update lastRawResponse with merged result
lastRawResponse = mergedJsonString
except Exception as e:
logger.warning(f"Failed data-based merge, falling back to string merging: {e}")
# Fallback to string merging
mergedJsonString = accumulatedDirectJson[0] if accumulatedDirectJson else result
hasOverlap = True # Track if any overlap was found
for prevJson in accumulatedDirectJson[1:]:
mergedJsonString, hasOverlapInMerge = JsonResponseHandler.mergeJsonStringsWithOverlap(mergedJsonString, prevJson)
if not hasOverlapInMerge:
hasOverlap = False
logger.info(f"Iteration {iteration}: No overlap found during fallback merge - stopping iterations")
break
if hasOverlap:
mergedJsonString, hasOverlapInMerge = JsonResponseHandler.mergeJsonStringsWithOverlap(mergedJsonString, result)
if not hasOverlapInMerge:
hasOverlap = False
logger.info(f"Iteration {iteration}: No overlap found in final fallback merge - stopping iterations")
result = mergedJsonString
# CRITICAL: Update lastRawResponse with merged result
lastRawResponse = mergedJsonString
# If no overlap was found, mark as complete and use closed JSON
if not hasOverlap:
isComplete = True
# JSON is already closed by mergeJsonStringsWithOverlap when no overlap
# Try to parse it to get parsedJsonForUseCase
try:
extractedMerged = extractJsonString(result)
parsedMerged, parseError, _ = tryParseJson(extractedMerged)
if parseError is None and parsedMerged:
parsedJsonForUseCase = parsedMerged
except Exception:
pass # Use string result if parsing fails
# Try to parse the string-merged result
try:
extractedMerged = extractJsonString(result)
parsedMerged, parseError, _ = tryParseJson(extractedMerged)
if parseError is None and parsedMerged:
parsedJsonForUseCase = parsedMerged
except Exception:
pass
logger.info(f"Iteration {iteration}: Use case '{useCaseId}' - JSON is complete")
logger.info(f"Iteration {iteration}: Use case '{useCaseId}' - returning JSON directly")
if iterationOperationId:
self.services.chat.progressLogFinish(iterationOperationId, True)
# Use callback to handle final result formatting and debug file writing (REQUIRED - no fallback)
if not useCase.finalResultHandler:
raise ValueError(
f"Use case '{useCaseId}' is missing required 'finalResultHandler' callback. "
"All use cases must provide a finalResultHandler function."
)
final_json = useCase.finalResultHandler(
result, parsedJsonForUseCase, extractedJsonForUseCase,
debugPrefix, self.services
)
return final_json
except Exception as e:
logger.error(f"Error in AI call iteration {iteration}: {str(e)}")
if iterationOperationId:
self.services.chat.progressLogFinish(iterationOperationId, False)
break
if iteration >= maxIterations:
logger.warning(f"AI call stopped after maximum iterations ({maxIterations})")
# This code path should never be reached because all registered use cases
# return early when JSON is complete. This would only execute for use cases that
# require section extraction, but no such use cases are currently registered.
logger.error(f"Unexpected code path: reached end of loop without return for use case '{useCaseId}'")
return result if result else ""
def _isJsonStringIncomplete(self, jsonString: str) -> bool:
"""
Check if JSON string is incomplete (truncated) BEFORE closing/parsing.
This is critical because if JSON is truncated, closing it makes it appear complete,
but we need to detect the truncation to continue iteration.
Args:
jsonString: JSON string to check
Returns:
True if JSON string appears incomplete/truncated, False otherwise
"""
if not jsonString or not jsonString.strip():
return False
from modules.shared.jsonUtils import stripCodeFences, normalizeJsonText
# Normalize JSON string
normalized = stripCodeFences(normalizeJsonText(jsonString)).strip()
if not normalized:
return False
# Find first '{' or '[' to start
startIdx = -1
for i, char in enumerate(normalized):
if char in '{[':
startIdx = i
break
if startIdx == -1:
return False
jsonContent = normalized[startIdx:]
# Check if structures are balanced (all opened structures are closed)
braceCount = 0
bracketCount = 0
inString = False
escapeNext = False
for char in jsonContent:
if escapeNext:
escapeNext = False
continue
if char == '\\':
escapeNext = True
continue
if char == '"':
inString = not inString
continue
if not inString:
if char == '{':
braceCount += 1
elif char == '}':
braceCount -= 1
elif char == '[':
bracketCount += 1
elif char == ']':
bracketCount -= 1
# If structures are unbalanced, JSON is incomplete
if braceCount > 0 or bracketCount > 0:
return True
# Check if JSON ends with incomplete value (e.g., unclosed string, incomplete number, trailing comma)
trimmed = jsonContent.rstrip()
if not trimmed:
return False
# Check for trailing comma (might indicate incomplete)
if trimmed.endswith(','):
# Trailing comma might indicate incomplete, but could also be valid
# Check if there's a closing bracket/brace after the comma
return False # Trailing comma alone doesn't mean incomplete
# Check if ends with incomplete string (odd number of quotes)
quoteCount = jsonContent.count('"')
if quoteCount % 2 == 1:
# Odd number of quotes - string is not closed
return True
# Check if ends mid-value (e.g., ends with "417 instead of "4170. 41719"])
# Look for patterns that suggest truncation:
# - Ends with incomplete number (e.g., "417)
# - Ends with incomplete array element (e.g., ["417)
# - Ends with incomplete object property (e.g., {"key": "val)
# If JSON parses successfully without closing, it's complete
from modules.shared.jsonUtils import tryParseJson
parsed, parseErr, _ = tryParseJson(jsonContent)
if parseErr is None:
# Parses successfully - it's complete
return False
# If it doesn't parse, try closing it and see if that helps
from modules.shared.jsonUtils import closeJsonStructures
closed = closeJsonStructures(jsonContent)
parsedClosed, parseErrClosed, _ = tryParseJson(closed)
if parseErrClosed is None:
# Only parses after closing - it was incomplete
return True
# Doesn't parse even after closing - might be malformed, but assume incomplete to be safe
return True
def _normalizeJsonStructure(self, parsed: Any, useCase) -> Any:
"""
Normalize JSON structure to ensure consistent format before merging.
Handles different response formats and converts them to expected structure.
Args:
parsed: Parsed JSON object (can be dict, list, or primitive)
useCase: LoopingUseCase instance with jsonNormalizer callback
Returns:
Normalized JSON structure
"""
# Use callback to normalize JSON structure (REQUIRED - no fallback)
if not useCase or not useCase.jsonNormalizer:
raise ValueError(
f"Use case '{useCase.useCaseId if useCase else 'unknown'}' is missing required 'jsonNormalizer' callback. "
"All use cases must provide a jsonNormalizer function."
)
return useCase.jsonNormalizer(parsed, useCase.useCaseId)