408 lines
16 KiB
Python
408 lines
16 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# Graph parsing, validation, and topological sort for automation2.
|
|
|
|
import logging
|
|
from typing import Dict, List, Any, Tuple, Set, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def parseGraph(graph: Dict[str, Any]) -> Tuple[List[Dict], List[Dict], Set[str]]:
|
|
"""
|
|
Parse graph into nodes, connections, and node IDs.
|
|
graph: { nodes: [...], connections: [...] }
|
|
Returns (nodes, connections, node_ids).
|
|
"""
|
|
nodes = graph.get("nodes") or []
|
|
connections = graph.get("connections") or []
|
|
nodeIds = {n.get("id") for n in nodes if n.get("id")}
|
|
logger.debug(
|
|
"parseGraph: nodes=%d connections=%d nodeIds=%s",
|
|
len(nodes),
|
|
len(connections),
|
|
sorted(nodeIds),
|
|
)
|
|
return nodes, connections, nodeIds
|
|
|
|
|
|
def buildConnectionMap(connections: List[Dict]) -> Dict[str, List[Tuple[str, int, int]]]:
|
|
"""
|
|
Build map: targetNodeId -> [(sourceNodeId, sourceOutput, targetInput), ...]
|
|
connection: { source, sourceOutput?, target, targetInput? }
|
|
"""
|
|
out: Dict[str, List[Tuple[str, int, int]]] = {}
|
|
for i, c in enumerate(connections):
|
|
src = c.get("source") or c.get("sourceNode")
|
|
tgt = c.get("target") or c.get("targetNode")
|
|
if not src or not tgt:
|
|
logger.debug("buildConnectionMap skip conn[%d]: missing source/target %r", i, c)
|
|
continue
|
|
so = c.get("sourceOutput", 0)
|
|
ti = c.get("targetInput", 0)
|
|
if tgt not in out:
|
|
out[tgt] = []
|
|
out[tgt].append((src, so, ti))
|
|
logger.debug("buildConnectionMap conn[%d]: %s -> %s (so=%d ti=%d)", i, src, tgt, so, ti)
|
|
logger.debug("buildConnectionMap result: %s", {k: v for k, v in out.items()})
|
|
return out
|
|
|
|
|
|
def getLoopBodyNodeIds(loopNodeId: str, connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> Set[str]:
|
|
"""Nodes reachable from loop's output (BFS forward). Body = downstream nodes that receive from loop."""
|
|
from collections import deque
|
|
body = set()
|
|
# connectionMap: target -> [(source, sourceOutput, targetInput)]
|
|
rev: Dict[str, List[str]] = {} # source -> [targets]
|
|
for tgt, pairs in connectionMap.items():
|
|
for src, _, _ in pairs:
|
|
if src not in rev:
|
|
rev[src] = []
|
|
rev[src].append(tgt)
|
|
q = deque([loopNodeId])
|
|
while q:
|
|
nid = q.popleft()
|
|
for tgt in rev.get(nid, []):
|
|
if tgt not in body:
|
|
body.add(tgt)
|
|
q.append(tgt)
|
|
return body
|
|
|
|
|
|
def getInputSources(nodeId: str, connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> Dict[int, Tuple[str, int]]:
|
|
"""
|
|
For a node, return targetInput -> (sourceNodeId, sourceOutput).
|
|
"""
|
|
result: Dict[int, Tuple[str, int]] = {}
|
|
for src, so, ti in connectionMap.get(nodeId, []):
|
|
result[ti] = (src, so)
|
|
return result
|
|
|
|
|
|
def getTriggerNodes(nodes: List[Dict]) -> List[Dict]:
|
|
"""Return nodes with category=trigger or type starting with trigger."""
|
|
return [n for n in nodes if (n.get("type", "").startswith("trigger.") or n.get("category") == "trigger")]
|
|
|
|
|
|
def validateGraph(graph: Dict[str, Any], nodeTypeIds: Set[str]) -> List[str]:
|
|
"""
|
|
Validate graph: all node IDs referenced in connections exist, all node types in registry.
|
|
Returns list of error messages (empty if valid).
|
|
"""
|
|
errors = []
|
|
nodes, connections, nodeIds = parseGraph(graph)
|
|
|
|
for n in nodes:
|
|
nid = n.get("id")
|
|
ntype = n.get("type")
|
|
if not nid:
|
|
errors.append("Node missing id")
|
|
continue
|
|
if not ntype:
|
|
errors.append(f"Node {nid} missing type")
|
|
continue
|
|
if ntype not in nodeTypeIds:
|
|
errors.append(f"Unknown node type '{ntype}' for node {nid}")
|
|
|
|
connMap = buildConnectionMap(connections)
|
|
allReferred = set()
|
|
for tgt, pairs in connMap.items():
|
|
allReferred.add(tgt)
|
|
for src, _, _ in pairs:
|
|
allReferred.add(src)
|
|
for nid in allReferred:
|
|
if nid not in nodeIds:
|
|
errors.append(f"Connection references non-existent node {nid}")
|
|
|
|
# Port compatibility: hard-fail (Pick-not-Push typed graph)
|
|
port_errors = _checkPortCompatibility(nodes, connMap)
|
|
if port_errors:
|
|
logger.warning("validateGraph port mismatches: %s", port_errors)
|
|
errors.extend(port_errors)
|
|
|
|
if errors:
|
|
logger.debug("validateGraph errors: %s", errors)
|
|
else:
|
|
logger.debug("validateGraph: OK")
|
|
return errors
|
|
|
|
|
|
def parse_graph_defined_schema(node: Dict[str, Any], parameter_key: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Build a JSON-serializable port schema dict from graph parameters (e.g. form ``fields``).
|
|
Used by tooling and future API surfaces; mirrors ``parse_graph_defined_output_schema`` logic.
|
|
"""
|
|
from modules.features.graphicalEditor.portTypes import deriveFormPayloadSchemaFromParam
|
|
|
|
sch = deriveFormPayloadSchemaFromParam(node, parameter_key)
|
|
if sch is None:
|
|
return None
|
|
return {
|
|
"name": sch.name,
|
|
"fields": [f.model_dump() for f in sch.fields],
|
|
}
|
|
|
|
|
|
def _checkPortCompatibility(
|
|
nodes: List[Dict],
|
|
connMap: Dict[str, List[Tuple[str, int, int]]],
|
|
) -> List[str]:
|
|
"""
|
|
Hard typed-port check: incompatible connections become validation errors.
|
|
"""
|
|
from modules.features.graphicalEditor.nodeDefinitions import STATIC_NODE_TYPES
|
|
from modules.features.graphicalEditor.portTypes import resolve_output_schema_name
|
|
|
|
nodeDefMap = {n["id"]: n for n in STATIC_NODE_TYPES}
|
|
nodeById = {n["id"]: n for n in nodes if n.get("id")}
|
|
warnings: List[str] = []
|
|
|
|
for tgt, pairs in connMap.items():
|
|
tgtNode = nodeById.get(tgt)
|
|
if not tgtNode:
|
|
continue
|
|
tgtDef = nodeDefMap.get(tgtNode.get("type", ""))
|
|
if not tgtDef:
|
|
continue
|
|
tgtInputPorts = tgtDef.get("inputPorts", {})
|
|
|
|
for src, srcOut, tgtIn in pairs:
|
|
srcNode = nodeById.get(src)
|
|
if not srcNode:
|
|
continue
|
|
srcDef = nodeDefMap.get(srcNode.get("type", ""))
|
|
if not srcDef:
|
|
continue
|
|
srcOutputPorts = srcDef.get("outputPorts", {})
|
|
srcPort = srcOutputPorts.get(srcOut, {}) or {}
|
|
tgtPort = tgtInputPorts.get(tgtIn, {}) or {}
|
|
|
|
if not isinstance(srcPort, dict):
|
|
continue
|
|
src_schema = resolve_output_schema_name(srcNode, srcPort)
|
|
accepts = tgtPort.get("accepts", [])
|
|
|
|
if not accepts or not src_schema:
|
|
continue
|
|
if src_schema in accepts:
|
|
continue
|
|
# Port that only declares Transit behaves as an untyped sink (legacy graphs).
|
|
if len(accepts) == 1 and accepts[0] == "Transit":
|
|
continue
|
|
if src_schema == "FormPayload_dynamic" and "FormPayload" in accepts:
|
|
continue
|
|
if src_schema.startswith("FormPayload") and "FormPayload" in accepts:
|
|
continue
|
|
warnings.append(
|
|
f"Port mismatch: {src}[out:{srcOut}] ({src_schema}) -> {tgt}[in:{tgtIn}] (accepts: {accepts})"
|
|
)
|
|
|
|
return warnings
|
|
|
|
|
|
def topoSort(nodes: List[Dict], connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> List[Dict]:
|
|
"""
|
|
Topological sort: start from trigger nodes, then BFS by connections.
|
|
Returns ordered list of nodes (trigger first, then downstream).
|
|
"""
|
|
nodeById = {n["id"]: n for n in nodes if n.get("id")}
|
|
triggers = getTriggerNodes(nodes)
|
|
if not triggers:
|
|
return list(nodes)
|
|
|
|
visited: Set[str] = set()
|
|
order: List[Dict] = []
|
|
|
|
def bfs(startIds: List[str]) -> None:
|
|
from collections import deque
|
|
q = deque(startIds)
|
|
for nid in startIds:
|
|
visited.add(nid)
|
|
if nid in nodeById:
|
|
order.append(nodeById[nid])
|
|
while q:
|
|
nid = q.popleft()
|
|
# Find all nodes that receive from nid
|
|
for tgt, pairs in connectionMap.items():
|
|
for src, _, _ in pairs:
|
|
if src == nid and tgt not in visited:
|
|
visited.add(tgt)
|
|
q.append(tgt)
|
|
if tgt in nodeById:
|
|
order.append(nodeById[tgt])
|
|
|
|
triggerIds = [t["id"] for t in triggers]
|
|
logger.debug("topoSort triggers: %s", triggerIds)
|
|
bfs(triggerIds)
|
|
|
|
# Append any orphan nodes (e.g. disconnected)
|
|
for n in nodes:
|
|
if n.get("id") and n["id"] not in visited:
|
|
order.append(n)
|
|
logger.debug("topoSort order (%d nodes): %s", len(order), [n.get("id") for n in order])
|
|
return order
|
|
|
|
|
|
_WILDCARD_SEGMENT = "*"
|
|
|
|
|
|
def _get_by_path(data: Any, path: List[Any]) -> Any:
|
|
"""Traverse data by path (strings and ints); return None if not found.
|
|
|
|
Supports the iteration wildcard ``"*"`` as a path segment: when applied
|
|
to a list, the remainder of the path is mapped over each element and the
|
|
results are returned as a list (drops elements that resolve to ``None``).
|
|
This is the "typed Bindings-Resolver" iteration primitive defined for
|
|
Schicht 4 of the Typed Action Architecture.
|
|
"""
|
|
current = data
|
|
for i, seg in enumerate(path):
|
|
if current is None:
|
|
return None
|
|
if isinstance(seg, str) and seg == _WILDCARD_SEGMENT:
|
|
if not isinstance(current, (list, tuple)):
|
|
return None
|
|
tail = list(path[i + 1 :])
|
|
if not tail:
|
|
return list(current)
|
|
mapped: List[Any] = []
|
|
for item in current:
|
|
resolved = _get_by_path(item, tail)
|
|
if resolved is None:
|
|
continue
|
|
mapped.append(resolved)
|
|
return mapped
|
|
if isinstance(current, dict) and isinstance(seg, str) and seg in current:
|
|
current = current[seg]
|
|
elif isinstance(current, (list, tuple)) and isinstance(seg, (int, str)):
|
|
idx = int(seg) if isinstance(seg, str) and seg.isdigit() else seg
|
|
if isinstance(idx, int) and 0 <= idx < len(current):
|
|
current = current[idx]
|
|
else:
|
|
return None
|
|
else:
|
|
return None
|
|
return current
|
|
|
|
|
|
def _pathContainsWildcard(path: List[Any]) -> bool:
|
|
"""True if any segment is the iteration wildcard ``"*"``."""
|
|
return any(isinstance(seg, str) and seg == _WILDCARD_SEGMENT for seg in path)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Phase-5 Schicht-4 — Typed-Ref envelope unwrap
|
|
# ---------------------------------------------------------------------------
|
|
#
|
|
# Workflow params can carry a typed-ref envelope like
|
|
# ``{"$type": "FeatureInstanceRef", "id": "<uuid>", "featureCode": "trustee"}``.
|
|
# Action implementations historically receive the canonical primitive (the
|
|
# referenced ``id``) as a string. ``_unwrapTypedRef`` extracts that primitive
|
|
# without losing the typed envelope shape on disk — the migration script
|
|
# (``featureInstanceRefMigration.materializeFeatureInstanceRefs``) writes the
|
|
# envelope, the resolver unwraps it on its way to the action.
|
|
|
|
_TYPED_REF_PRIMARY_FIELD = {
|
|
"FeatureInstanceRef": "id",
|
|
"ConnectionRef": "id",
|
|
"PromptTemplateRef": "id",
|
|
"ClickUpListRef": "listId",
|
|
"SharePointFileRef": "filePath",
|
|
"SharePointFolderRef": "folderPath",
|
|
}
|
|
|
|
|
|
def _isTypedRefEnvelope(value: Any) -> bool:
|
|
"""True if ``value`` looks like a typed-ref envelope ({\"$type\": \"<CatalogType>\", ...})."""
|
|
if not isinstance(value, dict):
|
|
return False
|
|
typeName = value.get("$type")
|
|
return isinstance(typeName, str) and typeName in _TYPED_REF_PRIMARY_FIELD
|
|
|
|
|
|
def _unwrapTypedRef(value: Any) -> Any:
|
|
"""If ``value`` is a typed-ref envelope, return its primary primitive.
|
|
|
|
Falls back to the original value for unknown / non-envelope inputs.
|
|
"""
|
|
if not _isTypedRefEnvelope(value):
|
|
return value
|
|
primary = _TYPED_REF_PRIMARY_FIELD[value["$type"]]
|
|
return value.get(primary, value)
|
|
|
|
|
|
def resolveParameterReferences(value: Any, nodeOutputs: Dict[str, Any]) -> Any:
|
|
"""
|
|
Resolve parameter references:
|
|
- {{nodeId.output}} or {{nodeId.output.path}} in strings (legacy)
|
|
- { "type": "ref", "nodeId": "...", "path": ["field", "nested"] } -> resolved value
|
|
- { "type": "value", "value": ... } -> value (then recursively resolve)
|
|
"""
|
|
import json
|
|
import re
|
|
|
|
if isinstance(value, dict):
|
|
# Phase-5 Schicht-4: typed-ref envelopes (FeatureInstanceRef etc.) on
|
|
# disk get unwrapped to their canonical primitive (e.g. ``id``) so
|
|
# legacy action signatures keep working. See ``_unwrapTypedRef``.
|
|
if _isTypedRefEnvelope(value):
|
|
return _unwrapTypedRef(value)
|
|
if value.get("type") == "ref":
|
|
node_id = value.get("nodeId")
|
|
path = value.get("path")
|
|
if node_id is not None and isinstance(path, (list, tuple)):
|
|
data = nodeOutputs.get(node_id)
|
|
# Unwrap transit envelopes to access the real data
|
|
if isinstance(data, dict) and data.get("_transit"):
|
|
data = data.get("data", data)
|
|
plist = list(path)
|
|
resolved = _get_by_path(data, plist)
|
|
if (
|
|
resolved is None
|
|
and isinstance(data, dict)
|
|
and plist
|
|
and plist[0] == "payload"
|
|
and len(plist) > 1
|
|
):
|
|
resolved = _get_by_path(data, plist[1:])
|
|
return resolveParameterReferences(resolved, nodeOutputs)
|
|
return value
|
|
if value.get("type") == "value":
|
|
inner = value.get("value")
|
|
return resolveParameterReferences(inner, nodeOutputs)
|
|
if value.get("type") == "system":
|
|
variable = value.get("variable", "")
|
|
from modules.features.graphicalEditor.portTypes import resolveSystemVariable
|
|
return resolveSystemVariable(variable, nodeOutputs.get("_context", {}))
|
|
return {k: resolveParameterReferences(v, nodeOutputs) for k, v in value.items()}
|
|
|
|
if isinstance(value, str):
|
|
def repl(m):
|
|
ref = m.group(1).strip()
|
|
parts = ref.split(".")
|
|
nodeId = parts[0]
|
|
data = nodeOutputs.get(nodeId)
|
|
if data is None:
|
|
return m.group(0)
|
|
if len(parts) < 2:
|
|
return json.dumps(data) if isinstance(data, (dict, list)) else str(data)
|
|
rest = ".".join(parts[1:])
|
|
if data is None:
|
|
return m.group(0)
|
|
for k in rest.split("."):
|
|
if isinstance(data, dict) and k in data:
|
|
data = data[k]
|
|
elif isinstance(data, (list, tuple)) and k.isdigit():
|
|
data = data[int(k)]
|
|
else:
|
|
return m.group(0)
|
|
return str(data) if data is not None else m.group(0)
|
|
return re.sub(r"\{\{\s*([^}]+)\s*\}\}", repl, value)
|
|
if isinstance(value, list):
|
|
# contextBuilder: list where every item is a `{"type":"ref",...}` envelope.
|
|
# Resolve each ref and join the serialised parts into a single prompt string.
|
|
if value and all(isinstance(v, dict) and v.get("type") == "ref" for v in value):
|
|
from modules.workflows.methods.methodAi._common import serialize_context
|
|
parts = [serialize_context(resolveParameterReferences(v, nodeOutputs)) for v in value]
|
|
return "\n\n".join(p for p in parts if p)
|
|
return [resolveParameterReferences(v, nodeOutputs) for v in value]
|
|
return value
|