platform-core/tests/eval/runTrusteeBenchmark.py
ValueOn AG 4f8473bd70
Some checks failed
Deploy Plattform-Core (Int) / test (push) Failing after 1m2s
Deploy Plattform-Core (Int) / deploy (push) Has been skipped
cleaned servicebag and removed servicehub
2026-06-08 23:35:31 +02:00

754 lines
27 KiB
Python

# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Trustee Sub-Agent Eval Harness (Phase 1.5).
Standalone runner that fires real AI calls against the Feature Data
Sub-Agent in three configurations:
* ``baseline`` -- production code without the pre-execute validator
(Repair-Loop disabled, Trustee domain hints active).
* ``phase1`` -- pre-execute validator on (Repair-Loop active),
domain hints active, no ontology yet.
* ``phase2`` -- validator on, ontology-driven schema context +
constraints (replaces hand-written domain hints).
For each mode we run all 19 gold-standard questions against an
in-memory :class:`FakeFeatureDataProvider`, capture the agent's tool
calls and final answer, score them against the gold standard, and
write a Markdown report to ``local/notes/`` for analysis.
Usage::
cd gateway
python -m tests.eval.runTrusteeBenchmark # all 3 modes
python -m tests.eval.runTrusteeBenchmark phase1 # one mode only
python -m tests.eval.runTrusteeBenchmark baseline phase1
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import sys
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# ---------------------------------------------------------------------------
# Path setup so `python -m tests.eval.runTrusteeBenchmark` works from gateway/
# ---------------------------------------------------------------------------
_GATEWAY_DIR = Path(__file__).resolve().parents[2]
if str(_GATEWAY_DIR) not in sys.path:
sys.path.insert(0, str(_GATEWAY_DIR))
import yaml # noqa: E402
from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( # noqa: E402
AgentConfig,
AgentEventTypeEnum,
)
from modules.datamodels.datamodelAi import ( # noqa: E402
AiCallRequest,
AiCallResponse,
OperationTypeEnum,
)
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop # noqa: E402
from modules.serviceCenter.services.serviceAgent.featureDataAgent import ( # noqa: E402
_buildSubAgentTools,
_buildSchemaContext,
)
from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( # noqa: E402
QueryValidationError,
)
from modules.serviceCenter.services.serviceAgent.queryValidator import ( # noqa: E402
QueryValidator,
)
from tests.eval.fakeFeatureDataProvider import ( # noqa: E402
FakeFeatureDataProvider,
)
from tests.fixtures.trusteeBenchmark.loadTrusteeBenchmarkFixture import ( # noqa: E402
buildTrusteeBenchmarkFixture,
BenchmarkFixture,
)
logger = logging.getLogger("trusteeBenchmark")
# ---------------------------------------------------------------------------
# NoOpValidator -- baseline mode (Repair-Loop OFF)
# ---------------------------------------------------------------------------
class _NoOpValidator(QueryValidator):
"""Validator that never rejects anything (used for baseline measurement)."""
def validateBrowseQuery(self, tableName, args): # noqa: ARG002
return None
def validateQueryTable(self, tableName, args): # noqa: ARG002
return None
def validateAggregateQuery(self, tableName, args): # noqa: ARG002
return None
# ---------------------------------------------------------------------------
# Mode-specific tool/prompt building
# ---------------------------------------------------------------------------
@dataclass
class _ModeConfig:
name: str
label: str
useValidator: bool
useOntology: bool
_MODES: Dict[str, _ModeConfig] = {
"baseline": _ModeConfig(name="baseline", label="Baseline (no validator)", useValidator=False, useOntology=False),
"phase1": _ModeConfig(name="phase1", label="Phase 1 (validator on)", useValidator=True, useOntology=False),
"phase2": _ModeConfig(name="phase2", label="Phase 2 (validator + ontology)", useValidator=True, useOntology=True),
}
def _buildValidator(mode: _ModeConfig) -> QueryValidator:
"""Construct the per-mode validator.
* baseline: no-op (Repair-Loop disabled, used to measure raw LLM
accuracy against today's prompt path).
* phase1: convention-based QueryValidator (NEVER_AGGREGATE on
``*Balance``/``*Total`` suffixes; no ontology).
* phase2: ontology-driven QueryValidator (constraints from the
trustee ontology override the convention defaults).
"""
if not mode.useValidator:
return _NoOpValidator()
if mode.useOntology:
try:
from modules.features.trustee.trusteeOntology import getTrusteeOntology
return QueryValidator(ontology=getTrusteeOntology())
except Exception as e:
logger.warning("Could not load trustee ontology, falling back: %s", e)
return QueryValidator()
def _applyEnvForMode(mode: _ModeConfig) -> None:
"""Set the ontology toggle for the production prompt builder.
The Phase 2 path uses ``featureDataAgent._buildSchemaContext`` to pull
the prompt block from ``getAgentOntology()`` automatically. For
baseline/phase1 we set ``POWERON_DISABLE_FEATURE_ONTOLOGY=1`` so the
builder falls back to the legacy ``getAgentDomainHints()`` block --
measuring exactly the production prompt that ships today.
"""
if mode.useOntology:
os.environ.pop("POWERON_DISABLE_FEATURE_ONTOLOGY", None)
else:
os.environ["POWERON_DISABLE_FEATURE_ONTOLOGY"] = "1"
def _buildSystemPrompt(featureCode: str, instanceLabel: str, selectedTables: List[Dict[str, Any]]) -> str:
"""Build the sub-agent system prompt via the production path.
Mode-specific behaviour (legacy hints vs ontology block) is controlled
by the ``POWERON_DISABLE_FEATURE_ONTOLOGY`` env flag set per mode in
:func:`_applyEnvForMode`. Keeping the builder call identical for all
three modes means the benchmark measures the EXACT prompt the agent
would see in production -- no eval-only forks.
"""
return _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang="de")
# ---------------------------------------------------------------------------
# Question loading + per-question evaluation
# ---------------------------------------------------------------------------
@dataclass
class _Question:
id: str
question: str
intent: str
expectedTools: List[str]
expectedTable: Optional[str]
expectedAggregate: Optional[str]
expectedAggregateField: Optional[str]
requiredFilters: Dict[str, Any]
forbiddenTools: List[str]
expectedNumbers: List[float]
expectedAnswerContains: List[str]
numericTolerance: float
def _loadQuestions(yamlPath: Path) -> List[_Question]:
with open(yamlPath, "r", encoding="utf-8") as f:
rawList = yaml.safe_load(f)
questions: List[_Question] = []
for raw in rawList:
questions.append(_Question(
id=raw["id"],
question=raw["question"],
intent=raw.get("intent", ""),
expectedTools=list(raw.get("expectedTools") or []),
expectedTable=raw.get("expectedTable"),
expectedAggregate=raw.get("expectedAggregate"),
expectedAggregateField=raw.get("expectedAggregateField"),
requiredFilters=dict(raw.get("requiredFilters") or {}),
forbiddenTools=list(raw.get("forbiddenTools") or []),
expectedNumbers=[float(x) for x in (raw.get("expectedNumbers") or [])],
expectedAnswerContains=[str(x) for x in (raw.get("expectedAnswerContains") or [])],
numericTolerance=float(raw.get("numericTolerance") or 0.005),
))
return questions
@dataclass
class _RunResult:
questionId: str
finalText: str
toolCalls: List[Dict[str, Any]] = field(default_factory=list)
toolResults: List[Dict[str, Any]] = field(default_factory=list)
summary: Dict[str, Any] = field(default_factory=dict)
durationS: float = 0.0
error: Optional[str] = None
@property
def costCHF(self) -> float:
return float(self.summary.get("costCHF") or 0.0)
@property
def rounds(self) -> int:
return int(self.summary.get("rounds") or 0)
@property
def validationFailures(self) -> int:
return int(self.summary.get("validationFailures") or 0)
@property
def repairAttempts(self) -> int:
return int(self.summary.get("repairAttempts") or 0)
@property
def successAfterRepair(self) -> int:
return int(self.summary.get("successAfterRepair") or 0)
@dataclass
class _Score:
patternOk: bool = False
forbidOk: bool = False
numericOk: bool = False
accuracyOk: bool = False
notes: List[str] = field(default_factory=list)
def _scoreRun(question: _Question, run: _RunResult) -> _Score:
score = _Score()
if run.error:
score.notes.append(f"Sub-agent error: {run.error}")
return score
score.patternOk = _checkPattern(question, run)
score.forbidOk = _checkForbid(question, run)
score.numericOk = _checkNumeric(question, run)
score.accuracyOk = score.patternOk and score.forbidOk and score.numericOk
return score
def _checkPattern(question: _Question, run: _RunResult) -> bool:
"""Did the agent call one of the expected tools on the expected table with required filters?"""
if not question.expectedTools:
return True
matchingCalls = [
c for c in run.toolCalls
if c.get("toolName") in question.expectedTools
and (not question.expectedTable or c.get("args", {}).get("tableName") == question.expectedTable)
]
if not matchingCalls:
return False
if question.expectedAggregate:
wantAgg = question.expectedAggregate.upper()
wantField = question.expectedAggregateField
for c in matchingCalls:
args = c.get("args", {})
if c.get("toolName") != "aggregateTable":
continue
if (args.get("aggregate") or "").upper() != wantAgg:
continue
if wantField and args.get("field") != wantField:
continue
if not _filtersSatisfied(question.requiredFilters, args.get("extraFilters") or args.get("filters") or []):
continue
return True
return False
if question.requiredFilters:
for c in matchingCalls:
args = c.get("args", {})
filters = args.get("filters") or args.get("extraFilters") or []
if _filtersSatisfied(question.requiredFilters, filters):
return True
return False
return True
def _filtersSatisfied(required: Dict[str, Any], actualFilters: List[Dict[str, Any]]) -> bool:
if not required:
return True
for reqField, reqValue in required.items():
if reqField.endswith("Like"):
field = reqField[:-4]
wanted = str(reqValue)
ok = any(
(f.get("field") == field) and (f.get("op", "").upper() in ("LIKE", "ILIKE"))
and str(f.get("value")) == wanted
for f in actualFilters
)
if not ok:
return False
else:
ok = any(
f.get("field") == reqField and _filterValueEqual(f.get("value"), reqValue)
for f in actualFilters
)
if not ok:
return False
return True
def _filterValueEqual(a: Any, b: Any) -> bool:
if a == b:
return True
try:
return str(a).strip() == str(b).strip()
except Exception:
return False
def _checkForbid(question: _Question, run: _RunResult) -> bool:
"""Did the agent AVOID forbidden tool/op combinations?
Forbidden hits only count if the call actually went through to the
provider (success=True). Validator-rejected calls don't count -- the
Repair-Loop is doing its job and steering the agent away.
"""
if not question.forbiddenTools:
return True
forbiddenSet = set(question.forbiddenTools)
for r in run.toolResults:
if not r.get("success"):
continue
if r.get("toolName") in forbiddenSet:
return False
return True
def _checkNumeric(question: _Question, run: _RunResult) -> bool:
text = (run.finalText or "")
if question.expectedNumbers:
textNumbers = _extractNumbers(text)
for expected in question.expectedNumbers:
tol = max(abs(expected) * question.numericTolerance, 0.5)
if not any(abs(n - expected) <= tol for n in textNumbers):
return False
if question.expectedAnswerContains:
lowered = text.lower()
for needle in question.expectedAnswerContains:
if needle.lower() not in lowered:
return False
return True
def _extractNumbers(text: str) -> List[float]:
"""Pick out all numbers from a free-text answer.
Handles Swiss thousand separators (apostrophe and U+2019), German
decimals (comma), plain integers/floats, and JSON numbers. Trailing
punctuation (``,``, ``;``, ``.`` from end-of-sentence) is stripped
before parsing so ``"180500.0,"`` parses cleanly to 180500.0.
"""
cleaned = text.replace("\u2019", "'")
tokens = re.findall(r"-?\d[\d'.,]*", cleaned)
out: List[float] = []
for tok in tokens:
tok = tok.rstrip(",;")
if tok.endswith(".") and tok.count(".") == 1:
tok = tok[:-1]
norm = tok.replace("'", "")
if norm.count(",") == 1 and norm.count(".") == 0:
norm = norm.replace(",", ".")
elif norm.count(",") >= 1 and norm.count(".") >= 1:
if norm.rfind(",") > norm.rfind("."):
norm = norm.replace(".", "").replace(",", ".")
else:
norm = norm.replace(",", "")
else:
norm = norm.replace(",", "")
try:
out.append(float(norm))
except ValueError:
continue
return out
# ---------------------------------------------------------------------------
# AI call wiring
# ---------------------------------------------------------------------------
def _bootstrapServices() -> Tuple[Any, str, str]:
"""Spin up a minimal services bag bound to the root user + initial mandate.
Returns a services bag, the user id, and the mandate id used for billing.
"""
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelUam import Mandate
from modules.serviceCenter import getService
from modules.serviceCenter.context import ServiceCenterContext
rootInterface = getRootInterface()
user = rootInterface.currentUser
mandateId = rootInterface.getInitialId(Mandate)
if not mandateId:
raise RuntimeError("No initial mandate available -- run bootstrap loader first.")
ctx = ServiceCenterContext(user=user, mandate_id=mandateId)
class _BenchmarkServicesBag:
def __init__(self, ctx):
self._ctx = ctx
self.user = ctx.user
self.mandateId = ctx.mandate_id
self.featureInstanceId = ctx.feature_instance_id
self.workflow = ctx.workflow
def __getattr__(self, name):
if name.startswith("_"):
raise AttributeError(name)
svc = getService(name, self._ctx)
setattr(self, name, svc)
return svc
services = _BenchmarkServicesBag(ctx)
return services, user.id, mandateId
async def _runOneQuestion(
*,
services: Any,
userId: str,
mandateId: str,
fixture: BenchmarkFixture,
question: _Question,
mode: _ModeConfig,
) -> _RunResult:
"""Execute a single sub-agent run for one question under one mode."""
provider = FakeFeatureDataProvider(
rowsByTable=fixture.rowsByTable,
availableTables=fixture.selectedTables,
)
validator = _buildValidator(mode)
registry = _buildSubAgentTools(
provider=provider,
featureInstanceId=fixture.featureInstanceId,
mandateId=fixture.mandateId,
tableFilters={},
validator=validator,
)
systemPrompt = _buildSystemPrompt(
featureCode="trustee",
instanceLabel="Demo AG",
selectedTables=fixture.selectedTables,
)
cost = 0.0
async def _aiCallFn(req: AiCallRequest) -> AiCallResponse:
nonlocal cost
resp = await services.ai.callAi(req)
cost += float(getattr(resp, "priceCHF", 0.0) or 0.0)
return resp
async def _getCost() -> float:
return cost
config = AgentConfig(
maxRounds=6,
maxCostCHF=0.50,
operationType=OperationTypeEnum.DATA_QUERY,
)
run = _RunResult(questionId=question.id, finalText="")
t0 = time.time()
try:
async for event in runAgentLoop(
prompt=question.question,
toolRegistry=registry,
config=config,
aiCallFn=_aiCallFn,
getWorkflowCostFn=_getCost,
workflowId=f"eval-{mode.name}-{question.id}-{uuid.uuid4().hex[:6]}",
userId=userId,
featureInstanceId=fixture.featureInstanceId,
mandateId=mandateId,
systemPromptOverride=systemPrompt,
):
if event.type == AgentEventTypeEnum.FINAL:
run.finalText = event.content or run.finalText
elif event.type == AgentEventTypeEnum.MESSAGE and event.content:
run.finalText += event.content
elif event.type == AgentEventTypeEnum.TOOL_CALL:
run.toolCalls.append(dict(event.data or {}))
elif event.type == AgentEventTypeEnum.TOOL_RESULT:
run.toolResults.append(dict(event.data or {}))
elif event.type == AgentEventTypeEnum.AGENT_SUMMARY:
run.summary = dict(event.data or {})
elif event.type == AgentEventTypeEnum.ERROR:
run.error = (run.error or "") + (event.content or "")
except Exception as e:
run.error = f"{type(e).__name__}: {e}"
logger.exception("Sub-agent run failed for %s/%s", mode.name, question.id)
run.durationS = time.time() - t0
return run
# ---------------------------------------------------------------------------
# Report
# ---------------------------------------------------------------------------
@dataclass
class _ModeReport:
mode: _ModeConfig
perQuestion: List[Tuple[_Question, _RunResult, _Score]] = field(default_factory=list)
@property
def total(self) -> int:
return len(self.perQuestion)
def _count(self, attr: str) -> int:
return sum(1 for _, _, s in self.perQuestion if getattr(s, attr))
@property
def accuracy(self) -> float:
return self._count("accuracyOk") / max(self.total, 1)
@property
def patternCompliance(self) -> float:
return self._count("patternOk") / max(self.total, 1)
@property
def repairConversionRate(self) -> float:
attempts = sum(r.repairAttempts for _, r, _ in self.perQuestion)
succeeded = sum(r.successAfterRepair for _, r, _ in self.perQuestion)
if attempts == 0:
return 0.0
return succeeded / attempts
@property
def totalCostCHF(self) -> float:
return sum(r.costCHF for _, r, _ in self.perQuestion)
@property
def totalRounds(self) -> int:
return sum(r.rounds for _, r, _ in self.perQuestion)
@property
def totalValidationFailures(self) -> int:
return sum(r.validationFailures for _, r, _ in self.perQuestion)
def _writeReport(reports: List[_ModeReport], outputPath: Path) -> None:
lines: List[str] = []
lines.append("# Trustee Sub-Agent Benchmark Report")
lines.append("")
lines.append(f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}")
lines.append("")
lines.append("## Summary")
lines.append("")
lines.append("| Mode | Questions | Accuracy | Pattern compliance | Repair conversion | Validator rejects | Rounds | Cost (CHF) |")
lines.append("|---|---|---|---|---|---|---|---|")
for rep in reports:
lines.append(
f"| {rep.mode.label} | {rep.total} | {rep.accuracy:.1%} | {rep.patternCompliance:.1%} | "
f"{rep.repairConversionRate:.1%} | {rep.totalValidationFailures} | {rep.totalRounds} | "
f"{rep.totalCostCHF:.4f} |"
)
lines.append("")
lines.append("## Per-question detail")
for rep in reports:
lines.append("")
lines.append(f"### {rep.mode.label}")
lines.append("")
lines.append("| id | acc | pattern | forbid | numeric | rounds | val-fail | repairs | cost CHF | duration | tools |")
lines.append("|---|---|---|---|---|---|---|---|---|---|---|")
for q, r, s in rep.perQuestion:
toolList = ",".join(
f"{c.get('toolName')}({c.get('args',{}).get('tableName','?')})"
for c in r.toolCalls
)
lines.append(
f"| {q.id} | {_yn(s.accuracyOk)} | {_yn(s.patternOk)} | {_yn(s.forbidOk)} | {_yn(s.numericOk)} | "
f"{r.rounds} | {r.validationFailures} | {r.repairAttempts}/{r.successAfterRepair} | "
f"{r.costCHF:.4f} | {r.durationS:.1f}s | {toolList} |"
)
lines.append("")
lines.append("#### Notes & failures")
for q, r, s in rep.perQuestion:
if s.accuracyOk:
continue
lines.append(f"- **{q.id}** ({q.intent}): pattern={s.patternOk} forbid={s.forbidOk} numeric={s.numericOk}")
if r.error:
lines.append(f" - error: `{r.error}`")
lines.append(f" - answer: `{(r.finalText or '').strip().replace('|', '/').splitlines()[0][:240]}`")
for note in s.notes:
lines.append(f" - note: {note}")
outputPath.parent.mkdir(parents=True, exist_ok=True)
outputPath.write_text("\n".join(lines), encoding="utf-8")
def _yn(b: bool) -> str:
return "OK" if b else "FAIL"
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
async def _runMain(modesToRun: List[str], onlyQuestionId: Optional[str] = None) -> None:
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s %(levelname)s %(name)s -- %(message)s",
)
logger.setLevel(logging.INFO)
fixture = buildTrusteeBenchmarkFixture()
questionsPath = _GATEWAY_DIR / "tests" / "fixtures" / "trusteeBenchmark" / "questions.yaml"
allQuestions = _loadQuestions(questionsPath)
if onlyQuestionId:
allQuestions = [q for q in allQuestions if q.id == onlyQuestionId]
if not allQuestions:
print(f"No question matches id={onlyQuestionId!r}")
return
print(f"Loaded {len(allQuestions)} questions, {len(modesToRun)} modes -> {len(allQuestions) * len(modesToRun)} sub-agent runs.")
services, userId, mandateId = _bootstrapServices()
print(f"Bootstrap OK: user={userId}, mandate={mandateId}")
reports: List[_ModeReport] = []
for modeName in modesToRun:
mode = _MODES[modeName]
_applyEnvForMode(mode)
rep = _ModeReport(mode=mode)
print(f"\n=== Mode: {mode.label} ===")
for idx, question in enumerate(allQuestions, start=1):
print(f" [{idx:>2}/{len(allQuestions)}] {question.id}: {question.question[:80]} ...", flush=True)
run = await _runOneQuestion(
services=services,
userId=userId,
mandateId=mandateId,
fixture=fixture,
question=question,
mode=mode,
)
score = _scoreRun(question, run)
rep.perQuestion.append((question, run, score))
print(
f" -> acc={_yn(score.accuracyOk)} "
f"pattern={_yn(score.patternOk)} forbid={_yn(score.forbidOk)} "
f"numeric={_yn(score.numericOk)} rounds={run.rounds} cost={run.costCHF:.4f} "
f"val-fail={run.validationFailures} repairs={run.repairAttempts}/{run.successAfterRepair}",
flush=True,
)
reports.append(rep)
timestamp = time.strftime("%Y%m%d-%H%M%S")
outDir = _GATEWAY_DIR.parent / "local" / "notes"
reportPath = outDir / f"trustee-benchmark-{timestamp}.md"
_writeReport(reports, reportPath)
rawJsonPath = outDir / f"trustee-benchmark-{timestamp}.json"
rawJsonPath.write_text(
json.dumps(
[
{
"mode": rep.mode.name,
"accuracy": rep.accuracy,
"patternCompliance": rep.patternCompliance,
"repairConversionRate": rep.repairConversionRate,
"totalCostCHF": rep.totalCostCHF,
"totalRounds": rep.totalRounds,
"totalValidationFailures": rep.totalValidationFailures,
"items": [
{
"questionId": q.id,
"intent": q.intent,
"accuracyOk": s.accuracyOk,
"patternOk": s.patternOk,
"forbidOk": s.forbidOk,
"numericOk": s.numericOk,
"rounds": r.rounds,
"validationFailures": r.validationFailures,
"repairAttempts": r.repairAttempts,
"successAfterRepair": r.successAfterRepair,
"costCHF": r.costCHF,
"durationS": r.durationS,
"finalText": (r.finalText or "")[:600],
"toolCalls": r.toolCalls,
"error": r.error,
}
for q, r, s in rep.perQuestion
],
}
for rep in reports
],
indent=2,
ensure_ascii=False,
),
encoding="utf-8",
)
print(f"\nReport written: {reportPath}")
print(f"Raw JSON: {rawJsonPath}")
for rep in reports:
print(f" {rep.mode.label}: acc={rep.accuracy:.1%} pattern={rep.patternCompliance:.1%} cost={rep.totalCostCHF:.4f}")
def _parseArgs(argv: List[str]) -> Tuple[List[str], Optional[str]]:
modes: List[str] = []
only: Optional[str] = None
for arg in argv:
if arg.startswith("--only="):
only = arg.split("=", 1)[1]
elif arg in _MODES:
modes.append(arg)
else:
print(f"Unknown argument: {arg!r}. Allowed modes: {list(_MODES)}")
sys.exit(2)
if not modes:
modes = ["baseline", "phase1", "phase2"]
return modes, only
def main() -> None:
modes, only = _parseArgs(sys.argv[1:])
asyncio.run(_runMain(modes, onlyQuestionId=only))
if __name__ == "__main__":
main()