201 lines
7.1 KiB
Python
201 lines
7.1 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# Main execution engine for automation2 graphs.
|
|
|
|
import logging
|
|
from typing import Dict, Any, List, Set, Optional
|
|
|
|
from modules.workflows.automation2.graphUtils import (
|
|
parseGraph,
|
|
buildConnectionMap,
|
|
validateGraph,
|
|
topoSort,
|
|
getInputSources,
|
|
)
|
|
|
|
from modules.workflows.automation2.executors import (
|
|
TriggerExecutor,
|
|
FlowExecutor,
|
|
DataExecutor,
|
|
IOExecutor,
|
|
InputExecutor,
|
|
PauseForHumanTaskError,
|
|
)
|
|
from modules.features.automation2.nodeDefinitions import STATIC_NODE_TYPES
|
|
from modules.workflows.processing.shared.methodDiscovery import discoverMethods, methods
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _getNodeTypeIds(services: Any) -> Set[str]:
|
|
"""Collect all known node type IDs from static + methodDiscovery."""
|
|
discoverMethods(services)
|
|
ids = {n["id"] for n in STATIC_NODE_TYPES}
|
|
for methodName, methodInfo in methods.items():
|
|
if not methodName.startswith("Method"):
|
|
continue
|
|
shortMethod = methodName.replace("Method", "").lower()
|
|
instance = methodInfo.get("instance")
|
|
if not instance:
|
|
continue
|
|
for actionName in instance.actions:
|
|
ids.add(f"io.{shortMethod}.{actionName}")
|
|
return ids
|
|
|
|
|
|
def _getExecutor(
|
|
nodeType: str,
|
|
services: Any,
|
|
automation2_interface: Optional[Any] = None,
|
|
) -> Any:
|
|
"""Dispatch to correct executor based on node type."""
|
|
if nodeType.startswith("trigger."):
|
|
return TriggerExecutor()
|
|
if nodeType.startswith("flow."):
|
|
return FlowExecutor()
|
|
if nodeType.startswith("data."):
|
|
return DataExecutor()
|
|
if nodeType.startswith("io."):
|
|
return IOExecutor(services)
|
|
if nodeType.startswith("input.") and automation2_interface:
|
|
return InputExecutor(automation2_interface)
|
|
return None
|
|
|
|
|
|
async def executeGraph(
|
|
graph: Dict[str, Any],
|
|
services: Any,
|
|
workflowId: str = None,
|
|
instanceId: str = None,
|
|
userId: str = None,
|
|
mandateId: str = None,
|
|
automation2_interface: Optional[Any] = None,
|
|
initialNodeOutputs: Optional[Dict[str, Any]] = None,
|
|
startAfterNodeId: Optional[str] = None,
|
|
runId: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Execute automation2 graph. Returns { success, nodeOutputs, error?, stopped? }.
|
|
When an input node is reached and automation2_interface is provided, creates a task,
|
|
pauses the run, and returns { success: False, paused: True, taskId, runId }.
|
|
For resume: pass initialNodeOutputs (with result for the human node) and startAfterNodeId.
|
|
"""
|
|
logger.info(
|
|
"executeGraph start: instanceId=%s workflowId=%s userId=%s mandateId=%s resume=%s",
|
|
instanceId,
|
|
workflowId,
|
|
userId,
|
|
mandateId,
|
|
startAfterNodeId is not None,
|
|
)
|
|
nodeTypeIds = _getNodeTypeIds(services)
|
|
logger.debug("executeGraph nodeTypeIds (%d): %s", len(nodeTypeIds), sorted(nodeTypeIds))
|
|
errors = validateGraph(graph, nodeTypeIds)
|
|
if errors:
|
|
logger.warning("executeGraph validation failed: %s", errors)
|
|
return {"success": False, "error": "; ".join(errors), "nodeOutputs": {}}
|
|
|
|
nodes, connections = parseGraph(graph)[:2]
|
|
connectionMap = buildConnectionMap(connections)
|
|
inputSources = {n["id"]: getInputSources(n["id"], connectionMap) for n in nodes if n.get("id")}
|
|
logger.info(
|
|
"executeGraph parsed: nodes=%d connections=%d connectionMap_targets=%s",
|
|
len(nodes),
|
|
len(connections),
|
|
list(connectionMap.keys()),
|
|
)
|
|
|
|
ordered = topoSort(nodes, connectionMap)
|
|
ordered_ids = [n.get("id") for n in ordered if n.get("id")]
|
|
logger.info("executeGraph topoSort order: %s", ordered_ids)
|
|
|
|
nodeOutputs: Dict[str, Any] = dict(initialNodeOutputs or {})
|
|
is_resume = startAfterNodeId is not None
|
|
if not runId and automation2_interface and workflowId and not is_resume:
|
|
run = automation2_interface.createRun(
|
|
workflowId=workflowId,
|
|
nodeOutputs=nodeOutputs,
|
|
context={"connectionMap": connectionMap, "inputSources": inputSources, "orderedNodeIds": ordered_ids},
|
|
)
|
|
runId = run.get("id") if run else None
|
|
logger.info("executeGraph created run %s", runId)
|
|
|
|
context = {
|
|
"workflowId": workflowId,
|
|
"instanceId": instanceId,
|
|
"userId": userId,
|
|
"mandateId": mandateId,
|
|
"nodeOutputs": nodeOutputs,
|
|
"connectionMap": connectionMap,
|
|
"inputSources": inputSources,
|
|
"services": services,
|
|
"_runId": runId,
|
|
"_orderedNodes": ordered,
|
|
}
|
|
|
|
skip_until_passed = bool(startAfterNodeId)
|
|
for i, node in enumerate(ordered):
|
|
if skip_until_passed:
|
|
if node.get("id") == startAfterNodeId:
|
|
skip_until_passed = False
|
|
continue
|
|
if context.get("_stopped"):
|
|
logger.info("executeGraph stopped early (flow.stop) at step %d", i)
|
|
break
|
|
nodeId = node.get("id")
|
|
nodeType = node.get("type", "")
|
|
executor = _getExecutor(nodeType, services, automation2_interface)
|
|
logger.info(
|
|
"executeGraph step %d/%d: nodeId=%s nodeType=%s executor=%s",
|
|
i + 1,
|
|
len(ordered),
|
|
nodeId,
|
|
nodeType,
|
|
type(executor).__name__ if executor else "None",
|
|
)
|
|
if not executor:
|
|
nodeOutputs[nodeId] = None
|
|
logger.debug("executeGraph node %s: no executor, output=None", nodeId)
|
|
continue
|
|
try:
|
|
result = await executor.execute(node, context)
|
|
nodeOutputs[nodeId] = result
|
|
logger.info(
|
|
"executeGraph node %s done: result_type=%s result_keys=%s",
|
|
nodeId,
|
|
type(result).__name__,
|
|
list(result.keys()) if isinstance(result, dict) else "n/a",
|
|
)
|
|
except PauseForHumanTaskError as e:
|
|
logger.info("executeGraph paused for human task %s", e.taskId)
|
|
return {
|
|
"success": False,
|
|
"paused": True,
|
|
"taskId": e.taskId,
|
|
"runId": e.runId,
|
|
"nodeId": e.nodeId,
|
|
"nodeOutputs": dict(nodeOutputs),
|
|
}
|
|
except Exception as e:
|
|
logger.exception("executeGraph node %s (%s) FAILED: %s", nodeId, nodeType, e)
|
|
nodeOutputs[nodeId] = {"error": str(e), "success": False}
|
|
if runId and automation2_interface:
|
|
automation2_interface.updateRun(runId, status="failed", nodeOutputs=nodeOutputs)
|
|
return {
|
|
"success": False,
|
|
"error": str(e),
|
|
"nodeOutputs": nodeOutputs,
|
|
"failedNode": nodeId,
|
|
}
|
|
|
|
if runId and automation2_interface:
|
|
automation2_interface.updateRun(runId, status="completed", nodeOutputs=nodeOutputs)
|
|
logger.info(
|
|
"executeGraph complete: success=True nodeOutputs_keys=%s stopped=%s",
|
|
list(nodeOutputs.keys()),
|
|
context.get("_stopped", False),
|
|
)
|
|
return {
|
|
"success": True,
|
|
"nodeOutputs": nodeOutputs,
|
|
"stopped": context.get("_stopped", False),
|
|
}
|