gateway/modules/workflows/automation2/graphUtils.py

243 lines
8.8 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# Graph parsing, validation, and topological sort for automation2.
import logging
from typing import Dict, List, Any, Tuple, Set
logger = logging.getLogger(__name__)
def parseGraph(graph: Dict[str, Any]) -> Tuple[List[Dict], List[Dict], Set[str]]:
"""
Parse graph into nodes, connections, and node IDs.
graph: { nodes: [...], connections: [...] }
Returns (nodes, connections, node_ids).
"""
nodes = graph.get("nodes") or []
connections = graph.get("connections") or []
nodeIds = {n.get("id") for n in nodes if n.get("id")}
logger.debug(
"parseGraph: nodes=%d connections=%d nodeIds=%s",
len(nodes),
len(connections),
sorted(nodeIds),
)
return nodes, connections, nodeIds
def buildConnectionMap(connections: List[Dict]) -> Dict[str, List[Tuple[str, int, int]]]:
"""
Build map: targetNodeId -> [(sourceNodeId, sourceOutput, targetInput), ...]
connection: { source, sourceOutput?, target, targetInput? }
"""
out: Dict[str, List[Tuple[str, int, int]]] = {}
for i, c in enumerate(connections):
src = c.get("source") or c.get("sourceNode")
tgt = c.get("target") or c.get("targetNode")
if not src or not tgt:
logger.debug("buildConnectionMap skip conn[%d]: missing source/target %r", i, c)
continue
so = c.get("sourceOutput", 0)
ti = c.get("targetInput", 0)
if tgt not in out:
out[tgt] = []
out[tgt].append((src, so, ti))
logger.debug("buildConnectionMap conn[%d]: %s -> %s (so=%d ti=%d)", i, src, tgt, so, ti)
logger.debug("buildConnectionMap result: %s", {k: v for k, v in out.items()})
return out
def getLoopBodyNodeIds(loopNodeId: str, connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> Set[str]:
"""Nodes reachable from loop's output (BFS forward). Body = downstream nodes that receive from loop."""
from collections import deque
body = set()
# connectionMap: target -> [(source, sourceOutput, targetInput)]
rev: Dict[str, List[str]] = {} # source -> [targets]
for tgt, pairs in connectionMap.items():
for src, _, _ in pairs:
if src not in rev:
rev[src] = []
rev[src].append(tgt)
q = deque([loopNodeId])
while q:
nid = q.popleft()
for tgt in rev.get(nid, []):
if tgt not in body:
body.add(tgt)
q.append(tgt)
return body
def getInputSources(nodeId: str, connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> Dict[int, Tuple[str, int]]:
"""
For a node, return targetInput -> (sourceNodeId, sourceOutput).
"""
result: Dict[int, Tuple[str, int]] = {}
for src, so, ti in connectionMap.get(nodeId, []):
result[ti] = (src, so)
return result
def getTriggerNodes(nodes: List[Dict]) -> List[Dict]:
"""Return nodes with category=trigger or type starting with trigger."""
return [n for n in nodes if (n.get("type", "").startswith("trigger.") or n.get("category") == "trigger")]
def validateGraph(graph: Dict[str, Any], nodeTypeIds: Set[str]) -> List[str]:
"""
Validate graph: all node IDs referenced in connections exist, all node types in registry.
Returns list of error messages (empty if valid).
"""
errors = []
nodes, connections, nodeIds = parseGraph(graph)
for n in nodes:
nid = n.get("id")
ntype = n.get("type")
if not nid:
errors.append("Node missing id")
continue
if not ntype:
errors.append(f"Node {nid} missing type")
continue
if ntype not in nodeTypeIds:
errors.append(f"Unknown node type '{ntype}' for node {nid}")
connMap = buildConnectionMap(connections)
allReferred = set()
for tgt, pairs in connMap.items():
allReferred.add(tgt)
for src, _, _ in pairs:
allReferred.add(src)
for nid in allReferred:
if nid not in nodeIds:
errors.append(f"Connection references non-existent node {nid}")
if errors:
logger.debug("validateGraph errors: %s", errors)
else:
logger.debug("validateGraph: OK")
return errors
def topoSort(nodes: List[Dict], connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> List[Dict]:
"""
Topological sort: start from trigger nodes, then BFS by connections.
Returns ordered list of nodes (trigger first, then downstream).
"""
nodeById = {n["id"]: n for n in nodes if n.get("id")}
triggers = getTriggerNodes(nodes)
if not triggers:
return list(nodes)
visited: Set[str] = set()
order: List[Dict] = []
def bfs(startIds: List[str]) -> None:
from collections import deque
q = deque(startIds)
for nid in startIds:
visited.add(nid)
if nid in nodeById:
order.append(nodeById[nid])
while q:
nid = q.popleft()
# Find all nodes that receive from nid
for tgt, pairs in connectionMap.items():
for src, _, _ in pairs:
if src == nid and tgt not in visited:
visited.add(tgt)
q.append(tgt)
if tgt in nodeById:
order.append(nodeById[tgt])
triggerIds = [t["id"] for t in triggers]
logger.debug("topoSort triggers: %s", triggerIds)
bfs(triggerIds)
# Append any orphan nodes (e.g. disconnected)
for n in nodes:
if n.get("id") and n["id"] not in visited:
order.append(n)
logger.debug("topoSort order (%d nodes): %s", len(order), [n.get("id") for n in order])
return order
def _get_by_path(data: Any, path: List[Any]) -> Any:
"""Traverse data by path (strings and ints); return None if not found."""
current = data
for seg in path:
if current is None:
return None
if isinstance(current, dict) and isinstance(seg, str) and seg in current:
current = current[seg]
elif isinstance(current, (list, tuple)) and isinstance(seg, (int, str)):
idx = int(seg) if isinstance(seg, str) and seg.isdigit() else seg
if isinstance(idx, int) and 0 <= idx < len(current):
current = current[idx]
else:
return None
else:
return None
return current
def resolveParameterReferences(value: Any, nodeOutputs: Dict[str, Any]) -> Any:
"""
Resolve parameter references:
- {{nodeId.output}} or {{nodeId.output.path}} in strings (legacy)
- { "type": "ref", "nodeId": "...", "path": ["field", "nested"] } -> resolved value
- { "type": "value", "value": ... } -> value (then recursively resolve)
"""
import json
import re
if isinstance(value, dict):
if value.get("type") == "ref":
node_id = value.get("nodeId")
path = value.get("path")
if node_id is not None and isinstance(path, (list, tuple)):
data = nodeOutputs.get(node_id)
plist = list(path)
resolved = _get_by_path(data, plist)
# input.form historically stored flat field dict; refs use payload.<field>
if (
resolved is None
and isinstance(data, dict)
and plist
and plist[0] == "payload"
and len(plist) > 1
):
resolved = _get_by_path(data, plist[1:])
return resolveParameterReferences(resolved, nodeOutputs)
return value
if value.get("type") == "value":
inner = value.get("value")
return resolveParameterReferences(inner, nodeOutputs)
return {k: resolveParameterReferences(v, nodeOutputs) for k, v in value.items()}
if isinstance(value, str):
def repl(m):
ref = m.group(1).strip()
parts = ref.split(".")
nodeId = parts[0]
data = nodeOutputs.get(nodeId)
if data is None:
return m.group(0)
if len(parts) < 2:
return json.dumps(data) if isinstance(data, (dict, list)) else str(data)
rest = ".".join(parts[1:])
if data is None:
return m.group(0)
for k in rest.split("."):
if isinstance(data, dict) and k in data:
data = data[k]
elif isinstance(data, (list, tuple)) and k.isdigit():
data = data[int(k)]
else:
return m.group(0)
return str(data) if data is not None else m.group(0)
return re.sub(r"\{\{\s*([^}]+)\s*\}\}", repl, value)
if isinstance(value, list):
return [resolveParameterReferences(v, nodeOutputs) for v in value]
return value