129 lines
4.6 KiB
Python
129 lines
4.6 KiB
Python
# Copyright (c) 2026 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Action ``context.mergeContext``.
|
|
|
|
Reads ``_branchInputs`` (injected by ``ActionNodeExecutor`` because the node
|
|
declaration sets ``injectBranchInputs: True``) and combines them according to
|
|
the selected strategy.
|
|
|
|
The barrier behaviour — waiting until every connected predecessor has produced
|
|
output — is handled by the execution engine via ``waitsForAllPredecessors`` on
|
|
the node definition; this action is invoked only after all (or ``waitFor``)
|
|
inputs are present.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import logging
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
from modules.datamodels.datamodelChat import ActionResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
_VALID_STRATEGIES = {"shallow", "deep", "firstWins", "lastWins", "errorOnConflict"}
|
|
|
|
|
|
def _shallow_merge(branches: List[Tuple[int, Any]]) -> Tuple[Dict[str, Any], List[str]]:
|
|
merged: Dict[str, Any] = {}
|
|
conflicts: List[str] = []
|
|
for _, val in branches:
|
|
if not isinstance(val, dict):
|
|
continue
|
|
for k, v in val.items():
|
|
if k in merged and merged[k] != v:
|
|
conflicts.append(k)
|
|
merged[k] = v
|
|
return merged, conflicts
|
|
|
|
|
|
def _deep_merge(target: Dict[str, Any], source: Dict[str, Any], conflicts: List[str], path: str = "") -> None:
|
|
for k, v in source.items():
|
|
full = f"{path}.{k}" if path else k
|
|
if k not in target:
|
|
target[k] = copy.deepcopy(v) if isinstance(v, (dict, list)) else v
|
|
continue
|
|
existing = target[k]
|
|
if isinstance(existing, dict) and isinstance(v, dict):
|
|
_deep_merge(existing, v, conflicts, full)
|
|
else:
|
|
if existing != v:
|
|
conflicts.append(full)
|
|
target[k] = copy.deepcopy(v) if isinstance(v, (dict, list)) else v
|
|
|
|
|
|
def _strategy_first_or_last_wins(
|
|
branches: List[Tuple[int, Any]], last: bool
|
|
) -> Tuple[Dict[str, Any], List[str]]:
|
|
iterator = list(reversed(branches)) if not last else list(branches)
|
|
merged: Dict[str, Any] = {}
|
|
conflicts: List[str] = []
|
|
for _, val in iterator:
|
|
if not isinstance(val, dict):
|
|
continue
|
|
for k, v in val.items():
|
|
if k in merged and merged[k] != v:
|
|
conflicts.append(k)
|
|
if last or k not in merged:
|
|
merged[k] = v
|
|
return merged, conflicts
|
|
|
|
|
|
async def mergeContext(self, parameters: Dict[str, Any]) -> ActionResult:
|
|
try:
|
|
strategy = str(parameters.get("strategy") or "deep")
|
|
if strategy not in _VALID_STRATEGIES:
|
|
return ActionResult.isFailure(
|
|
error=f"Invalid strategy '{strategy}', expected one of {sorted(_VALID_STRATEGIES)}"
|
|
)
|
|
|
|
wait_for = int(parameters.get("waitFor") or 0)
|
|
|
|
raw_inputs = parameters.get("_branchInputs") or {}
|
|
if not isinstance(raw_inputs, dict):
|
|
return ActionResult.isFailure(error="No branch inputs available — connect at least two upstream nodes")
|
|
|
|
items: List[Tuple[int, Any]] = sorted(
|
|
((int(k), v) for k, v in raw_inputs.items()),
|
|
key=lambda kv: kv[0],
|
|
)
|
|
if wait_for > 0:
|
|
items = items[:wait_for]
|
|
|
|
if not items:
|
|
return ActionResult.isFailure(error="No branch inputs available")
|
|
|
|
first_value = items[0][1] if items else None
|
|
conflicts: List[str] = []
|
|
|
|
if strategy == "shallow":
|
|
merged, conflicts = _shallow_merge(items)
|
|
elif strategy == "firstWins":
|
|
merged, conflicts = _strategy_first_or_last_wins(items, last=False)
|
|
elif strategy == "lastWins":
|
|
merged, conflicts = _strategy_first_or_last_wins(items, last=True)
|
|
elif strategy == "errorOnConflict":
|
|
merged, conflicts = _shallow_merge(items)
|
|
if conflicts:
|
|
return ActionResult.isFailure(
|
|
error=f"Conflicting keys: {sorted(set(conflicts))}",
|
|
)
|
|
else: # deep (default)
|
|
merged = {}
|
|
for _, val in items:
|
|
if isinstance(val, dict):
|
|
_deep_merge(merged, val, conflicts)
|
|
|
|
data: Dict[str, Any] = {
|
|
"inputs": {idx: val for idx, val in items},
|
|
"first": first_value,
|
|
"merged": merged,
|
|
"strategy": strategy,
|
|
"conflicts": sorted(set(conflicts)) if conflicts else [],
|
|
}
|
|
return ActionResult.isSuccess(data=data)
|
|
except Exception as exc:
|
|
logger.exception("mergeContext failed")
|
|
return ActionResult.isFailure(error=str(exc))
|