241 lines
11 KiB
Python
241 lines
11 KiB
Python
# taskPlanner.py
|
|
# Task planning functionality for workflows
|
|
|
|
import json
|
|
import logging
|
|
from typing import Dict, Any
|
|
from modules.datamodels.datamodelChat import TaskStep, TaskContext, TaskPlan
|
|
from modules.datamodels.datamodelAi import AiCallOptions, OperationType, ProcessingMode, Priority
|
|
from modules.workflows.processing.shared.promptGenerationTaskplan import (
|
|
generateTaskPlanningPrompt
|
|
)
|
|
from modules.workflows.processing.adaptive import IntentAnalyzer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TaskPlanner:
|
|
"""Handles task planning for workflows"""
|
|
|
|
def __init__(self, services):
|
|
self.services = services
|
|
|
|
def _checkWorkflowStopped(self, workflow):
|
|
"""Check if workflow has been stopped by user and raise exception if so"""
|
|
try:
|
|
# Get the current workflow status from the database to avoid stale data
|
|
current_workflow = self.services.interfaceDbChat.getWorkflow(workflow.id)
|
|
if current_workflow and current_workflow.status == "stopped":
|
|
logger.info("Workflow stopped by user, aborting task planning")
|
|
raise Exception("Workflow was stopped by user")
|
|
except Exception as e:
|
|
# If we can't get the current status due to other database issues, fall back to the in-memory object
|
|
logger.warning(f"Could not check current workflow status from database: {str(e)}")
|
|
if workflow and workflow.status == "stopped":
|
|
logger.info("Workflow stopped by user (from in-memory object), aborting task planning")
|
|
raise Exception("Workflow was stopped by user")
|
|
|
|
async def generateTaskPlan(self, userInput: str, workflow) -> TaskPlan:
|
|
"""Generate a high-level task plan for the workflow"""
|
|
try:
|
|
# Check workflow status before generating task plan
|
|
self._checkWorkflowStopped(workflow)
|
|
|
|
logger.info(f"=== STARTING TASK PLAN GENERATION ===")
|
|
logger.info(f"Workflow ID: {workflow.id}")
|
|
logger.info(f"User Input: {userInput}")
|
|
|
|
# Use stored user prompt if available, otherwise use the input
|
|
actualUserPrompt = self.services.currentUserPrompt if self.services and hasattr(self.services, 'currentUserPrompt') and self.services.currentUserPrompt else userInput
|
|
logger.info(f"Actual User Prompt: {actualUserPrompt}")
|
|
|
|
# Check workflow status before calling AI service
|
|
self._checkWorkflowStopped(workflow)
|
|
|
|
# Analyze user intent to obtain cleaned user objective for planning
|
|
intentAnalyzer = IntentAnalyzer(self.services)
|
|
intent = await intentAnalyzer.analyzeUserIntent(actualUserPrompt, None)
|
|
cleanedObjective = intent.get('primaryGoal', actualUserPrompt) if isinstance(intent, dict) else actualUserPrompt
|
|
|
|
# Create proper context object for task planning using cleaned intent
|
|
# For task planning, we need to create a minimal TaskStep since TaskContext requires it
|
|
planningTaskStep = TaskStep(
|
|
id="planning",
|
|
objective=cleanedObjective,
|
|
dependencies=[],
|
|
success_criteria=[],
|
|
estimated_complexity="medium"
|
|
)
|
|
|
|
taskPlanningContext = TaskContext(
|
|
task_step=planningTaskStep,
|
|
workflow=workflow,
|
|
workflow_id=workflow.id,
|
|
available_documents=None,
|
|
available_connections=None,
|
|
previous_results=[],
|
|
previous_handover=None,
|
|
improvements=[],
|
|
retry_count=0,
|
|
previous_action_results=[],
|
|
previous_review_result=None,
|
|
is_regeneration=False,
|
|
failure_patterns=[],
|
|
failed_actions=[],
|
|
successful_actions=[],
|
|
criteria_progress={
|
|
'met_criteria': set(),
|
|
'unmet_criteria': set(),
|
|
'attempt_history': []
|
|
}
|
|
)
|
|
|
|
# Build prompt bundle (template + placeholders) using new API
|
|
bundle = generateTaskPlanningPrompt(self.services, taskPlanningContext)
|
|
taskPlanningPromptTemplate = bundle.prompt
|
|
placeholders = bundle.placeholders
|
|
|
|
# Centralized AI call: Task planning (quality, detailed) with placeholders
|
|
options = AiCallOptions(
|
|
operationType=OperationType.GENERATE_PLAN,
|
|
priority=Priority.QUALITY,
|
|
compressPrompt=False,
|
|
compressContext=False,
|
|
processingMode=ProcessingMode.DETAILED,
|
|
maxCost=0.10,
|
|
maxProcessingTime=30
|
|
)
|
|
|
|
prompt = await self.services.ai.callAiPlanning(
|
|
prompt=taskPlanningPromptTemplate,
|
|
placeholders=placeholders,
|
|
options=options,
|
|
loopInstructionFormat="json"
|
|
)
|
|
|
|
# Check if AI response is valid
|
|
if not prompt:
|
|
raise ValueError("AI service returned no response for task planning")
|
|
|
|
# Parse task plan response
|
|
try:
|
|
jsonStart = prompt.find('{')
|
|
jsonEnd = prompt.rfind('}') + 1
|
|
if jsonStart == -1 or jsonEnd == 0:
|
|
raise ValueError("No JSON found in response")
|
|
jsonStr = prompt[jsonStart:jsonEnd]
|
|
taskPlanDict = json.loads(jsonStr)
|
|
|
|
if 'tasks' not in taskPlanDict:
|
|
raise ValueError("Task plan missing 'tasks' field")
|
|
except Exception as e:
|
|
logger.error(f"Error parsing task plan response: {str(e)}")
|
|
taskPlanDict = {'tasks': []}
|
|
|
|
if not self._validateTaskPlan(taskPlanDict):
|
|
logger.error("Generated task plan failed validation")
|
|
logger.error(f"AI Response: {prompt}")
|
|
logger.error(f"Parsed Task Plan: {json.dumps(taskPlanDict, indent=2)}")
|
|
raise Exception("AI-generated task plan failed validation - AI is required for task planning")
|
|
|
|
if not taskPlanDict.get('tasks'):
|
|
raise ValueError("Task plan contains no tasks")
|
|
|
|
|
|
# Use already detected language from services; do not detect here
|
|
userLanguage = self.services.currentUserLanguage or 'en'
|
|
logger.info(f"Task planning using user language: {userLanguage}")
|
|
|
|
tasks = []
|
|
for i, taskDict in enumerate(taskPlanDict.get('tasks', [])):
|
|
if not isinstance(taskDict, dict):
|
|
logger.warning(f"Skipping invalid task {i+1}: not a dictionary")
|
|
continue
|
|
|
|
# Map old 'description' field to new 'objective' field
|
|
if 'description' in taskDict and 'objective' not in taskDict:
|
|
taskDict['objective'] = taskDict.pop('description')
|
|
|
|
try:
|
|
task = TaskStep(**taskDict)
|
|
tasks.append(task)
|
|
except Exception as e:
|
|
logger.warning(f"Skipping invalid task {i+1}: {str(e)}")
|
|
continue
|
|
|
|
if not tasks:
|
|
raise ValueError("No valid tasks could be created from AI response")
|
|
|
|
taskPlan = TaskPlan(
|
|
overview=taskPlanDict.get('overview', ''),
|
|
tasks=tasks,
|
|
userMessage=taskPlanDict.get('userMessage', '')
|
|
)
|
|
|
|
logger.info(f"Task plan generated successfully with {len(tasks)} tasks")
|
|
|
|
return taskPlan
|
|
except Exception as e:
|
|
logger.error(f"Error in generateTaskPlan: {str(e)}")
|
|
raise
|
|
|
|
|
|
|
|
def _validateTaskPlan(self, taskPlan: Dict[str, Any]) -> bool:
|
|
"""Validate task plan structure"""
|
|
try:
|
|
if not isinstance(taskPlan, dict):
|
|
logger.error("Task plan is not a dictionary")
|
|
return False
|
|
|
|
if 'tasks' not in taskPlan or not isinstance(taskPlan['tasks'], list):
|
|
logger.error(f"Task plan missing 'tasks' field or not a list. Found: {type(taskPlan.get('tasks', 'MISSING'))}")
|
|
return False
|
|
|
|
# First pass: collect all task IDs to validate dependencies
|
|
taskIds = set()
|
|
for task in taskPlan['tasks']:
|
|
if not isinstance(task, dict):
|
|
logger.error(f"Task is not a dictionary: {type(task)}")
|
|
return False
|
|
if 'id' not in task:
|
|
logger.error(f"Task missing 'id' field: {task}")
|
|
return False
|
|
taskIds.add(task['id'])
|
|
|
|
# Second pass: validate each task
|
|
for i, task in enumerate(taskPlan['tasks']):
|
|
if not isinstance(task, dict):
|
|
logger.error(f"Task {i} is not a dictionary: {type(task)}")
|
|
return False
|
|
|
|
requiredFields = ['id', 'objective', 'success_criteria']
|
|
missingFields = [field for field in requiredFields if field not in task]
|
|
if missingFields:
|
|
logger.error(f"Task {i} missing required fields: {missingFields}")
|
|
return False
|
|
|
|
# Check for duplicate IDs (shouldn't happen after first pass, but safety check)
|
|
if task['id'] in taskIds and list(taskPlan['tasks']).count(task['id']) > 1:
|
|
logger.error(f"Task {i} has duplicate ID: {task['id']}")
|
|
return False
|
|
|
|
dependencies = task.get('dependencies', [])
|
|
if not isinstance(dependencies, list):
|
|
logger.error(f"Task {i} dependencies is not a list: {type(dependencies)}")
|
|
return False
|
|
|
|
for dep in dependencies:
|
|
if dep not in taskIds and dep != 'task_0':
|
|
logger.error(f"Task {i} has invalid dependency: {dep} (available: {list(taskIds) + ['task_0']})")
|
|
return False
|
|
|
|
logger.info(f"Task plan validation successful with {len(taskIds)} tasks")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating task plan: {str(e)}")
|
|
return False
|
|
|
|
def _writeTraceLog(self, contextText: str, data: Any) -> None:
|
|
"""Disabled extra trace file outputs (per chat debug simplification)."""
|
|
return
|