# Copyright (c) 2026 PowerOn AG # All rights reserved. """In-memory drop-in for FeatureDataProvider used by the eval harness. Implements the same three public methods (browseTable / queryTable / aggregateTable) plus the small surface the Sub-Agent reads (getActualColumns), but runs all filters/aggregations in Python over the BenchmarkFixture rows. This keeps the eval hermetic: no DB connection, no fixtures to insert/clean, no flakiness from shared test schemas. Only the LLM call is real. """ from __future__ import annotations from typing import Any, Dict, List, Optional _ALLOWED_AGGREGATES = {"SUM", "COUNT", "AVG", "MIN", "MAX"} class FakeFeatureDataProvider: """In-memory provider compatible with :class:`FeatureDataProvider`.""" def __init__( self, rowsByTable: Dict[str, List[Dict[str, Any]]], availableTables: Optional[List[Dict[str, Any]]] = None, ) -> None: self._rowsByTable = {name: list(rows) for name, rows in rowsByTable.items()} self._availableTables = list(availableTables or []) self.callLog: List[Dict[str, Any]] = [] def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]: # noqa: ARG002 return list(self._availableTables) def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]: # noqa: ARG002 for obj in self._availableTables: if obj.get("meta", {}).get("table") == tableName: return obj return None def getActualColumns(self, tableName: str) -> List[str]: rows = self._rowsByTable.get(tableName, []) if not rows: return [] seen: List[str] = [] seenSet: set = set() for row in rows: for key in row.keys(): if key not in seenSet: seen.append(key) seenSet.add(key) return seen def browseTable( self, tableName: str, featureInstanceId: str, mandateId: str, fields: List[str] = None, limit: int = 50, offset: int = 0, extraFilters: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: self.callLog.append({"method": "browseTable", "table": tableName, "fields": fields, "limit": limit}) rows = self._scopeRows(tableName, featureInstanceId, mandateId) rows = _applyFilters(rows, extraFilters) total = len(rows) rows = rows[offset : offset + limit] if fields: rows = [{k: v for k, v in row.items() if k in fields} for row in rows] return {"rows": rows, "total": total, "limit": limit, "offset": offset} def queryTable( self, tableName: str, featureInstanceId: str, mandateId: str, filters: List[Dict[str, Any]] = None, fields: List[str] = None, orderBy: str = None, limit: int = 50, offset: int = 0, extraFilters: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: self.callLog.append({ "method": "queryTable", "table": tableName, "filters": filters, "fields": fields, "orderBy": orderBy, "limit": limit, }) rows = self._scopeRows(tableName, featureInstanceId, mandateId) combined = list(filters or []) + list(extraFilters or []) rows = _applyFilters(rows, combined) if orderBy: try: rows = sorted(rows, key=lambda r: (r.get(orderBy) is None, r.get(orderBy))) except TypeError: rows = sorted(rows, key=lambda r: str(r.get(orderBy))) total = len(rows) rows = rows[offset : offset + limit] if fields: rows = [{k: v for k, v in row.items() if k in fields} for row in rows] return {"rows": rows, "total": total, "limit": limit, "offset": offset} def aggregateTable( self, tableName: str, featureInstanceId: str, mandateId: str, aggregate: str, field: str, groupBy: str = None, extraFilters: Optional[List[Dict[str, Any]]] = None, ) -> Dict[str, Any]: self.callLog.append({ "method": "aggregateTable", "table": tableName, "aggregate": aggregate, "field": field, "groupBy": groupBy, }) aggregate = aggregate.upper() if aggregate not in _ALLOWED_AGGREGATES: return {"rows": [], "error": f"Unsupported aggregate: {aggregate}"} rows = self._scopeRows(tableName, featureInstanceId, mandateId) rows = _applyFilters(rows, extraFilters) if groupBy: groups: Dict[Any, List[Dict[str, Any]]] = {} for row in rows: groups.setdefault(row.get(groupBy), []).append(row) outRows = [ {"groupValue": key, "result": _aggregate(aggregate, [r.get(field) for r in grp])} for key, grp in groups.items() ] outRows.sort(key=lambda r: (r["result"] is None, -(r["result"] or 0))) else: outRows = [{"result": _aggregate(aggregate, [r.get(field) for r in rows])}] return { "rows": outRows, "aggregate": aggregate, "field": field, "groupBy": groupBy, } def _scopeRows(self, tableName: str, featureInstanceId: str, mandateId: str) -> List[Dict[str, Any]]: rows = self._rowsByTable.get(tableName, []) return [ row for row in rows if (row.get("featureInstanceId") in (None, featureInstanceId)) and (row.get("mandateId") in (None, mandateId)) ] def _applyFilters(rows: List[Dict[str, Any]], filters: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: if not filters: return rows out = rows for f in filters: field = f.get("field") op = (f.get("op") or "=").upper() value = f.get("value") out = [r for r in out if _matchesFilter(r.get(field), op, value)] return out def _matchesFilter(rowValue: Any, op: str, filterValue: Any) -> bool: if op in ("IS NULL",): return rowValue is None if op in ("IS NOT NULL",): return rowValue is not None if rowValue is None: return False if op == "=": return _coerceEqual(rowValue, filterValue) if op == "!=": return not _coerceEqual(rowValue, filterValue) if op == ">": return _coerceFloat(rowValue) > _coerceFloat(filterValue) if op == "<": return _coerceFloat(rowValue) < _coerceFloat(filterValue) if op == ">=": return _coerceFloat(rowValue) >= _coerceFloat(filterValue) if op == "<=": return _coerceFloat(rowValue) <= _coerceFloat(filterValue) if op in ("LIKE", "ILIKE"): pattern = str(filterValue or "") target = str(rowValue) if op == "ILIKE": pattern = pattern.lower() target = target.lower() return _sqlLike(target, pattern) if op == "IN": if isinstance(filterValue, (list, tuple, set)): return any(_coerceEqual(rowValue, v) for v in filterValue) return _coerceEqual(rowValue, filterValue) return False def _coerceEqual(a: Any, b: Any) -> bool: if a == b: return True try: return str(a) == str(b) except Exception: return False def _coerceFloat(value: Any) -> float: if value is None: return 0.0 try: return float(value) except (TypeError, ValueError): return 0.0 def _sqlLike(value: str, pattern: str) -> bool: """Approximate SQL LIKE -- only % and _ wildcards.""" import re regex = "" i = 0 while i < len(pattern): ch = pattern[i] if ch == "%": regex += ".*" elif ch == "_": regex += "." else: regex += re.escape(ch) i += 1 return re.fullmatch(regex, value or "") is not None def _aggregate(op: str, values: List[Any]) -> Any: if op == "COUNT": return sum(1 for v in values if v is not None) nums = [_coerceFloat(v) for v in values if v is not None] if not nums: return 0 if op == "SUM" else None if op == "SUM": return round(sum(nums), 4) if op == "AVG": return round(sum(nums) / len(nums), 4) if op == "MIN": return min(nums) if op == "MAX": return max(nums) return None