gateway/modules/serviceCenter/services/serviceAgent/featureDataProvider.py

317 lines
12 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Generic data provider for querying feature-instance tables.
Uses the RBAC catalog's DATA_OBJECTS metadata (table name, fields) and the
DB connector to execute scoped, read-only queries against any registered
feature table. All queries are automatically filtered by featureInstanceId
and mandateId so data isolation is guaranteed.
"""
import logging
import json
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
_ALLOWED_OPERATORS = {"=", "!=", ">", "<", ">=", "<=", "LIKE", "ILIKE", "IS NULL", "IS NOT NULL"}
_ALLOWED_AGGREGATES = {"SUM", "COUNT", "AVG", "MIN", "MAX"}
class FeatureDataProvider:
"""Reads feature-instance data from the DB using DATA_OBJECTS metadata."""
def __init__(self, dbConnector):
"""
Args:
dbConnector: A connectorDbPostgre.DatabaseConnector with an open connection.
"""
self._db = dbConnector
# ------------------------------------------------------------------
# public API (called by FeatureDataAgent tools)
# ------------------------------------------------------------------
def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]:
"""Return DATA_OBJECTS registered for *featureCode*."""
from modules.security.rbacCatalog import getCatalogService
catalog = getCatalogService()
return catalog.getDataObjects(featureCode)
def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]:
"""Return the DATA_OBJECT entry for a specific table."""
for obj in self.getAvailableTables(featureCode):
if obj.get("meta", {}).get("table") == tableName:
return obj
return None
def getActualColumns(self, tableName: str) -> List[str]:
"""Read real column names from PostgreSQL information_schema."""
try:
conn = self._db.connection
with conn.cursor() as cur:
cur.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
"ORDER BY ordinal_position",
[tableName],
)
cols = [row["column_name"] for row in cur.fetchall()]
return [c for c in cols if not c.startswith("_")]
except Exception as e:
logger.warning(f"getActualColumns({tableName}) failed: {e}")
return []
def browseTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
fields: List[str] = None,
limit: int = 50,
offset: int = 0,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""List rows from a feature table with pagination.
Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``.
"""
_validateTableName(tableName)
conn = self._db.connection
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
extraWhere, extraParams = _buildFilterClauses(extraFilters)
fullWhere = scopeFilter["where"]
allParams = list(scopeFilter["params"])
if extraWhere:
fullWhere += " AND " + extraWhere
allParams.extend(extraParams)
try:
with conn.cursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
cur.execute(countSql, allParams)
total = cur.fetchone()["count"] if cur.rowcount else 0
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
dataSql = (
f'SELECT {selectCols} FROM "{tableName}" '
f'WHERE {fullWhere} '
f'ORDER BY "id" LIMIT %s OFFSET %s'
)
cur.execute(dataSql, allParams + [limit, offset])
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
except Exception as e:
logger.error(f"browseTable({tableName}) failed: {e}")
return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
def aggregateTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
aggregate: str,
field: str,
groupBy: str = None,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""Run an aggregate query (SUM, COUNT, AVG, MIN, MAX) on a feature table.
Returns ``{"rows": [{"groupValue": ..., "result": ...}], "aggregate": ..., "field": ..., "groupBy": ...}``.
"""
_validateTableName(tableName)
aggregate = aggregate.upper()
if aggregate not in _ALLOWED_AGGREGATES:
return {"rows": [], "error": f"Unsupported aggregate: {aggregate}. Allowed: {', '.join(sorted(_ALLOWED_AGGREGATES))}"}
if not _isValidIdentifier(field):
return {"rows": [], "error": f"Invalid field name: {field}"}
if groupBy and not _isValidIdentifier(groupBy):
return {"rows": [], "error": f"Invalid groupBy field: {groupBy}"}
conn = self._db.connection
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
extraWhere, extraParams = _buildFilterClauses(extraFilters)
fullWhere = scopeFilter["where"]
allParams = list(scopeFilter["params"])
if extraWhere:
fullWhere += " AND " + extraWhere
allParams.extend(extraParams)
try:
with conn.cursor() as cur:
if groupBy:
sql = (
f'SELECT "{groupBy}" AS "groupValue", {aggregate}("{field}") AS "result" '
f'FROM "{tableName}" WHERE {fullWhere} '
f'GROUP BY "{groupBy}" ORDER BY "result" DESC'
)
else:
sql = (
f'SELECT {aggregate}("{field}") AS "result" '
f'FROM "{tableName}" WHERE {fullWhere}'
)
cur.execute(sql, allParams)
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
return {
"rows": rows,
"aggregate": aggregate,
"field": field,
"groupBy": groupBy,
}
except Exception as e:
logger.error(f"aggregateTable({tableName}, {aggregate}({field})) failed: {e}")
return {"rows": [], "error": str(e), "aggregate": aggregate, "field": field, "groupBy": groupBy}
def queryTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
filters: List[Dict[str, Any]] = None,
fields: List[str] = None,
orderBy: str = None,
limit: int = 50,
offset: int = 0,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""Query a feature table with optional filters.
``filters`` is a list of ``{"field": "x", "op": "=", "value": "y"}``.
``extraFilters`` are mandatory record-level scoping filters injected by the pipeline.
"""
_validateTableName(tableName)
conn = self._db.connection
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
combinedFilters = list(filters or []) + list(extraFilters or [])
extraWhere, extraParams = _buildFilterClauses(combinedFilters if combinedFilters else None)
fullWhere = scopeFilter["where"]
allParams = list(scopeFilter["params"])
if extraWhere:
fullWhere += " AND " + extraWhere
allParams.extend(extraParams)
try:
with conn.cursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
cur.execute(countSql, allParams)
total = cur.fetchone()["count"] if cur.rowcount else 0
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
orderClause = f'ORDER BY "{orderBy}"' if orderBy and _isValidIdentifier(orderBy) else 'ORDER BY "id"'
dataSql = (
f'SELECT {selectCols} FROM "{tableName}" '
f'WHERE {fullWhere} {orderClause} LIMIT %s OFFSET %s'
)
cur.execute(dataSql, allParams + [limit, offset])
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
except Exception as e:
logger.error(f"queryTable({tableName}) failed: {e}")
return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
# ------------------------------------------------------------------
# helpers
# ------------------------------------------------------------------
_instanceColCache: Dict[str, str] = {}
def _resolveInstanceColumn(tableName: str, dbConnection=None) -> str:
"""Detect whether the table uses ``instanceId`` or ``featureInstanceId``."""
if tableName in _instanceColCache:
return _instanceColCache[tableName]
if dbConnection:
try:
with dbConnection.cursor() as cur:
cur.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
"AND column_name IN ('featureInstanceId', 'instanceId')",
[tableName],
)
cols = [row["column_name"] for row in cur.fetchall()]
if "featureInstanceId" in cols:
_instanceColCache[tableName] = "featureInstanceId"
return "featureInstanceId"
if "instanceId" in cols:
_instanceColCache[tableName] = "instanceId"
return "instanceId"
except Exception:
pass
return "instanceId"
def _validateTableName(tableName: str):
if not tableName or not _isValidIdentifier(tableName):
raise ValueError(f"Invalid table name: {tableName}")
def _isValidIdentifier(name: str) -> bool:
"""Only allow alphanumeric + underscore to prevent SQL injection."""
return name.isidentifier()
def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, dbConnection=None) -> Dict[str, Any]:
"""Build the mandatory WHERE clause that scopes rows to the feature instance.
Feature tables use either ``instanceId`` (commcoach, teamsbot) or
``featureInstanceId`` (trustee) as the FK. We detect the actual column
from ``information_schema`` when a DB connection is provided.
"""
instanceCol = _resolveInstanceColumn(tableName, dbConnection)
conditions = []
params = []
conditions.append(f'"{instanceCol}" = %s')
params.append(featureInstanceId)
if mandateId:
conditions.append('"mandateId" = %s')
params.append(mandateId)
return {"where": " AND ".join(conditions), "params": params}
def _buildFilterClauses(filters: Optional[List[Dict[str, Any]]]) -> tuple:
"""Convert agent-provided filter dicts into safe SQL."""
if not filters:
return "", []
parts = []
params = []
for f in filters:
field = f.get("field", "")
op = (f.get("op") or "=").upper()
value = f.get("value")
if not field or not _isValidIdentifier(field):
continue
if op not in _ALLOWED_OPERATORS:
continue
if op in ("IS NULL", "IS NOT NULL"):
parts.append(f'"{field}" {op}')
else:
parts.append(f'"{field}" {op} %s')
params.append(value)
return " AND ".join(parts), params
def _serializeRow(row: Dict[str, Any]) -> Dict[str, Any]:
"""Ensure all values are JSON-serializable."""
for k, v in row.items():
if isinstance(v, (bytes, bytearray)):
row[k] = f"<binary {len(v)} bytes>"
elif hasattr(v, "isoformat"):
row[k] = v.isoformat()
return row