gateway/modules/routes/routeHelpers.py

857 lines
31 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Shared helpers for route handlers.
Provides unified logic for:
- mode=filterValues: distinct column values for filter dropdowns (cross-filtered)
- mode=ids: all IDs matching current filters (for bulk selection)
- In-memory equivalents for enriched/non-SQL routes
"""
import copy
import json
import logging
from typing import Any, Dict, List, Optional, Callable, Union
from fastapi.responses import JSONResponse
from modules.datamodels.datamodelPagination import (
PaginationParams,
normalize_pagination_dict,
)
from modules.shared.i18nRegistry import resolveText
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Central FK label resolvers (cross-DB)
# ---------------------------------------------------------------------------
def resolveMandateLabels(ids: List[str]) -> Dict[str, Optional[str]]:
"""Resolve mandate IDs to labels. Returns None (not the ID!) for
unresolvable entries so the caller can distinguish "resolved" from "missing".
"""
from modules.interfaces.interfaceDbApp import getRootInterface
rootIface = getRootInterface()
mMap = rootIface.getMandatesByIds(ids)
result: Dict[str, Optional[str]] = {}
for mid in ids:
m = mMap.get(mid)
label = (getattr(m, "label", None) or getattr(m, "name", None)) if m else None
if not label:
logger.warning("resolveMandateLabels: no label for id=%s (found=%s)", mid, m is not None)
result[mid] = label or None
return result
def resolveInstanceLabels(ids: List[str]) -> Dict[str, Optional[str]]:
"""Resolve feature-instance IDs to labels. Returns None for unresolvable."""
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.interfaces.interfaceFeatures import getFeatureInterface
rootIface = getRootInterface()
featureIface = getFeatureInterface(rootIface.db)
result: Dict[str, Optional[str]] = {}
for iid in ids:
fi = featureIface.getFeatureInstance(iid)
label = fi.label if fi and fi.label else None
if not label:
logger.warning("resolveInstanceLabels: no label for id=%s (found=%s)", iid, fi is not None)
result[iid] = label
return result
def resolveUserLabels(ids: List[str]) -> Dict[str, Optional[str]]:
"""Resolve user IDs to display names. Returns None for unresolvable."""
from modules.interfaces.interfaceDbApp import getRootInterface
rootIface = getRootInterface()
from modules.datamodels.datamodelUam import UserInDB as _UserInDB
uniqueIds = list(set(ids))
users = rootIface.db.getRecordset(
_UserInDB,
recordFilter={"id": uniqueIds},
)
result: Dict[str, Optional[str]] = {}
found: Dict[str, dict] = {}
for u in (users or []):
uid = u.get("id", "")
found[uid] = u
for uid in ids:
u = found.get(uid)
if u:
result[uid] = u.get("displayName") or u.get("username") or u.get("email") or None
else:
result[uid] = None
return result
def resolveRoleLabels(ids: List[str]) -> Dict[str, Optional[str]]:
"""Resolve Role.id to roleLabel. Returns None for unresolvable."""
if not ids:
return {}
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelRbac import Role as _Role
rootIface = getRootInterface()
recs = rootIface.db.getRecordset(
_Role,
recordFilter={"id": list(set(ids))},
) or []
out: Dict[str, Optional[str]] = {i: None for i in ids}
for r in recs:
rid = r.get("id")
if rid:
out[rid] = r.get("roleLabel") or None
for rid in ids:
if out.get(rid) is None:
logger.warning("resolveRoleLabels: no label for id=%s", rid)
return out
_BUILTIN_FK_RESOLVERS: Dict[str, Callable[[List[str]], Dict[str, str]]] = {
"Mandate": resolveMandateLabels,
"FeatureInstance": resolveInstanceLabels,
"UserInDB": resolveUserLabels,
"Role": resolveRoleLabels,
}
def _buildLabelResolversFromModel(modelClass: type) -> Dict[str, Callable[[List[str]], Dict[str, str]]]:
"""
Auto-build labelResolvers dict from ``json_schema_extra.fk_target`` on a Pydantic model.
Maps field names to resolver functions when the target table has a registered builtin
resolver and ``fk_target.labelField`` is set (non-None).
"""
resolvers: Dict[str, Callable[[List[str]], Dict[str, str]]] = {}
for name, fieldInfo in modelClass.model_fields.items():
extra = fieldInfo.json_schema_extra
if not extra or not isinstance(extra, dict):
continue
tgt = extra.get("fk_target")
if not isinstance(tgt, dict):
continue
if tgt.get("labelField") is None:
continue
fkModel = tgt.get("table")
if fkModel and fkModel in _BUILTIN_FK_RESOLVERS:
resolvers[name] = _BUILTIN_FK_RESOLVERS[fkModel]
return resolvers
def enrichRowsWithFkLabels(
rows: List[Dict[str, Any]],
modelClass: type = None,
*,
labelResolvers: Optional[Dict[str, Callable[[List[str]], Dict[str, Optional[str]]]]] = None,
extraResolvers: Optional[Dict[str, Callable[[List[str]], Dict[str, Optional[str]]]]] = None,
) -> List[Dict[str, Any]]:
"""Add ``{field}Label`` columns to each row for every FK field that has a
registered resolver.
``modelClass`` — if provided, resolvers are auto-built from ``fk_target``
annotations on the Pydantic model (via ``_buildLabelResolversFromModel``).
``labelResolvers`` — explicit resolver map that overrides auto-built ones.
``extraResolvers`` — merged on top of auto-built / explicit resolvers. Use
for ad-hoc fields that are not FK-annotated on the model (e.g.
``createdByUserId`` on billing transactions).
If a label cannot be resolved the ``{field}Label`` value is ``None``
(never the raw ID — that would reintroduce the silent-truncation bug).
"""
resolvers: Dict[str, Callable] = {}
if modelClass is not None and labelResolvers is None:
resolvers = _buildLabelResolversFromModel(modelClass)
elif labelResolvers is not None:
resolvers = dict(labelResolvers)
if extraResolvers:
resolvers.update(extraResolvers)
if not resolvers or not rows:
return rows
for field, resolver in resolvers.items():
ids = list({str(r.get(field)) for r in rows if r.get(field)})
if not ids:
continue
try:
labelMap = resolver(ids)
except Exception as e:
logger.error("enrichRowsWithFkLabels: resolver for '%s' raised: %s", field, e)
labelMap = {}
labelKey = f"{field}Label"
for r in rows:
fkVal = r.get(field)
if fkVal:
r[labelKey] = labelMap.get(str(fkVal))
else:
r[labelKey] = None
return rows
# ---------------------------------------------------------------------------
# Cross-filter pagination parsing
# ---------------------------------------------------------------------------
def parseCrossFilterPagination(
column: str,
paginationJson: Optional[str],
) -> Optional[PaginationParams]:
"""
Parse pagination JSON, remove the requested column from filters (cross-filtering),
and drop sort — used for filter-values requests.
"""
if not paginationJson:
return None
try:
paginationDict = json.loads(paginationJson)
if not paginationDict:
return None
paginationDict = normalize_pagination_dict(paginationDict)
filters = paginationDict.get("filters", {})
filters.pop(column, None)
paginationDict["filters"] = filters
paginationDict.pop("sort", None)
return PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError, TypeError):
return None
def parsePaginationForIds(
paginationJson: Optional[str],
) -> Optional[PaginationParams]:
"""
Parse pagination JSON for mode=ids — keep filters, drop sort and page/pageSize.
"""
if not paginationJson:
return None
try:
paginationDict = json.loads(paginationJson)
if not paginationDict:
return None
paginationDict = normalize_pagination_dict(paginationDict)
paginationDict.pop("sort", None)
return PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError, TypeError):
return None
# ---------------------------------------------------------------------------
# SQL-based helpers (delegate to DB connector)
# ---------------------------------------------------------------------------
def handleFilterValuesMode(
db,
modelClass: type,
column: str,
paginationJson: Optional[str] = None,
recordFilter: Optional[Dict[str, Any]] = None,
enrichFn: Optional[Callable[[str, Optional[PaginationParams], Optional[Dict[str, Any]]], List[str]]] = None,
) -> List[str]:
"""
SQL-based distinct column values with cross-filtering.
If enrichFn is provided and the column is enriched (computed/joined),
enrichFn(column, crossPagination, recordFilter) is called instead of SQL DISTINCT.
"""
crossPagination = parseCrossFilterPagination(column, paginationJson)
if enrichFn:
try:
result = enrichFn(column, crossPagination, recordFilter)
if result is not None:
return JSONResponse(content=result)
except Exception as e:
logger.warning(f"handleFilterValuesMode enrichFn failed for {column}: {e}")
try:
values = db.getDistinctColumnValues(
modelClass, column,
pagination=crossPagination,
recordFilter=recordFilter,
) or []
return JSONResponse(content=values)
except Exception as e:
logger.error(f"handleFilterValuesMode SQL failed for {modelClass.__name__}.{column}: {e}")
return JSONResponse(content=[])
def handleIdsMode(
db,
modelClass: type,
paginationJson: Optional[str] = None,
recordFilter: Optional[Dict[str, Any]] = None,
idField: str = "id",
) -> List[str]:
"""
Return all IDs matching the current filters (no LIMIT/OFFSET).
Uses the same WHERE clause as getRecordsetPaginated.
"""
pagination = parsePaginationForIds(paginationJson)
table = modelClass.__name__
try:
if not db._ensureTableExists(modelClass):
return JSONResponse(content=[])
where_clause, _, _, values, _ = db._buildPaginationClauses(
modelClass, pagination, recordFilter,
)
sql = f'SELECT "{idField}"::TEXT AS val FROM "{table}"{where_clause} ORDER BY "{idField}"'
with db.connection.cursor() as cursor:
cursor.execute(sql, values)
return JSONResponse(content=[row["val"] for row in cursor.fetchall()])
except Exception as e:
logger.error(f"handleIdsMode failed for {table}: {e}")
return JSONResponse(content=[])
# ---------------------------------------------------------------------------
# In-memory helpers (for enriched / non-SQL routes)
# ---------------------------------------------------------------------------
def applyFiltersAndSort(
items: List[Dict[str, Any]],
paginationParams: Optional[PaginationParams],
) -> List[Dict[str, Any]]:
"""
Apply filters and sorting to a list of dicts in-memory.
Does NOT paginate (no page/pageSize slicing).
"""
if not paginationParams:
return items
result = list(items)
if paginationParams.filters:
filters = paginationParams.filters
searchTerm = filters.get("search", "").lower() if filters.get("search") else None
if searchTerm:
result = [
item for item in result
if any(
searchTerm in str(v).lower()
for v in item.values()
if v is not None
)
]
for field, filterValue in filters.items():
if field == "search":
continue
if isinstance(filterValue, dict) and "operator" in filterValue:
operator = filterValue.get("operator", "equals")
value = filterValue.get("value")
else:
operator = "equals"
value = filterValue
if value is None:
result = [
item for item in result
if item.get(field) is None or item.get(field) == ""
]
continue
if value == "":
continue
result = [
item for item in result
if _matchesFilter(item, field, operator, value)
]
if paginationParams.sort:
for sortField in reversed(paginationParams.sort):
fieldName = sortField.field
ascending = sortField.direction == "asc"
noneItems = [item for item in result if item.get(fieldName) is None]
nonNoneItems = [item for item in result if item.get(fieldName) is not None]
def _getSortKey(item: Dict[str, Any], _fn=fieldName):
value = item.get(_fn)
if isinstance(value, bool):
return (0, int(value), "")
if isinstance(value, (int, float)):
return (0, value, "")
return (1, 0, str(value).lower())
nonNoneItems = sorted(nonNoneItems, key=_getSortKey, reverse=not ascending)
result = nonNoneItems + noneItems
return result
def _matchesFilter(item: Dict[str, Any], field: str, operator: str, value: Any) -> bool:
"""Single-field filter match for in-memory filtering."""
itemValue = item.get(field)
if itemValue is None:
return False
itemStr = str(itemValue).lower()
valueStr = str(value).lower()
if operator in ("equals", "eq"):
return itemStr == valueStr
if operator == "contains":
return valueStr in itemStr
if operator == "startsWith":
return itemStr.startswith(valueStr)
if operator == "endsWith":
return itemStr.endswith(valueStr)
if operator in ("gt", "gte", "lt", "lte"):
try:
itemNum = float(itemValue)
valueNum = float(value)
if operator == "gt":
return itemNum > valueNum
if operator == "gte":
return itemNum >= valueNum
if operator == "lt":
return itemNum < valueNum
return itemNum <= valueNum
except (ValueError, TypeError):
return False
if operator == "between":
return _matchesBetween(itemValue, itemStr, value)
if operator == "in":
if isinstance(value, list):
return itemStr in [str(x).lower() for x in value]
return False
if operator == "notIn":
if isinstance(value, list):
return itemStr not in [str(x).lower() for x in value]
return True
return True
def _matchesBetween(itemValue: Any, itemStr: str, value: Any) -> bool:
"""Handle 'between' operator for date ranges and numeric ranges."""
if not isinstance(value, dict):
return True
fromVal = value.get("from", "")
toVal = value.get("to", "")
if not fromVal and not toVal:
return True
try:
from datetime import datetime, timezone
fromTs = None
toTs = None
if fromVal:
fromTs = datetime.strptime(str(fromVal), "%Y-%m-%d").replace(tzinfo=timezone.utc).timestamp()
if toVal:
toTs = datetime.strptime(str(toVal), "%Y-%m-%d").replace(
hour=23, minute=59, second=59, tzinfo=timezone.utc
).timestamp()
itemNum = float(itemValue) if not isinstance(itemValue, (int, float)) else itemValue
if itemNum > 10000000000:
itemNum = itemNum / 1000
if fromTs is not None and toTs is not None:
return fromTs <= itemNum <= toTs
if fromTs is not None:
return itemNum >= fromTs
if toTs is not None:
return itemNum <= toTs
except (ValueError, TypeError):
# Numeric range (e.g. FormGeneratorTable column filters on INTEGER/FLOAT)
try:
itemNum = float(itemValue)
fromNum = float(fromVal) if fromVal not in (None, "") else None
toNum = float(toVal) if toVal not in (None, "") else None
if fromNum is not None and toNum is not None:
return fromNum <= itemNum <= toNum
if fromNum is not None:
return itemNum >= fromNum
if toNum is not None:
return itemNum <= toNum
except (ValueError, TypeError):
pass
fromStr = str(fromVal).lower() if fromVal else ""
toStr = str(toVal).lower() if toVal else ""
if fromStr and toStr:
return fromStr <= itemStr <= toStr
if fromStr:
return itemStr >= fromStr
if toStr:
return itemStr <= toStr
return True
def _extractDistinctValues(
items: List[Dict[str, Any]],
columnKey: str,
requestLang: Optional[str] = None,
) -> list:
"""Extract sorted distinct display values for a column from enriched items.
When the items contain a ``{columnKey}Label`` field (FK enrichment convention),
returns ``{value, label}`` objects so the frontend shows human-readable
labels in filter dropdowns. Otherwise returns plain strings.
Includes ``None`` as the last entry when at least one row has a null/empty
value — this enables the "(Leer)" filter option in the frontend.
"""
_MISSING = object()
labelKey = f"{columnKey}Label"
hasFkLabels = any(labelKey in item for item in items[:20])
if hasFkLabels:
byVal: Dict[str, str] = {}
hasEmpty = False
for item in items:
val = item.get(columnKey, _MISSING)
if val is _MISSING:
continue
if val is None or val == "":
hasEmpty = True
continue
strVal = str(val)
if strVal not in byVal:
label = item.get(labelKey)
byVal[strVal] = str(label) if label else f"NA({strVal[:8]})"
result: list = sorted(
[{"value": v, "label": l} for v, l in byVal.items()],
key=lambda x: x["label"].lower(),
)
if hasEmpty:
result.append(None)
return result
values = set()
hasEmpty = False
for item in items:
val = item.get(columnKey, _MISSING)
if val is _MISSING:
continue
if val is None or val == "":
hasEmpty = True
continue
if isinstance(val, bool):
values.add("true" if val else "false")
elif isinstance(val, (int, float)):
values.add(str(val))
elif isinstance(val, dict):
text = resolveText(val, requestLang)
if text:
values.add(text)
else:
values.add(str(val))
result = sorted(values, key=lambda v: v.lower())
if hasEmpty:
result.append(None)
return result
def handleFilterValuesInMemory(
items: List[Dict[str, Any]],
column: str,
paginationJson: Optional[str] = None,
requestLang: Optional[str] = None,
) -> JSONResponse:
"""
In-memory filter-values: apply cross-filters, then extract distinct values.
For routes that build enriched in-memory lists.
Returns JSONResponse to bypass FastAPI response_model validation.
"""
crossFilterParams = parseCrossFilterPagination(column, paginationJson)
crossFiltered = applyFiltersAndSort(items, crossFilterParams)
return JSONResponse(content=_extractDistinctValues(crossFiltered, column, requestLang))
def handleIdsInMemory(
items: List[Dict[str, Any]],
paginationJson: Optional[str] = None,
idField: str = "id",
) -> JSONResponse:
"""
In-memory IDs: apply filters, return all IDs.
For routes that build enriched in-memory lists.
Returns JSONResponse to bypass FastAPI response_model validation.
"""
pagination = parsePaginationForIds(paginationJson)
filtered = applyFiltersAndSort(items, pagination)
ids = []
for item in filtered:
val = item.get(idField)
if val is not None:
ids.append(str(val))
return JSONResponse(content=ids)
def getRecordsetPaginatedWithFkSort(
db,
modelClass: type,
pagination,
recordFilter: Optional[Dict[str, Any]] = None,
labelResolvers: Optional[Dict[str, Callable[[List[str]], Dict[str, str]]]] = None,
fieldFilter: Optional[List[str]] = None,
idField: str = "id",
) -> Dict[str, Any]:
"""
Wrapper around db.getRecordsetPaginated that handles FK-label sorting.
If the current sort field is a FK with a registered labelResolver, the
function fetches all filtered IDs + FK values, resolves labels cross-DB,
sorts in-memory by label, and returns only the requested page.
If no FK sort is active, delegates directly to db.getRecordsetPaginated.
"""
import math
if not pagination or not pagination.sort:
return db.getRecordsetPaginated(modelClass, pagination, recordFilter, fieldFilter)
if labelResolvers is None:
labelResolvers = _buildLabelResolversFromModel(modelClass)
if not labelResolvers:
return db.getRecordsetPaginated(modelClass, pagination, recordFilter, fieldFilter)
fkSortField = None
fkSortDir = "asc"
for sf in pagination.sort:
sfField = sf.get("field") if isinstance(sf, dict) else getattr(sf, "field", None)
sfDir = sf.get("direction", "asc") if isinstance(sf, dict) else getattr(sf, "direction", "asc")
if sfField and sfField in labelResolvers:
fkSortField = sfField
fkSortDir = str(sfDir).lower()
break
if not fkSortField:
return db.getRecordsetPaginated(modelClass, pagination, recordFilter, fieldFilter)
try:
distinctIds = db.getDistinctColumnValues(
modelClass, fkSortField, recordFilter=recordFilter,
) or []
labelMap = {}
if distinctIds:
try:
labelMap = labelResolvers[fkSortField](distinctIds)
except Exception as e:
logger.warning(f"getRecordsetPaginatedWithFkSort: resolver for {fkSortField} failed: {e}")
filterOnlyPagination = copy.deepcopy(pagination)
filterOnlyPagination.sort = []
filterOnlyPagination.page = 1
filterOnlyPagination.pageSize = 999999
lightRows = db.getRecordsetPaginated(
modelClass, filterOnlyPagination, recordFilter,
fieldFilter=[idField, fkSortField],
)
allRows = lightRows.get("items", [])
totalItems = len(allRows)
if totalItems == 0:
return {"items": [], "totalItems": 0, "totalPages": 0}
def _sortKey(row):
fkVal = row.get(fkSortField, "") or ""
label = labelMap.get(str(fkVal), str(fkVal)).lower()
return label
reverse = fkSortDir == "desc"
allRows.sort(key=_sortKey, reverse=reverse)
pageSize = pagination.pageSize
offset = (pagination.page - 1) * pageSize
pageSlice = allRows[offset:offset + pageSize]
pageIds = [row[idField] for row in pageSlice if row.get(idField)]
if not pageIds:
return {"items": [], "totalItems": totalItems, "totalPages": math.ceil(totalItems / pageSize)}
pageItems = db.getRecordset(modelClass, recordFilter={idField: pageIds}, fieldFilter=fieldFilter)
idOrder = {pid: idx for idx, pid in enumerate(pageIds)}
pageItems.sort(key=lambda r: idOrder.get(r.get(idField), 999999))
enrichRowsWithFkLabels(pageItems, modelClass)
totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0
return {"items": pageItems, "totalItems": totalItems, "totalPages": totalPages}
except Exception as e:
logger.error(f"getRecordsetPaginatedWithFkSort failed for {modelClass.__name__}: {e}")
return db.getRecordsetPaginated(modelClass, pagination, recordFilter, fieldFilter)
def paginateInMemory(
items: List[Dict[str, Any]],
paginationParams: Optional[PaginationParams],
) -> tuple:
"""
Apply pagination (page/pageSize slicing) to an already-filtered+sorted list.
Returns (pageItems, totalItems).
"""
totalItems = len(items)
if not paginationParams:
return items, totalItems
offset = (paginationParams.page - 1) * paginationParams.pageSize
pageItems = items[offset:offset + paginationParams.pageSize]
return pageItems, totalItems
# ---------------------------------------------------------------------------
# Table Grouping helpers
# ---------------------------------------------------------------------------
from dataclasses import dataclass, field as dc_field
@dataclass
class GroupingContext:
"""
Result of handleGroupingInRequest.
Carries the group tree for the response and the resolved item-ID set for
group-scope filtering (None = no active group scope).
"""
groupTree: Optional[list] # List[TableGroupNode] serialised as dicts — for response
itemIds: Optional[set] # Set[str] when groupId was set, else None
def _collectItemIds(nodes: list, groupId: str) -> Optional[set]:
"""
Recursively search *nodes* for a node whose id == groupId and collect
all itemIds from it and all its descendant subGroups.
Returns None if the group is not found.
"""
for node in nodes:
nodeId = node.get("id") if isinstance(node, dict) else getattr(node, "id", None)
if nodeId == groupId:
ids: set = set()
_collectAllIds(node, ids)
return ids
subGroups = node.get("subGroups", []) if isinstance(node, dict) else getattr(node, "subGroups", [])
result = _collectItemIds(subGroups, groupId)
if result is not None:
return result
return None
def _collectAllIds(node, ids: set) -> None:
"""Collect itemIds from a node and all its descendants into ids."""
nodeItemIds = node.get("itemIds", []) if isinstance(node, dict) else getattr(node, "itemIds", [])
for iid in nodeItemIds:
ids.add(str(iid))
subGroups = node.get("subGroups", []) if isinstance(node, dict) else getattr(node, "subGroups", [])
for child in subGroups:
_collectAllIds(child, ids)
def _removeGroupFromTree(nodes: list, groupId: str) -> list:
"""Remove a group node (and all descendants) from the tree by id."""
result = []
for node in nodes:
nodeId = node.get("id") if isinstance(node, dict) else getattr(node, "id", None)
if nodeId == groupId:
continue # skip this node (remove it)
subGroups = node.get("subGroups", []) if isinstance(node, dict) else getattr(node, "subGroups", [])
filtered_sub = _removeGroupFromTree(subGroups, groupId)
if isinstance(node, dict):
node = {**node, "subGroups": filtered_sub}
result.append(node)
return result
def handleGroupingInRequest(
paginationParams: Optional[PaginationParams],
interface,
contextKey: str,
) -> GroupingContext:
"""
Central grouping handler — call at the start of every list route that
supports table grouping.
Steps (in order):
1. If paginationParams.saveGroupTree is set:
persist the new tree via interface.upsertTableGrouping, then clear
saveGroupTree from paginationParams so it is not treated as a filter.
2. Load the current group tree from the DB (used in step 3 and response).
3. If paginationParams.groupId is set:
resolve it to a Set[str] of itemIds (including all sub-groups),
then clear groupId from paginationParams so it is not treated as a
normal filter field.
4. Return a GroupingContext with groupTree (for the response) and itemIds
(for applyGroupScopeFilter).
The caller does NOT need to handle any grouping logic itself — just call
applyGroupScopeFilter(items, groupCtx.itemIds) and embed groupCtx.groupTree
in the response dict.
"""
from modules.datamodels.datamodelPagination import TableGroupNode
groupTree = None
itemIds = None
if paginationParams is None:
try:
existing = interface.getTableGrouping(contextKey)
if existing:
groupTree = [n.model_dump() if hasattr(n, "model_dump") else n for n in existing.rootGroups]
except Exception as e:
logger.warning(f"handleGroupingInRequest: getTableGrouping failed: {e}")
return GroupingContext(groupTree=groupTree, itemIds=None)
# Step 1: persist saveGroupTree if present
if paginationParams.saveGroupTree is not None:
try:
saved = interface.upsertTableGrouping(contextKey, paginationParams.saveGroupTree)
groupTree = [n.model_dump() if hasattr(n, "model_dump") else n for n in saved.rootGroups]
except Exception as e:
logger.error(f"handleGroupingInRequest: upsertTableGrouping failed: {e}")
paginationParams.saveGroupTree = None
# Step 2: load current tree (only if not already set from save above)
if groupTree is None:
try:
existing = interface.getTableGrouping(contextKey)
if existing:
groupTree = [n.model_dump() if hasattr(n, "model_dump") else n for n in existing.rootGroups]
except Exception as e:
logger.warning(f"handleGroupingInRequest: getTableGrouping failed: {e}")
# Step 3: resolve groupId to itemIds set
if paginationParams.groupId is not None:
targetGroupId = paginationParams.groupId
paginationParams.groupId = None # remove so it is not treated as a normal filter
if groupTree:
itemIds = _collectItemIds(groupTree, targetGroupId)
if itemIds is None:
logger.warning(
f"handleGroupingInRequest: groupId={targetGroupId!r} not found in tree "
f"for contextKey={contextKey!r} — returning empty set"
)
itemIds = set() # unknown group → show nothing rather than everything
else:
# groupId sent but no tree saved yet → return empty (nothing belongs to any group)
logger.warning(
f"handleGroupingInRequest: groupId={targetGroupId!r} set but no tree exists "
f"for contextKey={contextKey!r} — returning empty set"
)
itemIds = set()
return GroupingContext(groupTree=groupTree, itemIds=itemIds)
def applyGroupScopeFilter(items: List[Dict[str, Any]], itemIds: Optional[set]) -> List[Dict[str, Any]]:
"""
Filter items to those whose "id" field is in itemIds.
Returns items unchanged when itemIds is None (no active group scope).
Works for both normal list items and for mode=ids / mode=filterValues flows
— call it before handleIdsInMemory / handleFilterValuesInMemory.
"""
if itemIds is None:
return items
return [item for item in items if str(item.get("id", "")) in itemIds]