163 lines
5.5 KiB
Python
163 lines
5.5 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# Graph parsing, validation, and topological sort for automation2.
|
|
|
|
import logging
|
|
from typing import Dict, List, Any, Tuple, Set
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def parseGraph(graph: Dict[str, Any]) -> Tuple[List[Dict], List[Dict], Set[str]]:
|
|
"""
|
|
Parse graph into nodes, connections, and node IDs.
|
|
graph: { nodes: [...], connections: [...] }
|
|
Returns (nodes, connections, node_ids).
|
|
"""
|
|
nodes = graph.get("nodes") or []
|
|
connections = graph.get("connections") or []
|
|
nodeIds = {n.get("id") for n in nodes if n.get("id")}
|
|
return nodes, connections, nodeIds
|
|
|
|
|
|
def buildConnectionMap(connections: List[Dict]) -> Dict[str, List[Tuple[str, int, int]]]:
|
|
"""
|
|
Build map: targetNodeId -> [(sourceNodeId, sourceOutput, targetInput), ...]
|
|
connection: { source, sourceOutput?, target, targetInput? }
|
|
"""
|
|
out: Dict[str, List[Tuple[str, int, int]]] = {}
|
|
for c in connections:
|
|
src = c.get("source") or c.get("sourceNode")
|
|
tgt = c.get("target") or c.get("targetNode")
|
|
if not src or not tgt:
|
|
continue
|
|
so = c.get("sourceOutput", 0)
|
|
ti = c.get("targetInput", 0)
|
|
if tgt not in out:
|
|
out[tgt] = []
|
|
out[tgt].append((src, so, ti))
|
|
return out
|
|
|
|
|
|
def getInputSources(nodeId: str, connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> Dict[int, Tuple[str, int]]:
|
|
"""
|
|
For a node, return targetInput -> (sourceNodeId, sourceOutput).
|
|
"""
|
|
result: Dict[int, Tuple[str, int]] = {}
|
|
for src, so, ti in connectionMap.get(nodeId, []):
|
|
result[ti] = (src, so)
|
|
return result
|
|
|
|
|
|
def getTriggerNodes(nodes: List[Dict]) -> List[Dict]:
|
|
"""Return nodes with category=trigger or type starting with trigger."""
|
|
return [n for n in nodes if (n.get("type", "").startswith("trigger.") or n.get("category") == "trigger")]
|
|
|
|
|
|
def validateGraph(graph: Dict[str, Any], nodeTypeIds: Set[str]) -> List[str]:
|
|
"""
|
|
Validate graph: all node IDs referenced in connections exist, all node types in registry.
|
|
Returns list of error messages (empty if valid).
|
|
"""
|
|
errors = []
|
|
nodes, connections, nodeIds = parseGraph(graph)
|
|
|
|
for n in nodes:
|
|
nid = n.get("id")
|
|
ntype = n.get("type")
|
|
if not nid:
|
|
errors.append("Node missing id")
|
|
continue
|
|
if not ntype:
|
|
errors.append(f"Node {nid} missing type")
|
|
continue
|
|
if ntype not in nodeTypeIds:
|
|
errors.append(f"Unknown node type '{ntype}' for node {nid}")
|
|
|
|
connMap = buildConnectionMap(connections)
|
|
allReferred = set()
|
|
for tgt, pairs in connMap.items():
|
|
allReferred.add(tgt)
|
|
for src, _, _ in pairs:
|
|
allReferred.add(src)
|
|
for nid in allReferred:
|
|
if nid not in nodeIds:
|
|
errors.append(f"Connection references non-existent node {nid}")
|
|
|
|
return errors
|
|
|
|
|
|
def topoSort(nodes: List[Dict], connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> List[Dict]:
|
|
"""
|
|
Topological sort: start from trigger nodes, then BFS by connections.
|
|
Returns ordered list of nodes (trigger first, then downstream).
|
|
"""
|
|
nodeById = {n["id"]: n for n in nodes if n.get("id")}
|
|
triggers = getTriggerNodes(nodes)
|
|
if not triggers:
|
|
return list(nodes)
|
|
|
|
visited: Set[str] = set()
|
|
order: List[Dict] = []
|
|
|
|
def bfs(startIds: List[str]) -> None:
|
|
from collections import deque
|
|
q = deque(startIds)
|
|
for nid in startIds:
|
|
visited.add(nid)
|
|
if nid in nodeById:
|
|
order.append(nodeById[nid])
|
|
while q:
|
|
nid = q.popleft()
|
|
# Find all nodes that receive from nid
|
|
for tgt, pairs in connectionMap.items():
|
|
for src, _, _ in pairs:
|
|
if src == nid and tgt not in visited:
|
|
visited.add(tgt)
|
|
q.append(tgt)
|
|
if tgt in nodeById:
|
|
order.append(nodeById[tgt])
|
|
|
|
triggerIds = [t["id"] for t in triggers]
|
|
bfs(triggerIds)
|
|
|
|
# Append any orphan nodes (e.g. disconnected)
|
|
for n in nodes:
|
|
if n.get("id") and n["id"] not in visited:
|
|
order.append(n)
|
|
|
|
return order
|
|
|
|
|
|
def resolveParameterReferences(value: Any, nodeOutputs: Dict[str, Any]) -> Any:
|
|
"""
|
|
Resolve {{nodeId.output}} or {{nodeId.output.path}} in strings/structures.
|
|
"""
|
|
import json
|
|
import re
|
|
if isinstance(value, str):
|
|
def repl(m):
|
|
ref = m.group(1).strip()
|
|
parts = ref.split(".")
|
|
nodeId = parts[0]
|
|
data = nodeOutputs.get(nodeId)
|
|
if data is None:
|
|
return m.group(0)
|
|
if len(parts) < 2:
|
|
return json.dumps(data) if isinstance(data, (dict, list)) else str(data)
|
|
rest = ".".join(parts[1:])
|
|
if data is None:
|
|
return m.group(0)
|
|
for k in rest.split("."):
|
|
if isinstance(data, dict) and k in data:
|
|
data = data[k]
|
|
elif isinstance(data, (list, tuple)) and k.isdigit():
|
|
data = data[int(k)]
|
|
else:
|
|
return m.group(0)
|
|
return str(data) if data is not None else m.group(0)
|
|
return re.sub(r"\{\{\s*([^}]+)\s*\}\}", repl, value)
|
|
if isinstance(value, dict):
|
|
return {k: resolveParameterReferences(v, nodeOutputs) for k, v in value.items()}
|
|
if isinstance(value, list):
|
|
return [resolveParameterReferences(v, nodeOutputs) for v in value]
|
|
return value
|