239 lines
No EOL
8.8 KiB
Python
239 lines
No EOL
8.8 KiB
Python
from typing import Dict, Any, Optional, List
|
|
import logging
|
|
import json
|
|
import asyncio
|
|
from datetime import datetime, UTC
|
|
import uuid
|
|
|
|
from modules.workflow.managerChat import ChatManager
|
|
from modules.workflow.managerDocument import DocumentManager
|
|
from modules.interfaces.serviceChatModel import AgentTask, TaskStatus, ActionStatus
|
|
from modules.shared.configuration import APP_CONFIG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class WorkflowManager:
|
|
"""Workflow manager with improved task management and error recovery"""
|
|
|
|
def __init__(self):
|
|
self.chat_manager = ChatManager()
|
|
self.document_manager = DocumentManager()
|
|
self.workflow = None
|
|
self.context = {}
|
|
self.task_queue = asyncio.Queue()
|
|
self.active_tasks = {}
|
|
self.task_history = []
|
|
|
|
async def initialize(self, workflow: Any, context: Dict[str, Any]) -> None:
|
|
"""Initialize workflow manager with workflow and context"""
|
|
self.workflow = workflow
|
|
self.context = context
|
|
|
|
# Initialize managers
|
|
await self.chat_manager.initialize(workflow, context)
|
|
await self.document_manager.initialize(context)
|
|
|
|
# Start task processor
|
|
asyncio.create_task(self._process_task_queue())
|
|
|
|
async def process_workflow(self, user_input: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Process workflow with user input"""
|
|
try:
|
|
# Create initial task
|
|
task = await self.chat_manager.create_initial_task(user_input)
|
|
|
|
# Add to queue
|
|
await self.task_queue.put(task)
|
|
|
|
# Wait for completion
|
|
while not task.is_complete() and not task.has_failed():
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Process results
|
|
if task.status == TaskStatus.SUCCESS:
|
|
return {
|
|
"status": "success",
|
|
"result": task.result,
|
|
"documents": task.documentsOutput
|
|
}
|
|
else:
|
|
return {
|
|
"status": "error",
|
|
"error": task.error,
|
|
"feedback": task.thisTaskFeedback
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing workflow: {str(e)}")
|
|
return {
|
|
"status": "error",
|
|
"error": str(e)
|
|
}
|
|
|
|
async def _process_task_queue(self) -> None:
|
|
"""Process tasks in queue"""
|
|
while True:
|
|
try:
|
|
# Get task from queue
|
|
task = await self.task_queue.get()
|
|
|
|
# Process task
|
|
result = await self.chat_manager.process_task(task)
|
|
|
|
# Update task status
|
|
if result["status"] == "success":
|
|
task.status = TaskStatus.SUCCESS
|
|
task.result = result.get("result")
|
|
task.documentsOutput = result.get("documents", [])
|
|
else:
|
|
task.status = TaskStatus.FAILED
|
|
task.error = result.get("error")
|
|
|
|
# Add to history
|
|
self.task_history.append({
|
|
"id": task.id,
|
|
"status": task.status,
|
|
"startedAt": task.startedAt,
|
|
"finishedAt": datetime.now(UTC).isoformat(),
|
|
"error": task.error
|
|
})
|
|
|
|
# Check for next task
|
|
if not task.is_complete():
|
|
next_task = await self._define_next_task(task)
|
|
if next_task:
|
|
await self.task_queue.put(next_task)
|
|
|
|
# Mark task as done
|
|
self.task_queue.task_done()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing task queue: {str(e)}")
|
|
await asyncio.sleep(1) # Prevent tight loop on error
|
|
|
|
async def _define_next_task(self, current_task: AgentTask) -> Optional[AgentTask]:
|
|
"""Define next task based on current task results"""
|
|
try:
|
|
# Analyze current task
|
|
analysis = await self.chat_manager._analyze_task_results(current_task)
|
|
|
|
# Check if next task needed
|
|
if not analysis.get("isComplete", True):
|
|
# Create next task
|
|
next_task = await self.chat_manager.create_next_task(
|
|
current_task,
|
|
analysis.get("nextActions", []),
|
|
analysis.get("requiredDocuments", [])
|
|
)
|
|
|
|
# Add dependencies
|
|
next_task.dependencies = [current_task.id]
|
|
|
|
return next_task
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error defining next task: {str(e)}")
|
|
return None
|
|
|
|
async def handle_error(self, task: AgentTask, error: str) -> None:
|
|
"""Handle task error with recovery strategies"""
|
|
try:
|
|
# Log error
|
|
logger.error(f"Task {task.id} failed: {error}")
|
|
|
|
# Update task status
|
|
task.status = TaskStatus.FAILED
|
|
task.error = error
|
|
|
|
# Check for retryable errors
|
|
if self._is_retryable_error(error):
|
|
if task.retryCount < task.retryMax:
|
|
# Retry task
|
|
task.retryCount += 1
|
|
task.status = TaskStatus.RETRY
|
|
await self.task_queue.put(task)
|
|
return
|
|
|
|
# Check for rollback needed
|
|
if task.rollback_on_failure:
|
|
await self._rollback_task(task)
|
|
|
|
# Notify workflow
|
|
self.workflow.status = "error"
|
|
self.workflow.error = error
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error handling task error: {str(e)}")
|
|
|
|
async def _rollback_task(self, task: AgentTask) -> None:
|
|
"""Rollback task actions"""
|
|
try:
|
|
for action in task.actionList:
|
|
if action.status == ActionStatus.SUCCESS:
|
|
# Get method
|
|
method = self.chat_manager.service.methods.get(action.method)
|
|
if method:
|
|
# Rollback action
|
|
await method.rollback(
|
|
action.action,
|
|
action.parameters,
|
|
task.get_auth_data(action.auth_source)
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error rolling back task: {str(e)}")
|
|
|
|
def _is_retryable_error(self, error: str) -> bool:
|
|
"""Check if error is retryable"""
|
|
retryable_errors = [
|
|
"timeout",
|
|
"rate limit",
|
|
"temporary",
|
|
"connection",
|
|
"server error"
|
|
]
|
|
return any(err in error.lower() for err in retryable_errors)
|
|
|
|
async def cleanup(self) -> None:
|
|
"""Clean up workflow resources"""
|
|
try:
|
|
# Clean up managers
|
|
await self.chat_manager.cleanup()
|
|
await self.document_manager.cleanup()
|
|
|
|
# Clear task queue
|
|
while not self.task_queue.empty():
|
|
self.task_queue.get_nowait()
|
|
self.task_queue.task_done()
|
|
|
|
# Clear active tasks
|
|
self.active_tasks.clear()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during cleanup: {str(e)}")
|
|
|
|
async def get_workflow_status(self, workflow_id: str) -> Dict[str, Any]:
|
|
"""Get current status of workflow"""
|
|
current_task = self.chat_manager.service.tasks.get('current')
|
|
previous_task = self.chat_manager.service.tasks.get('previous')
|
|
|
|
return {
|
|
'workflowId': workflow_id,
|
|
'currentTask': current_task.dict() if current_task else None,
|
|
'previousTask': previous_task.dict() if previous_task else None,
|
|
'status': self.chat_manager.workflow.status if self.chat_manager.workflow else None
|
|
}
|
|
|
|
async def stop_workflow(self, workflow_id: str) -> None:
|
|
"""Stop workflow execution"""
|
|
if self.chat_manager.workflow and self.chat_manager.workflow.id == workflow_id:
|
|
self.chat_manager.workflow.status = TaskStatus.STOPPED
|
|
self.chat_manager.workflow.updatedAt = datetime.now(UTC)
|
|
|
|
# Stop current task if any
|
|
current_task = self.chat_manager.service.tasks.get('current')
|
|
if current_task:
|
|
current_task.status = TaskStatus.STOPPED
|
|
current_task.updatedAt = datetime.now(UTC) |