platform-core/tests/eval/fakeFeatureDataProvider.py
ValueOn AG 4a60086c80
Some checks failed
Deploy Plattform-Core (Int) / test (push) Failing after 15s
Deploy Plattform-Core (Int) / deploy (push) Has been skipped
cp adapted to 2026 poweron
2026-06-09 09:53:31 +02:00

246 lines
8.3 KiB
Python

# Copyright (c) 2026 PowerOn AG
# All rights reserved.
"""In-memory drop-in for FeatureDataProvider used by the eval harness.
Implements the same three public methods (browseTable / queryTable /
aggregateTable) plus the small surface the Sub-Agent reads (getActualColumns),
but runs all filters/aggregations in Python over the BenchmarkFixture rows.
This keeps the eval hermetic: no DB connection, no fixtures to insert/clean,
no flakiness from shared test schemas. Only the LLM call is real.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
_ALLOWED_AGGREGATES = {"SUM", "COUNT", "AVG", "MIN", "MAX"}
class FakeFeatureDataProvider:
"""In-memory provider compatible with :class:`FeatureDataProvider`."""
def __init__(
self,
rowsByTable: Dict[str, List[Dict[str, Any]]],
availableTables: Optional[List[Dict[str, Any]]] = None,
) -> None:
self._rowsByTable = {name: list(rows) for name, rows in rowsByTable.items()}
self._availableTables = list(availableTables or [])
self.callLog: List[Dict[str, Any]] = []
def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]: # noqa: ARG002
return list(self._availableTables)
def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]: # noqa: ARG002
for obj in self._availableTables:
if obj.get("meta", {}).get("table") == tableName:
return obj
return None
def getActualColumns(self, tableName: str) -> List[str]:
rows = self._rowsByTable.get(tableName, [])
if not rows:
return []
seen: List[str] = []
seenSet: set = set()
for row in rows:
for key in row.keys():
if key not in seenSet:
seen.append(key)
seenSet.add(key)
return seen
def browseTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
fields: List[str] = None,
limit: int = 50,
offset: int = 0,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
self.callLog.append({"method": "browseTable", "table": tableName, "fields": fields, "limit": limit})
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
rows = _applyFilters(rows, extraFilters)
total = len(rows)
rows = rows[offset : offset + limit]
if fields:
rows = [{k: v for k, v in row.items() if k in fields} for row in rows]
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
def queryTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
filters: List[Dict[str, Any]] = None,
fields: List[str] = None,
orderBy: str = None,
limit: int = 50,
offset: int = 0,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
self.callLog.append({
"method": "queryTable", "table": tableName, "filters": filters,
"fields": fields, "orderBy": orderBy, "limit": limit,
})
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
combined = list(filters or []) + list(extraFilters or [])
rows = _applyFilters(rows, combined)
if orderBy:
try:
rows = sorted(rows, key=lambda r: (r.get(orderBy) is None, r.get(orderBy)))
except TypeError:
rows = sorted(rows, key=lambda r: str(r.get(orderBy)))
total = len(rows)
rows = rows[offset : offset + limit]
if fields:
rows = [{k: v for k, v in row.items() if k in fields} for row in rows]
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
def aggregateTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
aggregate: str,
field: str,
groupBy: str = None,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
self.callLog.append({
"method": "aggregateTable", "table": tableName,
"aggregate": aggregate, "field": field, "groupBy": groupBy,
})
aggregate = aggregate.upper()
if aggregate not in _ALLOWED_AGGREGATES:
return {"rows": [], "error": f"Unsupported aggregate: {aggregate}"}
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
rows = _applyFilters(rows, extraFilters)
if groupBy:
groups: Dict[Any, List[Dict[str, Any]]] = {}
for row in rows:
groups.setdefault(row.get(groupBy), []).append(row)
outRows = [
{"groupValue": key, "result": _aggregate(aggregate, [r.get(field) for r in grp])}
for key, grp in groups.items()
]
outRows.sort(key=lambda r: (r["result"] is None, -(r["result"] or 0)))
else:
outRows = [{"result": _aggregate(aggregate, [r.get(field) for r in rows])}]
return {
"rows": outRows,
"aggregate": aggregate,
"field": field,
"groupBy": groupBy,
}
def _scopeRows(self, tableName: str, featureInstanceId: str, mandateId: str) -> List[Dict[str, Any]]:
rows = self._rowsByTable.get(tableName, [])
return [
row for row in rows
if (row.get("featureInstanceId") in (None, featureInstanceId))
and (row.get("mandateId") in (None, mandateId))
]
def _applyFilters(rows: List[Dict[str, Any]], filters: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
if not filters:
return rows
out = rows
for f in filters:
field = f.get("field")
op = (f.get("op") or "=").upper()
value = f.get("value")
out = [r for r in out if _matchesFilter(r.get(field), op, value)]
return out
def _matchesFilter(rowValue: Any, op: str, filterValue: Any) -> bool:
if op in ("IS NULL",):
return rowValue is None
if op in ("IS NOT NULL",):
return rowValue is not None
if rowValue is None:
return False
if op == "=":
return _coerceEqual(rowValue, filterValue)
if op == "!=":
return not _coerceEqual(rowValue, filterValue)
if op == ">":
return _coerceFloat(rowValue) > _coerceFloat(filterValue)
if op == "<":
return _coerceFloat(rowValue) < _coerceFloat(filterValue)
if op == ">=":
return _coerceFloat(rowValue) >= _coerceFloat(filterValue)
if op == "<=":
return _coerceFloat(rowValue) <= _coerceFloat(filterValue)
if op in ("LIKE", "ILIKE"):
pattern = str(filterValue or "")
target = str(rowValue)
if op == "ILIKE":
pattern = pattern.lower()
target = target.lower()
return _sqlLike(target, pattern)
if op == "IN":
if isinstance(filterValue, (list, tuple, set)):
return any(_coerceEqual(rowValue, v) for v in filterValue)
return _coerceEqual(rowValue, filterValue)
return False
def _coerceEqual(a: Any, b: Any) -> bool:
if a == b:
return True
try:
return str(a) == str(b)
except Exception:
return False
def _coerceFloat(value: Any) -> float:
if value is None:
return 0.0
try:
return float(value)
except (TypeError, ValueError):
return 0.0
def _sqlLike(value: str, pattern: str) -> bool:
"""Approximate SQL LIKE -- only % and _ wildcards."""
import re
regex = ""
i = 0
while i < len(pattern):
ch = pattern[i]
if ch == "%":
regex += ".*"
elif ch == "_":
regex += "."
else:
regex += re.escape(ch)
i += 1
return re.fullmatch(regex, value or "") is not None
def _aggregate(op: str, values: List[Any]) -> Any:
if op == "COUNT":
return sum(1 for v in values if v is not None)
nums = [_coerceFloat(v) for v in values if v is not None]
if not nums:
return 0 if op == "SUM" else None
if op == "SUM":
return round(sum(nums), 4)
if op == "AVG":
return round(sum(nums) / len(nums), 4)
if op == "MIN":
return min(nums)
if op == "MAX":
return max(nums)
return None