diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py index f1a34f70..fa4cba44 100644 --- a/modules/connectors/connectorDbPostgre.py +++ b/modules/connectors/connectorDbPostgre.py @@ -172,7 +172,7 @@ def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: s pass # already a list elif fieldType == "BOOLEAN": - record[fieldName] = bool(value) if value is not None else False + record[fieldName] = bool(value) if value is not None else None elif fieldType == "JSONB" and value is not None: try: @@ -184,6 +184,18 @@ def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: s logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})") +def _stripNulBytesFromStr(value: Any) -> Any: + """psycopg2 rejects bound parameters whose Python str contains NUL (0x00). + + Some extracted files (e.g. SQL dumps, mixed binary treated as text) still + carry those bytes; PostgreSQL TEXT could store them via other paths, but + the client protocol path used here cannot. + """ + if isinstance(value, str) and "\x00" in value: + return value.replace("\x00", "") + return value + + def _quotePgIdent(name: str) -> str: return '"' + str(name).replace('"', '""') + '"' @@ -983,7 +995,7 @@ class DatabaseConnector: else: value = json.dumps(value) - values.append(value) + values.append(_stripNulBytesFromStr(value)) # Build INSERT/UPDATE with quoted identifiers col_names = ", ".join([f'"{col}"' for col in columns]) diff --git a/modules/datamodels/datamodelFeatureDataSource.py b/modules/datamodels/datamodelFeatureDataSource.py index f07a8bda..10fd76a7 100644 --- a/modules/datamodels/datamodelFeatureDataSource.py +++ b/modules/datamodels/datamodelFeatureDataSource.py @@ -76,6 +76,14 @@ class FeatureDataSource(PowerOnModel): ), json_schema_extra={"label": "Neutralisieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, ) + ragIndexEnabled: Optional[bool] = Field( + default=None, + description=( + "Three-state RAG-indexing flag with cascade-inherit semantics. " + "None = inherit; True/False = explicit. Cascade-reset on parent toggle." + ), + json_schema_extra={"label": "RAG-Indexierung", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, + ) neutralizeFields: Optional[List[str]] = Field( default=None, description="Column names whose values are replaced with placeholders before AI processing", diff --git a/modules/demoConfigs/investorDemo2026.py b/modules/demoConfigs/investorDemo2026.py index f8fc678f..d807921d 100644 --- a/modules/demoConfigs/investorDemo2026.py +++ b/modules/demoConfigs/investorDemo2026.py @@ -124,6 +124,7 @@ class InvestorDemo2026(_BaseDemoConfig): from modules.datamodels.datamodelUam import Mandate, UserInDB from modules.datamodels.datamodelMembership import UserMandate + summary["_removedMandateIds"] = [] for mandateDef in [_MANDATE_HAPPYLIFE, _MANDATE_ALPINA]: try: existing = db.getRecordset(Mandate, recordFilter={"name": mandateDef["name"]}) @@ -132,28 +133,36 @@ class InvestorDemo2026(_BaseDemoConfig): self._removeMandateData(db, mid, mandateDef["label"], summary) db.recordDelete(Mandate, mid) summary["removed"].append(f"Mandate {mandateDef['label']} ({mid})") + summary["_removedMandateIds"].append({"id": mid, "mandateId": mid}) logger.info(f"Removed mandate {mandateDef['label']} ({mid})") except Exception as e: summary["errors"].append(f"Remove mandate {mandateDef['label']}: {e}") + # SAFETY: NEVER delete the user record. The user may have connections, + # chats, workflows, files, and other data across multiple databases. + # Only remove the mandate memberships that THIS demo created. try: existing = db.getRecordset(UserInDB, recordFilter={"username": _USER["username"]}) for u in existing: uid = u.get("id") + removedMandateIds = {m.get("mandateId") for m in summary.get("_removedMandateIds", [])} memberships = db.getRecordset(UserMandate, recordFilter={"userId": uid}) for mem in memberships: - try: - db.recordDelete(UserMandate, mem.get("id")) - except Exception: - pass - db.recordDelete(UserInDB, uid) - summary["removed"].append(f"User {_USER['username']} ({uid})") - logger.info(f"Removed user {_USER['username']} ({uid})") + if mem.get("mandateId") in removedMandateIds: + try: + db.recordDelete(UserMandate, mem.get("id")) + except Exception: + pass + summary["skipped"].append( + f"User {_USER['username']} ({uid}) preserved (only demo mandate memberships removed)" + ) + logger.info(f"Preserved user {_USER['username']} ({uid}) - removed demo mandate memberships only") except Exception as e: - summary["errors"].append(f"Remove user: {e}") + summary["errors"].append(f"Remove user memberships: {e}") self._removeLanguageSet(db, "es", summary) + summary.pop("_removedMandateIds", None) return summary # ------------------------------------------------------------------ diff --git a/modules/demoConfigs/pwgDemo2026.py b/modules/demoConfigs/pwgDemo2026.py index f0dc5e6d..4a6491a3 100644 --- a/modules/demoConfigs/pwgDemo2026.py +++ b/modules/demoConfigs/pwgDemo2026.py @@ -121,32 +121,39 @@ class PwgDemo2026(_BaseDemoConfig): from modules.datamodels.datamodelMembership import UserMandate from modules.datamodels.datamodelUam import Mandate, UserInDB + removedMandateIds = set() try: existing = db.getRecordset(Mandate, recordFilter={"name": _MANDATE_PWG["name"]}) for m in existing: mid = m.get("id") self._removeMandateData(db, mid, _MANDATE_PWG["label"], summary) db.recordDelete(Mandate, mid) + removedMandateIds.add(mid) summary["removed"].append(f"Mandate {_MANDATE_PWG['label']} ({mid})") logger.info(f"Removed mandate {_MANDATE_PWG['label']} ({mid})") except Exception as e: summary["errors"].append(f"Remove mandate {_MANDATE_PWG['label']}: {e}") + # SAFETY: NEVER delete the user record. The user may have connections, + # chats, workflows, files, and other data across multiple databases. + # Only remove the mandate memberships that THIS demo created. try: existing = db.getRecordset(UserInDB, recordFilter={"username": _USER["username"]}) for u in existing: uid = u.get("id") memberships = db.getRecordset(UserMandate, recordFilter={"userId": uid}) or [] for mem in memberships: - try: - db.recordDelete(UserMandate, mem.get("id")) - except Exception: - pass - db.recordDelete(UserInDB, uid) - summary["removed"].append(f"User {_USER['username']} ({uid})") - logger.info(f"Removed user {_USER['username']} ({uid})") + if mem.get("mandateId") in removedMandateIds: + try: + db.recordDelete(UserMandate, mem.get("id")) + except Exception: + pass + summary["skipped"].append( + f"User {_USER['username']} ({uid}) preserved (only demo mandate memberships removed)" + ) + logger.info(f"Preserved user {_USER['username']} ({uid}) - removed demo mandate memberships only") except Exception as e: - summary["errors"].append(f"Remove user: {e}") + summary["errors"].append(f"Remove user memberships: {e}") return summary diff --git a/modules/features/workspace/datamodelFeatureWorkspace.py b/modules/features/workspace/datamodelFeatureWorkspace.py index 4e32702c..d0ba8815 100644 --- a/modules/features/workspace/datamodelFeatureWorkspace.py +++ b/modules/features/workspace/datamodelFeatureWorkspace.py @@ -2,7 +2,7 @@ # All rights reserved. """Workspace feature data models — WorkspaceUserSettings.""" -from typing import List, Optional +from typing import Dict, List, Optional from pydantic import Field from modules.datamodels.datamodelBase import PowerOnModel from modules.shared.i18nRegistry import i18nModel @@ -52,7 +52,7 @@ class WorkspaceUserSettings(PowerOnModel): description="Max agent rounds override (None = instance default)", json_schema_extra={"label": "Max. Agenten-Runden", "frontend_type": "number", "frontend_readonly": False, "frontend_required": False}, ) - requireNeutralization: bool = Field( + requireNeutralization: Optional[bool] = Field( default=False, description="Default neutralization setting for this user", json_schema_extra={"label": "Neutralisierung", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, @@ -67,3 +67,8 @@ class WorkspaceUserSettings(PowerOnModel): description="Allowed AI models (empty = all permitted)", json_schema_extra={"label": "Erlaubte Modelle", "frontend_type": "modelMultiSelect", "frontend_readonly": False, "frontend_required": False}, ) + uiTreeExpansion: Dict[str, List[str]] = Field( + default_factory=dict, + description="Per-tab expanded tree-node ids for the UDB / FormGeneratorTree. Key = scope name (e.g. 'sources', 'filesOwn', 'filesShared').", + json_schema_extra={"label": "Tree-Expand-Zustand", "frontend_type": "json", "frontend_readonly": True, "frontend_required": False}, + ) diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py index 2fa788e8..5c24c113 100644 --- a/modules/features/workspace/routeFeatureWorkspace.py +++ b/modules/features/workspace/routeFeatureWorkspace.py @@ -1281,52 +1281,101 @@ async def listWorkspaceDataSources( try: from modules.datamodels.datamodelDataSource import DataSource from modules.interfaces.interfaceDbApp import getRootInterface + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import buildEffectiveByConnection rootIf = getRootInterface() recordFilter: dict = {"featureInstanceId": instanceId} if wsMandateId: recordFilter["mandateId"] = wsMandateId dataSources = rootIf.db.getRecordset(DataSource, recordFilter=recordFilter) - return JSONResponse({"dataSources": dataSources or []}) + if not dataSources: + return JSONResponse({"dataSources": []}) + + # Group by connectionId and compute effective values in aggregate mode + byConnection: dict = {} + for ds in dataSources: + connId = ds.get("connectionId") or "" + byConnection.setdefault(connId, []).append(ds) + + for connDs in byConnection.values(): + effNeutralize = buildEffectiveByConnection(connDs, "neutralize", mode="aggregate") + effScope = buildEffectiveByConnection(connDs, "scope", mode="aggregate") + effRag = buildEffectiveByConnection(connDs, "ragIndexEnabled", mode="aggregate") + for ds in connDs: + dsId = ds.get("id", "") + ds["effectiveNeutralize"] = effNeutralize.get(dsId, False) + ds["effectiveScope"] = effScope.get(dsId, "personal") + ds["effectiveRagIndexEnabled"] = effRag.get(dsId, False) + + return JSONResponse({"dataSources": dataSources}) except Exception: return JSONResponse({"dataSources": []}) -@router.get("/{instanceId}/connections") +class _TreeChildrenRequest(BaseModel): + """Request body for the generic tree children endpoint.""" + parents: List[Optional[str]] = Field( + default_factory=list, + description="List of parent keys to fetch children for. Use null for top-level.", + ) + + +@router.post("/{instanceId}/tree/children") @limiter.limit("300/minute") -async def listWorkspaceConnections( +async def getTreeChildren( request: Request, instanceId: str = Path(...), + body: _TreeChildrenRequest = Body(...), context: RequestContext = Depends(getRequestContext), ): - """Return the user's active connections (UserConnections).""" - _mandateId, _ = _validateInstanceAccess(instanceId, context) - from modules.serviceCenter import getService - from modules.serviceCenter.context import ServiceCenterContext - ctx = ServiceCenterContext( - user=context.user, - mandate_id=_mandateId or "", - feature_instance_id=instanceId, + """Generic UDB tree children resolver. + + The UI sends a list of parent keys (or null for top-level). The backend + returns children for each requested parent, with all effective flag + values pre-computed. The UI builds the visible tree from the resulting + flat per-parent map. + """ + _validateInstanceAccess(instanceId, context) + from modules.serviceCenter.services.serviceKnowledge._buildTree import getChildrenForParents + + try: + nodesByParent = await getChildrenForParents(instanceId, body.parents, context) + except Exception as exc: + logger.exception("Tree children build failed: %s", exc) + raise HTTPException(status_code=500, detail=str(exc)) + return JSONResponse({"nodesByParent": nodesByParent}) + + +class _TreeAttributesRequest(BaseModel): + """Request body for the attribute-refresh endpoint.""" + keys: List[str] = Field( + default_factory=list, + description="List of node keys to fetch current attributes for.", ) - chatService = getService("chat", ctx) - connections = chatService.getUserConnections() - items = [] - for c in connections or []: - conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {}) - authority = conn.get("authority") - if hasattr(authority, "value"): - authority = authority.value - status = conn.get("status") - if hasattr(status, "value"): - status = status.value - items.append({ - "id": conn.get("id"), - "authority": authority, - "externalUsername": conn.get("externalUsername"), - "externalEmail": conn.get("externalEmail"), - "status": status, - "knowledgeIngestionEnabled": bool(conn.get("knowledgeIngestionEnabled")), - }) - return JSONResponse({"connections": items}) + + +@router.post("/{instanceId}/tree/attributes") +@limiter.limit("300/minute") +async def getTreeAttributes( + request: Request, + instanceId: str = Path(...), + body: _TreeAttributesRequest = Body(...), + context: RequestContext = Depends(getRequestContext), +): + """Return current effective attribute values (neutralize, scope, + ragIndexEnabled) for a list of node keys. Used after a toggle action + to refresh only the visible nodes without reloading tree structure.""" + _validateInstanceAccess(instanceId, context) + from modules.serviceCenter.services.serviceKnowledge._buildTree import getAttributesForKeys + + if len(body.keys) > 500: + raise HTTPException(status_code=400, detail="Max 500 keys per request") + + try: + attrs = await getAttributesForKeys(instanceId, body.keys, context) + except Exception as exc: + logger.exception("Tree attributes failed: %s", exc) + raise HTTPException(status_code=500, detail=str(exc)) + return JSONResponse({"attributes": attrs}) class CreateDataSourceRequest(BaseModel): @@ -1391,303 +1440,6 @@ async def deleteWorkspaceDataSource( # ---- Feature Connections & Feature Data Sources ---- -@router.get("/{instanceId}/feature-connections") -@limiter.limit("120/minute") -async def listFeatureConnections( - request: Request, - instanceId: str = Path(...), - context: RequestContext = Depends(getRequestContext), -): - """List feature instances the user has access to, scoped to the workspace mandate.""" - wsMandateId, _ = _validateInstanceAccess(instanceId, context) - from modules.interfaces.interfaceDbApp import getRootInterface - from modules.security.rbacCatalog import getCatalogService - from modules.datamodels.datamodelUam import Mandate - - rootIf = getRootInterface() - userId = str(context.user.id) - - catalog = getCatalogService() - featureCodesWithData = catalog.getFeaturesWithDataObjects() - - userMandates = rootIf.getUserMandates(userId) - if not userMandates: - return JSONResponse({"featureConnectionsByMandate": []}) - - allowedMandateIds = {um.mandateId for um in userMandates} - if wsMandateId and wsMandateId in allowedMandateIds: - allowedMandateIds = {wsMandateId} - - mandateLabels: dict = {} - for um in userMandates: - if um.mandateId not in allowedMandateIds: - continue - try: - rows = rootIf.db.getRecordset(Mandate, recordFilter={"id": um.mandateId}) - if rows: - m = rows[0] - mandateLabels[um.mandateId] = m.get("label") or m.get("name") or um.mandateId - except Exception: - mandateLabels[um.mandateId] = um.mandateId - - byMandate: dict = {} - seenIds: set = set() - for um in userMandates: - if um.mandateId not in allowedMandateIds: - continue - allInstances = rootIf.getFeatureInstancesByMandate(um.mandateId) - for inst in allInstances: - if inst.id in seenIds: - continue - seenIds.add(inst.id) - if not inst.enabled: - continue - if inst.featureCode not in featureCodesWithData: - continue - featureAccess = rootIf.getFeatureAccess(userId, inst.id) - if not featureAccess or not featureAccess.enabled: - continue - - featureDef = catalog.getFeatureDefinition(inst.featureCode) or {} - dataObjects = catalog.getDataObjects(inst.featureCode) - label = inst.label or inst.featureCode - mid = inst.mandateId - connItem = { - "featureInstanceId": inst.id, - "featureCode": inst.featureCode, - "mandateId": mid, - "label": label, - "icon": featureDef.get("icon", "mdi-database"), - "tableCount": len(dataObjects), - } - if mid not in byMandate: - byMandate[mid] = [] - byMandate[mid].append(connItem) - - def _sortKeyLabel(x: dict) -> str: - return (x.get("label") or "").lower() - - groups = [] - for mid in sorted(byMandate.keys(), key=lambda m: (mandateLabels.get(m, m) or "").lower()): - conns = sorted(byMandate[mid], key=_sortKeyLabel) - groups.append({ - "mandateId": mid, - "mandateLabel": mandateLabels.get(mid, mid), - "featureConnections": conns, - }) - - return JSONResponse({"featureConnectionsByMandate": groups}) - - -@router.get("/{instanceId}/feature-connections/{fiId}/tables") -@limiter.limit("120/minute") -async def listFeatureConnectionTables( - request: Request, - instanceId: str = Path(...), - fiId: str = Path(..., description="Feature instance ID"), - context: RequestContext = Depends(getRequestContext), -): - """List data tables (DATA_OBJECTS) for a feature instance, filtered by RBAC.""" - wsMandateId, _ = _validateInstanceAccess(instanceId, context) - from modules.interfaces.interfaceDbApp import getRootInterface - from modules.security.rbacCatalog import getCatalogService - - rootIf = getRootInterface() - inst = rootIf.getFeatureInstance(fiId) - if not inst: - raise HTTPException(status_code=404, detail=routeApiMsg("Feature instance not found")) - - mandateId = str(inst.mandateId) if inst.mandateId else None - if wsMandateId and mandateId and mandateId != wsMandateId: - raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate")) - catalog = getCatalogService() - - try: - from modules.security.rbac import RbacClass - from modules.security.rootAccess import getRootDbAppConnector - dbApp = getRootDbAppConnector() - rbac = RbacClass(dbApp, dbApp=dbApp) - accessible = catalog.getAccessibleDataObjects( - featureCode=inst.featureCode, - rbacInstance=rbac, - user=context.user, - mandateId=mandateId or "", - featureInstanceId=fiId, - ) - except Exception: - accessible = catalog.getDataObjects(inst.featureCode) - - accessibleKeys = {obj.get("objectKey", "") for obj in accessible} - referencedGroups = set() - for obj in accessible: - meta = obj.get("meta", {}) - if meta.get("wildcard") or meta.get("isGroup"): - continue - if meta.get("group"): - referencedGroups.add(meta["group"]) - - tables = [] - for obj in catalog.getDataObjects(inst.featureCode): - meta = obj.get("meta", {}) - if meta.get("wildcard"): - continue - objectKey = obj.get("objectKey", "") - if meta.get("isGroup"): - # Groups are metadata-only; include if at least one child is accessible - # (regardless of whether the group itself was RBAC-granted). - if objectKey not in referencedGroups: - continue - else: - if objectKey not in accessibleKeys: - continue - node = { - "objectKey": objectKey, - "tableName": meta.get("table", ""), - "label": resolveText(obj.get("label", "")), - "fields": meta.get("fields", []), - "isParent": bool(meta.get("isParent", False)), - "parentTable": meta.get("parentTable") or None, - "parentKey": meta.get("parentKey") or None, - "displayFields": meta.get("displayFields", []), - "isGroup": bool(meta.get("isGroup", False)), - "group": meta.get("group") or None, - } - tables.append(node) - - return JSONResponse({"tables": tables}) - - -@router.get("/{instanceId}/feature-connections/{fiId}/parent-objects/{tableName}") -@limiter.limit("120/minute") -async def listParentObjects( - request: Request, - instanceId: str = Path(...), - fiId: str = Path(..., description="Feature instance ID"), - tableName: str = Path(..., description="Parent table name from DATA_OBJECTS"), - parentKey: Optional[str] = Query(None, description="Optional FK column name to filter by ancestor record (nested parent rendering)"), - parentValue: Optional[str] = Query(None, description="Optional FK value matching parentKey to filter children of a specific ancestor record"), - context: RequestContext = Depends(getRequestContext), -): - """List records from a parent table so the user can pick a specific record to scope data. - - When parentKey + parentValue are provided, results are additionally filtered by that FK, - enabling nested record hierarchies (e.g. Sessions OF Context X). - """ - wsMandateId, _ = _validateInstanceAccess(instanceId, context) - from modules.interfaces.interfaceDbApp import getRootInterface - from modules.security.rbacCatalog import getCatalogService - - rootIf = getRootInterface() - inst = rootIf.getFeatureInstance(fiId) - if not inst: - raise HTTPException(status_code=404, detail=routeApiMsg("Feature instance not found")) - - featureCode = inst.featureCode - mandateId = str(inst.mandateId) if inst.mandateId else "" - if wsMandateId and mandateId and mandateId != wsMandateId: - raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate")) - catalog = getCatalogService() - - parentObj = None - for obj in catalog.getDataObjects(featureCode): - meta = obj.get("meta", {}) - if meta.get("table") == tableName and meta.get("isParent"): - parentObj = obj - break - if not parentObj: - raise HTTPException(status_code=400, detail=f"Table '{tableName}' is not a registered parent table") - - displayFields = parentObj["meta"].get("displayFields", []) - selectCols = ', '.join(f'"{f}"' for f in (["id"] + displayFields)) if displayFields else "*" - - from modules.connectors.connectorDbPostgre import DatabaseConnector - from modules.shared.configuration import APP_CONFIG - featureDbName = f"poweron_{featureCode.lower()}" - featureDbConn = None - try: - featureDbConn = DatabaseConnector( - dbHost=APP_CONFIG.get("DB_HOST", "localhost"), - dbDatabase=featureDbName, - dbUser=APP_CONFIG.get("DB_USER"), - dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), - dbPort=int(APP_CONFIG.get("DB_PORT", 5432)), - userId=str(context.user.id), - ) - conn = featureDbConn.connection - with conn.cursor() as cur: - cur.execute( - "SELECT column_name FROM information_schema.columns " - "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " - "AND column_name IN ('featureInstanceId', 'instanceId')", - [tableName], - ) - instanceCols = [row["column_name"] for row in cur.fetchall()] - instanceCol = "featureInstanceId" if "featureInstanceId" in instanceCols else "instanceId" - - cur.execute( - "SELECT column_name FROM information_schema.columns " - "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " - "AND column_name = 'userId'", - [tableName], - ) - hasUserId = cur.rowcount > 0 - - sql = ( - f'SELECT {selectCols} FROM "{tableName}" ' - f'WHERE "{instanceCol}" = %s' - ) - params = [fiId] - if mandateId: - sql += ' AND "mandateId" = %s' - params.append(mandateId) - if hasUserId: - sql += ' AND "userId" = %s' - params.append(str(context.user.id)) - - if parentKey and parentValue: - cur.execute( - "SELECT 1 FROM information_schema.columns " - "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " - "AND column_name = %s", - [tableName, parentKey], - ) - if cur.rowcount > 0: - sql += f' AND "{parentKey}" = %s' - params.append(parentValue) - else: - logger.warning( - f"listParentObjects({tableName}): ignoring parentKey '{parentKey}' (column does not exist)" - ) - - sql += ' ORDER BY "id" DESC LIMIT 100' - cur.execute(sql, params) - rows = [] - for row in cur.fetchall(): - r = dict(row) - for k, v in r.items(): - if hasattr(v, "isoformat"): - r[k] = v.isoformat() - elif isinstance(v, (bytes, bytearray)): - r[k] = f"" - displayParts = [str(r.get(f, "")) for f in displayFields if r.get(f) is not None] - rows.append({ - "id": r.get("id", ""), - "displayLabel": " | ".join(displayParts) if displayParts else r.get("id", ""), - "fields": {f: r.get(f) for f in displayFields}, - }) - except Exception as e: - logger.error(f"listParentObjects({tableName}) failed: {e}", exc_info=True) - raise HTTPException(status_code=500, detail=f"Failed to list parent objects: {e}") - finally: - if featureDbConn: - try: - featureDbConn.close() - except Exception: - pass - - return JSONResponse({"parentObjects": rows}) - - class CreateFeatureDataSourceRequest(BaseModel): """Request body for adding a feature table as data source.""" featureInstanceId: str = Field(description="Feature instance ID") @@ -1706,16 +1458,35 @@ async def createFeatureDataSource( body: CreateFeatureDataSourceRequest = Body(...), context: RequestContext = Depends(getRequestContext), ): - """Create a FeatureDataSource for this workspace instance.""" + """Create a FeatureDataSource for this workspace instance. + + The FDS lives under the WORKSPACE's mandate (not the feature's): that + matches how the tree (`allFds = recordset where workspaceInstanceId = + instanceId`) and the PATCH endpoints scope these records — by workspace, + not by feature mandate. The user can legitimately reference a feature + from another mandate they have access to (via the UDB mandate-group + nodes), and a hard cross-mandate block here would silently 403 those + toggles. Access to the referenced feature is verified by the user's + `FeatureAccess` and the existing tree-children RBAC, which run before + the user can ever click on this node. + """ wsMandateId, _ = _validateInstanceAccess(instanceId, context) from modules.interfaces.interfaceDbApp import getRootInterface from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource rootIf = getRootInterface() - inst = rootIf.getFeatureInstance(body.featureInstanceId) - mandateId = str(inst.mandateId) if inst else (str(context.mandateId) if context.mandateId else "") - if wsMandateId and mandateId and mandateId != wsMandateId: - raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate")) + if not rootIf.getFeatureAccess(str(context.user.id), body.featureInstanceId): + raise HTTPException(status_code=403, detail=routeApiMsg("Access denied to this feature instance")) + + existing = rootIf.db.getRecordset(FeatureDataSource, recordFilter={ + "workspaceInstanceId": instanceId, + "featureInstanceId": body.featureInstanceId, + "tableName": body.tableName, + }) or [] + targetFilter = body.recordFilter or None + for rec in existing: + if (rec.get("recordFilter") or None) == targetFilter: + return JSONResponse(rec) fds = FeatureDataSource( featureInstanceId=body.featureInstanceId, @@ -1723,7 +1494,7 @@ async def createFeatureDataSource( tableName=body.tableName, objectKey=body.objectKey, label=body.label, - mandateId=mandateId, + mandateId=wsMandateId or "", userId=str(context.user.id), workspaceInstanceId=instanceId, recordFilter=body.recordFilter, @@ -1743,13 +1514,26 @@ async def listFeatureDataSources( wsMandateId, _ = _validateInstanceAccess(instanceId, context) from modules.interfaces.interfaceDbApp import getRootInterface from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import buildEffectiveByWorkspaceFds rootIf = getRootInterface() recordFilter: dict = {"workspaceInstanceId": instanceId} if wsMandateId: recordFilter["mandateId"] = wsMandateId records = rootIf.db.getRecordset(FeatureDataSource, recordFilter=recordFilter) - return JSONResponse({"featureDataSources": records or []}) + if not records: + return JSONResponse({"featureDataSources": []}) + + effNeutralize = buildEffectiveByWorkspaceFds(records, "neutralize", mode="aggregate") + effScope = buildEffectiveByWorkspaceFds(records, "scope", mode="aggregate") + effRag = buildEffectiveByWorkspaceFds(records, "ragIndexEnabled", mode="aggregate") + for fds in records: + fdsId = fds.get("id", "") + fds["effectiveNeutralize"] = effNeutralize.get(fdsId, False) + fds["effectiveScope"] = effScope.get(fdsId, "personal") + fds["effectiveRagIndexEnabled"] = effRag.get(fdsId, False) + + return JSONResponse({"featureDataSources": records}) @router.delete("/{instanceId}/feature-datasources/{featureDataSourceId}") @@ -1770,112 +1554,6 @@ async def deleteFeatureDataSource( return JSONResponse({"success": True}) -@router.get("/{instanceId}/connections/{connectionId}/services") -@limiter.limit("120/minute") -async def listConnectionServices( - request: Request, - instanceId: str = Path(...), - connectionId: str = Path(...), - context: RequestContext = Depends(getRequestContext), -): - """Return the available services for a specific UserConnection.""" - _mandateId, _ = _validateInstanceAccess(instanceId, context) - try: - from modules.connectors.connectorResolver import ConnectorResolver - from modules.serviceCenter import getService as getSvc - from modules.serviceCenter.context import ServiceCenterContext - ctx = ServiceCenterContext( - user=context.user, - mandate_id=_mandateId or "", - feature_instance_id=instanceId, - ) - chatService = getSvc("chat", ctx) - securityService = getSvc("security", ctx) - dbInterface = _buildResolverDbInterface(chatService) - resolver = ConnectorResolver(securityService, dbInterface) - provider = await resolver.resolve(connectionId) - services = provider.getAvailableServices() - _serviceLabels = { - "sharepoint": "SharePoint", - "outlook": "Outlook", - "teams": "Teams", - "onedrive": "OneDrive", - "drive": "Google Drive", - "gmail": "Gmail", - "files": "Files (FTP)", - "kdrive": "kDrive", - "calendar": "Calendar", - "contact": "Contacts", - } - _serviceIcons = { - "sharepoint": "sharepoint", - "outlook": "mail", - "teams": "chat", - "onedrive": "cloud", - "drive": "cloud", - "gmail": "mail", - "files": "folder", - "kdrive": "cloud", - "calendar": "calendar", - "contact": "contact", - } - items = [ - { - "service": s, - "label": _serviceLabels.get(s, s), - "icon": _serviceIcons.get(s, "folder"), - } - for s in services - ] - return JSONResponse({"services": items}) - except Exception as e: - logger.error(f"Error listing services for connection {connectionId}: {e}") - return JSONResponse({"services": [], "error": str(e)}, status_code=400) - - -@router.get("/{instanceId}/connections/{connectionId}/browse") -@limiter.limit("300/minute") -async def browseConnectionService( - request: Request, - instanceId: str = Path(...), - connectionId: str = Path(...), - service: str = Query(..., description="Service name (e.g. sharepoint, onedrive, outlook)"), - path: str = Query("/", description="Path within the service to browse"), - context: RequestContext = Depends(getRequestContext), -): - """Browse folders/items within a connection's service at a given path.""" - _mandateId, _ = _validateInstanceAccess(instanceId, context) - try: - from modules.connectors.connectorResolver import ConnectorResolver - from modules.serviceCenter import getService as getSvc - from modules.serviceCenter.context import ServiceCenterContext - ctx = ServiceCenterContext( - user=context.user, - mandate_id=_mandateId or "", - feature_instance_id=instanceId, - ) - chatService = getSvc("chat", ctx) - securityService = getSvc("security", ctx) - dbInterface = _buildResolverDbInterface(chatService) - resolver = ConnectorResolver(securityService, dbInterface) - adapter = await resolver.resolveService(connectionId, service) - entries = await adapter.browse(path, filter=None) - items = [] - for entry in (entries or []): - items.append({ - "name": entry.name, - "path": entry.path, - "isFolder": entry.isFolder, - "size": entry.size, - "mimeType": entry.mimeType, - "metadata": entry.metadata if hasattr(entry, "metadata") else {}, - }) - return JSONResponse({"items": items, "path": path, "service": service}) - except Exception as e: - logger.error(f"Error browsing {service} for connection {connectionId} at '{path}': {e}") - return JSONResponse({"items": [], "error": str(e)}, status_code=400) - - # --------------------------------------------------------------------------- # Voice endpoints # --------------------------------------------------------------------------- @@ -2191,6 +1869,71 @@ async def putWorkspaceUserSettings( }) +# ========================================================================= +# Per-user UI state: tree expand/collapse (UDB + FilesTab) +# Persisted on WorkspaceUserSettings.uiTreeExpansion as a {scope: [ids]} map. +# Each FE tab uses its own scope key so collapse-state for one tab doesn't +# bleed into another. + +@router.get("/{instanceId}/ui-tree-expansion/{scope}") +@limiter.limit("300/minute") +async def getUiTreeExpansion( + request: Request, + instanceId: str = Path(...), + scope: str = Path(..., description="UI scope key, e.g. 'sources', 'filesOwn', 'filesShared'"), + context: RequestContext = Depends(getRequestContext), +): + """Return the expanded tree-node ids for the current user + scope. + + Returns `null` when the user has never persisted a state for this scope + (lets the FE fall back to backend `defaultExpanded` hints). Returns `[]` + when the user actively collapsed everything. + """ + _validateInstanceAccess(instanceId, context) + wsInterface = _getWorkspaceInterface(context, instanceId) + settings = wsInterface.getWorkspaceUserSettings(str(context.user.id)) + expansion = (settings.uiTreeExpansion if settings else {}) or {} + if scope not in expansion: + return JSONResponse({"expandedNodes": None}) + return JSONResponse({"expandedNodes": list(expansion.get(scope) or [])}) + + +@router.put("/{instanceId}/ui-tree-expansion/{scope}") +@limiter.limit("300/minute") +async def putUiTreeExpansion( + request: Request, + instanceId: str = Path(...), + scope: str = Path(...), + body: dict = Body(...), + context: RequestContext = Depends(getRequestContext), +): + """Replace the expanded-node list for one scope. + + Body: `{"expandedNodes": List[str]}`. Empty list = explicit collapse-all. + """ + _validateInstanceAccess(instanceId, context) + wsInterface = _getWorkspaceInterface(context, instanceId) + userId = str(context.user.id) + nodes = body.get("expandedNodes") + if not isinstance(nodes, list): + raise HTTPException(status_code=400, detail=routeApiMsg("expandedNodes must be a list")) + cleaned = [str(n) for n in nodes if isinstance(n, (str, int))] + + existing = wsInterface.getWorkspaceUserSettings(userId) + existingMap: Dict[str, List[str]] = (existing.uiTreeExpansion if existing else {}) or {} + existingMap = dict(existingMap) + existingMap[scope] = cleaned + + data = { + "userId": userId, + "mandateId": str(context.mandateId) if context.mandateId else "", + "featureInstanceId": instanceId, + "uiTreeExpansion": existingMap, + } + wsInterface.saveWorkspaceUserSettings(data) + return JSONResponse({"expandedNodes": cleaned}) + + # ========================================================================= # RAG / Knowledge — anonymised instance statistics (presentation / KPIs) diff --git a/modules/routes/routeAdminDemoConfig.py b/modules/routes/routeAdminDemoConfig.py index db37e775..0673c299 100644 --- a/modules/routes/routeAdminDemoConfig.py +++ b/modules/routes/routeAdminDemoConfig.py @@ -68,9 +68,19 @@ def removeDemoConfig( request: Request, currentUser: User = Depends(requirePlatformAdmin), ) -> dict: - """Remove all data created by a demo configuration.""" + """Remove all data created by a demo configuration. + + Requires X-Confirm-Destructive: true header as safety guard. + """ from modules.demoConfigs import getDemoConfigByCode + confirmHeader = request.headers.get("X-Confirm-Destructive", "").lower() + if confirmHeader != "true": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Destructive operation requires header X-Confirm-Destructive: true", + ) + config = getDemoConfigByCode(code) if not config: raise HTTPException( @@ -79,7 +89,7 @@ def removeDemoConfig( ) db = getRootDbAppConnector() - logger.info(f"Removing demo config '{code}' (user: {currentUser.username})") + logger.info(f"Removing demo config '{code}' (user: {currentUser.username}, confirmed)") summary = config.remove(db) logger.info(f"Demo config '{code}' removed: {summary}") diff --git a/modules/routes/routeDataConnections.py b/modules/routes/routeDataConnections.py index e2b08461..2bc48042 100644 --- a/modules/routes/routeDataConnections.py +++ b/modules/routes/routeDataConnections.py @@ -778,7 +778,12 @@ async def _updateKnowledgeConsent( cancelled = cancelJobsByConnection(connectionId) else: from modules.datamodels.datamodelDataSource import DataSource - dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId, "ragIndexEnabled": True}) + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag + allConnDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + dataSources = [ + ds for ds in (allConnDs or []) + if getEffectiveFlag(ds, "ragIndexEnabled", allConnDs, mode="walk") is True + ] if dataSources: from modules.serviceCenter.services.serviceBackgroundJobs import startJob authority = connection.authority.value if hasattr(connection.authority, "value") else str(connection.authority or "") diff --git a/modules/routes/routeDataFiles.py b/modules/routes/routeDataFiles.py index b22dacae..4bcbcf8f 100644 --- a/modules/routes/routeDataFiles.py +++ b/modules/routes/routeDataFiles.py @@ -211,7 +211,7 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user, *, man from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob - await knowledgeService.requestIngestion( + handle = await knowledgeService.requestIngestion( IngestionJob( sourceKind="file", sourceId=fileId, @@ -229,7 +229,10 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user, *, man # Re-acquire interface after await to avoid stale user context from the singleton mgmtInterface = interfaceDbManagement.getInterface(user) mgmtInterface.updateFile(fileId, {"status": "active"}) - logger.info(f"Auto-index complete for file {fileId} ({fileName})") + if handle.status == "failed": + logger.warning(f"Auto-index ingestion failed for file {fileId} ({fileName}): {handle.error}") + else: + logger.info(f"Auto-index complete for file {fileId} ({fileName})") except Exception as e: logger.error(f"Auto-index failed for file {fileId}: {e}", exc_info=True) @@ -256,6 +259,24 @@ router = APIRouter( ) +def _getInterfaceForOwnedItem(currentUser: User, context, itemId: str, modelClass) -> Any: + """Create a management interface scoped to the item's own context. + Looks up the item by ID (unscoped) to resolve its mandateId/featureInstanceId, + then creates the interface with THAT context. This ensures toggle operations + work regardless of which page the user is on.""" + unscoped = interfaceDbManagement.getInterface(currentUser) + record = unscoped.db.getRecord(modelClass, itemId) + if not record: + raise interfaceDbManagement.FileNotFoundError(f"Item {itemId} not found") + itemMandateId = record.get("mandateId") if isinstance(record, dict) else getattr(record, "mandateId", None) + itemInstanceId = record.get("featureInstanceId") if isinstance(record, dict) else getattr(record, "featureInstanceId", None) + return interfaceDbManagement.getInterface( + currentUser, + mandateId=str(itemMandateId) if itemMandateId else None, + featureInstanceId=str(itemInstanceId) if itemInstanceId else None, + ) + + @router.get("/folders/tree") @limiter.limit("120/minute") def get_folder_tree( @@ -272,10 +293,12 @@ def get_folder_tree( ) o = (owner or "me").strip().lower() if o == "me": - return managementInterface.getOwnFolderTree() - if o == "shared": - return managementInterface.getSharedFolderTree() - raise HTTPException(status_code=400, detail="owner must be 'me' or 'shared'") + folders = managementInterface.getOwnFolderTree() + elif o == "shared": + folders = managementInterface.getSharedFolderTree() + else: + raise HTTPException(status_code=400, detail="owner must be 'me' or 'shared'") + return folders except HTTPException: raise except Exception as e: @@ -283,6 +306,185 @@ def get_folder_tree( raise HTTPException(status_code=500, detail=str(e)) +@router.post("/attributes") +@limiter.limit("120/minute") +def getAttributesForIds( + request: Request, + body: Dict[str, Any] = Body(...), + currentUser: User = Depends(getCurrentUser), + context: RequestContext = Depends(getRequestContext), +): + """Return current attribute values (neutralize, scope, ragIndexEnabled) for + a list of node IDs. For folder IDs, computes 'mixed' by checking direct + children. The frontend sends this after every toggle to refresh visible + nodes without reloading the tree structure.""" + ids = body.get("ids", []) + if not isinstance(ids, list) or len(ids) == 0: + return {} + if len(ids) > 500: + raise HTTPException(status_code=400, detail="Max 500 IDs per request") + + try: + managementInterface = interfaceDbManagement.getInterface( + currentUser, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, + ) + db = managementInterface.db + userId = str(currentUser.id) + + allFolders = db.getRecordset(FileFolder, recordFilter={"sysCreatedBy": userId}) or [] + allFiles = db.getRecordset(FileItem, recordFilter={"sysCreatedBy": userId}) or [] + + folderById = {f["id"]: f for f in allFolders} + fileById = {f["id"]: f for f in allFiles} + + logger.info( + "getAttributesForIds: %d ids requested, %d folders found, %d files found", + len(ids), len(allFolders), len(allFiles), + ) + + result: Dict[str, Dict[str, Any]] = {} + + for nodeId in ids: + if nodeId.startswith("__filesRoot:"): + attrs = _computeSyntheticRootAttrs(allFolders, allFiles) + result[nodeId] = attrs + elif nodeId in folderById: + folder = folderById[nodeId] + attrs = _computeFolderAttrs(folder, allFolders, allFiles) + result[nodeId] = attrs + elif nodeId in fileById: + f = fileById[nodeId] + result[nodeId] = { + "neutralize": bool(f.get("neutralize", False)), + "scope": f.get("scope", "personal"), + } + else: + logger.debug("getAttributesForIds: unknown id=%s", nodeId) + + logger.info("getAttributesForIds: returning %d entries", len(result)) + return result + except HTTPException: + raise + except Exception as e: + logger.error(f"getAttributesForIds error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +def _computeFolderAttrs( + folder: Dict[str, Any], + allFolders: List[Dict[str, Any]], + allFiles: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Compute attributes for a folder. Recursively checks the entire subtree: + if ANY descendant at any depth has a different value, the folder shows 'mixed'. + This propagates up through all ancestor levels.""" + fid = folder["id"] + neutralizeResult = _effectiveNeutralize(fid, allFolders, allFiles) + scopeResult = _effectiveScope(fid, allFolders, allFiles) + return {"neutralize": neutralizeResult, "scope": scopeResult} + + +def _effectiveNeutralize( + folderId: str, + allFolders: List[Dict[str, Any]], + allFiles: List[Dict[str, Any]], +) -> Any: + """Recursively compute effective neutralize for a folder. + Returns 'mixed' if any descendants diverge, otherwise the folder's own value.""" + childFolders = [f for f in allFolders if f.get("parentId") == folderId] + childFiles = [f for f in allFiles if f.get("folderId") == folderId] + + if not childFolders and not childFiles: + folder = next((f for f in allFolders if f["id"] == folderId), None) + return bool(folder.get("neutralize", False)) if folder else False + + childVals = set() + for cf in childFolders: + effective = _effectiveNeutralize(cf["id"], allFolders, allFiles) + if effective == "mixed": + return "mixed" + childVals.add(effective) + for cf in childFiles: + childVals.add(bool(cf.get("neutralize", False))) + + if len(childVals) > 1: + return "mixed" + if not childVals: + folder = next((f for f in allFolders if f["id"] == folderId), None) + return bool(folder.get("neutralize", False)) if folder else False + return childVals.pop() + + +def _effectiveScope( + folderId: str, + allFolders: List[Dict[str, Any]], + allFiles: List[Dict[str, Any]], +) -> Any: + """Recursively compute effective scope for a folder. + Returns 'mixed' if any descendants diverge, otherwise the folder's own value.""" + childFolders = [f for f in allFolders if f.get("parentId") == folderId] + childFiles = [f for f in allFiles if f.get("folderId") == folderId] + + if not childFolders and not childFiles: + folder = next((f for f in allFolders if f["id"] == folderId), None) + return folder.get("scope", "personal") if folder else "personal" + + childVals = set() + for cf in childFolders: + effective = _effectiveScope(cf["id"], allFolders, allFiles) + if effective == "mixed": + return "mixed" + childVals.add(effective) + for cf in childFiles: + childVals.add(cf.get("scope", "personal")) + + if len(childVals) > 1: + return "mixed" + if not childVals: + folder = next((f for f in allFolders if f["id"] == folderId), None) + return folder.get("scope", "personal") if folder else "personal" + return childVals.pop() + + +def _computeSyntheticRootAttrs( + allFolders: List[Dict[str, Any]], + allFiles: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Compute attributes for the synthetic root by recursively checking the + entire tree. If ANY item at any depth diverges, root shows 'mixed'.""" + topFolders = [f for f in allFolders if not f.get("parentId")] + topFiles = [f for f in allFiles if not f.get("folderId")] + + neutralizeVals = set() + scopeVals = set() + for cf in topFolders: + nEff = _effectiveNeutralize(cf["id"], allFolders, allFiles) + if nEff == "mixed": + neutralizeVals.add(True) + neutralizeVals.add(False) + else: + neutralizeVals.add(nEff) + sEff = _effectiveScope(cf["id"], allFolders, allFiles) + if sEff == "mixed": + scopeVals.add("__mixed_a__") + scopeVals.add("__mixed_b__") + else: + scopeVals.add(sEff) + for cf in topFiles: + neutralizeVals.add(bool(cf.get("neutralize", False))) + scopeVals.add(cf.get("scope", "personal")) + + if not neutralizeVals and not scopeVals: + return {"neutralize": False, "scope": "personal"} + + return { + "neutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False), + "scope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"), + } + + @router.post("/folders", status_code=status.HTTP_201_CREATED) @limiter.limit("30/minute") def create_folder( @@ -353,7 +555,12 @@ def move_folder( context: RequestContext = Depends(getRequestContext), ): try: + # FE may send `parentId` or `targetParentId`. Accept both so the + # FormGeneratorTree generic `provider.moveNodes(targetParentId)` API + # remains consistent with the file-move (PUT /api/files/{id}) shape. newParentId = body.get("parentId") + if newParentId is None: + newParentId = body.get("targetParentId") managementInterface = interfaceDbManagement.getInterface( currentUser, mandateId=str(context.mandateId) if context.mandateId else None, @@ -414,11 +621,7 @@ def patch_folder_scope( if not scope: raise HTTPException(status_code=400, detail="scope is required") cascadeToFiles = body.get("cascadeChildren", body.get("cascadeToFiles", False)) - managementInterface = interfaceDbManagement.getInterface( - currentUser, - mandateId=str(context.mandateId) if context.mandateId else None, - featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, - ) + managementInterface = _getInterfaceForOwnedItem(currentUser, context, folderId, FileFolder) return managementInterface.patchFolderScope(folderId, scope, cascadeToFiles) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -446,11 +649,7 @@ def patch_folder_neutralize( neutralize = body.get("neutralize") if neutralize is None: raise HTTPException(status_code=400, detail="neutralize is required") - managementInterface = interfaceDbManagement.getInterface( - currentUser, - mandateId=str(context.mandateId) if context.mandateId else None, - featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, - ) + managementInterface = _getInterfaceForOwnedItem(currentUser, context, folderId, FileFolder) return managementInterface.patchFolderNeutralize(folderId, bool(neutralize)) except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) @@ -1031,11 +1230,7 @@ def updateFileScope( if scope == "global" and not context.isSysAdmin: raise HTTPException(status_code=403, detail=routeApiMsg("Only sysadmins can set global scope")) - managementInterface = interfaceDbManagement.getInterface( - context.user, - mandateId=str(context.mandateId) if context.mandateId else None, - featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, - ) + managementInterface = _getInterfaceForOwnedItem(context.user, context, fileId, FileItem) managementInterface.updateFile(fileId, {"scope": scope}) @@ -1093,11 +1288,7 @@ def updateFileNeutralize( fails the file simply has no index — no un-neutralized data can leak. """ try: - managementInterface = interfaceDbManagement.getInterface( - context.user, - mandateId=str(context.mandateId) if context.mandateId else None, - featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, - ) + managementInterface = _getInterfaceForOwnedItem(context.user, context, fileId, FileItem) managementInterface.updateFile(fileId, {"neutralize": neutralize}) @@ -1212,7 +1403,8 @@ def update_file( request: Request, fileId: str = Path(..., description="ID of the file to update"), file_info: Dict[str, Any] = Body(...), - currentUser: User = Depends(getCurrentUser) + currentUser: User = Depends(getCurrentUser), + context: RequestContext = Depends(getRequestContext), ) -> FileItem: """Update file info""" try: @@ -1221,7 +1413,11 @@ def update_file( if not safeData: raise HTTPException(status_code=400, detail=routeApiMsg("No editable fields provided")) - managementInterface = interfaceDbManagement.getInterface(currentUser) + managementInterface = interfaceDbManagement.getInterface( + currentUser, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, + ) file = managementInterface.getFile(fileId) if not file: @@ -1267,10 +1463,15 @@ def update_file( def delete_file( request: Request, fileId: str = Path(..., description="ID of the file to delete"), - currentUser: User = Depends(getCurrentUser) + currentUser: User = Depends(getCurrentUser), + context: RequestContext = Depends(getRequestContext), ) -> Dict[str, Any]: """Delete a file""" - managementInterface = interfaceDbManagement.getInterface(currentUser) + managementInterface = interfaceDbManagement.getInterface( + currentUser, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, + ) # Check if the file exists existingFile = managementInterface.getFile(fileId) diff --git a/modules/routes/routeDataSources.py b/modules/routes/routeDataSources.py index 5dec19c8..b2f919b7 100644 --- a/modules/routes/routeDataSources.py +++ b/modules/routes/routeDataSources.py @@ -43,6 +43,49 @@ def _ensureConnectionKnowledgeFlag(rootIf, connectionId: str) -> None: except Exception as e: logger.warning("Could not auto-enable knowledgeIngestionEnabled for connection %s: %s", connectionId, e) +def _computeOwnEffective(rootIf, rec, model, sourceId: str, flag: str) -> Any: + """Re-load the record after modification and compute its aggregate effective value.""" + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + getEffectiveFlag, getEffectiveFlagFds, + ) + freshRec = rootIf.db.getRecord(model, sourceId) + if not freshRec: + return None + if model is DataSource: + connectionId = freshRec.get("connectionId", "") + allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + return getEffectiveFlag(freshRec, flag, allDs, mode="aggregate") + else: + wsId = freshRec.get("workspaceInstanceId", "") + allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": wsId}) + return getEffectiveFlagFds(freshRec, flag, allFds, mode="aggregate") + + +def _computeAncestorEffectives(rootIf, rec, model, flag: str) -> List[Dict[str, Any]]: + """Compute the aggregate effective value for all ancestors of `rec`.""" + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + collectAncestorChain, collectAncestorChainFds, + getEffectiveFlag, getEffectiveFlagFds, + ) + effectiveKey = f"effective{flag[0].upper()}{flag[1:]}" + if model is DataSource: + connectionId = rec.get("connectionId", "") + allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + ancestors = collectAncestorChain(rec, allDs) + return [ + {"id": a.get("id") or getattr(a, "id", ""), effectiveKey: getEffectiveFlag(a, flag, allDs, mode="aggregate")} + for a in ancestors + ] + else: + wsId = rec.get("workspaceInstanceId", "") + allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": wsId}) + ancestors = collectAncestorChainFds(rec, allFds) + return [ + {"id": a.get("id") or getattr(a, "id", ""), effectiveKey: getEffectiveFlagFds(a, flag, allFds, mode="aggregate")} + for a in ancestors + ] + + router = APIRouter( prefix="/api/datasources", tags=["Data Sources"], @@ -91,26 +134,41 @@ def _updateDataSourceScope( try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( - cascadeResetDescendants, - cascadeResetDescendantsFds, + cascadeResetDescendants, cascadeResetDescendantsFds, + getEffectiveFlag, getEffectiveFlagFds, + collectAncestorChain, collectAncestorChainFds, ) rootIf = getRootInterface() rec, model = _findSourceRecord(rootIf.db, sourceId) if not rec: raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") - rootIf.db.recordModify(model, sourceId, {"scope": scope}) - cascaded = 0 + # 1. Cascade reset descendants bottom-up (before modifying master) + resetIds: List[str] = [] if scope is not None: if model is DataSource: - cascaded = cascadeResetDescendants(rootIf, rec, "scope") + resetIds = cascadeResetDescendants(rootIf, rec, "scope") else: - cascaded = cascadeResetDescendantsFds(rootIf, rec, "scope") + resetIds = cascadeResetDescendantsFds(rootIf, rec, "scope") + + # 2. Set master value last (crash-safe) + rootIf.db.recordModify(model, sourceId, {"scope": scope}) + + # 3. Compute effective + ancestor chain for response + updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "scope") + effectiveScope = _computeOwnEffective(rootIf, rec, model, sourceId, "scope") + logger.info( "Updated scope=%s for %s %s (cascade-reset %d descendants)", - scope, model.__name__, sourceId, cascaded, + scope, model.__name__, sourceId, len(resetIds), ) - return {"sourceId": sourceId, "scope": scope, "updated": True, "cascadedDescendants": cascaded} + return { + "sourceId": sourceId, + "scope": scope, + "effectiveScope": effectiveScope, + "resetDescendantIds": resetIds, + "updatedAncestors": updatedAncestors, + } except HTTPException: raise except Exception as e: @@ -133,26 +191,39 @@ def _updateDataSourceNeutralize( try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( - cascadeResetDescendants, - cascadeResetDescendantsFds, + cascadeResetDescendants, cascadeResetDescendantsFds, ) rootIf = getRootInterface() rec, model = _findSourceRecord(rootIf.db, sourceId) if not rec: raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") - rootIf.db.recordModify(model, sourceId, {"neutralize": neutralize}) - cascaded = 0 + # 1. Cascade reset descendants bottom-up (before modifying master) + resetIds: List[str] = [] if neutralize is not None: if model is DataSource: - cascaded = cascadeResetDescendants(rootIf, rec, "neutralize") + resetIds = cascadeResetDescendants(rootIf, rec, "neutralize") else: - cascaded = cascadeResetDescendantsFds(rootIf, rec, "neutralize") + resetIds = cascadeResetDescendantsFds(rootIf, rec, "neutralize") + + # 2. Set master value last (crash-safe) + rootIf.db.recordModify(model, sourceId, {"neutralize": neutralize}) + + # 3. Compute effective + ancestor chain for response + updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "neutralize") + effectiveNeutralize = _computeOwnEffective(rootIf, rec, model, sourceId, "neutralize") + logger.info( "Updated neutralize=%s for %s %s (cascade-reset %d descendants)", - neutralize, model.__name__, sourceId, cascaded, + neutralize, model.__name__, sourceId, len(resetIds), ) - return {"sourceId": sourceId, "neutralize": neutralize, "updated": True, "cascadedDescendants": cascaded} + return { + "sourceId": sourceId, + "neutralize": neutralize, + "effectiveNeutralize": effectiveNeutralize, + "resetDescendantIds": resetIds, + "updatedAncestors": updatedAncestors, + } except HTTPException: raise except Exception as e: @@ -204,46 +275,57 @@ async def _updateDataSourceRagIndex( `True` enqueues a mini-bootstrap. `False` synchronously purges chunks. Must be `async def` so `await startJob(...)` registers `_runJob` in the - main event loop. Sync route → worker thread → temporary loop closes - before the task runs → job stays stuck forever. + main event loop. """ try: from modules.interfaces.interfaceDbApp import getRootInterface - from modules.serviceCenter.services.serviceKnowledge._inheritFlags import cascadeResetDescendants + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + cascadeResetDescendants, cascadeResetDescendantsFds, + ) rootIf = getRootInterface() - rec = rootIf.db.getRecord(DataSource, sourceId) + rec, model = _findSourceRecord(rootIf.db, sourceId) if not rec: raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") - rootIf.db.recordModify(DataSource, sourceId, {"ragIndexEnabled": ragIndexEnabled}) - cascaded = 0 + # 1. Cascade reset descendants bottom-up (before modifying master) + resetIds: List[str] = [] if ragIndexEnabled is not None: - cascaded = cascadeResetDescendants(rootIf, rec, "ragIndexEnabled") + if model is DataSource: + resetIds = cascadeResetDescendants(rootIf, rec, "ragIndexEnabled") + else: + resetIds = cascadeResetDescendantsFds(rootIf, rec, "ragIndexEnabled") + + # 2. Set master value last (crash-safe) + rootIf.db.recordModify(model, sourceId, {"ragIndexEnabled": ragIndexEnabled}) + logger.info( - "Updated ragIndexEnabled=%s for DataSource %s (cascade-reset %d descendants)", - ragIndexEnabled, sourceId, cascaded, + "Updated ragIndexEnabled=%s for %s %s (cascade-reset %d descendants)", + ragIndexEnabled, model.__name__, sourceId, len(resetIds), ) - connectionId = rec.get("connectionId") or rec.get("connection_id") or "" - if ragIndexEnabled is True: - _ensureConnectionKnowledgeFlag(rootIf, connectionId) - from modules.serviceCenter.services.serviceBackgroundJobs import startJob + # Bootstrap / purge only for personal DataSource (file/folder-based RAG). + # FDS RAG is handled by the feature pipeline; the flag alone is enough. + if model is DataSource: + connectionId = rec.get("connectionId") or rec.get("connection_id") or "" + if ragIndexEnabled is True: + _ensureConnectionKnowledgeFlag(rootIf, connectionId) + from modules.serviceCenter.services.serviceBackgroundJobs import startJob - conn = rootIf.getUserConnectionById(connectionId) if connectionId else None - authority = "" - if conn: - authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "") + conn = rootIf.getUserConnectionById(connectionId) if connectionId else None + authority = "" + if conn: + authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "") - await startJob( - "connection.bootstrap", - {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]}, - triggeredBy=str(context.user.id), - ) - elif ragIndexEnabled is False: - from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface - purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId) - logger.info("Purged %d index rows / %d chunks for DataSource %s", - purgeResult.get("indexRows", 0), purgeResult.get("chunks", 0), sourceId) + await startJob( + "connection.bootstrap", + {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]}, + triggeredBy=str(context.user.id), + ) + elif ragIndexEnabled is False: + from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface + purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId) + logger.info("Purged %d index rows / %d chunks for DataSource %s", + purgeResult.get("indexRows", 0), purgeResult.get("chunks", 0), sourceId) import json from modules.shared.auditLogger import audit_logger @@ -253,10 +335,20 @@ async def _updateDataSourceRagIndex( mandateId=context.mandateId, category=AuditCategory.PERMISSION.value, action="rag_index_toggled", - details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "cascadedDescendants": cascaded}), + details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "resetDescendants": len(resetIds), "model": model.__name__}), ) - return {"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "updated": True, "cascadedDescendants": cascaded} + # 3. Compute effective + ancestors for response + updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "ragIndexEnabled") + effectiveRag = _computeOwnEffective(rootIf, rec, model, sourceId, "ragIndexEnabled") + + return { + "sourceId": sourceId, + "ragIndexEnabled": ragIndexEnabled, + "effectiveRagIndexEnabled": effectiveRag, + "resetDescendantIds": resetIds, + "updatedAncestors": updatedAncestors, + } except HTTPException: raise except Exception as e: @@ -339,7 +431,17 @@ def _updateDataSourceSettings( ownerId = str(rec.get("userId") or "") currentUserId = str(context.user.id) if ownerId and ownerId != currentUserId and not context.isSysAdmin: - scope = str(rec.get("scope") or "personal") + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag + if model is DataSource: + connectionId = rec.get("connectionId", "") + allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + scope = str(getEffectiveFlag(rec, "scope", allDs, mode="walk")) + else: + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource as FDS + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds + wsId = rec.get("workspaceInstanceId", "") + allFds = rootIf.db.getRecordset(FDS, recordFilter={"workspaceInstanceId": wsId}) + scope = str(getEffectiveFlagFds(rec, "scope", allFds, mode="walk")) isMandateAdmin = getattr(context, "isMandateAdmin", False) if scope == "personal" or not isMandateAdmin: raise HTTPException(status_code=403, detail="Not allowed to modify this DataSource's settings") diff --git a/modules/routes/routeRagInventory.py b/modules/routes/routeRagInventory.py index 99d5c4df..6a5e9eb5 100644 --- a/modules/routes/routeRagInventory.py +++ b/modules/routes/routeRagInventory.py @@ -86,6 +86,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L """ from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelKnowledge import FileContentIndex + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag out = [] for conn in connections: @@ -136,8 +137,8 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""), "path": dsPath, "sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""), - "ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False), - "neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False), + "ragIndexEnabled": getEffectiveFlag(ds, "ragIndexEnabled", dataSources, mode="walk"), + "neutralize": getEffectiveFlag(ds, "neutralize", dataSources, mode="walk"), "lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None), "fileCount": filesByDs.get(dsId, 0), "chunkCount": chunksByDs.get(dsId, 0), @@ -223,13 +224,165 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L return out +def _buildFeatureInstanceInventory(featureInstanceIds, rootIf, knowledgeIf) -> List[Dict[str, Any]]: + """Build per-feature-instance RAG inventory rows. + + Feature-instance data lives in FileContentIndex with a non-empty + featureInstanceId. Additionally each feature instance may have + FeatureDataSource rows that define which tables/data are visible + as sources, with their own ragIndexEnabled flags. + Includes feature.bootstrap job status (running/success/error). + """ + from modules.datamodels.datamodelKnowledge import FileContentIndex + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + from modules.interfaces.interfaceFeatures import getFeatureInterface + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds + from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService + from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import FEATURE_BOOTSTRAP_JOB_TYPE + + featureIf = getFeatureInterface(rootIf.db) + + allFeatureJobs = jobService.listJobs(jobType=FEATURE_BOOTSTRAP_JOB_TYPE, limit=100) + + out = [] + for fiId in featureInstanceIds: + instance = featureIf.getFeatureInstance(fiId) + if not instance or not instance.enabled: + continue + + indexRows = knowledgeIf.db.getRecordset( + FileContentIndex, recordFilter={"featureInstanceId": fiId} + ) + fileIds = [ + (r.get("id") if isinstance(r, dict) else getattr(r, "id", "")) + for r in indexRows + ] + fileIds = [fid for fid in fileIds if fid] + chunkCounts = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {} + + statusCounts: Dict[str, int] = {} + for r in indexRows: + st = (r.get("status") if isinstance(r, dict) else getattr(r, "status", "unknown")) or "unknown" + statusCounts[st] = statusCounts.get(st, 0) + 1 + + allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": fiId}) + dsItems = [] + anyRagEnabled = False + for fds in allFds: + tblName = (fds.get("tableName") if isinstance(fds, dict) else getattr(fds, "tableName", "")) or "" + fCode = (fds.get("featureCode") if isinstance(fds, dict) else getattr(fds, "featureCode", "")) or "" + if tblName == "*" or not fCode: + continue + fdsId = fds.get("id") if isinstance(fds, dict) else getattr(fds, "id", "") + ragEnabled = getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate") + if ragEnabled: + anyRagEnabled = True + dsItems.append({ + "id": fdsId, + "label": (fds.get("label") if isinstance(fds, dict) else getattr(fds, "label", "")) or "", + "tableName": tblName, + "featureCode": fCode, + "ragIndexEnabled": ragEnabled, + }) + + fiJobs = [ + j for j in allFeatureJobs + if (j.get("payload") or {}).get("workspaceInstanceId") == fiId + ] + runningJobs = [ + { + "jobId": j["id"], + "progress": j.get("progress", 0), + "progressMessage": ( + resolveJobMessage(j.get("progressMessageData")) + or j.get("progressMessage", "") + ), + } + for j in fiJobs + if j.get("status") in ("PENDING", "RUNNING") + ] + lastError: Optional[Dict[str, Any]] = None + lastSuccess: Optional[Dict[str, Any]] = None + for j in fiJobs: + jStatus = j.get("status") + if jStatus == "ERROR" and lastError is None: + lastError = { + "jobId": j["id"], + "errorMessage": j.get("errorMessage", ""), + "finishedAt": j.get("finishedAt"), + } + elif jStatus == "SUCCESS" and lastSuccess is None: + result = j.get("result") or {} + lastSuccess = { + "jobId": j["id"], + "finishedAt": j.get("finishedAt"), + "indexed": result.get("indexed", 0), + "skippedDuplicate": result.get("skippedDuplicate", 0), + "failed": result.get("failed", 0), + } + if lastError and lastSuccess: + break + + if not indexRows and not dsItems: + continue + + out.append({ + "featureInstanceId": fiId, + "featureCode": instance.featureCode, + "label": instance.label or instance.featureCode, + "mandateId": str(instance.mandateId) if instance.mandateId else "", + "fileCount": len(indexRows), + "chunkCount": sum(chunkCounts.values()), + "statusCounts": statusCounts, + "dataSources": dsItems, + "ragEnabled": anyRagEnabled, + "runningJobs": runningJobs, + "lastSuccess": lastSuccess, + "lastError": lastError, + }) + return out + + +@router.get("/my-mandates") +@limiter.limit("30/minute") +def _getMyMandates( + request: Request, + currentUser: User = Depends(getCurrentUser), +) -> List[Dict[str, Any]]: + """Return mandates where the current user has an active membership. + + Used by the RAG inventory frontend to populate the mandate dropdown + without requiring admin rights (unlike GET /api/mandates/). + """ + try: + from modules.interfaces.interfaceDbApp import getRootInterface + rootIf = getRootInterface() + userMandates = rootIf.getUserMandates(str(currentUser.id)) + result = [] + for um in userMandates: + if not um.enabled: + continue + mandate = rootIf.getMandate(str(um.mandateId)) + if not mandate or not getattr(mandate, "enabled", True): + continue + result.append({ + "id": str(um.mandateId), + "name": getattr(mandate, "name", ""), + "label": getattr(mandate, "label", None) or getattr(mandate, "name", ""), + }) + return result + except Exception as e: + logger.error("Error in RAG inventory /my-mandates: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/me") @limiter.limit("30/minute") def _getInventoryMe( request: Request, currentUser: User = Depends(getCurrentUser), ) -> Dict[str, Any]: - """Personal RAG inventory: own connections + DataSources + chunk counts.""" + """Personal RAG inventory: own connections + DataSources + chunk counts + feature uploads.""" try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface @@ -243,7 +396,20 @@ def _getInventoryMe( totalChunks = sum(c.get("totalChunks", 0) for c in items) totalFiles = sum(c.get("totalFiles", 0) for c in items) - return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}} + featureAccesses = rootIf.getFeatureAccessesForUser(str(currentUser.id)) + fiIds = [ + str(fa.featureInstanceId) for fa in featureAccesses + if fa.enabled and fa.featureInstanceId + ] + fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf) + totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems) + totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems) + + return { + "connections": items, + "featureInstances": fiItems, + "totals": {"files": totalFiles, "chunks": totalChunks}, + } except Exception as e: logger.error("Error in RAG inventory /me: %s", e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @@ -262,21 +428,43 @@ def _getInventoryMandate( from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface, aggregateMandateRagTotalBytes from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService - rootIf = getRootInterface() knowledgeIf = getKnowledgeInterface(None) - mandateId = str(context.mandateId) if context.mandateId else "" + mandateId = str(context.mandateId) + userId = str(context.user.id) - from modules.datamodels.datamodelUam import UserConnection - allConnections = rootIf.db.getRecordset(UserConnection, recordFilter={"mandateId": mandateId}) - connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections] + userMandates = rootIf.getUserMandates(userId) + isMember = any( + getattr(um, "mandateId", None) == mandateId and um.enabled + for um in userMandates + ) + if not isMember and not context.isSysAdmin: + raise HTTPException(status_code=403, detail=routeApiMsg("No membership in this mandate")) - items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) + mandateMembers = rootIf.getUserMandatesByMandate(mandateId) + memberUserIds = {getattr(um, "userId", None) for um in mandateMembers} + memberUserIds.discard(None) + + allConnections = [] + for uid in memberUserIds: + allConnections.extend(rootIf.getUserConnections(uid)) + + items = _buildConnectionInventory(allConnections, rootIf, knowledgeIf, jobService) totalChunks = sum(c.get("totalChunks", 0) for c in items) totalFiles = sum(c.get("totalFiles", 0) for c in items) totalBytes = aggregateMandateRagTotalBytes(mandateId) - return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}} + mandateInstances = rootIf.getFeatureInstancesByMandate(mandateId, enabledOnly=True) + fiIds = [str(inst.id) for inst in mandateInstances if inst.id] + fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf) + totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems) + totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems) + + return { + "connections": items, + "featureInstances": fiItems, + "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}, + } except HTTPException: raise except Exception as e: @@ -308,7 +496,22 @@ def _getInventoryPlatform( totalChunks = sum(c.get("totalChunks", 0) for c in items) totalFiles = sum(c.get("totalFiles", 0) for c in items) - return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}} + from modules.datamodels.datamodelFeatures import FeatureInstance + allInstances = rootIf.db.getRecordset(FeatureInstance, recordFilter={"enabled": True}) + fiIds = [ + (r.get("id") if isinstance(r, dict) else getattr(r, "id", "")) + for r in allInstances + ] + fiIds = [fid for fid in fiIds if fid] + fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf) + totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems) + totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems) + + return { + "connections": items, + "featureInstances": fiItems, + "totals": {"files": totalFiles, "chunks": totalChunks}, + } except HTTPException: raise except Exception as e: @@ -345,8 +548,9 @@ async def _reindexConnection( if str(conn.userId) != str(currentUser.id): raise HTTPException(status_code=403, detail="Not your connection") + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) - ragDs = [ds for ds in dataSources if (ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False))] + ragDs = [ds for ds in dataSources if getEffectiveFlag(ds, "ragIndexEnabled", dataSources, mode="walk") is True] if not ragDs: return {"status": "skipped", "reason": "no_rag_enabled_datasources"} @@ -368,6 +572,47 @@ async def _reindexConnection( raise HTTPException(status_code=500, detail=str(e)) +@router.post("/reindex-feature/{workspaceInstanceId}") +@limiter.limit("10/minute") +async def _reindexFeature( + request: Request, + workspaceInstanceId: str, + currentUser: User = Depends(getCurrentUser), +) -> Dict[str, Any]: + """Re-trigger feature data bootstrap for a workspace instance. + + Indexes all RAG-enabled FeatureDataSource rows into the knowledge store. + Must be ``async def`` so ``await startJob(...)`` registers in the main loop. + """ + try: + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.serviceCenter.services.serviceBackgroundJobs import startJob + from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import FEATURE_BOOTSTRAP_JOB_TYPE + + rootIf = getRootInterface() + featureAccesses = rootIf.getFeatureAccessesForUser(str(currentUser.id)) + hasAccess = any( + str(fa.featureInstanceId) == workspaceInstanceId and fa.enabled + for fa in featureAccesses + ) + if not hasAccess and not getattr(currentUser, "isSysAdmin", False): + raise HTTPException(status_code=403, detail="No access to this feature instance") + + jobId = await startJob( + FEATURE_BOOTSTRAP_JOB_TYPE, + {"workspaceInstanceId": workspaceInstanceId}, + triggeredBy=str(currentUser.id), + ) + + logger.info("Feature reindex triggered for workspace %s (jobId=%s)", workspaceInstanceId, jobId) + return {"status": "queued", "workspaceInstanceId": workspaceInstanceId, "jobId": jobId} + except HTTPException: + raise + except Exception as e: + logger.error("Error triggering feature reindex: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/jobs") @limiter.limit("60/minute") def _getActiveJobs( diff --git a/modules/security/rbac.py b/modules/security/rbac.py index bec0b70e..59f8f55f 100644 --- a/modules/security/rbac.py +++ b/modules/security/rbac.py @@ -341,11 +341,10 @@ class RbacClass: return [] try: - conn = self.dbApp.connection roleIds = set() - + # 1. Mandant-Rollen via UserMandate → UserMandateRole (SINGLE Query) - with conn.cursor() as cursor: + with self.dbApp.borrowCursor() as cursor: cursor.execute( """ SELECT umr."roleId" @@ -357,10 +356,10 @@ class RbacClass: ) mandateRoles = cursor.fetchall() roleIds.update(r["roleId"] for r in mandateRoles if r.get("roleId")) - + # 2. Instanz-Rollen via FeatureAccess → FeatureAccessRole (SINGLE Query) if featureInstanceId: - with conn.cursor() as cursor: + with self.dbApp.borrowCursor() as cursor: cursor.execute( """ SELECT far."roleId" @@ -372,14 +371,13 @@ class RbacClass: ) instanceRoles = cursor.fetchall() roleIds.update(r["roleId"] for r in instanceRoles if r.get("roleId")) - + if not roleIds: return [] - + # 3. BULK Query: Alle Regeln für alle Rollen + zugehörige Role-Daten - # SINGLE Query mit JOIN statt N+1 roleIdsList = list(roleIds) - with conn.cursor() as cursor: + with self.dbApp.borrowCursor() as cursor: cursor.execute( """ SELECT ar.*, r."mandateId" as "roleMandateId", diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py index fff1bcb3..dbd28dd4 100644 --- a/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py +++ b/modules/serviceCenter/services/serviceAgent/coreTools/_dataSourceTools.py @@ -67,7 +67,12 @@ def _registerDataSourceTools(registry: ToolRegistry, services): sourceType = ds.get("sourceType", "") path = ds.get("path", "/") label = ds.get("label", "") - neutralize = bool(ds.get("neutralize", False)) + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag + from modules.datamodels.datamodelDataSource import DataSource + from modules.interfaces.interfaceDbApp import getRootInterface + rootIf = getRootInterface() + allConnDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) + neutralize = bool(getEffectiveFlag(ds, "neutralize", allConnDs or [ds], mode="walk")) service = _SOURCE_TYPE_TO_SERVICE.get(sourceType, sourceType) if not connectionId: raise ValueError(f"DataSource '{dsId}' has no connectionId") diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py index bdb3d23b..2ebc2720 100644 --- a/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py +++ b/modules/serviceCenter/services/serviceAgent/coreTools/_featureSubAgentTools.py @@ -110,9 +110,11 @@ def _registerFeatureSubAgentTools(registry: ToolRegistry, services): recordFilter={"featureInstanceId": featureInstanceId, "workspaceInstanceId": workspaceInstanceId}, ) + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds + _fdsAll = featureDataSources or [] _anySourceNeutralize = any( - bool(ds.get("neutralize", False) if isinstance(ds, dict) else getattr(ds, "neutralize", False)) - for ds in (featureDataSources or []) + getEffectiveFlagFds(ds, "neutralize", _fdsAll, mode="walk") is True + for ds in _fdsAll ) neutralizeFieldsPerTable: Dict[str, List[str]] = {} diff --git a/modules/serviceCenter/services/serviceAgent/featureDataProvider.py b/modules/serviceCenter/services/serviceAgent/featureDataProvider.py index d7707bdf..27ec36b2 100644 --- a/modules/serviceCenter/services/serviceAgent/featureDataProvider.py +++ b/modules/serviceCenter/services/serviceAgent/featureDataProvider.py @@ -95,8 +95,7 @@ class FeatureDataProvider: def getActualColumns(self, tableName: str) -> List[str]: """Read real column names from PostgreSQL information_schema.""" try: - conn = self._db.connection - with conn.cursor() as cur: + with self._db.borrowCursor() as cur: cur.execute( "SELECT column_name FROM information_schema.columns " "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " @@ -131,7 +130,6 @@ class FeatureDataProvider: Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``. """ _validateTableName(tableName) - conn = self._db.connection if fields: invalid = [f for f in fields if not _isValidIdentifier(f)] @@ -141,7 +139,7 @@ class FeatureDataProvider: "error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.", } - scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn) + scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db) extraWhere, extraParams = _buildFilterClauses(extraFilters) fullWhere = scopeFilter["where"] @@ -152,7 +150,7 @@ class FeatureDataProvider: t0 = time.time() try: - with conn.cursor() as cur: + with self._db.borrowCursor() as cur: countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}' cur.execute(countSql, allParams) total = cur.fetchone()["count"] if cur.rowcount else 0 @@ -179,10 +177,6 @@ class FeatureDataProvider: _debugQueryLog("browseTable", tableName, { "fields": fields, "limit": limit, "offset": offset, }, errResult, elapsed) - try: - conn.rollback() - except Exception: - pass return errResult def aggregateTable( @@ -208,8 +202,7 @@ class FeatureDataProvider: if groupBy and not _isValidIdentifier(groupBy): return {"rows": [], "error": f"Invalid groupBy field: {groupBy}"} - conn = self._db.connection - scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn) + scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db) extraWhere, extraParams = _buildFilterClauses(extraFilters) fullWhere = scopeFilter["where"] @@ -220,7 +213,7 @@ class FeatureDataProvider: t0 = time.time() try: - with conn.cursor() as cur: + with self._db.borrowCursor() as cur: if groupBy: sql = ( f'SELECT "{groupBy}" AS "groupValue", {aggregate}("{field}") AS "result" ' @@ -253,10 +246,6 @@ class FeatureDataProvider: _debugQueryLog("aggregateTable", tableName, { "aggregate": aggregate, "field": field, "groupBy": groupBy, }, errResult, elapsed) - try: - conn.rollback() - except Exception: - pass return errResult def queryTable( @@ -277,7 +266,6 @@ class FeatureDataProvider: ``extraFilters`` are mandatory record-level scoping filters injected by the pipeline. """ _validateTableName(tableName) - conn = self._db.connection if fields: invalid = [f for f in fields if not _isValidIdentifier(f)] @@ -287,7 +275,7 @@ class FeatureDataProvider: "error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.", } - scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn) + scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db) combinedFilters = list(filters or []) + list(extraFilters or []) extraWhere, extraParams = _buildFilterClauses(combinedFilters if combinedFilters else None) @@ -300,7 +288,7 @@ class FeatureDataProvider: t0 = time.time() try: - with conn.cursor() as cur: + with self._db.borrowCursor() as cur: countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}' cur.execute(countSql, allParams) total = cur.fetchone()["count"] if cur.rowcount else 0 @@ -329,10 +317,6 @@ class FeatureDataProvider: "filters": filters, "fields": fields, "orderBy": orderBy, "limit": limit, "offset": offset, }, errResult, elapsed) - try: - conn.rollback() - except Exception: - pass return errResult @@ -343,13 +327,13 @@ class FeatureDataProvider: _instanceColCache: Dict[str, str] = {} -def _resolveInstanceColumn(tableName: str, dbConnection=None) -> str: +def _resolveInstanceColumn(tableName: str, db=None) -> str: """Detect whether the table uses ``instanceId`` or ``featureInstanceId``.""" if tableName in _instanceColCache: return _instanceColCache[tableName] - if dbConnection: + if db: try: - with dbConnection.cursor() as cur: + with db.borrowCursor() as cur: cur.execute( "SELECT column_name FROM information_schema.columns " "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " @@ -378,14 +362,14 @@ def _isValidIdentifier(name: str) -> bool: return name.isidentifier() -def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, dbConnection=None) -> Dict[str, Any]: +def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, db=None, dbConnection=None) -> Dict[str, Any]: """Build the mandatory WHERE clause that scopes rows to the feature instance. Feature tables use either ``instanceId`` (commcoach, teamsbot) or ``featureInstanceId`` (trustee) as the FK. We detect the actual column - from ``information_schema`` when a DB connection is provided. + from ``information_schema`` when a DB connector is provided. """ - instanceCol = _resolveInstanceColumn(tableName, dbConnection) + instanceCol = _resolveInstanceColumn(tableName, db or dbConnection) conditions = [] params = [] diff --git a/modules/serviceCenter/services/serviceKnowledge/_buildTree.py b/modules/serviceCenter/services/serviceKnowledge/_buildTree.py new file mode 100644 index 00000000..9179f3d8 --- /dev/null +++ b/modules/serviceCenter/services/serviceKnowledge/_buildTree.py @@ -0,0 +1,1020 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Generic UDB Tree builder. + +The UDB shows three logical hierarchies as a single user-facing tree: + 1. Personal connections: UserConnection -> Service -> Folder -> File + 2. Mandate groups -> Feature instances -> FDS Workspace(*) -> FDS Table -> FDS Record + 3. (Settings/diagnostics nodes can be added later under the same model.) + +For every visible node the UI needs: + - a stable `key` (used both for expand-state and as parent reference) + - a `kind`, `label`, optional `icon` + - effective values for all three flags (neutralize, scope, ragIndexEnabled) + - whether a backing DB record exists (`dataSourceId` + `modelType`) + - whether the node has children to expand + +This module exposes one function: `getChildrenForParents(parents, ...)`. +The caller asks for the children of a list of parent keys. The orchestrator +does NOT decide what to expand; it only returns the children of what was +asked for. This keeps the contract minimal and predictable. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional, Tuple + +from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + resolveEffectiveForPath, + resolveEffectiveForFds, + _normalisePath, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Key encoding / decoding +# --------------------------------------------------------------------------- +# Format: "|||..." for data-bearing keys. +# Synthetic container keys use a single literal token without separator. +# +# Top-level (parent=None) returns: +# personalRoot (synthetic, groups all UserConnections) +# mgrp| (one per accessible mandate) +# +# Data-bearing: +# conn| +# svc|| +# ds||| +# mgrp| +# feat||| +# fdsws|| (synthetic '*' wildcard) +# fdstbl|| +# fdsrec||| + +_KEY_SEP = "|" + +# Stable, parseable synthetic-container key. Never encoded with `_encode` +# (no payload parts), always emitted/matched as literal. +_KEY_PERSONAL_ROOT = "personalRoot" + + +def _decode(key: str) -> Tuple[str, List[str]]: + parts = key.split(_KEY_SEP) + return parts[0], parts[1:] + + +def _encode(kind: str, *parts: str) -> str: + return _KEY_SEP.join((kind, *parts)) + + +# --------------------------------------------------------------------------- +# Sourcetype mapping (was hard-coded in frontend; now backend authority) +# --------------------------------------------------------------------------- +_SERVICE_TO_SOURCE_TYPE: Dict[str, str] = { + "sharepoint": "sharepointFolder", + "onedrive": "onedriveFolder", + "outlook": "outlookFolder", + "drive": "googleDriveFolder", + "gmail": "gmailFolder", + "files": "ftpFolder", + "clickup": "clickup", + "kdrive": "kdriveFolder", + "mail": "mailFolder", + "calendar": "calendarFolder", + "contact": "contactFolder", +} + +_SERVICE_LABELS: Dict[str, str] = { + "sharepoint": "SharePoint", + "outlook": "Outlook", + "teams": "Teams", + "onedrive": "OneDrive", + "drive": "Google Drive", + "gmail": "Gmail", + "files": "Files (FTP)", + "kdrive": "kDrive", + "calendar": "Calendar", + "contact": "Contacts", +} + + +# --------------------------------------------------------------------------- +# Per-node effective-value helpers +# --------------------------------------------------------------------------- + +def _effectiveTripletDs( + connectionId: str, + sourceType: str, + path: str, + allDs: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Return {effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled} + for an arbitrary DS coordinate (whether or not a record exists).""" + out = resolveEffectiveForPath(connectionId, sourceType, path, allDs, mode="aggregate") + return { + "effectiveNeutralize": out.get("effectiveNeutralize", False), + "effectiveScope": out.get("effectiveScope", "personal"), + "effectiveRagIndexEnabled": out.get("effectiveRagIndexEnabled", False), + } + + +def _effectiveTripletFds( + featureInstanceId: str, + tableName: str, + recordFilter: Optional[Dict[str, str]], + allFds: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Return effective-triplet for an FDS coordinate.""" + out = resolveEffectiveForFds(featureInstanceId, tableName, recordFilter, allFds, mode="aggregate") + return { + "effectiveNeutralize": out.get("effectiveNeutralize", False), + "effectiveScope": out.get("effectiveScope", "personal"), + "effectiveRagIndexEnabled": out.get("effectiveRagIndexEnabled", False), + } + + +def _findDsRecord( + allDs: List[Dict[str, Any]], + connectionId: str, + sourceType: str, + path: str, +) -> Optional[Dict[str, Any]]: + norm = _normalisePath(path) + for ds in allDs: + if ( + ds.get("connectionId") == connectionId + and ds.get("sourceType") == sourceType + and _normalisePath(ds.get("path")) == norm + ): + return ds + return None + + +def _findFdsRecord( + allFds: List[Dict[str, Any]], + featureInstanceId: str, + tableName: str, + recordFilter: Optional[Dict[str, str]] = None, +) -> Optional[Dict[str, Any]]: + """Find a FeatureDataSource record by featureInstanceId + tableName. + + `allFds` is already scoped to the workspace (loaded with + recordFilter={'workspaceInstanceId': wsInstanceId}), so the + distinguishing coordinate is featureInstanceId + tableName. + """ + target = recordFilter or None + for fds in allFds: + if ( + fds.get("featureInstanceId") == featureInstanceId + and fds.get("tableName") == tableName + and (fds.get("recordFilter") or None) == target + ): + return fds + return None + + +# --------------------------------------------------------------------------- +# Synthetic container helpers +# --------------------------------------------------------------------------- + +def _emptyTriplet() -> Dict[str, Any]: + """Synthetic container nodes carry no DB record and no inherited flags. + Backend reports neutral defaults so the UI never reads stale values for them.""" + return { + "effectiveNeutralize": False, + "effectiveScope": "personal", + "effectiveRagIndexEnabled": False, + } + + +def _syntheticNode( + key: str, + parentKey: Optional[str], + label: str, + icon: str, + displayOrder: int, + defaultExpanded: bool = False, +) -> Dict[str, Any]: + """Build a synthetic container node (no DB record, not flag-toggleable).""" + return { + "key": key, + "kind": "synthRoot", + "parentKey": parentKey, + "label": label, + "icon": icon, + "hasChildren": True, + "dataSourceId": None, + "modelType": None, + **_emptyTriplet(), + "supportsRag": False, + "canBeAdded": False, + "displayOrder": displayOrder, + "defaultExpanded": defaultExpanded, + } + + +# --------------------------------------------------------------------------- +# Top-level (parent = None) -> personalRoot + mandate groups (flat layout) +# --------------------------------------------------------------------------- + +def _topLevel( + instanceId: str, + context: Any, + rootIf: Any, + _allDs: List[Dict[str, Any]], + allFds: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return the visible top-level: 'personalRoot' first, then one node per + accessible mandate group. Both layers are marked `defaultExpanded=True` + so the UI opens down to the data-source level on first render. + """ + nodes: List[Dict[str, Any]] = [ + _syntheticNode( + key=_KEY_PERSONAL_ROOT, + parentKey=None, + label=resolveTextSafe("Persönliche Quellen"), + icon="person", + displayOrder=0, + defaultExpanded=True, + ) + ] + nodes.extend(_listMandateGroups(instanceId, context, rootIf, allFds)) + return nodes + + +# --------------------------------------------------------------------------- +# Children of personalRoot -> active UserConnections +# --------------------------------------------------------------------------- + +def _personalRootChildren( + instanceId: str, + context: Any, + allDs: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return one node per active UserConnection of the current user.""" + from modules.serviceCenter import getService + from modules.serviceCenter.context import ServiceCenterContext + + mandateId = getattr(context, "mandateId", "") or "" + ctx = ServiceCenterContext( + user=context.user, + mandate_id=mandateId, + feature_instance_id=instanceId, + ) + chatService = getService("chat", ctx) + connections = chatService.getUserConnections() or [] + + nodes: List[Dict[str, Any]] = [] + for c in connections: + conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {}) + status = conn.get("status") + if hasattr(status, "value"): + status = status.value + if status != "active": + continue + authority = conn.get("authority") + if hasattr(authority, "value"): + authority = authority.value + connId = conn.get("id") or "" + label = conn.get("externalEmail") or conn.get("externalUsername") or authority or "" + # Connection root = path '/' on its authority sourceType. + triplet = _effectiveTripletDs(connId, str(authority), "/", allDs) + rec = _findDsRecord(allDs, connId, str(authority), "/") + nodes.append({ + "key": _encode("conn", connId), + "kind": "connection", + "parentKey": _KEY_PERSONAL_ROOT, + "label": label, + "icon": str(authority), + "hasChildren": True, + "dataSourceId": rec.get("id") if rec else None, + "modelType": "DataSource" if rec else None, + **triplet, + "supportsRag": True, + "canBeAdded": rec is None, + "authority": authority, + "connectionId": connId, + }) + return nodes + + +# --------------------------------------------------------------------------- +# Mandate-group nodes (rendered top-level next to personalRoot) +# --------------------------------------------------------------------------- + +def _listMandateGroups( + _instanceId: str, + context: Any, + rootIf: Any, + _allFds: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return one mandate-group node per accessible mandate that has at least + one enabled feature instance with registered DATA objects. + + Emitted at the top level (parentKey=None). `defaultExpanded=True` so the + UI shows feature-instance children (= mandate data sources) without a + second user click. + """ + from modules.security.rbacCatalog import getCatalogService + from modules.datamodels.datamodelUam import Mandate + + userId = str(context.user.id) + catalog = getCatalogService() + featureCodesWithData = catalog.getFeaturesWithDataObjects() + userMandates = rootIf.getUserMandates(userId) + + wsMandateId = getattr(context, "mandateId", None) + allowedMandateIds = {um.mandateId for um in (userMandates or [])} + if wsMandateId and wsMandateId in allowedMandateIds: + allowedMandateIds = {wsMandateId} + + mandateLabels: Dict[str, str] = {} + for um in userMandates or []: + if um.mandateId not in allowedMandateIds: + continue + try: + rows = rootIf.db.getRecordset(Mandate, recordFilter={"id": um.mandateId}) + if rows: + m = rows[0] + mandateLabels[um.mandateId] = m.get("label") or m.get("name") or um.mandateId + except Exception: + mandateLabels[um.mandateId] = um.mandateId + + nodes: List[Dict[str, Any]] = [] + seenMandates: set = set() + for um in userMandates or []: + mid = um.mandateId + if mid in seenMandates or mid not in allowedMandateIds: + continue + seenMandates.add(mid) + instances = rootIf.getFeatureInstancesByMandate(mid) + hasFeature = False + for inst in instances: + if inst.enabled and inst.featureCode in featureCodesWithData: + fa = rootIf.getFeatureAccess(userId, inst.id) + if fa and fa.enabled: + hasFeature = True + break + if not hasFeature: + continue + nodes.append({ + "key": _encode("mgrp", mid), + "kind": "mandateGroup", + "parentKey": None, + "label": mandateLabels.get(mid, mid), + "icon": "mandate", + "hasChildren": True, + "dataSourceId": None, + "modelType": None, + **_emptyTriplet(), + "supportsRag": False, + "canBeAdded": False, + "mandateId": mid, + "defaultExpanded": True, + }) + return nodes + + +# --------------------------------------------------------------------------- +# Children of a connection -> services +# --------------------------------------------------------------------------- + +async def _connectionServices( + instanceId: str, + context: Any, + connectionId: str, + allDs: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + from modules.connectors.connectorResolver import ConnectorResolver + from modules.serviceCenter import getService + from modules.serviceCenter.context import ServiceCenterContext + + mandateId = getattr(context, "mandateId", "") or "" + ctx = ServiceCenterContext( + user=context.user, + mandate_id=mandateId, + feature_instance_id=instanceId, + ) + chatService = getService("chat", ctx) + securityService = getService("security", ctx) + from modules.features.workspace.routeFeatureWorkspace import _buildResolverDbInterface + dbInterface = _buildResolverDbInterface(chatService) + resolver = ConnectorResolver(securityService, dbInterface) + try: + provider = await resolver.resolve(connectionId) + services = provider.getAvailableServices() + except Exception as exc: + logger.error("Tree: cannot resolve services for connection %s: %s", connectionId, exc) + return [] + + nodes: List[Dict[str, Any]] = [] + for service in services or []: + sourceType = _SERVICE_TO_SOURCE_TYPE.get(service, service) + triplet = _effectiveTripletDs(connectionId, sourceType, "/", allDs) + rec = _findDsRecord(allDs, connectionId, sourceType, "/") + nodes.append({ + "key": _encode("svc", connectionId, service), + "kind": "service", + "parentKey": _encode("conn", connectionId), + "label": _SERVICE_LABELS.get(service, service), + "icon": service, + "hasChildren": True, + "dataSourceId": rec.get("id") if rec else None, + "modelType": "DataSource" if rec else None, + **triplet, + "supportsRag": True, + "canBeAdded": rec is None, + "connectionId": connectionId, + "service": service, + "sourceType": sourceType, + "path": "/", + }) + return nodes + + +# --------------------------------------------------------------------------- +# Children of a folder/service -> next-level folders+files via browse +# --------------------------------------------------------------------------- + +async def _browseChildren( + instanceId: str, + context: Any, + connectionId: str, + service: str, + sourceType: str, + parentPath: str, + allDs: List[Dict[str, Any]], + parentKey: Optional[str] = None, +) -> List[Dict[str, Any]]: + from modules.connectors.connectorResolver import ConnectorResolver + from modules.serviceCenter import getService + from modules.serviceCenter.context import ServiceCenterContext + + mandateId = getattr(context, "mandateId", "") or "" + ctx = ServiceCenterContext( + user=context.user, + mandate_id=mandateId, + feature_instance_id=instanceId, + ) + chatService = getService("chat", ctx) + securityService = getService("security", ctx) + from modules.features.workspace.routeFeatureWorkspace import _buildResolverDbInterface + dbInterface = _buildResolverDbInterface(chatService) + resolver = ConnectorResolver(securityService, dbInterface) + try: + adapter = await resolver.resolveService(connectionId, service) + entries = await adapter.browse(parentPath, filter=None) + except Exception as exc: + logger.error("Tree: cannot browse %s on connection %s path=%s: %s", service, connectionId, parentPath, exc) + return [] + + # Children parentKey must equal the key the caller asked for (= the + # currently-expanded node in the UI). If the caller doesn't pass an + # explicit key, fall back to the encoded ds-coordinate. + effectiveParentKey = parentKey if parentKey is not None else _encode("ds", connectionId, sourceType, parentPath) + nodes: List[Dict[str, Any]] = [] + for e in entries or []: + path = getattr(e, "path", "") or "" + kind = "folder" if getattr(e, "isFolder", False) else "file" + triplet = _effectiveTripletDs(connectionId, sourceType, path, allDs) + rec = _findDsRecord(allDs, connectionId, sourceType, path) + nodes.append({ + "key": _encode("ds", connectionId, sourceType, path), + "kind": kind, + "parentKey": effectiveParentKey, + "label": getattr(e, "name", "") or path, + "icon": kind, + "hasChildren": kind == "folder", + "dataSourceId": rec.get("id") if rec else None, + "modelType": "DataSource" if rec else None, + **triplet, + "supportsRag": True, + "canBeAdded": rec is None, + "connectionId": connectionId, + "service": service, + "sourceType": sourceType, + "path": path, + }) + return nodes + + +# --------------------------------------------------------------------------- +# Mandate group -> feature connections +# --------------------------------------------------------------------------- + +def _featureConnectionsForMandate( + instanceId: str, + context: Any, + rootIf: Any, + mandateId: str, + allFds: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + from modules.security.rbacCatalog import getCatalogService + + userId = str(context.user.id) + catalog = getCatalogService() + featureCodesWithData = catalog.getFeaturesWithDataObjects() + instances = rootIf.getFeatureInstancesByMandate(mandateId) + + parentKey = _encode("mgrp", mandateId) + nodes: List[Dict[str, Any]] = [] + for inst in instances or []: + if not inst.enabled: + continue + if inst.featureCode not in featureCodesWithData: + continue + fa = rootIf.getFeatureAccess(userId, inst.id) + if not fa or not fa.enabled: + continue + # Effective values come from the FDS workspace-wildcard for this featureInstanceId + wsId = inst.id + triplet = _effectiveTripletFds(wsId, "*", None, allFds) + rec = _findFdsRecord(allFds, wsId, "*", None) + featureDef = catalog.getFeatureDefinition(inst.featureCode) or {} + nodes.append({ + "key": _encode("feat", mandateId, inst.featureCode, inst.id), + "kind": "featureNode", + "parentKey": parentKey, + "label": inst.label or inst.featureCode, + "icon": featureDef.get("icon", "mdi-database"), + "hasChildren": True, + "dataSourceId": rec.get("id") if rec else None, + "modelType": "FeatureDataSource" if rec else None, + **triplet, + "supportsRag": True, + "canBeAdded": rec is None, + "featureInstanceId": wsId, + "featureCode": inst.featureCode, + "mandateId": mandateId, + "tableName": "*", + }) + return nodes + + +# --------------------------------------------------------------------------- +# Feature node -> tables +# --------------------------------------------------------------------------- + +def _featureTables( + context: Any, + rootIf: Any, + parentKey: str, + featureInstanceId: str, + featureCode: str, + allFds: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + from modules.security.rbacCatalog import getCatalogService + + inst = rootIf.getFeatureInstance(featureInstanceId) + if not inst: + return [] + catalog = getCatalogService() + try: + from modules.security.rbac import RbacClass + from modules.security.rootAccess import getRootDbAppConnector + dbApp = getRootDbAppConnector() + rbac = RbacClass(dbApp, dbApp=dbApp) + accessible = catalog.getAccessibleDataObjects( + featureCode=inst.featureCode, + rbacInstance=rbac, + user=context.user, + mandateId=str(inst.mandateId) if inst.mandateId else "", + featureInstanceId=featureInstanceId, + ) + except Exception: + accessible = catalog.getDataObjects(inst.featureCode) + + accessibleKeys = {obj.get("objectKey", "") for obj in accessible} + + nodes: List[Dict[str, Any]] = [] + for obj in catalog.getDataObjects(inst.featureCode): + meta = obj.get("meta", {}) + if meta.get("wildcard") or meta.get("isGroup"): + continue + objectKey = obj.get("objectKey", "") + if objectKey not in accessibleKeys: + continue + tableName = meta.get("table", "") + if not tableName: + continue + triplet = _effectiveTripletFds(featureInstanceId, tableName, None, allFds) + rec = _findFdsRecord(allFds, featureInstanceId, tableName, None) + fields = meta.get("fields") if isinstance(meta, dict) else None + hasFields = bool(isinstance(fields, list) and len(fields) > 0) + # Surface the persisted per-field neutralize list so the UI can + # render & toggle field-level icons without an extra GET. + neutralizeFields: List[str] = [] + if rec and isinstance(rec.get("neutralizeFields"), list): + neutralizeFields = [f for f in rec["neutralizeFields"] if isinstance(f, str)] + nodes.append({ + "key": _encode("fdstbl", featureInstanceId, tableName), + "kind": "fdsTable", + "parentKey": parentKey, + "label": resolveTextSafe(obj.get("label", "")) or tableName, + "icon": "table", + # Children = the per-column field nodes. Only emitted when the + # data-object metadata declared a non-empty `fields` list. + "hasChildren": hasFields, + "dataSourceId": rec.get("id") if rec else None, + "modelType": "FeatureDataSource" if rec else None, + **triplet, + "supportsRag": True, + "canBeAdded": rec is None, + "featureInstanceId": featureInstanceId, + "featureCode": featureCode, + "tableName": tableName, + "objectKey": objectKey, + "neutralizeFields": neutralizeFields, + }) + return nodes + + +def _featureTableFields( + parentKey: str, + featureInstanceId: str, + tableName: str, + fieldNames: List[str], + allFds: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Emit one node per declared column of a feature data table. + + Per-field neutralize semantics: + - The table-level FDS record carries `neutralizeFields: List[str]`. + - A field is "effectively neutralized" iff its name is in that list + OR the table's effective `neutralize` is True (blanket). + - Only `neutralize` is meaningful per-field; `scope` and `ragIndexEnabled` + are inherited from the parent table and not toggleable here. + """ + rec = _findFdsRecord(allFds, featureInstanceId, tableName, None) + tableNeutralize = bool(rec.get("neutralize")) if rec else False + neutralizeFields = rec.get("neutralizeFields") if rec else None + if not isinstance(neutralizeFields, list): + neutralizeFields = [] + + nodes: List[Dict[str, Any]] = [] + for field in fieldNames: + if not field: + continue + fieldNeutralized = bool(tableNeutralize or field in neutralizeFields) + nodes.append({ + "key": _encode("fdsfld", featureInstanceId, tableName, field), + "kind": "fdsField", + "parentKey": parentKey, + "label": field, + "icon": "field", + "hasChildren": False, + "dataSourceId": rec.get("id") if rec else None, + "modelType": "FeatureDataSource" if rec else None, + "effectiveNeutralize": fieldNeutralized, + # Field-level scope/RAG do not exist as a concept; the FE hides + # those affordances when supportsRag=False. We still need + # `effectiveScope` + `effectiveRagIndexEnabled` for the + # contract; they reflect the parent's effective values so the + # backend stays single source of truth. + "effectiveScope": "personal", + "effectiveRagIndexEnabled": False, + "supportsRag": False, + "canBeAdded": rec is None, + "featureInstanceId": featureInstanceId, + "tableName": tableName, + "fieldName": field, + }) + return nodes + + +def resolveTextSafe(label: Any) -> str: + try: + from modules.shared.i18nRegistry import resolveText + return resolveText(label) + except Exception: + return str(label or "") + + +# --------------------------------------------------------------------------- +# Public entrypoint +# --------------------------------------------------------------------------- + +async def getChildrenForParents( + instanceId: str, + parents: List[Optional[str]], + context: Any, +) -> Dict[str, List[Dict[str, Any]]]: + """Return per-parent children lists. + + `parents` is a list with `None` representing the top-level. Order is preserved. + Returns a dict keyed by parent key (or '__root__' for None). + + Each child is a fully-rendered TreeNode dict (see module docstring for shape). + """ + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelDataSource import DataSource + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + + rootIf = getRootInterface() + + # Pre-load DS (per user) and FDS (per workspace) once for the whole request. + userId = str(context.user.id) + allDs = rootIf.db.getRecordset(DataSource, recordFilter={"userId": userId}) or [] + allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": instanceId}) or [] + + out: Dict[str, List[Dict[str, Any]]] = {} + + for parentKey in parents: + if parentKey is None: + try: + out["__root__"] = _topLevel(instanceId, context, rootIf, allDs, allFds) + except Exception as exc: + logger.exception("Tree top-level failed: %s", exc) + out["__root__"] = [] + continue + + try: + kind, parts = _decode(parentKey) + except Exception: + out[parentKey] = [] + continue + + try: + if parentKey == _KEY_PERSONAL_ROOT: + out[parentKey] = _personalRootChildren(instanceId, context, allDs) + + elif kind == "conn" and len(parts) == 1: + out[parentKey] = await _connectionServices(instanceId, context, parts[0], allDs) + + elif kind == "svc" and len(parts) == 2: + connId, service = parts + sourceType = _SERVICE_TO_SOURCE_TYPE.get(service, service) + out[parentKey] = await _browseChildren( + instanceId, context, connId, service, sourceType, "/", allDs, + parentKey=parentKey, + ) + + elif kind == "ds" and len(parts) == 3: + connId, sourceType, path = parts + # Determine service from sourceType (reverse map) + service = _reverseService(sourceType) + out[parentKey] = await _browseChildren( + instanceId, context, connId, service, sourceType, path, allDs, + parentKey=parentKey, + ) + + elif kind == "mgrp" and len(parts) == 1: + out[parentKey] = _featureConnectionsForMandate(instanceId, context, rootIf, parts[0], allFds) + + elif kind == "feat" and len(parts) == 3: + _mandateId, featureCode, featureInstanceId = parts + out[parentKey] = _featureTables(context, rootIf, parentKey, featureInstanceId, featureCode, allFds) + + elif kind == "fdstbl" and len(parts) == 2: + featureInstanceId, tableName = parts + fieldNames = _resolveTableFieldNames(featureInstanceId, tableName, rootIf) + out[parentKey] = _featureTableFields( + parentKey, featureInstanceId, tableName, fieldNames, allFds, + ) + + else: + out[parentKey] = [] + except Exception as exc: + logger.exception("Tree children for %s failed: %s", parentKey, exc) + out[parentKey] = [] + + return out + + +def _reverseService(sourceType: str) -> str: + for svc, st in _SERVICE_TO_SOURCE_TYPE.items(): + if st == sourceType: + return svc + return sourceType + + +def _resolveTableFieldNames(featureInstanceId: str, tableName: str, rootIf: Any) -> List[str]: + """Look up the declared column list for a (featureInstance, tableName) + pair via the RBAC catalog data-object metadata. Returns empty list when + the catalog has no entry (e.g. wildcard-only feature).""" + from modules.security.rbacCatalog import getCatalogService + inst = rootIf.getFeatureInstance(featureInstanceId) + if not inst: + return [] + catalog = getCatalogService() + for obj in catalog.getDataObjects(inst.featureCode) or []: + meta = obj.get("meta", {}) if isinstance(obj, dict) else {} + if meta.get("table") == tableName: + fields = meta.get("fields") + if isinstance(fields, list): + return [f for f in fields if isinstance(f, str) and f] + return [] + return [] + + +# --------------------------------------------------------------------------- +# Attribute-only refresh: given node keys, return current effective values +# --------------------------------------------------------------------------- + +async def getAttributesForKeys( + instanceId: str, + keys: List[str], + context: Any, +) -> Dict[str, Dict[str, Any]]: + """Return effective attribute values for a list of node keys. + + Used by the frontend after a toggle to refresh only attributes (neutralize, + scope, ragIndexEnabled) without reloading the tree structure. For container + nodes (personalRoot, mgrp), aggregates child values and returns 'mixed' + when children diverge.""" + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelDataSource import DataSource + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + + rootIf = getRootInterface() + userId = str(context.user.id) + allDs = rootIf.db.getRecordset(DataSource, recordFilter={"userId": userId}) or [] + allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": instanceId}) or [] + + result: Dict[str, Dict[str, Any]] = {} + + for key in keys: + try: + attrs = _resolveAttrsForKey(key, allDs, allFds, instanceId, context, rootIf) + if attrs is not None: + result[key] = attrs + if "mixed" in str(attrs.values()): + logger.info("getAttributesForKeys key=%s returned MIXED: %s", key, attrs) + except Exception as exc: + logger.warning("getAttributesForKeys failed for key=%s: %s", key, exc) + + logger.info("getAttributesForKeys: %d keys requested, %d resolved", len(keys), len(result)) + return result + + +def _resolveAttrsForKey( + key: str, + allDs: List[Dict[str, Any]], + allFds: List[Dict[str, Any]], + instanceId: str, + context: Any, + rootIf: Any, +) -> Optional[Dict[str, Any]]: + """Resolve effective attributes for a single node key.""" + if key == _KEY_PERSONAL_ROOT: + return _aggregatePersonalRoot(allDs) + + try: + kind, parts = _decode(key) + except Exception: + return None + + if kind == "mgrp" and len(parts) == 1: + return _aggregateMandateGroup(parts[0], allFds, instanceId, context, rootIf) + + if kind == "conn" and len(parts) == 1: + connId = parts[0] + return _aggregateConnection(connId, allDs) + + if kind == "svc" and len(parts) == 2: + connId, service = parts + sourceType = _SERVICE_TO_SOURCE_TYPE.get(service, service) + return _effectiveTripletDs(connId, sourceType, "/", allDs) + + if kind == "ds" and len(parts) == 3: + connId, sourceType, path = parts + return _effectiveTripletDs(connId, sourceType, path, allDs) + + if kind == "feat" and len(parts) == 3: + _mandateId, _featureCode, featureInstanceId = parts + return _effectiveTripletFds(featureInstanceId, "*", None, allFds) + + if kind == "fdsws" and len(parts) == 2: + workspaceInstanceId, _featureCode = parts + return _effectiveTripletFds(workspaceInstanceId, "*", None, allFds) + + if kind == "fdstbl" and len(parts) == 2: + featureInstanceId, tableName = parts + return _effectiveTripletFds(featureInstanceId, tableName, None, allFds) + + if kind == "fdsrec" and len(parts) == 3: + featureInstanceId, tableName, recordId = parts + return _effectiveTripletFds(featureInstanceId, tableName, {"objectKey": recordId}, allFds) + + if kind == "fdsfld" and len(parts) >= 3: + featureInstanceId, tableName = parts[0], parts[1] + fieldName = parts[2] if len(parts) > 2 else "" + parentFds = None + for fds in allFds: + if (fds.get("featureInstanceId") == featureInstanceId + and (fds.get("tableName") or "") == tableName + and fds.get("recordFilter") is None): + parentFds = fds + break + neutralizeFields = (parentFds.get("neutralizeFields") or []) if parentFds else [] + return {"effectiveNeutralize": fieldName in neutralizeFields} + + return None + + +def _aggregateConnection(connId: str, allDs: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate effective values for a connection node. + + If the connection has an authority-level DS record (path="/"), use the + standard aggregate mode on it (which already handles subtree correctly). + Otherwise compute effective values for each child DS using walk mode and + aggregate them manually.""" + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + getEffectiveFlag, _AUTHORITY_SOURCE_TYPES, + ) + connRecords = [d for d in allDs if d.get("connectionId") == connId] + if not connRecords: + return {"effectiveNeutralize": False, "effectiveScope": "personal", "effectiveRagIndexEnabled": False} + + rootRec = None + for r in connRecords: + st = r.get("sourceType", "") + if st in _AUTHORITY_SOURCE_TYPES and _normalisePath(r.get("path", "")) == "/": + rootRec = r + break + + if rootRec: + return _effectiveTripletDs(connId, rootRec.get("sourceType", ""), "/", allDs) + + neutralizeVals = set() + scopeVals = set() + ragVals = set() + for r in connRecords: + neutralizeVals.add(getEffectiveFlag(r, "neutralize", allDs, mode="walk")) + scopeVals.add(getEffectiveFlag(r, "scope", allDs, mode="walk")) + ragVals.add(getEffectiveFlag(r, "ragIndexEnabled", allDs, mode="walk")) + return { + "effectiveNeutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False), + "effectiveScope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"), + "effectiveRagIndexEnabled": "mixed" if len(ragVals) > 1 else (ragVals.pop() if ragVals else False), + } + + +def _aggregatePersonalRoot(allDs: List[Dict[str, Any]]) -> Dict[str, Any]: + """Aggregate effective values across all personal DS records. + + Uses getEffectiveFlag in aggregate mode on each connection-root record. + If no root records exist, aggregates walk-effective values of all records.""" + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( + getEffectiveFlag, _AUTHORITY_SOURCE_TYPES, + ) + if not allDs: + return {"effectiveNeutralize": False, "effectiveScope": "personal", "effectiveRagIndexEnabled": False} + + rootRecords = [ + d for d in allDs + if d.get("sourceType", "") in _AUTHORITY_SOURCE_TYPES + and _normalisePath(d.get("path", "")) == "/" + ] + targets = rootRecords if rootRecords else allDs + + neutralizeVals = set() + scopeVals = set() + ragVals = set() + for ds in targets: + neutralizeVals.add(getEffectiveFlag(ds, "neutralize", allDs, mode="aggregate")) + scopeVals.add(getEffectiveFlag(ds, "scope", allDs, mode="aggregate")) + ragVals.add(getEffectiveFlag(ds, "ragIndexEnabled", allDs, mode="aggregate")) + return { + "effectiveNeutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False), + "effectiveScope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"), + "effectiveRagIndexEnabled": "mixed" if len(ragVals) > 1 else (ragVals.pop() if ragVals else False), + } + + +def _aggregateMandateGroup( + mandateId: str, + allFds: List[Dict[str, Any]], + instanceId: str, + context: Any, + rootIf: Any, +) -> Dict[str, Any]: + """Aggregate effective values across FDS records belonging to this mandate group. + + Uses getEffectiveFlagFds in aggregate mode on each workspace-level FDS + (tableName="*") that belongs to the given mandateId. This correctly resolves + inherited values from the full FDS hierarchy.""" + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds + + groupFds = [f for f in allFds if f.get("mandateId") == mandateId] + workspaceLevelFds = [f for f in groupFds if (f.get("tableName") or "") == "*"] + targets = workspaceLevelFds if workspaceLevelFds else groupFds + + if not targets: + return {"effectiveNeutralize": False, "effectiveScope": "personal", "effectiveRagIndexEnabled": False} + + neutralizeVals = set() + scopeVals = set() + ragVals = set() + for fds in targets: + neutralizeVals.add(getEffectiveFlagFds(fds, "neutralize", allFds, mode="aggregate")) + scopeVals.add(getEffectiveFlagFds(fds, "scope", allFds, mode="aggregate")) + ragVals.add(getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate")) + return { + "effectiveNeutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False), + "effectiveScope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"), + "effectiveRagIndexEnabled": "mixed" if len(ragVals) > 1 else (ragVals.pop() if ragVals else False), + } diff --git a/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py b/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py index 00180c9f..64a0019c 100644 --- a/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py +++ b/modules/serviceCenter/services/serviceKnowledge/_inheritFlags.py @@ -3,9 +3,15 @@ """Cascade-inherit semantics for DataSource flags (neutralize, ragIndexEnabled, scope). Three-state flags allow tree elements to either set an explicit value or -inherit the value from their nearest ancestor in the path hierarchy. The -walker (RAG/Neutralize) and routes resolve the *effective* value; the cascade -helper resets explicit descendant values when a parent is toggled. +inherit the value from their nearest ancestor in the path hierarchy. + +Modes: + - 'walk' (default): resolves the *concrete* effective value per-item + (never returns 'mixed'). Used by backend consumers (RAG walker, + neutralization pipeline, scope filter, etc.). + - 'aggregate': resolves the *display* effective value per-item. If the + item has descendants with differing walk-effective values, returns + 'mixed'. Used by listing endpoints and PATCH responses for the UI. Path-traversal rules: - A DataSource is identified by `(connectionId, sourceType, path)`. @@ -17,11 +23,12 @@ Path-traversal rules: """ import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple logger = logging.getLogger(__name__) _INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope") +_INHERITABLE_FDS_FLAGS = ("neutralize", "ragIndexEnabled", "scope") # Connection-root DataSources carry the authority as their sourceType # (e.g. 'msft', 'google'). They sit one level above all service DataSources @@ -29,6 +36,12 @@ _INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope") # cross sourceType boundaries — but ONLY from these authority roots. _AUTHORITY_SOURCE_TYPES = frozenset({"local", "google", "msft", "clickup", "infomaniak"}) +Mode = Literal["walk", "aggregate"] + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- def _normalisePath(path: Optional[str]) -> str: """Normalize a DataSource path to '/'-prefixed, no trailing slash (except root).""" @@ -49,10 +62,7 @@ def _flagDefault(flag: str) -> Any: def _isExplicit(value: Any) -> bool: - """A flag value is explicit when it is not None. - - Note: legacy rows may carry empty-string scope; treat as inherit too. - """ + """A flag value is explicit when it is not None/empty-string.""" if value is None: return False if isinstance(value, str) and value == "": @@ -66,6 +76,21 @@ def _getRecordValue(rec: Any, key: str) -> Any: return getattr(rec, key, None) +def _isAncestorPath(ancestor: str, descendant: str) -> bool: + """True iff `ancestor` is a strict path-prefix of `descendant`.""" + if ancestor == descendant: + return False + if ancestor == "/": + return descendant != "/" + return descendant.startswith(ancestor + "/") + + +def _pathDepth(path: str) -> int: + if path == "/": + return 0 + return path.count("/") + + def _findAncestorChain( rec: Dict[str, Any], allDs: Iterable[Dict[str, Any]], @@ -74,15 +99,13 @@ def _findAncestorChain( ordered nearest-first. Two ancestor relations are merged: - 1) **same-sourceType path-ancestor** — strict path-prefix within the - same service tree (sharepointFolder, gmailFolder, ...). - 2) **connection-root ancestor** — a DS with `path='/'` and - `sourceType` ∈ authority set (msft, google, ...) is the parent of - every other DS in that connection regardless of sourceType, so a - toggle on the connection node propagates to all services beneath. + 1) same-sourceType path-ancestor — strict path-prefix within the + same service tree. + 2) connection-root ancestor — a DS with `path='/'` and + `sourceType` in authority set is the parent of every other DS + in that connection regardless of sourceType. - The connection-root is always the most distant ancestor and therefore - sorts after any same-sourceType ancestors. + The connection-root is always the most distant ancestor. """ recPath = _normalisePath(_getRecordValue(rec, "path")) recSourceType = _getRecordValue(rec, "sourceType") @@ -114,36 +137,89 @@ def _findAncestorChain( return chain -def _isAncestorPath(ancestor: str, descendant: str) -> bool: - """True iff `ancestor` is a strict path-prefix of `descendant`. +def _isDescendantDs(parentRec: Dict[str, Any], candidate: Dict[str, Any]) -> bool: + """True iff `candidate` is a descendant of `parentRec` in the DS hierarchy.""" + parentSourceType = _getRecordValue(parentRec, "sourceType") + parentPath = _normalisePath(_getRecordValue(parentRec, "path")) + parentConnectionId = _getRecordValue(parentRec, "connectionId") + parentId = _getRecordValue(parentRec, "id") - '/' is ancestor of every non-root path. For non-root prefixes, the - descendant must continue with '/' so '/foo' isn't treated as ancestor of - '/foobar'. - """ - if ancestor == descendant: + candId = _getRecordValue(candidate, "id") + if candId == parentId: + return False + if _getRecordValue(candidate, "connectionId") != parentConnectionId: return False - if ancestor == "/": - return descendant != "/" - return descendant.startswith(ancestor + "/") + candSourceType = _getRecordValue(candidate, "sourceType") + candPath = _normalisePath(_getRecordValue(candidate, "path")) + + parentIsConnectionRoot = ( + parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/" + ) + if parentIsConnectionRoot: + return True + if candSourceType != parentSourceType: + return False + return _isAncestorPath(parentPath, candPath) + + +# --------------------------------------------------------------------------- +# DataSource: getEffectiveFlag +# --------------------------------------------------------------------------- def getEffectiveFlag( rec: Dict[str, Any], flag: str, sameConnectionDs: Iterable[Dict[str, Any]], + mode: Mode = "walk", ) -> Any: """Resolve the effective value of a flag via path-traversal. - Order: own value (if explicit) → nearest ancestor with explicit value → - static default (`False` or `'personal'`). + mode='walk': own explicit → nearest ancestor explicit → default. + Always returns a concrete value (never 'mixed'). + mode='aggregate': same as walk for leaf value, but if the item has + descendants whose walk-effective values differ from + each other, returns 'mixed'. """ if flag not in _INHERITABLE_FLAGS: raise ValueError(f"Unknown inheritable flag: {flag}") + + allDs = list(sameConnectionDs) + + walkValue = _resolveWalkValue(rec, flag, allDs) + + if mode == "walk": + return walkValue + + # mode == 'aggregate': check subtree for heterogeneous effective values + descendants = [d for d in allDs if _isDescendantDs(rec, d)] + if not descendants: + return walkValue + + subtreeValues = set() + subtreeValues.add(_normaliseForComparison(walkValue)) + for desc in descendants: + descEffective = _resolveWalkValue(desc, flag, allDs) + subtreeValues.add(_normaliseForComparison(descEffective)) + if len(subtreeValues) > 1: + recId = _getRecordValue(rec, "id") + descId = _getRecordValue(desc, "id") + descOwnVal = _getRecordValue(desc, flag) + logger.info( + "DS aggregate MIXED for rec=%s flag=%s: walkValue=%s, " + "divergent desc=%s (own=%s, effective=%s), subtreeValues=%s", + recId, flag, walkValue, descId, descOwnVal, descEffective, subtreeValues, + ) + return "mixed" + return walkValue + + +def _resolveWalkValue(rec: Dict[str, Any], flag: str, allDs: List[Dict[str, Any]]) -> Any: + """Core walk resolution: own explicit → ancestor chain → default.""" own = _getRecordValue(rec, flag) if _isExplicit(own): return own - chain = _findAncestorChain(rec, sameConnectionDs) + chain = _findAncestorChain(rec, allDs) for ancestor in chain: ancestorVal = _getRecordValue(ancestor, flag) if _isExplicit(ancestorVal): @@ -151,69 +227,112 @@ def getEffectiveFlag( return _flagDefault(flag) +def _normaliseForComparison(value: Any) -> Any: + """Normalize values for set-comparison (bool as int to avoid hash issues).""" + if isinstance(value, bool): + return int(value) + return value + + +# --------------------------------------------------------------------------- +# DataSource: cascadeResetDescendants (bottom-up) +# --------------------------------------------------------------------------- + def cascadeResetDescendants( rootIf: Any, parentRec: Dict[str, Any], flag: str, -) -> int: +) -> List[str]: """Reset all explicit descendant values of `flag` to NULL (= inherit). - Descendant relation mirrors `_findAncestorChain`: - - Connection-root (`path='/'` AND `sourceType` ∈ authorities) is parent - of every other DS in that connection (cross-sourceType cascade). - - Otherwise: same-sourceType strict path-descendants only. + Reset order: bottom-up (deepest first) for crash safety. + The parent itself is NOT modified here — the caller sets the master value + after this function returns. - Only the targeted `flag` is reset; other flags on the descendant are - untouched. - - Returns the number of records updated. + Returns list of reset record IDs in bottom-up order. """ if flag not in _INHERITABLE_FLAGS: raise ValueError(f"Unknown inheritable flag: {flag}") from modules.datamodels.datamodelDataSource import DataSource connectionId = _getRecordValue(parentRec, "connectionId") - parentSourceType = _getRecordValue(parentRec, "sourceType") - parentPath = _normalisePath(_getRecordValue(parentRec, "path")) parentId = _getRecordValue(parentRec, "id") - if not connectionId or not parentSourceType: - return 0 - - parentIsConnectionRoot = ( - parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/" - ) + if not connectionId: + return [] siblings = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) - affected = 0 + + toReset: List[Tuple[int, str]] = [] for sib in siblings: - sibId = _getRecordValue(sib, "id") - if sibId == parentId: + if not _isDescendantDs(parentRec, sib): continue - sibSourceType = _getRecordValue(sib, "sourceType") - sibPath = _normalisePath(_getRecordValue(sib, "path")) - if parentIsConnectionRoot: - # Connection-root resets everything else under this connection. - pass - else: - if sibSourceType != parentSourceType: - continue - if not _isAncestorPath(parentPath, sibPath): - continue sibVal = _getRecordValue(sib, flag) if not _isExplicit(sibVal): continue + sibId = _getRecordValue(sib, "id") + sibPath = _normalisePath(_getRecordValue(sib, "path")) + toReset.append((_pathDepth(sibPath), sibId)) + + # Sort deepest first (bottom-up) + toReset.sort(key=lambda x: x[0], reverse=True) + + resetIds: List[str] = [] + for _, sibId in toReset: try: rootIf.db.recordModify(DataSource, sibId, {flag: None}) - affected += 1 + resetIds.append(sibId) except Exception as exc: logger.warning("Cascade-reset failed for DataSource %s flag=%s: %s", sibId, flag, exc) - if affected: - logger.info( - "Cascade-reset %s on %d descendants of DataSource (connectionId=%s, sourceType=%s, path=%s, connectionRoot=%s)", - flag, affected, connectionId, parentSourceType, parentPath, parentIsConnectionRoot, - ) - return affected + if resetIds: + logger.info( + "Cascade-reset %s on %d descendants of DataSource %s (bottom-up)", + flag, len(resetIds), parentId, + ) + return resetIds + + +# --------------------------------------------------------------------------- +# DataSource: collectAncestorChain (for updatedAncestors in PATCH response) +# --------------------------------------------------------------------------- + +def collectAncestorChain( + rec: Dict[str, Any], + sameConnectionDs: Iterable[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return ancestor chain of `rec` (nearest-first), same as internal helper. + + Exposed for PATCH endpoints to compute updatedAncestors. + """ + return _findAncestorChain(rec, sameConnectionDs) + + +# --------------------------------------------------------------------------- +# DataSource: buildEffectiveByConnection +# --------------------------------------------------------------------------- + +def buildEffectiveByConnection( + dataSources: Iterable[Dict[str, Any]], + flag: str, + mode: Mode = "walk", +) -> Dict[str, Any]: + """Pre-compute the effective value of `flag` for every DataSource id. + + Uses the specified mode. O(N^2) worst case but N is bounded per connection. + """ + if flag not in _INHERITABLE_FLAGS: + raise ValueError(f"Unknown inheritable flag: {flag}") + allDs = list(dataSources) + out: Dict[str, Any] = {} + for rec in allDs: + recId = _getRecordValue(rec, "id") + out[recId] = getEffectiveFlag(rec, flag, allDs, mode=mode) + return out + + +# --------------------------------------------------------------------------- +# FeatureDataSource helpers +# --------------------------------------------------------------------------- def _fdsClassify(fds: Dict[str, Any]) -> str: """Return 'workspace' | 'table' | 'record' based on the FDS identifier shape.""" @@ -229,14 +348,14 @@ def _fdsClassify(fds: Dict[str, Any]) -> str: def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool: """Return True iff `parent` FDS is a strict ancestor of `child` FDS. - Hierarchy within one `workspaceInstanceId`: - workspace-wildcard (tableName='*') → table-wildcard (tableName='X', !recordFilter) - → record-fds (tableName='X', recordFilter.id=...) - table-wildcard (tableName='X') → record-fds (tableName='X', recordFilter.id=...) + Hierarchy within one featureInstanceId (allFds is already scoped to + a single workspace): + feature-wildcard (tableName='*') -> table-wildcard / record-fds + table-wildcard (tableName='X') -> record-fds (tableName='X') """ - parentWsId = _getRecordValue(parent, "workspaceInstanceId") - childWsId = _getRecordValue(child, "workspaceInstanceId") - if not parentWsId or parentWsId != childWsId: + parentFiId = _getRecordValue(parent, "featureInstanceId") + childFiId = _getRecordValue(child, "featureInstanceId") + if not parentFiId or parentFiId != childFiId: return False if _getRecordValue(parent, "id") == _getRecordValue(child, "id"): return False @@ -251,23 +370,68 @@ def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool: return False +def _fdsDepth(fds: Dict[str, Any]) -> int: + kind = _fdsClassify(fds) + if kind == "workspace": + return 0 + if kind == "table": + return 1 + return 2 + + +# --------------------------------------------------------------------------- +# FeatureDataSource: getEffectiveFlagFds +# --------------------------------------------------------------------------- + def getEffectiveFlagFds( rec: Dict[str, Any], flag: str, sameWorkspaceFds: Iterable[Dict[str, Any]], + mode: Mode = "walk", ) -> Any: """Resolve effective value of a FeatureDataSource flag. - Order: own (if explicit) → table-wildcard (if explicit) → - workspace-wildcard (if explicit) → static default. + mode='walk': own explicit -> table-wildcard -> workspace-wildcard -> default. + mode='aggregate': same but returns 'mixed' if descendants diverge. """ - if flag not in ("neutralize", "scope"): + if flag not in _INHERITABLE_FDS_FLAGS: raise ValueError(f"Unknown inheritable FDS flag: {flag}") + + allFds = list(sameWorkspaceFds) + walkValue = _resolveWalkValueFds(rec, flag, allFds) + + if mode == "walk": + return walkValue + + # mode == 'aggregate' + descendants = [f for f in allFds if _fdsIsAncestor(rec, f)] + if not descendants: + return walkValue + + subtreeValues = set() + subtreeValues.add(_normaliseForComparison(walkValue)) + for desc in descendants: + descEffective = _resolveWalkValueFds(desc, flag, allFds) + subtreeValues.add(_normaliseForComparison(descEffective)) + if len(subtreeValues) > 1: + recId = _getRecordValue(rec, "id") + descId = _getRecordValue(desc, "id") + descOwnVal = _getRecordValue(desc, flag) + logger.info( + "FDS aggregate MIXED for rec=%s flag=%s: walkValue=%s, " + "divergent desc=%s (own=%s, effective=%s), subtreeValues=%s", + recId, flag, walkValue, descId, descOwnVal, descEffective, subtreeValues, + ) + return "mixed" + return walkValue + + +def _resolveWalkValueFds(rec: Dict[str, Any], flag: str, allFds: List[Dict[str, Any]]) -> Any: + """Core walk resolution for FDS.""" own = _getRecordValue(rec, flag) if _isExplicit(own): return own - workspaceFds: List[Dict[str, Any]] = list(sameWorkspaceFds) - ancestors = [a for a in workspaceFds if _fdsIsAncestor(a, rec)] + ancestors = [a for a in allFds if _fdsIsAncestor(a, rec)] ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1) for ancestor in ancestors: val = _getRecordValue(ancestor, flag) @@ -276,27 +440,32 @@ def getEffectiveFlagFds( return _flagDefault(flag) +# --------------------------------------------------------------------------- +# FeatureDataSource: cascadeResetDescendantsFds (bottom-up) +# --------------------------------------------------------------------------- + def cascadeResetDescendantsFds( rootIf: Any, parentRec: Dict[str, Any], flag: str, -) -> int: +) -> List[str]: """Reset explicit `flag` to NULL on every descendant FDS of `parentRec`. - Only the targeted flag is reset; other flags on descendants are untouched. - Returns the number of records updated. + Reset order: bottom-up (deepest first) for crash safety. + Returns list of reset record IDs in bottom-up order. """ - if flag not in ("neutralize", "scope"): + if flag not in _INHERITABLE_FDS_FLAGS: raise ValueError(f"Unknown inheritable FDS flag: {flag}") from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource workspaceInstanceId = _getRecordValue(parentRec, "workspaceInstanceId") if not workspaceInstanceId: - return 0 + return [] siblings = rootIf.db.getRecordset( FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId} ) - affected = 0 + + toReset: List[Tuple[int, str]] = [] for sib in siblings: if not _fdsIsAncestor(parentRec, sib): continue @@ -304,39 +473,159 @@ def cascadeResetDescendantsFds( if not _isExplicit(sibVal): continue sibId = _getRecordValue(sib, "id") + toReset.append((_fdsDepth(sib), sibId)) + + # Sort deepest first (bottom-up) + toReset.sort(key=lambda x: x[0], reverse=True) + + resetIds: List[str] = [] + for _, sibId in toReset: try: rootIf.db.recordModify(FeatureDataSource, sibId, {flag: None}) - affected += 1 + resetIds.append(sibId) except Exception as exc: logger.warning("FDS cascade-reset failed for %s flag=%s: %s", sibId, flag, exc) - if affected: + + if resetIds: logger.info( - "FDS cascade-reset %s on %d descendants of FDS (workspaceInstanceId=%s, kind=%s)", - flag, affected, workspaceInstanceId, _fdsClassify(parentRec), + "FDS cascade-reset %s on %d descendants of FDS %s (bottom-up)", + flag, len(resetIds), _getRecordValue(parentRec, "id"), ) - return affected + return resetIds -def buildEffectiveByConnection( - dataSources: Iterable[Dict[str, Any]], - flag: str, -) -> Dict[str, Any]: - """Pre-compute the effective value of `flag` for every DataSource id. +# --------------------------------------------------------------------------- +# FeatureDataSource: collectAncestorChainFds +# --------------------------------------------------------------------------- - Useful for batch operations (walker, route DTOs) that touch many records - at once. O(N²) in the worst case but N is bounded per connection. +def collectAncestorChainFds( + rec: Dict[str, Any], + sameWorkspaceFds: Iterable[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """Return ancestor chain of `rec` FDS (nearest-first). + + Exposed for PATCH endpoints to compute updatedAncestors. """ - if flag not in _INHERITABLE_FLAGS: - raise ValueError(f"Unknown inheritable flag: {flag}") - bySourceType: Dict[Tuple[str, str], List[Dict[str, Any]]] = {} - for ds in dataSources: - connId = _getRecordValue(ds, "connectionId") or "" - srcType = _getRecordValue(ds, "sourceType") or "" - bySourceType.setdefault((connId, srcType), []).append(ds) + allFds = list(sameWorkspaceFds) + ancestors = [a for a in allFds if _fdsIsAncestor(a, rec)] + ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1) + return ancestors + +# --------------------------------------------------------------------------- +# FeatureDataSource: buildEffectiveByWorkspaceFds +# --------------------------------------------------------------------------- + +def buildEffectiveByWorkspaceFds( + fdses: Iterable[Dict[str, Any]], + flag: str, + mode: Mode = "walk", +) -> Dict[str, Any]: + """Pre-compute the effective value of `flag` for every FDS id.""" + if flag not in _INHERITABLE_FDS_FLAGS: + raise ValueError(f"Unknown inheritable FDS flag: {flag}") + allFds = list(fdses) out: Dict[str, Any] = {} - for group in bySourceType.values(): - for rec in group: - recId = _getRecordValue(rec, "id") - out[recId] = getEffectiveFlag(rec, flag, group) + for rec in allFds: + recId = _getRecordValue(rec, "id") + out[recId] = getEffectiveFlagFds(rec, flag, allFds, mode=mode) return out + + +# --------------------------------------------------------------------------- +# Bulk resolve: effective flags for arbitrary paths (even without DB record) +# --------------------------------------------------------------------------- + +def resolveEffectiveForPath( + connectionId: str, + sourceType: str, + path: str, + allDs: List[Dict[str, Any]], + mode: Mode = "aggregate", +) -> Dict[str, Any]: + """Resolve effective flags for ANY (connectionId, sourceType, path) tuple. + + Works whether or not a DataSource record exists for this exact path. + Returns dict with effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled. + """ + normPath = _normalisePath(path) + exactRecord = None + for ds in allDs: + if ( + _getRecordValue(ds, "connectionId") == connectionId + and _getRecordValue(ds, "sourceType") == sourceType + and _normalisePath(_getRecordValue(ds, "path")) == normPath + ): + exactRecord = ds + break + + if exactRecord: + return { + "effectiveNeutralize": getEffectiveFlag(exactRecord, "neutralize", allDs, mode=mode), + "effectiveScope": getEffectiveFlag(exactRecord, "scope", allDs, mode=mode), + "effectiveRagIndexEnabled": getEffectiveFlag(exactRecord, "ragIndexEnabled", allDs, mode=mode), + } + + virtualRec = { + "id": "__virtual__", + "connectionId": connectionId, + "sourceType": sourceType, + "path": normPath, + "neutralize": None, + "scope": None, + "ragIndexEnabled": None, + } + return { + "effectiveNeutralize": _resolveWalkValue(virtualRec, "neutralize", allDs), + "effectiveScope": _resolveWalkValue(virtualRec, "scope", allDs), + "effectiveRagIndexEnabled": _resolveWalkValue(virtualRec, "ragIndexEnabled", allDs), + } + + +def resolveEffectiveForFds( + featureInstanceId: str, + tableName: str, + recordFilter: Optional[Dict[str, str]], + allFds: List[Dict[str, Any]], + mode: Mode = "aggregate", +) -> Dict[str, Any]: + """Resolve effective flags for ANY FDS tuple (even without DB record). + + `allFds` is pre-scoped to a single workspace (loaded with + workspaceInstanceId filter). Within that set, the coordinate is + featureInstanceId + tableName + recordFilter. + + Returns dict with effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled. + """ + exactRecord = None + for fds in allFds: + if _getRecordValue(fds, "featureInstanceId") != featureInstanceId: + continue + if (_getRecordValue(fds, "tableName") or "") != tableName: + continue + fdsFilter = _getRecordValue(fds, "recordFilter") + if fdsFilter == recordFilter: + exactRecord = fds + break + + if exactRecord: + return { + "effectiveNeutralize": getEffectiveFlagFds(exactRecord, "neutralize", allFds, mode=mode), + "effectiveScope": getEffectiveFlagFds(exactRecord, "scope", allFds, mode=mode), + "effectiveRagIndexEnabled": getEffectiveFlagFds(exactRecord, "ragIndexEnabled", allFds, mode=mode), + } + + virtualRec = { + "id": "__virtual__", + "featureInstanceId": featureInstanceId, + "tableName": tableName, + "recordFilter": recordFilter, + "neutralize": None, + "scope": None, + "ragIndexEnabled": None, + } + return { + "effectiveNeutralize": _resolveWalkValueFds(virtualRec, "neutralize", allFds), + "effectiveScope": _resolveWalkValueFds(virtualRec, "scope", allFds), + "effectiveRagIndexEnabled": _resolveWalkValueFds(virtualRec, "ragIndexEnabled", allFds), + } diff --git a/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py b/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py index 6698e164..01c585d8 100644 --- a/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py +++ b/modules/serviceCenter/services/serviceKnowledge/mainServiceKnowledge.py @@ -147,7 +147,7 @@ class KnowledgeService: else getattr(existing, "status", "") ) or "" if existingMeta.get("hash") == contentHash and existingStatus == "indexed": - logger.info( + logger.debug( "ingestion.skipped.duplicate sourceKind=%s sourceId=%s hash=%s", job.sourceKind, job.sourceId, contentHash[:12], extra={ diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py index 618a9965..be059eef 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py @@ -431,6 +431,15 @@ def registerKnowledgeIngestionConsumer() -> None: callbackRegistry.register("connection.established", _onConnectionEstablished) callbackRegistry.register("connection.revoked", _onConnectionRevoked) registerJobHandler(BOOTSTRAP_JOB_TYPE, _bootstrapJobHandler) + + from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import ( + FEATURE_BOOTSTRAP_JOB_TYPE, _featureBootstrapHandler, + ) + registerJobHandler(FEATURE_BOOTSTRAP_JOB_TYPE, _featureBootstrapHandler) + registerDailyResyncScheduler() _registered = True - logger.info("KnowledgeIngestionConsumer registered (established/revoked + %s handler + daily resync)", BOOTSTRAP_JOB_TYPE) + logger.info( + "KnowledgeIngestionConsumer registered (established/revoked + %s + %s handler + daily resync)", + BOOTSTRAP_JOB_TYPE, FEATURE_BOOTSTRAP_JOB_TYPE, + ) diff --git a/modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py b/modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py new file mode 100644 index 00000000..aa81d929 --- /dev/null +++ b/modules/serviceCenter/services/serviceKnowledge/subFeatureBootstrap.py @@ -0,0 +1,289 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Feature-data RAG bootstrap: indexes FeatureDataSource rows into the knowledge store. + +Analogous to connection.bootstrap for external connections (Google, Microsoft), +this handler reads FeatureDataSource records with ragIndexEnabled=True, queries +the underlying feature tables via FeatureDataProvider, serialises each row into +text, and feeds it through KnowledgeService.requestIngestion so the data +appears in ContentChunk embeddings for semantic RAG search. + +Job type: ``feature.bootstrap`` +Payload: ``{"workspaceInstanceId": "...", "featureDataSourceIds": [...] (optional)}`` +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +FEATURE_BOOTSTRAP_JOB_TYPE = "feature.bootstrap" + + +def _loadRagEnabledFds(workspaceInstanceId: str, featureDataSourceIds: Optional[List[str]] = None): + """Load FeatureDataSource rows whose effective ragIndexEnabled is True. + + Returns dicts with resolved flags so downstream code can read them directly. + """ + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds + + rootIf = getRootInterface() + allFds = rootIf.db.getRecordset( + FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId} + ) + resolved = [] + for fds in allFds: + tblName = (fds.get("tableName") if isinstance(fds, dict) else getattr(fds, "tableName", "")) or "" + fCode = (fds.get("featureCode") if isinstance(fds, dict) else getattr(fds, "featureCode", "")) or "" + if tblName == "*" or not tblName or not fCode: + continue + effRag = getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate") + if effRag is not True: + continue + row = dict(fds) if isinstance(fds, dict) else {**fds.__dict__} + row["_effectiveNeutralize"] = getEffectiveFlagFds(fds, "neutralize", allFds, mode="aggregate") + row["_effectiveScope"] = getEffectiveFlagFds(fds, "scope", allFds, mode="aggregate") or "featureInstance" + row["ragIndexEnabled"] = True + resolved.append(row) + + if featureDataSourceIds: + idSet = set(featureDataSourceIds) + resolved = [r for r in resolved if r.get("id") in idSet] + return resolved + + +def _serializeRowToText(row: Dict[str, Any], neutralizeFields: Optional[List[str]] = None) -> str: + """Convert a feature-table row into readable text for embedding. + + Skips internal fields (starting with '_' or 'sys') and produces + ``key: value`` lines that embed well semantically. + """ + neutralizeSet = set(neutralizeFields or []) + lines = [] + for key, value in row.items(): + if key.startswith("_") or key.startswith("sys"): + continue + if key == "id": + continue + if value is None or value == "" or value == []: + continue + if key in neutralizeSet: + value = "[REDACTED]" + elif isinstance(value, (dict, list)): + value = json.dumps(value, ensure_ascii=False, default=str) + else: + value = str(value) + lines.append(f"{key}: {value}") + return "\n".join(lines) + + +def _getFeatureDbConnector(featureCode: str): + """Create a lightweight DB connector to the feature database.""" + from modules.connectors.connectorDbPostgre import DatabaseConnector + from modules.shared.configuration import APP_CONFIG + + dbName = f"poweron_{featureCode.lower()}" + return DatabaseConnector( + dbHost=APP_CONFIG.get("DB_HOST", "localhost"), + dbDatabase=dbName, + dbUser=APP_CONFIG.get("DB_USER"), + dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), + dbPort=int(APP_CONFIG.get("DB_PORT", 5432)), + userId="system.feature_bootstrap", + ) + + +async def _featureBootstrapHandler( + job: Dict[str, Any], + progressCb, +) -> Dict[str, Any]: + """Walk RAG-enabled FeatureDataSources and index their rows.""" + payload = job.get("payload") or {} + workspaceInstanceId = payload.get("workspaceInstanceId") + featureDataSourceIds = payload.get("featureDataSourceIds") + if not workspaceInstanceId: + raise ValueError("feature.bootstrap requires payload.workspaceInstanceId") + + progressCb(5, messageKey="Feature-Datenquellen werden geladen...") + + fdsList = _loadRagEnabledFds(workspaceInstanceId, featureDataSourceIds) + if not fdsList: + logger.info( + "feature.bootstrap.skipped — no rag-enabled FDS for workspace %s", + workspaceInstanceId, + ) + return {"workspaceInstanceId": workspaceInstanceId, "skipped": True, "reason": "no_rag_enabled_fds"} + + from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider + from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob + from modules.serviceCenter.context import ServiceCenterContext + from modules.serviceCenter import getService + from modules.security.rootAccess import getRootUser + + totalIndexed = 0 + totalSkipped = 0 + totalFailed = 0 + fdsResults = [] + + for fdsIdx, fds in enumerate(fdsList): + fdsId = fds.get("id", "") + featureCode = fds.get("featureCode", "") + tableName = fds.get("tableName", "") + featureInstanceId = fds.get("featureInstanceId", "") + mandateId = fds.get("mandateId", "") + neutralizeFields = fds.get("neutralizeFields") or [] + recordFilter = fds.get("recordFilter") or {} + effectiveScope = fds.get("_effectiveScope", "featureInstance") + effectiveNeutralize = bool(fds.get("_effectiveNeutralize", False)) + + progressPct = 5 + int(90 * fdsIdx / len(fdsList)) + progressCb( + progressPct, + messageKey="Indexiere {table} ({n}/{total})...", + messageParams={"table": tableName, "n": fdsIdx + 1, "total": len(fdsList)}, + ) + + if not featureCode or not tableName or not featureInstanceId: + logger.warning("feature.bootstrap: skipping FDS %s — missing featureCode/tableName/fiId", fdsId) + continue + + try: + dbConnector = _getFeatureDbConnector(featureCode) + provider = FeatureDataProvider(dbConnector) + + rootUser = getRootUser() + ctx = ServiceCenterContext( + user=rootUser, + mandate_id=mandateId, + feature_instance_id=workspaceInstanceId, + ) + knowledgeService = getService("knowledge", ctx) + + extraFilters = [ + {"field": k, "op": "=", "value": v} + for k, v in recordFilter.items() + ] if recordFilter else None + + batchSize = 200 + offset = 0 + fdsIndexed = 0 + fdsSkipped = 0 + fdsFailed = 0 + + while True: + result = provider.browseTable( + tableName=tableName, + featureInstanceId=featureInstanceId, + mandateId=mandateId, + limit=batchSize, + offset=offset, + extraFilters=extraFilters, + ) + rows = result.get("rows", []) + if not rows: + break + + for row in rows: + rowId = row.get("id", "") + if not rowId: + continue + + textContent = _serializeRowToText(row, neutralizeFields if effectiveNeutralize else None) + if not textContent.strip(): + fdsSkipped += 1 + continue + + contentVersion = str(row.get("sysUpdatedAt") or row.get("sysCreatedAt") or "") + + ingestionJob = IngestionJob( + sourceKind="feature_record", + sourceId=f"{workspaceInstanceId}:{tableName}:{rowId}", + fileName=f"{tableName}-{rowId}", + mimeType="application/vnd.poweron.feature-record+json", + userId=fds.get("userId") or "system", + featureInstanceId=workspaceInstanceId, + mandateId=mandateId, + contentObjects=[{ + "contentType": "text", + "data": textContent, + "contextRef": { + "table": tableName, + "featureCode": featureCode, + "featureInstanceId": featureInstanceId, + "rowId": rowId, + }, + "contentObjectId": f"{tableName}:{rowId}", + }], + structure={"sourceTable": tableName, "featureCode": featureCode}, + contentVersion=contentVersion, + provenance={ + "featureDataSourceId": fdsId, + "tableName": tableName, + "featureCode": featureCode, + "featureInstanceId": featureInstanceId, + }, + neutralize=effectiveNeutralize, + ) + + try: + handle = await knowledgeService.requestIngestion(ingestionJob) + if handle.status == "failed": + fdsFailed += 1 + logger.warning( + "feature.bootstrap: ingestion failed fds=%s table=%s row=%s error=%s", + fdsId, tableName, rowId, handle.error, + ) + elif handle.status == "duplicate": + fdsSkipped += 1 + else: + fdsIndexed += 1 + except Exception as ingErr: + fdsFailed += 1 + logger.error( + "feature.bootstrap: ingestion error fds=%s row=%s: %s", + fdsId, rowId, ingErr, + ) + + offset += batchSize + if len(rows) < batchSize: + break + + totalIndexed += fdsIndexed + totalSkipped += fdsSkipped + totalFailed += fdsFailed + + fdsResults.append({ + "featureDataSourceId": fdsId, + "tableName": tableName, + "featureCode": featureCode, + "indexed": fdsIndexed, + "skippedDuplicate": fdsSkipped, + "failed": fdsFailed, + }) + + except Exception as fdsErr: + logger.error( + "feature.bootstrap: error processing FDS %s (%s.%s): %s", + fdsId, featureCode, tableName, fdsErr, exc_info=True, + ) + fdsResults.append({ + "featureDataSourceId": fdsId, + "tableName": tableName, + "featureCode": featureCode, + "error": str(fdsErr), + }) + + progressCb(100, messageKey="Feature-Daten-Sync abgeschlossen.") + + return { + "workspaceInstanceId": workspaceInstanceId, + "indexed": totalIndexed, + "skippedDuplicate": totalSkipped, + "failed": totalFailed, + "dataSources": fdsResults, + } diff --git a/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py b/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py deleted file mode 100644 index 0deae777..00000000 --- a/modules/serviceCenter/services/serviceKnowledge/subPolicyResolver.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2025 Patrick Motsch -# All rights reserved. -"""DEPRECATED: Use `_inheritFlags.getEffectiveFlag()` directly. - -Thin shim to the new cascade-inherit helper. Kept so external callers don't -break on import — internal walkers consume pre-resolved dicts via -`_loadRagEnabledDataSources`. -""" - -from __future__ import annotations - -from typing import Any, Dict, List - -from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag - - -def resolveEffectiveNeutralize( - ds: Dict[str, Any], - allDataSources: List[Dict[str, Any]], -) -> bool: - """DEPRECATED: use `getEffectiveFlag(ds, 'neutralize', allDataSources)`.""" - value = getEffectiveFlag(ds, "neutralize", allDataSources) - return bool(value) - - -def resolveEffectiveRagIndexEnabled( - ds: Dict[str, Any], - allDataSources: List[Dict[str, Any]], -) -> bool: - """DEPRECATED: use `getEffectiveFlag(ds, 'ragIndexEnabled', allDataSources)`.""" - value = getEffectiveFlag(ds, "ragIndexEnabled", allDataSources) - return bool(value) diff --git a/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py b/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py index 8e65fd0f..41d9d458 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py +++ b/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py @@ -15,8 +15,9 @@ up with "Job stuck at 10% for 10h" zombies. These helpers wrap each phase in `asyncio.wait_for`. Sync extraction runs on a worker thread so the loop stays responsive. Every wrapped call also -emits a short start/done log line, so when something hangs we know the -exact item that caused it (path, size, mime). +emits start/done log lines at DEBUG so normal INFO logs stay quiet; for +stuck-job triage, enable DEBUG for this module — the last +``walker.item.start`` before a hang still pinpoints the item (path, size, mime). """ from __future__ import annotations @@ -48,7 +49,7 @@ async def downloadWithTimeout( used in log messages so we can pinpoint the offending item in case of a hang or timeout. """ - logger.info("walker.download.start %s timeout=%ds", label, timeoutSeconds) + logger.debug("walker.download.start %s timeout=%ds", label, timeoutSeconds) try: result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds) logger.debug("walker.download.done %s", label) @@ -71,7 +72,7 @@ async def extractWithTimeout( keep running until the process exits — but at least the walker proceeds to the next item instead of freezing forever. """ - logger.info("walker.extract.start %s timeout=%ds", label, timeoutSeconds) + logger.debug("walker.extract.start %s timeout=%ds", label, timeoutSeconds) try: result = await asyncio.wait_for( asyncio.to_thread(syncFn, *args), @@ -102,15 +103,15 @@ async def ingestWithTimeout( def logItemStart(service: str, label: str, *, sizeBytes: Optional[int] = None, mime: Optional[str] = None) -> None: - """Log that processing of one item is about to begin. + """Log that processing of one item is about to begin (DEBUG). When the worker hangs, the LAST `walker.item.start` line in the log - points to the exact item that caused the freeze. This is the single - most valuable diagnostic for stuck-job triage. + points to the exact item that caused the freeze. Enable DEBUG for this + module during triage. """ parts = [f"walker.item.start service={service} path={label}"] if sizeBytes is not None: parts.append(f"size={sizeBytes}") if mime: parts.append(f"mime={mime}") - logger.info(" ".join(parts)) + logger.debug(" ".join(parts)) diff --git a/scripts/script_migrate_user_uid.py b/scripts/script_migrate_user_uid.py new file mode 100644 index 00000000..07f9b443 --- /dev/null +++ b/scripts/script_migrate_user_uid.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +"""One-time migration: Reassign all DB references from an old user UID to a new UID. + +When a user is re-created in PORTA (same username, new UUID), all existing records +still reference the old UUID. This script scans ALL registered databases and tables +for VARCHAR columns containing the old UID and updates them to the new UID. + +Affected columns include: + - sysCreatedBy / sysModifiedBy (on every table via PowerOnModel) + - userId, revokedBy, createdByUserId, publishedBy, triggeredBy, assignedTo, etc. + +The script auto-detects the new UID from the UserInDB table by username. + +Usage: + # Dry-run (default) — shows what would change, no writes: + python scripts/script_migrate_user_uid.py --username patrick.helvetia --old-uid + + # Execute for real: + python scripts/script_migrate_user_uid.py --username patrick.helvetia --old-uid --execute +""" + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import List, Optional, Tuple + +scriptPath = Path(__file__).resolve() +gatewayPath = scriptPath.parent.parent +sys.path.insert(0, str(gatewayPath)) +os.chdir(str(gatewayPath)) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True) +logger = logging.getLogger(__name__) + +import psycopg2 +import psycopg2.extras +from modules.shared.configuration import APP_CONFIG + + +ALL_DATABASES = [ + "poweron_app", + "poweron_chat", + "poweron_management", + "poweron_knowledge", + "poweron_billing", + "poweron_workspace", + "poweron_graphicaleditor", + "poweron_chatbot", + "poweron_trustee", + "poweron_commcoach", + "poweron_neutralization", + "poweron_realestate", + "poweron_teamsbot", +] + + +def _getConnection(dbName: str): + return psycopg2.connect( + host=APP_CONFIG.get("DB_HOST", "localhost"), + port=int(APP_CONFIG.get("DB_PORT", "5432")), + database=dbName, + user=APP_CONFIG.get("DB_USER"), + password=APP_CONFIG.get("DB_PASSWORD_SECRET"), + client_encoding="utf8", + ) + + +def _getTablesInDb(conn) -> List[str]: + with conn.cursor() as cur: + cur.execute(""" + SELECT table_name FROM information_schema.tables + WHERE table_schema = 'public' + AND table_type = 'BASE TABLE' + AND table_name NOT LIKE '\\_%%' + ORDER BY table_name + """) + return [row[0] for row in cur.fetchall()] + + +def _getVarcharColumns(conn, tableName: str) -> List[str]: + """Get all VARCHAR/TEXT columns for a table (potential user-ID carriers).""" + with conn.cursor() as cur: + cur.execute(""" + SELECT column_name FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = %s + AND data_type IN ('character varying', 'text') + ORDER BY ordinal_position + """, (tableName,)) + return [row[0] for row in cur.fetchall()] + + +def _countMatches(conn, tableName: str, columnName: str, oldUid: str) -> int: + with conn.cursor() as cur: + cur.execute( + f'SELECT COUNT(*) FROM "{tableName}" WHERE "{columnName}" = %s', + (oldUid,), + ) + return cur.fetchone()[0] + + +def _updateColumn(conn, tableName: str, columnName: str, oldUid: str, newUid: str) -> int: + with conn.cursor() as cur: + cur.execute( + f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE "{columnName}" = %s', + (newUid, oldUid), + ) + return cur.rowcount + + +def _lookupNewUid(username: str) -> Optional[str]: + """Find the current UID for a username in poweron_app.UserInDB.""" + conn = _getConnection("poweron_app") + try: + with conn.cursor() as cur: + cur.execute( + 'SELECT "id" FROM "UserInDB" WHERE "username" = %s', + (username,), + ) + row = cur.fetchone() + return row[0] if row else None + finally: + conn.close() + + +def _scanJsonbForUid(conn, tableName: str, columnName: str, oldUid: str) -> int: + """Count JSONB fields that contain the old UID as a text value anywhere.""" + with conn.cursor() as cur: + cur.execute( + f"""SELECT COUNT(*) FROM "{tableName}" + WHERE "{columnName}"::text LIKE %s""", + (f"%{oldUid}%",), + ) + return cur.fetchone()[0] + + +def _updateJsonbColumn(conn, tableName: str, columnName: str, oldUid: str, newUid: str) -> int: + """Replace old UID inside JSONB columns using text replacement.""" + with conn.cursor() as cur: + cur.execute( + f"""UPDATE "{tableName}" + SET "{columnName}" = REPLACE("{columnName}"::text, %s, %s)::jsonb + WHERE "{columnName}"::text LIKE %s""", + (oldUid, newUid, f"%{oldUid}%"), + ) + return cur.rowcount + + +def _getJsonbColumns(conn, tableName: str) -> List[str]: + """Get all JSONB columns for a table.""" + with conn.cursor() as cur: + cur.execute(""" + SELECT column_name FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = %s + AND data_type = 'jsonb' + ORDER BY ordinal_position + """, (tableName,)) + return [row[0] for row in cur.fetchall()] + + +def migrate(username: str, oldUid: str, execute: bool = False): + newUid = _lookupNewUid(username) + if not newUid: + logger.error(f"User '{username}' not found in UserInDB. Cannot determine new UID.") + sys.exit(1) + + if newUid == oldUid: + logger.error(f"Old UID and new UID are identical ({oldUid}). Nothing to migrate.") + sys.exit(1) + + logger.info(f"Migration: user '{username}'") + logger.info(f" Old UID: {oldUid}") + logger.info(f" New UID: {newUid}") + logger.info(f" Mode: {'EXECUTE' if execute else 'DRY-RUN'}") + logger.info("") + + totalUpdated = 0 + findings: List[Tuple[str, str, str, int]] = [] + + for dbName in ALL_DATABASES: + try: + conn = _getConnection(dbName) + except Exception as e: + logger.warning(f" Cannot connect to {dbName}: {e}") + continue + + try: + conn.autocommit = False + tables = _getTablesInDb(conn) + + for tableName in tables: + varcharCols = _getVarcharColumns(conn, tableName) + for col in varcharCols: + count = _countMatches(conn, tableName, col, oldUid) + if count > 0: + findings.append((dbName, tableName, col, count)) + if execute: + updated = _updateColumn(conn, tableName, col, oldUid, newUid) + totalUpdated += updated + logger.info(f" [UPDATED] {dbName}.{tableName}.{col}: {updated} rows") + else: + logger.info(f" [DRY-RUN] {dbName}.{tableName}.{col}: {count} rows would be updated") + + jsonbCols = _getJsonbColumns(conn, tableName) + for col in jsonbCols: + count = _scanJsonbForUid(conn, tableName, col, oldUid) + if count > 0: + findings.append((dbName, tableName, f"{col} (JSONB)", count)) + if execute: + _updateJsonbColumn(conn, tableName, col, oldUid, newUid) + totalUpdated += count + logger.info(f" [UPDATED] {dbName}.{tableName}.{col} (JSONB): {count} rows") + else: + logger.info(f" [DRY-RUN] {dbName}.{tableName}.{col} (JSONB): {count} rows would be updated") + + if execute: + conn.commit() + else: + conn.rollback() + except Exception as e: + conn.rollback() + logger.error(f" Error processing {dbName}: {e}") + finally: + conn.close() + + logger.info("") + logger.info("=" * 70) + logger.info("SUMMARY") + logger.info("=" * 70) + if not findings: + logger.info(" No references to old UID found in any database.") + else: + logger.info(f" Found {len(findings)} column(s) with references to old UID:") + for dbName, tableName, col, count in findings: + logger.info(f" {dbName}.{tableName}.{col}: {count} rows") + logger.info("") + if execute: + logger.info(f" Total rows updated: {totalUpdated}") + else: + logger.info(f" Total rows that would be updated: {sum(c for _, _, _, c in findings)}") + logger.info("") + logger.info(" To apply changes, re-run with --execute") + + +def main(): + parser = argparse.ArgumentParser( + description="Migrate all DB references from old user UID to new UID." + ) + parser.add_argument( + "--username", + required=True, + help="Username to migrate (e.g. 'patrick.helvetia'). Used to look up the new UID.", + ) + parser.add_argument( + "--old-uid", + required=True, + help="The old UUID that is orphaned in the database.", + ) + parser.add_argument( + "--execute", + action="store_true", + default=False, + help="Actually perform the migration. Without this flag, only a dry-run is done.", + ) + args = parser.parse_args() + + migrate(username=args.username, oldUid=args.old_uid, execute=args.execute) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py index 57094760..5fb505d7 100644 --- a/tests/unit/connectors/test_connectorDbPostgre_failLoud.py +++ b/tests/unit/connectors/test_connectorDbPostgre_failLoud.py @@ -30,6 +30,7 @@ import psycopg2.errors from modules.connectors.connectorDbPostgre import ( DatabaseConnector, DatabaseQueryError, + _stripNulBytesFromStr, ) @@ -164,3 +165,12 @@ class TestGetRecordFailLoud: assert excinfo.value.table == "DummyTable" conn.rollback.assert_called_once() + + +class TestStripNulBytesFromStr: + def test_removesNul(self): + assert _stripNulBytesFromStr("a\x00b") == "ab" + + def test_passthroughNonStr(self): + assert _stripNulBytesFromStr(None) is None + assert _stripNulBytesFromStr(7) == 7 diff --git a/tests/unit/services/test_buildTree.py b/tests/unit/services/test_buildTree.py new file mode 100644 index 00000000..5a2bacb4 --- /dev/null +++ b/tests/unit/services/test_buildTree.py @@ -0,0 +1,359 @@ +"""Unit tests for the generic UDB tree builder. + +Verifies key encoding/decoding and that children for parent keys with +existing handlers (top-level, conn, mgrp, feat) are produced with the +correct effective-flag triplet. +""" + +from __future__ import annotations + +import asyncio +import unittest +from unittest.mock import MagicMock, patch + +from modules.serviceCenter.services.serviceKnowledge import _buildTree + + +class TestKeyCoding(unittest.TestCase): + def test_encode_decode_roundtrip(self): + key = _buildTree._encode("ds", "conn-1", "sharepointFolder", "/sites/x") + kind, parts = _buildTree._decode(key) + self.assertEqual(kind, "ds") + self.assertEqual(parts, ["conn-1", "sharepointFolder", "/sites/x"]) + + def test_top_level_kinds(self): + self.assertEqual(_buildTree._decode("conn|abc")[0], "conn") + self.assertEqual(_buildTree._decode("mgrp|m1")[0], "mgrp") + self.assertEqual(_buildTree._decode("feat|m1|trustee|fi-1")[1], ["m1", "trustee", "fi-1"]) + + +class TestEffectiveTriplets(unittest.TestCase): + def test_ds_triplet_no_record_returns_defaults(self): + result = _buildTree._effectiveTripletDs("c", "msft", "/", []) + self.assertEqual(result, { + "effectiveNeutralize": False, + "effectiveScope": "personal", + "effectiveRagIndexEnabled": False, + }) + + def test_ds_triplet_inherits_from_root(self): + root = { + "id": "r", "connectionId": "c", "sourceType": "msft", "path": "/", + "neutralize": True, "scope": "mandate", "ragIndexEnabled": True, + } + result = _buildTree._effectiveTripletDs("c", "sharepointFolder", "/sites/x", [root]) + self.assertEqual(result["effectiveNeutralize"], True) + self.assertEqual(result["effectiveScope"], "mandate") + self.assertEqual(result["effectiveRagIndexEnabled"], True) + + def test_fds_triplet_inherits_from_workspace_wildcard(self): + ws = { + "id": "ws", "workspaceInstanceId": "ws-inst", "featureInstanceId": "fi1", + "tableName": "*", "recordFilter": None, "neutralize": True, + "scope": "mandate", "ragIndexEnabled": True, + } + result = _buildTree._effectiveTripletFds("fi1", "Pos", None, [ws]) + self.assertEqual(result["effectiveNeutralize"], True) + self.assertEqual(result["effectiveScope"], "mandate") + self.assertEqual(result["effectiveRagIndexEnabled"], True) + + +class TestRecordLookup(unittest.TestCase): + def test_finds_ds_record_by_normalised_path(self): + rec = {"id": "x", "connectionId": "c", "sourceType": "msft", "path": "/folder"} + self.assertEqual(_buildTree._findDsRecord([rec], "c", "msft", "/folder/").get("id"), "x") + self.assertIsNone(_buildTree._findDsRecord([rec], "c", "msft", "/other")) + + def test_finds_fds_record_with_matching_filter(self): + rec = {"id": "f", "workspaceInstanceId": "ws", "featureInstanceId": "fi1", "tableName": "Pos", "recordFilter": {"id": "5"}} + self.assertEqual(_buildTree._findFdsRecord([rec], "fi1", "Pos", {"id": "5"}).get("id"), "f") + self.assertIsNone(_buildTree._findFdsRecord([rec], "fi1", "Pos", {"id": "99"})) + + def test_fds_record_with_none_filter_matches_only_none(self): + rec = {"id": "f", "workspaceInstanceId": "ws", "featureInstanceId": "fi1", "tableName": "*", "recordFilter": None} + self.assertEqual(_buildTree._findFdsRecord([rec], "fi1", "*", None).get("id"), "f") + self.assertIsNone(_buildTree._findFdsRecord([rec], "fi1", "*", {"id": "1"})) + + +class TestGetChildrenForParents(unittest.TestCase): + """End-to-end orchestrator test with mocked dependencies.""" + + def _runAsync(self, coro): + return asyncio.get_event_loop().run_until_complete(coro) + + def test_unknown_parent_key_returns_empty_list(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + mockRoot.return_value = rootIf + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = "m1" + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", ["bogus|key"], ctx) + ) + self.assertEqual(result["bogus|key"], []) + + def test_top_level_emits_personal_root_first(self): + """Top-level emits personalRoot first, then mandate-group nodes inline.""" + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + rootIf.getUserMandates.return_value = [] + mockRoot.return_value = rootIf + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = "m1" + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", [None], ctx) + ) + children = result["__root__"] + self.assertGreaterEqual(len(children), 1) + personalRoot = children[0] + self.assertEqual(personalRoot["key"], "personalRoot") + self.assertEqual(personalRoot["kind"], "synthRoot") + self.assertIsNone(personalRoot["parentKey"]) + self.assertTrue(personalRoot["hasChildren"]) + self.assertTrue(personalRoot["defaultExpanded"]) + + +class TestTopLevelLayout(unittest.TestCase): + """Tests for the flat top-level layout (personalRoot + mandate groups).""" + + def _runAsync(self, coro): + return asyncio.get_event_loop().run_until_complete(coro) + + def test_personal_root_carries_neutral_default_triplet(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + rootIf.getUserMandates.return_value = [] + mockRoot.return_value = rootIf + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = "m1" + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", [None], ctx) + ) + personalRoot = result["__root__"][0] + self.assertFalse(personalRoot["effectiveNeutralize"]) + self.assertEqual(personalRoot["effectiveScope"], "personal") + self.assertFalse(personalRoot["effectiveRagIndexEnabled"]) + self.assertFalse(personalRoot["supportsRag"]) + self.assertFalse(personalRoot["canBeAdded"]) + self.assertIsNone(personalRoot["dataSourceId"]) + self.assertIsNone(personalRoot["modelType"]) + + def test_personal_root_emits_active_connection_with_correct_parent(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \ + patch("modules.serviceCenter.getService") as mockGetService: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + mockRoot.return_value = rootIf + + chatService = MagicMock() + chatService.getUserConnections.return_value = [{ + "id": "conn-1", + "status": "active", + "authority": "msft", + "externalEmail": "user@example.com", + }] + mockGetService.return_value = chatService + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = "m1" + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", ["personalRoot"], ctx) + ) + children = result["personalRoot"] + self.assertEqual(len(children), 1) + self.assertEqual(children[0]["key"], "conn|conn-1") + self.assertEqual(children[0]["kind"], "connection") + self.assertEqual(children[0]["parentKey"], "personalRoot") + self.assertEqual(children[0]["label"], "user@example.com") + self.assertTrue(children[0]["supportsRag"]) + + def test_personal_root_skips_inactive_connection(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \ + patch("modules.serviceCenter.getService") as mockGetService: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + mockRoot.return_value = rootIf + + chatService = MagicMock() + chatService.getUserConnections.return_value = [ + {"id": "c1", "status": "active", "authority": "msft", "externalEmail": "a"}, + {"id": "c2", "status": "expired", "authority": "google", "externalEmail": "b"}, + ] + mockGetService.return_value = chatService + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = "m1" + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", ["personalRoot"], ctx) + ) + self.assertEqual(len(result["personalRoot"]), 1) + self.assertEqual(result["personalRoot"][0]["connectionId"], "c1") + + def test_mandate_groups_emitted_inline_at_top_level(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \ + patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + userMandate = MagicMock() + userMandate.mandateId = "m1" + rootIf.getUserMandates.return_value = [userMandate] + featureInst = MagicMock() + featureInst.id = "fi-1" + featureInst.featureCode = "trustee" + featureInst.enabled = True + rootIf.getFeatureInstancesByMandate.return_value = [featureInst] + featureAccess = MagicMock() + featureAccess.enabled = True + rootIf.getFeatureAccess.return_value = featureAccess + mockRoot.return_value = rootIf + + catalog = MagicMock() + catalog.getFeaturesWithDataObjects.return_value = ["trustee"] + mockCatalog.return_value = catalog + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = None + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", [None], ctx) + ) + children = result["__root__"] + byKey = {c["key"]: c for c in children} + self.assertIn("personalRoot", byKey) + self.assertIn("mgrp|m1", byKey) + mgroup = byKey["mgrp|m1"] + self.assertEqual(mgroup["kind"], "mandateGroup") + self.assertIsNone(mgroup["parentKey"]) + self.assertEqual(mgroup["mandateId"], "m1") + self.assertTrue(mgroup["defaultExpanded"]) + self.assertFalse(mgroup["supportsRag"]) + + def test_top_level_omits_mandates_without_data_features(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \ + patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + userMandate = MagicMock() + userMandate.mandateId = "m1" + rootIf.getUserMandates.return_value = [userMandate] + rootIf.getFeatureInstancesByMandate.return_value = [] + mockRoot.return_value = rootIf + + catalog = MagicMock() + catalog.getFeaturesWithDataObjects.return_value = ["trustee"] + mockCatalog.return_value = catalog + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = None + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", [None], ctx) + ) + keys = [c["key"] for c in result["__root__"]] + self.assertEqual(keys, ["personalRoot"]) + + def test_personal_root_listed_first_via_display_order(self): + with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \ + patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog: + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = [] + userMandate = MagicMock() + userMandate.mandateId = "m1" + rootIf.getUserMandates.return_value = [userMandate] + featureInst = MagicMock() + featureInst.id = "fi-1" + featureInst.featureCode = "trustee" + featureInst.enabled = True + rootIf.getFeatureInstancesByMandate.return_value = [featureInst] + featureAccess = MagicMock() + featureAccess.enabled = True + rootIf.getFeatureAccess.return_value = featureAccess + mockRoot.return_value = rootIf + + catalog = MagicMock() + catalog.getFeaturesWithDataObjects.return_value = ["trustee"] + mockCatalog.return_value = catalog + + ctx = MagicMock() + ctx.user.id = "u1" + ctx.mandateId = None + + result = self._runAsync( + _buildTree.getChildrenForParents("inst-1", [None], ctx) + ) + children = result["__root__"] + self.assertEqual(children[0]["key"], "personalRoot") + self.assertEqual(children[0]["displayOrder"], 0) + + +class TestFeatureTableFields(unittest.TestCase): + """Per-column field expansion under a feature data-source table.""" + + def test_emits_one_node_per_field(self): + nodes = _buildTree._featureTableFields( + parentKey="fdstbl|fi-1|TrusteePosition", + featureInstanceId="fi-1", + tableName="TrusteePosition", + fieldNames=["id", "valuta", "company"], + allFds=[], + ) + self.assertEqual(len(nodes), 3) + self.assertEqual(nodes[0]["kind"], "fdsField") + self.assertEqual(nodes[0]["fieldName"], "id") + self.assertEqual(nodes[0]["parentKey"], "fdstbl|fi-1|TrusteePosition") + self.assertEqual(nodes[0]["key"], "fdsfld|fi-1|TrusteePosition|id") + self.assertFalse(nodes[0]["hasChildren"]) + self.assertFalse(nodes[0]["supportsRag"]) + + def test_field_neutralize_inherits_from_table_blanket(self): + rec = {"id": "f", "workspaceInstanceId": "ws-1", "featureInstanceId": "fi-1", + "tableName": "TrusteePosition", "recordFilter": None, + "neutralize": True, "neutralizeFields": None, + "scope": None, "ragIndexEnabled": False} + nodes = _buildTree._featureTableFields( + parentKey="fdstbl|fi-1|TrusteePosition", + featureInstanceId="fi-1", + tableName="TrusteePosition", + fieldNames=["email", "company"], + allFds=[rec], + ) + self.assertTrue(nodes[0]["effectiveNeutralize"]) + self.assertTrue(nodes[1]["effectiveNeutralize"]) + + def test_field_neutralize_explicit_via_neutralize_fields(self): + rec = {"id": "f", "workspaceInstanceId": "ws-1", "featureInstanceId": "fi-1", + "tableName": "TrusteePosition", "recordFilter": None, + "neutralize": False, "neutralizeFields": ["email"], + "scope": None, "ragIndexEnabled": False} + nodes = _buildTree._featureTableFields( + parentKey="fdstbl|fi-1|TrusteePosition", + featureInstanceId="fi-1", + tableName="TrusteePosition", + fieldNames=["email", "company"], + allFds=[rec], + ) + byField = {n["fieldName"]: n for n in nodes} + self.assertTrue(byField["email"]["effectiveNeutralize"]) + self.assertFalse(byField["company"]["effectiveNeutralize"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/services/test_inheritFlags.py b/tests/unit/services/test_inheritFlags.py index b177e767..98e6fb41 100644 --- a/tests/unit/services/test_inheritFlags.py +++ b/tests/unit/services/test_inheritFlags.py @@ -1,12 +1,12 @@ """Unit tests for `_inheritFlags` cascade-inherit helpers. Verifies: -- getEffectiveFlag walks ancestors via path-prefix matching -- root default is False (or 'personal' for scope) when nothing explicit in chain -- only same-connectionId AND same-sourceType ancestors are considered -- cascadeResetDescendants only touches descendants with explicit values for THAT flag -- '/' is treated as ancestor of every non-root path -- '/foo' is NOT ancestor of '/foobar' (must require '/' separator) +- getEffectiveFlag mode='walk': walks ancestors via path-prefix matching +- getEffectiveFlag mode='aggregate': returns 'mixed' when subtree diverges +- cascadeResetDescendants: bottom-up reset returning List[str] +- cascadeResetDescendantsFds: same for FeatureDataSource +- collectAncestorChain / collectAncestorChainFds: ancestor discovery +- buildEffectiveByConnection / buildEffectiveByWorkspaceFds: batch compute """ from __future__ import annotations @@ -33,7 +33,26 @@ def _ds(idVal: str, path: str, **flags) -> dict: return base -class TestEffectiveFlag(unittest.TestCase): +def _fds(idVal: str, *, tableName: str, recordFilter=None, featureInstanceId="fi-1", **flags) -> dict: + """Build a FeatureDataSource dict fixture.""" + base = { + "id": idVal, + "workspaceInstanceId": "ws-1", + "featureInstanceId": featureInstanceId, + "tableName": tableName, + "recordFilter": recordFilter, + "neutralize": None, + "scope": None, + } + base.update(flags) + return base + + +# =========================================================================== +# DataSource: getEffectiveFlag mode='walk' +# =========================================================================== + +class TestEffectiveFlagWalk(unittest.TestCase): def test_explicit_own_value_wins(self): root = _ds("r", "/", neutralize=False) leaf = _ds("l", "/folder/sub", neutralize=True) @@ -65,7 +84,6 @@ class TestEffectiveFlag(unittest.TestCase): self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [otherType, leaf])) def test_path_separator_required(self): - """`/foo` must NOT be ancestor of `/foobar` (no shared `/` boundary).""" notAncestor = _ds("a", "/foo", neutralize=True) leaf = _ds("l", "/foobar") self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [notAncestor, leaf])) @@ -90,32 +108,101 @@ class TestEffectiveFlag(unittest.TestCase): _inheritFlags.getEffectiveFlag(leaf, "unknownFlag", [leaf]) def test_explicit_false_overrides_inherited_true(self): - """Explicit False on a child must NOT cascade up to True from an ancestor.""" root = _ds("r", "/", neutralize=True) leaf = _ds("l", "/folder", neutralize=False) self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf])) def test_connection_root_inherits_cross_sourcetype(self): - """Connection-root (sourceType=authority, path='/') is ancestor of all DS in that connection.""" connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) spService = _ds("sp", "/", sourceType="sharepointFolder") olService = _ds("ol", "/", sourceType="outlookFolder") - self.assertTrue(_inheritFlags.getEffectiveFlag(spService, "neutralize", [connRoot, spService, olService])) - self.assertTrue(_inheritFlags.getEffectiveFlag(olService, "neutralize", [connRoot, spService, olService])) + allDs = [connRoot, spService, olService] + self.assertTrue(_inheritFlags.getEffectiveFlag(spService, "neutralize", allDs)) + self.assertTrue(_inheritFlags.getEffectiveFlag(olService, "neutralize", allDs)) def test_same_sourcetype_ancestor_wins_over_connection_root(self): - """A same-sourceType service-root ancestor beats the connection-root.""" connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) spRoot = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False) spLeaf = _ds("spl", "/sites/x", sourceType="sharepointFolder") self.assertFalse(_inheritFlags.getEffectiveFlag(spLeaf, "neutralize", [connRoot, spRoot, spLeaf])) def test_connection_root_does_not_self_inherit(self): - """Connection-root has no ancestor — does not infinite-loop on itself.""" connRoot = _ds("conn", "/", sourceType="msft") self.assertFalse(_inheritFlags.getEffectiveFlag(connRoot, "neutralize", [connRoot])) +# =========================================================================== +# DataSource: getEffectiveFlag mode='aggregate' +# =========================================================================== + +class TestEffectiveFlagAggregate(unittest.TestCase): + def test_leaf_without_descendants_returns_concrete(self): + leaf = _ds("l", "/folder", neutralize=True) + self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [leaf], mode="aggregate")) + + def test_all_descendants_same_returns_concrete(self): + root = _ds("r", "/", neutralize=True) + child1 = _ds("c1", "/a", neutralize=True) + child2 = _ds("c2", "/b") # inherits True from root + allDs = [root, child1, child2] + self.assertTrue(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate")) + + def test_divergent_descendants_returns_mixed(self): + root = _ds("r", "/", neutralize=True) + child1 = _ds("c1", "/a", neutralize=False) + child2 = _ds("c2", "/b") # inherits True from root + allDs = [root, child1, child2] + self.assertEqual(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate"), "mixed") + + def test_mixed_scope(self): + root = _ds("r", "/", scope="personal") + child1 = _ds("c1", "/a", scope="team") + child2 = _ds("c2", "/b") # inherits personal from root + allDs = [root, child1, child2] + self.assertEqual(_inheritFlags.getEffectiveFlag(root, "scope", allDs, mode="aggregate"), "mixed") + + def test_all_scope_same_explicit_returns_concrete(self): + root = _ds("r", "/", scope="team") + child1 = _ds("c1", "/a", scope="team") + child2 = _ds("c2", "/b") # inherits team + allDs = [root, child1, child2] + self.assertEqual(_inheritFlags.getEffectiveFlag(root, "scope", allDs, mode="aggregate"), "team") + + def test_connection_root_aggregate_cross_sourcetype(self): + connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) + spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False) + olInherit = _ds("ol", "/", sourceType="outlookFolder") # inherits True + allDs = [connRoot, spExplicit, olInherit] + self.assertEqual( + _inheritFlags.getEffectiveFlag(connRoot, "neutralize", allDs, mode="aggregate"), + "mixed", + ) + + def test_mid_level_aggregate_only_considers_own_subtree(self): + root = _ds("r", "/", neutralize=True) + mid = _ds("m", "/folder", neutralize=True) + midChild = _ds("mc", "/folder/sub", neutralize=True) + sibling = _ds("s", "/other", neutralize=False) # not under mid + allDs = [root, mid, midChild, sibling] + # mid's subtree is just midChild(True) + mid(True) = uniform + self.assertTrue(_inheritFlags.getEffectiveFlag(mid, "neutralize", allDs, mode="aggregate")) + # root's subtree includes sibling(False) = mixed + self.assertEqual( + _inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate"), + "mixed", + ) + + def test_walk_mode_never_returns_mixed(self): + root = _ds("r", "/", neutralize=True) + child = _ds("c", "/a", neutralize=False) + allDs = [root, child] + self.assertTrue(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="walk")) + + +# =========================================================================== +# DataSource: cascadeResetDescendants (bottom-up, List[str]) +# =========================================================================== + class TestCascadeReset(unittest.TestCase): def _makeRootIf(self, dataSources: List[dict]): rootIf = MagicMock() @@ -127,54 +214,76 @@ class TestCascadeReset(unittest.TestCase): rootIf.db.recordModify = MagicMock(side_effect=_modify) return rootIf, modified + def test_returns_list_of_ids(self): + parent = _ds("p", "/sites", neutralize=True) + child = _ds("c1", "/sites/folder1", neutralize=False) + rootIf, _ = self._makeRootIf([parent, child]) + result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + self.assertIsInstance(result, list) + self.assertEqual(result, ["c1"]) + def test_resets_only_explicit_descendants(self): parent = _ds("p", "/sites", neutralize=True) explicitChild = _ds("c1", "/sites/folder1", neutralize=False) - inheritChild = _ds("c2", "/sites/folder2") # inherit -> not touched - sibling = _ds("s", "/other", neutralize=True) # NOT a descendant + inheritChild = _ds("c2", "/sites/folder2") + sibling = _ds("s", "/other", neutralize=True) rootIf, modified = self._makeRootIf([parent, explicitChild, inheritChild, sibling]) - affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") - self.assertEqual(affected, 1) + self.assertEqual(result, ["c1"]) self.assertEqual(modified, [("c1", {"neutralize": None})]) - def test_does_not_touch_other_flags(self): - parent = _ds("p", "/sites", neutralize=True) - child = _ds("c", "/sites/sub", neutralize=False, ragIndexEnabled=True) + def test_bottom_up_order(self): + """Deepest items are reset first.""" + parent = _ds("p", "/", neutralize=True) + level1 = _ds("l1", "/a", neutralize=False) + level2 = _ds("l2", "/a/b", neutralize=False) + level3 = _ds("l3", "/a/b/c", neutralize=False) + rootIf, modified = self._makeRootIf([parent, level1, level2, level3]) + + result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + + self.assertEqual(result, ["l3", "l2", "l1"]) + + def test_deep_cascade_through_null_items(self): + """null items are skipped (no DB write) but cascade continues deeper.""" + parent = _ds("p", "/", neutralize=True) + nullChild = _ds("n", "/a") # null — no write, but not a barrier + deepExplicit = _ds("d", "/a/b", neutralize=False) + rootIf, modified = self._makeRootIf([parent, nullChild, deepExplicit]) + + result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + + self.assertEqual(result, ["d"]) + self.assertEqual(modified, [("d", {"neutralize": None})]) + + def test_does_not_modify_parent(self): + parent = _ds("p", "/", neutralize=True) + child = _ds("c", "/a", neutralize=False) rootIf, modified = self._makeRootIf([parent, child]) - _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") - - self.assertEqual(modified, [("c", {"neutralize": None})]) - # ragIndexEnabled and scope on the child must remain untouched. - - def test_does_not_cross_sourcetype(self): - """Non-connection-root parents stay within their sourceType for cascade.""" - parent = _ds("p", "/", neutralize=True, sourceType="sharepointFolder") - otherTypeDescendant = _ds("o", "/anything", neutralize=False, sourceType="outlookFolder") - rootIf, modified = self._makeRootIf([parent, otherTypeDescendant]) - - affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") - - self.assertEqual(affected, 0) - self.assertEqual(modified, []) + self.assertNotIn("p", [m[0] for m in modified]) def test_connection_root_cascades_cross_sourcetype(self): - """Toggle on connection-root cascades into every explicit DS of that connection.""" connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False) olInherit = _ds("ol", "/", sourceType="outlookFolder") - spLeafExplicit = _ds("sp-leaf", "/sites/x", sourceType="sharepointFolder", neutralize=True) - rootIf, modified = self._makeRootIf([connRoot, spExplicit, olInherit, spLeafExplicit]) + spLeaf = _ds("sp-leaf", "/sites/x", sourceType="sharepointFolder", neutralize=True) + rootIf, modified = self._makeRootIf([connRoot, spExplicit, olInherit, spLeaf]) - affected = _inheritFlags.cascadeResetDescendants(rootIf, connRoot, "neutralize") + result = _inheritFlags.cascadeResetDescendants(rootIf, connRoot, "neutralize") - # spExplicit and spLeafExplicit had explicit values → reset. olInherit untouched. - self.assertEqual(affected, 2) - self.assertEqual({m[0] for m in modified}, {"sp", "sp-leaf"}) - for _, fields in modified: - self.assertEqual(fields, {"neutralize": None}) + self.assertEqual(set(result), {"sp", "sp-leaf"}) + # sp-leaf is deeper, should come first + self.assertEqual(result[0], "sp-leaf") + + def test_does_not_cross_sourcetype_for_non_authority(self): + parent = _ds("p", "/", neutralize=True, sourceType="sharepointFolder") + otherType = _ds("o", "/anything", neutralize=False, sourceType="outlookFolder") + rootIf, modified = self._makeRootIf([parent, otherType]) + result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") + self.assertEqual(result, []) def test_unknown_flag_raises(self): parent = _ds("p", "/", neutralize=True) @@ -183,57 +292,59 @@ class TestCascadeReset(unittest.TestCase): _inheritFlags.cascadeResetDescendants(rootIf, parent, "unknownFlag") -def _fds(idVal: str, *, tableName: str, recordFilter=None, **flags) -> dict: - """Build a FeatureDataSource dict fixture.""" - base = { - "id": idVal, - "workspaceInstanceId": "ws-1", - "tableName": tableName, - "recordFilter": recordFilter, - "neutralize": None, - "scope": None, - } - base.update(flags) - return base +# =========================================================================== +# DataSource: collectAncestorChain +# =========================================================================== + +class TestCollectAncestorChain(unittest.TestCase): + def test_returns_nearest_first(self): + root = _ds("r", "/", neutralize=True) + mid = _ds("m", "/a") + leaf = _ds("l", "/a/b") + chain = _inheritFlags.collectAncestorChain(leaf, [root, mid, leaf]) + self.assertEqual([_inheritFlags._getRecordValue(c, "id") for c in chain], ["m", "r"]) + + def test_connection_root_is_last(self): + connRoot = _ds("conn", "/", sourceType="msft") + spRoot = _ds("sp", "/", sourceType="sharepointFolder") + spLeaf = _ds("spl", "/sub", sourceType="sharepointFolder") + chain = _inheritFlags.collectAncestorChain(spLeaf, [connRoot, spRoot, spLeaf]) + ids = [_inheritFlags._getRecordValue(c, "id") for c in chain] + self.assertEqual(ids, ["sp", "conn"]) + + def test_root_has_no_ancestors(self): + root = _ds("r", "/") + chain = _inheritFlags.collectAncestorChain(root, [root]) + self.assertEqual(chain, []) -class TestFdsClassifyAndAncestry(unittest.TestCase): - def test_classify_workspace_wildcard(self): - self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="*")), "workspace") +# =========================================================================== +# DataSource: buildEffectiveByConnection +# =========================================================================== - def test_classify_table_wildcard(self): - self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="Pos")), "table") +class TestBuildEffectiveByConnection(unittest.TestCase): + def test_walk_mode(self): + root = _ds("r", "/", neutralize=True) + child = _ds("c", "/a", neutralize=False) + leaf = _ds("l", "/a/b") # inherits False from child + result = _inheritFlags.buildEffectiveByConnection([root, child, leaf], "neutralize", mode="walk") + self.assertEqual(result, {"r": True, "c": False, "l": False}) - def test_classify_record_specific(self): - rec = _fds("a", tableName="Pos", recordFilter={"id": "r-1"}) - self.assertEqual(_inheritFlags._fdsClassify(rec), "record") - - def test_workspace_is_ancestor_of_table_and_record(self): - ws = _fds("ws", tableName="*") - tbl = _fds("t", tableName="Pos") - rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) - self.assertTrue(_inheritFlags._fdsIsAncestor(ws, tbl)) - self.assertTrue(_inheritFlags._fdsIsAncestor(ws, rec)) - - def test_table_is_ancestor_of_record_same_table_only(self): - tbl = _fds("t", tableName="Pos") - recSame = _fds("r1", tableName="Pos", recordFilter={"id": "1"}) - recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"}) - self.assertTrue(_inheritFlags._fdsIsAncestor(tbl, recSame)) - self.assertFalse(_inheritFlags._fdsIsAncestor(tbl, recOther)) - - def test_record_has_no_descendants(self): - rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) - tbl = _fds("t", tableName="Pos") - self.assertFalse(_inheritFlags._fdsIsAncestor(rec, tbl)) - - def test_no_cross_workspace_ancestry(self): - ws = _fds("ws", tableName="*", workspaceInstanceId="ws-A") - rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, workspaceInstanceId="ws-B") - self.assertFalse(_inheritFlags._fdsIsAncestor(ws, rec)) + def test_aggregate_mode(self): + root = _ds("r", "/", neutralize=True) + child = _ds("c", "/a", neutralize=False) + leaf = _ds("l", "/a/b") # inherits False from child + result = _inheritFlags.buildEffectiveByConnection([root, child, leaf], "neutralize", mode="aggregate") + self.assertEqual(result["r"], "mixed") + self.assertEqual(result["c"], False) + self.assertEqual(result["l"], False) -class TestFdsEffectiveFlag(unittest.TestCase): +# =========================================================================== +# FeatureDataSource: getEffectiveFlagFds +# =========================================================================== + +class TestFdsEffectiveFlagWalk(unittest.TestCase): def test_own_explicit_wins(self): ws = _fds("ws", tableName="*", neutralize=False) rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) @@ -262,9 +373,50 @@ class TestFdsEffectiveFlag(unittest.TestCase): def test_unknown_flag_raises(self): rec = _fds("r", tableName="*") with self.assertRaises(ValueError): - _inheritFlags.getEffectiveFlagFds(rec, "ragIndexEnabled", [rec]) + _inheritFlags.getEffectiveFlagFds(rec, "doesNotExist", [rec]) +class TestFdsEffectiveFlagAggregate(unittest.TestCase): + def test_leaf_without_descendants(self): + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) + self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [rec], mode="aggregate")) + + def test_all_descendants_same(self): + ws = _fds("ws", tableName="*", neutralize=True) + tbl = _fds("t", tableName="Pos") # inherits True + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits True + allFds = [ws, tbl, rec] + self.assertTrue(_inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate")) + + def test_divergent_descendants_returns_mixed(self): + ws = _fds("ws", tableName="*", neutralize=True) + tbl = _fds("t", tableName="Pos", neutralize=False) + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits False from tbl + allFds = [ws, tbl, rec] + self.assertEqual( + _inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate"), + "mixed", + ) + + def test_table_aggregate_own_subtree_only(self): + ws = _fds("ws", tableName="*", neutralize=True) + tblA = _fds("tA", tableName="A", neutralize=True) + recA = _fds("rA", tableName="A", recordFilter={"id": "1"}, neutralize=True) + tblB = _fds("tB", tableName="B", neutralize=False) + allFds = [ws, tblA, recA, tblB] + # tblA subtree: all True + self.assertTrue(_inheritFlags.getEffectiveFlagFds(tblA, "neutralize", allFds, mode="aggregate")) + # ws subtree: mixed (tblB is False) + self.assertEqual( + _inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate"), + "mixed", + ) + + +# =========================================================================== +# FeatureDataSource: cascadeResetDescendantsFds (bottom-up, List[str]) +# =========================================================================== + class TestFdsCascadeReset(unittest.TestCase): def _makeRootIf(self, fdses): rootIf = MagicMock() @@ -276,6 +428,14 @@ class TestFdsCascadeReset(unittest.TestCase): rootIf.db.recordModify = MagicMock(side_effect=_modify) return rootIf, modified + def test_returns_list_of_ids(self): + ws = _fds("ws", tableName="*", neutralize=True) + tbl = _fds("t", tableName="Pos", neutralize=False) + rootIf, _ = self._makeRootIf([ws, tbl]) + result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize") + self.assertIsInstance(result, list) + self.assertEqual(result, ["t"]) + def test_workspace_cascades_to_all_explicit_descendants(self): ws = _fds("ws", tableName="*", neutralize=True) tblExplicit = _fds("t", tableName="Pos", neutralize=False) @@ -283,10 +443,11 @@ class TestFdsCascadeReset(unittest.TestCase): recExplicit = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) rootIf, modified = self._makeRootIf([ws, tblExplicit, tblInherit, recExplicit]) - affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize") + result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize") - self.assertEqual(affected, 2) - self.assertEqual({m[0] for m in modified}, {"t", "r"}) + self.assertEqual(set(result), {"t", "r"}) + # record is deeper (depth 2) than table (depth 1), should come first + self.assertEqual(result[0], "r") def test_table_cascades_only_to_same_table_records(self): tbl = _fds("t", tableName="Pos", neutralize=True) @@ -294,25 +455,189 @@ class TestFdsCascadeReset(unittest.TestCase): recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"}, neutralize=False) rootIf, modified = self._makeRootIf([tbl, recSame, recOther]) - affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, tbl, "neutralize") + result = _inheritFlags.cascadeResetDescendantsFds(rootIf, tbl, "neutralize") - self.assertEqual(affected, 1) + self.assertEqual(result, ["r1"]) self.assertEqual(modified, [("r1", {"neutralize": None})]) def test_record_has_no_cascade(self): rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) rootIf, modified = self._makeRootIf([rec]) - affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, rec, "neutralize") - self.assertEqual(affected, 0) - self.assertEqual(modified, []) + result = _inheritFlags.cascadeResetDescendantsFds(rootIf, rec, "neutralize") + self.assertEqual(result, []) def test_unknown_flag_raises(self): ws = _fds("ws", tableName="*", neutralize=True) rootIf, _ = self._makeRootIf([ws]) with self.assertRaises(ValueError): - _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "ragIndexEnabled") + _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "doesNotExist") +# =========================================================================== +# FeatureDataSource: collectAncestorChainFds +# =========================================================================== + +class TestCollectAncestorChainFds(unittest.TestCase): + def test_record_has_table_then_workspace(self): + ws = _fds("ws", tableName="*") + tbl = _fds("t", tableName="Pos") + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) + chain = _inheritFlags.collectAncestorChainFds(rec, [ws, tbl, rec]) + ids = [c["id"] for c in chain] + self.assertEqual(ids, ["t", "ws"]) + + def test_table_has_only_workspace(self): + ws = _fds("ws", tableName="*") + tbl = _fds("t", tableName="Pos") + chain = _inheritFlags.collectAncestorChainFds(tbl, [ws, tbl]) + self.assertEqual([c["id"] for c in chain], ["ws"]) + + def test_workspace_has_no_ancestors(self): + ws = _fds("ws", tableName="*") + chain = _inheritFlags.collectAncestorChainFds(ws, [ws]) + self.assertEqual(chain, []) + + +# =========================================================================== +# FeatureDataSource: buildEffectiveByWorkspaceFds +# =========================================================================== + +class TestBuildEffectiveByWorkspaceFds(unittest.TestCase): + def test_walk_mode(self): + ws = _fds("ws", tableName="*", neutralize=True) + tbl = _fds("t", tableName="Pos", neutralize=False) + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits False from tbl + result = _inheritFlags.buildEffectiveByWorkspaceFds([ws, tbl, rec], "neutralize", mode="walk") + self.assertEqual(result, {"ws": True, "t": False, "r": False}) + + def test_aggregate_mode(self): + ws = _fds("ws", tableName="*", neutralize=True) + tbl = _fds("t", tableName="Pos", neutralize=False) + rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) + result = _inheritFlags.buildEffectiveByWorkspaceFds([ws, tbl, rec], "neutralize", mode="aggregate") + self.assertEqual(result["ws"], "mixed") + self.assertEqual(result["t"], False) + self.assertEqual(result["r"], False) + + +# =========================================================================== +# resolveEffectiveForPath (with and without own record) +# =========================================================================== + +class TestResolveEffectiveForPath(unittest.TestCase): + def test_with_exact_record(self): + root = _ds("r", "/", neutralize=True, scope="mandate", ragIndexEnabled=False) + leaf = _ds("l", "/folder/sub", neutralize=False) + allDs = [root, leaf] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/folder/sub", allDs) + self.assertEqual(result["effectiveNeutralize"], False) + self.assertEqual(result["effectiveScope"], "mandate") + self.assertEqual(result["effectiveRagIndexEnabled"], False) + + def test_without_record_inherits_from_ancestor(self): + root = _ds("r", "/", neutralize=True, scope="mandate", ragIndexEnabled=True) + allDs = [root] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/deep/path/file.txt", allDs) + self.assertEqual(result["effectiveNeutralize"], True) + self.assertEqual(result["effectiveScope"], "mandate") + self.assertEqual(result["effectiveRagIndexEnabled"], True) + + def test_without_record_inherits_from_closest_ancestor(self): + root = _ds("r", "/", neutralize=True, ragIndexEnabled=True) + mid = _ds("m", "/folder", neutralize=False, ragIndexEnabled=False) + allDs = [root, mid] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/folder/sub/file.txt", allDs) + self.assertEqual(result["effectiveNeutralize"], False) + self.assertEqual(result["effectiveRagIndexEnabled"], False) + + def test_without_record_no_ancestors_returns_defaults(self): + allDs: list = [] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/path", allDs) + self.assertEqual(result["effectiveNeutralize"], False) + self.assertEqual(result["effectiveScope"], "personal") + self.assertEqual(result["effectiveRagIndexEnabled"], False) + + def test_connection_root_covers_service_subtree(self): + connRoot = _ds("cr", "/", neutralize=True, sourceType="msft") + allDs = [connRoot] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/sites/intranet", allDs) + self.assertEqual(result["effectiveNeutralize"], True) + + def test_exact_record_with_aggregate_mixed(self): + root = _ds("r", "/", neutralize=True) + leaf = _ds("l", "/sub", neutralize=False) + allDs = [root, leaf] + result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/", allDs, mode="aggregate") + self.assertEqual(result["effectiveNeutralize"], "mixed") + + +class TestResolveEffectiveForFds(unittest.TestCase): + def test_with_exact_record(self): + ws = _fds("ws", tableName="*", neutralize=True, scope="mandate") + tbl = _fds("t", tableName="Pos", neutralize=False, scope="personal") + allFds = [ws, tbl] + result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds) + self.assertEqual(result["effectiveNeutralize"], False) + self.assertEqual(result["effectiveScope"], "personal") + self.assertEqual(result["effectiveRagIndexEnabled"], False) + + def test_without_record_inherits_from_workspace_wildcard(self): + ws = _fds("ws", tableName="*", neutralize=True, scope="mandate", ragIndexEnabled=True) + allFds = [ws] + result = _inheritFlags.resolveEffectiveForFds("fi-1", "Unknown", None, allFds) + self.assertEqual(result["effectiveNeutralize"], True) + self.assertEqual(result["effectiveScope"], "mandate") + self.assertEqual(result["effectiveRagIndexEnabled"], True) + + def test_without_record_no_ancestors_returns_defaults(self): + allFds: list = [] + result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds) + self.assertEqual(result["effectiveNeutralize"], False) + self.assertEqual(result["effectiveScope"], "personal") + self.assertEqual(result["effectiveRagIndexEnabled"], False) + + def test_rag_inherits_when_table_overrides_neutralize_only(self): + """Tables that override only neutralize must still inherit RAG from parent.""" + ws = _fds("ws", tableName="*", ragIndexEnabled=True) + tbl = _fds("t", tableName="Pos", neutralize=False) + allFds = [ws, tbl] + result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds) + self.assertEqual(result["effectiveRagIndexEnabled"], True) + + def test_rag_aggregate_mixed_when_descendants_diverge(self): + ws = _fds("ws", tableName="*", ragIndexEnabled=True) + tbl = _fds("t", tableName="Pos", ragIndexEnabled=False) + allFds = [ws, tbl] + result = _inheritFlags.resolveEffectiveForFds("fi-1", "*", None, allFds, mode="aggregate") + self.assertEqual(result["effectiveRagIndexEnabled"], "mixed") + + def test_inheritable_fds_flags_includes_rag(self): + self.assertIn("ragIndexEnabled", _inheritFlags._INHERITABLE_FDS_FLAGS) + self.assertIn("neutralize", _inheritFlags._INHERITABLE_FDS_FLAGS) + self.assertIn("scope", _inheritFlags._INHERITABLE_FDS_FLAGS) + + +# =========================================================================== +# FDS cascade resets RAG (in addition to neutralize and scope) +# =========================================================================== + +class TestCascadeResetFdsRag(unittest.TestCase): + def test_cascade_resets_rag_on_descendants(self): + ws = _fds("ws", tableName="*") + tbl = _fds("t", tableName="Pos", ragIndexEnabled=False) + allFds = [ws, tbl] + rootIf = MagicMock() + rootIf.db.getRecordset.return_value = allFds + rootIf.db.recordModify = MagicMock() + result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "ragIndexEnabled") + self.assertIn("t", result) + rootIf.db.recordModify.assert_called() + + +# =========================================================================== +# Path normalization +# =========================================================================== + class TestPathNormalization(unittest.TestCase): def test_empty_path_normalises_to_root(self): self.assertEqual(_inheritFlags._normalisePath(""), "/") diff --git a/tests/unit/teamsbot/test_directorPrompts.py b/tests/unit/teamsbot/test_directorPrompts.py index f136438a..b8bdaafc 100644 --- a/tests/unit/teamsbot/test_directorPrompts.py +++ b/tests/unit/teamsbot/test_directorPrompts.py @@ -42,7 +42,7 @@ from modules.features.teamsbot.datamodelTeamsbot import ( from modules.features.teamsbot.service import ( TeamsbotService, _activeServices, - _sessionEvents, + sessionEvents, getActiveService, ) @@ -152,10 +152,10 @@ def _buildService() -> TeamsbotService: def _resetGlobals(): """Avoid cross-test bleed in module-level globals.""" _activeServices.clear() - _sessionEvents.clear() + sessionEvents.clear() yield _activeServices.clear() - _sessionEvents.clear() + sessionEvents.clear() # ============================================================================ @@ -251,7 +251,7 @@ class TestBuildPersistentDirectorContext: ] rendered = svc._buildPersistentDirectorContext() assert "OPERATOR_DIRECTIVES" in rendered - assert "- Antworte immer in Englisch." in rendered + assert "Antworte immer in Englisch." in rendered assert "private" in rendered def test_skipsBlankText(self): @@ -261,7 +261,7 @@ class TestBuildPersistentDirectorContext: {"id": "p2", "text": "Sei hoeflich."}, ] rendered = svc._buildPersistentDirectorContext() - assert "- Sei hoeflich." in rendered + assert "Sei hoeflich." in rendered assert "p1" not in rendered # the blank one is filtered out def test_allBlankPromptsResultInEmpty(self):