246 lines
8.3 KiB
Python
246 lines
8.3 KiB
Python
# Copyright (c) 2026 Patrick Motsch
|
|
# All rights reserved.
|
|
"""In-memory drop-in for FeatureDataProvider used by the eval harness.
|
|
|
|
Implements the same three public methods (browseTable / queryTable /
|
|
aggregateTable) plus the small surface the Sub-Agent reads (getActualColumns),
|
|
but runs all filters/aggregations in Python over the BenchmarkFixture rows.
|
|
|
|
This keeps the eval hermetic: no DB connection, no fixtures to insert/clean,
|
|
no flakiness from shared test schemas. Only the LLM call is real.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
|
|
_ALLOWED_AGGREGATES = {"SUM", "COUNT", "AVG", "MIN", "MAX"}
|
|
|
|
|
|
class FakeFeatureDataProvider:
|
|
"""In-memory provider compatible with :class:`FeatureDataProvider`."""
|
|
|
|
def __init__(
|
|
self,
|
|
rowsByTable: Dict[str, List[Dict[str, Any]]],
|
|
availableTables: Optional[List[Dict[str, Any]]] = None,
|
|
) -> None:
|
|
self._rowsByTable = {name: list(rows) for name, rows in rowsByTable.items()}
|
|
self._availableTables = list(availableTables or [])
|
|
self.callLog: List[Dict[str, Any]] = []
|
|
|
|
def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]: # noqa: ARG002
|
|
return list(self._availableTables)
|
|
|
|
def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]: # noqa: ARG002
|
|
for obj in self._availableTables:
|
|
if obj.get("meta", {}).get("table") == tableName:
|
|
return obj
|
|
return None
|
|
|
|
def getActualColumns(self, tableName: str) -> List[str]:
|
|
rows = self._rowsByTable.get(tableName, [])
|
|
if not rows:
|
|
return []
|
|
seen: List[str] = []
|
|
seenSet: set = set()
|
|
for row in rows:
|
|
for key in row.keys():
|
|
if key not in seenSet:
|
|
seen.append(key)
|
|
seenSet.add(key)
|
|
return seen
|
|
|
|
def browseTable(
|
|
self,
|
|
tableName: str,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
fields: List[str] = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Dict[str, Any]:
|
|
self.callLog.append({"method": "browseTable", "table": tableName, "fields": fields, "limit": limit})
|
|
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
|
|
rows = _applyFilters(rows, extraFilters)
|
|
total = len(rows)
|
|
rows = rows[offset : offset + limit]
|
|
if fields:
|
|
rows = [{k: v for k, v in row.items() if k in fields} for row in rows]
|
|
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
|
|
|
|
def queryTable(
|
|
self,
|
|
tableName: str,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
filters: List[Dict[str, Any]] = None,
|
|
fields: List[str] = None,
|
|
orderBy: str = None,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Dict[str, Any]:
|
|
self.callLog.append({
|
|
"method": "queryTable", "table": tableName, "filters": filters,
|
|
"fields": fields, "orderBy": orderBy, "limit": limit,
|
|
})
|
|
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
|
|
combined = list(filters or []) + list(extraFilters or [])
|
|
rows = _applyFilters(rows, combined)
|
|
if orderBy:
|
|
try:
|
|
rows = sorted(rows, key=lambda r: (r.get(orderBy) is None, r.get(orderBy)))
|
|
except TypeError:
|
|
rows = sorted(rows, key=lambda r: str(r.get(orderBy)))
|
|
total = len(rows)
|
|
rows = rows[offset : offset + limit]
|
|
if fields:
|
|
rows = [{k: v for k, v in row.items() if k in fields} for row in rows]
|
|
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
|
|
|
|
def aggregateTable(
|
|
self,
|
|
tableName: str,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
aggregate: str,
|
|
field: str,
|
|
groupBy: str = None,
|
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
|
) -> Dict[str, Any]:
|
|
self.callLog.append({
|
|
"method": "aggregateTable", "table": tableName,
|
|
"aggregate": aggregate, "field": field, "groupBy": groupBy,
|
|
})
|
|
aggregate = aggregate.upper()
|
|
if aggregate not in _ALLOWED_AGGREGATES:
|
|
return {"rows": [], "error": f"Unsupported aggregate: {aggregate}"}
|
|
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
|
|
rows = _applyFilters(rows, extraFilters)
|
|
|
|
if groupBy:
|
|
groups: Dict[Any, List[Dict[str, Any]]] = {}
|
|
for row in rows:
|
|
groups.setdefault(row.get(groupBy), []).append(row)
|
|
outRows = [
|
|
{"groupValue": key, "result": _aggregate(aggregate, [r.get(field) for r in grp])}
|
|
for key, grp in groups.items()
|
|
]
|
|
outRows.sort(key=lambda r: (r["result"] is None, -(r["result"] or 0)))
|
|
else:
|
|
outRows = [{"result": _aggregate(aggregate, [r.get(field) for r in rows])}]
|
|
|
|
return {
|
|
"rows": outRows,
|
|
"aggregate": aggregate,
|
|
"field": field,
|
|
"groupBy": groupBy,
|
|
}
|
|
|
|
def _scopeRows(self, tableName: str, featureInstanceId: str, mandateId: str) -> List[Dict[str, Any]]:
|
|
rows = self._rowsByTable.get(tableName, [])
|
|
return [
|
|
row for row in rows
|
|
if (row.get("featureInstanceId") in (None, featureInstanceId))
|
|
and (row.get("mandateId") in (None, mandateId))
|
|
]
|
|
|
|
|
|
def _applyFilters(rows: List[Dict[str, Any]], filters: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
|
if not filters:
|
|
return rows
|
|
out = rows
|
|
for f in filters:
|
|
field = f.get("field")
|
|
op = (f.get("op") or "=").upper()
|
|
value = f.get("value")
|
|
out = [r for r in out if _matchesFilter(r.get(field), op, value)]
|
|
return out
|
|
|
|
|
|
def _matchesFilter(rowValue: Any, op: str, filterValue: Any) -> bool:
|
|
if op in ("IS NULL",):
|
|
return rowValue is None
|
|
if op in ("IS NOT NULL",):
|
|
return rowValue is not None
|
|
if rowValue is None:
|
|
return False
|
|
if op == "=":
|
|
return _coerceEqual(rowValue, filterValue)
|
|
if op == "!=":
|
|
return not _coerceEqual(rowValue, filterValue)
|
|
if op == ">":
|
|
return _coerceFloat(rowValue) > _coerceFloat(filterValue)
|
|
if op == "<":
|
|
return _coerceFloat(rowValue) < _coerceFloat(filterValue)
|
|
if op == ">=":
|
|
return _coerceFloat(rowValue) >= _coerceFloat(filterValue)
|
|
if op == "<=":
|
|
return _coerceFloat(rowValue) <= _coerceFloat(filterValue)
|
|
if op in ("LIKE", "ILIKE"):
|
|
pattern = str(filterValue or "")
|
|
target = str(rowValue)
|
|
if op == "ILIKE":
|
|
pattern = pattern.lower()
|
|
target = target.lower()
|
|
return _sqlLike(target, pattern)
|
|
if op == "IN":
|
|
if isinstance(filterValue, (list, tuple, set)):
|
|
return any(_coerceEqual(rowValue, v) for v in filterValue)
|
|
return _coerceEqual(rowValue, filterValue)
|
|
return False
|
|
|
|
|
|
def _coerceEqual(a: Any, b: Any) -> bool:
|
|
if a == b:
|
|
return True
|
|
try:
|
|
return str(a) == str(b)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _coerceFloat(value: Any) -> float:
|
|
if value is None:
|
|
return 0.0
|
|
try:
|
|
return float(value)
|
|
except (TypeError, ValueError):
|
|
return 0.0
|
|
|
|
|
|
def _sqlLike(value: str, pattern: str) -> bool:
|
|
"""Approximate SQL LIKE -- only % and _ wildcards."""
|
|
import re
|
|
regex = ""
|
|
i = 0
|
|
while i < len(pattern):
|
|
ch = pattern[i]
|
|
if ch == "%":
|
|
regex += ".*"
|
|
elif ch == "_":
|
|
regex += "."
|
|
else:
|
|
regex += re.escape(ch)
|
|
i += 1
|
|
return re.fullmatch(regex, value or "") is not None
|
|
|
|
|
|
def _aggregate(op: str, values: List[Any]) -> Any:
|
|
if op == "COUNT":
|
|
return sum(1 for v in values if v is not None)
|
|
nums = [_coerceFloat(v) for v in values if v is not None]
|
|
if not nums:
|
|
return 0 if op == "SUM" else None
|
|
if op == "SUM":
|
|
return round(sum(nums), 4)
|
|
if op == "AVG":
|
|
return round(sum(nums) / len(nums), 4)
|
|
if op == "MIN":
|
|
return min(nums)
|
|
if op == "MAX":
|
|
return max(nums)
|
|
return None
|