348 lines
16 KiB
Python
348 lines
16 KiB
Python
# taskPlanner.py
|
|
# Task planning functionality for workflows
|
|
|
|
import json
|
|
import logging
|
|
from typing import Dict, Any
|
|
from modules.datamodels.datamodelWorkflow import TaskStep, TaskContext, TaskPlan
|
|
from modules.datamodels.datamodelAi import AiCallOptions, OperationType, ProcessingMode, Priority
|
|
from modules.workflows.processing.shared.promptGenerationTaskplan import (
|
|
createTaskPlanningPromptTemplate
|
|
)
|
|
from modules.workflows.processing.shared.placeholderFactory import (
|
|
extractUserPrompt
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TaskPlanner:
|
|
"""Handles task planning for workflows"""
|
|
|
|
def __init__(self, services):
|
|
self.services = services
|
|
|
|
def _checkWorkflowStopped(self, workflow):
|
|
"""Check if workflow has been stopped by user and raise exception if so"""
|
|
try:
|
|
# Get the current workflow status from the database to avoid stale data
|
|
current_workflow = self.services.interfaceDbChat.getWorkflow(workflow.id)
|
|
if current_workflow and current_workflow.status == "stopped":
|
|
logger.info("Workflow stopped by user, aborting task planning")
|
|
raise Exception("Workflow was stopped by user")
|
|
except Exception as e:
|
|
# If we can't get the current status due to other database issues, fall back to the in-memory object
|
|
logger.warning(f"Could not check current workflow status from database: {str(e)}")
|
|
if workflow and workflow.status == "stopped":
|
|
logger.info("Workflow stopped by user (from in-memory object), aborting task planning")
|
|
raise Exception("Workflow was stopped by user")
|
|
|
|
async def generateTaskPlan(self, userInput: str, workflow) -> TaskPlan:
|
|
"""Generate a high-level task plan for the workflow"""
|
|
try:
|
|
# Check workflow status before generating task plan
|
|
self._checkWorkflowStopped(workflow)
|
|
|
|
logger.info(f"=== STARTING TASK PLAN GENERATION ===")
|
|
logger.info(f"Workflow ID: {workflow.id}")
|
|
logger.info(f"User Input: {userInput}")
|
|
|
|
# Use stored user prompt if available, otherwise use the input
|
|
actualUserPrompt = self.services.currentUserPrompt if self.services and hasattr(self.services, 'currentUserPrompt') and self.services.currentUserPrompt else userInput
|
|
logger.info(f"Actual User Prompt: {actualUserPrompt}")
|
|
|
|
# Check workflow status before calling AI service
|
|
self._checkWorkflowStopped(workflow)
|
|
|
|
# Create proper context object for task planning
|
|
# For task planning, we need to create a minimal TaskStep since TaskContext requires it
|
|
planningTaskStep = TaskStep(
|
|
id="planning",
|
|
objective=actualUserPrompt,
|
|
dependencies=[],
|
|
success_criteria=[],
|
|
estimated_complexity="medium"
|
|
)
|
|
|
|
taskPlanningContext = TaskContext(
|
|
task_step=planningTaskStep,
|
|
workflow=workflow,
|
|
workflow_id=workflow.id,
|
|
available_documents=None,
|
|
available_connections=None,
|
|
previous_results=[],
|
|
previous_handover=None,
|
|
improvements=[],
|
|
retry_count=0,
|
|
previous_action_results=[],
|
|
previous_review_result=None,
|
|
is_regeneration=False,
|
|
failure_patterns=[],
|
|
failed_actions=[],
|
|
successful_actions=[],
|
|
criteria_progress={
|
|
'met_criteria': set(),
|
|
'unmet_criteria': set(),
|
|
'attempt_history': []
|
|
}
|
|
)
|
|
|
|
# Generate the task planning prompt with placeholders
|
|
taskPlanningPromptTemplate = createTaskPlanningPromptTemplate()
|
|
|
|
# Extract content for placeholders
|
|
userPrompt = extractUserPrompt(taskPlanningContext)
|
|
# Task planner only needs document count, not full document list
|
|
availableDocuments = self.services.workflow.getDocumentCount()
|
|
# Use centralized workflow history context function
|
|
workflowHistory = self.services.workflow.getWorkflowHistoryContext()
|
|
|
|
# Create placeholders dictionary
|
|
placeholders = {
|
|
"USER_PROMPT": userPrompt,
|
|
"AVAILABLE_DOCUMENTS": availableDocuments,
|
|
"WORKFLOW_HISTORY": workflowHistory
|
|
}
|
|
|
|
# Log task planning prompt sent to AI
|
|
logger.info("=== TASK PLANNING PROMPT SENT TO AI ===")
|
|
# Trace task planning prompt
|
|
self._writeTraceLog("Task Plan Prompt", taskPlanningPromptTemplate)
|
|
self._writeTraceLog("Task Plan Placeholders", placeholders)
|
|
|
|
# Centralized AI call: Task planning (quality, detailed) with placeholders
|
|
options = AiCallOptions(
|
|
operationType=OperationType.GENERATE_PLAN,
|
|
priority=Priority.QUALITY,
|
|
compressPrompt=False,
|
|
compressContext=False,
|
|
processingMode=ProcessingMode.DETAILED,
|
|
maxCost=0.10,
|
|
maxProcessingTime=30
|
|
)
|
|
|
|
prompt = await self.services.ai.callAi(
|
|
prompt=taskPlanningPromptTemplate,
|
|
placeholders=placeholders,
|
|
options=options
|
|
)
|
|
|
|
# Check if AI response is valid
|
|
if not prompt:
|
|
raise ValueError("AI service returned no response for task planning")
|
|
|
|
# Log task planning response received
|
|
logger.info("=== TASK PLANNING AI RESPONSE RECEIVED ===")
|
|
logger.info(f"Response length: {len(prompt) if prompt else 0}")
|
|
# Trace task planning response
|
|
self._writeTraceLog("Task Plan Response", prompt)
|
|
|
|
# Parse task plan response
|
|
try:
|
|
jsonStart = prompt.find('{')
|
|
jsonEnd = prompt.rfind('}') + 1
|
|
if jsonStart == -1 or jsonEnd == 0:
|
|
raise ValueError("No JSON found in response")
|
|
jsonStr = prompt[jsonStart:jsonEnd]
|
|
taskPlanDict = json.loads(jsonStr)
|
|
|
|
if 'tasks' not in taskPlanDict:
|
|
raise ValueError("Task plan missing 'tasks' field")
|
|
except Exception as e:
|
|
logger.error(f"Error parsing task plan response: {str(e)}")
|
|
taskPlanDict = {'tasks': []}
|
|
|
|
if not self._validateTaskPlan(taskPlanDict):
|
|
logger.error("Generated task plan failed validation")
|
|
logger.error(f"AI Response: {prompt}")
|
|
logger.error(f"Parsed Task Plan: {json.dumps(taskPlanDict, indent=2)}")
|
|
raise Exception("AI-generated task plan failed validation - AI is required for task planning")
|
|
|
|
if not taskPlanDict.get('tasks'):
|
|
raise ValueError("Task plan contains no tasks")
|
|
|
|
# LANGUAGE DETECTION: Determine user language once for the entire workflow
|
|
# Priority: 1. languageUserDetected from AI response, 2. service.user.language, 3. "en"
|
|
detectedLanguage = taskPlanDict.get('languageUserDetected', '').strip()
|
|
serviceUserLanguage = getattr(self.services.user, 'language', '') if self.services and self.services.user else ''
|
|
|
|
if detectedLanguage and len(detectedLanguage) == 2: # Valid language code like "en", "de", "fr"
|
|
userLanguage = detectedLanguage
|
|
logger.info(f"Using detected language from AI response: {userLanguage}")
|
|
elif serviceUserLanguage and len(serviceUserLanguage) == 2:
|
|
userLanguage = serviceUserLanguage
|
|
logger.info(f"Using language from service user object: {userLanguage}")
|
|
else:
|
|
userLanguage = "en"
|
|
logger.info(f"Using default language: {userLanguage}")
|
|
|
|
# Set the detected language in the service for use throughout the workflow
|
|
if self.services and self.services.user:
|
|
self.services.user.language = userLanguage
|
|
logger.info(f"Set workflow user language to: {userLanguage}")
|
|
|
|
tasks = []
|
|
for i, taskDict in enumerate(taskPlanDict.get('tasks', [])):
|
|
if not isinstance(taskDict, dict):
|
|
logger.warning(f"Skipping invalid task {i+1}: not a dictionary")
|
|
continue
|
|
|
|
# Map old 'description' field to new 'objective' field
|
|
if 'description' in taskDict and 'objective' not in taskDict:
|
|
taskDict['objective'] = taskDict.pop('description')
|
|
|
|
try:
|
|
task = TaskStep(**taskDict)
|
|
tasks.append(task)
|
|
except Exception as e:
|
|
logger.warning(f"Skipping invalid task {i+1}: {str(e)}")
|
|
continue
|
|
|
|
if not tasks:
|
|
raise ValueError("No valid tasks could be created from AI response")
|
|
|
|
taskPlan = TaskPlan(
|
|
overview=taskPlanDict.get('overview', ''),
|
|
tasks=tasks,
|
|
userMessage=taskPlanDict.get('userMessage', '')
|
|
)
|
|
|
|
logger.info(f"Task plan generated successfully with {len(tasks)} tasks")
|
|
logger.info(f"Workflow user language set to: {userLanguage}")
|
|
|
|
return taskPlan
|
|
except Exception as e:
|
|
logger.error(f"Error in generateTaskPlan: {str(e)}")
|
|
raise
|
|
|
|
|
|
|
|
def _validateTaskPlan(self, taskPlan: Dict[str, Any]) -> bool:
|
|
"""Validate task plan structure"""
|
|
try:
|
|
if not isinstance(taskPlan, dict):
|
|
logger.error("Task plan is not a dictionary")
|
|
return False
|
|
|
|
if 'tasks' not in taskPlan or not isinstance(taskPlan['tasks'], list):
|
|
logger.error(f"Task plan missing 'tasks' field or not a list. Found: {type(taskPlan.get('tasks', 'MISSING'))}")
|
|
return False
|
|
|
|
# First pass: collect all task IDs to validate dependencies
|
|
taskIds = set()
|
|
for task in taskPlan['tasks']:
|
|
if not isinstance(task, dict):
|
|
logger.error(f"Task is not a dictionary: {type(task)}")
|
|
return False
|
|
if 'id' not in task:
|
|
logger.error(f"Task missing 'id' field: {task}")
|
|
return False
|
|
taskIds.add(task['id'])
|
|
|
|
# Second pass: validate each task
|
|
for i, task in enumerate(taskPlan['tasks']):
|
|
if not isinstance(task, dict):
|
|
logger.error(f"Task {i} is not a dictionary: {type(task)}")
|
|
return False
|
|
|
|
requiredFields = ['id', 'objective', 'success_criteria']
|
|
missingFields = [field for field in requiredFields if field not in task]
|
|
if missingFields:
|
|
logger.error(f"Task {i} missing required fields: {missingFields}")
|
|
return False
|
|
|
|
# Check for duplicate IDs (shouldn't happen after first pass, but safety check)
|
|
if task['id'] in taskIds and list(taskPlan['tasks']).count(task['id']) > 1:
|
|
logger.error(f"Task {i} has duplicate ID: {task['id']}")
|
|
return False
|
|
|
|
dependencies = task.get('dependencies', [])
|
|
if not isinstance(dependencies, list):
|
|
logger.error(f"Task {i} dependencies is not a list: {type(dependencies)}")
|
|
return False
|
|
|
|
for dep in dependencies:
|
|
if dep not in taskIds and dep != 'task_0':
|
|
logger.error(f"Task {i} has invalid dependency: {dep} (available: {list(taskIds) + ['task_0']})")
|
|
return False
|
|
|
|
logger.info(f"Task plan validation successful with {len(taskIds)} tasks")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating task plan: {str(e)}")
|
|
return False
|
|
|
|
def _writeTraceLog(self, contextText: str, data: Any) -> None:
|
|
"""Write trace data to configured trace file if in debug mode with improved JSON formatting"""
|
|
try:
|
|
import os
|
|
import json
|
|
from datetime import datetime, UTC
|
|
|
|
# Only write if logger is in debug mode
|
|
if logger.level > logging.DEBUG:
|
|
return
|
|
|
|
# Get log directory from configuration
|
|
logDir = self.services.utils.configGet("APP_LOGGING_LOG_DIR", "./")
|
|
if not os.path.isabs(logDir):
|
|
# If relative path, make it relative to the gateway directory
|
|
gatewayDir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
logDir = os.path.join(gatewayDir, logDir)
|
|
|
|
# Ensure log directory exists
|
|
os.makedirs(logDir, exist_ok=True)
|
|
|
|
# Create trace file path
|
|
traceFile = os.path.join(logDir, "log_trace.log")
|
|
|
|
# Format the trace entry with better structure
|
|
timestamp = datetime.fromtimestamp(self.services.utils.getUtcTimestamp(), UTC).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
|
|
|
# Create a structured trace entry
|
|
traceEntry = f"[{timestamp}] {contextText}\n"
|
|
traceEntry += "=" * 80 + "\n"
|
|
|
|
# Add data if provided with improved formatting
|
|
if data is not None:
|
|
try:
|
|
if isinstance(data, (dict, list)):
|
|
# Format as pretty JSON with better settings
|
|
jsonStr = json.dumps(data, indent=2, default=str, ensure_ascii=False, sort_keys=False)
|
|
traceEntry += f"JSON Data:\n{jsonStr}\n"
|
|
elif isinstance(data, str):
|
|
# For string data, try to parse as JSON first, then fall back to plain text
|
|
try:
|
|
parsed = json.loads(data)
|
|
jsonStr = json.dumps(parsed, indent=2, default=str, ensure_ascii=False, sort_keys=False)
|
|
traceEntry += f"JSON Data (parsed from string):\n{jsonStr}\n"
|
|
except (json.JSONDecodeError, TypeError):
|
|
# Not valid JSON, show as plain text with proper line breaks
|
|
formatted_data = data.replace('\\n', '\n')
|
|
traceEntry += f"Text Data:\n{formatted_data}\n"
|
|
else:
|
|
# For other types, convert to string and try to parse as JSON
|
|
dataStr = str(data)
|
|
try:
|
|
parsed = json.loads(dataStr)
|
|
jsonStr = json.dumps(parsed, indent=2, default=str, ensure_ascii=False, sort_keys=False)
|
|
traceEntry += f"JSON Data (parsed from object):\n{jsonStr}\n"
|
|
except (json.JSONDecodeError, TypeError):
|
|
# Not valid JSON, show as plain text with proper line breaks
|
|
formatted_data = dataStr.replace('\\n', '\n')
|
|
traceEntry += f"Object Data:\n{formatted_data}\n"
|
|
except Exception as e:
|
|
# Fallback to simple string representation
|
|
traceEntry += f"Data (fallback): {str(data)}\n"
|
|
else:
|
|
traceEntry += "No data provided\n"
|
|
|
|
traceEntry += "=" * 80 + "\n\n"
|
|
|
|
# Write to trace file
|
|
with open(traceFile, "a", encoding="utf-8") as f:
|
|
f.write(traceEntry)
|
|
|
|
except Exception as e:
|
|
# Don't log trace errors to avoid recursion
|
|
pass
|