317 lines
12 KiB
Python
317 lines
12 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Generic data provider for querying feature-instance tables.
|
|
|
|
Uses the RBAC catalog's DATA_OBJECTS metadata (table name, fields) and the
|
|
DB connector to execute scoped, read-only queries against any registered
|
|
feature table. All queries are automatically filtered by featureInstanceId
|
|
and mandateId so data isolation is guaranteed.
|
|
"""
|
|
|
|
import logging
|
|
import json
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ALLOWED_OPERATORS = {"=", "!=", ">", "<", ">=", "<=", "LIKE", "ILIKE", "IS NULL", "IS NOT NULL"}
|
|
_ALLOWED_AGGREGATES = {"SUM", "COUNT", "AVG", "MIN", "MAX"}
|
|
|
|
|
|
class FeatureDataProvider:
|
|
"""Reads feature-instance data from the DB using DATA_OBJECTS metadata."""
|
|
|
|
def __init__(self, dbConnector):
|
|
"""
|
|
Args:
|
|
dbConnector: A connectorDbPostgre.DatabaseConnector with an open connection.
|
|
"""
|
|
self._db = dbConnector
|
|
|
|
# ------------------------------------------------------------------
|
|
# public API (called by FeatureDataAgent tools)
|
|
# ------------------------------------------------------------------
|
|
|
|
def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]:
|
|
"""Return DATA_OBJECTS registered for *featureCode*."""
|
|
from modules.security.rbacCatalog import getCatalogService
|
|
catalog = getCatalogService()
|
|
return catalog.getDataObjects(featureCode)
|
|
|
|
def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]:
|
|
"""Return the DATA_OBJECT entry for a specific table."""
|
|
for obj in self.getAvailableTables(featureCode):
|
|
if obj.get("meta", {}).get("table") == tableName:
|
|
return obj
|
|
return None
|
|
|
|
def getActualColumns(self, tableName: str) -> List[str]:
|
|
"""Read real column names from PostgreSQL information_schema."""
|
|
try:
|
|
conn = self._db.connection
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"SELECT column_name FROM information_schema.columns "
|
|
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
|
|
"ORDER BY ordinal_position",
|
|
[tableName],
|
|
)
|
|
cols = [row["column_name"] for row in cur.fetchall()]
|
|
return [c for c in cols if not c.startswith("_")]
|
|
except Exception as e:
|
|
logger.warning(f"getActualColumns({tableName}) failed: {e}")
|
|
return []
|
|
|
|
def browseTable(
|
|
self,
|
|
tableName: str,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
fields: List[str] = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""List rows from a feature table with pagination.
|
|
|
|
Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``.
|
|
"""
|
|
_validateTableName(tableName)
|
|
conn = self._db.connection
|
|
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
|
|
extraWhere, extraParams = _buildFilterClauses(extraFilters)
|
|
|
|
fullWhere = scopeFilter["where"]
|
|
allParams = list(scopeFilter["params"])
|
|
if extraWhere:
|
|
fullWhere += " AND " + extraWhere
|
|
allParams.extend(extraParams)
|
|
|
|
try:
|
|
with conn.cursor() as cur:
|
|
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
|
|
cur.execute(countSql, allParams)
|
|
total = cur.fetchone()["count"] if cur.rowcount else 0
|
|
|
|
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
|
|
dataSql = (
|
|
f'SELECT {selectCols} FROM "{tableName}" '
|
|
f'WHERE {fullWhere} '
|
|
f'ORDER BY "id" LIMIT %s OFFSET %s'
|
|
)
|
|
cur.execute(dataSql, allParams + [limit, offset])
|
|
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
|
|
|
|
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
|
|
except Exception as e:
|
|
logger.error(f"browseTable({tableName}) failed: {e}")
|
|
return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
|
|
|
|
def aggregateTable(
|
|
self,
|
|
tableName: str,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
aggregate: str,
|
|
field: str,
|
|
groupBy: str = None,
|
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Run an aggregate query (SUM, COUNT, AVG, MIN, MAX) on a feature table.
|
|
|
|
Returns ``{"rows": [{"groupValue": ..., "result": ...}], "aggregate": ..., "field": ..., "groupBy": ...}``.
|
|
"""
|
|
_validateTableName(tableName)
|
|
aggregate = aggregate.upper()
|
|
if aggregate not in _ALLOWED_AGGREGATES:
|
|
return {"rows": [], "error": f"Unsupported aggregate: {aggregate}. Allowed: {', '.join(sorted(_ALLOWED_AGGREGATES))}"}
|
|
if not _isValidIdentifier(field):
|
|
return {"rows": [], "error": f"Invalid field name: {field}"}
|
|
if groupBy and not _isValidIdentifier(groupBy):
|
|
return {"rows": [], "error": f"Invalid groupBy field: {groupBy}"}
|
|
|
|
conn = self._db.connection
|
|
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
|
|
extraWhere, extraParams = _buildFilterClauses(extraFilters)
|
|
|
|
fullWhere = scopeFilter["where"]
|
|
allParams = list(scopeFilter["params"])
|
|
if extraWhere:
|
|
fullWhere += " AND " + extraWhere
|
|
allParams.extend(extraParams)
|
|
|
|
try:
|
|
with conn.cursor() as cur:
|
|
if groupBy:
|
|
sql = (
|
|
f'SELECT "{groupBy}" AS "groupValue", {aggregate}("{field}") AS "result" '
|
|
f'FROM "{tableName}" WHERE {fullWhere} '
|
|
f'GROUP BY "{groupBy}" ORDER BY "result" DESC'
|
|
)
|
|
else:
|
|
sql = (
|
|
f'SELECT {aggregate}("{field}") AS "result" '
|
|
f'FROM "{tableName}" WHERE {fullWhere}'
|
|
)
|
|
cur.execute(sql, allParams)
|
|
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
|
|
|
|
return {
|
|
"rows": rows,
|
|
"aggregate": aggregate,
|
|
"field": field,
|
|
"groupBy": groupBy,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"aggregateTable({tableName}, {aggregate}({field})) failed: {e}")
|
|
return {"rows": [], "error": str(e), "aggregate": aggregate, "field": field, "groupBy": groupBy}
|
|
|
|
def queryTable(
|
|
self,
|
|
tableName: str,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
filters: List[Dict[str, Any]] = None,
|
|
fields: List[str] = None,
|
|
orderBy: str = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Query a feature table with optional filters.
|
|
|
|
``filters`` is a list of ``{"field": "x", "op": "=", "value": "y"}``.
|
|
``extraFilters`` are mandatory record-level scoping filters injected by the pipeline.
|
|
"""
|
|
_validateTableName(tableName)
|
|
conn = self._db.connection
|
|
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
|
|
|
|
combinedFilters = list(filters or []) + list(extraFilters or [])
|
|
extraWhere, extraParams = _buildFilterClauses(combinedFilters if combinedFilters else None)
|
|
|
|
fullWhere = scopeFilter["where"]
|
|
allParams = list(scopeFilter["params"])
|
|
if extraWhere:
|
|
fullWhere += " AND " + extraWhere
|
|
allParams.extend(extraParams)
|
|
|
|
try:
|
|
with conn.cursor() as cur:
|
|
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
|
|
cur.execute(countSql, allParams)
|
|
total = cur.fetchone()["count"] if cur.rowcount else 0
|
|
|
|
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
|
|
orderClause = f'ORDER BY "{orderBy}"' if orderBy and _isValidIdentifier(orderBy) else 'ORDER BY "id"'
|
|
dataSql = (
|
|
f'SELECT {selectCols} FROM "{tableName}" '
|
|
f'WHERE {fullWhere} {orderClause} LIMIT %s OFFSET %s'
|
|
)
|
|
cur.execute(dataSql, allParams + [limit, offset])
|
|
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
|
|
|
|
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
|
|
except Exception as e:
|
|
logger.error(f"queryTable({tableName}) failed: {e}")
|
|
return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
_instanceColCache: Dict[str, str] = {}
|
|
|
|
|
|
def _resolveInstanceColumn(tableName: str, dbConnection=None) -> str:
|
|
"""Detect whether the table uses ``instanceId`` or ``featureInstanceId``."""
|
|
if tableName in _instanceColCache:
|
|
return _instanceColCache[tableName]
|
|
if dbConnection:
|
|
try:
|
|
with dbConnection.cursor() as cur:
|
|
cur.execute(
|
|
"SELECT column_name FROM information_schema.columns "
|
|
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
|
|
"AND column_name IN ('featureInstanceId', 'instanceId')",
|
|
[tableName],
|
|
)
|
|
cols = [row["column_name"] for row in cur.fetchall()]
|
|
if "featureInstanceId" in cols:
|
|
_instanceColCache[tableName] = "featureInstanceId"
|
|
return "featureInstanceId"
|
|
if "instanceId" in cols:
|
|
_instanceColCache[tableName] = "instanceId"
|
|
return "instanceId"
|
|
except Exception:
|
|
pass
|
|
return "instanceId"
|
|
|
|
|
|
def _validateTableName(tableName: str):
|
|
if not tableName or not _isValidIdentifier(tableName):
|
|
raise ValueError(f"Invalid table name: {tableName}")
|
|
|
|
|
|
def _isValidIdentifier(name: str) -> bool:
|
|
"""Only allow alphanumeric + underscore to prevent SQL injection."""
|
|
return name.isidentifier()
|
|
|
|
|
|
def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, dbConnection=None) -> Dict[str, Any]:
|
|
"""Build the mandatory WHERE clause that scopes rows to the feature instance.
|
|
|
|
Feature tables use either ``instanceId`` (commcoach, teamsbot) or
|
|
``featureInstanceId`` (trustee) as the FK. We detect the actual column
|
|
from ``information_schema`` when a DB connection is provided.
|
|
"""
|
|
instanceCol = _resolveInstanceColumn(tableName, dbConnection)
|
|
|
|
conditions = []
|
|
params = []
|
|
|
|
conditions.append(f'"{instanceCol}" = %s')
|
|
params.append(featureInstanceId)
|
|
|
|
if mandateId:
|
|
conditions.append('"mandateId" = %s')
|
|
params.append(mandateId)
|
|
|
|
return {"where": " AND ".join(conditions), "params": params}
|
|
|
|
|
|
def _buildFilterClauses(filters: Optional[List[Dict[str, Any]]]) -> tuple:
|
|
"""Convert agent-provided filter dicts into safe SQL."""
|
|
if not filters:
|
|
return "", []
|
|
|
|
parts = []
|
|
params = []
|
|
for f in filters:
|
|
field = f.get("field", "")
|
|
op = (f.get("op") or "=").upper()
|
|
value = f.get("value")
|
|
|
|
if not field or not _isValidIdentifier(field):
|
|
continue
|
|
if op not in _ALLOWED_OPERATORS:
|
|
continue
|
|
|
|
if op in ("IS NULL", "IS NOT NULL"):
|
|
parts.append(f'"{field}" {op}')
|
|
else:
|
|
parts.append(f'"{field}" {op} %s')
|
|
params.append(value)
|
|
|
|
return " AND ".join(parts), params
|
|
|
|
|
|
def _serializeRow(row: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Ensure all values are JSON-serializable."""
|
|
for k, v in row.items():
|
|
if isinstance(v, (bytes, bytearray)):
|
|
row[k] = f"<binary {len(v)} bytes>"
|
|
elif hasattr(v, "isoformat"):
|
|
row[k] = v.isoformat()
|
|
return row
|