# Copyright (c) 2026 PowerOn AG # All rights reserved. """Cascade-inherit semantics for DataSource flags (neutralize, ragIndexEnabled). Three-state flags allow tree elements to either set an explicit value or inherit the value from their nearest ancestor in the path hierarchy. Modes: - 'walk' (default): resolves the *concrete* effective value per-item (never returns 'mixed'). Used by backend consumers (RAG walker, neutralization pipeline, scope filter, etc.). - 'aggregate': resolves the *display* effective value per-item. If the item has descendants with differing walk-effective values, returns 'mixed'. Used by listing endpoints and PATCH responses for the UI. Path-traversal rules: - A DataSource is identified by `(connectionId, sourceType, path)`. - The root of a service tree is `path == '/'`. - Sub-elements have paths like `/folder1/sub`. Their parent path is the longest prefix path that exists as a DataSource record (string-based). - If no ancestor with an explicit value exists, the default is `False` — matching the legacy behavior of NULL = inherit. (scope was removed from DataSource in 2026-06 for privacy reasons.) """ import logging from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple logger = logging.getLogger(__name__) _INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled") _INHERITABLE_FDS_FLAGS = ("neutralize", "ragIndexEnabled") # Connection-root DataSources carry the authority as their sourceType # (e.g. 'msft', 'google'). They sit one level above all service DataSources # of the same connection in the visual tree, so flag inheritance must # cross sourceType boundaries — but ONLY from these authority roots. _AUTHORITY_SOURCE_TYPES = frozenset({"local", "google", "msft", "clickup", "infomaniak"}) Mode = Literal["walk", "aggregate"] # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def normalisePath(path: Optional[str]) -> str: """Normalize a DataSource path to '/'-prefixed, no trailing slash (except root).""" if not path: return "/" p = str(path).strip() if not p.startswith("/"): p = "/" + p if len(p) > 1 and p.endswith("/"): p = p.rstrip("/") return p def _flagDefault(flag: str) -> Any: return False def _isExplicit(value: Any) -> bool: """A flag value is explicit when it is not None/empty-string.""" if value is None: return False if isinstance(value, str) and value == "": return False return True def _getRecordValue(rec: Any, key: str) -> Any: if isinstance(rec, dict): return rec.get(key) return getattr(rec, key, None) def _isAncestorPath(ancestor: str, descendant: str) -> bool: """True iff `ancestor` is a strict path-prefix of `descendant`.""" if ancestor == descendant: return False if ancestor == "/": return descendant != "/" return descendant.startswith(ancestor + "/") def _pathDepth(path: str) -> int: if path == "/": return 0 return path.count("/") def _findAncestorChain( rec: Dict[str, Any], allDs: Iterable[Dict[str, Any]], ) -> List[Dict[str, Any]]: """Return all ancestor DataSources of `rec` in the same connection, ordered nearest-first. Two ancestor relations are merged: 1) same-sourceType path-ancestor — strict path-prefix within the same service tree. 2) connection-root ancestor — a DS with `path='/'` and `sourceType` in authority set is the parent of every other DS in that connection regardless of sourceType. The connection-root is always the most distant ancestor. """ recPath = normalisePath(_getRecordValue(rec, "path")) recSourceType = _getRecordValue(rec, "sourceType") recConnectionId = _getRecordValue(rec, "connectionId") sameTypeCandidates: List[Tuple[int, Dict[str, Any]]] = [] connectionRoot: Optional[Dict[str, Any]] = None recIsConnectionRoot = recSourceType in _AUTHORITY_SOURCE_TYPES and recPath == "/" for cand in allDs: if _getRecordValue(cand, "id") == _getRecordValue(rec, "id"): continue if _getRecordValue(cand, "connectionId") != recConnectionId: continue candSourceType = _getRecordValue(cand, "sourceType") candPath = normalisePath(_getRecordValue(cand, "path")) if candSourceType == recSourceType: if candPath == recPath or not _isAncestorPath(candPath, recPath): continue sameTypeCandidates.append((len(candPath), cand)) elif ( not recIsConnectionRoot and candSourceType in _AUTHORITY_SOURCE_TYPES and candPath == "/" ): connectionRoot = cand sameTypeCandidates.sort(key=lambda x: x[0], reverse=True) chain = [c for _, c in sameTypeCandidates] if connectionRoot is not None: chain.append(connectionRoot) return chain def _isDescendantDs(parentRec: Dict[str, Any], candidate: Dict[str, Any]) -> bool: """True iff `candidate` is a descendant of `parentRec` in the DS hierarchy.""" parentSourceType = _getRecordValue(parentRec, "sourceType") parentPath = normalisePath(_getRecordValue(parentRec, "path")) parentConnectionId = _getRecordValue(parentRec, "connectionId") parentId = _getRecordValue(parentRec, "id") candId = _getRecordValue(candidate, "id") if candId == parentId: return False if _getRecordValue(candidate, "connectionId") != parentConnectionId: return False candSourceType = _getRecordValue(candidate, "sourceType") candPath = normalisePath(_getRecordValue(candidate, "path")) parentIsConnectionRoot = ( parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/" ) if parentIsConnectionRoot: return True if candSourceType != parentSourceType: return False return _isAncestorPath(parentPath, candPath) # --------------------------------------------------------------------------- # DataSource: getEffectiveFlag # --------------------------------------------------------------------------- def getEffectiveFlag( rec: Dict[str, Any], flag: str, sameConnectionDs: Iterable[Dict[str, Any]], mode: Mode = "walk", ) -> Any: """Resolve the effective value of a flag via path-traversal. mode='walk': own explicit → nearest ancestor explicit → default. Always returns a concrete value (never 'mixed'). mode='aggregate': same as walk for leaf value, but if the item has descendants whose walk-effective values differ from each other, returns 'mixed'. """ if flag not in _INHERITABLE_FLAGS: raise ValueError(f"Unknown inheritable flag: {flag}") allDs = list(sameConnectionDs) walkValue = _resolveWalkValue(rec, flag, allDs) if mode == "walk": return walkValue # mode == 'aggregate': check subtree for heterogeneous effective values descendants = [d for d in allDs if _isDescendantDs(rec, d)] if not descendants: return walkValue subtreeValues = set() subtreeValues.add(_normaliseForComparison(walkValue)) for desc in descendants: descEffective = _resolveWalkValue(desc, flag, allDs) subtreeValues.add(_normaliseForComparison(descEffective)) if len(subtreeValues) > 1: recId = _getRecordValue(rec, "id") descId = _getRecordValue(desc, "id") descOwnVal = _getRecordValue(desc, flag) logger.info( "DS aggregate MIXED for rec=%s flag=%s: walkValue=%s, " "divergent desc=%s (own=%s, effective=%s), subtreeValues=%s", recId, flag, walkValue, descId, descOwnVal, descEffective, subtreeValues, ) return "mixed" return walkValue def _resolveWalkValue(rec: Dict[str, Any], flag: str, allDs: List[Dict[str, Any]]) -> Any: """Core walk resolution: own explicit → ancestor chain → default.""" own = _getRecordValue(rec, flag) if _isExplicit(own): return own chain = _findAncestorChain(rec, allDs) for ancestor in chain: ancestorVal = _getRecordValue(ancestor, flag) if _isExplicit(ancestorVal): return ancestorVal return _flagDefault(flag) def _normaliseForComparison(value: Any) -> Any: """Normalize values for set-comparison (bool as int to avoid hash issues).""" if isinstance(value, bool): return int(value) return value # --------------------------------------------------------------------------- # DataSource: cascadeResetDescendants (bottom-up) # --------------------------------------------------------------------------- def cascadeResetDescendants( rootIf: Any, parentRec: Dict[str, Any], flag: str, ) -> List[str]: """Reset all explicit descendant values of `flag` to NULL (= inherit). Reset order: bottom-up (deepest first) for crash safety. The parent itself is NOT modified here — the caller sets the master value after this function returns. Returns list of reset record IDs in bottom-up order. """ if flag not in _INHERITABLE_FLAGS: raise ValueError(f"Unknown inheritable flag: {flag}") from modules.datamodels.datamodelDataSource import DataSource connectionId = _getRecordValue(parentRec, "connectionId") parentId = _getRecordValue(parentRec, "id") if not connectionId: return [] siblings = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) toReset: List[Tuple[int, str]] = [] for sib in siblings: if not _isDescendantDs(parentRec, sib): continue sibVal = _getRecordValue(sib, flag) if not _isExplicit(sibVal): continue sibId = _getRecordValue(sib, "id") sibPath = normalisePath(_getRecordValue(sib, "path")) toReset.append((_pathDepth(sibPath), sibId)) # Sort deepest first (bottom-up) toReset.sort(key=lambda x: x[0], reverse=True) resetIds: List[str] = [] for _, sibId in toReset: try: rootIf.db.recordModify(DataSource, sibId, {flag: None}) resetIds.append(sibId) except Exception as exc: logger.warning("Cascade-reset failed for DataSource %s flag=%s: %s", sibId, flag, exc) if resetIds: logger.info( "Cascade-reset %s on %d descendants of DataSource %s (bottom-up)", flag, len(resetIds), parentId, ) return resetIds # --------------------------------------------------------------------------- # DataSource: collectAncestorChain (for updatedAncestors in PATCH response) # --------------------------------------------------------------------------- def collectAncestorChain( rec: Dict[str, Any], sameConnectionDs: Iterable[Dict[str, Any]], ) -> List[Dict[str, Any]]: """Return ancestor chain of `rec` (nearest-first), same as internal helper. Exposed for PATCH endpoints to compute updatedAncestors. """ return _findAncestorChain(rec, sameConnectionDs) # --------------------------------------------------------------------------- # DataSource: buildEffectiveByConnection # --------------------------------------------------------------------------- def buildEffectiveByConnection( dataSources: Iterable[Dict[str, Any]], flag: str, mode: Mode = "walk", ) -> Dict[str, Any]: """Pre-compute the effective value of `flag` for every DataSource id. Uses the specified mode. O(N^2) worst case but N is bounded per connection. """ if flag not in _INHERITABLE_FLAGS: raise ValueError(f"Unknown inheritable flag: {flag}") allDs = list(dataSources) out: Dict[str, Any] = {} for rec in allDs: recId = _getRecordValue(rec, "id") out[recId] = getEffectiveFlag(rec, flag, allDs, mode=mode) return out # --------------------------------------------------------------------------- # FeatureDataSource helpers # --------------------------------------------------------------------------- def _fdsClassify(fds: Dict[str, Any]) -> str: """Return 'workspace' | 'table' | 'record' based on the FDS identifier shape.""" tableName = _getRecordValue(fds, "tableName") or "" recordFilter = _getRecordValue(fds, "recordFilter") if tableName == "*": return "workspace" if not recordFilter: return "table" return "record" def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool: """Return True iff `parent` FDS is a strict ancestor of `child` FDS. Hierarchy within one featureInstanceId (allFds is already scoped to a single workspace): feature-wildcard (tableName='*') -> table-wildcard / record-fds table-wildcard (tableName='X') -> record-fds (tableName='X') """ parentFiId = _getRecordValue(parent, "featureInstanceId") childFiId = _getRecordValue(child, "featureInstanceId") if not parentFiId or parentFiId != childFiId: return False if _getRecordValue(parent, "id") == _getRecordValue(child, "id"): return False parentKind = _fdsClassify(parent) childKind = _fdsClassify(child) if parentKind == "workspace": return childKind in ("table", "record") if parentKind == "table": if childKind != "record": return False return _getRecordValue(parent, "tableName") == _getRecordValue(child, "tableName") return False def _fdsDepth(fds: Dict[str, Any]) -> int: kind = _fdsClassify(fds) if kind == "workspace": return 0 if kind == "table": return 1 return 2 # --------------------------------------------------------------------------- # FeatureDataSource: getEffectiveFlagFds # --------------------------------------------------------------------------- def getEffectiveFlagFds( rec: Dict[str, Any], flag: str, sameWorkspaceFds: Iterable[Dict[str, Any]], mode: Mode = "walk", ) -> Any: """Resolve effective value of a FeatureDataSource flag. mode='walk': own explicit -> table-wildcard -> workspace-wildcard -> default. mode='aggregate': same but returns 'mixed' if descendants diverge. """ if flag not in _INHERITABLE_FDS_FLAGS: raise ValueError(f"Unknown inheritable FDS flag: {flag}") allFds = list(sameWorkspaceFds) walkValue = _resolveWalkValueFds(rec, flag, allFds) if mode == "walk": return walkValue # mode == 'aggregate' descendants = [f for f in allFds if _fdsIsAncestor(rec, f)] if not descendants: return walkValue subtreeValues = set() subtreeValues.add(_normaliseForComparison(walkValue)) for desc in descendants: descEffective = _resolveWalkValueFds(desc, flag, allFds) subtreeValues.add(_normaliseForComparison(descEffective)) if len(subtreeValues) > 1: recId = _getRecordValue(rec, "id") descId = _getRecordValue(desc, "id") descOwnVal = _getRecordValue(desc, flag) logger.info( "FDS aggregate MIXED for rec=%s flag=%s: walkValue=%s, " "divergent desc=%s (own=%s, effective=%s), subtreeValues=%s", recId, flag, walkValue, descId, descOwnVal, descEffective, subtreeValues, ) return "mixed" return walkValue def _resolveWalkValueFds(rec: Dict[str, Any], flag: str, allFds: List[Dict[str, Any]]) -> Any: """Core walk resolution for FDS.""" own = _getRecordValue(rec, flag) if _isExplicit(own): return own ancestors = [a for a in allFds if _fdsIsAncestor(a, rec)] ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1) for ancestor in ancestors: val = _getRecordValue(ancestor, flag) if _isExplicit(val): return val return _flagDefault(flag) # --------------------------------------------------------------------------- # FeatureDataSource: cascadeResetDescendantsFds (bottom-up) # --------------------------------------------------------------------------- def cascadeResetDescendantsFds( rootIf: Any, parentRec: Dict[str, Any], flag: str, ) -> List[str]: """Reset explicit `flag` to NULL on every descendant FDS of `parentRec`. Reset order: bottom-up (deepest first) for crash safety. Returns list of reset record IDs in bottom-up order. """ if flag not in _INHERITABLE_FDS_FLAGS: raise ValueError(f"Unknown inheritable FDS flag: {flag}") from modules.datamodels.datamodelFeatures import FeatureDataSource featureInstanceId = _getRecordValue(parentRec, "featureInstanceId") if not featureInstanceId: return [] siblings = rootIf.db.getRecordset( FeatureDataSource, recordFilter={"featureInstanceId": featureInstanceId} ) toReset: List[Tuple[int, str]] = [] for sib in siblings: if not _fdsIsAncestor(parentRec, sib): continue sibVal = _getRecordValue(sib, flag) if not _isExplicit(sibVal): continue sibId = _getRecordValue(sib, "id") toReset.append((_fdsDepth(sib), sibId)) toReset.sort(key=lambda x: x[0], reverse=True) resetIds: List[str] = [] for _, sibId in toReset: try: rootIf.db.recordModify(FeatureDataSource, sibId, {flag: None}) resetIds.append(sibId) except Exception as exc: logger.warning("FDS cascade-reset failed for %s flag=%s: %s", sibId, flag, exc) if resetIds: logger.info( "FDS cascade-reset %s on %d descendants of FDS %s (bottom-up)", flag, len(resetIds), _getRecordValue(parentRec, "id"), ) return resetIds # --------------------------------------------------------------------------- # FeatureDataSource: collectAncestorChainFds # --------------------------------------------------------------------------- def collectAncestorChainFds( rec: Dict[str, Any], sameWorkspaceFds: Iterable[Dict[str, Any]], ) -> List[Dict[str, Any]]: """Return ancestor chain of `rec` FDS (nearest-first). Exposed for PATCH endpoints to compute updatedAncestors. """ allFds = list(sameWorkspaceFds) ancestors = [a for a in allFds if _fdsIsAncestor(a, rec)] ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1) return ancestors # --------------------------------------------------------------------------- # FeatureDataSource: buildEffectiveByWorkspaceFds # --------------------------------------------------------------------------- def buildEffectiveByWorkspaceFds( fdses: Iterable[Dict[str, Any]], flag: str, mode: Mode = "walk", ) -> Dict[str, Any]: """Pre-compute the effective value of `flag` for every FDS id.""" if flag not in _INHERITABLE_FDS_FLAGS: raise ValueError(f"Unknown inheritable FDS flag: {flag}") allFds = list(fdses) out: Dict[str, Any] = {} for rec in allFds: recId = _getRecordValue(rec, "id") out[recId] = getEffectiveFlagFds(rec, flag, allFds, mode=mode) return out # --------------------------------------------------------------------------- # Bulk resolve: effective flags for arbitrary paths (even without DB record) # --------------------------------------------------------------------------- def resolveEffectiveForPath( connectionId: str, sourceType: str, path: str, allDs: List[Dict[str, Any]], mode: Mode = "aggregate", ) -> Dict[str, Any]: """Resolve effective flags for ANY (connectionId, sourceType, path) tuple. Works whether or not a DataSource record exists for this exact path. Returns dict with effectiveNeutralize, effectiveRagIndexEnabled. (effectiveScope removed 2026-06 — personal sources have no scope.) """ normPath = normalisePath(path) exactRecord = None for ds in allDs: if ( _getRecordValue(ds, "connectionId") == connectionId and _getRecordValue(ds, "sourceType") == sourceType and normalisePath(_getRecordValue(ds, "path")) == normPath ): exactRecord = ds break if exactRecord: return { "effectiveNeutralize": getEffectiveFlag(exactRecord, "neutralize", allDs, mode=mode), "effectiveRagIndexEnabled": getEffectiveFlag(exactRecord, "ragIndexEnabled", allDs, mode=mode), } virtualRec = { "id": "__virtual__", "connectionId": connectionId, "sourceType": sourceType, "path": normPath, "neutralize": None, "ragIndexEnabled": None, } return { "effectiveNeutralize": getEffectiveFlag(virtualRec, "neutralize", allDs, mode=mode), "effectiveRagIndexEnabled": getEffectiveFlag(virtualRec, "ragIndexEnabled", allDs, mode=mode), } def resolveEffectiveForFds( featureInstanceId: str, tableName: str, recordFilter: Optional[Dict[str, str]], allFds: List[Dict[str, Any]], mode: Mode = "aggregate", ) -> Dict[str, Any]: """Resolve effective flags for ANY FDS tuple (even without DB record). `allFds` is pre-scoped (typically to a mandate). Within that set, the coordinate is featureInstanceId + tableName + recordFilter. Returns dict with effectiveNeutralize, effectiveRagIndexEnabled. FDS has no `scope` attribute (visibility is governed by feature RBAC). """ exactRecord = None for fds in allFds: if _getRecordValue(fds, "featureInstanceId") != featureInstanceId: continue if (_getRecordValue(fds, "tableName") or "") != tableName: continue fdsFilter = _getRecordValue(fds, "recordFilter") if fdsFilter == recordFilter: exactRecord = fds break if exactRecord: return { "effectiveNeutralize": getEffectiveFlagFds(exactRecord, "neutralize", allFds, mode=mode), "effectiveRagIndexEnabled": getEffectiveFlagFds(exactRecord, "ragIndexEnabled", allFds, mode=mode), } virtualRec = { "id": "__virtual__", "featureInstanceId": featureInstanceId, "tableName": tableName, "recordFilter": recordFilter, "neutralize": None, "ragIndexEnabled": None, } return { "effectiveNeutralize": getEffectiveFlagFds(virtualRec, "neutralize", allFds, mode=mode), "effectiveRagIndexEnabled": getEffectiveFlagFds(virtualRec, "ragIndexEnabled", allFds, mode=mode), }