gateway/modules/serviceCenter/services/serviceAgent/featureDataProvider.py
2026-04-16 23:13:05 +02:00

455 lines
18 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Generic data provider for querying feature-instance tables.
Uses the RBAC catalog's DATA_OBJECTS metadata (table name, fields) and the
DB connector to execute scoped, read-only queries against any registered
feature table. All queries are automatically filtered by featureInstanceId
and mandateId so data isolation is guaranteed.
"""
import hashlib
import logging
import json
import os
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
logger = logging.getLogger(__name__)
_DEBUG_DIR = Path("D:/Athi/Local/Web/poweron/local/debug")
def _isDebugEnabled() -> bool:
try:
from modules.shared.configuration import APP_CONFIG
val = APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", False)
return val is True or str(val).lower() == "true"
except Exception:
return False
def _debugQueryLog(method: str, tableName: str, params: dict, result: dict, elapsed: float):
"""Append query + result to local/debug/debug_queryTable.log."""
if not _isDebugEnabled():
return
debugDir = _DEBUG_DIR
try:
debugDir.mkdir(parents=True, exist_ok=True)
logPath = debugDir / "debug_queryTable.log"
ts = time.strftime("%Y-%m-%d %H:%M:%S")
rows = result.get("rows", [])
total = result.get("total", len(rows))
err = result.get("error")
header = f"[{ts}] {method}({tableName}) — {len(rows)} rows returned, total={total}, elapsed={elapsed:.0f}ms"
if err:
header += f", ERROR={err}"
lines = [header]
lines.append(f" params: {json.dumps(params, ensure_ascii=False, default=str)}")
for i, row in enumerate(rows):
lines.append(f" [{i}] {json.dumps(row, ensure_ascii=False, default=str)}")
lines.append("")
with open(logPath, "a", encoding="utf-8") as f:
f.write("\n".join(lines) + "\n")
except Exception:
pass
_ALLOWED_OPERATORS = {"=", "!=", ">", "<", ">=", "<=", "LIKE", "ILIKE", "IS NULL", "IS NOT NULL"}
_ALLOWED_AGGREGATES = {"SUM", "COUNT", "AVG", "MIN", "MAX"}
class FeatureDataProvider:
"""Reads feature-instance data from the DB using DATA_OBJECTS metadata."""
def __init__(self, dbConnector, neutralizeFields: Optional[Dict[str, List[str]]] = None):
"""
Args:
dbConnector: A connectorDbPostgre.DatabaseConnector with an open connection.
neutralizeFields: Per-table field names whose values must be replaced
with placeholders before returning to the AI, e.g.
``{"TrusteePosition": ["firstName", "lastName", "address"]}``.
"""
self._db = dbConnector
self._neutralizeFields: Dict[str, Set[str]] = {
tbl: set(fields) for tbl, fields in (neutralizeFields or {}).items()
}
# ------------------------------------------------------------------
# public API (called by FeatureDataAgent tools)
# ------------------------------------------------------------------
def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]:
"""Return DATA_OBJECTS registered for *featureCode*."""
from modules.security.rbacCatalog import getCatalogService
catalog = getCatalogService()
return catalog.getDataObjects(featureCode)
def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]:
"""Return the DATA_OBJECT entry for a specific table."""
for obj in self.getAvailableTables(featureCode):
if obj.get("meta", {}).get("table") == tableName:
return obj
return None
def getActualColumns(self, tableName: str) -> List[str]:
"""Read real column names from PostgreSQL information_schema."""
try:
conn = self._db.connection
with conn.cursor() as cur:
cur.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
"ORDER BY ordinal_position",
[tableName],
)
cols = [row["column_name"] for row in cur.fetchall()]
return [c for c in cols if not c.startswith("_")]
except Exception as e:
logger.warning(f"getActualColumns({tableName}) failed: {e}")
return []
def _applyFieldNeutralization(self, tableName: str, rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Neutralize sensitive field values in query results before they reach the AI."""
fieldsToNeut = self._neutralizeFields.get(tableName)
if not fieldsToNeut:
return rows
return [_neutralizeRowFields(row, fieldsToNeut) for row in rows]
def browseTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
fields: List[str] = None,
limit: int = 50,
offset: int = 0,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""List rows from a feature table with pagination.
Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``.
"""
_validateTableName(tableName)
conn = self._db.connection
if fields:
invalid = [f for f in fields if not _isValidIdentifier(f)]
if invalid:
return {
"rows": [], "total": 0, "limit": limit, "offset": offset,
"error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.",
}
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
extraWhere, extraParams = _buildFilterClauses(extraFilters)
fullWhere = scopeFilter["where"]
allParams = list(scopeFilter["params"])
if extraWhere:
fullWhere += " AND " + extraWhere
allParams.extend(extraParams)
t0 = time.time()
try:
with conn.cursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
cur.execute(countSql, allParams)
total = cur.fetchone()["count"] if cur.rowcount else 0
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
dataSql = (
f'SELECT {selectCols} FROM "{tableName}" '
f'WHERE {fullWhere} '
f'ORDER BY "id" LIMIT %s OFFSET %s'
)
cur.execute(dataSql, allParams + [limit, offset])
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
rows = self._applyFieldNeutralization(tableName, rows)
result = {"rows": rows, "total": total, "limit": limit, "offset": offset}
_debugQueryLog("browseTable", tableName, {
"fields": fields, "limit": limit, "offset": offset,
}, result, (time.time() - t0) * 1000)
return result
except Exception as e:
logger.error(f"browseTable({tableName}) failed: {e}")
elapsed = (time.time() - t0) * 1000
errResult = {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
_debugQueryLog("browseTable", tableName, {
"fields": fields, "limit": limit, "offset": offset,
}, errResult, elapsed)
try:
conn.rollback()
except Exception:
pass
return errResult
def aggregateTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
aggregate: str,
field: str,
groupBy: str = None,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""Run an aggregate query (SUM, COUNT, AVG, MIN, MAX) on a feature table.
Returns ``{"rows": [{"groupValue": ..., "result": ...}], "aggregate": ..., "field": ..., "groupBy": ...}``.
"""
_validateTableName(tableName)
aggregate = aggregate.upper()
if aggregate not in _ALLOWED_AGGREGATES:
return {"rows": [], "error": f"Unsupported aggregate: {aggregate}. Allowed: {', '.join(sorted(_ALLOWED_AGGREGATES))}"}
if not _isValidIdentifier(field):
return {"rows": [], "error": f"Invalid field name: {field}"}
if groupBy and not _isValidIdentifier(groupBy):
return {"rows": [], "error": f"Invalid groupBy field: {groupBy}"}
conn = self._db.connection
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
extraWhere, extraParams = _buildFilterClauses(extraFilters)
fullWhere = scopeFilter["where"]
allParams = list(scopeFilter["params"])
if extraWhere:
fullWhere += " AND " + extraWhere
allParams.extend(extraParams)
t0 = time.time()
try:
with conn.cursor() as cur:
if groupBy:
sql = (
f'SELECT "{groupBy}" AS "groupValue", {aggregate}("{field}") AS "result" '
f'FROM "{tableName}" WHERE {fullWhere} '
f'GROUP BY "{groupBy}" ORDER BY "result" DESC'
)
else:
sql = (
f'SELECT {aggregate}("{field}") AS "result" '
f'FROM "{tableName}" WHERE {fullWhere}'
)
cur.execute(sql, allParams)
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
rows = self._applyFieldNeutralization(tableName, rows)
result = {
"rows": rows,
"aggregate": aggregate,
"field": field,
"groupBy": groupBy,
}
_debugQueryLog("aggregateTable", tableName, {
"aggregate": aggregate, "field": field, "groupBy": groupBy,
}, result, (time.time() - t0) * 1000)
return result
except Exception as e:
logger.error(f"aggregateTable({tableName}, {aggregate}({field})) failed: {e}")
elapsed = (time.time() - t0) * 1000
errResult = {"rows": [], "error": str(e), "aggregate": aggregate, "field": field, "groupBy": groupBy}
_debugQueryLog("aggregateTable", tableName, {
"aggregate": aggregate, "field": field, "groupBy": groupBy,
}, errResult, elapsed)
try:
conn.rollback()
except Exception:
pass
return errResult
def queryTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
filters: List[Dict[str, Any]] = None,
fields: List[str] = None,
orderBy: str = None,
limit: int = 50,
offset: int = 0,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
"""Query a feature table with optional filters.
``filters`` is a list of ``{"field": "x", "op": "=", "value": "y"}``.
``extraFilters`` are mandatory record-level scoping filters injected by the pipeline.
"""
_validateTableName(tableName)
conn = self._db.connection
if fields:
invalid = [f for f in fields if not _isValidIdentifier(f)]
if invalid:
return {
"rows": [], "total": 0, "limit": limit, "offset": offset,
"error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.",
}
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
combinedFilters = list(filters or []) + list(extraFilters or [])
extraWhere, extraParams = _buildFilterClauses(combinedFilters if combinedFilters else None)
fullWhere = scopeFilter["where"]
allParams = list(scopeFilter["params"])
if extraWhere:
fullWhere += " AND " + extraWhere
allParams.extend(extraParams)
t0 = time.time()
try:
with conn.cursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
cur.execute(countSql, allParams)
total = cur.fetchone()["count"] if cur.rowcount else 0
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
orderClause = f'ORDER BY "{orderBy}"' if orderBy and _isValidIdentifier(orderBy) else 'ORDER BY "id"'
dataSql = (
f'SELECT {selectCols} FROM "{tableName}" '
f'WHERE {fullWhere} {orderClause} LIMIT %s OFFSET %s'
)
cur.execute(dataSql, allParams + [limit, offset])
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
rows = self._applyFieldNeutralization(tableName, rows)
result = {"rows": rows, "total": total, "limit": limit, "offset": offset}
_debugQueryLog("queryTable", tableName, {
"filters": filters, "fields": fields, "orderBy": orderBy,
"limit": limit, "offset": offset,
}, result, (time.time() - t0) * 1000)
return result
except Exception as e:
logger.error(f"queryTable({tableName}) failed: {e}")
elapsed = (time.time() - t0) * 1000
errResult = {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
_debugQueryLog("queryTable", tableName, {
"filters": filters, "fields": fields, "orderBy": orderBy,
"limit": limit, "offset": offset,
}, errResult, elapsed)
try:
conn.rollback()
except Exception:
pass
return errResult
# ------------------------------------------------------------------
# helpers
# ------------------------------------------------------------------
_instanceColCache: Dict[str, str] = {}
def _resolveInstanceColumn(tableName: str, dbConnection=None) -> str:
"""Detect whether the table uses ``instanceId`` or ``featureInstanceId``."""
if tableName in _instanceColCache:
return _instanceColCache[tableName]
if dbConnection:
try:
with dbConnection.cursor() as cur:
cur.execute(
"SELECT column_name FROM information_schema.columns "
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
"AND column_name IN ('featureInstanceId', 'instanceId')",
[tableName],
)
cols = [row["column_name"] for row in cur.fetchall()]
if "featureInstanceId" in cols:
_instanceColCache[tableName] = "featureInstanceId"
return "featureInstanceId"
if "instanceId" in cols:
_instanceColCache[tableName] = "instanceId"
return "instanceId"
except Exception:
pass
return "instanceId"
def _validateTableName(tableName: str):
if not tableName or not _isValidIdentifier(tableName):
raise ValueError(f"Invalid table name: {tableName}")
def _isValidIdentifier(name: str) -> bool:
"""Only allow alphanumeric + underscore to prevent SQL injection."""
return name.isidentifier()
def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, dbConnection=None) -> Dict[str, Any]:
"""Build the mandatory WHERE clause that scopes rows to the feature instance.
Feature tables use either ``instanceId`` (commcoach, teamsbot) or
``featureInstanceId`` (trustee) as the FK. We detect the actual column
from ``information_schema`` when a DB connection is provided.
"""
instanceCol = _resolveInstanceColumn(tableName, dbConnection)
conditions = []
params = []
conditions.append(f'"{instanceCol}" = %s')
params.append(featureInstanceId)
if mandateId:
conditions.append('"mandateId" = %s')
params.append(mandateId)
return {"where": " AND ".join(conditions), "params": params}
def _buildFilterClauses(filters: Optional[List[Dict[str, Any]]]) -> tuple:
"""Convert agent-provided filter dicts into safe SQL."""
if not filters:
return "", []
parts = []
params = []
for f in filters:
field = f.get("field", "")
op = (f.get("op") or "=").upper()
value = f.get("value")
if not field or not _isValidIdentifier(field):
continue
if op not in _ALLOWED_OPERATORS:
continue
if op in ("IS NULL", "IS NOT NULL"):
parts.append(f'"{field}" {op}')
else:
parts.append(f'"{field}" {op} %s')
params.append(value)
return " AND ".join(parts), params
def _serializeRow(row: Dict[str, Any]) -> Dict[str, Any]:
"""Ensure all values are JSON-serializable."""
for k, v in row.items():
if isinstance(v, (bytes, bytearray)):
row[k] = f"<binary {len(v)} bytes>"
elif hasattr(v, "isoformat"):
row[k] = v.isoformat()
return row
_PLACEHOLDER_PREFIX = "NEUT"
def _neutralizeRowFields(row: Dict[str, Any], fieldsToNeutralize: Set[str]) -> Dict[str, Any]:
"""Replace values in sensitive fields with stable, deterministic placeholders.
The placeholder format ``[NEUT.<field>.<short-hash>]`` is stable for the same
value so that identical values in different rows produce the same token.
This allows the AI to reason about equality without seeing the real data.
"""
for field in fieldsToNeutralize:
val = row.get(field)
if val is None or val == "":
continue
shortHash = hashlib.sha256(str(val).encode()).hexdigest()[:8]
row[field] = f"[{_PLACEHOLDER_PREFIX}.{field}.{shortHash}]"
return row