115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# Main execution engine for automation2 graphs.
|
|
|
|
import logging
|
|
from typing import Dict, Any, List, Set
|
|
|
|
from modules.workflows.automation2.graphUtils import (
|
|
parseGraph,
|
|
buildConnectionMap,
|
|
validateGraph,
|
|
topoSort,
|
|
getInputSources,
|
|
)
|
|
|
|
from modules.workflows.automation2.executors import (
|
|
TriggerExecutor,
|
|
FlowExecutor,
|
|
DataExecutor,
|
|
IOExecutor,
|
|
)
|
|
from modules.features.automation2.nodeDefinitions import STATIC_NODE_TYPES
|
|
from modules.workflows.processing.shared.methodDiscovery import discoverMethods, methods
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _getNodeTypeIds(services: Any) -> Set[str]:
|
|
"""Collect all known node type IDs from static + methodDiscovery."""
|
|
discoverMethods(services)
|
|
ids = {n["id"] for n in STATIC_NODE_TYPES}
|
|
for methodName, methodInfo in methods.items():
|
|
if not methodName.startswith("Method"):
|
|
continue
|
|
shortMethod = methodName.replace("Method", "").lower()
|
|
instance = methodInfo.get("instance")
|
|
if not instance:
|
|
continue
|
|
for actionName in instance.actions:
|
|
ids.add(f"io.{shortMethod}.{actionName}")
|
|
return ids
|
|
|
|
|
|
def _getExecutor(nodeType: str, services: Any) -> Any:
|
|
"""Dispatch to correct executor based on node type."""
|
|
if nodeType.startswith("trigger."):
|
|
return TriggerExecutor()
|
|
if nodeType.startswith("flow."):
|
|
return FlowExecutor()
|
|
if nodeType.startswith("data."):
|
|
return DataExecutor()
|
|
if nodeType.startswith("io."):
|
|
return IOExecutor(services)
|
|
return None
|
|
|
|
|
|
async def executeGraph(
|
|
graph: Dict[str, Any],
|
|
services: Any,
|
|
workflowId: str = None,
|
|
instanceId: str = None,
|
|
userId: str = None,
|
|
mandateId: str = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Execute automation2 graph. Returns { success, nodeOutputs, error?, stopped? }.
|
|
"""
|
|
nodeTypeIds = _getNodeTypeIds(services)
|
|
errors = validateGraph(graph, nodeTypeIds)
|
|
if errors:
|
|
return {"success": False, "error": "; ".join(errors), "nodeOutputs": {}}
|
|
|
|
nodes, connections = parseGraph(graph)[:2]
|
|
connectionMap = buildConnectionMap(connections)
|
|
inputSources = {n["id"]: getInputSources(n["id"], connectionMap) for n in nodes if n.get("id")}
|
|
|
|
ordered = topoSort(nodes, connectionMap)
|
|
nodeOutputs: Dict[str, Any] = {}
|
|
context = {
|
|
"workflowId": workflowId,
|
|
"instanceId": instanceId,
|
|
"userId": userId,
|
|
"mandateId": mandateId,
|
|
"nodeOutputs": nodeOutputs,
|
|
"connectionMap": connectionMap,
|
|
"inputSources": inputSources,
|
|
"services": services,
|
|
}
|
|
|
|
for node in ordered:
|
|
if context.get("_stopped"):
|
|
break
|
|
nodeId = node.get("id")
|
|
nodeType = node.get("type", "")
|
|
executor = _getExecutor(nodeType, services)
|
|
if not executor:
|
|
nodeOutputs[nodeId] = None
|
|
continue
|
|
try:
|
|
result = await executor.execute(node, context)
|
|
nodeOutputs[nodeId] = result
|
|
except Exception as e:
|
|
logger.exception(f"automation2 execution failed for node {nodeId} ({nodeType})")
|
|
nodeOutputs[nodeId] = {"error": str(e), "success": False}
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"nodeOutputs": nodeOutputs,
|
|
"failedNode": nodeId,
|
|
}
|
|
|
|
return {
|
|
"success": True,
|
|
"nodeOutputs": nodeOutputs,
|
|
"stopped": context.get("_stopped", False),
|
|
}
|