# Copyright (c) 2025 Patrick Motsch # All rights reserved. """Generic data provider for querying feature-instance tables. Uses the RBAC catalog's DATA_OBJECTS metadata (table name, fields) and the DB connector to execute scoped, read-only queries against any registered feature table. All queries are automatically filtered by featureInstanceId and mandateId so data isolation is guaranteed. """ import logging import json from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) _ALLOWED_OPERATORS = {"=", "!=", ">", "<", ">=", "<=", "LIKE", "ILIKE", "IS NULL", "IS NOT NULL"} class FeatureDataProvider: """Reads feature-instance data from the DB using DATA_OBJECTS metadata.""" def __init__(self, dbConnector): """ Args: dbConnector: A connectorDbPostgre.DatabaseConnector with an open connection. """ self._db = dbConnector # ------------------------------------------------------------------ # public API (called by FeatureDataAgent tools) # ------------------------------------------------------------------ def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]: """Return DATA_OBJECTS registered for *featureCode*.""" from modules.security.rbacCatalog import getCatalogService catalog = getCatalogService() return catalog.getDataObjects(featureCode) def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]: """Return the DATA_OBJECT entry for a specific table.""" for obj in self.getAvailableTables(featureCode): if obj.get("meta", {}).get("table") == tableName: return obj return None def getActualColumns(self, tableName: str) -> List[str]: """Read real column names from PostgreSQL information_schema.""" try: conn = self._db.connection with conn.cursor() as cur: cur.execute( "SELECT column_name FROM information_schema.columns " "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " "ORDER BY ordinal_position", [tableName], ) cols = [row["column_name"] for row in cur.fetchall()] return [c for c in cols if not c.startswith("_")] except Exception as e: logger.warning(f"getActualColumns({tableName}) failed: {e}") return [] def browseTable( self, tableName: str, featureInstanceId: str, mandateId: str, fields: List[str] = None, limit: int = 50, offset: int = 0, extraFilters: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: """List rows from a feature table with pagination. Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``. """ _validateTableName(tableName) conn = self._db.connection scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn) extraWhere, extraParams = _buildFilterClauses(extraFilters) fullWhere = scopeFilter["where"] allParams = list(scopeFilter["params"]) if extraWhere: fullWhere += " AND " + extraWhere allParams.extend(extraParams) try: with conn.cursor() as cur: countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}' cur.execute(countSql, allParams) total = cur.fetchone()["count"] if cur.rowcount else 0 selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*" dataSql = ( f'SELECT {selectCols} FROM "{tableName}" ' f'WHERE {fullWhere} ' f'ORDER BY "id" LIMIT %s OFFSET %s' ) cur.execute(dataSql, allParams + [limit, offset]) rows = [_serializeRow(dict(r)) for r in cur.fetchall()] return {"rows": rows, "total": total, "limit": limit, "offset": offset} except Exception as e: logger.error(f"browseTable({tableName}) failed: {e}") return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)} def queryTable( self, tableName: str, featureInstanceId: str, mandateId: str, filters: List[Dict[str, Any]] = None, fields: List[str] = None, orderBy: str = None, limit: int = 50, offset: int = 0, extraFilters: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: """Query a feature table with optional filters. ``filters`` is a list of ``{"field": "x", "op": "=", "value": "y"}``. ``extraFilters`` are mandatory record-level scoping filters injected by the pipeline. """ _validateTableName(tableName) conn = self._db.connection scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId, dbConnection=conn) combinedFilters = list(filters or []) + list(extraFilters or []) extraWhere, extraParams = _buildFilterClauses(combinedFilters if combinedFilters else None) fullWhere = scopeFilter["where"] allParams = list(scopeFilter["params"]) if extraWhere: fullWhere += " AND " + extraWhere allParams.extend(extraParams) try: with conn.cursor() as cur: countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}' cur.execute(countSql, allParams) total = cur.fetchone()["count"] if cur.rowcount else 0 selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*" orderClause = f'ORDER BY "{orderBy}"' if orderBy and _isValidIdentifier(orderBy) else 'ORDER BY "id"' dataSql = ( f'SELECT {selectCols} FROM "{tableName}" ' f'WHERE {fullWhere} {orderClause} LIMIT %s OFFSET %s' ) cur.execute(dataSql, allParams + [limit, offset]) rows = [_serializeRow(dict(r)) for r in cur.fetchall()] return {"rows": rows, "total": total, "limit": limit, "offset": offset} except Exception as e: logger.error(f"queryTable({tableName}) failed: {e}") return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)} # ------------------------------------------------------------------ # helpers # ------------------------------------------------------------------ _instanceColCache: Dict[str, str] = {} def _resolveInstanceColumn(tableName: str, dbConnection=None) -> str: """Detect whether the table uses ``instanceId`` or ``featureInstanceId``.""" if tableName in _instanceColCache: return _instanceColCache[tableName] if dbConnection: try: with dbConnection.cursor() as cur: cur.execute( "SELECT column_name FROM information_schema.columns " "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " "AND column_name IN ('featureInstanceId', 'instanceId')", [tableName], ) cols = [row["column_name"] for row in cur.fetchall()] if "featureInstanceId" in cols: _instanceColCache[tableName] = "featureInstanceId" return "featureInstanceId" if "instanceId" in cols: _instanceColCache[tableName] = "instanceId" return "instanceId" except Exception: pass return "instanceId" def _validateTableName(tableName: str): if not tableName or not _isValidIdentifier(tableName): raise ValueError(f"Invalid table name: {tableName}") def _isValidIdentifier(name: str) -> bool: """Only allow alphanumeric + underscore to prevent SQL injection.""" return name.isidentifier() def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str, dbConnection=None) -> Dict[str, Any]: """Build the mandatory WHERE clause that scopes rows to the feature instance. Feature tables use either ``instanceId`` (commcoach, teamsbot) or ``featureInstanceId`` (trustee) as the FK. We detect the actual column from ``information_schema`` when a DB connection is provided. """ instanceCol = _resolveInstanceColumn(tableName, dbConnection) conditions = [] params = [] conditions.append(f'"{instanceCol}" = %s') params.append(featureInstanceId) if mandateId: conditions.append('"mandateId" = %s') params.append(mandateId) return {"where": " AND ".join(conditions), "params": params} def _buildFilterClauses(filters: Optional[List[Dict[str, Any]]]) -> tuple: """Convert agent-provided filter dicts into safe SQL.""" if not filters: return "", [] parts = [] params = [] for f in filters: field = f.get("field", "") op = (f.get("op") or "=").upper() value = f.get("value") if not field or not _isValidIdentifier(field): continue if op not in _ALLOWED_OPERATORS: continue if op in ("IS NULL", "IS NOT NULL"): parts.append(f'"{field}" {op}') else: parts.append(f'"{field}" {op} %s') params.append(value) return " AND ".join(parts), params def _serializeRow(row: Dict[str, Any]) -> Dict[str, Any]: """Ensure all values are JSON-serializable.""" for k, v in row.items(): if isinstance(v, (bytes, bytearray)): row[k] = f"" elif hasattr(v, "isoformat"): row[k] = v.isoformat() return row