fixes rag and workflow

This commit is contained in:
ValueOn AG 2026-05-19 16:48:01 +02:00
parent 4064ac0266
commit 1ed462ad13
27 changed files with 3728 additions and 848 deletions

View file

@ -172,7 +172,7 @@ def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: s
pass # already a list pass # already a list
elif fieldType == "BOOLEAN": elif fieldType == "BOOLEAN":
record[fieldName] = bool(value) if value is not None else False record[fieldName] = bool(value) if value is not None else None
elif fieldType == "JSONB" and value is not None: elif fieldType == "JSONB" and value is not None:
try: try:
@ -184,6 +184,18 @@ def parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: s
logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})") logger.warning(f"Could not parse JSONB field {fieldName}, keeping as string ({context})")
def _stripNulBytesFromStr(value: Any) -> Any:
"""psycopg2 rejects bound parameters whose Python str contains NUL (0x00).
Some extracted files (e.g. SQL dumps, mixed binary treated as text) still
carry those bytes; PostgreSQL TEXT could store them via other paths, but
the client protocol path used here cannot.
"""
if isinstance(value, str) and "\x00" in value:
return value.replace("\x00", "")
return value
def _quotePgIdent(name: str) -> str: def _quotePgIdent(name: str) -> str:
return '"' + str(name).replace('"', '""') + '"' return '"' + str(name).replace('"', '""') + '"'
@ -983,7 +995,7 @@ class DatabaseConnector:
else: else:
value = json.dumps(value) value = json.dumps(value)
values.append(value) values.append(_stripNulBytesFromStr(value))
# Build INSERT/UPDATE with quoted identifiers # Build INSERT/UPDATE with quoted identifiers
col_names = ", ".join([f'"{col}"' for col in columns]) col_names = ", ".join([f'"{col}"' for col in columns])

View file

@ -76,6 +76,14 @@ class FeatureDataSource(PowerOnModel):
), ),
json_schema_extra={"label": "Neutralisieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, json_schema_extra={"label": "Neutralisieren", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False},
) )
ragIndexEnabled: Optional[bool] = Field(
default=None,
description=(
"Three-state RAG-indexing flag with cascade-inherit semantics. "
"None = inherit; True/False = explicit. Cascade-reset on parent toggle."
),
json_schema_extra={"label": "RAG-Indexierung", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False},
)
neutralizeFields: Optional[List[str]] = Field( neutralizeFields: Optional[List[str]] = Field(
default=None, default=None,
description="Column names whose values are replaced with placeholders before AI processing", description="Column names whose values are replaced with placeholders before AI processing",

View file

@ -124,6 +124,7 @@ class InvestorDemo2026(_BaseDemoConfig):
from modules.datamodels.datamodelUam import Mandate, UserInDB from modules.datamodels.datamodelUam import Mandate, UserInDB
from modules.datamodels.datamodelMembership import UserMandate from modules.datamodels.datamodelMembership import UserMandate
summary["_removedMandateIds"] = []
for mandateDef in [_MANDATE_HAPPYLIFE, _MANDATE_ALPINA]: for mandateDef in [_MANDATE_HAPPYLIFE, _MANDATE_ALPINA]:
try: try:
existing = db.getRecordset(Mandate, recordFilter={"name": mandateDef["name"]}) existing = db.getRecordset(Mandate, recordFilter={"name": mandateDef["name"]})
@ -132,28 +133,36 @@ class InvestorDemo2026(_BaseDemoConfig):
self._removeMandateData(db, mid, mandateDef["label"], summary) self._removeMandateData(db, mid, mandateDef["label"], summary)
db.recordDelete(Mandate, mid) db.recordDelete(Mandate, mid)
summary["removed"].append(f"Mandate {mandateDef['label']} ({mid})") summary["removed"].append(f"Mandate {mandateDef['label']} ({mid})")
summary["_removedMandateIds"].append({"id": mid, "mandateId": mid})
logger.info(f"Removed mandate {mandateDef['label']} ({mid})") logger.info(f"Removed mandate {mandateDef['label']} ({mid})")
except Exception as e: except Exception as e:
summary["errors"].append(f"Remove mandate {mandateDef['label']}: {e}") summary["errors"].append(f"Remove mandate {mandateDef['label']}: {e}")
# SAFETY: NEVER delete the user record. The user may have connections,
# chats, workflows, files, and other data across multiple databases.
# Only remove the mandate memberships that THIS demo created.
try: try:
existing = db.getRecordset(UserInDB, recordFilter={"username": _USER["username"]}) existing = db.getRecordset(UserInDB, recordFilter={"username": _USER["username"]})
for u in existing: for u in existing:
uid = u.get("id") uid = u.get("id")
removedMandateIds = {m.get("mandateId") for m in summary.get("_removedMandateIds", [])}
memberships = db.getRecordset(UserMandate, recordFilter={"userId": uid}) memberships = db.getRecordset(UserMandate, recordFilter={"userId": uid})
for mem in memberships: for mem in memberships:
if mem.get("mandateId") in removedMandateIds:
try: try:
db.recordDelete(UserMandate, mem.get("id")) db.recordDelete(UserMandate, mem.get("id"))
except Exception: except Exception:
pass pass
db.recordDelete(UserInDB, uid) summary["skipped"].append(
summary["removed"].append(f"User {_USER['username']} ({uid})") f"User {_USER['username']} ({uid}) preserved (only demo mandate memberships removed)"
logger.info(f"Removed user {_USER['username']} ({uid})") )
logger.info(f"Preserved user {_USER['username']} ({uid}) - removed demo mandate memberships only")
except Exception as e: except Exception as e:
summary["errors"].append(f"Remove user: {e}") summary["errors"].append(f"Remove user memberships: {e}")
self._removeLanguageSet(db, "es", summary) self._removeLanguageSet(db, "es", summary)
summary.pop("_removedMandateIds", None)
return summary return summary
# ------------------------------------------------------------------ # ------------------------------------------------------------------

View file

@ -121,32 +121,39 @@ class PwgDemo2026(_BaseDemoConfig):
from modules.datamodels.datamodelMembership import UserMandate from modules.datamodels.datamodelMembership import UserMandate
from modules.datamodels.datamodelUam import Mandate, UserInDB from modules.datamodels.datamodelUam import Mandate, UserInDB
removedMandateIds = set()
try: try:
existing = db.getRecordset(Mandate, recordFilter={"name": _MANDATE_PWG["name"]}) existing = db.getRecordset(Mandate, recordFilter={"name": _MANDATE_PWG["name"]})
for m in existing: for m in existing:
mid = m.get("id") mid = m.get("id")
self._removeMandateData(db, mid, _MANDATE_PWG["label"], summary) self._removeMandateData(db, mid, _MANDATE_PWG["label"], summary)
db.recordDelete(Mandate, mid) db.recordDelete(Mandate, mid)
removedMandateIds.add(mid)
summary["removed"].append(f"Mandate {_MANDATE_PWG['label']} ({mid})") summary["removed"].append(f"Mandate {_MANDATE_PWG['label']} ({mid})")
logger.info(f"Removed mandate {_MANDATE_PWG['label']} ({mid})") logger.info(f"Removed mandate {_MANDATE_PWG['label']} ({mid})")
except Exception as e: except Exception as e:
summary["errors"].append(f"Remove mandate {_MANDATE_PWG['label']}: {e}") summary["errors"].append(f"Remove mandate {_MANDATE_PWG['label']}: {e}")
# SAFETY: NEVER delete the user record. The user may have connections,
# chats, workflows, files, and other data across multiple databases.
# Only remove the mandate memberships that THIS demo created.
try: try:
existing = db.getRecordset(UserInDB, recordFilter={"username": _USER["username"]}) existing = db.getRecordset(UserInDB, recordFilter={"username": _USER["username"]})
for u in existing: for u in existing:
uid = u.get("id") uid = u.get("id")
memberships = db.getRecordset(UserMandate, recordFilter={"userId": uid}) or [] memberships = db.getRecordset(UserMandate, recordFilter={"userId": uid}) or []
for mem in memberships: for mem in memberships:
if mem.get("mandateId") in removedMandateIds:
try: try:
db.recordDelete(UserMandate, mem.get("id")) db.recordDelete(UserMandate, mem.get("id"))
except Exception: except Exception:
pass pass
db.recordDelete(UserInDB, uid) summary["skipped"].append(
summary["removed"].append(f"User {_USER['username']} ({uid})") f"User {_USER['username']} ({uid}) preserved (only demo mandate memberships removed)"
logger.info(f"Removed user {_USER['username']} ({uid})") )
logger.info(f"Preserved user {_USER['username']} ({uid}) - removed demo mandate memberships only")
except Exception as e: except Exception as e:
summary["errors"].append(f"Remove user: {e}") summary["errors"].append(f"Remove user memberships: {e}")
return summary return summary

View file

@ -2,7 +2,7 @@
# All rights reserved. # All rights reserved.
"""Workspace feature data models — WorkspaceUserSettings.""" """Workspace feature data models — WorkspaceUserSettings."""
from typing import List, Optional from typing import Dict, List, Optional
from pydantic import Field from pydantic import Field
from modules.datamodels.datamodelBase import PowerOnModel from modules.datamodels.datamodelBase import PowerOnModel
from modules.shared.i18nRegistry import i18nModel from modules.shared.i18nRegistry import i18nModel
@ -52,7 +52,7 @@ class WorkspaceUserSettings(PowerOnModel):
description="Max agent rounds override (None = instance default)", description="Max agent rounds override (None = instance default)",
json_schema_extra={"label": "Max. Agenten-Runden", "frontend_type": "number", "frontend_readonly": False, "frontend_required": False}, json_schema_extra={"label": "Max. Agenten-Runden", "frontend_type": "number", "frontend_readonly": False, "frontend_required": False},
) )
requireNeutralization: bool = Field( requireNeutralization: Optional[bool] = Field(
default=False, default=False,
description="Default neutralization setting for this user", description="Default neutralization setting for this user",
json_schema_extra={"label": "Neutralisierung", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}, json_schema_extra={"label": "Neutralisierung", "frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False},
@ -67,3 +67,8 @@ class WorkspaceUserSettings(PowerOnModel):
description="Allowed AI models (empty = all permitted)", description="Allowed AI models (empty = all permitted)",
json_schema_extra={"label": "Erlaubte Modelle", "frontend_type": "modelMultiSelect", "frontend_readonly": False, "frontend_required": False}, json_schema_extra={"label": "Erlaubte Modelle", "frontend_type": "modelMultiSelect", "frontend_readonly": False, "frontend_required": False},
) )
uiTreeExpansion: Dict[str, List[str]] = Field(
default_factory=dict,
description="Per-tab expanded tree-node ids for the UDB / FormGeneratorTree. Key = scope name (e.g. 'sources', 'filesOwn', 'filesShared').",
json_schema_extra={"label": "Tree-Expand-Zustand", "frontend_type": "json", "frontend_readonly": True, "frontend_required": False},
)

View file

@ -1281,52 +1281,101 @@ async def listWorkspaceDataSources(
try: try:
from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelDataSource import DataSource
from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbApp import getRootInterface
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import buildEffectiveByConnection
rootIf = getRootInterface() rootIf = getRootInterface()
recordFilter: dict = {"featureInstanceId": instanceId} recordFilter: dict = {"featureInstanceId": instanceId}
if wsMandateId: if wsMandateId:
recordFilter["mandateId"] = wsMandateId recordFilter["mandateId"] = wsMandateId
dataSources = rootIf.db.getRecordset(DataSource, recordFilter=recordFilter) dataSources = rootIf.db.getRecordset(DataSource, recordFilter=recordFilter)
return JSONResponse({"dataSources": dataSources or []}) if not dataSources:
return JSONResponse({"dataSources": []})
# Group by connectionId and compute effective values in aggregate mode
byConnection: dict = {}
for ds in dataSources:
connId = ds.get("connectionId") or ""
byConnection.setdefault(connId, []).append(ds)
for connDs in byConnection.values():
effNeutralize = buildEffectiveByConnection(connDs, "neutralize", mode="aggregate")
effScope = buildEffectiveByConnection(connDs, "scope", mode="aggregate")
effRag = buildEffectiveByConnection(connDs, "ragIndexEnabled", mode="aggregate")
for ds in connDs:
dsId = ds.get("id", "")
ds["effectiveNeutralize"] = effNeutralize.get(dsId, False)
ds["effectiveScope"] = effScope.get(dsId, "personal")
ds["effectiveRagIndexEnabled"] = effRag.get(dsId, False)
return JSONResponse({"dataSources": dataSources})
except Exception: except Exception:
return JSONResponse({"dataSources": []}) return JSONResponse({"dataSources": []})
@router.get("/{instanceId}/connections") class _TreeChildrenRequest(BaseModel):
"""Request body for the generic tree children endpoint."""
parents: List[Optional[str]] = Field(
default_factory=list,
description="List of parent keys to fetch children for. Use null for top-level.",
)
@router.post("/{instanceId}/tree/children")
@limiter.limit("300/minute") @limiter.limit("300/minute")
async def listWorkspaceConnections( async def getTreeChildren(
request: Request, request: Request,
instanceId: str = Path(...), instanceId: str = Path(...),
body: _TreeChildrenRequest = Body(...),
context: RequestContext = Depends(getRequestContext), context: RequestContext = Depends(getRequestContext),
): ):
"""Return the user's active connections (UserConnections).""" """Generic UDB tree children resolver.
_mandateId, _ = _validateInstanceAccess(instanceId, context)
from modules.serviceCenter import getService The UI sends a list of parent keys (or null for top-level). The backend
from modules.serviceCenter.context import ServiceCenterContext returns children for each requested parent, with all effective flag
ctx = ServiceCenterContext( values pre-computed. The UI builds the visible tree from the resulting
user=context.user, flat per-parent map.
mandate_id=_mandateId or "", """
feature_instance_id=instanceId, _validateInstanceAccess(instanceId, context)
from modules.serviceCenter.services.serviceKnowledge._buildTree import getChildrenForParents
try:
nodesByParent = await getChildrenForParents(instanceId, body.parents, context)
except Exception as exc:
logger.exception("Tree children build failed: %s", exc)
raise HTTPException(status_code=500, detail=str(exc))
return JSONResponse({"nodesByParent": nodesByParent})
class _TreeAttributesRequest(BaseModel):
"""Request body for the attribute-refresh endpoint."""
keys: List[str] = Field(
default_factory=list,
description="List of node keys to fetch current attributes for.",
) )
chatService = getService("chat", ctx)
connections = chatService.getUserConnections()
items = [] @router.post("/{instanceId}/tree/attributes")
for c in connections or []: @limiter.limit("300/minute")
conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {}) async def getTreeAttributes(
authority = conn.get("authority") request: Request,
if hasattr(authority, "value"): instanceId: str = Path(...),
authority = authority.value body: _TreeAttributesRequest = Body(...),
status = conn.get("status") context: RequestContext = Depends(getRequestContext),
if hasattr(status, "value"): ):
status = status.value """Return current effective attribute values (neutralize, scope,
items.append({ ragIndexEnabled) for a list of node keys. Used after a toggle action
"id": conn.get("id"), to refresh only the visible nodes without reloading tree structure."""
"authority": authority, _validateInstanceAccess(instanceId, context)
"externalUsername": conn.get("externalUsername"), from modules.serviceCenter.services.serviceKnowledge._buildTree import getAttributesForKeys
"externalEmail": conn.get("externalEmail"),
"status": status, if len(body.keys) > 500:
"knowledgeIngestionEnabled": bool(conn.get("knowledgeIngestionEnabled")), raise HTTPException(status_code=400, detail="Max 500 keys per request")
})
return JSONResponse({"connections": items}) try:
attrs = await getAttributesForKeys(instanceId, body.keys, context)
except Exception as exc:
logger.exception("Tree attributes failed: %s", exc)
raise HTTPException(status_code=500, detail=str(exc))
return JSONResponse({"attributes": attrs})
class CreateDataSourceRequest(BaseModel): class CreateDataSourceRequest(BaseModel):
@ -1391,303 +1440,6 @@ async def deleteWorkspaceDataSource(
# ---- Feature Connections & Feature Data Sources ---- # ---- Feature Connections & Feature Data Sources ----
@router.get("/{instanceId}/feature-connections")
@limiter.limit("120/minute")
async def listFeatureConnections(
request: Request,
instanceId: str = Path(...),
context: RequestContext = Depends(getRequestContext),
):
"""List feature instances the user has access to, scoped to the workspace mandate."""
wsMandateId, _ = _validateInstanceAccess(instanceId, context)
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.security.rbacCatalog import getCatalogService
from modules.datamodels.datamodelUam import Mandate
rootIf = getRootInterface()
userId = str(context.user.id)
catalog = getCatalogService()
featureCodesWithData = catalog.getFeaturesWithDataObjects()
userMandates = rootIf.getUserMandates(userId)
if not userMandates:
return JSONResponse({"featureConnectionsByMandate": []})
allowedMandateIds = {um.mandateId for um in userMandates}
if wsMandateId and wsMandateId in allowedMandateIds:
allowedMandateIds = {wsMandateId}
mandateLabels: dict = {}
for um in userMandates:
if um.mandateId not in allowedMandateIds:
continue
try:
rows = rootIf.db.getRecordset(Mandate, recordFilter={"id": um.mandateId})
if rows:
m = rows[0]
mandateLabels[um.mandateId] = m.get("label") or m.get("name") or um.mandateId
except Exception:
mandateLabels[um.mandateId] = um.mandateId
byMandate: dict = {}
seenIds: set = set()
for um in userMandates:
if um.mandateId not in allowedMandateIds:
continue
allInstances = rootIf.getFeatureInstancesByMandate(um.mandateId)
for inst in allInstances:
if inst.id in seenIds:
continue
seenIds.add(inst.id)
if not inst.enabled:
continue
if inst.featureCode not in featureCodesWithData:
continue
featureAccess = rootIf.getFeatureAccess(userId, inst.id)
if not featureAccess or not featureAccess.enabled:
continue
featureDef = catalog.getFeatureDefinition(inst.featureCode) or {}
dataObjects = catalog.getDataObjects(inst.featureCode)
label = inst.label or inst.featureCode
mid = inst.mandateId
connItem = {
"featureInstanceId": inst.id,
"featureCode": inst.featureCode,
"mandateId": mid,
"label": label,
"icon": featureDef.get("icon", "mdi-database"),
"tableCount": len(dataObjects),
}
if mid not in byMandate:
byMandate[mid] = []
byMandate[mid].append(connItem)
def _sortKeyLabel(x: dict) -> str:
return (x.get("label") or "").lower()
groups = []
for mid in sorted(byMandate.keys(), key=lambda m: (mandateLabels.get(m, m) or "").lower()):
conns = sorted(byMandate[mid], key=_sortKeyLabel)
groups.append({
"mandateId": mid,
"mandateLabel": mandateLabels.get(mid, mid),
"featureConnections": conns,
})
return JSONResponse({"featureConnectionsByMandate": groups})
@router.get("/{instanceId}/feature-connections/{fiId}/tables")
@limiter.limit("120/minute")
async def listFeatureConnectionTables(
request: Request,
instanceId: str = Path(...),
fiId: str = Path(..., description="Feature instance ID"),
context: RequestContext = Depends(getRequestContext),
):
"""List data tables (DATA_OBJECTS) for a feature instance, filtered by RBAC."""
wsMandateId, _ = _validateInstanceAccess(instanceId, context)
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.security.rbacCatalog import getCatalogService
rootIf = getRootInterface()
inst = rootIf.getFeatureInstance(fiId)
if not inst:
raise HTTPException(status_code=404, detail=routeApiMsg("Feature instance not found"))
mandateId = str(inst.mandateId) if inst.mandateId else None
if wsMandateId and mandateId and mandateId != wsMandateId:
raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate"))
catalog = getCatalogService()
try:
from modules.security.rbac import RbacClass
from modules.security.rootAccess import getRootDbAppConnector
dbApp = getRootDbAppConnector()
rbac = RbacClass(dbApp, dbApp=dbApp)
accessible = catalog.getAccessibleDataObjects(
featureCode=inst.featureCode,
rbacInstance=rbac,
user=context.user,
mandateId=mandateId or "",
featureInstanceId=fiId,
)
except Exception:
accessible = catalog.getDataObjects(inst.featureCode)
accessibleKeys = {obj.get("objectKey", "") for obj in accessible}
referencedGroups = set()
for obj in accessible:
meta = obj.get("meta", {})
if meta.get("wildcard") or meta.get("isGroup"):
continue
if meta.get("group"):
referencedGroups.add(meta["group"])
tables = []
for obj in catalog.getDataObjects(inst.featureCode):
meta = obj.get("meta", {})
if meta.get("wildcard"):
continue
objectKey = obj.get("objectKey", "")
if meta.get("isGroup"):
# Groups are metadata-only; include if at least one child is accessible
# (regardless of whether the group itself was RBAC-granted).
if objectKey not in referencedGroups:
continue
else:
if objectKey not in accessibleKeys:
continue
node = {
"objectKey": objectKey,
"tableName": meta.get("table", ""),
"label": resolveText(obj.get("label", "")),
"fields": meta.get("fields", []),
"isParent": bool(meta.get("isParent", False)),
"parentTable": meta.get("parentTable") or None,
"parentKey": meta.get("parentKey") or None,
"displayFields": meta.get("displayFields", []),
"isGroup": bool(meta.get("isGroup", False)),
"group": meta.get("group") or None,
}
tables.append(node)
return JSONResponse({"tables": tables})
@router.get("/{instanceId}/feature-connections/{fiId}/parent-objects/{tableName}")
@limiter.limit("120/minute")
async def listParentObjects(
request: Request,
instanceId: str = Path(...),
fiId: str = Path(..., description="Feature instance ID"),
tableName: str = Path(..., description="Parent table name from DATA_OBJECTS"),
parentKey: Optional[str] = Query(None, description="Optional FK column name to filter by ancestor record (nested parent rendering)"),
parentValue: Optional[str] = Query(None, description="Optional FK value matching parentKey to filter children of a specific ancestor record"),
context: RequestContext = Depends(getRequestContext),
):
"""List records from a parent table so the user can pick a specific record to scope data.
When parentKey + parentValue are provided, results are additionally filtered by that FK,
enabling nested record hierarchies (e.g. Sessions OF Context X).
"""
wsMandateId, _ = _validateInstanceAccess(instanceId, context)
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.security.rbacCatalog import getCatalogService
rootIf = getRootInterface()
inst = rootIf.getFeatureInstance(fiId)
if not inst:
raise HTTPException(status_code=404, detail=routeApiMsg("Feature instance not found"))
featureCode = inst.featureCode
mandateId = str(inst.mandateId) if inst.mandateId else ""
if wsMandateId and mandateId and mandateId != wsMandateId:
raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate"))
catalog = getCatalogService()
parentObj = None
for obj in catalog.getDataObjects(featureCode):
meta = obj.get("meta", {})
if meta.get("table") == tableName and meta.get("isParent"):
parentObj = obj
break
if not parentObj:
raise HTTPException(status_code=400, detail=f"Table '{tableName}' is not a registered parent table")
displayFields = parentObj["meta"].get("displayFields", [])
selectCols = ', '.join(f'"{f}"' for f in (["id"] + displayFields)) if displayFields else "*"
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
featureDbName = f"poweron_{featureCode.lower()}"
featureDbConn = None
try:
featureDbConn = DatabaseConnector(
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
dbDatabase=featureDbName,
dbUser=APP_CONFIG.get("DB_USER"),
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
dbPort=int(APP_CONFIG.get("DB_PORT", 5432)),
userId=str(context.user.id),
)
conn = featureDbConn.connection
with conn.cursor() as cur:
cur.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
"AND column_name IN ('featureInstanceId', 'instanceId')",
[tableName],
)
instanceCols = [row["column_name"] for row in cur.fetchall()]
instanceCol = "featureInstanceId" if "featureInstanceId" in instanceCols else "instanceId"
cur.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
"AND column_name = 'userId'",
[tableName],
)
hasUserId = cur.rowcount > 0
sql = (
f'SELECT {selectCols} FROM "{tableName}" '
f'WHERE "{instanceCol}" = %s'
)
params = [fiId]
if mandateId:
sql += ' AND "mandateId" = %s'
params.append(mandateId)
if hasUserId:
sql += ' AND "userId" = %s'
params.append(str(context.user.id))
if parentKey and parentValue:
cur.execute(
"SELECT 1 FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
"AND column_name = %s",
[tableName, parentKey],
)
if cur.rowcount > 0:
sql += f' AND "{parentKey}" = %s'
params.append(parentValue)
else:
logger.warning(
f"listParentObjects({tableName}): ignoring parentKey '{parentKey}' (column does not exist)"
)
sql += ' ORDER BY "id" DESC LIMIT 100'
cur.execute(sql, params)
rows = []
for row in cur.fetchall():
r = dict(row)
for k, v in r.items():
if hasattr(v, "isoformat"):
r[k] = v.isoformat()
elif isinstance(v, (bytes, bytearray)):
r[k] = f"<binary {len(v)} bytes>"
displayParts = [str(r.get(f, "")) for f in displayFields if r.get(f) is not None]
rows.append({
"id": r.get("id", ""),
"displayLabel": " | ".join(displayParts) if displayParts else r.get("id", ""),
"fields": {f: r.get(f) for f in displayFields},
})
except Exception as e:
logger.error(f"listParentObjects({tableName}) failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to list parent objects: {e}")
finally:
if featureDbConn:
try:
featureDbConn.close()
except Exception:
pass
return JSONResponse({"parentObjects": rows})
class CreateFeatureDataSourceRequest(BaseModel): class CreateFeatureDataSourceRequest(BaseModel):
"""Request body for adding a feature table as data source.""" """Request body for adding a feature table as data source."""
featureInstanceId: str = Field(description="Feature instance ID") featureInstanceId: str = Field(description="Feature instance ID")
@ -1706,16 +1458,35 @@ async def createFeatureDataSource(
body: CreateFeatureDataSourceRequest = Body(...), body: CreateFeatureDataSourceRequest = Body(...),
context: RequestContext = Depends(getRequestContext), context: RequestContext = Depends(getRequestContext),
): ):
"""Create a FeatureDataSource for this workspace instance.""" """Create a FeatureDataSource for this workspace instance.
The FDS lives under the WORKSPACE's mandate (not the feature's): that
matches how the tree (`allFds = recordset where workspaceInstanceId =
instanceId`) and the PATCH endpoints scope these records by workspace,
not by feature mandate. The user can legitimately reference a feature
from another mandate they have access to (via the UDB mandate-group
nodes), and a hard cross-mandate block here would silently 403 those
toggles. Access to the referenced feature is verified by the user's
`FeatureAccess` and the existing tree-children RBAC, which run before
the user can ever click on this node.
"""
wsMandateId, _ = _validateInstanceAccess(instanceId, context) wsMandateId, _ = _validateInstanceAccess(instanceId, context)
from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
rootIf = getRootInterface() rootIf = getRootInterface()
inst = rootIf.getFeatureInstance(body.featureInstanceId) if not rootIf.getFeatureAccess(str(context.user.id), body.featureInstanceId):
mandateId = str(inst.mandateId) if inst else (str(context.mandateId) if context.mandateId else "") raise HTTPException(status_code=403, detail=routeApiMsg("Access denied to this feature instance"))
if wsMandateId and mandateId and mandateId != wsMandateId:
raise HTTPException(status_code=403, detail=routeApiMsg("Feature instance does not belong to workspace mandate")) existing = rootIf.db.getRecordset(FeatureDataSource, recordFilter={
"workspaceInstanceId": instanceId,
"featureInstanceId": body.featureInstanceId,
"tableName": body.tableName,
}) or []
targetFilter = body.recordFilter or None
for rec in existing:
if (rec.get("recordFilter") or None) == targetFilter:
return JSONResponse(rec)
fds = FeatureDataSource( fds = FeatureDataSource(
featureInstanceId=body.featureInstanceId, featureInstanceId=body.featureInstanceId,
@ -1723,7 +1494,7 @@ async def createFeatureDataSource(
tableName=body.tableName, tableName=body.tableName,
objectKey=body.objectKey, objectKey=body.objectKey,
label=body.label, label=body.label,
mandateId=mandateId, mandateId=wsMandateId or "",
userId=str(context.user.id), userId=str(context.user.id),
workspaceInstanceId=instanceId, workspaceInstanceId=instanceId,
recordFilter=body.recordFilter, recordFilter=body.recordFilter,
@ -1743,13 +1514,26 @@ async def listFeatureDataSources(
wsMandateId, _ = _validateInstanceAccess(instanceId, context) wsMandateId, _ = _validateInstanceAccess(instanceId, context)
from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import buildEffectiveByWorkspaceFds
rootIf = getRootInterface() rootIf = getRootInterface()
recordFilter: dict = {"workspaceInstanceId": instanceId} recordFilter: dict = {"workspaceInstanceId": instanceId}
if wsMandateId: if wsMandateId:
recordFilter["mandateId"] = wsMandateId recordFilter["mandateId"] = wsMandateId
records = rootIf.db.getRecordset(FeatureDataSource, recordFilter=recordFilter) records = rootIf.db.getRecordset(FeatureDataSource, recordFilter=recordFilter)
return JSONResponse({"featureDataSources": records or []}) if not records:
return JSONResponse({"featureDataSources": []})
effNeutralize = buildEffectiveByWorkspaceFds(records, "neutralize", mode="aggregate")
effScope = buildEffectiveByWorkspaceFds(records, "scope", mode="aggregate")
effRag = buildEffectiveByWorkspaceFds(records, "ragIndexEnabled", mode="aggregate")
for fds in records:
fdsId = fds.get("id", "")
fds["effectiveNeutralize"] = effNeutralize.get(fdsId, False)
fds["effectiveScope"] = effScope.get(fdsId, "personal")
fds["effectiveRagIndexEnabled"] = effRag.get(fdsId, False)
return JSONResponse({"featureDataSources": records})
@router.delete("/{instanceId}/feature-datasources/{featureDataSourceId}") @router.delete("/{instanceId}/feature-datasources/{featureDataSourceId}")
@ -1770,112 +1554,6 @@ async def deleteFeatureDataSource(
return JSONResponse({"success": True}) return JSONResponse({"success": True})
@router.get("/{instanceId}/connections/{connectionId}/services")
@limiter.limit("120/minute")
async def listConnectionServices(
request: Request,
instanceId: str = Path(...),
connectionId: str = Path(...),
context: RequestContext = Depends(getRequestContext),
):
"""Return the available services for a specific UserConnection."""
_mandateId, _ = _validateInstanceAccess(instanceId, context)
try:
from modules.connectors.connectorResolver import ConnectorResolver
from modules.serviceCenter import getService as getSvc
from modules.serviceCenter.context import ServiceCenterContext
ctx = ServiceCenterContext(
user=context.user,
mandate_id=_mandateId or "",
feature_instance_id=instanceId,
)
chatService = getSvc("chat", ctx)
securityService = getSvc("security", ctx)
dbInterface = _buildResolverDbInterface(chatService)
resolver = ConnectorResolver(securityService, dbInterface)
provider = await resolver.resolve(connectionId)
services = provider.getAvailableServices()
_serviceLabels = {
"sharepoint": "SharePoint",
"outlook": "Outlook",
"teams": "Teams",
"onedrive": "OneDrive",
"drive": "Google Drive",
"gmail": "Gmail",
"files": "Files (FTP)",
"kdrive": "kDrive",
"calendar": "Calendar",
"contact": "Contacts",
}
_serviceIcons = {
"sharepoint": "sharepoint",
"outlook": "mail",
"teams": "chat",
"onedrive": "cloud",
"drive": "cloud",
"gmail": "mail",
"files": "folder",
"kdrive": "cloud",
"calendar": "calendar",
"contact": "contact",
}
items = [
{
"service": s,
"label": _serviceLabels.get(s, s),
"icon": _serviceIcons.get(s, "folder"),
}
for s in services
]
return JSONResponse({"services": items})
except Exception as e:
logger.error(f"Error listing services for connection {connectionId}: {e}")
return JSONResponse({"services": [], "error": str(e)}, status_code=400)
@router.get("/{instanceId}/connections/{connectionId}/browse")
@limiter.limit("300/minute")
async def browseConnectionService(
request: Request,
instanceId: str = Path(...),
connectionId: str = Path(...),
service: str = Query(..., description="Service name (e.g. sharepoint, onedrive, outlook)"),
path: str = Query("/", description="Path within the service to browse"),
context: RequestContext = Depends(getRequestContext),
):
"""Browse folders/items within a connection's service at a given path."""
_mandateId, _ = _validateInstanceAccess(instanceId, context)
try:
from modules.connectors.connectorResolver import ConnectorResolver
from modules.serviceCenter import getService as getSvc
from modules.serviceCenter.context import ServiceCenterContext
ctx = ServiceCenterContext(
user=context.user,
mandate_id=_mandateId or "",
feature_instance_id=instanceId,
)
chatService = getSvc("chat", ctx)
securityService = getSvc("security", ctx)
dbInterface = _buildResolverDbInterface(chatService)
resolver = ConnectorResolver(securityService, dbInterface)
adapter = await resolver.resolveService(connectionId, service)
entries = await adapter.browse(path, filter=None)
items = []
for entry in (entries or []):
items.append({
"name": entry.name,
"path": entry.path,
"isFolder": entry.isFolder,
"size": entry.size,
"mimeType": entry.mimeType,
"metadata": entry.metadata if hasattr(entry, "metadata") else {},
})
return JSONResponse({"items": items, "path": path, "service": service})
except Exception as e:
logger.error(f"Error browsing {service} for connection {connectionId} at '{path}': {e}")
return JSONResponse({"items": [], "error": str(e)}, status_code=400)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Voice endpoints # Voice endpoints
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -2191,6 +1869,71 @@ async def putWorkspaceUserSettings(
}) })
# =========================================================================
# Per-user UI state: tree expand/collapse (UDB + FilesTab)
# Persisted on WorkspaceUserSettings.uiTreeExpansion as a {scope: [ids]} map.
# Each FE tab uses its own scope key so collapse-state for one tab doesn't
# bleed into another.
@router.get("/{instanceId}/ui-tree-expansion/{scope}")
@limiter.limit("300/minute")
async def getUiTreeExpansion(
request: Request,
instanceId: str = Path(...),
scope: str = Path(..., description="UI scope key, e.g. 'sources', 'filesOwn', 'filesShared'"),
context: RequestContext = Depends(getRequestContext),
):
"""Return the expanded tree-node ids for the current user + scope.
Returns `null` when the user has never persisted a state for this scope
(lets the FE fall back to backend `defaultExpanded` hints). Returns `[]`
when the user actively collapsed everything.
"""
_validateInstanceAccess(instanceId, context)
wsInterface = _getWorkspaceInterface(context, instanceId)
settings = wsInterface.getWorkspaceUserSettings(str(context.user.id))
expansion = (settings.uiTreeExpansion if settings else {}) or {}
if scope not in expansion:
return JSONResponse({"expandedNodes": None})
return JSONResponse({"expandedNodes": list(expansion.get(scope) or [])})
@router.put("/{instanceId}/ui-tree-expansion/{scope}")
@limiter.limit("300/minute")
async def putUiTreeExpansion(
request: Request,
instanceId: str = Path(...),
scope: str = Path(...),
body: dict = Body(...),
context: RequestContext = Depends(getRequestContext),
):
"""Replace the expanded-node list for one scope.
Body: `{"expandedNodes": List[str]}`. Empty list = explicit collapse-all.
"""
_validateInstanceAccess(instanceId, context)
wsInterface = _getWorkspaceInterface(context, instanceId)
userId = str(context.user.id)
nodes = body.get("expandedNodes")
if not isinstance(nodes, list):
raise HTTPException(status_code=400, detail=routeApiMsg("expandedNodes must be a list"))
cleaned = [str(n) for n in nodes if isinstance(n, (str, int))]
existing = wsInterface.getWorkspaceUserSettings(userId)
existingMap: Dict[str, List[str]] = (existing.uiTreeExpansion if existing else {}) or {}
existingMap = dict(existingMap)
existingMap[scope] = cleaned
data = {
"userId": userId,
"mandateId": str(context.mandateId) if context.mandateId else "",
"featureInstanceId": instanceId,
"uiTreeExpansion": existingMap,
}
wsInterface.saveWorkspaceUserSettings(data)
return JSONResponse({"expandedNodes": cleaned})
# ========================================================================= # =========================================================================
# RAG / Knowledge — anonymised instance statistics (presentation / KPIs) # RAG / Knowledge — anonymised instance statistics (presentation / KPIs)

View file

@ -68,9 +68,19 @@ def removeDemoConfig(
request: Request, request: Request,
currentUser: User = Depends(requirePlatformAdmin), currentUser: User = Depends(requirePlatformAdmin),
) -> dict: ) -> dict:
"""Remove all data created by a demo configuration.""" """Remove all data created by a demo configuration.
Requires X-Confirm-Destructive: true header as safety guard.
"""
from modules.demoConfigs import getDemoConfigByCode from modules.demoConfigs import getDemoConfigByCode
confirmHeader = request.headers.get("X-Confirm-Destructive", "").lower()
if confirmHeader != "true":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Destructive operation requires header X-Confirm-Destructive: true",
)
config = getDemoConfigByCode(code) config = getDemoConfigByCode(code)
if not config: if not config:
raise HTTPException( raise HTTPException(
@ -79,7 +89,7 @@ def removeDemoConfig(
) )
db = getRootDbAppConnector() db = getRootDbAppConnector()
logger.info(f"Removing demo config '{code}' (user: {currentUser.username})") logger.info(f"Removing demo config '{code}' (user: {currentUser.username}, confirmed)")
summary = config.remove(db) summary = config.remove(db)
logger.info(f"Demo config '{code}' removed: {summary}") logger.info(f"Demo config '{code}' removed: {summary}")

View file

@ -778,7 +778,12 @@ async def _updateKnowledgeConsent(
cancelled = cancelJobsByConnection(connectionId) cancelled = cancelJobsByConnection(connectionId)
else: else:
from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelDataSource import DataSource
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId, "ragIndexEnabled": True}) from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
allConnDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
dataSources = [
ds for ds in (allConnDs or [])
if getEffectiveFlag(ds, "ragIndexEnabled", allConnDs, mode="walk") is True
]
if dataSources: if dataSources:
from modules.serviceCenter.services.serviceBackgroundJobs import startJob from modules.serviceCenter.services.serviceBackgroundJobs import startJob
authority = connection.authority.value if hasattr(connection.authority, "value") else str(connection.authority or "") authority = connection.authority.value if hasattr(connection.authority, "value") else str(connection.authority or "")

