242 lines
8.4 KiB
Python
242 lines
8.4 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
"""Compute pickable upstream paths for DataPicker / AI workflow tools."""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, List, Set
|
|
|
|
from modules.features.graphicalEditor.nodeDefinitions import STATIC_NODE_TYPES
|
|
from modules.features.graphicalEditor.portTypes import PORT_TYPE_CATALOG, PortSchema, parse_graph_defined_output_schema
|
|
from modules.workflows.automation2.graphUtils import buildConnectionMap, getLoopBodyNodeIds, getLoopDoneNodeIds
|
|
|
|
_NODE_BY_TYPE = {n["id"]: n for n in STATIC_NODE_TYPES}
|
|
|
|
|
|
def _paths_for_port_schema(schema: PortSchema, producer_node_id: str) -> List[Dict[str, Any]]:
|
|
out: List[Dict[str, Any]] = []
|
|
for field in schema.fields:
|
|
path = [field.name]
|
|
out.append(
|
|
{
|
|
"producerNodeId": producer_node_id,
|
|
"path": path,
|
|
"type": field.type,
|
|
"label": ".".join(str(p) for p in path),
|
|
"scopeOrigin": "data",
|
|
}
|
|
)
|
|
out.append(
|
|
{
|
|
"producerNodeId": producer_node_id,
|
|
"path": [],
|
|
"type": schema.name,
|
|
"label": "(whole output)",
|
|
"scopeOrigin": "data",
|
|
}
|
|
)
|
|
return out
|
|
|
|
|
|
def _paths_for_data_pick_options(
|
|
options: List[Dict[str, Any]],
|
|
producer_node_id: str,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Explicit per-port pick list from node definition (authoritative; no catalog expansion)."""
|
|
out: List[Dict[str, Any]] = []
|
|
for o in options:
|
|
if not isinstance(o, dict):
|
|
continue
|
|
path = o.get("path")
|
|
if not isinstance(path, list):
|
|
continue
|
|
label = o.get("pickerLabel")
|
|
out.append(
|
|
{
|
|
"producerNodeId": producer_node_id,
|
|
"path": path,
|
|
"type": o.get("type") or "Any",
|
|
"label": label if isinstance(label, str) else ".".join(str(p) for p in path),
|
|
"scopeOrigin": "data",
|
|
}
|
|
)
|
|
return out
|
|
|
|
|
|
def _paths_for_schema(schema_name: str, producer_node_id: str) -> List[Dict[str, Any]]:
|
|
if not schema_name or schema_name == "Transit":
|
|
return []
|
|
schema = PORT_TYPE_CATALOG.get(schema_name)
|
|
if not schema:
|
|
return []
|
|
return _paths_for_port_schema(schema, producer_node_id)
|
|
|
|
|
|
def compute_upstream_paths(graph: Dict[str, Any], target_node_id: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Return flattened first-level paths for every ancestor node's primary output schema.
|
|
"""
|
|
nodes = graph.get("nodes") or []
|
|
connections = graph.get("connections") or []
|
|
node_by_id = {n["id"]: n for n in nodes if n.get("id")}
|
|
if target_node_id not in node_by_id:
|
|
return []
|
|
|
|
conn_map = buildConnectionMap(connections)
|
|
# predecessors: walk backwards along edges (target -> source)
|
|
preds: Dict[str, Set[str]] = {}
|
|
for tgt, pairs in conn_map.items():
|
|
for src, _, _ in pairs:
|
|
preds.setdefault(tgt, set()).add(src)
|
|
|
|
seen: Set[str] = set()
|
|
stack = [target_node_id]
|
|
ancestors: Set[str] = set()
|
|
while stack:
|
|
cur = stack.pop()
|
|
for p in preds.get(cur, ()):
|
|
if p not in seen:
|
|
seen.add(p)
|
|
ancestors.add(p)
|
|
stack.append(p)
|
|
|
|
paths: List[Dict[str, Any]] = []
|
|
for aid in sorted(ancestors):
|
|
anode = node_by_id.get(aid)
|
|
if not anode:
|
|
continue
|
|
nt = anode.get("type", "")
|
|
ndef = _NODE_BY_TYPE.get(nt)
|
|
if not ndef:
|
|
continue
|
|
out0 = (ndef.get("outputPorts") or {}).get(0, {})
|
|
out0 = out0 if isinstance(out0, dict) else {}
|
|
dpo = out0.get("dataPickOptions")
|
|
if isinstance(dpo, list) and len(dpo) > 0:
|
|
plab = (anode.get("title") or "").strip() or aid
|
|
for entry in _paths_for_data_pick_options(dpo, aid):
|
|
entry["producerLabel"] = plab
|
|
paths.append(entry)
|
|
continue
|
|
|
|
derived = parse_graph_defined_output_schema(anode, out0)
|
|
if derived:
|
|
for entry in _paths_for_port_schema(derived, aid):
|
|
entry["producerLabel"] = (anode.get("title") or "").strip() or aid
|
|
paths.append(entry)
|
|
else:
|
|
raw_schema = out0.get("schema") if isinstance(out0, dict) else None
|
|
schema_name = raw_schema if isinstance(raw_schema, str) and raw_schema else "ActionResult"
|
|
for entry in _paths_for_schema(schema_name, aid):
|
|
entry["producerLabel"] = (anode.get("title") or "").strip() or aid
|
|
paths.append(entry)
|
|
|
|
# Lexical loop hints (flow.loop): only for nodes inside the loop body
|
|
for aid in ancestors:
|
|
anode = node_by_id.get(aid) or {}
|
|
if anode.get("type") != "flow.loop":
|
|
continue
|
|
body_ids = getLoopBodyNodeIds(aid, conn_map)
|
|
if target_node_id in body_ids:
|
|
paths.extend(
|
|
[
|
|
{
|
|
"producerNodeId": aid,
|
|
"path": ["currentItem"],
|
|
"type": "Any",
|
|
"label": "loop.currentItem",
|
|
"scopeOrigin": "loop",
|
|
},
|
|
{
|
|
"producerNodeId": aid,
|
|
"path": ["currentIndex"],
|
|
"type": "int",
|
|
"label": "loop.currentIndex",
|
|
"scopeOrigin": "loop",
|
|
},
|
|
{
|
|
"producerNodeId": aid,
|
|
"path": ["count"],
|
|
"type": "int",
|
|
"label": "loop.count",
|
|
"scopeOrigin": "loop",
|
|
},
|
|
]
|
|
)
|
|
|
|
return paths
|
|
|
|
|
|
def compute_graph_data_sources(graph: Dict[str, Any], target_node_id: str) -> Dict[str, Any]:
|
|
"""Return scope-aware data sources for the DataPicker.
|
|
|
|
Determines which ancestor nodes are valid sources for ``target_node_id``,
|
|
taking loop scoping into account:
|
|
|
|
- If ``target_node_id`` is on the *Done* branch of a ``flow.loop``, the
|
|
loop body nodes are excluded from ``availableSourceIds`` and the loop
|
|
node itself is mapped to its *Fertig* output port (index 1) via
|
|
``portIndexOverrides``.
|
|
- If ``target_node_id`` is *inside* the loop body, the loop node id is
|
|
included in ``loopBodyContextIds`` so the frontend can show the lexical
|
|
loop variables (currentItem, currentIndex, count).
|
|
|
|
Returns::
|
|
|
|
{
|
|
"availableSourceIds": [...], # ordered list
|
|
"portIndexOverrides": {nodeId: n}, # non-zero port indices
|
|
"loopBodyContextIds": [...], # loops whose body this node is in
|
|
}
|
|
"""
|
|
nodes = graph.get("nodes") or []
|
|
connections = graph.get("connections") or []
|
|
node_by_id: Dict[str, Any] = {n["id"]: n for n in nodes if n.get("id")}
|
|
|
|
if target_node_id not in node_by_id:
|
|
return {"availableSourceIds": [], "portIndexOverrides": {}, "loopBodyContextIds": []}
|
|
|
|
conn_map = buildConnectionMap(connections)
|
|
|
|
# Collect all ancestors via backward BFS
|
|
preds: Dict[str, Set[str]] = {}
|
|
for tgt, pairs in conn_map.items():
|
|
for src, _, _ in pairs:
|
|
preds.setdefault(tgt, set()).add(src)
|
|
|
|
seen: Set[str] = set()
|
|
stack = [target_node_id]
|
|
ancestors: Set[str] = set()
|
|
while stack:
|
|
cur = stack.pop()
|
|
for p in preds.get(cur, ()):
|
|
if p not in seen:
|
|
seen.add(p)
|
|
ancestors.add(p)
|
|
stack.append(p)
|
|
|
|
body_nodes_to_exclude: Set[str] = set()
|
|
port_index_overrides: Dict[str, int] = {}
|
|
loop_body_context_ids: List[str] = []
|
|
|
|
for aid in ancestors:
|
|
anode = node_by_id.get(aid) or {}
|
|
if anode.get("type") != "flow.loop":
|
|
continue
|
|
body_ids = getLoopBodyNodeIds(aid, conn_map)
|
|
done_ids = getLoopDoneNodeIds(aid, conn_map)
|
|
|
|
if target_node_id in body_ids:
|
|
loop_body_context_ids.append(aid)
|
|
elif target_node_id in done_ids:
|
|
body_nodes_to_exclude.update(body_ids)
|
|
port_index_overrides[aid] = 1
|
|
|
|
available_source_ids = [
|
|
aid for aid in sorted(ancestors)
|
|
if aid not in body_nodes_to_exclude
|
|
]
|
|
|
|
return {
|
|
"availableSourceIds": available_source_ids,
|
|
"portIndexOverrides": port_index_overrides,
|
|
"loopBodyContextIds": loop_body_context_ids,
|
|
}
|