gateway/modules/workflows/processing/core/taskPlanner.py
2025-10-04 18:44:42 +02:00

343 lines
16 KiB
Python

# taskPlanner.py
# Task planning functionality for workflows
import json
import logging
from typing import Dict, Any
from modules.datamodels.datamodelWorkflow import TaskStep, TaskContext, TaskPlan
from modules.datamodels.datamodelAi import AiCallOptions, OperationType, ProcessingMode, Priority
from modules.workflows.processing.shared.promptFactoryPlaceholders import (
createTaskPlanningPromptTemplate,
extractUserPrompt,
extractAvailableDocuments,
extractWorkflowHistory
)
logger = logging.getLogger(__name__)
class TaskPlanner:
"""Handles task planning for workflows"""
def __init__(self, services):
self.services = services
def _checkWorkflowStopped(self, workflow):
"""Check if workflow has been stopped by user and raise exception if so"""
try:
# Get the current workflow status from the database to avoid stale data
current_workflow = self.services.interfaceDbChat.getWorkflow(workflow.id)
if current_workflow and current_workflow.status == "stopped":
logger.info("Workflow stopped by user, aborting task planning")
raise Exception("Workflow was stopped by user")
except Exception as e:
# If we can't get the current status due to other database issues, fall back to the in-memory object
logger.warning(f"Could not check current workflow status from database: {str(e)}")
if workflow and workflow.status == "stopped":
logger.info("Workflow stopped by user (from in-memory object), aborting task planning")
raise Exception("Workflow was stopped by user")
async def generateTaskPlan(self, userInput: str, workflow) -> TaskPlan:
"""Generate a high-level task plan for the workflow"""
try:
# Check workflow status before generating task plan
self._checkWorkflowStopped(workflow)
logger.info(f"=== STARTING TASK PLAN GENERATION ===")
logger.info(f"Workflow ID: {workflow.id}")
logger.info(f"User Input: {userInput}")
# Check workflow status before calling AI service
self._checkWorkflowStopped(workflow)
# Create proper context object for task planning
# For task planning, we need to create a minimal TaskStep since TaskContext requires it
planningTaskStep = TaskStep(
id="planning",
objective=userInput,
dependencies=[],
success_criteria=[],
estimated_complexity="medium"
)
taskPlanningContext = TaskContext(
task_step=planningTaskStep,
workflow=workflow,
workflow_id=workflow.id,
available_documents=None,
available_connections=None,
previous_results=[],
previous_handover=None,
improvements=[],
retry_count=0,
previous_action_results=[],
previous_review_result=None,
is_regeneration=False,
failure_patterns=[],
failed_actions=[],
successful_actions=[],
criteria_progress={
'met_criteria': set(),
'unmet_criteria': set(),
'attempt_history': []
}
)
# Generate the task planning prompt with placeholders
taskPlanningPromptTemplate = createTaskPlanningPromptTemplate()
# Extract content for placeholders
userPrompt = extractUserPrompt(taskPlanningContext)
# Task planner only needs document count, not full document list
availableDocuments = self.services.workflow.getDocumentCount()
# Use centralized workflow history context function
workflowHistory = self.services.workflow.getWorkflowHistoryContext()
# Create placeholders dictionary
placeholders = {
"USER_PROMPT": userPrompt,
"AVAILABLE_DOCUMENTS": availableDocuments,
"WORKFLOW_HISTORY": workflowHistory
}
# Log task planning prompt sent to AI
logger.info("=== TASK PLANNING PROMPT SENT TO AI ===")
# Trace task planning prompt
self._writeTraceLog("Task Plan Prompt", taskPlanningPromptTemplate)
self._writeTraceLog("Task Plan Placeholders", placeholders)
# Centralized AI call: Task planning (quality, detailed) with placeholders
options = AiCallOptions(
operationType=OperationType.GENERATE_PLAN,
priority=Priority.QUALITY,
compressPrompt=False,
compressContext=False,
processingMode=ProcessingMode.DETAILED,
maxCost=0.10,
maxProcessingTime=30
)
prompt = await self.services.ai.callAi(
prompt=taskPlanningPromptTemplate,
placeholders=placeholders,
options=options
)
# Check if AI response is valid
if not prompt:
raise ValueError("AI service returned no response for task planning")
# Log task planning response received
logger.info("=== TASK PLANNING AI RESPONSE RECEIVED ===")
logger.info(f"Response length: {len(prompt) if prompt else 0}")
# Trace task planning response
self._writeTraceLog("Task Plan Response", prompt)
# Parse task plan response
try:
jsonStart = prompt.find('{')
jsonEnd = prompt.rfind('}') + 1
if jsonStart == -1 or jsonEnd == 0:
raise ValueError("No JSON found in response")
jsonStr = prompt[jsonStart:jsonEnd]
taskPlanDict = json.loads(jsonStr)
if 'tasks' not in taskPlanDict:
raise ValueError("Task plan missing 'tasks' field")
except Exception as e:
logger.error(f"Error parsing task plan response: {str(e)}")
taskPlanDict = {'tasks': []}
if not self._validateTaskPlan(taskPlanDict):
logger.error("Generated task plan failed validation")
logger.error(f"AI Response: {prompt}")
logger.error(f"Parsed Task Plan: {json.dumps(taskPlanDict, indent=2)}")
raise Exception("AI-generated task plan failed validation - AI is required for task planning")
if not taskPlanDict.get('tasks'):
raise ValueError("Task plan contains no tasks")
# LANGUAGE DETECTION: Determine user language once for the entire workflow
# Priority: 1. languageUserDetected from AI response, 2. service.user.language, 3. "en"
detectedLanguage = taskPlanDict.get('languageUserDetected', '').strip()
serviceUserLanguage = getattr(self.services.user, 'language', '') if self.services and self.services.user else ''
if detectedLanguage and len(detectedLanguage) == 2: # Valid language code like "en", "de", "fr"
userLanguage = detectedLanguage
logger.info(f"Using detected language from AI response: {userLanguage}")
elif serviceUserLanguage and len(serviceUserLanguage) == 2:
userLanguage = serviceUserLanguage
logger.info(f"Using language from service user object: {userLanguage}")
else:
userLanguage = "en"
logger.info(f"Using default language: {userLanguage}")
# Set the detected language in the service for use throughout the workflow
if self.services and self.services.user:
self.services.user.language = userLanguage
logger.info(f"Set workflow user language to: {userLanguage}")
tasks = []
for i, taskDict in enumerate(taskPlanDict.get('tasks', [])):
if not isinstance(taskDict, dict):
logger.warning(f"Skipping invalid task {i+1}: not a dictionary")
continue
# Map old 'description' field to new 'objective' field
if 'description' in taskDict and 'objective' not in taskDict:
taskDict['objective'] = taskDict.pop('description')
try:
task = TaskStep(**taskDict)
tasks.append(task)
except Exception as e:
logger.warning(f"Skipping invalid task {i+1}: {str(e)}")
continue
if not tasks:
raise ValueError("No valid tasks could be created from AI response")
taskPlan = TaskPlan(
overview=taskPlanDict.get('overview', ''),
tasks=tasks,
userMessage=taskPlanDict.get('userMessage', '')
)
logger.info(f"Task plan generated successfully with {len(tasks)} tasks")
logger.info(f"Workflow user language set to: {userLanguage}")
return taskPlan
except Exception as e:
logger.error(f"Error in generateTaskPlan: {str(e)}")
raise
def _validateTaskPlan(self, taskPlan: Dict[str, Any]) -> bool:
"""Validate task plan structure"""
try:
if not isinstance(taskPlan, dict):
logger.error("Task plan is not a dictionary")
return False
if 'tasks' not in taskPlan or not isinstance(taskPlan['tasks'], list):
logger.error(f"Task plan missing 'tasks' field or not a list. Found: {type(taskPlan.get('tasks', 'MISSING'))}")
return False
# First pass: collect all task IDs to validate dependencies
taskIds = set()
for task in taskPlan['tasks']:
if not isinstance(task, dict):
logger.error(f"Task is not a dictionary: {type(task)}")
return False
if 'id' not in task:
logger.error(f"Task missing 'id' field: {task}")
return False
taskIds.add(task['id'])
# Second pass: validate each task
for i, task in enumerate(taskPlan['tasks']):
if not isinstance(task, dict):
logger.error(f"Task {i} is not a dictionary: {type(task)}")
return False
requiredFields = ['id', 'objective', 'success_criteria']
missingFields = [field for field in requiredFields if field not in task]
if missingFields:
logger.error(f"Task {i} missing required fields: {missingFields}")
return False
# Check for duplicate IDs (shouldn't happen after first pass, but safety check)
if task['id'] in taskIds and list(taskPlan['tasks']).count(task['id']) > 1:
logger.error(f"Task {i} has duplicate ID: {task['id']}")
return False
dependencies = task.get('dependencies', [])
if not isinstance(dependencies, list):
logger.error(f"Task {i} dependencies is not a list: {type(dependencies)}")
return False
for dep in dependencies:
if dep not in taskIds and dep != 'task_0':
logger.error(f"Task {i} has invalid dependency: {dep} (available: {list(taskIds) + ['task_0']})")
return False
logger.info(f"Task plan validation successful with {len(taskIds)} tasks")
return True
except Exception as e:
logger.error(f"Error validating task plan: {str(e)}")
return False
def _writeTraceLog(self, contextText: str, data: Any) -> None:
"""Write trace data to configured trace file if in debug mode with improved JSON formatting"""
try:
import os
import json
from datetime import datetime, UTC
# Only write if logger is in debug mode
if logger.level > logging.DEBUG:
return
# Get log directory from configuration
logDir = self.services.utils.configGet("APP_LOGGING_LOG_DIR", "./")
if not os.path.isabs(logDir):
# If relative path, make it relative to the gateway directory
gatewayDir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
logDir = os.path.join(gatewayDir, logDir)
# Ensure log directory exists
os.makedirs(logDir, exist_ok=True)
# Create trace file path
traceFile = os.path.join(logDir, "log_trace.log")
# Format the trace entry with better structure
timestamp = datetime.fromtimestamp(self.services.utils.getUtcTimestamp(), UTC).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
# Create a structured trace entry
traceEntry = f"[{timestamp}] {contextText}\n"
traceEntry += "=" * 80 + "\n"
# Add data if provided with improved formatting
if data is not None:
try:
if isinstance(data, (dict, list)):
# Format as pretty JSON with better settings
jsonStr = json.dumps(data, indent=2, default=str, ensure_ascii=False, sort_keys=False)
traceEntry += f"JSON Data:\n{jsonStr}\n"
elif isinstance(data, str):
# For string data, try to parse as JSON first, then fall back to plain text
try:
parsed = json.loads(data)
jsonStr = json.dumps(parsed, indent=2, default=str, ensure_ascii=False, sort_keys=False)
traceEntry += f"JSON Data (parsed from string):\n{jsonStr}\n"
except (json.JSONDecodeError, TypeError):
# Not valid JSON, show as plain text with proper line breaks
formatted_data = data.replace('\\n', '\n')
traceEntry += f"Text Data:\n{formatted_data}\n"
else:
# For other types, convert to string and try to parse as JSON
dataStr = str(data)
try:
parsed = json.loads(dataStr)
jsonStr = json.dumps(parsed, indent=2, default=str, ensure_ascii=False, sort_keys=False)
traceEntry += f"JSON Data (parsed from object):\n{jsonStr}\n"
except (json.JSONDecodeError, TypeError):
# Not valid JSON, show as plain text with proper line breaks
formatted_data = dataStr.replace('\\n', '\n')
traceEntry += f"Object Data:\n{formatted_data}\n"
except Exception as e:
# Fallback to simple string representation
traceEntry += f"Data (fallback): {str(data)}\n"
else:
traceEntry += "No data provided\n"
traceEntry += "=" * 80 + "\n\n"
# Write to trace file
with open(traceFile, "a", encoding="utf-8") as f:
f.write(traceEntry)
except Exception as e:
# Don't log trace errors to avoid recursion
pass