# Copyright (c) 2026 PowerOn AG # All rights reserved. """Trustee Sub-Agent Eval Harness (Phase 1.5). Standalone runner that fires real AI calls against the Feature Data Sub-Agent in three configurations: * ``baseline`` -- production code without the pre-execute validator (Repair-Loop disabled, Trustee domain hints active). * ``phase1`` -- pre-execute validator on (Repair-Loop active), domain hints active, no ontology yet. * ``phase2`` -- validator on, ontology-driven schema context + constraints (replaces hand-written domain hints). For each mode we run all 19 gold-standard questions against an in-memory :class:`FakeFeatureDataProvider`, capture the agent's tool calls and final answer, score them against the gold standard, and write a Markdown report to ``local/notes/`` for analysis. Usage:: cd gateway python -m tests.eval.runTrusteeBenchmark # all 3 modes python -m tests.eval.runTrusteeBenchmark phase1 # one mode only python -m tests.eval.runTrusteeBenchmark baseline phase1 """ from __future__ import annotations import asyncio import json import logging import os import re import sys import time import uuid from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple # --------------------------------------------------------------------------- # Path setup so `python -m tests.eval.runTrusteeBenchmark` works from gateway/ # --------------------------------------------------------------------------- _GATEWAY_DIR = Path(__file__).resolve().parents[2] if str(_GATEWAY_DIR) not in sys.path: sys.path.insert(0, str(_GATEWAY_DIR)) import yaml # noqa: E402 from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( # noqa: E402 AgentConfig, AgentEventTypeEnum, ) from modules.datamodels.datamodelAi import ( # noqa: E402 AiCallRequest, AiCallResponse, OperationTypeEnum, ) from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop # noqa: E402 from modules.serviceCenter.services.serviceAgent.featureDataAgent import ( # noqa: E402 _buildSubAgentTools, _buildSchemaContext, ) from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( # noqa: E402 QueryValidationError, ) from modules.serviceCenter.services.serviceAgent.queryValidator import ( # noqa: E402 QueryValidator, ) from tests.eval.fakeFeatureDataProvider import ( # noqa: E402 FakeFeatureDataProvider, ) from tests.fixtures.trusteeBenchmark.loadTrusteeBenchmarkFixture import ( # noqa: E402 buildTrusteeBenchmarkFixture, BenchmarkFixture, ) logger = logging.getLogger("trusteeBenchmark") # --------------------------------------------------------------------------- # NoOpValidator -- baseline mode (Repair-Loop OFF) # --------------------------------------------------------------------------- class _NoOpValidator(QueryValidator): """Validator that never rejects anything (used for baseline measurement).""" def validateBrowseQuery(self, tableName, args): # noqa: ARG002 return None def validateQueryTable(self, tableName, args): # noqa: ARG002 return None def validateAggregateQuery(self, tableName, args): # noqa: ARG002 return None # --------------------------------------------------------------------------- # Mode-specific tool/prompt building # --------------------------------------------------------------------------- @dataclass class _ModeConfig: name: str label: str useValidator: bool useOntology: bool _MODES: Dict[str, _ModeConfig] = { "baseline": _ModeConfig(name="baseline", label="Baseline (no validator)", useValidator=False, useOntology=False), "phase1": _ModeConfig(name="phase1", label="Phase 1 (validator on)", useValidator=True, useOntology=False), "phase2": _ModeConfig(name="phase2", label="Phase 2 (validator + ontology)", useValidator=True, useOntology=True), } def _buildValidator(mode: _ModeConfig) -> QueryValidator: """Construct the per-mode validator. * baseline: no-op (Repair-Loop disabled, used to measure raw LLM accuracy against today's prompt path). * phase1: convention-based QueryValidator (NEVER_AGGREGATE on ``*Balance``/``*Total`` suffixes; no ontology). * phase2: ontology-driven QueryValidator (constraints from the trustee ontology override the convention defaults). """ if not mode.useValidator: return _NoOpValidator() if mode.useOntology: try: from modules.features.trustee.trusteeOntology import getTrusteeOntology return QueryValidator(ontology=getTrusteeOntology()) except Exception as e: logger.warning("Could not load trustee ontology, falling back: %s", e) return QueryValidator() def _applyEnvForMode(mode: _ModeConfig) -> None: """Set the ontology toggle for the production prompt builder. The Phase 2 path uses ``featureDataAgent._buildSchemaContext`` to pull the prompt block from ``getAgentOntology()`` automatically. For baseline/phase1 we set ``POWERON_DISABLE_FEATURE_ONTOLOGY=1`` so the builder falls back to the legacy ``getAgentDomainHints()`` block -- measuring exactly the production prompt that ships today. """ if mode.useOntology: os.environ.pop("POWERON_DISABLE_FEATURE_ONTOLOGY", None) else: os.environ["POWERON_DISABLE_FEATURE_ONTOLOGY"] = "1" def _buildSystemPrompt(featureCode: str, instanceLabel: str, selectedTables: List[Dict[str, Any]]) -> str: """Build the sub-agent system prompt via the production path. Mode-specific behaviour (legacy hints vs ontology block) is controlled by the ``POWERON_DISABLE_FEATURE_ONTOLOGY`` env flag set per mode in :func:`_applyEnvForMode`. Keeping the builder call identical for all three modes means the benchmark measures the EXACT prompt the agent would see in production -- no eval-only forks. """ return _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang="de") # --------------------------------------------------------------------------- # Question loading + per-question evaluation # --------------------------------------------------------------------------- @dataclass class _Question: id: str question: str intent: str expectedTools: List[str] expectedTable: Optional[str] expectedAggregate: Optional[str] expectedAggregateField: Optional[str] requiredFilters: Dict[str, Any] forbiddenTools: List[str] expectedNumbers: List[float] expectedAnswerContains: List[str] numericTolerance: float def _loadQuestions(yamlPath: Path) -> List[_Question]: with open(yamlPath, "r", encoding="utf-8") as f: rawList = yaml.safe_load(f) questions: List[_Question] = [] for raw in rawList: questions.append(_Question( id=raw["id"], question=raw["question"], intent=raw.get("intent", ""), expectedTools=list(raw.get("expectedTools") or []), expectedTable=raw.get("expectedTable"), expectedAggregate=raw.get("expectedAggregate"), expectedAggregateField=raw.get("expectedAggregateField"), requiredFilters=dict(raw.get("requiredFilters") or {}), forbiddenTools=list(raw.get("forbiddenTools") or []), expectedNumbers=[float(x) for x in (raw.get("expectedNumbers") or [])], expectedAnswerContains=[str(x) for x in (raw.get("expectedAnswerContains") or [])], numericTolerance=float(raw.get("numericTolerance") or 0.005), )) return questions @dataclass class _RunResult: questionId: str finalText: str toolCalls: List[Dict[str, Any]] = field(default_factory=list) toolResults: List[Dict[str, Any]] = field(default_factory=list) summary: Dict[str, Any] = field(default_factory=dict) durationS: float = 0.0 error: Optional[str] = None @property def costCHF(self) -> float: return float(self.summary.get("costCHF") or 0.0) @property def rounds(self) -> int: return int(self.summary.get("rounds") or 0) @property def validationFailures(self) -> int: return int(self.summary.get("validationFailures") or 0) @property def repairAttempts(self) -> int: return int(self.summary.get("repairAttempts") or 0) @property def successAfterRepair(self) -> int: return int(self.summary.get("successAfterRepair") or 0) @dataclass class _Score: patternOk: bool = False forbidOk: bool = False numericOk: bool = False accuracyOk: bool = False notes: List[str] = field(default_factory=list) def _scoreRun(question: _Question, run: _RunResult) -> _Score: score = _Score() if run.error: score.notes.append(f"Sub-agent error: {run.error}") return score score.patternOk = _checkPattern(question, run) score.forbidOk = _checkForbid(question, run) score.numericOk = _checkNumeric(question, run) score.accuracyOk = score.patternOk and score.forbidOk and score.numericOk return score def _checkPattern(question: _Question, run: _RunResult) -> bool: """Did the agent call one of the expected tools on the expected table with required filters?""" if not question.expectedTools: return True matchingCalls = [ c for c in run.toolCalls if c.get("toolName") in question.expectedTools and (not question.expectedTable or c.get("args", {}).get("tableName") == question.expectedTable) ] if not matchingCalls: return False if question.expectedAggregate: wantAgg = question.expectedAggregate.upper() wantField = question.expectedAggregateField for c in matchingCalls: args = c.get("args", {}) if c.get("toolName") != "aggregateTable": continue if (args.get("aggregate") or "").upper() != wantAgg: continue if wantField and args.get("field") != wantField: continue if not _filtersSatisfied(question.requiredFilters, args.get("extraFilters") or args.get("filters") or []): continue return True return False if question.requiredFilters: for c in matchingCalls: args = c.get("args", {}) filters = args.get("filters") or args.get("extraFilters") or [] if _filtersSatisfied(question.requiredFilters, filters): return True return False return True def _filtersSatisfied(required: Dict[str, Any], actualFilters: List[Dict[str, Any]]) -> bool: if not required: return True for reqField, reqValue in required.items(): if reqField.endswith("Like"): field = reqField[:-4] wanted = str(reqValue) ok = any( (f.get("field") == field) and (f.get("op", "").upper() in ("LIKE", "ILIKE")) and str(f.get("value")) == wanted for f in actualFilters ) if not ok: return False else: ok = any( f.get("field") == reqField and _filterValueEqual(f.get("value"), reqValue) for f in actualFilters ) if not ok: return False return True def _filterValueEqual(a: Any, b: Any) -> bool: if a == b: return True try: return str(a).strip() == str(b).strip() except Exception: return False def _checkForbid(question: _Question, run: _RunResult) -> bool: """Did the agent AVOID forbidden tool/op combinations? Forbidden hits only count if the call actually went through to the provider (success=True). Validator-rejected calls don't count -- the Repair-Loop is doing its job and steering the agent away. """ if not question.forbiddenTools: return True forbiddenSet = set(question.forbiddenTools) for r in run.toolResults: if not r.get("success"): continue if r.get("toolName") in forbiddenSet: return False return True def _checkNumeric(question: _Question, run: _RunResult) -> bool: text = (run.finalText or "") if question.expectedNumbers: textNumbers = _extractNumbers(text) for expected in question.expectedNumbers: tol = max(abs(expected) * question.numericTolerance, 0.5) if not any(abs(n - expected) <= tol for n in textNumbers): return False if question.expectedAnswerContains: lowered = text.lower() for needle in question.expectedAnswerContains: if needle.lower() not in lowered: return False return True def _extractNumbers(text: str) -> List[float]: """Pick out all numbers from a free-text answer. Handles Swiss thousand separators (apostrophe and U+2019), German decimals (comma), plain integers/floats, and JSON numbers. Trailing punctuation (``,``, ``;``, ``.`` from end-of-sentence) is stripped before parsing so ``"180500.0,"`` parses cleanly to 180500.0. """ cleaned = text.replace("\u2019", "'") tokens = re.findall(r"-?\d[\d'.,]*", cleaned) out: List[float] = [] for tok in tokens: tok = tok.rstrip(",;") if tok.endswith(".") and tok.count(".") == 1: tok = tok[:-1] norm = tok.replace("'", "") if norm.count(",") == 1 and norm.count(".") == 0: norm = norm.replace(",", ".") elif norm.count(",") >= 1 and norm.count(".") >= 1: if norm.rfind(",") > norm.rfind("."): norm = norm.replace(".", "").replace(",", ".") else: norm = norm.replace(",", "") else: norm = norm.replace(",", "") try: out.append(float(norm)) except ValueError: continue return out # --------------------------------------------------------------------------- # AI call wiring # --------------------------------------------------------------------------- def _bootstrapServices() -> Tuple[Any, str, str]: """Spin up a minimal services bag bound to the root user + initial mandate. Returns a services bag, the user id, and the mandate id used for billing. """ from modules.interfaces.interfaceDbApp import getRootInterface from modules.datamodels.datamodelUam import Mandate from modules.serviceCenter import getService from modules.serviceCenter.context import ServiceCenterContext rootInterface = getRootInterface() user = rootInterface.currentUser mandateId = rootInterface.getInitialId(Mandate) if not mandateId: raise RuntimeError("No initial mandate available -- run bootstrap loader first.") ctx = ServiceCenterContext(user=user, mandate_id=mandateId) class _BenchmarkServicesBag: def __init__(self, ctx): self._ctx = ctx self.user = ctx.user self.mandateId = ctx.mandate_id self.featureInstanceId = ctx.feature_instance_id self.workflow = ctx.workflow def __getattr__(self, name): if name.startswith("_"): raise AttributeError(name) svc = getService(name, self._ctx) setattr(self, name, svc) return svc services = _BenchmarkServicesBag(ctx) return services, user.id, mandateId async def _runOneQuestion( *, services: Any, userId: str, mandateId: str, fixture: BenchmarkFixture, question: _Question, mode: _ModeConfig, ) -> _RunResult: """Execute a single sub-agent run for one question under one mode.""" provider = FakeFeatureDataProvider( rowsByTable=fixture.rowsByTable, availableTables=fixture.selectedTables, ) validator = _buildValidator(mode) registry = _buildSubAgentTools( provider=provider, featureInstanceId=fixture.featureInstanceId, mandateId=fixture.mandateId, tableFilters={}, validator=validator, ) systemPrompt = _buildSystemPrompt( featureCode="trustee", instanceLabel="Demo AG", selectedTables=fixture.selectedTables, ) cost = 0.0 async def _aiCallFn(req: AiCallRequest) -> AiCallResponse: nonlocal cost resp = await services.ai.callAi(req) cost += float(getattr(resp, "priceCHF", 0.0) or 0.0) return resp async def _getCost() -> float: return cost config = AgentConfig( maxRounds=6, maxCostCHF=0.50, operationType=OperationTypeEnum.DATA_QUERY, ) run = _RunResult(questionId=question.id, finalText="") t0 = time.time() try: async for event in runAgentLoop( prompt=question.question, toolRegistry=registry, config=config, aiCallFn=_aiCallFn, getWorkflowCostFn=_getCost, workflowId=f"eval-{mode.name}-{question.id}-{uuid.uuid4().hex[:6]}", userId=userId, featureInstanceId=fixture.featureInstanceId, mandateId=mandateId, systemPromptOverride=systemPrompt, ): if event.type == AgentEventTypeEnum.FINAL: run.finalText = event.content or run.finalText elif event.type == AgentEventTypeEnum.MESSAGE and event.content: run.finalText += event.content elif event.type == AgentEventTypeEnum.TOOL_CALL: run.toolCalls.append(dict(event.data or {})) elif event.type == AgentEventTypeEnum.TOOL_RESULT: run.toolResults.append(dict(event.data or {})) elif event.type == AgentEventTypeEnum.AGENT_SUMMARY: run.summary = dict(event.data or {}) elif event.type == AgentEventTypeEnum.ERROR: run.error = (run.error or "") + (event.content or "") except Exception as e: run.error = f"{type(e).__name__}: {e}" logger.exception("Sub-agent run failed for %s/%s", mode.name, question.id) run.durationS = time.time() - t0 return run # --------------------------------------------------------------------------- # Report # --------------------------------------------------------------------------- @dataclass class _ModeReport: mode: _ModeConfig perQuestion: List[Tuple[_Question, _RunResult, _Score]] = field(default_factory=list) @property def total(self) -> int: return len(self.perQuestion) def _count(self, attr: str) -> int: return sum(1 for _, _, s in self.perQuestion if getattr(s, attr)) @property def accuracy(self) -> float: return self._count("accuracyOk") / max(self.total, 1) @property def patternCompliance(self) -> float: return self._count("patternOk") / max(self.total, 1) @property def repairConversionRate(self) -> float: attempts = sum(r.repairAttempts for _, r, _ in self.perQuestion) succeeded = sum(r.successAfterRepair for _, r, _ in self.perQuestion) if attempts == 0: return 0.0 return succeeded / attempts @property def totalCostCHF(self) -> float: return sum(r.costCHF for _, r, _ in self.perQuestion) @property def totalRounds(self) -> int: return sum(r.rounds for _, r, _ in self.perQuestion) @property def totalValidationFailures(self) -> int: return sum(r.validationFailures for _, r, _ in self.perQuestion) def _writeReport(reports: List[_ModeReport], outputPath: Path) -> None: lines: List[str] = [] lines.append("# Trustee Sub-Agent Benchmark Report") lines.append("") lines.append(f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}") lines.append("") lines.append("## Summary") lines.append("") lines.append("| Mode | Questions | Accuracy | Pattern compliance | Repair conversion | Validator rejects | Rounds | Cost (CHF) |") lines.append("|---|---|---|---|---|---|---|---|") for rep in reports: lines.append( f"| {rep.mode.label} | {rep.total} | {rep.accuracy:.1%} | {rep.patternCompliance:.1%} | " f"{rep.repairConversionRate:.1%} | {rep.totalValidationFailures} | {rep.totalRounds} | " f"{rep.totalCostCHF:.4f} |" ) lines.append("") lines.append("## Per-question detail") for rep in reports: lines.append("") lines.append(f"### {rep.mode.label}") lines.append("") lines.append("| id | acc | pattern | forbid | numeric | rounds | val-fail | repairs | cost CHF | duration | tools |") lines.append("|---|---|---|---|---|---|---|---|---|---|---|") for q, r, s in rep.perQuestion: toolList = ",".join( f"{c.get('toolName')}({c.get('args',{}).get('tableName','?')})" for c in r.toolCalls ) lines.append( f"| {q.id} | {_yn(s.accuracyOk)} | {_yn(s.patternOk)} | {_yn(s.forbidOk)} | {_yn(s.numericOk)} | " f"{r.rounds} | {r.validationFailures} | {r.repairAttempts}/{r.successAfterRepair} | " f"{r.costCHF:.4f} | {r.durationS:.1f}s | {toolList} |" ) lines.append("") lines.append("#### Notes & failures") for q, r, s in rep.perQuestion: if s.accuracyOk: continue lines.append(f"- **{q.id}** ({q.intent}): pattern={s.patternOk} forbid={s.forbidOk} numeric={s.numericOk}") if r.error: lines.append(f" - error: `{r.error}`") lines.append(f" - answer: `{(r.finalText or '').strip().replace('|', '/').splitlines()[0][:240]}`") for note in s.notes: lines.append(f" - note: {note}") outputPath.parent.mkdir(parents=True, exist_ok=True) outputPath.write_text("\n".join(lines), encoding="utf-8") def _yn(b: bool) -> str: return "OK" if b else "FAIL" # --------------------------------------------------------------------------- # Main entry point # --------------------------------------------------------------------------- async def _runMain(modesToRun: List[str], onlyQuestionId: Optional[str] = None) -> None: logging.basicConfig( level=logging.WARNING, format="%(asctime)s %(levelname)s %(name)s -- %(message)s", ) logger.setLevel(logging.INFO) fixture = buildTrusteeBenchmarkFixture() questionsPath = _GATEWAY_DIR / "tests" / "fixtures" / "trusteeBenchmark" / "questions.yaml" allQuestions = _loadQuestions(questionsPath) if onlyQuestionId: allQuestions = [q for q in allQuestions if q.id == onlyQuestionId] if not allQuestions: print(f"No question matches id={onlyQuestionId!r}") return print(f"Loaded {len(allQuestions)} questions, {len(modesToRun)} modes -> {len(allQuestions) * len(modesToRun)} sub-agent runs.") services, userId, mandateId = _bootstrapServices() print(f"Bootstrap OK: user={userId}, mandate={mandateId}") reports: List[_ModeReport] = [] for modeName in modesToRun: mode = _MODES[modeName] _applyEnvForMode(mode) rep = _ModeReport(mode=mode) print(f"\n=== Mode: {mode.label} ===") for idx, question in enumerate(allQuestions, start=1): print(f" [{idx:>2}/{len(allQuestions)}] {question.id}: {question.question[:80]} ...", flush=True) run = await _runOneQuestion( services=services, userId=userId, mandateId=mandateId, fixture=fixture, question=question, mode=mode, ) score = _scoreRun(question, run) rep.perQuestion.append((question, run, score)) print( f" -> acc={_yn(score.accuracyOk)} " f"pattern={_yn(score.patternOk)} forbid={_yn(score.forbidOk)} " f"numeric={_yn(score.numericOk)} rounds={run.rounds} cost={run.costCHF:.4f} " f"val-fail={run.validationFailures} repairs={run.repairAttempts}/{run.successAfterRepair}", flush=True, ) reports.append(rep) timestamp = time.strftime("%Y%m%d-%H%M%S") outDir = _GATEWAY_DIR.parent / "local" / "notes" reportPath = outDir / f"trustee-benchmark-{timestamp}.md" _writeReport(reports, reportPath) rawJsonPath = outDir / f"trustee-benchmark-{timestamp}.json" rawJsonPath.write_text( json.dumps( [ { "mode": rep.mode.name, "accuracy": rep.accuracy, "patternCompliance": rep.patternCompliance, "repairConversionRate": rep.repairConversionRate, "totalCostCHF": rep.totalCostCHF, "totalRounds": rep.totalRounds, "totalValidationFailures": rep.totalValidationFailures, "items": [ { "questionId": q.id, "intent": q.intent, "accuracyOk": s.accuracyOk, "patternOk": s.patternOk, "forbidOk": s.forbidOk, "numericOk": s.numericOk, "rounds": r.rounds, "validationFailures": r.validationFailures, "repairAttempts": r.repairAttempts, "successAfterRepair": r.successAfterRepair, "costCHF": r.costCHF, "durationS": r.durationS, "finalText": (r.finalText or "")[:600], "toolCalls": r.toolCalls, "error": r.error, } for q, r, s in rep.perQuestion ], } for rep in reports ], indent=2, ensure_ascii=False, ), encoding="utf-8", ) print(f"\nReport written: {reportPath}") print(f"Raw JSON: {rawJsonPath}") for rep in reports: print(f" {rep.mode.label}: acc={rep.accuracy:.1%} pattern={rep.patternCompliance:.1%} cost={rep.totalCostCHF:.4f}") def _parseArgs(argv: List[str]) -> Tuple[List[str], Optional[str]]: modes: List[str] = [] only: Optional[str] = None for arg in argv: if arg.startswith("--only="): only = arg.split("=", 1)[1] elif arg in _MODES: modes.append(arg) else: print(f"Unknown argument: {arg!r}. Allowed modes: {list(_MODES)}") sys.exit(2) if not modes: modes = ["baseline", "phase1", "phase2"] return modes, only def main() -> None: modes, only = _parseArgs(sys.argv[1:]) asyncio.run(_runMain(modes, onlyQuestionId=only)) if __name__ == "__main__": main()