735 lines
26 KiB
Python
735 lines
26 KiB
Python
# Copyright (c) 2026 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Trustee Sub-Agent Eval Harness (Phase 1.5).
|
|
|
|
Standalone runner that fires real AI calls against the Feature Data
|
|
Sub-Agent in three configurations:
|
|
|
|
* ``baseline`` -- production code without the pre-execute validator
|
|
(Repair-Loop disabled, Trustee domain hints active).
|
|
* ``phase1`` -- pre-execute validator on (Repair-Loop active),
|
|
domain hints active, no ontology yet.
|
|
* ``phase2`` -- validator on, ontology-driven schema context +
|
|
constraints (replaces hand-written domain hints).
|
|
|
|
For each mode we run all 19 gold-standard questions against an
|
|
in-memory :class:`FakeFeatureDataProvider`, capture the agent's tool
|
|
calls and final answer, score them against the gold standard, and
|
|
write a Markdown report to ``local/notes/`` for analysis.
|
|
|
|
Usage::
|
|
|
|
cd gateway
|
|
python -m tests.eval.runTrusteeBenchmark # all 3 modes
|
|
python -m tests.eval.runTrusteeBenchmark phase1 # one mode only
|
|
python -m tests.eval.runTrusteeBenchmark baseline phase1
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Path setup so `python -m tests.eval.runTrusteeBenchmark` works from gateway/
|
|
# ---------------------------------------------------------------------------
|
|
_GATEWAY_DIR = Path(__file__).resolve().parents[2]
|
|
if str(_GATEWAY_DIR) not in sys.path:
|
|
sys.path.insert(0, str(_GATEWAY_DIR))
|
|
|
|
import yaml # noqa: E402
|
|
|
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( # noqa: E402
|
|
AgentConfig,
|
|
AgentEventTypeEnum,
|
|
)
|
|
from modules.datamodels.datamodelAi import ( # noqa: E402
|
|
AiCallRequest,
|
|
AiCallResponse,
|
|
OperationTypeEnum,
|
|
)
|
|
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop # noqa: E402
|
|
from modules.serviceCenter.services.serviceAgent.featureDataAgent import ( # noqa: E402
|
|
_buildSubAgentTools,
|
|
_buildSchemaContext,
|
|
)
|
|
from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( # noqa: E402
|
|
QueryValidationError,
|
|
)
|
|
from modules.serviceCenter.services.serviceAgent.queryValidator import ( # noqa: E402
|
|
QueryValidator,
|
|
)
|
|
|
|
from tests.eval.fakeFeatureDataProvider import ( # noqa: E402
|
|
FakeFeatureDataProvider,
|
|
)
|
|
from tests.fixtures.trusteeBenchmark.loadTrusteeBenchmarkFixture import ( # noqa: E402
|
|
buildTrusteeBenchmarkFixture,
|
|
BenchmarkFixture,
|
|
)
|
|
|
|
|
|
logger = logging.getLogger("trusteeBenchmark")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# NoOpValidator -- baseline mode (Repair-Loop OFF)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _NoOpValidator(QueryValidator):
|
|
"""Validator that never rejects anything (used for baseline measurement)."""
|
|
|
|
def validateBrowseQuery(self, tableName, args): # noqa: ARG002
|
|
return None
|
|
|
|
def validateQueryTable(self, tableName, args): # noqa: ARG002
|
|
return None
|
|
|
|
def validateAggregateQuery(self, tableName, args): # noqa: ARG002
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mode-specific tool/prompt building
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class _ModeConfig:
|
|
name: str
|
|
label: str
|
|
useValidator: bool
|
|
useOntology: bool
|
|
|
|
|
|
_MODES: Dict[str, _ModeConfig] = {
|
|
"baseline": _ModeConfig(name="baseline", label="Baseline (no validator)", useValidator=False, useOntology=False),
|
|
"phase1": _ModeConfig(name="phase1", label="Phase 1 (validator on)", useValidator=True, useOntology=False),
|
|
"phase2": _ModeConfig(name="phase2", label="Phase 2 (validator + ontology)", useValidator=True, useOntology=True),
|
|
}
|
|
|
|
|
|
def _buildValidator(mode: _ModeConfig) -> QueryValidator:
|
|
"""Construct the per-mode validator.
|
|
|
|
* baseline: no-op (Repair-Loop disabled, used to measure raw LLM
|
|
accuracy against today's prompt path).
|
|
* phase1: convention-based QueryValidator (NEVER_AGGREGATE on
|
|
``*Balance``/``*Total`` suffixes; no ontology).
|
|
* phase2: ontology-driven QueryValidator (constraints from the
|
|
trustee ontology override the convention defaults).
|
|
"""
|
|
if not mode.useValidator:
|
|
return _NoOpValidator()
|
|
if mode.useOntology:
|
|
try:
|
|
from modules.features.trustee.trusteeOntology import getTrusteeOntology
|
|
return QueryValidator(ontology=getTrusteeOntology())
|
|
except Exception as e:
|
|
logger.warning("Could not load trustee ontology, falling back: %s", e)
|
|
return QueryValidator()
|
|
|
|
|
|
def _applyEnvForMode(mode: _ModeConfig) -> None:
|
|
"""Set the ontology toggle for the production prompt builder.
|
|
|
|
The Phase 2 path uses ``featureDataAgent._buildSchemaContext`` to pull
|
|
the prompt block from ``getAgentOntology()`` automatically. For
|
|
baseline/phase1 we set ``POWERON_DISABLE_FEATURE_ONTOLOGY=1`` so the
|
|
builder falls back to the legacy ``getAgentDomainHints()`` block --
|
|
measuring exactly the production prompt that ships today.
|
|
"""
|
|
if mode.useOntology:
|
|
os.environ.pop("POWERON_DISABLE_FEATURE_ONTOLOGY", None)
|
|
else:
|
|
os.environ["POWERON_DISABLE_FEATURE_ONTOLOGY"] = "1"
|
|
|
|
|
|
def _buildSystemPrompt(featureCode: str, instanceLabel: str, selectedTables: List[Dict[str, Any]]) -> str:
|
|
"""Build the sub-agent system prompt via the production path.
|
|
|
|
Mode-specific behaviour (legacy hints vs ontology block) is controlled
|
|
by the ``POWERON_DISABLE_FEATURE_ONTOLOGY`` env flag set per mode in
|
|
:func:`_applyEnvForMode`. Keeping the builder call identical for all
|
|
three modes means the benchmark measures the EXACT prompt the agent
|
|
would see in production -- no eval-only forks.
|
|
"""
|
|
return _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang="de")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Question loading + per-question evaluation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class _Question:
|
|
id: str
|
|
question: str
|
|
intent: str
|
|
expectedTools: List[str]
|
|
expectedTable: Optional[str]
|
|
expectedAggregate: Optional[str]
|
|
expectedAggregateField: Optional[str]
|
|
requiredFilters: Dict[str, Any]
|
|
forbiddenTools: List[str]
|
|
expectedNumbers: List[float]
|
|
expectedAnswerContains: List[str]
|
|
numericTolerance: float
|
|
|
|
|
|
def _loadQuestions(yamlPath: Path) -> List[_Question]:
|
|
with open(yamlPath, "r", encoding="utf-8") as f:
|
|
rawList = yaml.safe_load(f)
|
|
questions: List[_Question] = []
|
|
for raw in rawList:
|
|
questions.append(_Question(
|
|
id=raw["id"],
|
|
question=raw["question"],
|
|
intent=raw.get("intent", ""),
|
|
expectedTools=list(raw.get("expectedTools") or []),
|
|
expectedTable=raw.get("expectedTable"),
|
|
expectedAggregate=raw.get("expectedAggregate"),
|
|
expectedAggregateField=raw.get("expectedAggregateField"),
|
|
requiredFilters=dict(raw.get("requiredFilters") or {}),
|
|
forbiddenTools=list(raw.get("forbiddenTools") or []),
|
|
expectedNumbers=[float(x) for x in (raw.get("expectedNumbers") or [])],
|
|
expectedAnswerContains=[str(x) for x in (raw.get("expectedAnswerContains") or [])],
|
|
numericTolerance=float(raw.get("numericTolerance") or 0.005),
|
|
))
|
|
return questions
|
|
|
|
|
|
@dataclass
|
|
class _RunResult:
|
|
questionId: str
|
|
finalText: str
|
|
toolCalls: List[Dict[str, Any]] = field(default_factory=list)
|
|
toolResults: List[Dict[str, Any]] = field(default_factory=list)
|
|
summary: Dict[str, Any] = field(default_factory=dict)
|
|
durationS: float = 0.0
|
|
error: Optional[str] = None
|
|
|
|
@property
|
|
def costCHF(self) -> float:
|
|
return float(self.summary.get("costCHF") or 0.0)
|
|
|
|
@property
|
|
def rounds(self) -> int:
|
|
return int(self.summary.get("rounds") or 0)
|
|
|
|
@property
|
|
def validationFailures(self) -> int:
|
|
return int(self.summary.get("validationFailures") or 0)
|
|
|
|
@property
|
|
def repairAttempts(self) -> int:
|
|
return int(self.summary.get("repairAttempts") or 0)
|
|
|
|
@property
|
|
def successAfterRepair(self) -> int:
|
|
return int(self.summary.get("successAfterRepair") or 0)
|
|
|
|
|
|
@dataclass
|
|
class _Score:
|
|
patternOk: bool = False
|
|
forbidOk: bool = False
|
|
numericOk: bool = False
|
|
accuracyOk: bool = False
|
|
notes: List[str] = field(default_factory=list)
|
|
|
|
|
|
def _scoreRun(question: _Question, run: _RunResult) -> _Score:
|
|
score = _Score()
|
|
if run.error:
|
|
score.notes.append(f"Sub-agent error: {run.error}")
|
|
return score
|
|
|
|
score.patternOk = _checkPattern(question, run)
|
|
score.forbidOk = _checkForbid(question, run)
|
|
score.numericOk = _checkNumeric(question, run)
|
|
score.accuracyOk = score.patternOk and score.forbidOk and score.numericOk
|
|
return score
|
|
|
|
|
|
def _checkPattern(question: _Question, run: _RunResult) -> bool:
|
|
"""Did the agent call one of the expected tools on the expected table with required filters?"""
|
|
if not question.expectedTools:
|
|
return True
|
|
matchingCalls = [
|
|
c for c in run.toolCalls
|
|
if c.get("toolName") in question.expectedTools
|
|
and (not question.expectedTable or c.get("args", {}).get("tableName") == question.expectedTable)
|
|
]
|
|
if not matchingCalls:
|
|
return False
|
|
|
|
if question.expectedAggregate:
|
|
wantAgg = question.expectedAggregate.upper()
|
|
wantField = question.expectedAggregateField
|
|
for c in matchingCalls:
|
|
args = c.get("args", {})
|
|
if c.get("toolName") != "aggregateTable":
|
|
continue
|
|
if (args.get("aggregate") or "").upper() != wantAgg:
|
|
continue
|
|
if wantField and args.get("field") != wantField:
|
|
continue
|
|
if not _filtersSatisfied(question.requiredFilters, args.get("extraFilters") or args.get("filters") or []):
|
|
continue
|
|
return True
|
|
return False
|
|
|
|
if question.requiredFilters:
|
|
for c in matchingCalls:
|
|
args = c.get("args", {})
|
|
filters = args.get("filters") or args.get("extraFilters") or []
|
|
if _filtersSatisfied(question.requiredFilters, filters):
|
|
return True
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def _filtersSatisfied(required: Dict[str, Any], actualFilters: List[Dict[str, Any]]) -> bool:
|
|
if not required:
|
|
return True
|
|
for reqField, reqValue in required.items():
|
|
if reqField.endswith("Like"):
|
|
field = reqField[:-4]
|
|
wanted = str(reqValue)
|
|
ok = any(
|
|
(f.get("field") == field) and (f.get("op", "").upper() in ("LIKE", "ILIKE"))
|
|
and str(f.get("value")) == wanted
|
|
for f in actualFilters
|
|
)
|
|
if not ok:
|
|
return False
|
|
else:
|
|
ok = any(
|
|
f.get("field") == reqField and _filterValueEqual(f.get("value"), reqValue)
|
|
for f in actualFilters
|
|
)
|
|
if not ok:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _filterValueEqual(a: Any, b: Any) -> bool:
|
|
if a == b:
|
|
return True
|
|
try:
|
|
return str(a).strip() == str(b).strip()
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _checkForbid(question: _Question, run: _RunResult) -> bool:
|
|
"""Did the agent AVOID forbidden tool/op combinations?
|
|
|
|
Forbidden hits only count if the call actually went through to the
|
|
provider (success=True). Validator-rejected calls don't count -- the
|
|
Repair-Loop is doing its job and steering the agent away.
|
|
"""
|
|
if not question.forbiddenTools:
|
|
return True
|
|
forbiddenSet = set(question.forbiddenTools)
|
|
for r in run.toolResults:
|
|
if not r.get("success"):
|
|
continue
|
|
if r.get("toolName") in forbiddenSet:
|
|
return False
|
|
return True
|
|
|
|
|
|
def _checkNumeric(question: _Question, run: _RunResult) -> bool:
|
|
text = (run.finalText or "")
|
|
if question.expectedNumbers:
|
|
textNumbers = _extractNumbers(text)
|
|
for expected in question.expectedNumbers:
|
|
tol = max(abs(expected) * question.numericTolerance, 0.5)
|
|
if not any(abs(n - expected) <= tol for n in textNumbers):
|
|
return False
|
|
|
|
if question.expectedAnswerContains:
|
|
lowered = text.lower()
|
|
for needle in question.expectedAnswerContains:
|
|
if needle.lower() not in lowered:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def _extractNumbers(text: str) -> List[float]:
|
|
"""Pick out all numbers from a free-text answer.
|
|
|
|
Handles Swiss thousand separators (apostrophe and U+2019), German
|
|
decimals (comma), plain integers/floats, and JSON numbers. Trailing
|
|
punctuation (``,``, ``;``, ``.`` from end-of-sentence) is stripped
|
|
before parsing so ``"180500.0,"`` parses cleanly to 180500.0.
|
|
"""
|
|
cleaned = text.replace("\u2019", "'")
|
|
tokens = re.findall(r"-?\d[\d'.,]*", cleaned)
|
|
out: List[float] = []
|
|
for tok in tokens:
|
|
tok = tok.rstrip(",;")
|
|
if tok.endswith(".") and tok.count(".") == 1:
|
|
tok = tok[:-1]
|
|
norm = tok.replace("'", "")
|
|
if norm.count(",") == 1 and norm.count(".") == 0:
|
|
norm = norm.replace(",", ".")
|
|
elif norm.count(",") >= 1 and norm.count(".") >= 1:
|
|
if norm.rfind(",") > norm.rfind("."):
|
|
norm = norm.replace(".", "").replace(",", ".")
|
|
else:
|
|
norm = norm.replace(",", "")
|
|
else:
|
|
norm = norm.replace(",", "")
|
|
try:
|
|
out.append(float(norm))
|
|
except ValueError:
|
|
continue
|
|
return out
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AI call wiring
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _bootstrapServices() -> Tuple[Any, str, str]:
|
|
"""Spin up a minimal service hub bound to the root user + initial mandate.
|
|
|
|
Returns the ServiceHub, the user id, and the mandate id used for billing.
|
|
"""
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.datamodels.datamodelUam import Mandate
|
|
from modules.serviceHub import getInterface as getServices
|
|
|
|
rootInterface = getRootInterface()
|
|
user = rootInterface.currentUser
|
|
mandateId = rootInterface.getInitialId(Mandate)
|
|
if not mandateId:
|
|
raise RuntimeError("No initial mandate available -- run bootstrap loader first.")
|
|
services = getServices(user, workflow=None, mandateId=mandateId, featureInstanceId=None)
|
|
return services, user.id, mandateId
|
|
|
|
|
|
async def _runOneQuestion(
|
|
*,
|
|
services: Any,
|
|
userId: str,
|
|
mandateId: str,
|
|
fixture: BenchmarkFixture,
|
|
question: _Question,
|
|
mode: _ModeConfig,
|
|
) -> _RunResult:
|
|
"""Execute a single sub-agent run for one question under one mode."""
|
|
provider = FakeFeatureDataProvider(
|
|
rowsByTable=fixture.rowsByTable,
|
|
availableTables=fixture.selectedTables,
|
|
)
|
|
validator = _buildValidator(mode)
|
|
registry = _buildSubAgentTools(
|
|
provider=provider,
|
|
featureInstanceId=fixture.featureInstanceId,
|
|
mandateId=fixture.mandateId,
|
|
tableFilters={},
|
|
validator=validator,
|
|
)
|
|
|
|
systemPrompt = _buildSystemPrompt(
|
|
featureCode="trustee",
|
|
instanceLabel="Demo AG",
|
|
selectedTables=fixture.selectedTables,
|
|
)
|
|
|
|
cost = 0.0
|
|
|
|
async def _aiCallFn(req: AiCallRequest) -> AiCallResponse:
|
|
nonlocal cost
|
|
resp = await services.ai.callAi(req)
|
|
cost += float(getattr(resp, "priceCHF", 0.0) or 0.0)
|
|
return resp
|
|
|
|
async def _getCost() -> float:
|
|
return cost
|
|
|
|
config = AgentConfig(
|
|
maxRounds=6,
|
|
maxCostCHF=0.50,
|
|
operationType=OperationTypeEnum.DATA_QUERY,
|
|
)
|
|
|
|
run = _RunResult(questionId=question.id, finalText="")
|
|
t0 = time.time()
|
|
try:
|
|
async for event in runAgentLoop(
|
|
prompt=question.question,
|
|
toolRegistry=registry,
|
|
config=config,
|
|
aiCallFn=_aiCallFn,
|
|
getWorkflowCostFn=_getCost,
|
|
workflowId=f"eval-{mode.name}-{question.id}-{uuid.uuid4().hex[:6]}",
|
|
userId=userId,
|
|
featureInstanceId=fixture.featureInstanceId,
|
|
mandateId=mandateId,
|
|
systemPromptOverride=systemPrompt,
|
|
):
|
|
if event.type == AgentEventTypeEnum.FINAL:
|
|
run.finalText = event.content or run.finalText
|
|
elif event.type == AgentEventTypeEnum.MESSAGE and event.content:
|
|
run.finalText += event.content
|
|
elif event.type == AgentEventTypeEnum.TOOL_CALL:
|
|
run.toolCalls.append(dict(event.data or {}))
|
|
elif event.type == AgentEventTypeEnum.TOOL_RESULT:
|
|
run.toolResults.append(dict(event.data or {}))
|
|
elif event.type == AgentEventTypeEnum.AGENT_SUMMARY:
|
|
run.summary = dict(event.data or {})
|
|
elif event.type == AgentEventTypeEnum.ERROR:
|
|
run.error = (run.error or "") + (event.content or "")
|
|
except Exception as e:
|
|
run.error = f"{type(e).__name__}: {e}"
|
|
logger.exception("Sub-agent run failed for %s/%s", mode.name, question.id)
|
|
run.durationS = time.time() - t0
|
|
return run
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Report
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class _ModeReport:
|
|
mode: _ModeConfig
|
|
perQuestion: List[Tuple[_Question, _RunResult, _Score]] = field(default_factory=list)
|
|
|
|
@property
|
|
def total(self) -> int:
|
|
return len(self.perQuestion)
|
|
|
|
def _count(self, attr: str) -> int:
|
|
return sum(1 for _, _, s in self.perQuestion if getattr(s, attr))
|
|
|
|
@property
|
|
def accuracy(self) -> float:
|
|
return self._count("accuracyOk") / max(self.total, 1)
|
|
|
|
@property
|
|
def patternCompliance(self) -> float:
|
|
return self._count("patternOk") / max(self.total, 1)
|
|
|
|
@property
|
|
def repairConversionRate(self) -> float:
|
|
attempts = sum(r.repairAttempts for _, r, _ in self.perQuestion)
|
|
succeeded = sum(r.successAfterRepair for _, r, _ in self.perQuestion)
|
|
if attempts == 0:
|
|
return 0.0
|
|
return succeeded / attempts
|
|
|
|
@property
|
|
def totalCostCHF(self) -> float:
|
|
return sum(r.costCHF for _, r, _ in self.perQuestion)
|
|
|
|
@property
|
|
def totalRounds(self) -> int:
|
|
return sum(r.rounds for _, r, _ in self.perQuestion)
|
|
|
|
@property
|
|
def totalValidationFailures(self) -> int:
|
|
return sum(r.validationFailures for _, r, _ in self.perQuestion)
|
|
|
|
|
|
def _writeReport(reports: List[_ModeReport], outputPath: Path) -> None:
|
|
lines: List[str] = []
|
|
lines.append("# Trustee Sub-Agent Benchmark Report")
|
|
lines.append("")
|
|
lines.append(f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
lines.append("")
|
|
lines.append("## Summary")
|
|
lines.append("")
|
|
lines.append("| Mode | Questions | Accuracy | Pattern compliance | Repair conversion | Validator rejects | Rounds | Cost (CHF) |")
|
|
lines.append("|---|---|---|---|---|---|---|---|")
|
|
for rep in reports:
|
|
lines.append(
|
|
f"| {rep.mode.label} | {rep.total} | {rep.accuracy:.1%} | {rep.patternCompliance:.1%} | "
|
|
f"{rep.repairConversionRate:.1%} | {rep.totalValidationFailures} | {rep.totalRounds} | "
|
|
f"{rep.totalCostCHF:.4f} |"
|
|
)
|
|
lines.append("")
|
|
lines.append("## Per-question detail")
|
|
for rep in reports:
|
|
lines.append("")
|
|
lines.append(f"### {rep.mode.label}")
|
|
lines.append("")
|
|
lines.append("| id | acc | pattern | forbid | numeric | rounds | val-fail | repairs | cost CHF | duration | tools |")
|
|
lines.append("|---|---|---|---|---|---|---|---|---|---|---|")
|
|
for q, r, s in rep.perQuestion:
|
|
toolList = ",".join(
|
|
f"{c.get('toolName')}({c.get('args',{}).get('tableName','?')})"
|
|
for c in r.toolCalls
|
|
)
|
|
lines.append(
|
|
f"| {q.id} | {_yn(s.accuracyOk)} | {_yn(s.patternOk)} | {_yn(s.forbidOk)} | {_yn(s.numericOk)} | "
|
|
f"{r.rounds} | {r.validationFailures} | {r.repairAttempts}/{r.successAfterRepair} | "
|
|
f"{r.costCHF:.4f} | {r.durationS:.1f}s | {toolList} |"
|
|
)
|
|
lines.append("")
|
|
lines.append("#### Notes & failures")
|
|
for q, r, s in rep.perQuestion:
|
|
if s.accuracyOk:
|
|
continue
|
|
lines.append(f"- **{q.id}** ({q.intent}): pattern={s.patternOk} forbid={s.forbidOk} numeric={s.numericOk}")
|
|
if r.error:
|
|
lines.append(f" - error: `{r.error}`")
|
|
lines.append(f" - answer: `{(r.finalText or '').strip().replace('|', '/').splitlines()[0][:240]}`")
|
|
for note in s.notes:
|
|
lines.append(f" - note: {note}")
|
|
outputPath.parent.mkdir(parents=True, exist_ok=True)
|
|
outputPath.write_text("\n".join(lines), encoding="utf-8")
|
|
|
|
|
|
def _yn(b: bool) -> str:
|
|
return "OK" if b else "FAIL"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def _runMain(modesToRun: List[str], onlyQuestionId: Optional[str] = None) -> None:
|
|
logging.basicConfig(
|
|
level=logging.WARNING,
|
|
format="%(asctime)s %(levelname)s %(name)s -- %(message)s",
|
|
)
|
|
logger.setLevel(logging.INFO)
|
|
|
|
fixture = buildTrusteeBenchmarkFixture()
|
|
questionsPath = _GATEWAY_DIR / "tests" / "fixtures" / "trusteeBenchmark" / "questions.yaml"
|
|
allQuestions = _loadQuestions(questionsPath)
|
|
if onlyQuestionId:
|
|
allQuestions = [q for q in allQuestions if q.id == onlyQuestionId]
|
|
if not allQuestions:
|
|
print(f"No question matches id={onlyQuestionId!r}")
|
|
return
|
|
|
|
print(f"Loaded {len(allQuestions)} questions, {len(modesToRun)} modes -> {len(allQuestions) * len(modesToRun)} sub-agent runs.")
|
|
|
|
services, userId, mandateId = _bootstrapServices()
|
|
print(f"Bootstrap OK: user={userId}, mandate={mandateId}")
|
|
|
|
reports: List[_ModeReport] = []
|
|
for modeName in modesToRun:
|
|
mode = _MODES[modeName]
|
|
_applyEnvForMode(mode)
|
|
rep = _ModeReport(mode=mode)
|
|
print(f"\n=== Mode: {mode.label} ===")
|
|
for idx, question in enumerate(allQuestions, start=1):
|
|
print(f" [{idx:>2}/{len(allQuestions)}] {question.id}: {question.question[:80]} ...", flush=True)
|
|
run = await _runOneQuestion(
|
|
services=services,
|
|
userId=userId,
|
|
mandateId=mandateId,
|
|
fixture=fixture,
|
|
question=question,
|
|
mode=mode,
|
|
)
|
|
score = _scoreRun(question, run)
|
|
rep.perQuestion.append((question, run, score))
|
|
print(
|
|
f" -> acc={_yn(score.accuracyOk)} "
|
|
f"pattern={_yn(score.patternOk)} forbid={_yn(score.forbidOk)} "
|
|
f"numeric={_yn(score.numericOk)} rounds={run.rounds} cost={run.costCHF:.4f} "
|
|
f"val-fail={run.validationFailures} repairs={run.repairAttempts}/{run.successAfterRepair}",
|
|
flush=True,
|
|
)
|
|
reports.append(rep)
|
|
|
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
|
outDir = _GATEWAY_DIR.parent / "local" / "notes"
|
|
reportPath = outDir / f"trustee-benchmark-{timestamp}.md"
|
|
_writeReport(reports, reportPath)
|
|
|
|
rawJsonPath = outDir / f"trustee-benchmark-{timestamp}.json"
|
|
rawJsonPath.write_text(
|
|
json.dumps(
|
|
[
|
|
{
|
|
"mode": rep.mode.name,
|
|
"accuracy": rep.accuracy,
|
|
"patternCompliance": rep.patternCompliance,
|
|
"repairConversionRate": rep.repairConversionRate,
|
|
"totalCostCHF": rep.totalCostCHF,
|
|
"totalRounds": rep.totalRounds,
|
|
"totalValidationFailures": rep.totalValidationFailures,
|
|
"items": [
|
|
{
|
|
"questionId": q.id,
|
|
"intent": q.intent,
|
|
"accuracyOk": s.accuracyOk,
|
|
"patternOk": s.patternOk,
|
|
"forbidOk": s.forbidOk,
|
|
"numericOk": s.numericOk,
|
|
"rounds": r.rounds,
|
|
"validationFailures": r.validationFailures,
|
|
"repairAttempts": r.repairAttempts,
|
|
"successAfterRepair": r.successAfterRepair,
|
|
"costCHF": r.costCHF,
|
|
"durationS": r.durationS,
|
|
"finalText": (r.finalText or "")[:600],
|
|
"toolCalls": r.toolCalls,
|
|
"error": r.error,
|
|
}
|
|
for q, r, s in rep.perQuestion
|
|
],
|
|
}
|
|
for rep in reports
|
|
],
|
|
indent=2,
|
|
ensure_ascii=False,
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
print(f"\nReport written: {reportPath}")
|
|
print(f"Raw JSON: {rawJsonPath}")
|
|
for rep in reports:
|
|
print(f" {rep.mode.label}: acc={rep.accuracy:.1%} pattern={rep.patternCompliance:.1%} cost={rep.totalCostCHF:.4f}")
|
|
|
|
|
|
def _parseArgs(argv: List[str]) -> Tuple[List[str], Optional[str]]:
|
|
modes: List[str] = []
|
|
only: Optional[str] = None
|
|
for arg in argv:
|
|
if arg.startswith("--only="):
|
|
only = arg.split("=", 1)[1]
|
|
elif arg in _MODES:
|
|
modes.append(arg)
|
|
else:
|
|
print(f"Unknown argument: {arg!r}. Allowed modes: {list(_MODES)}")
|
|
sys.exit(2)
|
|
if not modes:
|
|
modes = ["baseline", "phase1", "phase2"]
|
|
return modes, only
|
|
|
|
|
|
def main() -> None:
|
|
modes, only = _parseArgs(sys.argv[1:])
|
|
asyncio.run(_runMain(modes, onlyQuestionId=only))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|