gateway/modules/workflows/processing/core/taskPlanner.py

238 lines
No EOL
11 KiB
Python

# taskPlanner.py
# Task planning functionality for workflows
import json
import logging
from typing import Dict, Any
from modules.datamodels.datamodelChat import TaskStep, TaskContext, TaskPlan
from modules.datamodels.datamodelAi import AiCallOptions, OperationType, ProcessingMode, Priority
from modules.workflows.processing.shared.promptGenerationTaskplan import (
generateTaskPlanningPrompt
)
from modules.workflows.processing.adaptive import IntentAnalyzer
logger = logging.getLogger(__name__)
class TaskPlanner:
"""Handles task planning for workflows"""
def __init__(self, services):
self.services = services
def _checkWorkflowStopped(self, workflow):
"""Check if workflow has been stopped by user and raise exception if so"""
try:
# Get the current workflow status from the database to avoid stale data
current_workflow = self.services.interfaceDbChat.getWorkflow(workflow.id)
if current_workflow and current_workflow.status == "stopped":
logger.info("Workflow stopped by user, aborting task planning")
raise Exception("Workflow was stopped by user")
except Exception as e:
# If we can't get the current status due to other database issues, fall back to the in-memory object
logger.warning(f"Could not check current workflow status from database: {str(e)}")
if workflow and workflow.status == "stopped":
logger.info("Workflow stopped by user (from in-memory object), aborting task planning")
raise Exception("Workflow was stopped by user")
async def generateTaskPlan(self, userInput: str, workflow) -> TaskPlan:
"""Generate a high-level task plan for the workflow"""
try:
# Check workflow status before generating task plan
self._checkWorkflowStopped(workflow)
logger.info(f"=== STARTING TASK PLAN GENERATION ===")
logger.info(f"Workflow ID: {workflow.id}")
logger.info(f"User Input: {userInput}")
# Use stored user prompt if available, otherwise use the input
actualUserPrompt = self.services.currentUserPrompt if self.services and hasattr(self.services, 'currentUserPrompt') and self.services.currentUserPrompt else userInput
logger.info(f"Actual User Prompt: {actualUserPrompt}")
# Check workflow status before calling AI service
self._checkWorkflowStopped(workflow)
# Analyze user intent to obtain cleaned user objective for planning
intentAnalyzer = IntentAnalyzer(self.services)
intent = await intentAnalyzer.analyzeUserIntent(actualUserPrompt, None)
cleanedObjective = intent.get('primaryGoal', actualUserPrompt) if isinstance(intent, dict) else actualUserPrompt
# Create proper context object for task planning using cleaned intent
# For task planning, we need to create a minimal TaskStep since TaskContext requires it
planningTaskStep = TaskStep(
id="planning",
objective=cleanedObjective,
dependencies=[],
success_criteria=[],
estimated_complexity="medium"
)
taskPlanningContext = TaskContext(
task_step=planningTaskStep,
workflow=workflow,
workflow_id=workflow.id,
available_documents=None,
available_connections=None,
previous_results=[],
previous_handover=None,
improvements=[],
retry_count=0,
previous_action_results=[],
previous_review_result=None,
is_regeneration=False,
failure_patterns=[],
failed_actions=[],
successful_actions=[],
criteria_progress={
'met_criteria': set(),
'unmet_criteria': set(),
'attempt_history': []
}
)
# Build prompt bundle (template + placeholders) using new API
bundle = generateTaskPlanningPrompt(self.services, taskPlanningContext)
taskPlanningPromptTemplate = bundle.prompt
placeholders = bundle.placeholders
# Centralized AI call: Task planning (quality, detailed) with placeholders
options = AiCallOptions(
operationType=OperationType.GENERATE_PLAN,
priority=Priority.QUALITY,
compressPrompt=False,
compressContext=False,
processingMode=ProcessingMode.DETAILED,
maxCost=0.10,
maxProcessingTime=30
)
prompt = await self.services.ai.callAiPlanning(
prompt=taskPlanningPromptTemplate,
placeholders=placeholders,
options=options,
loopInstructionFormat="json"
)
# Check if AI response is valid
if not prompt:
raise ValueError("AI service returned no response for task planning")
# Parse task plan response
try:
jsonStart = prompt.find('{')
jsonEnd = prompt.rfind('}') + 1
if jsonStart == -1 or jsonEnd == 0:
raise ValueError("No JSON found in response")
jsonStr = prompt[jsonStart:jsonEnd]
taskPlanDict = json.loads(jsonStr)
if 'tasks' not in taskPlanDict:
raise ValueError("Task plan missing 'tasks' field")
except Exception as e:
logger.error(f"Error parsing task plan response: {str(e)}")
taskPlanDict = {'tasks': []}
if not self._validateTaskPlan(taskPlanDict):
logger.error("Generated task plan failed validation")
logger.error(f"AI Response: {prompt}")
logger.error(f"Parsed Task Plan: {json.dumps(taskPlanDict, indent=2)}")
raise Exception("AI-generated task plan failed validation - AI is required for task planning")
if not taskPlanDict.get('tasks'):
raise ValueError("Task plan contains no tasks")
# Use already detected language from services; do not detect here
userLanguage = self.services.currentUserLanguage or 'en'
logger.info(f"Task planning using user language: {userLanguage}")
tasks = []
for i, taskDict in enumerate(taskPlanDict.get('tasks', [])):
if not isinstance(taskDict, dict):
logger.warning(f"Skipping invalid task {i+1}: not a dictionary")
continue
# Map old 'description' field to new 'objective' field
if 'description' in taskDict and 'objective' not in taskDict:
taskDict['objective'] = taskDict.pop('description')
try:
task = TaskStep(**taskDict)
tasks.append(task)
except Exception as e:
logger.warning(f"Skipping invalid task {i+1}: {str(e)}")
continue
if not tasks:
raise ValueError("No valid tasks could be created from AI response")
taskPlan = TaskPlan(
overview=taskPlanDict.get('overview', ''),
tasks=tasks,
userMessage=taskPlanDict.get('userMessage', '')
)
logger.info(f"Task plan generated successfully with {len(tasks)} tasks")
return taskPlan
except Exception as e:
logger.error(f"Error in generateTaskPlan: {str(e)}")
raise
def _validateTaskPlan(self, taskPlan: Dict[str, Any]) -> bool:
"""Validate task plan structure"""
try:
if not isinstance(taskPlan, dict):
logger.error("Task plan is not a dictionary")
return False
if 'tasks' not in taskPlan or not isinstance(taskPlan['tasks'], list):
logger.error(f"Task plan missing 'tasks' field or not a list. Found: {type(taskPlan.get('tasks', 'MISSING'))}")
return False
# First pass: collect all task IDs to validate dependencies
taskIds = set()
for task in taskPlan['tasks']:
if not isinstance(task, dict):
logger.error(f"Task is not a dictionary: {type(task)}")
return False
if 'id' not in task:
logger.error(f"Task missing 'id' field: {task}")
return False
taskIds.add(task['id'])
# Second pass: validate each task
for i, task in enumerate(taskPlan['tasks']):
if not isinstance(task, dict):
logger.error(f"Task {i} is not a dictionary: {type(task)}")
return False
requiredFields = ['id', 'objective', 'success_criteria']
missingFields = [field for field in requiredFields if field not in task]
if missingFields:
logger.error(f"Task {i} missing required fields: {missingFields}")
return False
# Check for duplicate IDs (shouldn't happen after first pass, but safety check)
if task['id'] in taskIds and list(taskPlan['tasks']).count(task['id']) > 1:
logger.error(f"Task {i} has duplicate ID: {task['id']}")
return False
dependencies = task.get('dependencies', [])
if not isinstance(dependencies, list):
logger.error(f"Task {i} dependencies is not a list: {type(dependencies)}")
return False
for dep in dependencies:
if dep not in taskIds and dep != 'task_0':
logger.error(f"Task {i} has invalid dependency: {dep} (available: {list(taskIds) + ['task_0']})")
return False
logger.info(f"Task plan validation successful with {len(taskIds)} tasks")
return True
except Exception as e:
logger.error(f"Error validating task plan: {str(e)}")
return False