View file

@ -211,7 +211,7 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user, *, man
from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob
await knowledgeService.requestIngestion( handle = await knowledgeService.requestIngestion(
IngestionJob( IngestionJob(
sourceKind="file", sourceKind="file",
sourceId=fileId, sourceId=fileId,
@ -229,6 +229,9 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user, *, man
# Re-acquire interface after await to avoid stale user context from the singleton # Re-acquire interface after await to avoid stale user context from the singleton
mgmtInterface = interfaceDbManagement.getInterface(user) mgmtInterface = interfaceDbManagement.getInterface(user)
mgmtInterface.updateFile(fileId, {"status": "active"}) mgmtInterface.updateFile(fileId, {"status": "active"})
if handle.status == "failed":
logger.warning(f"Auto-index ingestion failed for file {fileId} ({fileName}): {handle.error}")
else:
logger.info(f"Auto-index complete for file {fileId} ({fileName})") logger.info(f"Auto-index complete for file {fileId} ({fileName})")
except Exception as e: except Exception as e:
@ -256,6 +259,24 @@ router = APIRouter(
) )
def _getInterfaceForOwnedItem(currentUser: User, context, itemId: str, modelClass) -> Any:
"""Create a management interface scoped to the item's own context.
Looks up the item by ID (unscoped) to resolve its mandateId/featureInstanceId,
then creates the interface with THAT context. This ensures toggle operations
work regardless of which page the user is on."""
unscoped = interfaceDbManagement.getInterface(currentUser)
record = unscoped.db.getRecord(modelClass, itemId)
if not record:
raise interfaceDbManagement.FileNotFoundError(f"Item {itemId} not found")
itemMandateId = record.get("mandateId") if isinstance(record, dict) else getattr(record, "mandateId", None)
itemInstanceId = record.get("featureInstanceId") if isinstance(record, dict) else getattr(record, "featureInstanceId", None)
return interfaceDbManagement.getInterface(
currentUser,
mandateId=str(itemMandateId) if itemMandateId else None,
featureInstanceId=str(itemInstanceId) if itemInstanceId else None,
)
@router.get("/folders/tree") @router.get("/folders/tree")
@limiter.limit("120/minute") @limiter.limit("120/minute")
def get_folder_tree( def get_folder_tree(
@ -272,10 +293,12 @@ def get_folder_tree(
) )
o = (owner or "me").strip().lower() o = (owner or "me").strip().lower()
if o == "me": if o == "me":
return managementInterface.getOwnFolderTree() folders = managementInterface.getOwnFolderTree()
if o == "shared": elif o == "shared":
return managementInterface.getSharedFolderTree() folders = managementInterface.getSharedFolderTree()
else:
raise HTTPException(status_code=400, detail="owner must be 'me' or 'shared'") raise HTTPException(status_code=400, detail="owner must be 'me' or 'shared'")
return folders
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -283,6 +306,185 @@ def get_folder_tree(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/attributes")
@limiter.limit("120/minute")
def getAttributesForIds(
request: Request,
body: Dict[str, Any] = Body(...),
currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext),
):
"""Return current attribute values (neutralize, scope, ragIndexEnabled) for
a list of node IDs. For folder IDs, computes 'mixed' by checking direct
children. The frontend sends this after every toggle to refresh visible
nodes without reloading the tree structure."""
ids = body.get("ids", [])
if not isinstance(ids, list) or len(ids) == 0:
return {}
if len(ids) > 500:
raise HTTPException(status_code=400, detail="Max 500 IDs per request")
try:
managementInterface = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
db = managementInterface.db
userId = str(currentUser.id)
allFolders = db.getRecordset(FileFolder, recordFilter={"sysCreatedBy": userId}) or []
allFiles = db.getRecordset(FileItem, recordFilter={"sysCreatedBy": userId}) or []
folderById = {f["id"]: f for f in allFolders}
fileById = {f["id"]: f for f in allFiles}
logger.info(
"getAttributesForIds: %d ids requested, %d folders found, %d files found",
len(ids), len(allFolders), len(allFiles),
)
result: Dict[str, Dict[str, Any]] = {}
for nodeId in ids:
if nodeId.startswith("__filesRoot:"):
attrs = _computeSyntheticRootAttrs(allFolders, allFiles)
result[nodeId] = attrs
elif nodeId in folderById:
folder = folderById[nodeId]
attrs = _computeFolderAttrs(folder, allFolders, allFiles)
result[nodeId] = attrs
elif nodeId in fileById:
f = fileById[nodeId]
result[nodeId] = {
"neutralize": bool(f.get("neutralize", False)),
"scope": f.get("scope", "personal"),
}
else:
logger.debug("getAttributesForIds: unknown id=%s", nodeId)
logger.info("getAttributesForIds: returning %d entries", len(result))
return result
except HTTPException:
raise
except Exception as e:
logger.error(f"getAttributesForIds error: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _computeFolderAttrs(
folder: Dict[str, Any],
allFolders: List[Dict[str, Any]],
allFiles: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Compute attributes for a folder. Recursively checks the entire subtree:
if ANY descendant at any depth has a different value, the folder shows 'mixed'.
This propagates up through all ancestor levels."""
fid = folder["id"]
neutralizeResult = _effectiveNeutralize(fid, allFolders, allFiles)
scopeResult = _effectiveScope(fid, allFolders, allFiles)
return {"neutralize": neutralizeResult, "scope": scopeResult}
def _effectiveNeutralize(
folderId: str,
allFolders: List[Dict[str, Any]],
allFiles: List[Dict[str, Any]],
) -> Any:
"""Recursively compute effective neutralize for a folder.
Returns 'mixed' if any descendants diverge, otherwise the folder's own value."""
childFolders = [f for f in allFolders if f.get("parentId") == folderId]
childFiles = [f for f in allFiles if f.get("folderId") == folderId]
if not childFolders and not childFiles:
folder = next((f for f in allFolders if f["id"] == folderId), None)
return bool(folder.get("neutralize", False)) if folder else False
childVals = set()
for cf in childFolders:
effective = _effectiveNeutralize(cf["id"], allFolders, allFiles)
if effective == "mixed":
return "mixed"
childVals.add(effective)
for cf in childFiles:
childVals.add(bool(cf.get("neutralize", False)))
if len(childVals) > 1:
return "mixed"
if not childVals:
folder = next((f for f in allFolders if f["id"] == folderId), None)
return bool(folder.get("neutralize", False)) if folder else False
return childVals.pop()
def _effectiveScope(
folderId: str,
allFolders: List[Dict[str, Any]],
allFiles: List[Dict[str, Any]],
) -> Any:
"""Recursively compute effective scope for a folder.
Returns 'mixed' if any descendants diverge, otherwise the folder's own value."""
childFolders = [f for f in allFolders if f.get("parentId") == folderId]
childFiles = [f for f in allFiles if f.get("folderId") == folderId]
if not childFolders and not childFiles:
folder = next((f for f in allFolders if f["id"] == folderId), None)
return folder.get("scope", "personal") if folder else "personal"
childVals = set()
for cf in childFolders:
effective = _effectiveScope(cf["id"], allFolders, allFiles)
if effective == "mixed":
return "mixed"
childVals.add(effective)
for cf in childFiles:
childVals.add(cf.get("scope", "personal"))
if len(childVals) > 1:
return "mixed"
if not childVals:
folder = next((f for f in allFolders if f["id"] == folderId), None)
return folder.get("scope", "personal") if folder else "personal"
return childVals.pop()
def _computeSyntheticRootAttrs(
allFolders: List[Dict[str, Any]],
allFiles: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Compute attributes for the synthetic root by recursively checking the
entire tree. If ANY item at any depth diverges, root shows 'mixed'."""
topFolders = [f for f in allFolders if not f.get("parentId")]
topFiles = [f for f in allFiles if not f.get("folderId")]
neutralizeVals = set()
scopeVals = set()
for cf in topFolders:
nEff = _effectiveNeutralize(cf["id"], allFolders, allFiles)
if nEff == "mixed":
neutralizeVals.add(True)
neutralizeVals.add(False)
else:
neutralizeVals.add(nEff)
sEff = _effectiveScope(cf["id"], allFolders, allFiles)
if sEff == "mixed":
scopeVals.add("__mixed_a__")
scopeVals.add("__mixed_b__")
else:
scopeVals.add(sEff)
for cf in topFiles:
neutralizeVals.add(bool(cf.get("neutralize", False)))
scopeVals.add(cf.get("scope", "personal"))
if not neutralizeVals and not scopeVals:
return {"neutralize": False, "scope": "personal"}
return {
"neutralize": "mixed" if len(neutralizeVals) > 1 else (neutralizeVals.pop() if neutralizeVals else False),
"scope": "mixed" if len(scopeVals) > 1 else (scopeVals.pop() if scopeVals else "personal"),
}
@router.post("/folders", status_code=status.HTTP_201_CREATED) @router.post("/folders", status_code=status.HTTP_201_CREATED)
@limiter.limit("30/minute") @limiter.limit("30/minute")
def create_folder( def create_folder(
@ -353,7 +555,12 @@ def move_folder(
context: RequestContext = Depends(getRequestContext), context: RequestContext = Depends(getRequestContext),
): ):
try: try:
# FE may send `parentId` or `targetParentId`. Accept both so the
# FormGeneratorTree generic `provider.moveNodes(targetParentId)` API
# remains consistent with the file-move (PUT /api/files/{id}) shape.
newParentId = body.get("parentId") newParentId = body.get("parentId")
if newParentId is None:
newParentId = body.get("targetParentId")
managementInterface = interfaceDbManagement.getInterface( managementInterface = interfaceDbManagement.getInterface(
currentUser, currentUser,
mandateId=str(context.mandateId) if context.mandateId else None, mandateId=str(context.mandateId) if context.mandateId else None,
@ -414,11 +621,7 @@ def patch_folder_scope(
if not scope: if not scope:
raise HTTPException(status_code=400, detail="scope is required") raise HTTPException(status_code=400, detail="scope is required")
cascadeToFiles = body.get("cascadeChildren", body.get("cascadeToFiles", False)) cascadeToFiles = body.get("cascadeChildren", body.get("cascadeToFiles", False))
managementInterface = interfaceDbManagement.getInterface( managementInterface = _getInterfaceForOwnedItem(currentUser, context, folderId, FileFolder)
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
return managementInterface.patchFolderScope(folderId, scope, cascadeToFiles) return managementInterface.patchFolderScope(folderId, scope, cascadeToFiles)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
@ -446,11 +649,7 @@ def patch_folder_neutralize(
neutralize = body.get("neutralize") neutralize = body.get("neutralize")
if neutralize is None: if neutralize is None:
raise HTTPException(status_code=400, detail="neutralize is required") raise HTTPException(status_code=400, detail="neutralize is required")
managementInterface = interfaceDbManagement.getInterface( managementInterface = _getInterfaceForOwnedItem(currentUser, context, folderId, FileFolder)
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
return managementInterface.patchFolderNeutralize(folderId, bool(neutralize)) return managementInterface.patchFolderNeutralize(folderId, bool(neutralize))
except PermissionError as e: except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e)) raise HTTPException(status_code=403, detail=str(e))
@ -1031,11 +1230,7 @@ def updateFileScope(
if scope == "global" and not context.isSysAdmin: if scope == "global" and not context.isSysAdmin:
raise HTTPException(status_code=403, detail=routeApiMsg("Only sysadmins can set global scope")) raise HTTPException(status_code=403, detail=routeApiMsg("Only sysadmins can set global scope"))
managementInterface = interfaceDbManagement.getInterface( managementInterface = _getInterfaceForOwnedItem(context.user, context, fileId, FileItem)
context.user,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
managementInterface.updateFile(fileId, {"scope": scope}) managementInterface.updateFile(fileId, {"scope": scope})
@ -1093,11 +1288,7 @@ def updateFileNeutralize(
fails the file simply has no index no un-neutralized data can leak. fails the file simply has no index no un-neutralized data can leak.
""" """
try: try:
managementInterface = interfaceDbManagement.getInterface( managementInterface = _getInterfaceForOwnedItem(context.user, context, fileId, FileItem)
context.user,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
managementInterface.updateFile(fileId, {"neutralize": neutralize}) managementInterface.updateFile(fileId, {"neutralize": neutralize})
@ -1212,7 +1403,8 @@ def update_file(
request: Request, request: Request,
fileId: str = Path(..., description="ID of the file to update"), fileId: str = Path(..., description="ID of the file to update"),
file_info: Dict[str, Any] = Body(...), file_info: Dict[str, Any] = Body(...),
currentUser: User = Depends(getCurrentUser) currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext),
) -> FileItem: ) -> FileItem:
"""Update file info""" """Update file info"""
try: try:
@ -1221,7 +1413,11 @@ def update_file(
if not safeData: if not safeData:
raise HTTPException(status_code=400, detail=routeApiMsg("No editable fields provided")) raise HTTPException(status_code=400, detail=routeApiMsg("No editable fields provided"))
managementInterface = interfaceDbManagement.getInterface(currentUser) managementInterface = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
file = managementInterface.getFile(fileId) file = managementInterface.getFile(fileId)
if not file: if not file:
@ -1267,10 +1463,15 @@ def update_file(
def delete_file( def delete_file(
request: Request, request: Request,
fileId: str = Path(..., description="ID of the file to delete"), fileId: str = Path(..., description="ID of the file to delete"),
currentUser: User = Depends(getCurrentUser) currentUser: User = Depends(getCurrentUser),
context: RequestContext = Depends(getRequestContext),
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Delete a file""" """Delete a file"""
managementInterface = interfaceDbManagement.getInterface(currentUser) managementInterface = interfaceDbManagement.getInterface(
currentUser,
mandateId=str(context.mandateId) if context.mandateId else None,
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None,
)
# Check if the file exists # Check if the file exists
existingFile = managementInterface.getFile(fileId) existingFile = managementInterface.getFile(fileId)

View file

@ -43,6 +43,49 @@ def _ensureConnectionKnowledgeFlag(rootIf, connectionId: str) -> None:
except Exception as e: except Exception as e:
logger.warning("Could not auto-enable knowledgeIngestionEnabled for connection %s: %s", connectionId, e) logger.warning("Could not auto-enable knowledgeIngestionEnabled for connection %s: %s", connectionId, e)
def _computeOwnEffective(rootIf, rec, model, sourceId: str, flag: str) -> Any:
"""Re-load the record after modification and compute its aggregate effective value."""
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
getEffectiveFlag, getEffectiveFlagFds,
)
freshRec = rootIf.db.getRecord(model, sourceId)
if not freshRec:
return None
if model is DataSource:
connectionId = freshRec.get("connectionId", "")
allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
return getEffectiveFlag(freshRec, flag, allDs, mode="aggregate")
else:
wsId = freshRec.get("workspaceInstanceId", "")
allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": wsId})
return getEffectiveFlagFds(freshRec, flag, allFds, mode="aggregate")
def _computeAncestorEffectives(rootIf, rec, model, flag: str) -> List[Dict[str, Any]]:
"""Compute the aggregate effective value for all ancestors of `rec`."""
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
collectAncestorChain, collectAncestorChainFds,
getEffectiveFlag, getEffectiveFlagFds,
)
effectiveKey = f"effective{flag[0].upper()}{flag[1:]}"
if model is DataSource:
connectionId = rec.get("connectionId", "")
allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
ancestors = collectAncestorChain(rec, allDs)
return [
{"id": a.get("id") or getattr(a, "id", ""), effectiveKey: getEffectiveFlag(a, flag, allDs, mode="aggregate")}
for a in ancestors
]
else:
wsId = rec.get("workspaceInstanceId", "")
allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": wsId})
ancestors = collectAncestorChainFds(rec, allFds)
return [
{"id": a.get("id") or getattr(a, "id", ""), effectiveKey: getEffectiveFlagFds(a, flag, allFds, mode="aggregate")}
for a in ancestors
]
router = APIRouter( router = APIRouter(
prefix="/api/datasources", prefix="/api/datasources",
tags=["Data Sources"], tags=["Data Sources"],
@ -91,26 +134,41 @@ def _updateDataSourceScope(
try: try:
from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbApp import getRootInterface
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
cascadeResetDescendants, cascadeResetDescendants, cascadeResetDescendantsFds,
cascadeResetDescendantsFds, getEffectiveFlag, getEffectiveFlagFds,
collectAncestorChain, collectAncestorChainFds,
) )
rootIf = getRootInterface() rootIf = getRootInterface()
rec, model = _findSourceRecord(rootIf.db, sourceId) rec, model = _findSourceRecord(rootIf.db, sourceId)
if not rec: if not rec:
raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
rootIf.db.recordModify(model, sourceId, {"scope": scope}) # 1. Cascade reset descendants bottom-up (before modifying master)
cascaded = 0 resetIds: List[str] = []
if scope is not None: if scope is not None:
if model is DataSource: if model is DataSource:
cascaded = cascadeResetDescendants(rootIf, rec, "scope") resetIds = cascadeResetDescendants(rootIf, rec, "scope")
else: else:
cascaded = cascadeResetDescendantsFds(rootIf, rec, "scope") resetIds = cascadeResetDescendantsFds(rootIf, rec, "scope")
# 2. Set master value last (crash-safe)
rootIf.db.recordModify(model, sourceId, {"scope": scope})
# 3. Compute effective + ancestor chain for response
updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "scope")
effectiveScope = _computeOwnEffective(rootIf, rec, model, sourceId, "scope")
logger.info( logger.info(
"Updated scope=%s for %s %s (cascade-reset %d descendants)", "Updated scope=%s for %s %s (cascade-reset %d descendants)",
scope, model.__name__, sourceId, cascaded, scope, model.__name__, sourceId, len(resetIds),
) )
return {"sourceId": sourceId, "scope": scope, "updated": True, "cascadedDescendants": cascaded} return {
"sourceId": sourceId,
"scope": scope,
"effectiveScope": effectiveScope,
"resetDescendantIds": resetIds,
"updatedAncestors": updatedAncestors,
}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -133,26 +191,39 @@ def _updateDataSourceNeutralize(
try: try:
from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbApp import getRootInterface
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import ( from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
cascadeResetDescendants, cascadeResetDescendants, cascadeResetDescendantsFds,
cascadeResetDescendantsFds,
) )
rootIf = getRootInterface() rootIf = getRootInterface()
rec, model = _findSourceRecord(rootIf.db, sourceId) rec, model = _findSourceRecord(rootIf.db, sourceId)
if not rec: if not rec:
raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
rootIf.db.recordModify(model, sourceId, {"neutralize": neutralize}) # 1. Cascade reset descendants bottom-up (before modifying master)
cascaded = 0 resetIds: List[str] = []
if neutralize is not None: if neutralize is not None:
if model is DataSource: if model is DataSource:
cascaded = cascadeResetDescendants(rootIf, rec, "neutralize") resetIds = cascadeResetDescendants(rootIf, rec, "neutralize")
else: else:
cascaded = cascadeResetDescendantsFds(rootIf, rec, "neutralize") resetIds = cascadeResetDescendantsFds(rootIf, rec, "neutralize")
# 2. Set master value last (crash-safe)
rootIf.db.recordModify(model, sourceId, {"neutralize": neutralize})
# 3. Compute effective + ancestor chain for response
updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "neutralize")
effectiveNeutralize = _computeOwnEffective(rootIf, rec, model, sourceId, "neutralize")
logger.info( logger.info(
"Updated neutralize=%s for %s %s (cascade-reset %d descendants)", "Updated neutralize=%s for %s %s (cascade-reset %d descendants)",
neutralize, model.__name__, sourceId, cascaded, neutralize, model.__name__, sourceId, len(resetIds),
) )
return {"sourceId": sourceId, "neutralize": neutralize, "updated": True, "cascadedDescendants": cascaded} return {
"sourceId": sourceId,
"neutralize": neutralize,
"effectiveNeutralize": effectiveNeutralize,
"resetDescendantIds": resetIds,
"updatedAncestors": updatedAncestors,
}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -204,26 +275,37 @@ async def _updateDataSourceRagIndex(
`True` enqueues a mini-bootstrap. `False` synchronously purges chunks. `True` enqueues a mini-bootstrap. `False` synchronously purges chunks.
Must be `async def` so `await startJob(...)` registers `_runJob` in the Must be `async def` so `await startJob(...)` registers `_runJob` in the
main event loop. Sync route worker thread temporary loop closes main event loop.
before the task runs job stays stuck forever.
""" """
try: try:
from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbApp import getRootInterface
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import cascadeResetDescendants from modules.serviceCenter.services.serviceKnowledge._inheritFlags import (
cascadeResetDescendants, cascadeResetDescendantsFds,
)
rootIf = getRootInterface() rootIf = getRootInterface()
rec = rootIf.db.getRecord(DataSource, sourceId) rec, model = _findSourceRecord(rootIf.db, sourceId)
if not rec: if not rec:
raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found") raise HTTPException(status_code=404, detail=f"DataSource {sourceId} not found")
rootIf.db.recordModify(DataSource, sourceId, {"ragIndexEnabled": ragIndexEnabled}) # 1. Cascade reset descendants bottom-up (before modifying master)
cascaded = 0 resetIds: List[str] = []
if ragIndexEnabled is not None: if ragIndexEnabled is not None:
cascaded = cascadeResetDescendants(rootIf, rec, "ragIndexEnabled") if model is DataSource:
resetIds = cascadeResetDescendants(rootIf, rec, "ragIndexEnabled")
else:
resetIds = cascadeResetDescendantsFds(rootIf, rec, "ragIndexEnabled")
# 2. Set master value last (crash-safe)
rootIf.db.recordModify(model, sourceId, {"ragIndexEnabled": ragIndexEnabled})
logger.info( logger.info(
"Updated ragIndexEnabled=%s for DataSource %s (cascade-reset %d descendants)", "Updated ragIndexEnabled=%s for %s %s (cascade-reset %d descendants)",
ragIndexEnabled, sourceId, cascaded, ragIndexEnabled, model.__name__, sourceId, len(resetIds),
) )
# Bootstrap / purge only for personal DataSource (file/folder-based RAG).
# FDS RAG is handled by the feature pipeline; the flag alone is enough.
if model is DataSource:
connectionId = rec.get("connectionId") or rec.get("connection_id") or "" connectionId = rec.get("connectionId") or rec.get("connection_id") or ""
if ragIndexEnabled is True: if ragIndexEnabled is True:
_ensureConnectionKnowledgeFlag(rootIf, connectionId) _ensureConnectionKnowledgeFlag(rootIf, connectionId)
@ -253,10 +335,20 @@ async def _updateDataSourceRagIndex(
mandateId=context.mandateId, mandateId=context.mandateId,
category=AuditCategory.PERMISSION.value, category=AuditCategory.PERMISSION.value,
action="rag_index_toggled", action="rag_index_toggled",
details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "cascadedDescendants": cascaded}), details=json.dumps({"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "resetDescendants": len(resetIds), "model": model.__name__}),
) )
return {"sourceId": sourceId, "ragIndexEnabled": ragIndexEnabled, "updated": True, "cascadedDescendants": cascaded} # 3. Compute effective + ancestors for response
updatedAncestors = _computeAncestorEffectives(rootIf, rec, model, "ragIndexEnabled")
effectiveRag = _computeOwnEffective(rootIf, rec, model, sourceId, "ragIndexEnabled")
return {
"sourceId": sourceId,
"ragIndexEnabled": ragIndexEnabled,
"effectiveRagIndexEnabled": effectiveRag,
"resetDescendantIds": resetIds,
"updatedAncestors": updatedAncestors,
}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -339,7 +431,17 @@ def _updateDataSourceSettings(
ownerId = str(rec.get("userId") or "") ownerId = str(rec.get("userId") or "")
currentUserId = str(context.user.id) currentUserId = str(context.user.id)
if ownerId and ownerId != currentUserId and not context.isSysAdmin: if ownerId and ownerId != currentUserId and not context.isSysAdmin:
scope = str(rec.get("scope") or "personal") from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
if model is DataSource:
connectionId = rec.get("connectionId", "")
allDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
scope = str(getEffectiveFlag(rec, "scope", allDs, mode="walk"))
else:
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource as FDS
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds
wsId = rec.get("workspaceInstanceId", "")
allFds = rootIf.db.getRecordset(FDS, recordFilter={"workspaceInstanceId": wsId})
scope = str(getEffectiveFlagFds(rec, "scope", allFds, mode="walk"))
isMandateAdmin = getattr(context, "isMandateAdmin", False) isMandateAdmin = getattr(context, "isMandateAdmin", False)
if scope == "personal" or not isMandateAdmin: if scope == "personal" or not isMandateAdmin:
raise HTTPException(status_code=403, detail="Not allowed to modify this DataSource's settings") raise HTTPException(status_code=403, detail="Not allowed to modify this DataSource's settings")

View file

@ -86,6 +86,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
""" """
from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelDataSource import DataSource
from modules.datamodels.datamodelKnowledge import FileContentIndex from modules.datamodels.datamodelKnowledge import FileContentIndex
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
out = [] out = []
for conn in connections: for conn in connections:
@ -136,8 +137,8 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""), "label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""),
"path": dsPath, "path": dsPath,
"sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""), "sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""),
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False), "ragIndexEnabled": getEffectiveFlag(ds, "ragIndexEnabled", dataSources, mode="walk"),
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False), "neutralize": getEffectiveFlag(ds, "neutralize", dataSources, mode="walk"),
"lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None), "lastIndexed": ds.get("lastIndexed") if isinstance(ds, dict) else getattr(ds, "lastIndexed", None),
"fileCount": filesByDs.get(dsId, 0), "fileCount": filesByDs.get(dsId, 0),
"chunkCount": chunksByDs.get(dsId, 0), "chunkCount": chunksByDs.get(dsId, 0),
@ -223,13 +224,165 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
return out return out
def _buildFeatureInstanceInventory(featureInstanceIds, rootIf, knowledgeIf) -> List[Dict[str, Any]]:
"""Build per-feature-instance RAG inventory rows.
Feature-instance data lives in FileContentIndex with a non-empty
featureInstanceId. Additionally each feature instance may have
FeatureDataSource rows that define which tables/data are visible
as sources, with their own ragIndexEnabled flags.
Includes feature.bootstrap job status (running/success/error).
"""
from modules.datamodels.datamodelKnowledge import FileContentIndex
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
from modules.interfaces.interfaceFeatures import getFeatureInterface
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import FEATURE_BOOTSTRAP_JOB_TYPE
featureIf = getFeatureInterface(rootIf.db)
allFeatureJobs = jobService.listJobs(jobType=FEATURE_BOOTSTRAP_JOB_TYPE, limit=100)
out = []
for fiId in featureInstanceIds:
instance = featureIf.getFeatureInstance(fiId)
if not instance or not instance.enabled:
continue
indexRows = knowledgeIf.db.getRecordset(
FileContentIndex, recordFilter={"featureInstanceId": fiId}
)
fileIds = [
(r.get("id") if isinstance(r, dict) else getattr(r, "id", ""))
for r in indexRows
]
fileIds = [fid for fid in fileIds if fid]
chunkCounts = knowledgeIf.countChunksByFileIds(fileIds) if fileIds else {}
statusCounts: Dict[str, int] = {}
for r in indexRows:
st = (r.get("status") if isinstance(r, dict) else getattr(r, "status", "unknown")) or "unknown"
statusCounts[st] = statusCounts.get(st, 0) + 1
allFds = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"workspaceInstanceId": fiId})
dsItems = []
anyRagEnabled = False
for fds in allFds:
tblName = (fds.get("tableName") if isinstance(fds, dict) else getattr(fds, "tableName", "")) or ""
fCode = (fds.get("featureCode") if isinstance(fds, dict) else getattr(fds, "featureCode", "")) or ""
if tblName == "*" or not fCode:
continue
fdsId = fds.get("id") if isinstance(fds, dict) else getattr(fds, "id", "")
ragEnabled = getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate")
if ragEnabled:
anyRagEnabled = True
dsItems.append({
"id": fdsId,
"label": (fds.get("label") if isinstance(fds, dict) else getattr(fds, "label", "")) or "",
"tableName": tblName,
"featureCode": fCode,
"ragIndexEnabled": ragEnabled,
})
fiJobs = [
j for j in allFeatureJobs
if (j.get("payload") or {}).get("workspaceInstanceId") == fiId
]
runningJobs = [
{
"jobId": j["id"],
"progress": j.get("progress", 0),
"progressMessage": (
resolveJobMessage(j.get("progressMessageData"))
or j.get("progressMessage", "")
),
}
for j in fiJobs
if j.get("status") in ("PENDING", "RUNNING")
]
lastError: Optional[Dict[str, Any]] = None
lastSuccess: Optional[Dict[str, Any]] = None
for j in fiJobs:
jStatus = j.get("status")
if jStatus == "ERROR" and lastError is None:
lastError = {
"jobId": j["id"],
"errorMessage": j.get("errorMessage", ""),
"finishedAt": j.get("finishedAt"),
}
elif jStatus == "SUCCESS" and lastSuccess is None:
result = j.get("result") or {}
lastSuccess = {
"jobId": j["id"],
"finishedAt": j.get("finishedAt"),
"indexed": result.get("indexed", 0),
"skippedDuplicate": result.get("skippedDuplicate", 0),
"failed": result.get("failed", 0),
}
if lastError and lastSuccess:
break
if not indexRows and not dsItems:
continue
out.append({
"featureInstanceId": fiId,
"featureCode": instance.featureCode,
"label": instance.label or instance.featureCode,
"mandateId": str(instance.mandateId) if instance.mandateId else "",
"fileCount": len(indexRows),
"chunkCount": sum(chunkCounts.values()),
"statusCounts": statusCounts,
"dataSources": dsItems,
"ragEnabled": anyRagEnabled,
"runningJobs": runningJobs,
"lastSuccess": lastSuccess,
"lastError": lastError,
})
return out
@router.get("/my-mandates")
@limiter.limit("30/minute")
def _getMyMandates(
request: Request,
currentUser: User = Depends(getCurrentUser),
) -> List[Dict[str, Any]]:
"""Return mandates where the current user has an active membership.
Used by the RAG inventory frontend to populate the mandate dropdown
without requiring admin rights (unlike GET /api/mandates/).
"""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
rootIf = getRootInterface()
userMandates = rootIf.getUserMandates(str(currentUser.id))
result = []
for um in userMandates:
if not um.enabled:
continue
mandate = rootIf.getMandate(str(um.mandateId))
if not mandate or not getattr(mandate, "enabled", True):
continue
result.append({
"id": str(um.mandateId),
"name": getattr(mandate, "name", ""),
"label": getattr(mandate, "label", None) or getattr(mandate, "name", ""),
})
return result
except Exception as e:
logger.error("Error in RAG inventory /my-mandates: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/me") @router.get("/me")
@limiter.limit("30/minute") @limiter.limit("30/minute")
def _getInventoryMe( def _getInventoryMe(
request: Request, request: Request,
currentUser: User = Depends(getCurrentUser), currentUser: User = Depends(getCurrentUser),
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Personal RAG inventory: own connections + DataSources + chunk counts.""" """Personal RAG inventory: own connections + DataSources + chunk counts + feature uploads."""
try: try:
from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbApp import getRootInterface
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
@ -243,7 +396,20 @@ def _getInventoryMe(
totalChunks = sum(c.get("totalChunks", 0) for c in items) totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items) totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}} featureAccesses = rootIf.getFeatureAccessesForUser(str(currentUser.id))
fiIds = [
str(fa.featureInstanceId) for fa in featureAccesses
if fa.enabled and fa.featureInstanceId
]
fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf)
totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems)
totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems)
return {
"connections": items,
"featureInstances": fiItems,
"totals": {"files": totalFiles, "chunks": totalChunks},
}
except Exception as e: except Exception as e:
logger.error("Error in RAG inventory /me: %s", e, exc_info=True) logger.error("Error in RAG inventory /me: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@ -262,21 +428,43 @@ def _getInventoryMandate(
from modules.interfaces.interfaceDbApp import getRootInterface from modules.interfaces.interfaceDbApp import getRootInterface
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface, aggregateMandateRagTotalBytes from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface, aggregateMandateRagTotalBytes
from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService from modules.serviceCenter.services.serviceBackgroundJobs import mainBackgroundJobService as jobService
rootIf = getRootInterface() rootIf = getRootInterface()
knowledgeIf = getKnowledgeInterface(None) knowledgeIf = getKnowledgeInterface(None)
mandateId = str(context.mandateId) if context.mandateId else "" mandateId = str(context.mandateId)
userId = str(context.user.id)
from modules.datamodels.datamodelUam import UserConnection userMandates = rootIf.getUserMandates(userId)
allConnections = rootIf.db.getRecordset(UserConnection, recordFilter={"mandateId": mandateId}) isMember = any(
connectionObjects = [type("C", (), row)() if isinstance(row, dict) else row for row in allConnections] getattr(um, "mandateId", None) == mandateId and um.enabled
for um in userMandates
)
if not isMember and not context.isSysAdmin:
raise HTTPException(status_code=403, detail=routeApiMsg("No membership in this mandate"))
items = _buildConnectionInventory(connectionObjects, rootIf, knowledgeIf, jobService) mandateMembers = rootIf.getUserMandatesByMandate(mandateId)
memberUserIds = {getattr(um, "userId", None) for um in mandateMembers}
memberUserIds.discard(None)
allConnections = []
for uid in memberUserIds:
allConnections.extend(rootIf.getUserConnections(uid))
items = _buildConnectionInventory(allConnections, rootIf, knowledgeIf, jobService)
totalChunks = sum(c.get("totalChunks", 0) for c in items) totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items) totalFiles = sum(c.get("totalFiles", 0) for c in items)
totalBytes = aggregateMandateRagTotalBytes(mandateId) totalBytes = aggregateMandateRagTotalBytes(mandateId)
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes}} mandateInstances = rootIf.getFeatureInstancesByMandate(mandateId, enabledOnly=True)
fiIds = [str(inst.id) for inst in mandateInstances if inst.id]
fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf)
totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems)
totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems)
return {
"connections": items,
"featureInstances": fiItems,
"totals": {"files": totalFiles, "chunks": totalChunks, "bytes": totalBytes},
}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -308,7 +496,22 @@ def _getInventoryPlatform(
totalChunks = sum(c.get("totalChunks", 0) for c in items) totalChunks = sum(c.get("totalChunks", 0) for c in items)
totalFiles = sum(c.get("totalFiles", 0) for c in items) totalFiles = sum(c.get("totalFiles", 0) for c in items)
return {"connections": items, "totals": {"files": totalFiles, "chunks": totalChunks}} from modules.datamodels.datamodelFeatures import FeatureInstance
allInstances = rootIf.db.getRecordset(FeatureInstance, recordFilter={"enabled": True})
fiIds = [
(r.get("id") if isinstance(r, dict) else getattr(r, "id", ""))
for r in allInstances
]
fiIds = [fid for fid in fiIds if fid]
fiItems = _buildFeatureInstanceInventory(fiIds, rootIf, knowledgeIf)
totalFiles += sum(fi.get("fileCount", 0) for fi in fiItems)
totalChunks += sum(fi.get("chunkCount", 0) for fi in fiItems)
return {
"connections": items,
"featureInstances": fiItems,
"totals": {"files": totalFiles, "chunks": totalChunks},
}
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -345,8 +548,9 @@ async def _reindexConnection(
if str(conn.userId) != str(currentUser.id): if str(conn.userId) != str(currentUser.id):
raise HTTPException(status_code=403, detail="Not your connection") raise HTTPException(status_code=403, detail="Not your connection")
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
ragDs = [ds for ds in dataSources if (ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False))] ragDs = [ds for ds in dataSources if getEffectiveFlag(ds, "ragIndexEnabled", dataSources, mode="walk") is True]
if not ragDs: if not ragDs:
return {"status": "skipped", "reason": "no_rag_enabled_datasources"} return {"status": "skipped", "reason": "no_rag_enabled_datasources"}
@ -368,6 +572,47 @@ async def _reindexConnection(
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/reindex-feature/{workspaceInstanceId}")
@limiter.limit("10/minute")
async def _reindexFeature(
request: Request,
workspaceInstanceId: str,
currentUser: User = Depends(getCurrentUser),
) -> Dict[str, Any]:
"""Re-trigger feature data bootstrap for a workspace instance.
Indexes all RAG-enabled FeatureDataSource rows into the knowledge store.
Must be ``async def`` so ``await startJob(...)`` registers in the main loop.
"""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import FEATURE_BOOTSTRAP_JOB_TYPE
rootIf = getRootInterface()
featureAccesses = rootIf.getFeatureAccessesForUser(str(currentUser.id))
hasAccess = any(
str(fa.featureInstanceId) == workspaceInstanceId and fa.enabled
for fa in featureAccesses
)
if not hasAccess and not getattr(currentUser, "isSysAdmin", False):
raise HTTPException(status_code=403, detail="No access to this feature instance")
jobId = await startJob(
FEATURE_BOOTSTRAP_JOB_TYPE,
{"workspaceInstanceId": workspaceInstanceId},
triggeredBy=str(currentUser.id),
)
logger.info("Feature reindex triggered for workspace %s (jobId=%s)", workspaceInstanceId, jobId)
return {"status": "queued", "workspaceInstanceId": workspaceInstanceId, "jobId": jobId}
except HTTPException:
raise
except Exception as e:
logger.error("Error triggering feature reindex: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/jobs") @router.get("/jobs")
@limiter.limit("60/minute") @limiter.limit("60/minute")
def _getActiveJobs( def _getActiveJobs(

View file

@ -341,11 +341,10 @@ class RbacClass:
return [] return []
try: try:
conn = self.dbApp.connection
roleIds = set() roleIds = set()
# 1. Mandant-Rollen via UserMandate → UserMandateRole (SINGLE Query) # 1. Mandant-Rollen via UserMandate → UserMandateRole (SINGLE Query)
with conn.cursor() as cursor: with self.dbApp.borrowCursor() as cursor:
cursor.execute( cursor.execute(
""" """
SELECT umr."roleId" SELECT umr."roleId"
@ -360,7 +359,7 @@ class RbacClass:
# 2. Instanz-Rollen via FeatureAccess → FeatureAccessRole (SINGLE Query) # 2. Instanz-Rollen via FeatureAccess → FeatureAccessRole (SINGLE Query)
if featureInstanceId: if featureInstanceId:
with conn.cursor() as cursor: with self.dbApp.borrowCursor() as cursor:
cursor.execute( cursor.execute(
""" """
SELECT far."roleId" SELECT far."roleId"
@ -377,9 +376,8 @@ class RbacClass:
return [] return []
# 3. BULK Query: Alle Regeln für alle Rollen + zugehörige Role-Daten # 3. BULK Query: Alle Regeln für alle Rollen + zugehörige Role-Daten
# SINGLE Query mit JOIN statt N+1
roleIdsList = list(roleIds) roleIdsList = list(roleIds)
with conn.cursor() as cursor: with self.dbApp.borrowCursor() as cursor:
cursor.execute( cursor.execute(
""" """
SELECT ar.*, r."mandateId" as "roleMandateId", SELECT ar.*, r."mandateId" as "roleMandateId",

View file

@ -67,7 +67,12 @@ def _registerDataSourceTools(registry: ToolRegistry, services):
sourceType = ds.get("sourceType", "") sourceType = ds.get("sourceType", "")
path = ds.get("path", "/") path = ds.get("path", "/")
label = ds.get("label", "") label = ds.get("label", "")
neutralize = bool(ds.get("neutralize", False)) from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
from modules.datamodels.datamodelDataSource import DataSource
from modules.interfaces.interfaceDbApp import getRootInterface
rootIf = getRootInterface()
allConnDs = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
neutralize = bool(getEffectiveFlag(ds, "neutralize", allConnDs or [ds], mode="walk"))
service = _SOURCE_TYPE_TO_SERVICE.get(sourceType, sourceType) service = _SOURCE_TYPE_TO_SERVICE.get(sourceType, sourceType)
if not connectionId: if not connectionId:
raise ValueError(f"DataSource '{dsId}' has no connectionId") raise ValueError(f"DataSource '{dsId}' has no connectionId")

View file

@ -110,9 +110,11 @@ def _registerFeatureSubAgentTools(registry: ToolRegistry, services):
recordFilter={"featureInstanceId": featureInstanceId, "workspaceInstanceId": workspaceInstanceId}, recordFilter={"featureInstanceId": featureInstanceId, "workspaceInstanceId": workspaceInstanceId},
) )
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds
_fdsAll = featureDataSources or []
_anySourceNeutralize = any( _anySourceNeutralize = any(
bool(ds.get("neutralize", False) if isinstance(ds, dict) else getattr(ds, "neutralize", False)) getEffectiveFlagFds(ds, "neutralize", _fdsAll, mode="walk") is True
for ds in (featureDataSources or []) for ds in _fdsAll
) )
neutralizeFieldsPerTable: Dict[str, List[str]] = {} neutralizeFieldsPerTable: Dict[str, List[str]] = {}

View file

@ -95,8 +95,7 @@ class FeatureDataProvider:
def getActualColumns(self, tableName: str) -> List[str]: def getActualColumns(self, tableName: str) -> List[str]:
"""Read real column names from PostgreSQL information_schema.""" """Read real column names from PostgreSQL information_schema."""
try: try:
conn = self._db.connection with self._db.borrowCursor() as cur:
with conn.cursor() as cur:
cur.execute( cur.execute(
"SELECT column_name FROM information_schema.columns " "SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
@ -131,7 +130,6 @@ class FeatureDataProvider:
Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``. Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``.
""" """
_validateTableName(tableName) _validateTableName(tableName)
conn = self._db.connection
if fields: if fields:
invalid = [f for f in fields if not _isValidIdentifier(f)] invalid = [f for f in fields if not _isValidIdentifier(f)]
@ -141,7 +139,7 @@ class FeatureDataProvider:
"error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.", "error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.",
} }
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn) scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db)
extraWhere, extraParams = _buildFilterClauses(extraFilters) extraWhere, extraParams = _buildFilterClauses(extraFilters)
fullWhere = scopeFilter["where"] fullWhere = scopeFilter["where"]
@ -152,7 +150,7 @@ class FeatureDataProvider:
t0 = time.time() t0 = time.time()
try: try:
with conn.cursor() as cur: with self._db.borrowCursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}' countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
cur.execute(countSql, allParams) cur.execute(countSql, allParams)
total = cur.fetchone()["count"] if cur.rowcount else 0 total = cur.fetchone()["count"] if cur.rowcount else 0
@ -179,10 +177,6 @@ class FeatureDataProvider:
_debugQueryLog("browseTable", tableName, { _debugQueryLog("browseTable", tableName, {
"fields": fields, "limit": limit, "offset": offset, "fields": fields, "limit": limit, "offset": offset,
}, errResult, elapsed) }, errResult, elapsed)
try:
conn.rollback()
except Exception:
pass
return errResult return errResult
def aggregateTable( def aggregateTable(
@ -208,8 +202,7 @@ class FeatureDataProvider:
if groupBy and not _isValidIdentifier(groupBy): if groupBy and not _isValidIdentifier(groupBy):
return {"rows": [], "error": f"Invalid groupBy field: {groupBy}"} return {"rows": [], "error": f"Invalid groupBy field: {groupBy}"}
conn = self._db.connection scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db)
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
extraWhere, extraParams = _buildFilterClauses(extraFilters) extraWhere, extraParams = _buildFilterClauses(extraFilters)
fullWhere = scopeFilter["where"] fullWhere = scopeFilter["where"]
@ -220,7 +213,7 @@ class FeatureDataProvider:
t0 = time.time() t0 = time.time()
try: try:
with conn.cursor() as cur: with self._db.borrowCursor() as cur:
if groupBy: if groupBy:
sql = ( sql = (
f'SELECT "{groupBy}" AS "groupValue", {aggregate}("{field}") AS "result" ' f'SELECT "{groupBy}" AS "groupValue", {aggregate}("{field}") AS "result" '
@ -253,10 +246,6 @@ class FeatureDataProvider:
_debugQueryLog("aggregateTable", tableName, { _debugQueryLog("aggregateTable", tableName, {
"aggregate": aggregate, "field": field, "groupBy": groupBy, "aggregate": aggregate, "field": field, "groupBy": groupBy,
}, errResult, elapsed) }, errResult, elapsed)
try:
conn.rollback()
except Exception:
pass
return errResult return errResult
def queryTable( def queryTable(
@ -277,7 +266,6 @@ class FeatureDataProvider:
``extraFilters`` are mandatory record-level scoping filters injected by the pipeline. ``extraFilters`` are mandatory record-level scoping filters injected by the pipeline.
""" """
_validateTableName(tableName) _validateTableName(tableName)
conn = self._db.connection
if fields: if fields:
invalid = [f for f in fields if not _isValidIdentifier(f)] invalid = [f for f in fields if not _isValidIdentifier(f)]
@ -287,7 +275,7 @@ class FeatureDataProvider:
"error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.", "error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.",
} }
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn) scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, db=self._db)
combinedFilters = list(filters or []) + list(extraFilters or []) combinedFilters = list(filters or []) + list(extraFilters or [])
extraWhere, extraParams = _buildFilterClauses(combinedFilters if combinedFilters else None) extraWhere, extraParams = _buildFilterClauses(combinedFilters if combinedFilters else None)
@ -300,7 +288,7 @@ class FeatureDataProvider:
t0 = time.time() t0 = time.time()
try: try:
with conn.cursor() as cur: with self._db.borrowCursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}' countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
cur.execute(countSql, allParams) cur.execute(countSql, allParams)
total = cur.fetchone()["count"] if cur.rowcount else 0 total = cur.fetchone()["count"] if cur.rowcount else 0
@ -329,10 +317,6 @@ class FeatureDataProvider:
"filters": filters, "fields": fields, "orderBy": orderBy, "filters": filters, "fields": fields, "orderBy": orderBy,
"limit": limit, "offset": offset, "limit": limit, "offset": offset,
}, errResult, elapsed) }, errResult, elapsed)
try:
conn.rollback()
except Exception:
pass
return errResult return errResult
@ -343,13 +327,13 @@ class FeatureDataProvider:
_instanceColCache: Dict[str, str] = {} _instanceColCache: Dict[str, str] = {}
def _resolveInstanceColumn(tableName: str, dbConnection=None) -> str: def _resolveInstanceColumn(tableName: str, db=None) -> str:
"""Detect whether the table uses ``instanceId`` or ``featureInstanceId``.""" """Detect whether the table uses ``instanceId`` or ``featureInstanceId``."""
if tableName in _instanceColCache: if tableName in _instanceColCache:
return _instanceColCache[tableName] return _instanceColCache[tableName]
if dbConnection: if db:
try: try:
with dbConnection.cursor() as cur: with db.borrowCursor() as cur:
cur.execute( cur.execute(
"SELECT column_name FROM information_schema.columns " "SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
@ -378,14 +362,14 @@ def _isValidIdentifier(name: str) -> bool:
return name.isidentifier() return name.isidentifier()
def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, dbConnection=None) -> Dict[str, Any]: def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, db=None, dbConnection=None) -> Dict[str, Any]:
"""Build the mandatory WHERE clause that scopes rows to the feature instance. """Build the mandatory WHERE clause that scopes rows to the feature instance.
Feature tables use either ``instanceId`` (commcoach, teamsbot) or Feature tables use either ``instanceId`` (commcoach, teamsbot) or
``featureInstanceId`` (trustee) as the FK. We detect the actual column ``featureInstanceId`` (trustee) as the FK. We detect the actual column
from ``information_schema`` when a DB connection is provided. from ``information_schema`` when a DB connector is provided.
""" """
instanceCol = _resolveInstanceColumn(tableName, dbConnection) instanceCol = _resolveInstanceColumn(tableName, db or dbConnection)
conditions = [] conditions = []
params = [] params = []

File diff suppressed because it is too large Load diff

View file

@ -3,9 +3,15 @@
"""Cascade-inherit semantics for DataSource flags (neutralize, ragIndexEnabled, scope). """Cascade-inherit semantics for DataSource flags (neutralize, ragIndexEnabled, scope).
Three-state flags allow tree elements to either set an explicit value or Three-state flags allow tree elements to either set an explicit value or
inherit the value from their nearest ancestor in the path hierarchy. The inherit the value from their nearest ancestor in the path hierarchy.
walker (RAG/Neutralize) and routes resolve the *effective* value; the cascade
helper resets explicit descendant values when a parent is toggled. Modes:
- 'walk' (default): resolves the *concrete* effective value per-item
(never returns 'mixed'). Used by backend consumers (RAG walker,
neutralization pipeline, scope filter, etc.).
- 'aggregate': resolves the *display* effective value per-item. If the
item has descendants with differing walk-effective values, returns
'mixed'. Used by listing endpoints and PATCH responses for the UI.
Path-traversal rules: Path-traversal rules:
- A DataSource is identified by `(connectionId, sourceType, path)`. - A DataSource is identified by `(connectionId, sourceType, path)`.
@ -17,11 +23,12 @@ Path-traversal rules:
""" """
import logging import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope") _INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope")
_INHERITABLE_FDS_FLAGS = ("neutralize", "ragIndexEnabled", "scope")
# Connection-root DataSources carry the authority as their sourceType # Connection-root DataSources carry the authority as their sourceType
# (e.g. 'msft', 'google'). They sit one level above all service DataSources # (e.g. 'msft', 'google'). They sit one level above all service DataSources
@ -29,6 +36,12 @@ _INHERITABLE_FLAGS = ("neutralize", "ragIndexEnabled", "scope")
# cross sourceType boundaries — but ONLY from these authority roots. # cross sourceType boundaries — but ONLY from these authority roots.
_AUTHORITY_SOURCE_TYPES = frozenset({"local", "google", "msft", "clickup", "infomaniak"}) _AUTHORITY_SOURCE_TYPES = frozenset({"local", "google", "msft", "clickup", "infomaniak"})
Mode = Literal["walk", "aggregate"]
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _normalisePath(path: Optional[str]) -> str: def _normalisePath(path: Optional[str]) -> str:
"""Normalize a DataSource path to '/'-prefixed, no trailing slash (except root).""" """Normalize a DataSource path to '/'-prefixed, no trailing slash (except root)."""
@ -49,10 +62,7 @@ def _flagDefault(flag: str) -> Any:
def _isExplicit(value: Any) -> bool: def _isExplicit(value: Any) -> bool:
"""A flag value is explicit when it is not None. """A flag value is explicit when it is not None/empty-string."""
Note: legacy rows may carry empty-string scope; treat as inherit too.
"""
if value is None: if value is None:
return False return False
if isinstance(value, str) and value == "": if isinstance(value, str) and value == "":
@ -66,6 +76,21 @@ def _getRecordValue(rec: Any, key: str) -> Any:
return getattr(rec, key, None) return getattr(rec, key, None)
def _isAncestorPath(ancestor: str, descendant: str) -> bool:
"""True iff `ancestor` is a strict path-prefix of `descendant`."""
if ancestor == descendant:
return False
if ancestor == "/":
return descendant != "/"
return descendant.startswith(ancestor + "/")
def _pathDepth(path: str) -> int:
if path == "/":
return 0
return path.count("/")
def _findAncestorChain( def _findAncestorChain(
rec: Dict[str, Any], rec: Dict[str, Any],
allDs: Iterable[Dict[str, Any]], allDs: Iterable[Dict[str, Any]],
@ -74,15 +99,13 @@ def _findAncestorChain(
ordered nearest-first. ordered nearest-first.
Two ancestor relations are merged: Two ancestor relations are merged:
1) **same-sourceType path-ancestor** strict path-prefix within the 1) same-sourceType path-ancestor strict path-prefix within the
same service tree (sharepointFolder, gmailFolder, ...). same service tree.
2) **connection-root ancestor** a DS with `path='/'` and 2) connection-root ancestor a DS with `path='/'` and
`sourceType` authority set (msft, google, ...) is the parent of `sourceType` in authority set is the parent of every other DS
every other DS in that connection regardless of sourceType, so a in that connection regardless of sourceType.
toggle on the connection node propagates to all services beneath.
The connection-root is always the most distant ancestor and therefore The connection-root is always the most distant ancestor.
sorts after any same-sourceType ancestors.
""" """
recPath = _normalisePath(_getRecordValue(rec, "path")) recPath = _normalisePath(_getRecordValue(rec, "path"))
recSourceType = _getRecordValue(rec, "sourceType") recSourceType = _getRecordValue(rec, "sourceType")
@ -114,36 +137,89 @@ def _findAncestorChain(
return chain return chain
def _isAncestorPath(ancestor: str, descendant: str) -> bool: def _isDescendantDs(parentRec: Dict[str, Any], candidate: Dict[str, Any]) -> bool:
"""True iff `ancestor` is a strict path-prefix of `descendant`. """True iff `candidate` is a descendant of `parentRec` in the DS hierarchy."""
parentSourceType = _getRecordValue(parentRec, "sourceType")
parentPath = _normalisePath(_getRecordValue(parentRec, "path"))
parentConnectionId = _getRecordValue(parentRec, "connectionId")
parentId = _getRecordValue(parentRec, "id")
'/' is ancestor of every non-root path. For non-root prefixes, the candId = _getRecordValue(candidate, "id")
descendant must continue with '/' so '/foo' isn't treated as ancestor of if candId == parentId:
'/foobar'. return False
""" if _getRecordValue(candidate, "connectionId") != parentConnectionId:
if ancestor == descendant:
return False return False
if ancestor == "/":
return descendant != "/"
return descendant.startswith(ancestor + "/")
candSourceType = _getRecordValue(candidate, "sourceType")
candPath = _normalisePath(_getRecordValue(candidate, "path"))
parentIsConnectionRoot = (
parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/"
)
if parentIsConnectionRoot:
return True
if candSourceType != parentSourceType:
return False
return _isAncestorPath(parentPath, candPath)
# ---------------------------------------------------------------------------
# DataSource: getEffectiveFlag
# ---------------------------------------------------------------------------
def getEffectiveFlag( def getEffectiveFlag(
rec: Dict[str, Any], rec: Dict[str, Any],
flag: str, flag: str,
sameConnectionDs: Iterable[Dict[str, Any]], sameConnectionDs: Iterable[Dict[str, Any]],
mode: Mode = "walk",
) -> Any: ) -> Any:
"""Resolve the effective value of a flag via path-traversal. """Resolve the effective value of a flag via path-traversal.
Order: own value (if explicit) nearest ancestor with explicit value mode='walk': own explicit nearest ancestor explicit default.
static default (`False` or `'personal'`). Always returns a concrete value (never 'mixed').
mode='aggregate': same as walk for leaf value, but if the item has
descendants whose walk-effective values differ from
each other, returns 'mixed'.
""" """
if flag not in _INHERITABLE_FLAGS: if flag not in _INHERITABLE_FLAGS:
raise ValueError(f"Unknown inheritable flag: {flag}") raise ValueError(f"Unknown inheritable flag: {flag}")
allDs = list(sameConnectionDs)
walkValue = _resolveWalkValue(rec, flag, allDs)
if mode == "walk":
return walkValue
# mode == 'aggregate': check subtree for heterogeneous effective values
descendants = [d for d in allDs if _isDescendantDs(rec, d)]
if not descendants:
return walkValue
subtreeValues = set()
subtreeValues.add(_normaliseForComparison(walkValue))
for desc in descendants:
descEffective = _resolveWalkValue(desc, flag, allDs)
subtreeValues.add(_normaliseForComparison(descEffective))
if len(subtreeValues) > 1:
recId = _getRecordValue(rec, "id")
descId = _getRecordValue(desc, "id")
descOwnVal = _getRecordValue(desc, flag)
logger.info(
"DS aggregate MIXED for rec=%s flag=%s: walkValue=%s, "
"divergent desc=%s (own=%s, effective=%s), subtreeValues=%s",
recId, flag, walkValue, descId, descOwnVal, descEffective, subtreeValues,
)
return "mixed"
return walkValue
def _resolveWalkValue(rec: Dict[str, Any], flag: str, allDs: List[Dict[str, Any]]) -> Any:
"""Core walk resolution: own explicit → ancestor chain → default."""
own = _getRecordValue(rec, flag) own = _getRecordValue(rec, flag)
if _isExplicit(own): if _isExplicit(own):
return own return own
chain = _findAncestorChain(rec, sameConnectionDs) chain = _findAncestorChain(rec, allDs)
for ancestor in chain: for ancestor in chain:
ancestorVal = _getRecordValue(ancestor, flag) ancestorVal = _getRecordValue(ancestor, flag)
if _isExplicit(ancestorVal): if _isExplicit(ancestorVal):
@ -151,69 +227,112 @@ def getEffectiveFlag(
return _flagDefault(flag) return _flagDefault(flag)
def _normaliseForComparison(value: Any) -> Any:
"""Normalize values for set-comparison (bool as int to avoid hash issues)."""
if isinstance(value, bool):
return int(value)
return value
# ---------------------------------------------------------------------------
# DataSource: cascadeResetDescendants (bottom-up)
# ---------------------------------------------------------------------------
def cascadeResetDescendants( def cascadeResetDescendants(
rootIf: Any, rootIf: Any,
parentRec: Dict[str, Any], parentRec: Dict[str, Any],
flag: str, flag: str,
) -> int: ) -> List[str]:
"""Reset all explicit descendant values of `flag` to NULL (= inherit). """Reset all explicit descendant values of `flag` to NULL (= inherit).
Descendant relation mirrors `_findAncestorChain`: Reset order: bottom-up (deepest first) for crash safety.
- Connection-root (`path='/'` AND `sourceType` authorities) is parent The parent itself is NOT modified here the caller sets the master value
of every other DS in that connection (cross-sourceType cascade). after this function returns.
- Otherwise: same-sourceType strict path-descendants only.
Only the targeted `flag` is reset; other flags on the descendant are Returns list of reset record IDs in bottom-up order.
untouched.
Returns the number of records updated.
""" """
if flag not in _INHERITABLE_FLAGS: if flag not in _INHERITABLE_FLAGS:
raise ValueError(f"Unknown inheritable flag: {flag}") raise ValueError(f"Unknown inheritable flag: {flag}")
from modules.datamodels.datamodelDataSource import DataSource from modules.datamodels.datamodelDataSource import DataSource
connectionId = _getRecordValue(parentRec, "connectionId") connectionId = _getRecordValue(parentRec, "connectionId")
parentSourceType = _getRecordValue(parentRec, "sourceType")
parentPath = _normalisePath(_getRecordValue(parentRec, "path"))
parentId = _getRecordValue(parentRec, "id") parentId = _getRecordValue(parentRec, "id")
if not connectionId or not parentSourceType: if not connectionId:
return 0 return []
parentIsConnectionRoot = (
parentSourceType in _AUTHORITY_SOURCE_TYPES and parentPath == "/"
)
siblings = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId}) siblings = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId})
affected = 0
toReset: List[Tuple[int, str]] = []
for sib in siblings: for sib in siblings:
sibId = _getRecordValue(sib, "id") if not _isDescendantDs(parentRec, sib):
if sibId == parentId:
continue
sibSourceType = _getRecordValue(sib, "sourceType")
sibPath = _normalisePath(_getRecordValue(sib, "path"))
if parentIsConnectionRoot:
# Connection-root resets everything else under this connection.
pass
else:
if sibSourceType != parentSourceType:
continue
if not _isAncestorPath(parentPath, sibPath):
continue continue
sibVal = _getRecordValue(sib, flag) sibVal = _getRecordValue(sib, flag)
if not _isExplicit(sibVal): if not _isExplicit(sibVal):
continue continue
sibId = _getRecordValue(sib, "id")
sibPath = _normalisePath(_getRecordValue(sib, "path"))
toReset.append((_pathDepth(sibPath), sibId))
# Sort deepest first (bottom-up)
toReset.sort(key=lambda x: x[0], reverse=True)
resetIds: List[str] = []
for _, sibId in toReset:
try: try:
rootIf.db.recordModify(DataSource, sibId, {flag: None}) rootIf.db.recordModify(DataSource, sibId, {flag: None})
affected += 1 resetIds.append(sibId)
except Exception as exc: except Exception as exc:
logger.warning("Cascade-reset failed for DataSource %s flag=%s: %s", sibId, flag, exc) logger.warning("Cascade-reset failed for DataSource %s flag=%s: %s", sibId, flag, exc)
if affected:
logger.info(
"Cascade-reset %s on %d descendants of DataSource (connectionId=%s, sourceType=%s, path=%s, connectionRoot=%s)",
flag, affected, connectionId, parentSourceType, parentPath, parentIsConnectionRoot,
)
return affected
if resetIds:
logger.info(
"Cascade-reset %s on %d descendants of DataSource %s (bottom-up)",
flag, len(resetIds), parentId,
)
return resetIds
# ---------------------------------------------------------------------------
# DataSource: collectAncestorChain (for updatedAncestors in PATCH response)
# ---------------------------------------------------------------------------
def collectAncestorChain(
rec: Dict[str, Any],
sameConnectionDs: Iterable[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Return ancestor chain of `rec` (nearest-first), same as internal helper.
Exposed for PATCH endpoints to compute updatedAncestors.
"""
return _findAncestorChain(rec, sameConnectionDs)
# ---------------------------------------------------------------------------
# DataSource: buildEffectiveByConnection
# ---------------------------------------------------------------------------
def buildEffectiveByConnection(
dataSources: Iterable[Dict[str, Any]],
flag: str,
mode: Mode = "walk",
) -> Dict[str, Any]:
"""Pre-compute the effective value of `flag` for every DataSource id.
Uses the specified mode. O(N^2) worst case but N is bounded per connection.
"""
if flag not in _INHERITABLE_FLAGS:
raise ValueError(f"Unknown inheritable flag: {flag}")
allDs = list(dataSources)
out: Dict[str, Any] = {}
for rec in allDs:
recId = _getRecordValue(rec, "id")
out[recId] = getEffectiveFlag(rec, flag, allDs, mode=mode)
return out
# ---------------------------------------------------------------------------
# FeatureDataSource helpers
# ---------------------------------------------------------------------------
def _fdsClassify(fds: Dict[str, Any]) -> str: def _fdsClassify(fds: Dict[str, Any]) -> str:
"""Return 'workspace' | 'table' | 'record' based on the FDS identifier shape.""" """Return 'workspace' | 'table' | 'record' based on the FDS identifier shape."""
@ -229,14 +348,14 @@ def _fdsClassify(fds: Dict[str, Any]) -> str:
def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool: def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool:
"""Return True iff `parent` FDS is a strict ancestor of `child` FDS. """Return True iff `parent` FDS is a strict ancestor of `child` FDS.
Hierarchy within one `workspaceInstanceId`: Hierarchy within one featureInstanceId (allFds is already scoped to
workspace-wildcard (tableName='*') table-wildcard (tableName='X', !recordFilter) a single workspace):
record-fds (tableName='X', recordFilter.id=...) feature-wildcard (tableName='*') -> table-wildcard / record-fds
table-wildcard (tableName='X') record-fds (tableName='X', recordFilter.id=...) table-wildcard (tableName='X') -> record-fds (tableName='X')
""" """
parentWsId = _getRecordValue(parent, "workspaceInstanceId") parentFiId = _getRecordValue(parent, "featureInstanceId")
childWsId = _getRecordValue(child, "workspaceInstanceId") childFiId = _getRecordValue(child, "featureInstanceId")
if not parentWsId or parentWsId != childWsId: if not parentFiId or parentFiId != childFiId:
return False return False
if _getRecordValue(parent, "id") == _getRecordValue(child, "id"): if _getRecordValue(parent, "id") == _getRecordValue(child, "id"):
return False return False
@ -251,23 +370,68 @@ def _fdsIsAncestor(parent: Dict[str, Any], child: Dict[str, Any]) -> bool:
return False return False
def _fdsDepth(fds: Dict[str, Any]) -> int:
kind = _fdsClassify(fds)
if kind == "workspace":
return 0
if kind == "table":
return 1
return 2
# ---------------------------------------------------------------------------
# FeatureDataSource: getEffectiveFlagFds
# ---------------------------------------------------------------------------
def getEffectiveFlagFds( def getEffectiveFlagFds(
rec: Dict[str, Any], rec: Dict[str, Any],
flag: str, flag: str,
sameWorkspaceFds: Iterable[Dict[str, Any]], sameWorkspaceFds: Iterable[Dict[str, Any]],
mode: Mode = "walk",
) -> Any: ) -> Any:
"""Resolve effective value of a FeatureDataSource flag. """Resolve effective value of a FeatureDataSource flag.
Order: own (if explicit) table-wildcard (if explicit) mode='walk': own explicit -> table-wildcard -> workspace-wildcard -> default.
workspace-wildcard (if explicit) static default. mode='aggregate': same but returns 'mixed' if descendants diverge.
""" """
if flag not in ("neutralize", "scope"): if flag not in _INHERITABLE_FDS_FLAGS:
raise ValueError(f"Unknown inheritable FDS flag: {flag}") raise ValueError(f"Unknown inheritable FDS flag: {flag}")
allFds = list(sameWorkspaceFds)
walkValue = _resolveWalkValueFds(rec, flag, allFds)
if mode == "walk":
return walkValue
# mode == 'aggregate'
descendants = [f for f in allFds if _fdsIsAncestor(rec, f)]
if not descendants:
return walkValue
subtreeValues = set()
subtreeValues.add(_normaliseForComparison(walkValue))
for desc in descendants:
descEffective = _resolveWalkValueFds(desc, flag, allFds)
subtreeValues.add(_normaliseForComparison(descEffective))
if len(subtreeValues) > 1:
recId = _getRecordValue(rec, "id")
descId = _getRecordValue(desc, "id")
descOwnVal = _getRecordValue(desc, flag)
logger.info(
"FDS aggregate MIXED for rec=%s flag=%s: walkValue=%s, "
"divergent desc=%s (own=%s, effective=%s), subtreeValues=%s",
recId, flag, walkValue, descId, descOwnVal, descEffective, subtreeValues,
)
return "mixed"
return walkValue
def _resolveWalkValueFds(rec: Dict[str, Any], flag: str, allFds: List[Dict[str, Any]]) -> Any:
"""Core walk resolution for FDS."""
own = _getRecordValue(rec, flag) own = _getRecordValue(rec, flag)
if _isExplicit(own): if _isExplicit(own):
return own return own
workspaceFds: List[Dict[str, Any]] = list(sameWorkspaceFds) ancestors = [a for a in allFds if _fdsIsAncestor(a, rec)]
ancestors = [a for a in workspaceFds if _fdsIsAncestor(a, rec)]
ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1) ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1)
for ancestor in ancestors: for ancestor in ancestors:
val = _getRecordValue(ancestor, flag) val = _getRecordValue(ancestor, flag)
@ -276,27 +440,32 @@ def getEffectiveFlagFds(
return _flagDefault(flag) return _flagDefault(flag)
# ---------------------------------------------------------------------------
# FeatureDataSource: cascadeResetDescendantsFds (bottom-up)
# ---------------------------------------------------------------------------
def cascadeResetDescendantsFds( def cascadeResetDescendantsFds(
rootIf: Any, rootIf: Any,
parentRec: Dict[str, Any], parentRec: Dict[str, Any],
flag: str, flag: str,
) -> int: ) -> List[str]:
"""Reset explicit `flag` to NULL on every descendant FDS of `parentRec`. """Reset explicit `flag` to NULL on every descendant FDS of `parentRec`.
Only the targeted flag is reset; other flags on descendants are untouched. Reset order: bottom-up (deepest first) for crash safety.
Returns the number of records updated. Returns list of reset record IDs in bottom-up order.
""" """
if flag not in ("neutralize", "scope"): if flag not in _INHERITABLE_FDS_FLAGS:
raise ValueError(f"Unknown inheritable FDS flag: {flag}") raise ValueError(f"Unknown inheritable FDS flag: {flag}")
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
workspaceInstanceId = _getRecordValue(parentRec, "workspaceInstanceId") workspaceInstanceId = _getRecordValue(parentRec, "workspaceInstanceId")
if not workspaceInstanceId: if not workspaceInstanceId:
return 0 return []
siblings = rootIf.db.getRecordset( siblings = rootIf.db.getRecordset(
FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId} FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId}
) )
affected = 0
toReset: List[Tuple[int, str]] = []
for sib in siblings: for sib in siblings:
if not _fdsIsAncestor(parentRec, sib): if not _fdsIsAncestor(parentRec, sib):
continue continue
@ -304,39 +473,159 @@ def cascadeResetDescendantsFds(
if not _isExplicit(sibVal): if not _isExplicit(sibVal):
continue continue
sibId = _getRecordValue(sib, "id") sibId = _getRecordValue(sib, "id")
toReset.append((_fdsDepth(sib), sibId))
# Sort deepest first (bottom-up)
toReset.sort(key=lambda x: x[0], reverse=True)
resetIds: List[str] = []
for _, sibId in toReset:
try: try:
rootIf.db.recordModify(FeatureDataSource, sibId, {flag: None}) rootIf.db.recordModify(FeatureDataSource, sibId, {flag: None})
affected += 1 resetIds.append(sibId)
except Exception as exc: except Exception as exc:
logger.warning("FDS cascade-reset failed for %s flag=%s: %s", sibId, flag, exc) logger.warning("FDS cascade-reset failed for %s flag=%s: %s", sibId, flag, exc)
if affected:
if resetIds:
logger.info( logger.info(
"FDS cascade-reset %s on %d descendants of FDS (workspaceInstanceId=%s, kind=%s)", "FDS cascade-reset %s on %d descendants of FDS %s (bottom-up)",
flag, affected, workspaceInstanceId, _fdsClassify(parentRec), flag, len(resetIds), _getRecordValue(parentRec, "id"),
) )
return affected return resetIds
def buildEffectiveByConnection( # ---------------------------------------------------------------------------
dataSources: Iterable[Dict[str, Any]], # FeatureDataSource: collectAncestorChainFds
flag: str, # ---------------------------------------------------------------------------
) -> Dict[str, Any]:
"""Pre-compute the effective value of `flag` for every DataSource id.
Useful for batch operations (walker, route DTOs) that touch many records def collectAncestorChainFds(
at once. O() in the worst case but N is bounded per connection. rec: Dict[str, Any],
sameWorkspaceFds: Iterable[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Return ancestor chain of `rec` FDS (nearest-first).
Exposed for PATCH endpoints to compute updatedAncestors.
""" """
if flag not in _INHERITABLE_FLAGS: allFds = list(sameWorkspaceFds)
raise ValueError(f"Unknown inheritable flag: {flag}") ancestors = [a for a in allFds if _fdsIsAncestor(a, rec)]
bySourceType: Dict[Tuple[str, str], List[Dict[str, Any]]] = {} ancestors.sort(key=lambda a: 0 if _fdsClassify(a) == "table" else 1)
for ds in dataSources: return ancestors
connId = _getRecordValue(ds, "connectionId") or ""
srcType = _getRecordValue(ds, "sourceType") or ""
bySourceType.setdefault((connId, srcType), []).append(ds)
# ---------------------------------------------------------------------------
# FeatureDataSource: buildEffectiveByWorkspaceFds
# ---------------------------------------------------------------------------
def buildEffectiveByWorkspaceFds(
fdses: Iterable[Dict[str, Any]],
flag: str,
mode: Mode = "walk",
) -> Dict[str, Any]:
"""Pre-compute the effective value of `flag` for every FDS id."""
if flag not in _INHERITABLE_FDS_FLAGS:
raise ValueError(f"Unknown inheritable FDS flag: {flag}")
allFds = list(fdses)
out: Dict[str, Any] = {} out: Dict[str, Any] = {}
for group in bySourceType.values(): for rec in allFds:
for rec in group:
recId = _getRecordValue(rec, "id") recId = _getRecordValue(rec, "id")
out[recId] = getEffectiveFlag(rec, flag, group) out[recId] = getEffectiveFlagFds(rec, flag, allFds, mode=mode)
return out return out
# ---------------------------------------------------------------------------
# Bulk resolve: effective flags for arbitrary paths (even without DB record)
# ---------------------------------------------------------------------------
def resolveEffectiveForPath(
connectionId: str,
sourceType: str,
path: str,
allDs: List[Dict[str, Any]],
mode: Mode = "aggregate",
) -> Dict[str, Any]:
"""Resolve effective flags for ANY (connectionId, sourceType, path) tuple.
Works whether or not a DataSource record exists for this exact path.
Returns dict with effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled.
"""
normPath = _normalisePath(path)
exactRecord = None
for ds in allDs:
if (
_getRecordValue(ds, "connectionId") == connectionId
and _getRecordValue(ds, "sourceType") == sourceType
and _normalisePath(_getRecordValue(ds, "path")) == normPath
):
exactRecord = ds
break
if exactRecord:
return {
"effectiveNeutralize": getEffectiveFlag(exactRecord, "neutralize", allDs, mode=mode),
"effectiveScope": getEffectiveFlag(exactRecord, "scope", allDs, mode=mode),
"effectiveRagIndexEnabled": getEffectiveFlag(exactRecord, "ragIndexEnabled", allDs, mode=mode),
}
virtualRec = {
"id": "__virtual__",
"connectionId": connectionId,
"sourceType": sourceType,
"path": normPath,
"neutralize": None,
"scope": None,
"ragIndexEnabled": None,
}
return {
"effectiveNeutralize": _resolveWalkValue(virtualRec, "neutralize", allDs),
"effectiveScope": _resolveWalkValue(virtualRec, "scope", allDs),
"effectiveRagIndexEnabled": _resolveWalkValue(virtualRec, "ragIndexEnabled", allDs),
}
def resolveEffectiveForFds(
featureInstanceId: str,
tableName: str,
recordFilter: Optional[Dict[str, str]],
allFds: List[Dict[str, Any]],
mode: Mode = "aggregate",
) -> Dict[str, Any]:
"""Resolve effective flags for ANY FDS tuple (even without DB record).
`allFds` is pre-scoped to a single workspace (loaded with
workspaceInstanceId filter). Within that set, the coordinate is
featureInstanceId + tableName + recordFilter.
Returns dict with effectiveNeutralize, effectiveScope, effectiveRagIndexEnabled.
"""
exactRecord = None
for fds in allFds:
if _getRecordValue(fds, "featureInstanceId") != featureInstanceId:
continue
if (_getRecordValue(fds, "tableName") or "") != tableName:
continue
fdsFilter = _getRecordValue(fds, "recordFilter")
if fdsFilter == recordFilter:
exactRecord = fds
break
if exactRecord:
return {
"effectiveNeutralize": getEffectiveFlagFds(exactRecord, "neutralize", allFds, mode=mode),
"effectiveScope": getEffectiveFlagFds(exactRecord, "scope", allFds, mode=mode),
"effectiveRagIndexEnabled": getEffectiveFlagFds(exactRecord, "ragIndexEnabled", allFds, mode=mode),
}
virtualRec = {
"id": "__virtual__",
"featureInstanceId": featureInstanceId,
"tableName": tableName,
"recordFilter": recordFilter,
"neutralize": None,
"scope": None,
"ragIndexEnabled": None,
}
return {
"effectiveNeutralize": _resolveWalkValueFds(virtualRec, "neutralize", allFds),
"effectiveScope": _resolveWalkValueFds(virtualRec, "scope", allFds),
"effectiveRagIndexEnabled": _resolveWalkValueFds(virtualRec, "ragIndexEnabled", allFds),
}

View file

@ -147,7 +147,7 @@ class KnowledgeService:
else getattr(existing, "status", "") else getattr(existing, "status", "")
) or "" ) or ""
if existingMeta.get("hash") == contentHash and existingStatus == "indexed": if existingMeta.get("hash") == contentHash and existingStatus == "indexed":
logger.info( logger.debug(
"ingestion.skipped.duplicate sourceKind=%s sourceId=%s hash=%s", "ingestion.skipped.duplicate sourceKind=%s sourceId=%s hash=%s",
job.sourceKind, job.sourceId, contentHash[:12], job.sourceKind, job.sourceId, contentHash[:12],
extra={ extra={

View file

@ -431,6 +431,15 @@ def registerKnowledgeIngestionConsumer() -> None:
callbackRegistry.register("connection.established", _onConnectionEstablished) callbackRegistry.register("connection.established", _onConnectionEstablished)
callbackRegistry.register("connection.revoked", _onConnectionRevoked) callbackRegistry.register("connection.revoked", _onConnectionRevoked)
registerJobHandler(BOOTSTRAP_JOB_TYPE, _bootstrapJobHandler) registerJobHandler(BOOTSTRAP_JOB_TYPE, _bootstrapJobHandler)
from modules.serviceCenter.services.serviceKnowledge.subFeatureBootstrap import (
FEATURE_BOOTSTRAP_JOB_TYPE, _featureBootstrapHandler,
)
registerJobHandler(FEATURE_BOOTSTRAP_JOB_TYPE, _featureBootstrapHandler)
registerDailyResyncScheduler() registerDailyResyncScheduler()
_registered = True _registered = True
logger.info("KnowledgeIngestionConsumer registered (established/revoked + %s handler + daily resync)", BOOTSTRAP_JOB_TYPE) logger.info(
"KnowledgeIngestionConsumer registered (established/revoked + %s + %s handler + daily resync)",
BOOTSTRAP_JOB_TYPE, FEATURE_BOOTSTRAP_JOB_TYPE,
)

View file

@ -0,0 +1,289 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Feature-data RAG bootstrap: indexes FeatureDataSource rows into the knowledge store.
Analogous to connection.bootstrap for external connections (Google, Microsoft),
this handler reads FeatureDataSource records with ragIndexEnabled=True, queries
the underlying feature tables via FeatureDataProvider, serialises each row into
text, and feeds it through KnowledgeService.requestIngestion so the data
appears in ContentChunk embeddings for semantic RAG search.
Job type: ``feature.bootstrap``
Payload: ``{"workspaceInstanceId": "...", "featureDataSourceIds": [...] (optional)}``
"""
from __future__ import annotations
import json
import logging
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
FEATURE_BOOTSTRAP_JOB_TYPE = "feature.bootstrap"
def _loadRagEnabledFds(workspaceInstanceId: str, featureDataSourceIds: Optional[List[str]] = None):
"""Load FeatureDataSource rows whose effective ragIndexEnabled is True.
Returns dicts with resolved flags so downstream code can read them directly.
"""
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlagFds
rootIf = getRootInterface()
allFds = rootIf.db.getRecordset(
FeatureDataSource, recordFilter={"workspaceInstanceId": workspaceInstanceId}
)
resolved = []
for fds in allFds:
tblName = (fds.get("tableName") if isinstance(fds, dict) else getattr(fds, "tableName", "")) or ""
fCode = (fds.get("featureCode") if isinstance(fds, dict) else getattr(fds, "featureCode", "")) or ""
if tblName == "*" or not tblName or not fCode:
continue
effRag = getEffectiveFlagFds(fds, "ragIndexEnabled", allFds, mode="aggregate")
if effRag is not True:
continue
row = dict(fds) if isinstance(fds, dict) else {**fds.__dict__}
row["_effectiveNeutralize"] = getEffectiveFlagFds(fds, "neutralize", allFds, mode="aggregate")
row["_effectiveScope"] = getEffectiveFlagFds(fds, "scope", allFds, mode="aggregate") or "featureInstance"
row["ragIndexEnabled"] = True
resolved.append(row)
if featureDataSourceIds:
idSet = set(featureDataSourceIds)
resolved = [r for r in resolved if r.get("id") in idSet]
return resolved
def _serializeRowToText(row: Dict[str, Any], neutralizeFields: Optional[List[str]] = None) -> str:
"""Convert a feature-table row into readable text for embedding.
Skips internal fields (starting with '_' or 'sys') and produces
``key: value`` lines that embed well semantically.
"""
neutralizeSet = set(neutralizeFields or [])
lines = []
for key, value in row.items():
if key.startswith("_") or key.startswith("sys"):
continue
if key == "id":
continue
if value is None or value == "" or value == []:
continue
if key in neutralizeSet:
value = "[REDACTED]"
elif isinstance(value, (dict, list)):
value = json.dumps(value, ensure_ascii=False, default=str)
else:
value = str(value)
lines.append(f"{key}: {value}")
return "\n".join(lines)
def _getFeatureDbConnector(featureCode: str):
"""Create a lightweight DB connector to the feature database."""
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
dbName = f"poweron_{featureCode.lower()}"
return DatabaseConnector(
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
dbDatabase=dbName,
dbUser=APP_CONFIG.get("DB_USER"),
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
dbPort=int(APP_CONFIG.get("DB_PORT", 5432)),
userId="system.feature_bootstrap",
)
async def _featureBootstrapHandler(
job: Dict[str, Any],
progressCb,
) -> Dict[str, Any]:
"""Walk RAG-enabled FeatureDataSources and index their rows."""
payload = job.get("payload") or {}
workspaceInstanceId = payload.get("workspaceInstanceId")
featureDataSourceIds = payload.get("featureDataSourceIds")
if not workspaceInstanceId:
raise ValueError("feature.bootstrap requires payload.workspaceInstanceId")
progressCb(5, messageKey="Feature-Datenquellen werden geladen...")
fdsList = _loadRagEnabledFds(workspaceInstanceId, featureDataSourceIds)
if not fdsList:
logger.info(
"feature.bootstrap.skipped — no rag-enabled FDS for workspace %s",
workspaceInstanceId,
)
return {"workspaceInstanceId": workspaceInstanceId, "skipped": True, "reason": "no_rag_enabled_fds"}
from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob
from modules.serviceCenter.context import ServiceCenterContext
from modules.serviceCenter import getService
from modules.security.rootAccess import getRootUser
totalIndexed = 0
totalSkipped = 0
totalFailed = 0
fdsResults = []
for fdsIdx, fds in enumerate(fdsList):
fdsId = fds.get("id", "")
featureCode = fds.get("featureCode", "")
tableName = fds.get("tableName", "")
featureInstanceId = fds.get("featureInstanceId", "")
mandateId = fds.get("mandateId", "")
neutralizeFields = fds.get("neutralizeFields") or []
recordFilter = fds.get("recordFilter") or {}
effectiveScope = fds.get("_effectiveScope", "featureInstance")
effectiveNeutralize = bool(fds.get("_effectiveNeutralize", False))
progressPct = 5 + int(90 * fdsIdx / len(fdsList))
progressCb(
progressPct,
messageKey="Indexiere {table} ({n}/{total})...",
messageParams={"table": tableName, "n": fdsIdx + 1, "total": len(fdsList)},
)
if not featureCode or not tableName or not featureInstanceId:
logger.warning("feature.bootstrap: skipping FDS %s — missing featureCode/tableName/fiId", fdsId)
continue
try:
dbConnector = _getFeatureDbConnector(featureCode)
provider = FeatureDataProvider(dbConnector)
rootUser = getRootUser()
ctx = ServiceCenterContext(
user=rootUser,
mandate_id=mandateId,
feature_instance_id=workspaceInstanceId,
)
knowledgeService = getService("knowledge", ctx)
extraFilters = [
{"field": k, "op": "=", "value": v}
for k, v in recordFilter.items()
] if recordFilter else None
batchSize = 200
offset = 0
fdsIndexed = 0
fdsSkipped = 0
fdsFailed = 0
while True:
result = provider.browseTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
limit=batchSize,
offset=offset,
extraFilters=extraFilters,
)
rows = result.get("rows", [])
if not rows:
break
for row in rows:
rowId = row.get("id", "")
if not rowId:
continue
textContent = _serializeRowToText(row, neutralizeFields if effectiveNeutralize else None)
if not textContent.strip():
fdsSkipped += 1
continue
contentVersion = str(row.get("sysUpdatedAt") or row.get("sysCreatedAt") or "")
ingestionJob = IngestionJob(
sourceKind="feature_record",
sourceId=f"{workspaceInstanceId}:{tableName}:{rowId}",
fileName=f"{tableName}-{rowId}",
mimeType="application/vnd.poweron.feature-record+json",
userId=fds.get("userId") or "system",
featureInstanceId=workspaceInstanceId,
mandateId=mandateId,
contentObjects=[{
"contentType": "text",
"data": textContent,
"contextRef": {
"table": tableName,
"featureCode": featureCode,
"featureInstanceId": featureInstanceId,
"rowId": rowId,
},
"contentObjectId": f"{tableName}:{rowId}",
}],
structure={"sourceTable": tableName, "featureCode": featureCode},
contentVersion=contentVersion,
provenance={
"featureDataSourceId": fdsId,
"tableName": tableName,
"featureCode": featureCode,
"featureInstanceId": featureInstanceId,
},
neutralize=effectiveNeutralize,
)
try:
handle = await knowledgeService.requestIngestion(ingestionJob)
if handle.status == "failed":
fdsFailed += 1
logger.warning(
"feature.bootstrap: ingestion failed fds=%s table=%s row=%s error=%s",
fdsId, tableName, rowId, handle.error,
)
elif handle.status == "duplicate":
fdsSkipped += 1
else:
fdsIndexed += 1
except Exception as ingErr:
fdsFailed += 1
logger.error(
"feature.bootstrap: ingestion error fds=%s row=%s: %s",
fdsId, rowId, ingErr,
)
offset += batchSize
if len(rows) < batchSize:
break
totalIndexed += fdsIndexed
totalSkipped += fdsSkipped
totalFailed += fdsFailed
fdsResults.append({
"featureDataSourceId": fdsId,
"tableName": tableName,
"featureCode": featureCode,
"indexed": fdsIndexed,
"skippedDuplicate": fdsSkipped,
"failed": fdsFailed,
})
except Exception as fdsErr:
logger.error(
"feature.bootstrap: error processing FDS %s (%s.%s): %s",
fdsId, featureCode, tableName, fdsErr, exc_info=True,
)
fdsResults.append({
"featureDataSourceId": fdsId,
"tableName": tableName,
"featureCode": featureCode,
"error": str(fdsErr),
})
progressCb(100, messageKey="Feature-Daten-Sync abgeschlossen.")
return {
"workspaceInstanceId": workspaceInstanceId,
"indexed": totalIndexed,
"skippedDuplicate": totalSkipped,
"failed": totalFailed,
"dataSources": fdsResults,
}

