455 lines
18 KiB
Python
455 lines
18 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Generic data provider for querying feature-instance tables.
|
|
|
|
Uses the RBAC catalog's DATA_OBJECTS metadata (table name, fields) and the
|
|
DB connector to execute scoped, read-only queries against any registered
|
|
feature table. All queries are automatically filtered by featureInstanceId
|
|
and mandateId so data isolation is guaranteed.
|
|
"""
|
|
|
|
import hashlib
|
|
import logging
|
|
import json
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Set
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEBUG_DIR = Path("D:/Athi/Local/Web/poweron/local/debug")
|
|
|
|
|
|
def _isDebugEnabled() -> bool:
|
|
try:
|
|
from modules.shared.configuration import APP_CONFIG
|
|
val = APP_CONFIG.get("APP_LOGGING_FILE_ENABLED", False)
|
|
return val is True or str(val).lower() == "true"
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _debugQueryLog(method: str, tableName: str, params: dict, result: dict, elapsed: float):
|
|
"""Append query + result to local/debug/debug_queryTable.log."""
|
|
if not _isDebugEnabled():
|
|
return
|
|
debugDir = _DEBUG_DIR
|
|
try:
|
|
debugDir.mkdir(parents=True, exist_ok=True)
|
|
logPath = debugDir / "debug_queryTable.log"
|
|
ts = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
rows = result.get("rows", [])
|
|
total = result.get("total", len(rows))
|
|
err = result.get("error")
|
|
header = f"[{ts}] {method}({tableName}) — {len(rows)} rows returned, total={total}, elapsed={elapsed:.0f}ms"
|
|
if err:
|
|
header += f", ERROR={err}"
|
|
lines = [header]
|
|
lines.append(f" params: {json.dumps(params, ensure_ascii=False, default=str)}")
|
|
for i, row in enumerate(rows):
|
|
lines.append(f" [{i}] {json.dumps(row, ensure_ascii=False, default=str)}")
|
|
lines.append("")
|
|
with open(logPath, "a", encoding="utf-8") as f:
|
|
f.write("\n".join(lines) + "\n")
|
|
except Exception:
|
|
pass
|
|
|
|
_ALLOWED_OPERATORS = {"=", "!=", ">", "<", ">=", "<=", "LIKE", "ILIKE", "IS NULL", "IS NOT NULL"}
|
|
_ALLOWED_AGGREGATES = {"SUM", "COUNT", "AVG", "MIN", "MAX"}
|
|
|
|
|
|
class FeatureDataProvider:
|
|
"""Reads feature-instance data from the DB using DATA_OBJECTS metadata."""
|
|
|
|
def __init__(self, dbConnector, neutralizeFields: Optional[Dict[str, List[str]]] = None):
|
|
"""
|
|
Args:
|
|
dbConnector: A connectorDbPostgre.DatabaseConnector with an open connection.
|
|
neutralizeFields: Per-table field names whose values must be replaced
|
|
with placeholders before returning to the AI, e.g.
|
|
``{"TrusteePosition": ["firstName", "lastName", "address"]}``.
|
|
"""
|
|
self._db = dbConnector
|
|
self._neutralizeFields: Dict[str, Set[str]] = {
|
|
tbl: set(fields) for tbl, fields in (neutralizeFields or {}).items()
|
|
}
|
|
|
|
# ------------------------------------------------------------------
|
|
# public API (called by FeatureDataAgent tools)
|
|
# ------------------------------------------------------------------
|
|
|
|
def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]:
|
|
"""Return DATA_OBJECTS registered for *featureCode*."""
|
|
from modules.security.rbacCatalog import getCatalogService
|
|
catalog = getCatalogService()
|
|
return catalog.getDataObjects(featureCode)
|
|
|
|
def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]:
|
|
"""Return the DATA_OBJECT entry for a specific table."""
|
|
for obj in self.getAvailableTables(featureCode):
|
|
if obj.get("meta", {}).get("table") == tableName:
|
|
return obj
|
|
return None
|
|
|
|
def getActualColumns(self, tableName: str) -> List[str]:
|
|
"""Read real column names from PostgreSQL information_schema."""
|
|
try:
|
|
conn = self._db.connection
|
|
with conn.cursor() as cur:
|
|
cur.execute(
|
|
"SELECT column_name FROM information_schema.columns "
|
|
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
|
|
"ORDER BY ordinal_position",
|
|
[tableName],
|
|
)
|
|
cols = [row["column_name"] for row in cur.fetchall()]
|
|
return [c for c in cols if not c.startswith("_")]
|
|
except Exception as e:
|
|
logger.warning(f"getActualColumns({tableName}) failed: {e}")
|
|
return []
|
|
|
|
def _applyFieldNeutralization(self, tableName: str, rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Neutralize sensitive field values in query results before they reach the AI."""
|
|
fieldsToNeut = self._neutralizeFields.get(tableName)
|
|
if not fieldsToNeut:
|
|
return rows
|
|
return [_neutralizeRowFields(row, fieldsToNeut) for row in rows]
|
|
|
|
def browseTable(
|
|
self,
|
|
tableName: str,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
fields: List[str] = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""List rows from a feature table with pagination.
|
|
|
|
Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``.
|
|
"""
|
|
_validateTableName(tableName)
|
|
conn = self._db.connection
|
|
|
|
if fields:
|
|
invalid = [f for f in fields if not _isValidIdentifier(f)]
|
|
if invalid:
|
|
return {
|
|
"rows": [], "total": 0, "limit": limit, "offset": offset,
|
|
"error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.",
|
|
}
|
|
|
|
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
|
|
extraWhere, extraParams = _buildFilterClauses(extraFilters)
|
|
|
|
fullWhere = scopeFilter["where"]
|
|
allParams = list(scopeFilter["params"])
|
|
if extraWhere:
|
|
fullWhere += " AND " + extraWhere
|
|
allParams.extend(extraParams)
|
|
|
|
t0 = time.time()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
|
|
cur.execute(countSql, allParams)
|
|
total = cur.fetchone()["count"] if cur.rowcount else 0
|
|
|
|
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
|
|
dataSql = (
|
|
f'SELECT {selectCols} FROM "{tableName}" '
|
|
f'WHERE {fullWhere} '
|
|
f'ORDER BY "id" LIMIT %s OFFSET %s'
|
|
)
|
|
cur.execute(dataSql, allParams + [limit, offset])
|
|
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
|
|
|
|
rows = self._applyFieldNeutralization(tableName, rows)
|
|
result = {"rows": rows, "total": total, "limit": limit, "offset": offset}
|
|
_debugQueryLog("browseTable", tableName, {
|
|
"fields": fields, "limit": limit, "offset": offset,
|
|
}, result, (time.time() - t0) * 1000)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"browseTable({tableName}) failed: {e}")
|
|
elapsed = (time.time() - t0) * 1000
|
|
errResult = {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
|
|
_debugQueryLog("browseTable", tableName, {
|
|
"fields": fields, "limit": limit, "offset": offset,
|
|
}, errResult, elapsed)
|
|
try:
|
|
conn.rollback()
|
|
except Exception:
|
|
pass
|
|
return errResult
|
|
|
|
def aggregateTable(
|
|
self,
|
|
tableName: str,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
aggregate: str,
|
|
field: str,
|
|
groupBy: str = None,
|
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Run an aggregate query (SUM, COUNT, AVG, MIN, MAX) on a feature table.
|
|
|
|
Returns ``{"rows": [{"groupValue": ..., "result": ...}], "aggregate": ..., "field": ..., "groupBy": ...}``.
|
|
"""
|
|
_validateTableName(tableName)
|
|
aggregate = aggregate.upper()
|
|
if aggregate not in _ALLOWED_AGGREGATES:
|
|
return {"rows": [], "error": f"Unsupported aggregate: {aggregate}. Allowed: {', '.join(sorted(_ALLOWED_AGGREGATES))}"}
|
|
if not _isValidIdentifier(field):
|
|
return {"rows": [], "error": f"Invalid field name: {field}"}
|
|
if groupBy and not _isValidIdentifier(groupBy):
|
|
return {"rows": [], "error": f"Invalid groupBy field: {groupBy}"}
|
|
|
|
conn = self._db.connection
|
|
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
|
|
extraWhere, extraParams = _buildFilterClauses(extraFilters)
|
|
|
|
fullWhere = scopeFilter["where"]
|
|
allParams = list(scopeFilter["params"])
|
|
if extraWhere:
|
|
fullWhere += " AND " + extraWhere
|
|
allParams.extend(extraParams)
|
|
|
|
t0 = time.time()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
if groupBy:
|
|
sql = (
|
|
f'SELECT "{groupBy}" AS "groupValue", {aggregate}("{field}") AS "result" '
|
|
f'FROM "{tableName}" WHERE {fullWhere} '
|
|
f'GROUP BY "{groupBy}" ORDER BY "result" DESC'
|
|
)
|
|
else:
|
|
sql = (
|
|
f'SELECT {aggregate}("{field}") AS "result" '
|
|
f'FROM "{tableName}" WHERE {fullWhere}'
|
|
)
|
|
cur.execute(sql, allParams)
|
|
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
|
|
|
|
rows = self._applyFieldNeutralization(tableName, rows)
|
|
result = {
|
|
"rows": rows,
|
|
"aggregate": aggregate,
|
|
"field": field,
|
|
"groupBy": groupBy,
|
|
}
|
|
_debugQueryLog("aggregateTable", tableName, {
|
|
"aggregate": aggregate, "field": field, "groupBy": groupBy,
|
|
}, result, (time.time() - t0) * 1000)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"aggregateTable({tableName}, {aggregate}({field})) failed: {e}")
|
|
elapsed = (time.time() - t0) * 1000
|
|
errResult = {"rows": [], "error": str(e), "aggregate": aggregate, "field": field, "groupBy": groupBy}
|
|
_debugQueryLog("aggregateTable", tableName, {
|
|
"aggregate": aggregate, "field": field, "groupBy": groupBy,
|
|
}, errResult, elapsed)
|
|
try:
|
|
conn.rollback()
|
|
except Exception:
|
|
pass
|
|
return errResult
|
|
|
|
def queryTable(
|
|
self,
|
|
tableName: str,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
filters: List[Dict[str, Any]] = None,
|
|
fields: List[str] = None,
|
|
orderBy: str = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Query a feature table with optional filters.
|
|
|
|
``filters`` is a list of ``{"field": "x", "op": "=", "value": "y"}``.
|
|
``extraFilters`` are mandatory record-level scoping filters injected by the pipeline.
|
|
"""
|
|
_validateTableName(tableName)
|
|
conn = self._db.connection
|
|
|
|
if fields:
|
|
invalid = [f for f in fields if not _isValidIdentifier(f)]
|
|
if invalid:
|
|
return {
|
|
"rows": [], "total": 0, "limit": limit, "offset": offset,
|
|
"error": f"Invalid field name(s): {', '.join(invalid)}. Use getActualColumns to discover valid column names.",
|
|
}
|
|
|
|
scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn)
|
|
|
|
combinedFilters = list(filters or []) + list(extraFilters or [])
|
|
extraWhere, extraParams = _buildFilterClauses(combinedFilters if combinedFilters else None)
|
|
|
|
fullWhere = scopeFilter["where"]
|
|
allParams = list(scopeFilter["params"])
|
|
if extraWhere:
|
|
fullWhere += " AND " + extraWhere
|
|
allParams.extend(extraParams)
|
|
|
|
t0 = time.time()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}'
|
|
cur.execute(countSql, allParams)
|
|
total = cur.fetchone()["count"] if cur.rowcount else 0
|
|
|
|
selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*"
|
|
orderClause = f'ORDER BY "{orderBy}"' if orderBy and _isValidIdentifier(orderBy) else 'ORDER BY "id"'
|
|
dataSql = (
|
|
f'SELECT {selectCols} FROM "{tableName}" '
|
|
f'WHERE {fullWhere} {orderClause} LIMIT %s OFFSET %s'
|
|
)
|
|
cur.execute(dataSql, allParams + [limit, offset])
|
|
rows = [_serializeRow(dict(r)) for r in cur.fetchall()]
|
|
|
|
rows = self._applyFieldNeutralization(tableName, rows)
|
|
result = {"rows": rows, "total": total, "limit": limit, "offset": offset}
|
|
_debugQueryLog("queryTable", tableName, {
|
|
"filters": filters, "fields": fields, "orderBy": orderBy,
|
|
"limit": limit, "offset": offset,
|
|
}, result, (time.time() - t0) * 1000)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"queryTable({tableName}) failed: {e}")
|
|
elapsed = (time.time() - t0) * 1000
|
|
errResult = {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)}
|
|
_debugQueryLog("queryTable", tableName, {
|
|
"filters": filters, "fields": fields, "orderBy": orderBy,
|
|
"limit": limit, "offset": offset,
|
|
}, errResult, elapsed)
|
|
try:
|
|
conn.rollback()
|
|
except Exception:
|
|
pass
|
|
return errResult
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
_instanceColCache: Dict[str, str] = {}
|
|
|
|
|
|
def _resolveInstanceColumn(tableName: str, dbConnection=None) -> str:
|
|
"""Detect whether the table uses ``instanceId`` or ``featureInstanceId``."""
|
|
if tableName in _instanceColCache:
|
|
return _instanceColCache[tableName]
|
|
if dbConnection:
|
|
try:
|
|
with dbConnection.cursor() as cur:
|
|
cur.execute(
|
|
"SELECT column_name FROM information_schema.columns "
|
|
"WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) "
|
|
"AND column_name IN ('featureInstanceId', 'instanceId')",
|
|
[tableName],
|
|
)
|
|
cols = [row["column_name"] for row in cur.fetchall()]
|
|
if "featureInstanceId" in cols:
|
|
_instanceColCache[tableName] = "featureInstanceId"
|
|
return "featureInstanceId"
|
|
if "instanceId" in cols:
|
|
_instanceColCache[tableName] = "instanceId"
|
|
return "instanceId"
|
|
except Exception:
|
|
pass
|
|
return "instanceId"
|
|
|
|
|
|
def _validateTableName(tableName: str):
|
|
if not tableName or not _isValidIdentifier(tableName):
|
|
raise ValueError(f"Invalid table name: {tableName}")
|
|
|
|
|
|
def _isValidIdentifier(name: str) -> bool:
|
|
"""Only allow alphanumeric + underscore to prevent SQL injection."""
|
|
return name.isidentifier()
|
|
|
|
|
|
def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, dbConnection=None) -> Dict[str, Any]:
|
|
"""Build the mandatory WHERE clause that scopes rows to the feature instance.
|
|
|
|
Feature tables use either ``instanceId`` (commcoach, teamsbot) or
|
|
``featureInstanceId`` (trustee) as the FK. We detect the actual column
|
|
from ``information_schema`` when a DB connection is provided.
|
|
"""
|
|
instanceCol = _resolveInstanceColumn(tableName, dbConnection)
|
|
|
|
conditions = []
|
|
params = []
|
|
|
|
conditions.append(f'"{instanceCol}" = %s')
|
|
params.append(featureInstanceId)
|
|
|
|
if mandateId:
|
|
conditions.append('"mandateId" = %s')
|
|
params.append(mandateId)
|
|
|
|
return {"where": " AND ".join(conditions), "params": params}
|
|
|
|
|
|
def _buildFilterClauses(filters: Optional[List[Dict[str, Any]]]) -> tuple:
|
|
"""Convert agent-provided filter dicts into safe SQL."""
|
|
if not filters:
|
|
return "", []
|
|
|
|
parts = []
|
|
params = []
|
|
for f in filters:
|
|
field = f.get("field", "")
|
|
op = (f.get("op") or "=").upper()
|
|
value = f.get("value")
|
|
|
|
if not field or not _isValidIdentifier(field):
|
|
continue
|
|
if op not in _ALLOWED_OPERATORS:
|
|
continue
|
|
|
|
if op in ("IS NULL", "IS NOT NULL"):
|
|
parts.append(f'"{field}" {op}')
|
|
else:
|
|
parts.append(f'"{field}" {op} %s')
|
|
params.append(value)
|
|
|
|
return " AND ".join(parts), params
|
|
|
|
|
|
def _serializeRow(row: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Ensure all values are JSON-serializable."""
|
|
for k, v in row.items():
|
|
if isinstance(v, (bytes, bytearray)):
|
|
row[k] = f"<binary {len(v)} bytes>"
|
|
elif hasattr(v, "isoformat"):
|
|
row[k] = v.isoformat()
|
|
return row
|
|
|
|
|
|
_PLACEHOLDER_PREFIX = "NEUT"
|
|
|
|
|
|
def _neutralizeRowFields(row: Dict[str, Any], fieldsToNeutralize: Set[str]) -> Dict[str, Any]:
|
|
"""Replace values in sensitive fields with stable, deterministic placeholders.
|
|
|
|
The placeholder format ``[NEUT.<field>.<short-hash>]`` is stable for the same
|
|
value so that identical values in different rows produce the same token.
|
|
This allows the AI to reason about equality without seeing the real data.
|
|
"""
|
|
for field in fieldsToNeutralize:
|
|
val = row.get(field)
|
|
if val is None or val == "":
|
|
continue
|
|
shortHash = hashlib.sha256(str(val).encode()).hexdigest()[:8]
|
|
row[field] = f"[{_PLACEHOLDER_PREFIX}.{field}.{shortHash}]"
|
|
return row
|