From 125c2bbce7bc88fd8b286f5ba6eeb0287fc3e6ae Mon Sep 17 00:00:00 2001 From: Ida Dittrich Date: Mon, 9 Mar 2026 10:56:24 +0100 Subject: [PATCH] stop automation bug fix --- .../workflows/processing/shared/stateTools.py | 4 +- modules/workflows/workflowManager.py | 51 +++++++++++++++++-- 2 files changed, 49 insertions(+), 6 deletions(-) diff --git a/modules/workflows/processing/shared/stateTools.py b/modules/workflows/processing/shared/stateTools.py index 485539f9..70259b3c 100644 --- a/modules/workflows/processing/shared/stateTools.py +++ b/modules/workflows/processing/shared/stateTools.py @@ -26,8 +26,8 @@ def checkWorkflowStopped(services: Any) -> None: Raises: WorkflowStoppedException: If workflow status is "stopped" """ - workflow = services.workflow - if not workflow: + workflow = getattr(services, 'workflow', None) + if not workflow or not hasattr(workflow, 'id') or workflow.id is None: return try: diff --git a/modules/workflows/workflowManager.py b/modules/workflows/workflowManager.py index c81977c1..b9b64a9a 100644 --- a/modules/workflows/workflowManager.py +++ b/modules/workflows/workflowManager.py @@ -19,6 +19,14 @@ from modules.workflows.processing.shared.stateTools import WorkflowStoppedExcept logger = logging.getLogger(__name__) +# Registry of running workflow tasks: workflowId -> asyncio.Task +# Used to cancel workflow immediately when stop is requested +_workflow_tasks: Dict[str, asyncio.Task] = {} + +def _unregister_workflow_task(workflow_id: str) -> None: + """Remove workflow task from registry when it completes.""" + _workflow_tasks.pop(workflow_id, None) + class WorkflowManager: """Manager for workflow processing and coordination""" @@ -26,6 +34,19 @@ class WorkflowManager: self.services = services self.workflowProcessor = None + def _propagateWorkflowToContext(self, workflow): + """Update workflow in all service contexts. Resolved services may be cached and hold + a different context than hub._service_context; update each service's _context.workflow.""" + # Update stored context if present + ctx = getattr(self.services, '_service_context', None) + if ctx is not None: + ctx.workflow = workflow + # Also update contexts on resolved services (they may be cached with different context refs) + for attr in ('chat', 'ai', 'extraction', 'sharepoint', 'utils', 'billing', 'generation'): + svc = getattr(self.services, attr, None) + if svc is not None and hasattr(svc, '_context') and svc._context is not None: + svc._context.workflow = workflow + # Exported functions async def workflowStart(self, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None) -> ChatWorkflow: @@ -42,9 +63,14 @@ class WorkflowManager: # Store workflow in services for reference (this is the ChatWorkflow object) self.services.workflow = workflow + self._propagateWorkflowToContext(workflow) if workflow.status == "running": logger.info(f"Stopping running workflow {workflowId} before processing new prompt") + # Cancel existing task immediately so we don't have two tasks for same workflow + existing = _workflow_tasks.pop(workflowId, None) + if existing and not existing.done(): + existing.cancel() workflow.status = "stopped" workflow.lastActivity = currentTime self.services.chat.updateWorkflow(workflowId, { @@ -104,6 +130,7 @@ class WorkflowManager: # Store workflow in services (this is the ChatWorkflow object) self.services.workflow = workflow + self._propagateWorkflowToContext(workflow) # CRITICAL: Update all method instances to use the current Services object with the correct workflow # This ensures cached method instances don't use stale workflow IDs from previous workflows @@ -111,8 +138,11 @@ class WorkflowManager: discoverMethods(self.services) logger.debug(f"Updated method instances to use workflow {self.services.workflow.id}") - # Start workflow processing asynchronously - asyncio.create_task(self._workflowProcess(userInput)) + # Start workflow processing asynchronously; register for immediate cancel on stop + task = asyncio.create_task(self._workflowProcess(userInput)) + wid = workflow.id + _workflow_tasks[wid] = task + task.add_done_callback(lambda _: _unregister_workflow_task(wid)) return workflow except Exception as e: @@ -128,6 +158,7 @@ class WorkflowManager: # Store workflow in services (this is the ChatWorkflow object) self.services.workflow = workflow + self._propagateWorkflowToContext(workflow) workflow.status = "stopped" workflow.lastActivity = self.services.utils.timestampGetUtc() @@ -141,6 +172,10 @@ class WorkflowManager: "status": "stopped", "progress": 1.0 }) + # Cancel the running task immediately so workflow stops without waiting for checkpoints + running_task = _workflow_tasks.pop(workflowId, None) + if running_task and not running_task.done(): + running_task.cancel() return workflow except Exception as e: logger.error(f"Error stopping workflow: {str(e)}") @@ -274,6 +309,10 @@ class WorkflowManager: await self._executeTasks(taskPlan) await self._processWorkflowResults() + except asyncio.CancelledError: + # Task was cancelled (user clicked stop) - ensure stopped message is created, then re-raise + self._handleWorkflowStop() + raise except WorkflowStoppedException: self._handleWorkflowStop() @@ -1317,8 +1356,12 @@ The following is the user's original input message. Analyze intent, normalize th async def _neutralizeContentIfEnabled(self, contentBytes: bytes, mimeType: str) -> bytes: """Neutralize content if neutralization is enabled in user settings""" try: + # Automation hub may not have neutralization service; skip if unavailable + neutralization = getattr(self.services, 'neutralization', None) + if not neutralization: + return contentBytes # Check if neutralization is enabled - config = self.services.neutralization.getConfig() + config = neutralization.getConfig() if not config or not config.enabled: return contentBytes @@ -1340,7 +1383,7 @@ The following is the user's original input message. Analyze intent, normalize th # Neutralize the text content # Note: The neutralization service should use names from config when processing - result = self.services.neutralization.processText(textContent) + result = neutralization.processText(textContent) if result and 'neutralized_text' in result: neutralizedText = result['neutralized_text'] # Encode back to bytes using the same encoding