View file

@ -1,32 +0,0 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""DEPRECATED: Use `_inheritFlags.getEffectiveFlag()` directly.
Thin shim to the new cascade-inherit helper. Kept so external callers don't
break on import internal walkers consume pre-resolved dicts via
`_loadRagEnabledDataSources`.
"""
from __future__ import annotations
from typing import Any, Dict, List
from modules.serviceCenter.services.serviceKnowledge._inheritFlags import getEffectiveFlag
def resolveEffectiveNeutralize(
ds: Dict[str, Any],
allDataSources: List[Dict[str, Any]],
) -> bool:
"""DEPRECATED: use `getEffectiveFlag(ds, 'neutralize', allDataSources)`."""
value = getEffectiveFlag(ds, "neutralize", allDataSources)
return bool(value)
def resolveEffectiveRagIndexEnabled(
ds: Dict[str, Any],
allDataSources: List[Dict[str, Any]],
) -> bool:
"""DEPRECATED: use `getEffectiveFlag(ds, 'ragIndexEnabled', allDataSources)`."""
value = getEffectiveFlag(ds, "ragIndexEnabled", allDataSources)
return bool(value)

View file

@ -15,8 +15,9 @@ up with "Job stuck at 10% for 10h" zombies.
These helpers wrap each phase in `asyncio.wait_for`. Sync extraction runs These helpers wrap each phase in `asyncio.wait_for`. Sync extraction runs
on a worker thread so the loop stays responsive. Every wrapped call also on a worker thread so the loop stays responsive. Every wrapped call also
emits a short start/done log line, so when something hangs we know the emits start/done log lines at DEBUG so normal INFO logs stay quiet; for
exact item that caused it (path, size, mime). stuck-job triage, enable DEBUG for this module the last
``walker.item.start`` before a hang still pinpoints the item (path, size, mime).
""" """
from __future__ import annotations from __future__ import annotations
@ -48,7 +49,7 @@ async def downloadWithTimeout(
used in log messages so we can pinpoint the offending item in case of a used in log messages so we can pinpoint the offending item in case of a
hang or timeout. hang or timeout.
""" """
logger.info("walker.download.start %s timeout=%ds", label, timeoutSeconds) logger.debug("walker.download.start %s timeout=%ds", label, timeoutSeconds)
try: try:
result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds) result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds)
logger.debug("walker.download.done %s", label) logger.debug("walker.download.done %s", label)
@ -71,7 +72,7 @@ async def extractWithTimeout(
keep running until the process exits but at least the walker proceeds keep running until the process exits but at least the walker proceeds
to the next item instead of freezing forever. to the next item instead of freezing forever.
""" """
logger.info("walker.extract.start %s timeout=%ds", label, timeoutSeconds) logger.debug("walker.extract.start %s timeout=%ds", label, timeoutSeconds)
try: try:
result = await asyncio.wait_for( result = await asyncio.wait_for(
asyncio.to_thread(syncFn, *args), asyncio.to_thread(syncFn, *args),
@ -102,15 +103,15 @@ async def ingestWithTimeout(
def logItemStart(service: str, label: str, *, sizeBytes: Optional[int] = None, mime: Optional[str] = None) -> None: def logItemStart(service: str, label: str, *, sizeBytes: Optional[int] = None, mime: Optional[str] = None) -> None:
"""Log that processing of one item is about to begin. """Log that processing of one item is about to begin (DEBUG).
When the worker hangs, the LAST `walker.item.start` line in the log When the worker hangs, the LAST `walker.item.start` line in the log
points to the exact item that caused the freeze. This is the single points to the exact item that caused the freeze. Enable DEBUG for this
most valuable diagnostic for stuck-job triage. module during triage.
""" """
parts = [f"walker.item.start service={service} path={label}"] parts = [f"walker.item.start service={service} path={label}"]
if sizeBytes is not None: if sizeBytes is not None:
parts.append(f"size={sizeBytes}") parts.append(f"size={sizeBytes}")
if mime: if mime:
parts.append(f"mime={mime}") parts.append(f"mime={mime}")
logger.info(" ".join(parts)) logger.debug(" ".join(parts))

View file

@ -0,0 +1,274 @@
#!/usr/bin/env python3
"""One-time migration: Reassign all DB references from an old user UID to a new UID.
When a user is re-created in PORTA (same username, new UUID), all existing records
still reference the old UUID. This script scans ALL registered databases and tables
for VARCHAR columns containing the old UID and updates them to the new UID.
Affected columns include:
- sysCreatedBy / sysModifiedBy (on every table via PowerOnModel)
- userId, revokedBy, createdByUserId, publishedBy, triggeredBy, assignedTo, etc.
The script auto-detects the new UID from the UserInDB table by username.
Usage:
# Dry-run (default) — shows what would change, no writes:
python scripts/script_migrate_user_uid.py --username patrick.helvetia --old-uid <OLD_UUID>
# Execute for real:
python scripts/script_migrate_user_uid.py --username patrick.helvetia --old-uid <OLD_UUID> --execute
"""
import argparse
import logging
import os
import sys
from pathlib import Path
from typing import List, Optional, Tuple
scriptPath = Path(__file__).resolve()
gatewayPath = scriptPath.parent.parent
sys.path.insert(0, str(gatewayPath))
os.chdir(str(gatewayPath))
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", force=True)
logger = logging.getLogger(__name__)
import psycopg2
import psycopg2.extras
from modules.shared.configuration import APP_CONFIG
ALL_DATABASES = [
"poweron_app",
"poweron_chat",
"poweron_management",
"poweron_knowledge",
"poweron_billing",
"poweron_workspace",
"poweron_graphicaleditor",
"poweron_chatbot",
"poweron_trustee",
"poweron_commcoach",
"poweron_neutralization",
"poweron_realestate",
"poweron_teamsbot",
]
def _getConnection(dbName: str):
return psycopg2.connect(
host=APP_CONFIG.get("DB_HOST", "localhost"),
port=int(APP_CONFIG.get("DB_PORT", "5432")),
database=dbName,
user=APP_CONFIG.get("DB_USER"),
password=APP_CONFIG.get("DB_PASSWORD_SECRET"),
client_encoding="utf8",
)
def _getTablesInDb(conn) -> List[str]:
with conn.cursor() as cur:
cur.execute("""
SELECT table_name FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE'
AND table_name NOT LIKE '\\_%%'
ORDER BY table_name
""")
return [row[0] for row in cur.fetchall()]
def _getVarcharColumns(conn, tableName: str) -> List[str]:
"""Get all VARCHAR/TEXT columns for a table (potential user-ID carriers)."""
with conn.cursor() as cur:
cur.execute("""
SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s
AND data_type IN ('character varying', 'text')
ORDER BY ordinal_position
""", (tableName,))
return [row[0] for row in cur.fetchall()]
def _countMatches(conn, tableName: str, columnName: str, oldUid: str) -> int:
with conn.cursor() as cur:
cur.execute(
f'SELECT COUNT(*) FROM "{tableName}" WHERE "{columnName}" = %s',
(oldUid,),
)
return cur.fetchone()[0]
def _updateColumn(conn, tableName: str, columnName: str, oldUid: str, newUid: str) -> int:
with conn.cursor() as cur:
cur.execute(
f'UPDATE "{tableName}" SET "{columnName}" = %s WHERE "{columnName}" = %s',
(newUid, oldUid),
)
return cur.rowcount
def _lookupNewUid(username: str) -> Optional[str]:
"""Find the current UID for a username in poweron_app.UserInDB."""
conn = _getConnection("poweron_app")
try:
with conn.cursor() as cur:
cur.execute(
'SELECT "id" FROM "UserInDB" WHERE "username" = %s',
(username,),
)
row = cur.fetchone()
return row[0] if row else None
finally:
conn.close()
def _scanJsonbForUid(conn, tableName: str, columnName: str, oldUid: str) -> int:
"""Count JSONB fields that contain the old UID as a text value anywhere."""
with conn.cursor() as cur:
cur.execute(
f"""SELECT COUNT(*) FROM "{tableName}"
WHERE "{columnName}"::text LIKE %s""",
(f"%{oldUid}%",),
)
return cur.fetchone()[0]
def _updateJsonbColumn(conn, tableName: str, columnName: str, oldUid: str, newUid: str) -> int:
"""Replace old UID inside JSONB columns using text replacement."""
with conn.cursor() as cur:
cur.execute(
f"""UPDATE "{tableName}"
SET "{columnName}" = REPLACE("{columnName}"::text, %s, %s)::jsonb
WHERE "{columnName}"::text LIKE %s""",
(oldUid, newUid, f"%{oldUid}%"),
)
return cur.rowcount
def _getJsonbColumns(conn, tableName: str) -> List[str]:
"""Get all JSONB columns for a table."""
with conn.cursor() as cur:
cur.execute("""
SELECT column_name FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = %s
AND data_type = 'jsonb'
ORDER BY ordinal_position
""", (tableName,))
return [row[0] for row in cur.fetchall()]
def migrate(username: str, oldUid: str, execute: bool = False):
newUid = _lookupNewUid(username)
if not newUid:
logger.error(f"User '{username}' not found in UserInDB. Cannot determine new UID.")
sys.exit(1)
if newUid == oldUid:
logger.error(f"Old UID and new UID are identical ({oldUid}). Nothing to migrate.")
sys.exit(1)
logger.info(f"Migration: user '{username}'")
logger.info(f" Old UID: {oldUid}")
logger.info(f" New UID: {newUid}")
logger.info(f" Mode: {'EXECUTE' if execute else 'DRY-RUN'}")
logger.info("")
totalUpdated = 0
findings: List[Tuple[str, str, str, int]] = []
for dbName in ALL_DATABASES:
try:
conn = _getConnection(dbName)
except Exception as e:
logger.warning(f" Cannot connect to {dbName}: {e}")
continue
try:
conn.autocommit = False
tables = _getTablesInDb(conn)
for tableName in tables:
varcharCols = _getVarcharColumns(conn, tableName)
for col in varcharCols:
count = _countMatches(conn, tableName, col, oldUid)
if count > 0:
findings.append((dbName, tableName, col, count))
if execute:
updated = _updateColumn(conn, tableName, col, oldUid, newUid)
totalUpdated += updated
logger.info(f" [UPDATED] {dbName}.{tableName}.{col}: {updated} rows")
else:
logger.info(f" [DRY-RUN] {dbName}.{tableName}.{col}: {count} rows would be updated")
jsonbCols = _getJsonbColumns(conn, tableName)
for col in jsonbCols:
count = _scanJsonbForUid(conn, tableName, col, oldUid)
if count > 0:
findings.append((dbName, tableName, f"{col} (JSONB)", count))
if execute:
_updateJsonbColumn(conn, tableName, col, oldUid, newUid)
totalUpdated += count
logger.info(f" [UPDATED] {dbName}.{tableName}.{col} (JSONB): {count} rows")
else:
logger.info(f" [DRY-RUN] {dbName}.{tableName}.{col} (JSONB): {count} rows would be updated")
if execute:
conn.commit()
else:
conn.rollback()
except Exception as e:
conn.rollback()
logger.error(f" Error processing {dbName}: {e}")
finally:
conn.close()
logger.info("")
logger.info("=" * 70)
logger.info("SUMMARY")
logger.info("=" * 70)
if not findings:
logger.info(" No references to old UID found in any database.")
else:
logger.info(f" Found {len(findings)} column(s) with references to old UID:")
for dbName, tableName, col, count in findings:
logger.info(f" {dbName}.{tableName}.{col}: {count} rows")
logger.info("")
if execute:
logger.info(f" Total rows updated: {totalUpdated}")
else:
logger.info(f" Total rows that would be updated: {sum(c for _, _, _, c in findings)}")
logger.info("")
logger.info(" To apply changes, re-run with --execute")
def main():
parser = argparse.ArgumentParser(
description="Migrate all DB references from old user UID to new UID."
)
parser.add_argument(
"--username",
required=True,
help="Username to migrate (e.g. 'patrick.helvetia'). Used to look up the new UID.",
)
parser.add_argument(
"--old-uid",
required=True,
help="The old UUID that is orphaned in the database.",
)
parser.add_argument(
"--execute",
action="store_true",
default=False,
help="Actually perform the migration. Without this flag, only a dry-run is done.",
)
args = parser.parse_args()
migrate(username=args.username, oldUid=args.old_uid, execute=args.execute)
if __name__ == "__main__":
main()

View file

@ -30,6 +30,7 @@ import psycopg2.errors
from modules.connectors.connectorDbPostgre import ( from modules.connectors.connectorDbPostgre import (
DatabaseConnector, DatabaseConnector,
DatabaseQueryError, DatabaseQueryError,
_stripNulBytesFromStr,
) )
@ -164,3 +165,12 @@ class TestGetRecordFailLoud:
assert excinfo.value.table == "DummyTable" assert excinfo.value.table == "DummyTable"
conn.rollback.assert_called_once() conn.rollback.assert_called_once()
class TestStripNulBytesFromStr:
def test_removesNul(self):
assert _stripNulBytesFromStr("a\x00b") == "ab"
def test_passthroughNonStr(self):
assert _stripNulBytesFromStr(None) is None
assert _stripNulBytesFromStr(7) == 7

View file

@ -0,0 +1,359 @@
"""Unit tests for the generic UDB tree builder.
Verifies key encoding/decoding and that children for parent keys with
existing handlers (top-level, conn, mgrp, feat) are produced with the
correct effective-flag triplet.
"""
from __future__ import annotations
import asyncio
import unittest
from unittest.mock import MagicMock, patch
from modules.serviceCenter.services.serviceKnowledge import _buildTree
class TestKeyCoding(unittest.TestCase):
def test_encode_decode_roundtrip(self):
key = _buildTree._encode("ds", "conn-1", "sharepointFolder", "/sites/x")
kind, parts = _buildTree._decode(key)
self.assertEqual(kind, "ds")
self.assertEqual(parts, ["conn-1", "sharepointFolder", "/sites/x"])
def test_top_level_kinds(self):
self.assertEqual(_buildTree._decode("conn|abc")[0], "conn")
self.assertEqual(_buildTree._decode("mgrp|m1")[0], "mgrp")
self.assertEqual(_buildTree._decode("feat|m1|trustee|fi-1")[1], ["m1", "trustee", "fi-1"])
class TestEffectiveTriplets(unittest.TestCase):
def test_ds_triplet_no_record_returns_defaults(self):
result = _buildTree._effectiveTripletDs("c", "msft", "/", [])
self.assertEqual(result, {
"effectiveNeutralize": False,
"effectiveScope": "personal",
"effectiveRagIndexEnabled": False,
})
def test_ds_triplet_inherits_from_root(self):
root = {
"id": "r", "connectionId": "c", "sourceType": "msft", "path": "/",
"neutralize": True, "scope": "mandate", "ragIndexEnabled": True,
}
result = _buildTree._effectiveTripletDs("c", "sharepointFolder", "/sites/x", [root])
self.assertEqual(result["effectiveNeutralize"], True)
self.assertEqual(result["effectiveScope"], "mandate")
self.assertEqual(result["effectiveRagIndexEnabled"], True)
def test_fds_triplet_inherits_from_workspace_wildcard(self):
ws = {
"id": "ws", "workspaceInstanceId": "ws-inst", "featureInstanceId": "fi1",
"tableName": "*", "recordFilter": None, "neutralize": True,
"scope": "mandate", "ragIndexEnabled": True,
}
result = _buildTree._effectiveTripletFds("fi1", "Pos", None, [ws])
self.assertEqual(result["effectiveNeutralize"], True)
self.assertEqual(result["effectiveScope"], "mandate")
self.assertEqual(result["effectiveRagIndexEnabled"], True)
class TestRecordLookup(unittest.TestCase):
def test_finds_ds_record_by_normalised_path(self):
rec = {"id": "x", "connectionId": "c", "sourceType": "msft", "path": "/folder"}
self.assertEqual(_buildTree._findDsRecord([rec], "c", "msft", "/folder/").get("id"), "x")
self.assertIsNone(_buildTree._findDsRecord([rec], "c", "msft", "/other"))
def test_finds_fds_record_with_matching_filter(self):
rec = {"id": "f", "workspaceInstanceId": "ws", "featureInstanceId": "fi1", "tableName": "Pos", "recordFilter": {"id": "5"}}
self.assertEqual(_buildTree._findFdsRecord([rec], "fi1", "Pos", {"id": "5"}).get("id"), "f")
self.assertIsNone(_buildTree._findFdsRecord([rec], "fi1", "Pos", {"id": "99"}))
def test_fds_record_with_none_filter_matches_only_none(self):
rec = {"id": "f", "workspaceInstanceId": "ws", "featureInstanceId": "fi1", "tableName": "*", "recordFilter": None}
self.assertEqual(_buildTree._findFdsRecord([rec], "fi1", "*", None).get("id"), "f")
self.assertIsNone(_buildTree._findFdsRecord([rec], "fi1", "*", {"id": "1"}))
class TestGetChildrenForParents(unittest.TestCase):
"""End-to-end orchestrator test with mocked dependencies."""
def _runAsync(self, coro):
return asyncio.get_event_loop().run_until_complete(coro)
def test_unknown_parent_key_returns_empty_list(self):
with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot:
rootIf = MagicMock()
rootIf.db.getRecordset.return_value = []
mockRoot.return_value = rootIf
ctx = MagicMock()
ctx.user.id = "u1"
ctx.mandateId = "m1"
result = self._runAsync(
_buildTree.getChildrenForParents("inst-1", ["bogus|key"], ctx)
)
self.assertEqual(result["bogus|key"], [])
def test_top_level_emits_personal_root_first(self):
"""Top-level emits personalRoot first, then mandate-group nodes inline."""
with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot:
rootIf = MagicMock()
rootIf.db.getRecordset.return_value = []
rootIf.getUserMandates.return_value = []
mockRoot.return_value = rootIf
ctx = MagicMock()
ctx.user.id = "u1"
ctx.mandateId = "m1"
result = self._runAsync(
_buildTree.getChildrenForParents("inst-1", [None], ctx)
)
children = result["__root__"]
self.assertGreaterEqual(len(children), 1)
personalRoot = children[0]
self.assertEqual(personalRoot["key"], "personalRoot")
self.assertEqual(personalRoot["kind"], "synthRoot")
self.assertIsNone(personalRoot["parentKey"])
self.assertTrue(personalRoot["hasChildren"])
self.assertTrue(personalRoot["defaultExpanded"])
class TestTopLevelLayout(unittest.TestCase):
"""Tests for the flat top-level layout (personalRoot + mandate groups)."""
def _runAsync(self, coro):
return asyncio.get_event_loop().run_until_complete(coro)
def test_personal_root_carries_neutral_default_triplet(self):
with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot:
rootIf = MagicMock()
rootIf.db.getRecordset.return_value = []
rootIf.getUserMandates.return_value = []
mockRoot.return_value = rootIf
ctx = MagicMock()
ctx.user.id = "u1"
ctx.mandateId = "m1"
result = self._runAsync(
_buildTree.getChildrenForParents("inst-1", [None], ctx)
)
personalRoot = result["__root__"][0]
self.assertFalse(personalRoot["effectiveNeutralize"])
self.assertEqual(personalRoot["effectiveScope"], "personal")
self.assertFalse(personalRoot["effectiveRagIndexEnabled"])
self.assertFalse(personalRoot["supportsRag"])
self.assertFalse(personalRoot["canBeAdded"])
self.assertIsNone(personalRoot["dataSourceId"])
self.assertIsNone(personalRoot["modelType"])
def test_personal_root_emits_active_connection_with_correct_parent(self):
with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \
patch("modules.serviceCenter.getService") as mockGetService:
rootIf = MagicMock()
rootIf.db.getRecordset.return_value = []
mockRoot.return_value = rootIf
chatService = MagicMock()
chatService.getUserConnections.return_value = [{
"id": "conn-1",
"status": "active",
"authority": "msft",
"externalEmail": "user@example.com",
}]
mockGetService.return_value = chatService
ctx = MagicMock()
ctx.user.id = "u1"
ctx.mandateId = "m1"
result = self._runAsync(
_buildTree.getChildrenForParents("inst-1", ["personalRoot"], ctx)
)
children = result["personalRoot"]
self.assertEqual(len(children), 1)
self.assertEqual(children[0]["key"], "conn|conn-1")
self.assertEqual(children[0]["kind"], "connection")
self.assertEqual(children[0]["parentKey"], "personalRoot")
self.assertEqual(children[0]["label"], "user@example.com")
self.assertTrue(children[0]["supportsRag"])
def test_personal_root_skips_inactive_connection(self):
with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \
patch("modules.serviceCenter.getService") as mockGetService:
rootIf = MagicMock()
rootIf.db.getRecordset.return_value = []
mockRoot.return_value = rootIf
chatService = MagicMock()
chatService.getUserConnections.return_value = [
{"id": "c1", "status": "active", "authority": "msft", "externalEmail": "a"},
{"id": "c2", "status": "expired", "authority": "google", "externalEmail": "b"},
]
mockGetService.return_value = chatService
ctx = MagicMock()
ctx.user.id = "u1"
ctx.mandateId = "m1"
result = self._runAsync(
_buildTree.getChildrenForParents("inst-1", ["personalRoot"], ctx)
)
self.assertEqual(len(result["personalRoot"]), 1)
self.assertEqual(result["personalRoot"][0]["connectionId"], "c1")
def test_mandate_groups_emitted_inline_at_top_level(self):
with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \
patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog:
rootIf = MagicMock()
rootIf.db.getRecordset.return_value = []
userMandate = MagicMock()
userMandate.mandateId = "m1"
rootIf.getUserMandates.return_value = [userMandate]
featureInst = MagicMock()
featureInst.id = "fi-1"
featureInst.featureCode = "trustee"
featureInst.enabled = True
rootIf.getFeatureInstancesByMandate.return_value = [featureInst]
featureAccess = MagicMock()
featureAccess.enabled = True
rootIf.getFeatureAccess.return_value = featureAccess
mockRoot.return_value = rootIf
catalog = MagicMock()
catalog.getFeaturesWithDataObjects.return_value = ["trustee"]
mockCatalog.return_value = catalog
ctx = MagicMock()
ctx.user.id = "u1"
ctx.mandateId = None
result = self._runAsync(
_buildTree.getChildrenForParents("inst-1", [None], ctx)
)
children = result["__root__"]
byKey = {c["key"]: c for c in children}
self.assertIn("personalRoot", byKey)
self.assertIn("mgrp|m1", byKey)
mgroup = byKey["mgrp|m1"]
self.assertEqual(mgroup["kind"], "mandateGroup")
self.assertIsNone(mgroup["parentKey"])
self.assertEqual(mgroup["mandateId"], "m1")
self.assertTrue(mgroup["defaultExpanded"])
self.assertFalse(mgroup["supportsRag"])
def test_top_level_omits_mandates_without_data_features(self):
with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \
patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog:
rootIf = MagicMock()
rootIf.db.getRecordset.return_value = []
userMandate = MagicMock()
userMandate.mandateId = "m1"
rootIf.getUserMandates.return_value = [userMandate]
rootIf.getFeatureInstancesByMandate.return_value = []
mockRoot.return_value = rootIf
catalog = MagicMock()
catalog.getFeaturesWithDataObjects.return_value = ["trustee"]
mockCatalog.return_value = catalog
ctx = MagicMock()
ctx.user.id = "u1"
ctx.mandateId = None
result = self._runAsync(
_buildTree.getChildrenForParents("inst-1", [None], ctx)
)
keys = [c["key"] for c in result["__root__"]]
self.assertEqual(keys, ["personalRoot"])
def test_personal_root_listed_first_via_display_order(self):
with patch("modules.interfaces.interfaceDbApp.getRootInterface") as mockRoot, \
patch("modules.security.rbacCatalog.getCatalogService") as mockCatalog:
rootIf = MagicMock()
rootIf.db.getRecordset.return_value = []
userMandate = MagicMock()
userMandate.mandateId = "m1"
rootIf.getUserMandates.return_value = [userMandate]
featureInst = MagicMock()
featureInst.id = "fi-1"
featureInst.featureCode = "trustee"
featureInst.enabled = True
rootIf.getFeatureInstancesByMandate.return_value = [featureInst]
featureAccess = MagicMock()
featureAccess.enabled = True
rootIf.getFeatureAccess.return_value = featureAccess
mockRoot.return_value = rootIf
catalog = MagicMock()
catalog.getFeaturesWithDataObjects.return_value = ["trustee"]
mockCatalog.return_value = catalog
ctx = MagicMock()
ctx.user.id = "u1"
ctx.mandateId = None
result = self._runAsync(
_buildTree.getChildrenForParents("inst-1", [None], ctx)
)
children = result["__root__"]
self.assertEqual(children[0]["key"], "personalRoot")
self.assertEqual(children[0]["displayOrder"], 0)
class TestFeatureTableFields(unittest.TestCase):
"""Per-column field expansion under a feature data-source table."""
def test_emits_one_node_per_field(self):
nodes = _buildTree._featureTableFields(
parentKey="fdstbl|fi-1|TrusteePosition",
featureInstanceId="fi-1",
tableName="TrusteePosition",
fieldNames=["id", "valuta", "company"],
allFds=[],
)
self.assertEqual(len(nodes), 3)
self.assertEqual(nodes[0]["kind"], "fdsField")
self.assertEqual(nodes[0]["fieldName"], "id")
self.assertEqual(nodes[0]["parentKey"], "fdstbl|fi-1|TrusteePosition")
self.assertEqual(nodes[0]["key"], "fdsfld|fi-1|TrusteePosition|id")
self.assertFalse(nodes[0]["hasChildren"])
self.assertFalse(nodes[0]["supportsRag"])
def test_field_neutralize_inherits_from_table_blanket(self):
rec = {"id": "f", "workspaceInstanceId": "ws-1", "featureInstanceId": "fi-1",
"tableName": "TrusteePosition", "recordFilter": None,
"neutralize": True, "neutralizeFields": None,
"scope": None, "ragIndexEnabled": False}
nodes = _buildTree._featureTableFields(
parentKey="fdstbl|fi-1|TrusteePosition",
featureInstanceId="fi-1",
tableName="TrusteePosition",
fieldNames=["email", "company"],
allFds=[rec],
)
self.assertTrue(nodes[0]["effectiveNeutralize"])
self.assertTrue(nodes[1]["effectiveNeutralize"])
def test_field_neutralize_explicit_via_neutralize_fields(self):
rec = {"id": "f", "workspaceInstanceId": "ws-1", "featureInstanceId": "fi-1",
"tableName": "TrusteePosition", "recordFilter": None,
"neutralize": False, "neutralizeFields": ["email"],
"scope": None, "ragIndexEnabled": False}
nodes = _buildTree._featureTableFields(
parentKey="fdstbl|fi-1|TrusteePosition",
featureInstanceId="fi-1",
tableName="TrusteePosition",
fieldNames=["email", "company"],
allFds=[rec],
)
byField = {n["fieldName"]: n for n in nodes}
self.assertTrue(byField["email"]["effectiveNeutralize"])
self.assertFalse(byField["company"]["effectiveNeutralize"])
if __name__ == "__main__":
unittest.main()

View file

@ -1,12 +1,12 @@
"""Unit tests for `_inheritFlags` cascade-inherit helpers. """Unit tests for `_inheritFlags` cascade-inherit helpers.
Verifies: Verifies:
- getEffectiveFlag walks ancestors via path-prefix matching - getEffectiveFlag mode='walk': walks ancestors via path-prefix matching
- root default is False (or 'personal' for scope) when nothing explicit in chain - getEffectiveFlag mode='aggregate': returns 'mixed' when subtree diverges
- only same-connectionId AND same-sourceType ancestors are considered - cascadeResetDescendants: bottom-up reset returning List[str]
- cascadeResetDescendants only touches descendants with explicit values for THAT flag - cascadeResetDescendantsFds: same for FeatureDataSource
- '/' is treated as ancestor of every non-root path - collectAncestorChain / collectAncestorChainFds: ancestor discovery
- '/foo' is NOT ancestor of '/foobar' (must require '/' separator) - buildEffectiveByConnection / buildEffectiveByWorkspaceFds: batch compute
""" """
from __future__ import annotations from __future__ import annotations
@ -33,7 +33,26 @@ def _ds(idVal: str, path: str, **flags) -> dict:
return base return base
class TestEffectiveFlag(unittest.TestCase): def _fds(idVal: str, *, tableName: str, recordFilter=None, featureInstanceId="fi-1", **flags) -> dict:
"""Build a FeatureDataSource dict fixture."""
base = {
"id": idVal,
"workspaceInstanceId": "ws-1",
"featureInstanceId": featureInstanceId,
"tableName": tableName,
"recordFilter": recordFilter,
"neutralize": None,
"scope": None,
}
base.update(flags)
return base
# ===========================================================================
# DataSource: getEffectiveFlag mode='walk'
# ===========================================================================
class TestEffectiveFlagWalk(unittest.TestCase):
def test_explicit_own_value_wins(self): def test_explicit_own_value_wins(self):
root = _ds("r", "/", neutralize=False) root = _ds("r", "/", neutralize=False)
leaf = _ds("l", "/folder/sub", neutralize=True) leaf = _ds("l", "/folder/sub", neutralize=True)
@ -65,7 +84,6 @@ class TestEffectiveFlag(unittest.TestCase):
self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [otherType, leaf])) self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [otherType, leaf]))
def test_path_separator_required(self): def test_path_separator_required(self):
"""`/foo` must NOT be ancestor of `/foobar` (no shared `/` boundary)."""
notAncestor = _ds("a", "/foo", neutralize=True) notAncestor = _ds("a", "/foo", neutralize=True)
leaf = _ds("l", "/foobar") leaf = _ds("l", "/foobar")
self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [notAncestor, leaf])) self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [notAncestor, leaf]))
@ -90,32 +108,101 @@ class TestEffectiveFlag(unittest.TestCase):
_inheritFlags.getEffectiveFlag(leaf, "unknownFlag", [leaf]) _inheritFlags.getEffectiveFlag(leaf, "unknownFlag", [leaf])
def test_explicit_false_overrides_inherited_true(self): def test_explicit_false_overrides_inherited_true(self):
"""Explicit False on a child must NOT cascade up to True from an ancestor."""
root = _ds("r", "/", neutralize=True) root = _ds("r", "/", neutralize=True)
leaf = _ds("l", "/folder", neutralize=False) leaf = _ds("l", "/folder", neutralize=False)
self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf])) self.assertFalse(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [root, leaf]))
def test_connection_root_inherits_cross_sourcetype(self): def test_connection_root_inherits_cross_sourcetype(self):
"""Connection-root (sourceType=authority, path='/') is ancestor of all DS in that connection."""
connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
spService = _ds("sp", "/", sourceType="sharepointFolder") spService = _ds("sp", "/", sourceType="sharepointFolder")
olService = _ds("ol", "/", sourceType="outlookFolder") olService = _ds("ol", "/", sourceType="outlookFolder")
self.assertTrue(_inheritFlags.getEffectiveFlag(spService, "neutralize", [connRoot, spService, olService])) allDs = [connRoot, spService, olService]
self.assertTrue(_inheritFlags.getEffectiveFlag(olService, "neutralize", [connRoot, spService, olService])) self.assertTrue(_inheritFlags.getEffectiveFlag(spService, "neutralize", allDs))
self.assertTrue(_inheritFlags.getEffectiveFlag(olService, "neutralize", allDs))
def test_same_sourcetype_ancestor_wins_over_connection_root(self): def test_same_sourcetype_ancestor_wins_over_connection_root(self):
"""A same-sourceType service-root ancestor beats the connection-root."""
connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
spRoot = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False) spRoot = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False)
spLeaf = _ds("spl", "/sites/x", sourceType="sharepointFolder") spLeaf = _ds("spl", "/sites/x", sourceType="sharepointFolder")
self.assertFalse(_inheritFlags.getEffectiveFlag(spLeaf, "neutralize", [connRoot, spRoot, spLeaf])) self.assertFalse(_inheritFlags.getEffectiveFlag(spLeaf, "neutralize", [connRoot, spRoot, spLeaf]))
def test_connection_root_does_not_self_inherit(self): def test_connection_root_does_not_self_inherit(self):
"""Connection-root has no ancestor — does not infinite-loop on itself."""
connRoot = _ds("conn", "/", sourceType="msft") connRoot = _ds("conn", "/", sourceType="msft")
self.assertFalse(_inheritFlags.getEffectiveFlag(connRoot, "neutralize", [connRoot])) self.assertFalse(_inheritFlags.getEffectiveFlag(connRoot, "neutralize", [connRoot]))
# ===========================================================================
# DataSource: getEffectiveFlag mode='aggregate'
# ===========================================================================
class TestEffectiveFlagAggregate(unittest.TestCase):
def test_leaf_without_descendants_returns_concrete(self):
leaf = _ds("l", "/folder", neutralize=True)
self.assertTrue(_inheritFlags.getEffectiveFlag(leaf, "neutralize", [leaf], mode="aggregate"))
def test_all_descendants_same_returns_concrete(self):
root = _ds("r", "/", neutralize=True)
child1 = _ds("c1", "/a", neutralize=True)
child2 = _ds("c2", "/b") # inherits True from root
allDs = [root, child1, child2]
self.assertTrue(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate"))
def test_divergent_descendants_returns_mixed(self):
root = _ds("r", "/", neutralize=True)
child1 = _ds("c1", "/a", neutralize=False)
child2 = _ds("c2", "/b") # inherits True from root
allDs = [root, child1, child2]
self.assertEqual(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate"), "mixed")
def test_mixed_scope(self):
root = _ds("r", "/", scope="personal")
child1 = _ds("c1", "/a", scope="team")
child2 = _ds("c2", "/b") # inherits personal from root
allDs = [root, child1, child2]
self.assertEqual(_inheritFlags.getEffectiveFlag(root, "scope", allDs, mode="aggregate"), "mixed")
def test_all_scope_same_explicit_returns_concrete(self):
root = _ds("r", "/", scope="team")
child1 = _ds("c1", "/a", scope="team")
child2 = _ds("c2", "/b") # inherits team
allDs = [root, child1, child2]
self.assertEqual(_inheritFlags.getEffectiveFlag(root, "scope", allDs, mode="aggregate"), "team")
def test_connection_root_aggregate_cross_sourcetype(self):
connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False)
olInherit = _ds("ol", "/", sourceType="outlookFolder") # inherits True
allDs = [connRoot, spExplicit, olInherit]
self.assertEqual(
_inheritFlags.getEffectiveFlag(connRoot, "neutralize", allDs, mode="aggregate"),
"mixed",
)
def test_mid_level_aggregate_only_considers_own_subtree(self):
root = _ds("r", "/", neutralize=True)
mid = _ds("m", "/folder", neutralize=True)
midChild = _ds("mc", "/folder/sub", neutralize=True)
sibling = _ds("s", "/other", neutralize=False) # not under mid
allDs = [root, mid, midChild, sibling]
# mid's subtree is just midChild(True) + mid(True) = uniform
self.assertTrue(_inheritFlags.getEffectiveFlag(mid, "neutralize", allDs, mode="aggregate"))
# root's subtree includes sibling(False) = mixed
self.assertEqual(
_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="aggregate"),
"mixed",
)
def test_walk_mode_never_returns_mixed(self):
root = _ds("r", "/", neutralize=True)
child = _ds("c", "/a", neutralize=False)
allDs = [root, child]
self.assertTrue(_inheritFlags.getEffectiveFlag(root, "neutralize", allDs, mode="walk"))
# ===========================================================================
# DataSource: cascadeResetDescendants (bottom-up, List[str])
# ===========================================================================
class TestCascadeReset(unittest.TestCase): class TestCascadeReset(unittest.TestCase):
def _makeRootIf(self, dataSources: List[dict]): def _makeRootIf(self, dataSources: List[dict]):
rootIf = MagicMock() rootIf = MagicMock()
@ -127,54 +214,76 @@ class TestCascadeReset(unittest.TestCase):
rootIf.db.recordModify = MagicMock(side_effect=_modify) rootIf.db.recordModify = MagicMock(side_effect=_modify)
return rootIf, modified return rootIf, modified
def test_returns_list_of_ids(self):
parent = _ds("p", "/sites", neutralize=True)
child = _ds("c1", "/sites/folder1", neutralize=False)
rootIf, _ = self._makeRootIf([parent, child])
result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
self.assertIsInstance(result, list)
self.assertEqual(result, ["c1"])
def test_resets_only_explicit_descendants(self): def test_resets_only_explicit_descendants(self):
parent = _ds("p", "/sites", neutralize=True) parent = _ds("p", "/sites", neutralize=True)
explicitChild = _ds("c1", "/sites/folder1", neutralize=False) explicitChild = _ds("c1", "/sites/folder1", neutralize=False)
inheritChild = _ds("c2", "/sites/folder2") # inherit -> not touched inheritChild = _ds("c2", "/sites/folder2")
sibling = _ds("s", "/other", neutralize=True) # NOT a descendant sibling = _ds("s", "/other", neutralize=True)
rootIf, modified = self._makeRootIf([parent, explicitChild, inheritChild, sibling]) rootIf, modified = self._makeRootIf([parent, explicitChild, inheritChild, sibling])
affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
self.assertEqual(affected, 1) self.assertEqual(result, ["c1"])
self.assertEqual(modified, [("c1", {"neutralize": None})]) self.assertEqual(modified, [("c1", {"neutralize": None})])
def test_does_not_touch_other_flags(self): def test_bottom_up_order(self):
parent = _ds("p", "/sites", neutralize=True) """Deepest items are reset first."""
child = _ds("c", "/sites/sub", neutralize=False, ragIndexEnabled=True) parent = _ds("p", "/", neutralize=True)
level1 = _ds("l1", "/a", neutralize=False)
level2 = _ds("l2", "/a/b", neutralize=False)
level3 = _ds("l3", "/a/b/c", neutralize=False)
rootIf, modified = self._makeRootIf([parent, level1, level2, level3])
result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
self.assertEqual(result, ["l3", "l2", "l1"])
def test_deep_cascade_through_null_items(self):
"""null items are skipped (no DB write) but cascade continues deeper."""
parent = _ds("p", "/", neutralize=True)
nullChild = _ds("n", "/a") # null — no write, but not a barrier
deepExplicit = _ds("d", "/a/b", neutralize=False)
rootIf, modified = self._makeRootIf([parent, nullChild, deepExplicit])
result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
self.assertEqual(result, ["d"])
self.assertEqual(modified, [("d", {"neutralize": None})])
def test_does_not_modify_parent(self):
parent = _ds("p", "/", neutralize=True)
child = _ds("c", "/a", neutralize=False)
rootIf, modified = self._makeRootIf([parent, child]) rootIf, modified = self._makeRootIf([parent, child])
_inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize") _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
self.assertNotIn("p", [m[0] for m in modified])
self.assertEqual(modified, [("c", {"neutralize": None})])
# ragIndexEnabled and scope on the child must remain untouched.
def test_does_not_cross_sourcetype(self):
"""Non-connection-root parents stay within their sourceType for cascade."""
parent = _ds("p", "/", neutralize=True, sourceType="sharepointFolder")
otherTypeDescendant = _ds("o", "/anything", neutralize=False, sourceType="outlookFolder")
rootIf, modified = self._makeRootIf([parent, otherTypeDescendant])
affected = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
self.assertEqual(affected, 0)
self.assertEqual(modified, [])
def test_connection_root_cascades_cross_sourcetype(self): def test_connection_root_cascades_cross_sourcetype(self):
"""Toggle on connection-root cascades into every explicit DS of that connection."""
connRoot = _ds("conn", "/", sourceType="msft", neutralize=True) connRoot = _ds("conn", "/", sourceType="msft", neutralize=True)
spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False) spExplicit = _ds("sp", "/", sourceType="sharepointFolder", neutralize=False)
olInherit = _ds("ol", "/", sourceType="outlookFolder") olInherit = _ds("ol", "/", sourceType="outlookFolder")
spLeafExplicit = _ds("sp-leaf", "/sites/x", sourceType="sharepointFolder", neutralize=True) spLeaf = _ds("sp-leaf", "/sites/x", sourceType="sharepointFolder", neutralize=True)
rootIf, modified = self._makeRootIf([connRoot, spExplicit, olInherit, spLeafExplicit]) rootIf, modified = self._makeRootIf([connRoot, spExplicit, olInherit, spLeaf])
affected = _inheritFlags.cascadeResetDescendants(rootIf, connRoot, "neutralize") result = _inheritFlags.cascadeResetDescendants(rootIf, connRoot, "neutralize")
# spExplicit and spLeafExplicit had explicit values → reset. olInherit untouched. self.assertEqual(set(result), {"sp", "sp-leaf"})
self.assertEqual(affected, 2) # sp-leaf is deeper, should come first
self.assertEqual({m[0] for m in modified}, {"sp", "sp-leaf"}) self.assertEqual(result[0], "sp-leaf")
for _, fields in modified:
self.assertEqual(fields, {"neutralize": None}) def test_does_not_cross_sourcetype_for_non_authority(self):
parent = _ds("p", "/", neutralize=True, sourceType="sharepointFolder")
otherType = _ds("o", "/anything", neutralize=False, sourceType="outlookFolder")
rootIf, modified = self._makeRootIf([parent, otherType])
result = _inheritFlags.cascadeResetDescendants(rootIf, parent, "neutralize")
self.assertEqual(result, [])
def test_unknown_flag_raises(self): def test_unknown_flag_raises(self):
parent = _ds("p", "/", neutralize=True) parent = _ds("p", "/", neutralize=True)
@ -183,57 +292,59 @@ class TestCascadeReset(unittest.TestCase):
_inheritFlags.cascadeResetDescendants(rootIf, parent, "unknownFlag") _inheritFlags.cascadeResetDescendants(rootIf, parent, "unknownFlag")
def _fds(idVal: str, *, tableName: str, recordFilter=None, **flags) -> dict: # ===========================================================================
"""Build a FeatureDataSource dict fixture.""" # DataSource: collectAncestorChain
base = { # ===========================================================================
"id": idVal,
"workspaceInstanceId": "ws-1", class TestCollectAncestorChain(unittest.TestCase):
"tableName": tableName, def test_returns_nearest_first(self):
"recordFilter": recordFilter, root = _ds("r", "/", neutralize=True)
"neutralize": None, mid = _ds("m", "/a")
"scope": None, leaf = _ds("l", "/a/b")
} chain = _inheritFlags.collectAncestorChain(leaf, [root, mid, leaf])
base.update(flags) self.assertEqual([_inheritFlags._getRecordValue(c, "id") for c in chain], ["m", "r"])
return base
def test_connection_root_is_last(self):
connRoot = _ds("conn", "/", sourceType="msft")
spRoot = _ds("sp", "/", sourceType="sharepointFolder")
spLeaf = _ds("spl", "/sub", sourceType="sharepointFolder")
chain = _inheritFlags.collectAncestorChain(spLeaf, [connRoot, spRoot, spLeaf])
ids = [_inheritFlags._getRecordValue(c, "id") for c in chain]
self.assertEqual(ids, ["sp", "conn"])
def test_root_has_no_ancestors(self):
root = _ds("r", "/")
chain = _inheritFlags.collectAncestorChain(root, [root])
self.assertEqual(chain, [])
class TestFdsClassifyAndAncestry(unittest.TestCase): # ===========================================================================
def test_classify_workspace_wildcard(self): # DataSource: buildEffectiveByConnection
self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="*")), "workspace") # ===========================================================================
def test_classify_table_wildcard(self): class TestBuildEffectiveByConnection(unittest.TestCase):
self.assertEqual(_inheritFlags._fdsClassify(_fds("a", tableName="Pos")), "table") def test_walk_mode(self):
root = _ds("r", "/", neutralize=True)
child = _ds("c", "/a", neutralize=False)
leaf = _ds("l", "/a/b") # inherits False from child
result = _inheritFlags.buildEffectiveByConnection([root, child, leaf], "neutralize", mode="walk")
self.assertEqual(result, {"r": True, "c": False, "l": False})
def test_classify_record_specific(self): def test_aggregate_mode(self):
rec = _fds("a", tableName="Pos", recordFilter={"id": "r-1"}) root = _ds("r", "/", neutralize=True)
self.assertEqual(_inheritFlags._fdsClassify(rec), "record") child = _ds("c", "/a", neutralize=False)
leaf = _ds("l", "/a/b") # inherits False from child
def test_workspace_is_ancestor_of_table_and_record(self): result = _inheritFlags.buildEffectiveByConnection([root, child, leaf], "neutralize", mode="aggregate")
ws = _fds("ws", tableName="*") self.assertEqual(result["r"], "mixed")
tbl = _fds("t", tableName="Pos") self.assertEqual(result["c"], False)
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) self.assertEqual(result["l"], False)
self.assertTrue(_inheritFlags._fdsIsAncestor(ws, tbl))
self.assertTrue(_inheritFlags._fdsIsAncestor(ws, rec))
def test_table_is_ancestor_of_record_same_table_only(self):
tbl = _fds("t", tableName="Pos")
recSame = _fds("r1", tableName="Pos", recordFilter={"id": "1"})
recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"})
self.assertTrue(_inheritFlags._fdsIsAncestor(tbl, recSame))
self.assertFalse(_inheritFlags._fdsIsAncestor(tbl, recOther))
def test_record_has_no_descendants(self):
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
tbl = _fds("t", tableName="Pos")
self.assertFalse(_inheritFlags._fdsIsAncestor(rec, tbl))
def test_no_cross_workspace_ancestry(self):
ws = _fds("ws", tableName="*", workspaceInstanceId="ws-A")
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, workspaceInstanceId="ws-B")
self.assertFalse(_inheritFlags._fdsIsAncestor(ws, rec))
class TestFdsEffectiveFlag(unittest.TestCase): # ===========================================================================
# FeatureDataSource: getEffectiveFlagFds
# ===========================================================================
class TestFdsEffectiveFlagWalk(unittest.TestCase):
def test_own_explicit_wins(self): def test_own_explicit_wins(self):
ws = _fds("ws", tableName="*", neutralize=False) ws = _fds("ws", tableName="*", neutralize=False)
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
@ -262,9 +373,50 @@ class TestFdsEffectiveFlag(unittest.TestCase):
def test_unknown_flag_raises(self): def test_unknown_flag_raises(self):
rec = _fds("r", tableName="*") rec = _fds("r", tableName="*")
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_inheritFlags.getEffectiveFlagFds(rec, "ragIndexEnabled", [rec]) _inheritFlags.getEffectiveFlagFds(rec, "doesNotExist", [rec])
class TestFdsEffectiveFlagAggregate(unittest.TestCase):
def test_leaf_without_descendants(self):
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
self.assertTrue(_inheritFlags.getEffectiveFlagFds(rec, "neutralize", [rec], mode="aggregate"))
def test_all_descendants_same(self):
ws = _fds("ws", tableName="*", neutralize=True)
tbl = _fds("t", tableName="Pos") # inherits True
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits True
allFds = [ws, tbl, rec]
self.assertTrue(_inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate"))
def test_divergent_descendants_returns_mixed(self):
ws = _fds("ws", tableName="*", neutralize=True)
tbl = _fds("t", tableName="Pos", neutralize=False)
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits False from tbl
allFds = [ws, tbl, rec]
self.assertEqual(
_inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate"),
"mixed",
)
def test_table_aggregate_own_subtree_only(self):
ws = _fds("ws", tableName="*", neutralize=True)
tblA = _fds("tA", tableName="A", neutralize=True)
recA = _fds("rA", tableName="A", recordFilter={"id": "1"}, neutralize=True)
tblB = _fds("tB", tableName="B", neutralize=False)
allFds = [ws, tblA, recA, tblB]
# tblA subtree: all True
self.assertTrue(_inheritFlags.getEffectiveFlagFds(tblA, "neutralize", allFds, mode="aggregate"))
# ws subtree: mixed (tblB is False)
self.assertEqual(
_inheritFlags.getEffectiveFlagFds(ws, "neutralize", allFds, mode="aggregate"),
"mixed",
)
# ===========================================================================
# FeatureDataSource: cascadeResetDescendantsFds (bottom-up, List[str])
# ===========================================================================
class TestFdsCascadeReset(unittest.TestCase): class TestFdsCascadeReset(unittest.TestCase):
def _makeRootIf(self, fdses): def _makeRootIf(self, fdses):
rootIf = MagicMock() rootIf = MagicMock()
@ -276,6 +428,14 @@ class TestFdsCascadeReset(unittest.TestCase):
rootIf.db.recordModify = MagicMock(side_effect=_modify) rootIf.db.recordModify = MagicMock(side_effect=_modify)
return rootIf, modified return rootIf, modified
def test_returns_list_of_ids(self):
ws = _fds("ws", tableName="*", neutralize=True)
tbl = _fds("t", tableName="Pos", neutralize=False)
rootIf, _ = self._makeRootIf([ws, tbl])
result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize")
self.assertIsInstance(result, list)
self.assertEqual(result, ["t"])
def test_workspace_cascades_to_all_explicit_descendants(self): def test_workspace_cascades_to_all_explicit_descendants(self):
ws = _fds("ws", tableName="*", neutralize=True) ws = _fds("ws", tableName="*", neutralize=True)
tblExplicit = _fds("t", tableName="Pos", neutralize=False) tblExplicit = _fds("t", tableName="Pos", neutralize=False)
@ -283,10 +443,11 @@ class TestFdsCascadeReset(unittest.TestCase):
recExplicit = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) recExplicit = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
rootIf, modified = self._makeRootIf([ws, tblExplicit, tblInherit, recExplicit]) rootIf, modified = self._makeRootIf([ws, tblExplicit, tblInherit, recExplicit])
affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize") result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "neutralize")
self.assertEqual(affected, 2) self.assertEqual(set(result), {"t", "r"})
self.assertEqual({m[0] for m in modified}, {"t", "r"}) # record is deeper (depth 2) than table (depth 1), should come first
self.assertEqual(result[0], "r")
def test_table_cascades_only_to_same_table_records(self): def test_table_cascades_only_to_same_table_records(self):
tbl = _fds("t", tableName="Pos", neutralize=True) tbl = _fds("t", tableName="Pos", neutralize=True)
@ -294,25 +455,189 @@ class TestFdsCascadeReset(unittest.TestCase):
recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"}, neutralize=False) recOther = _fds("r2", tableName="Other", recordFilter={"id": "1"}, neutralize=False)
rootIf, modified = self._makeRootIf([tbl, recSame, recOther]) rootIf, modified = self._makeRootIf([tbl, recSame, recOther])
affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, tbl, "neutralize") result = _inheritFlags.cascadeResetDescendantsFds(rootIf, tbl, "neutralize")
self.assertEqual(affected, 1) self.assertEqual(result, ["r1"])
self.assertEqual(modified, [("r1", {"neutralize": None})]) self.assertEqual(modified, [("r1", {"neutralize": None})])
def test_record_has_no_cascade(self): def test_record_has_no_cascade(self):
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True) rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}, neutralize=True)
rootIf, modified = self._makeRootIf([rec]) rootIf, modified = self._makeRootIf([rec])
affected = _inheritFlags.cascadeResetDescendantsFds(rootIf, rec, "neutralize") result = _inheritFlags.cascadeResetDescendantsFds(rootIf, rec, "neutralize")
self.assertEqual(affected, 0) self.assertEqual(result, [])
self.assertEqual(modified, [])
def test_unknown_flag_raises(self): def test_unknown_flag_raises(self):
ws = _fds("ws", tableName="*", neutralize=True) ws = _fds("ws", tableName="*", neutralize=True)
rootIf, _ = self._makeRootIf([ws]) rootIf, _ = self._makeRootIf([ws])
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
_inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "ragIndexEnabled") _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "doesNotExist")
# ===========================================================================
# FeatureDataSource: collectAncestorChainFds
# ===========================================================================
class TestCollectAncestorChainFds(unittest.TestCase):
def test_record_has_table_then_workspace(self):
ws = _fds("ws", tableName="*")
tbl = _fds("t", tableName="Pos")
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
chain = _inheritFlags.collectAncestorChainFds(rec, [ws, tbl, rec])
ids = [c["id"] for c in chain]
self.assertEqual(ids, ["t", "ws"])
def test_table_has_only_workspace(self):
ws = _fds("ws", tableName="*")
tbl = _fds("t", tableName="Pos")
chain = _inheritFlags.collectAncestorChainFds(tbl, [ws, tbl])
self.assertEqual([c["id"] for c in chain], ["ws"])
def test_workspace_has_no_ancestors(self):
ws = _fds("ws", tableName="*")
chain = _inheritFlags.collectAncestorChainFds(ws, [ws])
self.assertEqual(chain, [])
# ===========================================================================
# FeatureDataSource: buildEffectiveByWorkspaceFds
# ===========================================================================
class TestBuildEffectiveByWorkspaceFds(unittest.TestCase):
def test_walk_mode(self):
ws = _fds("ws", tableName="*", neutralize=True)
tbl = _fds("t", tableName="Pos", neutralize=False)
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"}) # inherits False from tbl
result = _inheritFlags.buildEffectiveByWorkspaceFds([ws, tbl, rec], "neutralize", mode="walk")
self.assertEqual(result, {"ws": True, "t": False, "r": False})
def test_aggregate_mode(self):
ws = _fds("ws", tableName="*", neutralize=True)
tbl = _fds("t", tableName="Pos", neutralize=False)
rec = _fds("r", tableName="Pos", recordFilter={"id": "1"})
result = _inheritFlags.buildEffectiveByWorkspaceFds([ws, tbl, rec], "neutralize", mode="aggregate")
self.assertEqual(result["ws"], "mixed")
self.assertEqual(result["t"], False)
self.assertEqual(result["r"], False)
# ===========================================================================
# resolveEffectiveForPath (with and without own record)
# ===========================================================================
class TestResolveEffectiveForPath(unittest.TestCase):
def test_with_exact_record(self):
root = _ds("r", "/", neutralize=True, scope="mandate", ragIndexEnabled=False)
leaf = _ds("l", "/folder/sub", neutralize=False)
allDs = [root, leaf]
result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/folder/sub", allDs)
self.assertEqual(result["effectiveNeutralize"], False)
self.assertEqual(result["effectiveScope"], "mandate")
self.assertEqual(result["effectiveRagIndexEnabled"], False)
def test_without_record_inherits_from_ancestor(self):
root = _ds("r", "/", neutralize=True, scope="mandate", ragIndexEnabled=True)
allDs = [root]
result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/deep/path/file.txt", allDs)
self.assertEqual(result["effectiveNeutralize"], True)
self.assertEqual(result["effectiveScope"], "mandate")
self.assertEqual(result["effectiveRagIndexEnabled"], True)
def test_without_record_inherits_from_closest_ancestor(self):
root = _ds("r", "/", neutralize=True, ragIndexEnabled=True)
mid = _ds("m", "/folder", neutralize=False, ragIndexEnabled=False)
allDs = [root, mid]
result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/folder/sub/file.txt", allDs)
self.assertEqual(result["effectiveNeutralize"], False)
self.assertEqual(result["effectiveRagIndexEnabled"], False)
def test_without_record_no_ancestors_returns_defaults(self):
allDs: list = []
result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/path", allDs)
self.assertEqual(result["effectiveNeutralize"], False)
self.assertEqual(result["effectiveScope"], "personal")
self.assertEqual(result["effectiveRagIndexEnabled"], False)
def test_connection_root_covers_service_subtree(self):
connRoot = _ds("cr", "/", neutralize=True, sourceType="msft")
allDs = [connRoot]
result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/sites/intranet", allDs)
self.assertEqual(result["effectiveNeutralize"], True)
def test_exact_record_with_aggregate_mixed(self):
root = _ds("r", "/", neutralize=True)
leaf = _ds("l", "/sub", neutralize=False)
allDs = [root, leaf]
result = _inheritFlags.resolveEffectiveForPath("conn-1", "sharepointFolder", "/", allDs, mode="aggregate")
self.assertEqual(result["effectiveNeutralize"], "mixed")
class TestResolveEffectiveForFds(unittest.TestCase):
def test_with_exact_record(self):
ws = _fds("ws", tableName="*", neutralize=True, scope="mandate")
tbl = _fds("t", tableName="Pos", neutralize=False, scope="personal")
allFds = [ws, tbl]
result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds)
self.assertEqual(result["effectiveNeutralize"], False)
self.assertEqual(result["effectiveScope"], "personal")
self.assertEqual(result["effectiveRagIndexEnabled"], False)
def test_without_record_inherits_from_workspace_wildcard(self):
ws = _fds("ws", tableName="*", neutralize=True, scope="mandate", ragIndexEnabled=True)
allFds = [ws]
result = _inheritFlags.resolveEffectiveForFds("fi-1", "Unknown", None, allFds)
self.assertEqual(result["effectiveNeutralize"], True)
self.assertEqual(result["effectiveScope"], "mandate")
self.assertEqual(result["effectiveRagIndexEnabled"], True)
def test_without_record_no_ancestors_returns_defaults(self):
allFds: list = []
result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds)
self.assertEqual(result["effectiveNeutralize"], False)
self.assertEqual(result["effectiveScope"], "personal")
self.assertEqual(result["effectiveRagIndexEnabled"], False)
def test_rag_inherits_when_table_overrides_neutralize_only(self):
"""Tables that override only neutralize must still inherit RAG from parent."""
ws = _fds("ws", tableName="*", ragIndexEnabled=True)
tbl = _fds("t", tableName="Pos", neutralize=False)
allFds = [ws, tbl]
result = _inheritFlags.resolveEffectiveForFds("fi-1", "Pos", None, allFds)
self.assertEqual(result["effectiveRagIndexEnabled"], True)
def test_rag_aggregate_mixed_when_descendants_diverge(self):
ws = _fds("ws", tableName="*", ragIndexEnabled=True)
tbl = _fds("t", tableName="Pos", ragIndexEnabled=False)
allFds = [ws, tbl]
result = _inheritFlags.resolveEffectiveForFds("fi-1", "*", None, allFds, mode="aggregate")
self.assertEqual(result["effectiveRagIndexEnabled"], "mixed")
def test_inheritable_fds_flags_includes_rag(self):
self.assertIn("ragIndexEnabled", _inheritFlags._INHERITABLE_FDS_FLAGS)
self.assertIn("neutralize", _inheritFlags._INHERITABLE_FDS_FLAGS)
self.assertIn("scope", _inheritFlags._INHERITABLE_FDS_FLAGS)
# ===========================================================================
# FDS cascade resets RAG (in addition to neutralize and scope)
# ===========================================================================
class TestCascadeResetFdsRag(unittest.TestCase):
def test_cascade_resets_rag_on_descendants(self):
ws = _fds("ws", tableName="*")
tbl = _fds("t", tableName="Pos", ragIndexEnabled=False)
allFds = [ws, tbl]
rootIf = MagicMock()
rootIf.db.getRecordset.return_value = allFds
rootIf.db.recordModify = MagicMock()
result = _inheritFlags.cascadeResetDescendantsFds(rootIf, ws, "ragIndexEnabled")
self.assertIn("t", result)
rootIf.db.recordModify.assert_called()
# ===========================================================================
# Path normalization
# ===========================================================================
class TestPathNormalization(unittest.TestCase): class TestPathNormalization(unittest.TestCase):
def test_empty_path_normalises_to_root(self): def test_empty_path_normalises_to_root(self):
self.assertEqual(_inheritFlags._normalisePath(""), "/") self.assertEqual(_inheritFlags._normalisePath(""), "/")

View file

@ -42,7 +42,7 @@ from modules.features.teamsbot.datamodelTeamsbot import (
from modules.features.teamsbot.service import ( from modules.features.teamsbot.service import (
TeamsbotService, TeamsbotService,
_activeServices, _activeServices,
_sessionEvents, sessionEvents,
getActiveService, getActiveService,
) )
@ -152,10 +152,10 @@ def _buildService() -> TeamsbotService:
def _resetGlobals(): def _resetGlobals():
"""Avoid cross-test bleed in module-level globals.""" """Avoid cross-test bleed in module-level globals."""
_activeServices.clear() _activeServices.clear()
_sessionEvents.clear() sessionEvents.clear()
yield yield
_activeServices.clear() _activeServices.clear()
_sessionEvents.clear() sessionEvents.clear()
# ============================================================================ # ============================================================================
@ -251,7 +251,7 @@ class TestBuildPersistentDirectorContext:
] ]
rendered = svc._buildPersistentDirectorContext() rendered = svc._buildPersistentDirectorContext()
assert "OPERATOR_DIRECTIVES" in rendered assert "OPERATOR_DIRECTIVES" in rendered
assert "- Antworte immer in Englisch." in rendered assert "Antworte immer in Englisch." in rendered
assert "private" in rendered assert "private" in rendered
def test_skipsBlankText(self): def test_skipsBlankText(self):
@ -261,7 +261,7 @@ class TestBuildPersistentDirectorContext:
{"id": "p2", "text": "Sei hoeflich."}, {"id": "p2", "text": "Sei hoeflich."},
] ]
rendered = svc._buildPersistentDirectorContext() rendered = svc._buildPersistentDirectorContext()
assert "- Sei hoeflich." in rendered assert "Sei hoeflich." in rendered
assert "p1" not in rendered # the blank one is filtered out assert "p1" not in rendered # the blank one is filtered out
def test_allBlankPromptsResultInEmpty(self): def test_allBlankPromptsResultInEmpty(self):