Merge pull request #164 from valueonag/feat/demo-system-readieness

rag enhancements
This commit is contained in:
Patrick Motsch 2026-05-16 23:02:23 +02:00 committed by GitHub
commit f5aba4bf99
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 4809 additions and 398 deletions

5
app.py
View file

@ -404,8 +404,10 @@ async def lifespan(app: FastAPI):
try:
from modules.serviceCenter.services.serviceBackgroundJobs.mainBackgroundJobService import (
recoverInterruptedJobs,
registerZombieKillerScheduler,
)
recoverInterruptedJobs()
registerZombieKillerScheduler(intervalMinutes=5)
except Exception as e:
logger.warning(f"BackgroundJob recovery failed (non-critical): {e}")
@ -607,6 +609,9 @@ app.include_router(connectionsRouter)
from modules.routes.routeRagInventory import router as ragInventoryRouter
app.include_router(ragInventoryRouter)
from modules.routes.routeAdminSttBenchmark import router as sttBenchmarkRouter
app.include_router(sttBenchmarkRouter)
from modules.routes.routeTableViews import router as tableViewsRouter
app.include_router(tableViewsRouter)

View file

@ -319,25 +319,24 @@ class AiOpenai(BaseConnectorAi):
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00013
),
AiModel(
name="dall-e-3",
displayName="OpenAI DALL-E 3",
name="gpt-image-1",
displayName="OpenAI GPT Image",
connectorType="openai",
apiUrl="https://api.openai.com/v1/images/generations",
temperature=0.0, # Image generation doesn't use temperature
maxTokens=0, # Image generation doesn't use tokens
temperature=0.0,
maxTokens=0,
contextLength=0,
costPer1kTokensInput=0.04,
costPer1kTokensOutput=0.0,
speedRating=5, # Slow for image generation
qualityRating=9, # High quality art generation
# capabilities removed (not used in business logic)
speedRating=5,
qualityRating=9,
functionCall=self.generateImage,
priority=PriorityEnum.QUALITY,
processingMode=ProcessingModeEnum.DETAILED,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.IMAGE_GENERATE, 10)
),
version="dall-e-3",
version="gpt-image-1",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.04
)
]
@ -653,105 +652,82 @@ class AiOpenai(BaseConnectorAi):
)
async def generateImage(self, modelCall: AiModelCall) -> AiModelResponse:
"""
Generate an image using DALL-E 3 using standardized pattern.
Args:
modelCall: AiModelCall with messages and generation options
Returns:
AiModelResponse with generated image data
"""
"""Generate an image using GPT Image model (gpt-image-1)."""
try:
# Extract parameters from modelCall
messages = modelCall.messages
model = modelCall.model
options = modelCall.options
# Get prompt from messages
promptContent = messages[0]["content"] if messages else ""
# Parse prompt using AiCallPromptImage model
import json
messages = modelCall.messages
options = modelCall.options
promptContent = messages[0]["content"] if messages else ""
try:
# Try to parse as JSON
promptData = json.loads(promptContent)
promptModel = AiCallPromptImage(**promptData)
except:
# If not JSON, use plain text prompt
except Exception:
promptModel = AiCallPromptImage(
prompt=promptContent,
size=options.size if options and hasattr(options, 'size') else "1024x1024",
quality=options.quality if options and hasattr(options, 'quality') else "standard",
style=options.style if options and hasattr(options, 'style') else "vivid"
size=options.size if options and hasattr(options, "size") else "1024x1024",
quality=options.quality if options and hasattr(options, "quality") else "auto",
)
# Extract parameters from Pydantic model
prompt = promptModel.prompt
size = promptModel.size or "1024x1024"
quality = promptModel.quality or "standard"
style = promptModel.style or "vivid"
rawQuality = promptModel.quality or "auto"
quality = {"standard": "auto", "hd": "high"}.get(rawQuality, rawQuality)
logger.debug(f"Starting image generation with prompt: '{prompt[:100]}...'")
# DALL-E 3 API endpoint
dalle_url = "https://api.openai.com/v1/images/generations"
payload = {
"model": "dall-e-3",
"model": "gpt-image-1",
"prompt": prompt,
"size": size,
"quality": quality,
"style": style,
"n": 1,
"response_format": "b64_json" # Get base64 data directly instead of URLs
}
# Use existing httpClient to benefit from connection pooling
# This avoids TLS connection issues that can occur with fresh clients
response = await self.httpClient.post(
dalle_url,
json=payload
"https://api.openai.com/v1/images/generations",
json=payload,
)
if response.status_code != 200:
logger.error(f"DALL-E API error: {response.status_code} - {response.text}")
logger.error(f"Image generation API error: {response.status_code} - {response.text}")
return AiModelResponse(
content="",
success=False,
error=f"DALL-E API error: {response.status_code} - {response.text}"
error=f"Image generation API error: {response.status_code} - {response.text}",
)
responseJson = response.json()
if "data" in responseJson and len(responseJson["data"]) > 0:
image_data = responseJson["data"][0]["b64_json"]
logger.info(f"Successfully generated image: {len(image_data)} characters")
imageData = responseJson["data"][0].get("b64_json", "")
if not imageData:
imageData = responseJson["data"][0].get("url", "")
logger.info(f"Successfully generated image: {len(imageData)} characters")
return AiModelResponse(
content=image_data,
content=imageData,
success=True,
modelId="dall-e-3",
modelId="gpt-image-1",
metadata={
"size": size,
"quality": quality,
"style": style,
"response_id": responseJson.get("id", "")
}
"response_id": responseJson.get("id", ""),
},
)
else:
logger.error("No image data in DALL-E response")
logger.error("No image data in generation response")
return AiModelResponse(
content="",
success=False,
error="No image data in DALL-E response"
error="No image data in generation response",
)
except Exception as e:
logger.error(f"Error during image generation: {str(e)}", exc_info=True)
return AiModelResponse(
content="",
success=False,
error=f"Error during image generation: {str(e)}"
error=f"Error during image generation: {str(e)}",
)

View file

@ -311,7 +311,10 @@ class DatabaseConnector:
# Establish connection to the database
self._connect()
logger.info("PostgreSQL database system initialized successfully")
logger.debug(
"PostgreSQL database system initialized (db=%s, host=%s, port=%s)",
self.dbDatabase, self.dbHost, self.dbPort,
)
except Exception as e:
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
raise

View file

@ -245,11 +245,10 @@ class AiCallPromptWebCrawl(BaseModel):
class AiCallPromptImage(BaseModel):
"""Structured prompt format for image generation."""
prompt: str = Field(description="Text description of the image to generate")
size: Optional[str] = Field(default="1024x1024", description="Image size (1024x1024, 1792x1024, 1024x1792)")
quality: Optional[str] = Field(default="standard", description="Image quality (standard, hd)")
style: Optional[str] = Field(default="vivid", description="Image style (vivid, natural)")
size: Optional[str] = Field(default="1024x1024", description="Image size (1024x1024, 1536x1024, 1024x1536)")
quality: Optional[str] = Field(default="auto", description="Image quality (auto, high, medium, low)")
class AiProcessParameters(BaseModel):

View file

@ -754,14 +754,35 @@ ANTI-PATTERNS (do NOT do this):
"""
# Parked for one release as a fallback while the ontology-based path rolls
# out (see `trusteeOntology.getTrusteeOntology()`). Remove together with the
# legacy ``_loadFeatureDomainHints`` path once Phase 2 is the only supplier
# of the trustee prompt block.
_AGENT_DOMAIN_HINTS_LEGACY = _AGENT_DOMAIN_HINTS
def getAgentDomainHints() -> str:
"""Return Trustee-specific guidance for the Feature Data Sub-Agent.
The text is appended verbatim to the sub-agent's system prompt by
``featureDataAgent._buildSchemaContext``. Keep it concise and
pattern-driven every line costs tokens on every sub-agent call.
Deprecated as of Phase 2 (2026-05). Prefer ``getAgentOntology()`` ->
``ontologyToPromptCompiler.compileOntologyToPrompt(...)``. The legacy
text remains available so callers that still go through
``_buildSchemaContext()`` keep working during the migration window.
"""
return _AGENT_DOMAIN_HINTS
return _AGENT_DOMAIN_HINTS_LEGACY
def getAgentOntology():
"""Return the structured ontology used by the Feature Data Sub-Agent.
Discovered by ``featureDataAgent._buildSchemaContext`` (Phase 2 path):
when this hook is present, the agent compiles its domain block from
the ontology instead of using the legacy free-text hints. The same
descriptor feeds the validator's NEVER_AGGREGATE constraints, so
prompt and validator stay in sync.
"""
from modules.features.trustee.trusteeOntology import getTrusteeOntology
return getTrusteeOntology()
def registerFeature(catalogService) -> bool:

View file

@ -0,0 +1,295 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Trustee feature ontology (Phase 2 pilot).
Replaces the hand-written ``_AGENT_DOMAIN_HINTS`` block with a structured
ontology so the Feature Data Sub-Agent's QueryValidator AND the prompt
compiler share the same source of truth: account-group conventions,
period-bucket semantics, the NEVER_AGGREGATE constraints on already-
aggregated columns, and canonical tool-call templates for the most
frequent user intents.
Both the validator (deterministic enforcement) and the prompt compiler
(LLM steering) read from this descriptor, so an LLM that follows the
prompt patterns will never trigger a validator failure -- and one that
ignores them gets a structured repair hint pointing back at the same
constraint.
The legacy ``_AGENT_DOMAIN_HINTS_LEGACY`` block stays parked in
``mainTrustee.py`` for one release as a fallback during rollout.
"""
from __future__ import annotations
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
CanonicalQueryPattern,
Cardinality,
Constraint,
ConstraintRule,
Entity,
Invariant,
OntologyDescriptor,
Relation,
SemanticType,
)
# ---------------------------------------------------------------------------
# Entities
# ---------------------------------------------------------------------------
_ENTITIES = [
Entity(
name="Account",
pythonClass="TrusteeDataAccount",
semanticType=SemanticType.ACCOUNT,
description=(
"Chart-of-accounts row (Konto). One row per accountNumber per "
"mandate. Identifies the account, never holds balances."
),
invariants=[
Invariant(description="accountNumber is a stable string identifier (e.g. '1020', '5400')."),
Invariant(description="accountType is one of: asset / liability / revenue / expense."),
],
),
Entity(
name="BankAccount",
pythonClass="TrusteeDataAccount",
semanticType=SemanticType.ACCOUNT,
parentEntity="Account",
description="Account subgroup with accountNumber LIKE '102%' (ZKB, PostFinance, UBS, ...).",
),
Entity(
name="CashAccount",
pythonClass="TrusteeDataAccount",
semanticType=SemanticType.ACCOUNT,
parentEntity="Account",
description="Account subgroup with accountNumber LIKE '100%' (Hauptkasse, Nebenkassen).",
),
Entity(
name="AccountBalance",
pythonClass="TrusteeDataAccountBalance",
semanticType=SemanticType.BALANCE_SNAPSHOT,
description=(
"Period-bucketed snapshot: one row per (account, year, month). "
"closingBalance is THE balance at end of period -- already aggregated."
),
invariants=[
Invariant(description="periodMonth=0 means annual total of periodYear (use for 'per 31.12.YYYY')."),
Invariant(description="periodMonth in 1..12 means month-end snapshot."),
Invariant(description="closingBalance is the balance at period end; openingBalance at period start."),
Invariant(description="debitTotal/creditTotal are turnovers for the period, NOT balances."),
],
),
Entity(
name="JournalEntry",
pythonClass="TrusteeDataJournalEntry",
semanticType=SemanticType.TRANSACTION,
description="One booking header (Beleg). Has a bookingDate (unix seconds float) and totalAmount.",
invariants=[
Invariant(description="bookingDate is a UTC unix-seconds float; never compare against ISO strings."),
],
),
Entity(
name="JournalLine",
pythonClass="TrusteeDataJournalLine",
semanticType=SemanticType.TRANSACTION,
description="One booking line of a JournalEntry. Each line debits or credits exactly one account.",
invariants=[
Invariant(description="Per line either debitAmount > 0 (Soll) or creditAmount > 0 (Haben), not both."),
],
),
]
# ---------------------------------------------------------------------------
# Relations
# ---------------------------------------------------------------------------
_RELATIONS = [
Relation(fromEntity="AccountBalance", toEntity="Account", cardinality=Cardinality.MANY_TO_ONE, via="accountNumber"),
Relation(fromEntity="JournalLine", toEntity="JournalEntry", cardinality=Cardinality.MANY_TO_ONE, via="journalEntryId"),
Relation(fromEntity="JournalLine", toEntity="Account", cardinality=Cardinality.MANY_TO_ONE, via="accountNumber"),
]
# ---------------------------------------------------------------------------
# Constraints (validator-enforced)
# ---------------------------------------------------------------------------
_CONSTRAINTS = [
# closingBalance is the single biggest hallucination magnet -- it's a
# balance per period, summing it across periods or accounts is meaningless.
Constraint(
appliesTo="TrusteeDataAccountBalance.closingBalance",
rule=ConstraintRule.NEVER_AGGREGATE,
message=(
"closingBalance is per-period already; query with periodYear+periodMonth, never SUM/AVG it."
),
),
Constraint(
appliesTo="TrusteeDataAccountBalance.openingBalance",
rule=ConstraintRule.NEVER_AGGREGATE,
message="openingBalance is already a balance per period; do not SUM/AVG it across rows.",
),
Constraint(
appliesTo="TrusteeDataAccountBalance.debitTotal",
rule=ConstraintRule.NEVER_AGGREGATE,
message=(
"debitTotal is the period's debit TURNOVER; do not SUM it without an explicit period filter."
),
),
Constraint(
appliesTo="TrusteeDataAccountBalance.creditTotal",
rule=ConstraintRule.NEVER_AGGREGATE,
message="creditTotal is a per-period turnover; do not SUM it across periods without an explicit period filter.",
),
# AccountBalance queries without a period filter are almost always wrong --
# they conflate annual and monthly snapshots. Phase 2 (REQUIRES_FILTER_ON)
# is wired through to the validator in a later iteration; for now this
# rule is rendered into the prompt compiler so the LLM sees it explicitly.
Constraint(
appliesTo="TrusteeDataAccountBalance",
rule=ConstraintRule.REQUIRES_FILTER_ON,
message=(
"Always filter on periodYear AND periodMonth (use periodMonth=0 for end-of-year)."
),
params={"requiredFields": ["periodYear", "periodMonth"]},
),
Constraint(
appliesTo="TrusteeDataAccountBalance",
rule=ConstraintRule.PREFERRED_TABLE_FOR_INTENT,
message="For 'Saldo per <date>' and 'Stand <year>' questions, prefer AccountBalance over JournalLine.",
params={"intents": ["BANK_BALANCE_AT_DATE", "BALANCE_AT_YEAR_END"]},
),
]
# ---------------------------------------------------------------------------
# Canonical query patterns (worked examples for the LLM)
# ---------------------------------------------------------------------------
_CANONICAL_PATTERNS = [
CanonicalQueryPattern(
intent="BANK_BALANCE_AT_DATE",
description="Saldo eines Bankkontos per Jahresende.",
pattern={
"tool": "queryTable",
"tableName": "TrusteeDataAccountBalance",
"filters": [
{"field": "accountNumber", "op": "=", "value": "<accountNumber>"},
{"field": "periodYear", "op": "=", "value": "<year>"},
{"field": "periodMonth", "op": "=", "value": 0},
],
"fields": ["closingBalance", "currency"],
},
),
CanonicalQueryPattern(
intent="BANK_GROUP_TOTAL_AT_DATE",
description="Summe einer Kontogruppe (z. B. alle Bankkonten 102%) per Jahresende.",
pattern={
"tool": "queryTable",
"tableName": "TrusteeDataAccountBalance",
"filters": [
{"field": "accountNumber", "op": "LIKE", "value": "<prefix>%"},
{"field": "periodYear", "op": "=", "value": "<year>"},
{"field": "periodMonth", "op": "=", "value": 0},
],
"fields": ["accountNumber", "closingBalance", "currency"],
"_postProcessing": "Sum closingBalance values in your final answer; do NOT SUM via aggregateTable.",
},
),
CanonicalQueryPattern(
intent="BALANCE_HISTORY_PER_YEAR",
description="Saldo-Verlauf eines Kontos ueber mehrere Jahre.",
pattern={
"tool": "queryTable",
"tableName": "TrusteeDataAccountBalance",
"filters": [
{"field": "accountNumber", "op": "=", "value": "<accountNumber>"},
{"field": "periodMonth", "op": "=", "value": 0},
],
"fields": ["periodYear", "closingBalance", "currency"],
"orderBy": "periodYear",
},
),
CanonicalQueryPattern(
intent="MONTHLY_BALANCE_SNAPSHOT",
description="Saldo per Ende eines bestimmten Monats.",
pattern={
"tool": "queryTable",
"tableName": "TrusteeDataAccountBalance",
"filters": [
{"field": "accountNumber", "op": "=", "value": "<accountNumber>"},
{"field": "periodYear", "op": "=", "value": "<year>"},
{"field": "periodMonth", "op": "=", "value": "<month 1..12>"},
],
"fields": ["closingBalance", "currency"],
},
),
CanonicalQueryPattern(
intent="ACCOUNT_LIST_BY_TYPE_OR_PREFIX",
description="Welche Konten gehoeren zu einer Gruppe (Typ oder Nummern-Prefix)?",
pattern={
"tool": "queryTable",
"tableName": "TrusteeDataAccount",
"filters": [
{"field": "accountNumber", "op": "LIKE", "value": "<prefix>%"},
],
"fields": ["accountNumber", "label", "accountType"],
},
),
CanonicalQueryPattern(
intent="JOURNAL_SUM_AT_ACCOUNT",
description="Summe der Soll- oder Haben-Buchungen auf einem Konto.",
pattern={
"tool": "aggregateTable",
"tableName": "TrusteeDataJournalLine",
"aggregate": "SUM",
"field": "debitAmount",
"filters": [
{"field": "accountNumber", "op": "=", "value": "<accountNumber>"},
],
},
),
CanonicalQueryPattern(
intent="COUNT_ROWS",
description="Anzahl Buchungen / Buchungszeilen / Konten.",
pattern={
"tool": "aggregateTable",
"tableName": "<table>",
"aggregate": "COUNT",
"field": "id",
},
),
CanonicalQueryPattern(
intent="JOURNAL_LINES_BY_AMOUNT",
description="Buchungszeilen mit einem Betrag groesser/kleiner als einer Schwelle.",
pattern={
"tool": "queryTable",
"tableName": "TrusteeDataJournalLine",
"filters": [
{"field": "debitAmount", "op": ">", "value": "<amount>"},
],
"fields": ["accountNumber", "debitAmount", "description"],
},
),
]
_TRUSTEE_ONTOLOGY = OntologyDescriptor(
featureCode="trustee",
entities=_ENTITIES,
relations=_RELATIONS,
constraints=_CONSTRAINTS,
canonicalPatterns=_CANONICAL_PATTERNS,
)
def getTrusteeOntology() -> OntologyDescriptor:
"""Public accessor for the trustee ontology.
Cached as a module-level singleton -- the descriptor is immutable and
has no per-call state.
"""
return _TRUSTEE_ONTOLOGY

View file

@ -0,0 +1,217 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""STT Benchmark route — compare Speech-to-Text v1 (latest_long) vs v2 (Chirp 2).
Sysadmin-only page for evaluating STT model quality and latency.
"""
import json
import time
import logging
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Depends, Request, UploadFile, File, Form
from modules.auth import limiter, getCurrentUser
from modules.datamodels.datamodelUam import User
from modules.shared.configuration import APP_CONFIG
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/admin/stt-benchmark",
tags=["Admin STT Benchmark"],
responses={401: {"description": "Unauthorized"}, 403: {"description": "Forbidden"}},
)
def _requireSysAdmin(currentUser: User = Depends(getCurrentUser)) -> User:
if not getattr(currentUser, "isSysAdmin", False) and not getattr(currentUser, "isPlatformAdmin", False):
raise HTTPException(status_code=403, detail="SysAdmin required")
return currentUser
def _getCredentials():
apiKey = APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET")
if not apiKey or apiKey.startswith("YOUR_"):
raise HTTPException(status_code=500, detail="Google Speech API key not configured")
from google.oauth2 import service_account
return service_account.Credentials.from_service_account_info(json.loads(apiKey))
def _runV1(audioBytes: bytes, language: str, model: str) -> Dict[str, Any]:
"""Run Speech-to-Text v1 recognition."""
from google.cloud import speech
credentials = _getCredentials()
client = speech.SpeechClient(credentials=credentials)
config = speech.RecognitionConfig(
encoding=speech.RecognitionConfig.AudioEncoding.ENCODING_UNSPECIFIED,
language_code=language,
model=model,
enable_automatic_punctuation=True,
enable_word_time_offsets=True,
enable_word_confidence=True,
max_alternatives=3,
use_enhanced=True,
)
audio = speech.RecognitionAudio(content=audioBytes)
t0 = time.perf_counter()
response = client.recognize(config=config, audio=audio)
elapsed = time.perf_counter() - t0
results = []
for r in response.results:
for alt in r.alternatives:
results.append({
"transcript": alt.transcript,
"confidence": round(alt.confidence, 4),
"words": len(alt.words) if alt.words else 0,
})
return {
"api": "v1",
"model": model,
"latencyMs": round(elapsed * 1000, 1),
"results": results,
"resultCount": len(response.results),
}
def _runV2(audioBytes: bytes, language: str, model: str, location: str) -> Dict[str, Any]:
"""Run Speech-to-Text v2 recognition (Chirp 2)."""
from google.cloud.speech_v2 import SpeechClient
from google.cloud.speech_v2.types import cloud_speech
credentials = _getCredentials()
credInfo = json.loads(APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET"))
projectId = credInfo.get("project_id", "")
client = SpeechClient(
credentials=credentials,
client_options={"api_endpoint": f"{location}-speech.googleapis.com"},
)
config = cloud_speech.RecognitionConfig(
auto_decoding_config=cloud_speech.AutoDetectDecodingConfig(),
language_codes=[language],
model=model,
features=cloud_speech.RecognitionFeatures(
enable_automatic_punctuation=True,
enable_word_time_offsets=True,
enable_word_confidence=True,
),
)
recognizer = f"projects/{projectId}/locations/{location}/recognizers/_"
request = cloud_speech.RecognizeRequest(
recognizer=recognizer,
config=config,
content=audioBytes,
)
t0 = time.perf_counter()
response = client.recognize(request=request)
elapsed = time.perf_counter() - t0
results = []
for r in response.results:
for alt in r.alternatives:
results.append({
"transcript": alt.transcript,
"confidence": round(alt.confidence, 4),
"words": len(alt.words) if alt.words else 0,
})
return {
"api": "v2",
"model": model,
"location": location,
"latencyMs": round(elapsed * 1000, 1),
"results": results,
"resultCount": len(getattr(response, "results", [])),
}
@router.post("/run")
@limiter.limit("10/minute")
async def runBenchmark(
request: Request,
file: UploadFile = File(...),
language: str = Form(default="de-DE"),
v1Model: str = Form(default="latest_long"),
v2Model: str = Form(default="chirp_2"),
v2Location: str = Form(default="europe-west4"),
currentUser: User = Depends(_requireSysAdmin),
) -> Dict[str, Any]:
"""Upload audio and compare v1 vs v2 STT results."""
audioBytes = await file.read()
if len(audioBytes) > 10 * 1024 * 1024:
raise HTTPException(status_code=400, detail="Audio file too large (max 10 MB)")
if len(audioBytes) < 100:
raise HTTPException(status_code=400, detail="Audio file too small")
logger.info("STT benchmark: %s, %d bytes, language=%s, v1=%s, v2=%s@%s",
file.filename, len(audioBytes), language, v1Model, v2Model, v2Location)
v1Result = None
v1Error = None
try:
v1Result = _runV1(audioBytes, language, v1Model)
except Exception as e:
v1Error = str(e)
logger.warning("STT v1 benchmark failed: %s", e)
v2Result = None
v2Error = None
try:
v2Result = _runV2(audioBytes, language, v2Model, v2Location)
except Exception as e:
v2Error = str(e)
logger.warning("STT v2 benchmark failed: %s", e)
return {
"filename": file.filename,
"fileSizeBytes": len(audioBytes),
"language": language,
"v1": v1Result or {"error": v1Error},
"v2": v2Result or {"error": v2Error},
}
@router.get("/models")
@limiter.limit("30/minute")
async def getAvailableModels(
request: Request,
currentUser: User = Depends(_requireSysAdmin),
) -> Dict[str, Any]:
"""Return available STT models for the benchmark UI."""
return {
"v1Models": [
{"value": "latest_long", "label": "latest_long (default)"},
{"value": "latest_short", "label": "latest_short"},
{"value": "phone_call", "label": "phone_call"},
{"value": "video", "label": "video"},
{"value": "command_and_search", "label": "command_and_search"},
],
"v2Models": [
{"value": "chirp_2", "label": "Chirp 2 (recommended)"},
{"value": "chirp", "label": "Chirp (original)"},
{"value": "long", "label": "long"},
{"value": "short", "label": "short"},
],
"locations": [
{"value": "europe-west4", "label": "Europe West (NL)"},
{"value": "us-central1", "label": "US Central"},
{"value": "asia-southeast1", "label": "Asia Southeast"},
],
"languages": [
{"value": "de-DE", "label": "Deutsch (DE)"},
{"value": "de-CH", "label": "Deutsch (CH)"},
{"value": "en-US", "label": "English (US)"},
{"value": "en-GB", "label": "English (GB)"},
{"value": "fr-FR", "label": "Francais (FR)"},
{"value": "it-IT", "label": "Italiano (IT)"},
],
}

View file

@ -745,7 +745,7 @@ def _findOwnConnection(interface, userId: str, connectionId: str):
@router.patch("/{connectionId}/knowledge-consent")
@limiter.limit("10/minute")
def _updateKnowledgeConsent(
async def _updateKnowledgeConsent(
request: Request,
connectionId: str = Path(..., description="Connection ID"),
enabled: bool = Body(..., embed=True),
@ -780,24 +780,13 @@ def _updateKnowledgeConsent(
from modules.datamodels.datamodelDataSource import DataSource
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId, "ragIndexEnabled": True})
if dataSources:
import asyncio
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
authority = connection.authority.value if hasattr(connection.authority, "value") else str(connection.authority or "")
async def _enqueue():
await startJob(
"connection.bootstrap",
{"connectionId": connectionId, "authority": authority.lower()},
triggeredBy=str(currentUser.id),
)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(_enqueue())
else:
loop.run_until_complete(_enqueue())
except RuntimeError:
asyncio.run(_enqueue())
await startJob(
"connection.bootstrap",
{"connectionId": connectionId, "authority": authority.lower()},
triggeredBy=str(currentUser.id),
)
bootstrapEnqueued = True
import json as _json

View file

@ -129,7 +129,7 @@ def _updateNeutralizeFields(
@router.patch("/{sourceId}/rag-index")
@limiter.limit("30/minute")
def _updateDataSourceRagIndex(
async def _updateDataSourceRagIndex(
request: Request,
sourceId: str = Path(..., description="ID of the DataSource"),
ragIndexEnabled: bool = Body(..., embed=True),
@ -139,6 +139,10 @@ def _updateDataSourceRagIndex(
true: sets flag + enqueues mini-bootstrap for this DataSource only.
false: sets flag + synchronously purges all chunks from this DataSource.
Must be `async def` so `await startJob(...)` registers `_runJob` in the
main event loop. Sync route worker thread temporary loop closes
before the task runs job stays stuck forever.
"""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
@ -152,7 +156,6 @@ def _updateDataSourceRagIndex(
if ragIndexEnabled:
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
import asyncio
connectionId = rec.get("connectionId") or rec.get("connection_id") or ""
conn = rootIf.getUserConnectionById(connectionId) if connectionId else None
@ -160,20 +163,11 @@ def _updateDataSourceRagIndex(
if conn:
authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
async def _enqueue():
await startJob(
"connection.bootstrap",
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]},
triggeredBy=str(context.user.id),
)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(_enqueue())
else:
loop.run_until_complete(_enqueue())
except RuntimeError:
asyncio.run(_enqueue())
await startJob(
"connection.bootstrap",
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]},
triggeredBy=str(context.user.id),
)
else:
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId)

View file

@ -39,20 +39,27 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
chunksByDs: Dict[str, int] = {}
unassigned = 0
for idx in connIndexRows:
prov = (idx.get("provenance") if isinstance(idx, dict) else getattr(idx, "provenance", None)) or {}
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
if dsIdRef:
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
else:
unassigned += 1
seen: Dict[str, bool] = {}
dsItems = []
for ds in dataSources:
dsId = ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")
dsPath = ds.get("path") if isinstance(ds, dict) else getattr(ds, "path", "")
if dsPath in seen:
continue
seen[dsPath] = True
dsItems.append({
"id": dsId,
"label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""),
"path": ds.get("path") if isinstance(ds, dict) else getattr(ds, "path", ""),
"path": dsPath,
"sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""),
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
@ -60,20 +67,43 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"chunkCount": chunksByDs.get(dsId, 0),
})
if unassigned > 0 and len(dsItems) == 1:
dsItems[0]["chunkCount"] += unassigned
if unassigned > 0 and len(dsItems) > 0:
perDs = unassigned // len(dsItems)
remainder = unassigned % len(dsItems)
for i, item in enumerate(dsItems):
item["chunkCount"] += perDs + (1 if i < remainder else 0)
jobs = jobService.listJobs(jobType="connection.bootstrap", limit=5)
# Pull a wider window than the previous 5 so the "last successful
# sync" is found even if a connection has many recent jobs queued.
jobs = jobService.listJobs(jobType="connection.bootstrap", limit=50)
connJobs = [j for j in jobs if (j.get("payload") or {}).get("connectionId") == connectionId]
runningJobs = [
{"jobId": j["id"], "progress": j.get("progress", 0), "progressMessage": j.get("progressMessage", "")}
for j in connJobs
if j.get("status") in ("PENDING", "RUNNING")
]
lastError = None
lastError: Optional[Dict[str, Any]] = None
lastSuccess: Optional[Dict[str, Any]] = None
for j in connJobs:
if j.get("status") == "ERROR":
lastError = {"jobId": j["id"], "errorMessage": j.get("errorMessage", "")}
status = j.get("status")
if status == "ERROR" and lastError is None:
lastError = {
"jobId": j["id"],
"errorMessage": j.get("errorMessage", ""),
"finishedAt": j.get("finishedAt"),
}
elif status == "SUCCESS" and lastSuccess is None:
result = j.get("result") or {}
lastSuccess = {
"jobId": j["id"],
"finishedAt": j.get("finishedAt"),
"indexed": result.get("indexed", 0),
"skippedDuplicate": result.get("skippedDuplicate", 0),
"skippedPolicy": result.get("skippedPolicy", 0),
"failed": result.get("failed", 0),
"durationMs": result.get("durationMs", 0),
}
if lastError and lastSuccess:
break
out.append({
@ -86,6 +116,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
"totalChunks": connChunkTotal,
"runningJobs": runningJobs,
"lastError": lastError,
"lastSuccess": lastSuccess,
})
return out
@ -182,7 +213,7 @@ def _getInventoryPlatform(
@router.post("/reindex/{connectionId}")
@limiter.limit("10/minute")
def _reindexConnection(
async def _reindexConnection(
request: Request,
connectionId: str,
currentUser: User = Depends(getCurrentUser),
@ -190,12 +221,16 @@ def _reindexConnection(
"""Re-trigger bootstrap for a connection (re-index all ragIndexEnabled DataSources).
Submits a new connection.bootstrap job, regardless of previous failures.
Must be `async def` so `await startJob(...)` registers the `_runJob` task
in FastAPI's main event loop. A sync route would land in the worker
threadpool and `asyncio.run` would tear down the temporary loop right
after `create_task`, leaving the job stuck in PENDING forever.
"""
try:
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
from modules.datamodels.datamodelDataSource import DataSource
import asyncio
rootIf = getRootInterface()
conn = rootIf.getUserConnectionById(connectionId)
@ -213,23 +248,13 @@ def _reindexConnection(
authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
dsIds = [(ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")) for ds in ragDs]
async def _enqueue():
return await startJob(
"connection.bootstrap",
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": dsIds},
triggeredBy=str(currentUser.id),
)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
future = asyncio.ensure_future(_enqueue())
jobId = None
else:
jobId = loop.run_until_complete(_enqueue())
except RuntimeError:
jobId = asyncio.run(_enqueue())
jobId = await startJob(
"connection.bootstrap",
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": dsIds},
triggeredBy=str(currentUser.id),
)
logger.info("Reindex triggered for connection %s (%d DataSources)", connectionId, len(dsIds))
logger.info("Reindex triggered for connection %s (%d DataSources, jobId=%s)", connectionId, len(dsIds), jobId)
return {"status": "queued", "connectionId": connectionId, "dataSourceCount": len(dsIds), "jobId": jobId}
except HTTPException:
raise

View file

@ -7,7 +7,7 @@ import logging
import time
import json
import re
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable, Tuple
from modules.datamodels.datamodelAi import (
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum
@ -360,12 +360,18 @@ async def runAgentLoop(
state.totalToolCalls += len(results)
for result in results:
validationCode = None
if isinstance(result.errorDetails, dict):
code = result.errorDetails.get("code")
if isinstance(code, str):
validationCode = code
roundLog.toolCalls.append(ToolCallLog(
toolName=result.toolName,
args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}),
success=result.success,
durationMs=result.durationMs,
error=result.error,
validationFailureCode=validationCode,
resultData=result.data[:300] if result.data else "",
))
if not result.success:
@ -443,6 +449,11 @@ async def runAgentLoop(
trace.totalCostCHF = state.totalCostCHF
trace.abortReason = state.abortReason
validationFailures, repairAttempts, successAfterRepair = _computeRepairCounters(trace.rounds)
trace.validationFailures = validationFailures
trace.repairAttempts = repairAttempts
trace.successAfterRepair = successAfterRepair
artifactSummary = _buildArtifactSummary(trace.rounds)
yield AgentEvent(
@ -456,6 +467,9 @@ async def runAgentLoop(
"status": state.status.value,
"abortReason": state.abortReason,
"artifacts": artifactSummary,
"validationFailures": validationFailures,
"repairAttempts": repairAttempts,
"successAfterRepair": successAfterRepair,
}
)
@ -720,6 +734,41 @@ def classifyToolResult(
return None
def _computeRepairCounters(rounds: List[AgentRoundLog]) -> Tuple[int, int, int]:
"""Aggregate repair-loop telemetry across all rounds.
Returns ``(validationFailures, repairAttempts, successAfterRepair)``.
* `validationFailures` -- total tool calls rejected by a pre-execute
validator (any round, counts every occurrence).
* `repairAttempts` -- tool calls in **later** rounds whose `toolName`
had been rejected in some **earlier** round. Multiple retries of the
same tool count multiple times. We intentionally do not count
sibling calls within the same round, since the LLM has not yet seen
the first one's result when emitting the second.
* `successAfterRepair` -- the subset of `repairAttempts` that passed
the validator (``validationFailureCode is None``).
"""
validationFailures = 0
repairAttempts = 0
successAfterRepair = 0
rejectedTools: set = set()
for roundLog in rounds:
rejectedFromPriorRounds = set(rejectedTools)
for tc in roundLog.toolCalls:
wasRejectedBefore = tc.toolName in rejectedFromPriorRounds
if tc.validationFailureCode is not None:
validationFailures += 1
if wasRejectedBefore:
repairAttempts += 1
rejectedTools.add(tc.toolName)
elif wasRejectedBefore:
repairAttempts += 1
successAfterRepair += 1
return validationFailures, repairAttempts, successAfterRepair
_ARTIFACT_TOOLS = {"writeFile", "replaceInFile", "deleteFile", "renameFile", "copyFile",
"createFolder", "deleteFolder", "renderDocument", "generateImage"}

View file

@ -19,6 +19,20 @@ from modules.serviceCenter.services.serviceAgent.coreTools._helpers import (
logger = logging.getLogger(__name__)
_STALE_EXTRACTION_PATTERNS = (
"requires the extract-msg package",
"extraction requires the",
"will be treated as binary",
)
def _isStaleExtractionResult(text: str) -> bool:
"""Detect cached extraction results that are just error/warning placeholders."""
if len(text) > 500:
return False
textLower = text.lower()
return any(p in textLower for p in _STALE_EXTRACTION_PATTERNS)
import uuid as _uuid
@ -62,15 +76,16 @@ def _registerWorkspaceTools(registry: ToolRegistry, services):
]
if textChunks:
assembled = "\n\n".join(c["data"] for c in textChunks)
chunked = _applyOffsetLimit(assembled, offset, limit)
if chunked is not None:
return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked)
if len(assembled) > _MAX_TOOL_RESULT_CHARS:
assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}. Use offset/limit to read specific sections.]"
return ToolResult(
toolCallId="", toolName="readFile", success=True,
data=assembled,
)
if not _isStaleExtractionResult(assembled):
chunked = _applyOffsetLimit(assembled, offset, limit)
if chunked is not None:
return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked)
if len(assembled) > _MAX_TOOL_RESULT_CHARS:
assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}. Use offset/limit to read specific sections.]"
return ToolResult(
toolCallId="", toolName="readFile", success=True,
data=assembled,
)
elif fileStatus in ("processing", "embedding", "extracted"):
return ToolResult(
toolCallId="", toolName="readFile", success=True,
@ -101,12 +116,31 @@ def _registerWorkspaceTools(registry: ToolRegistry, services):
isBinary = _looksLikeBinary(rawBytes)
if isBinary:
extractionService = services.getService("extraction") if hasattr(services, "getService") else None
if extractionService:
try:
extracted = extractionService.extractContentFromBytes(
rawBytes, fileName, mimeType, documentId=fileId,
)
textParts = [
p.data for p in (extracted.parts or [])
if getattr(p, "contentType", "") != "image" and getattr(p, "data", None)
]
if textParts:
assembled = "\n\n".join(textParts)
chunked = _applyOffsetLimit(assembled, offset, limit)
if chunked is not None:
return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked)
if len(assembled) > _MAX_TOOL_RESULT_CHARS:
assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}. Use offset/limit to read specific sections.]"
return ToolResult(toolCallId="", toolName="readFile", success=True, data=assembled)
except Exception as extractErr:
logger.warning("readFile: inline extraction failed for %s: %s", fileId, extractErr)
return ToolResult(
toolCallId="", toolName="readFile", success=True,
data=(
f"[File '{fileName}' ({mimeType}) is not yet indexed "
f"(status: {fileStatus or 'unknown'}). Indexing runs automatically "
f"on upload. Please wait a few seconds and retry, or re-upload the file. "
f"[File '{fileName}' ({mimeType}) is binary and could not be extracted "
f"(status: {fileStatus or 'unknown'}). "
f"For visual content use describeImage(fileId='{fileId}').]"
),
)

View file

@ -79,6 +79,14 @@ class ToolResult(BaseModel):
success: bool = True
data: str = ""
error: Optional[str] = None
errorDetails: Optional[Dict[str, Any]] = Field(
default=None,
description=(
"Structured, machine-readable error payload for the LLM (e.g. validation "
"repair hints with code/field/suggestion/hint). `error` remains the short "
"human-readable text for logs and audit."
),
)
durationMs: int = 0
sideEvents: Optional[List[Dict[str, Any]]] = None
@ -141,6 +149,14 @@ class ToolCallLog(BaseModel):
success: bool = True
durationMs: int = 0
error: Optional[str] = None
validationFailureCode: Optional[str] = Field(
default=None,
description=(
"If the tool call was rejected by a pre-execute validator (e.g. "
"QueryValidator), the structured error code (e.g. FIELD_NOT_FOUND). "
"None when the call ran cleanly or failed for other reasons."
),
)
resultData: str = Field(default="", description="Short result summary for artifact tracking")
@ -167,6 +183,24 @@ class AgentTrace(BaseModel):
totalToolCalls: int = 0
totalCostCHF: float = 0.0
abortReason: Optional[str] = None
validationFailures: int = Field(
default=0,
description="Total tool calls rejected by a pre-execute validator across the run.",
)
repairAttempts: int = Field(
default=0,
description=(
"Number of times the LLM retried a previously rejected tool (same toolName) "
"in a later round. Counted by `agentLoop` from per-round ToolCallLog entries."
),
)
successAfterRepair: int = Field(
default=0,
description=(
"Number of repair attempts that produced a clean (validationFailureCode=None) "
"result. Combined with `repairAttempts` this gives the repair conversion rate."
),
)
rounds: List[AgentRoundLog] = Field(default_factory=list)

View file

@ -0,0 +1,203 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Ontology data model for feature data sub-agents.
This module defines the data structures that describe a feature's data
ontology -- entities, relations, constraints, canonical query patterns --
plus the validation error payload used by the QueryValidator.
Phase 1 (Repair-Loop) only needs `QueryValidationError`, `Constraint`,
`ConstraintRule` and `ValidationErrorCode`; the richer `Entity`/`Relation`/
`OntologyDescriptor` types are defined here so Phase 2 (Trustee ontology
pilot) can plug in without a second data-model change.
See `wiki/c-work/2-build/2026-05-feature-data-agent-ontology-and-repair.md`.
"""
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class ValidationErrorCode(str, Enum):
"""Stable codes for validator failures.
The LLM sees these codes verbatim in `ToolResult.errorDetails["code"]`
and is expected to react to them deterministically (e.g. inspect the
schema via browseTable when FIELD_NOT_FOUND, drop the SUM when
INVALID_AGGREGATE_TARGET, add a period filter when MISSING_REQUIRED_FILTER).
"""
FIELD_NOT_FOUND = "FIELD_NOT_FOUND"
INVALID_AGGREGATE_TARGET = "INVALID_AGGREGATE_TARGET"
WRONG_TABLE_FOR_PURPOSE = "WRONG_TABLE_FOR_PURPOSE"
TYPE_MISMATCH = "TYPE_MISMATCH"
OPERATOR_INCOMPATIBLE = "OPERATOR_INCOMPATIBLE"
MISSING_REQUIRED_FILTER = "MISSING_REQUIRED_FILTER"
ORDER_BY_INVALID = "ORDER_BY_INVALID"
class QueryValidationError(BaseModel):
"""Structured pre-execute validation error.
Serialized into `ToolResult.errorDetails` (machine-readable) and
summarized into `ToolResult.error` (short human-readable string).
"""
code: ValidationErrorCode
field: Optional[str] = Field(
default=None,
description="The offending field name (when applicable).",
)
suggestion: Optional[str] = Field(
default=None,
description=(
"Best-effort suggestion (e.g. fuzzy-matched valid field name). "
"None when no useful suggestion exists."
),
)
hint: str = Field(
description="Short corrective hint, max ~80 chars. Surfaced to the LLM verbatim.",
max_length=160,
)
def toShortError(self) -> str:
"""Build the short `error` string for logs/audit.
Format: `<CODE>: <hint>` (or with field when present).
"""
if self.field:
return f"{self.code.value}: {self.field}: {self.hint}"
return f"{self.code.value}: {self.hint}"
def toErrorDetails(self) -> Dict[str, Any]:
"""Build the dict for `ToolResult.errorDetails`."""
return {
"code": self.code.value,
"field": self.field,
"suggestion": self.suggestion,
"hint": self.hint,
}
class ConstraintRule(str, Enum):
"""High-level rule kinds that can be attached to a field or table."""
NEVER_AGGREGATE = "NEVER_AGGREGATE"
REQUIRES_FILTER_ON = "REQUIRES_FILTER_ON"
TYPE_MISMATCH_GUARD = "TYPE_MISMATCH_GUARD"
PREFERRED_TABLE_FOR_INTENT = "PREFERRED_TABLE_FOR_INTENT"
class Constraint(BaseModel):
"""A single rule the validator and the prompt compiler both consume.
Phase 1 uses constraints declared inline by the validator (defaults
derived from naming conventions like ``*Balance`` / ``*Total``).
Phase 2 sources them from feature ontologies, replacing the
convention-based defaults.
"""
appliesTo: str = Field(
description=(
"Target identifier, format depends on rule: `<Table>.<field>` for "
"field-level constraints, `<Table>` for table-level."
),
)
rule: ConstraintRule
message: str = Field(
description="Short hint forwarded to the LLM if the constraint fires.",
max_length=160,
)
params: Dict[str, Any] = Field(
default_factory=dict,
description=(
"Rule-specific extras, e.g. {'requiredFields': ['periodYear', 'periodMonth']} "
"for REQUIRES_FILTER_ON."
),
)
class SemanticType(str, Enum):
"""High-level semantic category an entity belongs to.
Coarser than the underlying Pydantic type -- used so the prompt compiler
can group entities ("here are your ACCOUNT-like tables") without the LLM
having to read the full schema.
"""
ACCOUNT = "ACCOUNT"
BALANCE_SNAPSHOT = "BALANCE_SNAPSHOT"
TRANSACTION = "TRANSACTION"
DOCUMENT = "DOCUMENT"
PARTY = "PARTY"
PERIOD = "PERIOD"
OTHER = "OTHER"
class Cardinality(str, Enum):
ONE_TO_ONE = "ONE_TO_ONE"
ONE_TO_MANY = "ONE_TO_MANY"
MANY_TO_ONE = "MANY_TO_ONE"
MANY_TO_MANY = "MANY_TO_MANY"
class Invariant(BaseModel):
"""Free-form invariant attached to an entity.
Phase 1 leaves these as opaque text consumed by the prompt compiler.
Future phases may add a structured rule kind.
"""
description: str = Field(max_length=200)
class Entity(BaseModel):
"""One semantic entity in the ontology (often backed by a Pydantic table)."""
name: str
pythonClass: Optional[str] = Field(
default=None,
description="MODEL_REGISTRY key when the entity is DB-backed (e.g. 'TrusteeDataAccountBalance').",
)
semanticType: SemanticType = SemanticType.OTHER
parentEntity: Optional[str] = Field(
default=None,
description="Name of a broader entity this one specializes (e.g. 'BankAccount' parentEntity 'Account').",
)
description: str = ""
invariants: List[Invariant] = Field(default_factory=list)
class Relation(BaseModel):
fromEntity: str
toEntity: str
cardinality: Cardinality
via: Optional[str] = Field(
default=None,
description="FK-Feldname auf der fromEntity-Seite (z. B. 'journalEntryId').",
)
class CanonicalQueryPattern(BaseModel):
"""Tool-call skeleton for a recurring user intent.
The prompt compiler renders these as worked examples so the LLM has a
template to mimic instead of inventing a query shape.
"""
intent: str = Field(description="Short label, e.g. 'BANK_BALANCE_AT_DATE'.")
description: str = Field(default="", description="Human-readable when to use this pattern.")
pattern: Dict[str, Any] = Field(
description="Tool-call shape with placeholders, e.g. {'tool': 'queryTable', 'tableName': '...', 'filters': [...]}",
)
class OntologyDescriptor(BaseModel):
"""Top-level container exported by `getAgentOntology()` per feature."""
featureCode: str
entities: List[Entity] = Field(default_factory=list)
relations: List[Relation] = Field(default_factory=list)
constraints: List[Constraint] = Field(default_factory=list)
canonicalPatterns: List[CanonicalQueryPattern] = Field(default_factory=list)
def constraintsForTable(self, tableName: str) -> List[Constraint]:
"""Return constraints whose ``appliesTo`` targets the given table or one of its fields."""
prefix = f"{tableName}."
return [
c for c in self.constraints
if c.appliesTo == tableName or c.appliesTo.startswith(prefix)
]

View file

@ -15,6 +15,7 @@ invoked outside an agent loop (e.g. in tests).
import json
import logging
import os
from typing import Any, Callable, Awaitable, Dict, List, Optional
from modules.datamodels.datamodelAi import (
@ -25,6 +26,10 @@ from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
)
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
QueryValidationError,
)
from modules.serviceCenter.services.serviceAgent.queryValidator import QueryValidator
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
from modules.shared.i18nRegistry import resolveText
@ -83,7 +88,8 @@ async def runFeatureDataAgent(
"""
provider = FeatureDataProvider(dbConnector, neutralizeFields=neutralizeFields)
registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {})
validator = _buildValidatorForFeature(featureCode)
registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {}, validator=validator)
for tbl in selectedTables:
meta = tbl.get("meta", {})
@ -153,10 +159,19 @@ def _buildSubAgentTools(
featureInstanceId: str,
mandateId: str,
tableFilters: Dict[str, Dict[str, str]] = None,
validator: Optional[QueryValidator] = None,
) -> ToolRegistry:
"""Register browseTable and queryTable as sub-agent tools."""
"""Register browseTable and queryTable as sub-agent tools.
The optional ``validator`` runs **before** the provider on every call.
When it returns a structured error, the tool result carries
``errorDetails`` (machine-readable repair hint for the LLM) plus the
short ``error`` string for logs/audit. No provider call happens in that
case, so the database is never reached with a known-bad query.
"""
registry = ToolRegistry()
_tableFilters = tableFilters or {}
_validator = validator or QueryValidator()
def _recordFilterToList(tableName: str) -> Optional[List[Dict[str, Any]]]:
"""Convert a recordFilter dict to a list of {field, op, value} filter dicts."""
@ -165,6 +180,14 @@ def _buildSubAgentTools(
return None
return [{"field": k, "op": "=", "value": v} for k, v in rf.items()]
def _validationToolResult(toolName: str, err: QueryValidationError) -> ToolResult:
return ToolResult(
toolCallId="", toolName=toolName,
success=False,
error=err.toShortError(),
errorDetails=err.toErrorDetails(),
)
async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]):
tableName = args.get("tableName", "")
limit = args.get("limit", 50)
@ -172,6 +195,9 @@ def _buildSubAgentTools(
fields = args.get("fields")
if not tableName:
return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required")
validationErr = _validator.validateBrowseQuery(tableName, args)
if validationErr is not None:
return _validationToolResult("browseTable", validationErr)
result = provider.browseTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
@ -197,6 +223,9 @@ def _buildSubAgentTools(
offset = args.get("offset", 0)
if not tableName:
return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required")
validationErr = _validator.validateQueryTable(tableName, args)
if validationErr is not None:
return _validationToolResult("queryTable", validationErr)
result = provider.queryTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
@ -220,12 +249,19 @@ def _buildSubAgentTools(
aggregate = args.get("aggregate", "")
field = args.get("field", "")
groupBy = args.get("groupBy")
filters = args.get("filters") or []
if not tableName:
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="tableName required")
if not aggregate:
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="aggregate required (SUM, COUNT, AVG, MIN, MAX)")
if not field:
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="field required")
validationErr = _validator.validateAggregateQuery(tableName, args)
if validationErr is not None:
return _validationToolResult("aggregateTable", validationErr)
combinedFilters = list(filters)
recordFilters = _recordFilterToList(tableName) or []
combinedFilters.extend(recordFilters)
result = provider.aggregateTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
@ -233,7 +269,7 @@ def _buildSubAgentTools(
aggregate=aggregate,
field=field,
groupBy=groupBy,
extraFilters=_recordFilterToList(tableName),
extraFilters=combinedFilters or None,
)
return ToolResult(
toolCallId="", toolName="aggregateTable",
@ -246,8 +282,12 @@ def _buildSubAgentTools(
"aggregateTable", _aggregateTable,
description=(
"Run an aggregate query on a feature data table. "
"Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY. "
"Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', field='debitAmount', groupBy='costCenter')"
"Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY and filters. "
"Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', "
"field='debitAmount', filters=[{'field':'accountNumber','op':'=','value':'5400'}]). "
"On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- "
"read errorDetails and correct the next call (e.g. drop the SUM, switch to queryTable with period filters, "
"or use the suggested field name)."
),
parameters={
"type": "object",
@ -256,6 +296,22 @@ def _buildSubAgentTools(
"aggregate": {"type": "string", "enum": ["SUM", "COUNT", "AVG", "MIN", "MAX"], "description": "Aggregate function"},
"field": {"type": "string", "description": "Field to aggregate (e.g. debitAmount, creditAmount)"},
"groupBy": {"type": "string", "description": "Optional field to group by (e.g. costCenter, accountNumber)"},
"filters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"field": {"type": "string"},
"op": {"type": "string"},
"value": {},
},
},
"description": (
"Optional filter conditions applied before the aggregate. Same shape as queryTable's "
"filters. Required whenever you want to aggregate only a subset (e.g. SUM debits on "
"ONE account, COUNT rows in ONE year)."
),
},
},
"required": ["tableName", "aggregate", "field"],
},
@ -264,7 +320,11 @@ def _buildSubAgentTools(
registry.register(
"browseTable", _browseTable,
description="List rows from a feature data table with pagination.",
description=(
"List rows from a feature data table with pagination. "
"On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- "
"use errorDetails to correct the next call."
),
parameters={
"type": "object",
"properties": {
@ -286,7 +346,10 @@ def _buildSubAgentTools(
description=(
"Query a feature data table with filters, field selection, and ordering. "
"Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. "
"Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL."
"Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL. "
"On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- "
"common codes: FIELD_NOT_FOUND (use the suggestion or call browseTable), OPERATOR_INCOMPATIBLE "
"(switch to a compatible operator for that field type), ORDER_BY_INVALID."
),
parameters={
"type": "object",
@ -410,13 +473,94 @@ def _buildSchemaContext(
"- Keep your answer SHORT. The caller is a machine, not a human.",
]
domainHints = _loadFeatureDomainHints(featureCode)
if domainHints:
parts.extend(["", domainHints.strip()])
domainBlock = ""
if not _isOntologyDisabled():
domainBlock = _loadFeatureOntologyBlock(featureCode)
if not domainBlock:
domainBlock = _loadFeatureDomainHints(featureCode)
if domainBlock:
parts.extend(["", domainBlock.strip()])
return "\n".join(parts)
def _isOntologyDisabled() -> bool:
"""Eval-only escape hatch.
Set ``POWERON_DISABLE_FEATURE_ONTOLOGY=1`` in the environment to force
``_buildSchemaContext`` back onto the legacy ``getAgentDomainHints()``
path. Used by the Phase 1.5 benchmark to measure ``baseline`` and
``phase1`` accuracy WITHOUT the ontology-driven prompt block. Never
set this flag in production.
"""
return os.environ.get("POWERON_DISABLE_FEATURE_ONTOLOGY", "").strip() in ("1", "true", "TRUE", "yes")
def _buildValidatorForFeature(featureCode: str) -> QueryValidator:
"""Construct a QueryValidator wired with the feature ontology (when present).
Without an ontology the validator falls back to its convention-based
constraints (``*Balance`` / ``*Total`` are NEVER_AGGREGATE). With an
ontology the descriptor's constraints take precedence -- the validator
and the prompt block then share the same source of truth.
"""
ontology = _loadFeatureOntology(featureCode)
return QueryValidator(ontology=ontology)
def _loadFeatureOntology(featureCode: str):
"""Return the feature's OntologyDescriptor or None when no hook is exposed."""
if not featureCode:
return None
try:
from modules.system.registry import loadFeatureMainModules
except Exception:
return None
try:
mainModules = loadFeatureMainModules() or {}
except Exception as exc:
logger.debug("Ontology lookup: cannot load main modules (%s)", exc)
return None
module = mainModules.get(featureCode) or mainModules.get(featureCode.lower())
if module is None:
return None
hook = getattr(module, "getAgentOntology", None)
if not callable(hook):
return None
try:
return hook()
except Exception as exc:
logger.warning("Feature '%s' getAgentOntology() raised: %s", featureCode, exc)
return None
def _loadFeatureOntologyBlock(featureCode: str) -> str:
"""Return the ontology-derived prompt block when the feature exposes one.
Each feature can expose ``getAgentOntology() -> OntologyDescriptor`` in
its ``mainXxx.py``. When present, the descriptor is compiled via
:func:`ontologyToPromptCompiler.compileOntologyToPrompt` and the result
replaces the legacy ``getAgentDomainHints()`` text block. This keeps
one single source of truth for the validator AND the prompt.
Failures are swallowed (missing hook, exceptions in compilation) so the
caller can fall back to the legacy domain-hints path.
"""
ontology = _loadFeatureOntology(featureCode)
if ontology is None:
return ""
try:
from modules.serviceCenter.services.serviceAgent.ontologyToPromptCompiler import (
compileOntologyToPrompt,
)
return compileOntologyToPrompt(ontology)
except Exception as exc:
logger.warning("Ontology compile failed for '%s': %s", featureCode, exc)
return ""
def _loadFeatureDomainHints(featureCode: str) -> str:
"""Pull optional domain-specific hints from the feature's main module.

View file

@ -0,0 +1,140 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Deterministic compiler: OntologyDescriptor -> sub-agent prompt block.
Phase 2 replaces a feature's hand-written ``_AGENT_DOMAIN_HINTS`` text
with a structured :class:`OntologyDescriptor`. This compiler renders the
descriptor into a stable, terse Markdown-ish block that the sub-agent
appends to its system prompt -- the same source of truth the
:class:`QueryValidator` consults.
The output is intentionally:
* short (every token costs every call)
* deterministic (no f-string ordering bugs, no Python dict iteration)
* free of internal jargon ('canonicalQueryPattern' is rendered as
'CANONICAL PATTERN' for the LLM)
"""
from __future__ import annotations
from typing import Iterable, List
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
CanonicalQueryPattern,
Constraint,
ConstraintRule,
Entity,
OntologyDescriptor,
Relation,
)
def compileOntologyToPrompt(ontology: OntologyDescriptor) -> str:
"""Render *ontology* into a sub-agent prompt block.
The output starts with a stable marker line (``DOMAIN ONTOLOGY (...)``)
so downstream tooling can find/replace it deterministically.
"""
lines: List[str] = []
lines.append(f"DOMAIN ONTOLOGY ({ontology.featureCode}):")
lines.append("")
lines.extend(_renderEntities(ontology.entities))
relationLines = _renderRelations(ontology.relations)
if relationLines:
lines.append("")
lines.extend(relationLines)
constraintLines = _renderConstraints(ontology.constraints)
if constraintLines:
lines.append("")
lines.extend(constraintLines)
patternLines = _renderPatterns(ontology.canonicalPatterns)
if patternLines:
lines.append("")
lines.extend(patternLines)
return "\n".join(lines).rstrip() + "\n"
def _renderEntities(entities: Iterable[Entity]) -> List[str]:
out: List[str] = ["ENTITIES:"]
for e in entities:
head = f"- {e.name}"
if e.parentEntity:
head += f" (specializes {e.parentEntity})"
if e.pythonClass:
head += f" [table: {e.pythonClass}]"
out.append(head)
if e.description:
out.append(f" {e.description}")
for inv in e.invariants:
out.append(f" * {inv.description}")
return out
def _renderRelations(relations: Iterable[Relation]) -> List[str]:
rels = list(relations)
if not rels:
return []
out: List[str] = ["RELATIONS:"]
for r in rels:
line = f"- {r.fromEntity} -> {r.toEntity} ({r.cardinality.value}"
if r.via:
line += f" via {r.via}"
line += ")"
out.append(line)
return out
def _renderConstraints(constraints: Iterable[Constraint]) -> List[str]:
cons = list(constraints)
if not cons:
return []
out: List[str] = ["CONSTRAINTS (validator-enforced):"]
for c in cons:
rule = _ruleLabel(c.rule)
line = f"- {rule} on {c.appliesTo}: {c.message}"
params = c.params or {}
required = params.get("requiredFields")
if isinstance(required, list) and required:
line += f" (required filters: {', '.join(required)})"
intents = params.get("intents")
if isinstance(intents, list) and intents:
line += f" (intents: {', '.join(intents)})"
out.append(line)
return out
def _ruleLabel(rule: ConstraintRule) -> str:
return rule.value.replace("_", " ").lower()
def _renderPatterns(patterns: Iterable[CanonicalQueryPattern]) -> List[str]:
pats = list(patterns)
if not pats:
return []
out: List[str] = ["CANONICAL QUERY PATTERNS (mimic these tool calls):"]
for i, p in enumerate(pats, start=1):
out.append(f"{i}) intent={p.intent}: {p.description}")
out.append(f" call: {_renderPatternCall(p.pattern)}")
extra = p.pattern.get("_postProcessing") if isinstance(p.pattern, dict) else None
if isinstance(extra, str):
out.append(f" note: {extra}")
return out
def _renderPatternCall(pattern: dict) -> str:
"""Render the pattern as a compact one-line tool call signature."""
tool = pattern.get("tool", "?")
parts: List[str] = []
for key in ("tableName", "aggregate", "field", "groupBy", "orderBy"):
if key in pattern and pattern[key] is not None and not str(key).startswith("_"):
parts.append(f"{key}={pattern[key]!r}")
if "fields" in pattern and pattern["fields"]:
parts.append(f"fields={pattern['fields']}")
if "filters" in pattern and pattern["filters"]:
compact = ", ".join(
f"{f.get('field')}{f.get('op','=')}{f.get('value')!r}"
for f in pattern["filters"]
if isinstance(f, dict)
)
parts.append(f"filters=[{compact}]")
return f"{tool}({', '.join(parts)})"

View file

@ -0,0 +1,311 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Pre-execute query validator for the Feature Data Sub-Agent.
Sits between the LLM tool call and `FeatureDataProvider`. Catches the four
high-impact hallucination classes deterministically so the LLM gets an
actionable repair hint instead of a raw SQL exception:
* invented field names -> FIELD_NOT_FOUND (+ fuzzy suggestion)
* operator/type mismatches -> OPERATOR_INCOMPATIBLE
* SUM/AVG on already-aggregated -> INVALID_AGGREGATE_TARGET
balance/total columns
* orderBy on invented fields -> ORDER_BY_INVALID
The validator reads the canonical schema from
`modules.datamodels.datamodelBase.MODEL_REGISTRY`. When an
`OntologyDescriptor` is provided (Phase 2), its constraints override the
convention-based defaults (e.g. NEVER_AGGREGATE on closingBalance).
"""
from __future__ import annotations
import difflib
import logging
import re
import typing
from typing import Any, Dict, List, Optional, Tuple
from modules.datamodels.datamodelBase import MODEL_REGISTRY
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
Constraint,
ConstraintRule,
OntologyDescriptor,
QueryValidationError,
ValidationErrorCode,
)
logger = logging.getLogger(__name__)
_STRING_ONLY_OPERATORS = {"LIKE", "ILIKE"}
_COMPARISON_OPERATORS = {">", "<", ">=", "<="}
_VALUELESS_OPERATORS = {"IS NULL", "IS NOT NULL"}
_AGGREGATES_THAT_SUM = {"SUM", "AVG"}
_AGGREGATE_BLACKLIST_SUFFIXES_DEFAULT: Tuple[str, ...] = ("Balance", "Total")
class QueryValidator:
"""Validate sub-agent tool arguments against the schema (+ optional ontology).
Stateless per call -- holding only the optional ontology. Each
`validateXxx` method returns ``None`` on success or a
:class:`QueryValidationError` to be surfaced to the LLM.
"""
def __init__(self, ontology: Optional[OntologyDescriptor] = None):
self._ontology = ontology
# ------------------------------------------------------------------
# public API: one method per sub-agent tool
# ------------------------------------------------------------------
def validateBrowseQuery(
self, tableName: str, args: Dict[str, Any]
) -> Optional[QueryValidationError]:
"""Validate browseTable arguments.
Phase 1 scope: only `fields` (whitelist) is LLM-driven; `limit`/`offset`
are sanitized by the tool wrapper.
"""
modelFields = _getModelFields(tableName)
if modelFields is None:
return None
fieldsErr = self._validateFieldList(args.get("fields"), modelFields)
if fieldsErr is not None:
return fieldsErr
return None
def validateQueryTable(
self, tableName: str, args: Dict[str, Any]
) -> Optional[QueryValidationError]:
"""Validate queryTable arguments (filters + fields + orderBy)."""
modelFields = _getModelFields(tableName)
if modelFields is None:
return None
fieldsErr = self._validateFieldList(args.get("fields"), modelFields)
if fieldsErr is not None:
return fieldsErr
for f in args.get("filters") or []:
filterErr = self._validateFilter(f, modelFields)
if filterErr is not None:
return filterErr
orderBy = args.get("orderBy")
if orderBy is not None and not _isPlainNone(orderBy):
if orderBy not in modelFields:
return QueryValidationError(
code=ValidationErrorCode.ORDER_BY_INVALID,
field=orderBy,
suggestion=_suggestFieldName(orderBy, modelFields),
hint="orderBy must be a real field of this table.",
)
return None
def validateAggregateQuery(
self, tableName: str, args: Dict[str, Any]
) -> Optional[QueryValidationError]:
"""Validate aggregateTable arguments.
Catches the highest-impact hallucination in the codebase:
``SUM(closingBalance)`` (and friends) across periods -- closing
balances are already per-period, summing them produces nonsense.
"""
modelFields = _getModelFields(tableName)
if modelFields is None:
return None
field = args.get("field")
aggregate = (args.get("aggregate") or "").upper()
if not field:
return None # tool wrapper rejects empty field already
if field not in modelFields:
return QueryValidationError(
code=ValidationErrorCode.FIELD_NOT_FOUND,
field=field,
suggestion=_suggestFieldName(field, modelFields),
hint="Use browseTable to inspect this table's columns.",
)
if aggregate in _AGGREGATES_THAT_SUM and self._isAggregateBlacklisted(tableName, field):
return QueryValidationError(
code=ValidationErrorCode.INVALID_AGGREGATE_TARGET,
field=field,
suggestion=None,
hint=(
f"{field} is already aggregated per period; do not {aggregate} it "
"across rows. Use queryTable with period filters instead."
),
)
if aggregate in _AGGREGATES_THAT_SUM and not _isNumericAnnotation(modelFields[field]):
return QueryValidationError(
code=ValidationErrorCode.TYPE_MISMATCH,
field=field,
suggestion=None,
hint=f"{aggregate} requires a numeric field; {field} is not numeric.",
)
groupBy = args.get("groupBy")
if groupBy is not None and not _isPlainNone(groupBy):
if groupBy not in modelFields:
return QueryValidationError(
code=ValidationErrorCode.FIELD_NOT_FOUND,
field=groupBy,
suggestion=_suggestFieldName(groupBy, modelFields),
hint="groupBy must be a real field of this table.",
)
# filters validation matches queryTable so the LLM gets consistent
# repair hints regardless of which tool it picked.
for f in args.get("filters") or []:
filterErr = self._validateFilter(f, modelFields)
if filterErr is not None:
return filterErr
return None
# ------------------------------------------------------------------
# internals
# ------------------------------------------------------------------
def _validateFieldList(
self, fields: Optional[List[str]], modelFields: Dict[str, Any]
) -> Optional[QueryValidationError]:
if not fields:
return None
for f in fields:
if not isinstance(f, str):
continue
if f not in modelFields:
return QueryValidationError(
code=ValidationErrorCode.FIELD_NOT_FOUND,
field=f,
suggestion=_suggestFieldName(f, modelFields),
hint="Use browseTable to inspect this table's columns.",
)
return None
def _validateFilter(
self, filterEntry: Any, modelFields: Dict[str, Any]
) -> Optional[QueryValidationError]:
if not isinstance(filterEntry, dict):
return None
field = filterEntry.get("field")
op = (filterEntry.get("op") or "=").upper()
if not isinstance(field, str) or not field:
return None # tool wrapper passes these straight through
if field not in modelFields:
return QueryValidationError(
code=ValidationErrorCode.FIELD_NOT_FOUND,
field=field,
suggestion=_suggestFieldName(field, modelFields),
hint="Use browseTable to inspect this table's columns.",
)
annotation = modelFields[field]
if op in _STRING_ONLY_OPERATORS and not _isStringAnnotation(annotation):
return QueryValidationError(
code=ValidationErrorCode.OPERATOR_INCOMPATIBLE,
field=field,
suggestion=None,
hint=f"{op} only works on string fields; {field} is not a string.",
)
if op in _COMPARISON_OPERATORS and not _isComparableAnnotation(annotation):
return QueryValidationError(
code=ValidationErrorCode.OPERATOR_INCOMPATIBLE,
field=field,
suggestion=None,
hint=f"{op} requires a numeric or date field; {field} is not comparable.",
)
return None
def _isAggregateBlacklisted(self, tableName: str, fieldName: str) -> bool:
"""Check whether a field is marked NEVER_AGGREGATE.
Phase 2 (ontology present): consult the descriptor.
Phase 1 fallback: naming convention (``*Balance`` / ``*Total``).
"""
if self._ontology is not None:
target = f"{tableName}.{fieldName}"
for c in self._ontology.constraintsForTable(tableName):
if c.rule == ConstraintRule.NEVER_AGGREGATE and c.appliesTo == target:
return True
for suffix in _AGGREGATE_BLACKLIST_SUFFIXES_DEFAULT:
if fieldName.endswith(suffix):
return True
return False
# ------------------------------------------------------------------
# helpers
# ------------------------------------------------------------------
def _getModelFields(tableName: str) -> Optional[Dict[str, Any]]:
"""Return ``{fieldName: annotation}`` for a registered Pydantic table model.
None when the table is not in MODEL_REGISTRY (e.g. pure UDB tables in
early-startup contexts). The validator is a best-effort layer -- when
the schema is unknown we let the request through and rely on the
downstream SQL layer for safety.
"""
modelClass = MODEL_REGISTRY.get(tableName)
if modelClass is None:
return None
return {
name: info.annotation for name, info in modelClass.model_fields.items()
}
def _suggestFieldName(badName: str, modelFields: Dict[str, Any]) -> Optional[str]:
"""Return the closest valid field name, or None if nothing reasonable."""
if not badName or not modelFields:
return None
matches = difflib.get_close_matches(badName, list(modelFields.keys()), n=1, cutoff=0.6)
return matches[0] if matches else None
def _isPlainNone(value: Any) -> bool:
"""LLMs sometimes pass the literal string 'None' -- treat both as None."""
return value is None or (isinstance(value, str) and value.strip().lower() == "none")
def _unwrapAnnotation(annotation: Any) -> Tuple[Any, ...]:
"""Flatten Optional/Union annotations into their constituent types."""
origin = typing.get_origin(annotation)
if origin is None:
return (annotation,)
return tuple(a for a in typing.get_args(annotation) if a is not type(None))
def _isStringAnnotation(annotation: Any) -> bool:
return any(a is str for a in _unwrapAnnotation(annotation))
def _isNumericAnnotation(annotation: Any) -> bool:
numericTypes = (int, float)
return any(a in numericTypes for a in _unwrapAnnotation(annotation))
def _isComparableAnnotation(annotation: Any) -> bool:
"""Numeric types are the comparable shape we see in feature tables.
Booleans count as int in Python's type hierarchy but the comparison
operators ``>``/``<`` on bool columns are almost never meaningful, so we
treat bool as non-comparable for validator purposes.
"""
for a in _unwrapAnnotation(annotation):
if a is bool:
continue
if a in (int, float):
return True
return False

View file

@ -98,14 +98,17 @@ class _VirtualFS:
def _makeReadFile(services):
"""Create a readFile(fileId) closure bound to the current services context."""
def readFile(fileId: str) -> str:
def readFile(fileId: str, encoding: str = "utf-8") -> str:
mgmt = getattr(services, 'interfaceDbComponent', None) if services else None
if not mgmt:
raise RuntimeError("readFile: no file store available in this session")
data = mgmt.getFileData(str(fileId))
if data is None:
raise FileNotFoundError(f"File '{fileId}' not found in workspace")
return data.decode("utf-8")
try:
return data.decode(encoding)
except (UnicodeDecodeError, LookupError):
return data.decode("utf-8", errors="replace")
return readFile

View file

@ -60,6 +60,7 @@ from modules.shared.jsonContinuation import getContexts
from modules.shared.jsonUtils import buildContinuationContext, tryParseJson
from modules.shared.jsonUtils import closeJsonStructures
from modules.shared.jsonUtils import stripCodeFences, normalizeJsonText
from modules.shared.jsonUtils import extractJsonString, repairBrokenJson
logger = logging.getLogger(__name__)
@ -447,7 +448,6 @@ class AiCallLooper:
extracted = extractJsonString(contexts.completePart)
parsed, parseErr, _ = tryParseJson(extracted)
if parseErr is not None:
from modules.shared.jsonUtils import repairBrokenJson
repaired = repairBrokenJson(extracted)
if repaired:
parsed = repaired
@ -470,9 +470,10 @@ class AiCallLooper:
return useCase.finalResultHandler(
result, normalized, extracted, debugPrefix, self.services
)
except Exception as e:
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.warning(
f"Iteration {iteration}: completePart not serializable after getContexts success: {e}"
f"Iteration {iteration}: completePart not serializable after getContexts success: "
f"{type(e).__name__}: {e}"
)
mergeFailCount += 1
if mergeFailCount >= MAX_MERGE_FAILS:
@ -491,6 +492,15 @@ class AiCallLooper:
)
self.services.chat.progressLogFinish(iterationOperationId, True)
continue
except Exception as e:
logger.error(
f"Iteration {iteration}: unexpected error during completePart processing "
f"(re-raising, NOT a pipeline-mismatch retry): {type(e).__name__}: {e}",
exc_info=True,
)
if iterationOperationId:
self.services.chat.progressLogFinish(iterationOperationId, False)
raise
elif contexts.jsonParsingSuccess and contexts.overlapContext != "":
# JSON parseable but has cut point - CONTINUE to next iteration

View file

@ -34,7 +34,7 @@ import time
from datetime import datetime, timezone
from typing import Any, Awaitable, Callable, Dict, List, Optional
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.connectors.connectorDbPostgre import DatabaseConnector, getCachedConnector
from modules.shared.configuration import APP_CONFIG
from modules.shared.dbRegistry import registerDatabase
from modules.datamodels.datamodelBackgroundJob import (
@ -104,7 +104,13 @@ def registerJobHandler(jobType: str, handler: JobHandler) -> None:
def _getDb() -> DatabaseConnector:
return DatabaseConnector(
"""Return the shared cached connector for the jobs DB.
Reuses the same connector across all job CRUD calls instead of opening a
fresh psycopg2 connection (and re-running `_create_database_if_not_exists`
+ `_create_tables` + `_initializeSystemTable`) on every operation.
"""
return getCachedConnector(
dbDatabase=JOBS_DATABASE,
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
@ -290,12 +296,12 @@ def cancelJobsByConnection(connectionId: str, *, jobType: str = "connection.boot
def recoverInterruptedJobs() -> int:
"""Flip any RUNNING jobs to ERROR and re-queue bootstrap jobs (called at worker boot).
"""Flip any RUNNING jobs to ERROR (called at worker boot).
A RUNNING job in the DB after process restart means the previous worker
died mid-execution; the asyncio task is gone and the job will never
finish on its own. For connection.bootstrap jobs, a fresh job is
automatically re-queued so the user doesn't have to manually retry.
finish on its own. The daily scheduler or manual "Neu indexieren"
button handles retry no automatic re-queue to avoid infinite loops.
"""
db = _getDb()
try:
@ -304,34 +310,70 @@ def recoverInterruptedJobs() -> int:
logger.warning("recoverInterruptedJobs: failed to scan RUNNING jobs: %s", ex)
return 0
count = 0
requeued = 0
for row in rows:
try:
_markError(row["id"], "Interrupted by worker restart")
count += 1
except Exception as ex:
logger.warning("recoverInterruptedJobs: could not mark %s as ERROR: %s", row.get("id"), ex)
continue
if row.get("jobType") == "connection.bootstrap":
payload = row.get("payload") or {}
if payload.get("connectionId"):
try:
newJob = BackgroundJob(
jobType="connection.bootstrap",
payload=payload,
triggeredBy="recovery.requeue",
)
record = db.recordCreate(BackgroundJob, _serialiseDatetimes(newJob.model_dump()))
asyncio.create_task(_runJob(record["id"]))
requeued += 1
logger.info(
"recoverInterruptedJobs: re-queued bootstrap for connectionId=%s (new jobId=%s)",
payload["connectionId"], record["id"],
)
except Exception as reqEx:
logger.warning("recoverInterruptedJobs: re-queue failed for %s: %s", row.get("id"), reqEx)
if count:
logger.warning("Recovered %d interrupted background job(s) after restart (re-queued %d)", count, requeued)
logger.warning("Recovered %d interrupted background job(s) after restart", count)
return count
_ZOMBIE_MAX_AGE_SECONDS = 30 * 60
def killZombieJobs(maxAgeSeconds: int = _ZOMBIE_MAX_AGE_SECONDS) -> int:
"""Kill RUNNING jobs that have not been updated within `maxAgeSeconds`.
Detects walkers that are stuck in a sync call without progress updates.
A live job updates progress at least every few seconds via JobProgressCallback.
Anything older than maxAgeSeconds without finishing is considered hung.
"""
db = _getDb()
try:
rows = db.getRecordset(BackgroundJob, recordFilter={"status": BackgroundJobStatusEnum.RUNNING.value})
except Exception as ex:
logger.warning("killZombieJobs: failed to scan RUNNING jobs: %s", ex)
return 0
now = time.time()
threshold = now - maxAgeSeconds
count = 0
for row in rows:
started = row.get("startedAt") or row.get("createdAt")
if not started or started > threshold:
continue
ageMin = (now - started) / 60
try:
_markError(row["id"], f"Zombie killed (stuck >{maxAgeSeconds // 60}min, no progress)")
count += 1
payload = row.get("payload") or {}
logger.warning(
"killZombieJobs: killed %s (type=%s connId=%s ageMin=%.1f)",
row["id"], row.get("jobType"), payload.get("connectionId", "")[:12], ageMin,
)
except Exception as ex:
logger.warning("killZombieJobs: could not kill %s: %s", row.get("id"), ex)
return count
def registerZombieKillerScheduler(*, intervalMinutes: int = 5) -> None:
"""Register a recurring cron job that kills stuck RUNNING jobs.
Idempotent. Runs every `intervalMinutes` minutes.
"""
try:
from modules.shared.eventManagement import eventManager
async def _runKiller():
killZombieJobs()
eventManager.registerCron(
jobId="background_jobs.zombie_killer",
func=_runKiller,
cronKwargs={"minute": f"*/{intervalMinutes}"},
)
logger.info("Zombie-killer scheduler registered (every %d min)", intervalMinutes)
except Exception as ex:
logger.warning("Zombie-killer scheduler registration failed (non-critical): %s", ex)

View file

@ -532,8 +532,16 @@ class ChatService:
self, connectionId: str, sourceType: str, path: str, label: str,
featureInstanceId: str = None, displayPath: str = None,
) -> Dict[str, Any]:
"""Create a new external data source reference."""
"""Create a new external data source reference.
Returns existing record if connectionId + path already exists (upsert semantics).
"""
from modules.datamodels.datamodelDataSource import DataSource
existing = self.interfaceDbApp.db.getRecordset(
DataSource, recordFilter={"connectionId": connectionId, "path": path}
)
if existing:
return existing[0] if isinstance(existing[0], dict) else existing[0].model_dump()
ds = DataSource(
connectionId=connectionId,
sourceType=sourceType,

View file

@ -132,10 +132,10 @@ _SOURCE_TYPE_MAP = {
"gmail": ("gmailFolder",),
},
"clickup": {
"clickup": ("clickupList",),
"clickup": ("clickupList", "clickup"),
},
"infomaniak": {
"kdrive": ("kdriveFolder",),
"kdrive": ("kdriveFolder", "infomaniak"),
},
}
@ -225,7 +225,7 @@ async def _bootstrapJobHandler(
bootstrapOutlook,
)
progressCb(10, "sharepoint + outlook")
progressCb(0, "Synchronisierung läuft...")
spDs = _filterDs("sharepoint")
olDs = _filterDs("outlook")
async def _noopResult():
@ -251,7 +251,7 @@ async def _bootstrapJobHandler(
bootstrapGmail,
)
progressCb(10, "drive + gmail")
progressCb(0, "Synchronisierung läuft...")
gdDs = _filterDs("drive")
gmDs = _filterDs("gmail")
async def _noopResult():
@ -274,7 +274,7 @@ async def _bootstrapJobHandler(
bootstrapClickup,
)
progressCb(10, "clickup tasks")
progressCb(0, "Synchronisierung läuft...")
cuDs = _filterDs("clickup")
cuResult = await bootstrapClickup(connectionId=connectionId, progressCb=progressCb, dataSources=cuDs) if cuDs else {"skipped": True, "reason": "no_datasources"}
return {
@ -283,6 +283,20 @@ async def _bootstrapJobHandler(
"clickup": _normalize(cuResult, "clickup"),
}
if authority == "infomaniak":
from modules.serviceCenter.services.serviceKnowledge.subConnectorSyncKdrive import (
bootstrapKdrive,
)
progressCb(0, "Synchronisierung läuft...")
kdDs = _filterDs("kdrive")
kdResult = await bootstrapKdrive(connectionId=connectionId, progressCb=progressCb, dataSources=kdDs) if kdDs else {"skipped": True, "reason": "no_datasources"}
return {
"connectionId": connectionId,
"authority": authority,
"kdrive": _normalize(kdResult, "kdrive"),
}
logger.info(
"ingestion.connection.bootstrap.skipped reason=unsupported_authority authority=%s connectionId=%s",
authority, connectionId,

View file

@ -25,6 +25,12 @@ from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
WalkerTimeout,
ingestWithTimeout,
logItemStart,
)
logger = logging.getLogger(__name__)
MAX_TASKS_DEFAULT = 500
@ -449,36 +455,44 @@ async def _ingestTask(
name = task.get("name") or f"Task {taskId}"
syntheticId = _syntheticTaskId(connectionId, taskId)
fileName = f"{name[:80].strip() or taskId}.task.json"
logItemStart("clickup", f"{teamId}/{taskId}")
contentObjects = _buildContentObjects(task, limits)
try:
handle = await knowledgeService.requestIngestion(
IngestionJob(
sourceKind="clickup_task",
sourceId=syntheticId,
fileName=fileName,
mimeType="application/vnd.clickup.task+json",
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=revision or None,
neutralize=limits.neutralize,
provenance={
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "clickup",
"service": "clickup",
"externalItemId": taskId,
"teamId": teamId,
"listId": ((task.get("list") or {}).get("id")),
"spaceId": ((task.get("space") or {}).get("id")),
"url": task.get("url"),
"status": ((task.get("status") or {}).get("status")),
"tier": limits.clickupScope,
},
)
handle = await ingestWithTimeout(
knowledgeService.requestIngestion(
IngestionJob(
sourceKind="clickup_task",
sourceId=syntheticId,
fileName=fileName,
mimeType="application/vnd.clickup.task+json",
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=revision or None,
neutralize=limits.neutralize,
provenance={
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "clickup",
"service": "clickup",
"externalItemId": taskId,
"teamId": teamId,
"listId": ((task.get("list") or {}).get("id")),
"spaceId": ((task.get("space") or {}).get("id")),
"url": task.get("url"),
"status": ((task.get("status") or {}).get("status")),
"tier": limits.clickupScope,
},
)
),
label=taskId,
)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.error("clickup ingestion %s failed: %s", taskId, exc, exc_info=True)
result.failed += 1
@ -493,18 +507,16 @@ async def _ingestTask(
result.failed += 1
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 50 == 0:
if progressCb is not None and processed % 5 == 0:
if hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
return
try:
progressCb(
min(90, 10 + int(80 * processed / max(1, limits.maxTasks))),
f"clickup processed={processed}",
)
progressCb(0, f"{processed} Tasks verarbeitet, {result.indexed} indexiert")
except Exception:
pass
logger.info(
"ingestion.connection.bootstrap.progress part=clickup processed=%d skippedDup=%d failed=%d",
if processed % 50 == 0:
logger.info(
"ingestion.connection.bootstrap.progress part=clickup processed=%d skippedDup=%d failed=%d",
processed, result.skippedDuplicate, result.failed,
extra={
"event": "ingestion.connection.bootstrap.progress",

View file

@ -21,6 +21,13 @@ from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Dict, List, Optional
from modules.datamodels.datamodelExtraction import ExtractionOptions
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
WalkerTimeout,
downloadWithTimeout,
extractWithTimeout,
ingestWithTimeout,
logItemStart,
)
logger = logging.getLogger(__name__)
@ -342,9 +349,15 @@ async def _ingestOne(
syntheticFileId = _syntheticFileId(connectionId, externalItemId)
fileName = getattr(entry, "name", "") or externalItemId
declaredSize = int(getattr(entry, "size", 0) or 0) or None
logItemStart("gdrive", entryPath, sizeBytes=declaredSize, mime=mimeType)
try:
downloaded = await adapter.download(entryPath)
downloaded = await downloadWithTimeout(adapter.download(entryPath), label=entryPath)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.warning("gdrive download %s failed: %s", entryPath, exc)
result.failed += 1
@ -368,10 +381,16 @@ async def _ingestOne(
result.bytesProcessed += len(fileBytes)
try:
extracted = runExtractionFn(
extracted = await extractWithTimeout(
runExtractionFn,
fileBytes, fileName, mimeType,
ExtractionOptions(mergeStrategy=None),
label=entryPath,
)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.warning("gdrive extraction %s failed: %s", entryPath, exc)
result.failed += 1
@ -393,20 +412,27 @@ async def _ingestOne(
"tier": "body",
}
try:
handle = await knowledgeService.requestIngestion(
IngestionJob(
sourceKind="gdrive_item",
sourceId=syntheticFileId,
fileName=fileName,
mimeType=mimeType,
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=revision,
neutralize=limits.neutralize,
provenance=provenance,
)
handle = await ingestWithTimeout(
knowledgeService.requestIngestion(
IngestionJob(
sourceKind="gdrive_item",
sourceId=syntheticFileId,
fileName=fileName,
mimeType=mimeType,
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=revision,
neutralize=limits.neutralize,
provenance=provenance,
)
),
label=entryPath,
)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.error("gdrive ingestion %s failed: %s", entryPath, exc, exc_info=True)
result.failed += 1
@ -422,13 +448,10 @@ async def _ingestOne(
if handle.error:
result.errors.append(f"ingest({entryPath}): {handle.error}")
if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0:
processed = result.indexed + result.skippedDuplicate
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 5 == 0:
try:
progressCb(
min(90, 10 + int(80 * processed / max(1, limits.maxItems))),
f"gdrive processed={processed}",
)
progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert")
except Exception:
pass
logger.info(

View file

@ -24,6 +24,11 @@ from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Dict, List, Optional
from modules.serviceCenter.services.serviceKnowledge.subTextClean import cleanEmailBody
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
WalkerTimeout,
ingestWithTimeout,
logItemStart,
)
logger = logging.getLogger(__name__)
@ -399,34 +404,42 @@ async def _ingestMessage(
subject = headers.get("subject") or "(no subject)"
syntheticId = _syntheticMessageId(connectionId, messageId)
fileName = f"{subject[:80].strip()}.eml" if subject else f"{messageId}.eml"
logItemStart("gmail", f"{labelId}/{messageId}", mime="message/rfc822")
contentObjects = _buildContentObjects(
message, limits.maxBodyChars, mailContentDepth=limits.mailContentDepth
)
try:
handle = await knowledgeService.requestIngestion(
IngestionJob(
sourceKind="gmail_message",
sourceId=syntheticId,
fileName=fileName,
mimeType="message/rfc822",
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=str(revision) if revision else None,
neutralize=limits.neutralize,
provenance={
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "google",
"service": "gmail",
"externalItemId": messageId,
"label": labelId,
"threadId": message.get("threadId"),
"tier": limits.mailContentDepth,
},
)
handle = await ingestWithTimeout(
knowledgeService.requestIngestion(
IngestionJob(
sourceKind="gmail_message",
sourceId=syntheticId,
fileName=fileName,
mimeType="message/rfc822",
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=str(revision) if revision else None,
neutralize=limits.neutralize,
provenance={
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "google",
"service": "gmail",
"externalItemId": messageId,
"label": labelId,
"threadId": message.get("threadId"),
"tier": limits.mailContentDepth,
},
)
),
label=messageId,
)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.error("gmail ingestion %s failed: %s", messageId, exc, exc_info=True)
result.failed += 1
@ -458,18 +471,16 @@ async def _ingestMessage(
logger.warning("gmail attachments %s failed: %s", messageId, exc)
result.errors.append(f"attachments({messageId}): {exc}")
if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0:
processed = result.indexed + result.skippedDuplicate
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 5 == 0:
try:
progressCb(
min(90, 10 + int(80 * processed / max(1, limits.maxMessages))),
f"gmail processed={processed}",
)
progressCb(0, f"{processed} Mails verarbeitet, {result.indexed} indexiert")
except Exception:
pass
logger.info(
"ingestion.connection.bootstrap.progress part=gmail processed=%d skippedDup=%d failed=%d",
processed, result.skippedDuplicate, result.failed,
if processed % 50 == 0:
logger.info(
"ingestion.connection.bootstrap.progress part=gmail processed=%d skippedDup=%d failed=%d",
processed, result.skippedDuplicate, result.failed,
extra={
"event": "ingestion.connection.bootstrap.progress",
"part": "gmail",
@ -546,13 +557,26 @@ async def _ingestAttachments(
fileName = stub["filename"]
mimeType = stub["mimeType"]
syntheticId = _syntheticAttachmentId(connectionId, messageId, stub["attachmentId"])
attLabel = f"{messageId}/att:{stub['attachmentId']}/{fileName}"
logItemStart("gmail-attachment", attLabel, sizeBytes=stub.get("size") or None, mime=mimeType)
try:
extracted = runExtraction(
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
extractWithTimeout as _extractWithTimeout,
)
def _runAttExtraction():
return runExtraction(
extractorRegistry, chunkerRegistry,
rawBytes, fileName, mimeType,
ExtractionOptions(mergeStrategy=None),
)
try:
extracted = await _extractWithTimeout(_runAttExtraction, label=attLabel)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
continue
except Exception as exc:
logger.warning("gmail attachment extract %s failed: %s", stub["attachmentId"], exc)
result.failed += 1
@ -584,27 +608,33 @@ async def _ingestAttachments(
continue
try:
await knowledgeService.requestIngestion(
IngestionJob(
sourceKind="gmail_attachment",
sourceId=syntheticId,
fileName=fileName,
mimeType=mimeType,
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
provenance={
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "google",
"service": "gmail",
"parentId": parentSyntheticId,
"externalItemId": stub["attachmentId"],
"parentMessageId": messageId,
},
)
await ingestWithTimeout(
knowledgeService.requestIngestion(
IngestionJob(
sourceKind="gmail_attachment",
sourceId=syntheticId,
fileName=fileName,
mimeType=mimeType,
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
provenance={
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "google",
"service": "gmail",
"parentId": parentSyntheticId,
"externalItemId": stub["attachmentId"],
"parentMessageId": messageId,
},
)
),
label=attLabel,
)
result.attachmentsIndexed += 1
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
except Exception as exc:
logger.warning("gmail attachment ingest %s failed: %s", stub["attachmentId"], exc)
result.failed += 1

View file

@ -0,0 +1,439 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""kDrive bootstrap for the unified knowledge ingestion lane.
Walks every ragIndexEnabled kDrive DataSource, downloads file items and
hands them to KnowledgeService.requestIngestion. Idempotency is provided
by the ingestion facade (content-hash dedup).
"""
from __future__ import annotations
import asyncio
import hashlib
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from modules.datamodels.datamodelExtraction import ExtractionOptions
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
WalkerTimeout,
downloadWithTimeout,
extractWithTimeout,
ingestWithTimeout,
logItemStart,
)
logger = logging.getLogger(__name__)
MAX_ITEMS_DEFAULT = 500
MAX_BYTES_DEFAULT = 200 * 1024 * 1024
MAX_FILE_SIZE_DEFAULT = 25 * 1024 * 1024
SKIP_MIME_PREFIXES_DEFAULT = ("video/", "audio/")
MAX_DEPTH_DEFAULT = 4
@dataclass
class KdriveBootstrapLimits:
maxItems: int = MAX_ITEMS_DEFAULT
maxBytes: int = MAX_BYTES_DEFAULT
maxFileSize: int = MAX_FILE_SIZE_DEFAULT
skipMimePrefixes: tuple = SKIP_MIME_PREFIXES_DEFAULT
maxDepth: int = MAX_DEPTH_DEFAULT
neutralize: bool = False
@dataclass
class KdriveBootstrapResult:
connectionId: str
indexed: int = 0
skippedDuplicate: int = 0
skippedPolicy: int = 0
failed: int = 0
bytesProcessed: int = 0
errors: List[str] = field(default_factory=list)
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
token = hashlib.sha256(f"{connectionId}:{externalItemId}".encode("utf-8")).hexdigest()[:16]
return f"kd:{connectionId[:8]}:{token}"
def _toContentObjects(extracted, fileName: str) -> List[Dict[str, Any]]:
parts = getattr(extracted, "parts", None) or []
out: List[Dict[str, Any]] = []
for part in parts:
data = getattr(part, "data", None) or ""
if not data or not str(data).strip():
continue
typeGroup = getattr(part, "typeGroup", "text") or "text"
contentType = "text"
if typeGroup == "image":
contentType = "image"
elif typeGroup in ("binary", "container"):
contentType = "other"
out.append({
"contentObjectId": getattr(part, "id", ""),
"contentType": contentType,
"data": data,
"contextRef": {
"containerPath": fileName,
"location": getattr(part, "label", None) or "file",
**(getattr(part, "metadata", None) or {}),
},
})
return out
async def bootstrapKdrive(
connectionId: str,
*,
dataSources: Optional[List[Dict[str, Any]]] = None,
progressCb: Optional[Any] = None,
adapter: Any = None,
connection: Any = None,
knowledgeService: Any = None,
limits: Optional[KdriveBootstrapLimits] = None,
runExtractionFn: Optional[Callable[..., Any]] = None,
) -> Dict[str, Any]:
"""Enumerate kDrive folders and ingest files via the facade."""
if not dataSources:
return {"connectionId": connectionId, "skipped": True, "reason": "no_datasources"}
if not limits:
limits = KdriveBootstrapLimits()
startMs = time.time()
result = KdriveBootstrapResult(connectionId=connectionId)
logger.info(
"ingestion.connection.bootstrap.started part=kdrive connectionId=%s dataSources=%d",
connectionId, len(dataSources),
extra={"event": "ingestion.connection.bootstrap.started", "part": "kdrive",
"connectionId": connectionId, "dataSourceCount": len(dataSources)},
)
if adapter is None or knowledgeService is None or connection is None:
adapter, connection, knowledgeService = await _resolveDependencies(connectionId)
if runExtractionFn is None:
from modules.serviceCenter.services.serviceExtraction.subPipeline import runExtraction
from modules.serviceCenter.services.serviceExtraction.subRegistry import (
ExtractorRegistry, ChunkerRegistry,
)
extractorRegistry = ExtractorRegistry()
chunkerRegistry = ChunkerRegistry()
def runExtractionFn(bytesData, name, mime, options):
return runExtraction(extractorRegistry, chunkerRegistry, bytesData, name, mime, options)
mandateId = str(getattr(connection, "mandateId", "") or "") if connection is not None else ""
userId = str(getattr(connection, "userId", "") or "") if connection is not None else ""
cancelled = False
for ds in dataSources:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
break
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
cancelled = True
break
dsPath = ds.get("path", "")
dsId = ds.get("id", "")
dsNeutralize = ds.get("neutralize", False)
dsLimits = KdriveBootstrapLimits(
maxItems=limits.maxItems,
maxBytes=limits.maxBytes,
maxFileSize=limits.maxFileSize,
skipMimePrefixes=limits.skipMimePrefixes,
maxDepth=limits.maxDepth,
neutralize=dsNeutralize,
)
try:
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
runExtractionFn=runExtractionFn,
connectionId=connectionId,
mandateId=mandateId,
userId=userId,
folderPath=dsPath,
depth=0,
limits=dsLimits,
result=result,
progressCb=progressCb,
dataSourceId=dsId,
)
except Exception as exc:
logger.error("kdrive walk failed for ds %s path %s: %s", dsId, dsPath, exc, exc_info=True)
result.errors.append(f"walk({dsPath}): {exc}")
finalResult = _finalizeResult(connectionId, result, startMs)
if cancelled:
finalResult["cancelled"] = True
return finalResult
async def _resolveDependencies(connectionId: str):
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.auth import TokenManager
from modules.connectors.providerInfomaniak.connectorInfomaniak import InfomaniakConnector
from modules.serviceCenter import getService
from modules.serviceCenter.context import ServiceCenterContext
from modules.security.rootAccess import getRootUser
rootInterface = getRootInterface()
connection = rootInterface.getUserConnectionById(connectionId)
if connection is None:
raise ValueError(f"UserConnection not found: {connectionId}")
token = TokenManager().getFreshToken(connectionId)
if not token or not token.tokenAccess:
raise ValueError(f"No valid token for connection {connectionId}")
provider = InfomaniakConnector(connection, token.tokenAccess)
adapter = provider.getServiceAdapter("kdrive")
rootUser = getRootUser()
ctx = ServiceCenterContext(
user=rootUser,
mandate_id=str(getattr(connection, "mandateId", "") or ""),
)
knowledgeService = getService("knowledge", ctx)
return adapter, connection, knowledgeService
async def _walkFolder(
*,
adapter,
knowledgeService,
runExtractionFn,
connectionId: str,
mandateId: str,
userId: str,
folderPath: str,
depth: int,
limits: KdriveBootstrapLimits,
result: KdriveBootstrapResult,
progressCb: Optional[Any],
dataSourceId: str = "",
) -> None:
if depth > limits.maxDepth:
return
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
return
try:
entries = await adapter.browse(folderPath)
except Exception as exc:
logger.warning("kdrive browse %s failed: %s", folderPath, exc)
result.errors.append(f"browse({folderPath}): {exc}")
return
for entry in entries:
if result.indexed + result.skippedDuplicate >= limits.maxItems:
return
if result.bytesProcessed >= limits.maxBytes:
return
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
return
entryPath = getattr(entry, "path", "") or ""
if getattr(entry, "isFolder", False):
await _walkFolder(
adapter=adapter,
knowledgeService=knowledgeService,
runExtractionFn=runExtractionFn,
connectionId=connectionId,
mandateId=mandateId,
userId=userId,
folderPath=entryPath,
depth=depth + 1,
limits=limits,
result=result,
progressCb=progressCb,
dataSourceId=dataSourceId,
)
continue
mimeType = getattr(entry, "mimeType", None) or "application/octet-stream"
if any(mimeType.startswith(prefix) for prefix in limits.skipMimePrefixes):
result.skippedPolicy += 1
continue
size = int(getattr(entry, "size", 0) or 0)
if size and size > limits.maxFileSize:
result.skippedPolicy += 1
continue
metadata = getattr(entry, "metadata", {}) or {}
externalItemId = metadata.get("id") or entryPath
revision = metadata.get("revision") or metadata.get("lastModified")
await _ingestOne(
adapter=adapter,
knowledgeService=knowledgeService,
runExtractionFn=runExtractionFn,
connectionId=connectionId,
mandateId=mandateId,
userId=userId,
entry=entry,
entryPath=entryPath,
mimeType=mimeType,
externalItemId=externalItemId,
revision=revision,
limits=limits,
result=result,
progressCb=progressCb,
dataSourceId=dataSourceId,
)
async def _ingestOne(
*,
adapter,
knowledgeService,
runExtractionFn,
connectionId: str,
mandateId: str,
userId: str,
entry,
entryPath: str,
mimeType: str,
externalItemId: str,
revision: Optional[str],
limits: KdriveBootstrapLimits,
result: KdriveBootstrapResult,
progressCb: Optional[Any],
dataSourceId: str = "",
) -> None:
from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob
syntheticFileId = _syntheticFileId(connectionId, externalItemId)
fileName = getattr(entry, "name", "") or externalItemId
declaredSize = int(getattr(entry, "size", 0) or 0) or None
logItemStart("kdrive", entryPath, sizeBytes=declaredSize, mime=mimeType)
try:
downloadResult = await downloadWithTimeout(adapter.download(entryPath), label=entryPath)
fileBytes = getattr(downloadResult, "data", None)
dlFileName = getattr(downloadResult, "fileName", None)
dlMimeType = getattr(downloadResult, "mimeType", None)
if dlFileName:
fileName = dlFileName
if dlMimeType:
mimeType = dlMimeType
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.warning("kdrive download %s failed: %s", entryPath, exc)
result.failed += 1
result.errors.append(f"download({entryPath}): {exc}")
return
if not fileBytes:
result.failed += 1
return
result.bytesProcessed += len(fileBytes)
try:
extracted = await extractWithTimeout(
runExtractionFn,
fileBytes, fileName, mimeType,
ExtractionOptions(mergeStrategy=None),
label=entryPath,
)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.warning("kdrive extraction %s failed: %s", entryPath, exc)
result.failed += 1
result.errors.append(f"extract({entryPath}): {exc}")
return
contentObjects = _toContentObjects(extracted, fileName)
if not contentObjects:
result.skippedPolicy += 1
return
provenance: Dict[str, Any] = {
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "infomaniak",
"service": "kdrive",
"externalItemId": externalItemId,
"externalPath": entryPath,
"revision": revision,
}
try:
handle = await ingestWithTimeout(
knowledgeService.requestIngestion(
IngestionJob(
sourceKind="kdrive_item",
sourceId=syntheticFileId,
fileName=fileName,
mimeType=mimeType,
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=revision,
neutralize=limits.neutralize,
provenance=provenance,
)
),
label=entryPath,
)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.error("kdrive ingestion %s failed: %s", entryPath, exc, exc_info=True)
result.failed += 1
result.errors.append(f"ingest({entryPath}): {exc}")
return
if handle.status == "duplicate":
result.skippedDuplicate += 1
elif handle.status == "indexed":
result.indexed += 1
else:
result.failed += 1
if handle.error:
result.errors.append(f"ingest({entryPath}): {handle.error}")
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 5 == 0:
try:
progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert")
except Exception:
pass
await asyncio.sleep(0)
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
durationMs = int((time.time() - startMs) * 1000)
logger.info(
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
connectionId,
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
durationMs,
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
"connectionId": connectionId, "indexed": result.indexed,
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
"failed": result.failed, "durationMs": durationMs},
)
return {
"connectionId": result.connectionId,
"indexed": result.indexed,
"skippedDuplicate": result.skippedDuplicate,
"skippedPolicy": result.skippedPolicy,
"failed": result.failed,
"bytesProcessed": result.bytesProcessed,
"durationMs": durationMs,
"errors": result.errors[:20],
}

View file

@ -21,6 +21,12 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from modules.serviceCenter.services.serviceKnowledge.subTextClean import cleanEmailBody
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
WalkerTimeout,
extractWithTimeout,
ingestWithTimeout,
logItemStart,
)
logger = logging.getLogger(__name__)
@ -384,34 +390,42 @@ async def _ingestMessage(
subject = message.get("subject") or "(no subject)"
syntheticId = _syntheticMessageId(connectionId, messageId)
fileName = f"{subject[:80].strip()}.eml" if subject else f"{messageId}.eml"
logItemStart("outlook", messageId, mime="message/rfc822")
contentObjects = _buildContentObjects(
message, limits.maxBodyChars, mailContentDepth=limits.mailContentDepth
)
# Always at least the header is emitted, so `contentObjects` is non-empty.
try:
handle = await knowledgeService.requestIngestion(
IngestionJob(
sourceKind="outlook_message",
sourceId=syntheticId,
fileName=fileName,
mimeType="message/rfc822",
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=revision,
neutralize=limits.neutralize,
provenance={
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "msft",
"service": "outlook",
"externalItemId": messageId,
"internetMessageId": message.get("internetMessageId"),
"tier": limits.mailContentDepth,
},
)
handle = await ingestWithTimeout(
knowledgeService.requestIngestion(
IngestionJob(
sourceKind="outlook_message",
sourceId=syntheticId,
fileName=fileName,
mimeType="message/rfc822",
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=revision,
neutralize=limits.neutralize,
provenance={
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "msft",
"service": "outlook",
"externalItemId": messageId,
"internetMessageId": message.get("internetMessageId"),
"tier": limits.mailContentDepth,
},
)
),
label=messageId,
)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.error("outlook ingestion %s failed: %s", messageId, exc, exc_info=True)
result.failed += 1
@ -443,18 +457,16 @@ async def _ingestMessage(
logger.warning("outlook attachments %s failed: %s", messageId, exc)
result.errors.append(f"attachments({messageId}): {exc}")
if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0:
processed = result.indexed + result.skippedDuplicate
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 5 == 0:
try:
progressCb(
min(90, 10 + int(80 * processed / max(1, limits.maxMessages))),
f"outlook processed={processed}",
)
progressCb(0, f"{processed} Mails verarbeitet, {result.indexed} indexiert")
except Exception:
pass
logger.info(
"ingestion.connection.bootstrap.progress part=outlook processed=%d skippedDup=%d failed=%d",
processed, result.skippedDuplicate, result.failed,
if processed % 50 == 0:
logger.info(
"ingestion.connection.bootstrap.progress part=outlook processed=%d skippedDup=%d failed=%d",
processed, result.skippedDuplicate, result.failed,
extra={
"event": "ingestion.connection.bootstrap.progress",
"part": "outlook",
@ -518,13 +530,22 @@ async def _ingestAttachments(
mimeType = attachment.get("contentType") or "application/octet-stream"
attachmentId = attachment.get("id") or fileName
syntheticId = _syntheticAttachmentId(connectionId, messageId, attachmentId)
attLabel = f"{messageId}/att:{attachmentId}/{fileName}"
logItemStart("outlook-attachment", attLabel, sizeBytes=size or None, mime=mimeType)
try:
extracted = runExtraction(
def _runAttExtraction():
return runExtraction(
extractorRegistry, chunkerRegistry,
rawBytes, fileName, mimeType,
ExtractionOptions(mergeStrategy=None),
)
try:
extracted = await extractWithTimeout(_runAttExtraction, label=attLabel)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
continue
except Exception as exc:
logger.warning("outlook attachment extract %s failed: %s", attachmentId, exc)
result.failed += 1
@ -556,28 +577,34 @@ async def _ingestAttachments(
continue
try:
await knowledgeService.requestIngestion(
IngestionJob(
sourceKind="outlook_attachment",
sourceId=syntheticId,
fileName=fileName,
mimeType=mimeType,
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
neutralize=limits.neutralize,
provenance={
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "msft",
"service": "outlook",
"parentId": parentSyntheticId,
"externalItemId": attachmentId,
"parentMessageId": messageId,
},
)
await ingestWithTimeout(
knowledgeService.requestIngestion(
IngestionJob(
sourceKind="outlook_attachment",
sourceId=syntheticId,
fileName=fileName,
mimeType=mimeType,
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
neutralize=limits.neutralize,
provenance={
"connectionId": connectionId,
"dataSourceId": dataSourceId,
"authority": "msft",
"service": "outlook",
"parentId": parentSyntheticId,
"externalItemId": attachmentId,
"parentMessageId": messageId,
},
)
),
label=attLabel,
)
result.attachmentsIndexed += 1
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
except Exception as exc:
logger.warning("outlook attachment ingest %s failed: %s", attachmentId, exc)
result.failed += 1

View file

@ -20,6 +20,13 @@ from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from modules.datamodels.datamodelExtraction import ExtractionOptions
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
WalkerTimeout,
downloadWithTimeout,
extractWithTimeout,
ingestWithTimeout,
logItemStart,
)
logger = logging.getLogger(__name__)
@ -330,9 +337,15 @@ async def _ingestOne(
syntheticFileId = _syntheticFileId(connectionId, externalItemId)
fileName = getattr(entry, "name", "") or externalItemId
declaredSize = int(getattr(entry, "size", 0) or 0) or None
logItemStart("sharepoint", entryPath, sizeBytes=declaredSize, mime=mimeType)
try:
fileBytes = await adapter.download(entryPath)
fileBytes = await downloadWithTimeout(adapter.download(entryPath), label=entryPath)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.warning("sharepoint download %s failed: %s", entryPath, exc)
result.failed += 1
@ -345,10 +358,16 @@ async def _ingestOne(
result.bytesProcessed += len(fileBytes)
try:
extracted = runExtractionFn(
extracted = await extractWithTimeout(
runExtractionFn,
fileBytes, fileName, mimeType,
ExtractionOptions(mergeStrategy=None),
label=entryPath,
)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.warning("sharepoint extraction %s failed: %s", entryPath, exc)
result.failed += 1
@ -370,20 +389,27 @@ async def _ingestOne(
"revision": revision,
}
try:
handle = await knowledgeService.requestIngestion(
IngestionJob(
sourceKind="sharepoint_item",
sourceId=syntheticFileId,
fileName=fileName,
mimeType=mimeType,
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=revision,
neutralize=limits.neutralize,
provenance=provenance,
)
handle = await ingestWithTimeout(
knowledgeService.requestIngestion(
IngestionJob(
sourceKind="sharepoint_item",
sourceId=syntheticFileId,
fileName=fileName,
mimeType=mimeType,
userId=userId,
mandateId=mandateId,
contentObjects=contentObjects,
contentVersion=revision,
neutralize=limits.neutralize,
provenance=provenance,
)
),
label=entryPath,
)
except WalkerTimeout as exc:
result.failed += 1
result.errors.append(str(exc))
return
except Exception as exc:
logger.error("sharepoint ingestion %s failed: %s", entryPath, exc, exc_info=True)
result.failed += 1
@ -399,27 +425,17 @@ async def _ingestOne(
if handle.error:
result.errors.append(f"ingest({entryPath}): {handle.error}")
if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0:
processed = result.indexed + result.skippedDuplicate
processed = result.indexed + result.skippedDuplicate
if progressCb is not None and processed % 5 == 0:
try:
progressCb(
min(90, 10 + int(80 * processed / max(1, limits.maxItems))),
f"sharepoint processed={processed}",
)
progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert")
except Exception:
pass
logger.info(
"ingestion.connection.bootstrap.progress part=sharepoint processed=%d skippedDup=%d failed=%d",
processed, result.skippedDuplicate, result.failed,
extra={
"event": "ingestion.connection.bootstrap.progress",
"part": "sharepoint",
"connectionId": connectionId,
"processed": processed,
"skippedDup": result.skippedDuplicate,
"failed": result.failed,
},
)
if processed % 50 == 0:
logger.info(
"ingestion.connection.bootstrap.progress part=sharepoint processed=%d indexed=%d failed=%d",
processed, result.indexed, result.failed,
)
# Yield so the event loop can interleave other tasks (download/extract are
# CPU-ish and extraction uses sync libs; cooperative scheduling prevents

View file

@ -0,0 +1,116 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Shared helpers for ingestion walkers (timeouts, per-item logging).
Walkers (sharepoint, gdrive, gmail, outlook, clickup, kdrive) all face the
same risks:
- A single `adapter.download()` call can hang on the network for hours.
- A single `runExtraction()` call can hang on a corrupt PDF/Office doc inside
a sync extractor library, blocking the asyncio loop.
- A single `requestIngestion()` call can stall on the embedding API.
Without timeouts, one bad item freezes the whole bootstrap job and we end
up with "Job stuck at 10% for 10h" zombies.
These helpers wrap each phase in `asyncio.wait_for`. Sync extraction runs
on a worker thread so the loop stays responsive. Every wrapped call also
emits a short start/done log line, so when something hangs we know the
exact item that caused it (path, size, mime).
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Awaitable, Callable, Optional
logger = logging.getLogger(__name__)
DOWNLOAD_TIMEOUT_S = 60
EXTRACTION_TIMEOUT_S = 90
INGEST_TIMEOUT_S = 60
class WalkerTimeout(Exception):
"""Raised when a walker phase exceeds its timeout budget."""
async def downloadWithTimeout(
awaitable: Awaitable[Any],
*,
label: str,
timeoutSeconds: int = DOWNLOAD_TIMEOUT_S,
) -> Any:
"""Run a download awaitable with a hard timeout.
`label` is a short human-readable identifier (typically the external path)
used in log messages so we can pinpoint the offending item in case of a
hang or timeout.
"""
logger.info("walker.download.start %s timeout=%ds", label, timeoutSeconds)
try:
result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds)
logger.debug("walker.download.done %s", label)
return result
except asyncio.TimeoutError as ex:
logger.warning("walker.download.timeout %s after %ds", label, timeoutSeconds)
raise WalkerTimeout(f"download timeout after {timeoutSeconds}s: {label}") from ex
async def extractWithTimeout(
syncFn: Callable[..., Any],
*args: Any,
label: str,
timeoutSeconds: int = EXTRACTION_TIMEOUT_S,
) -> Any:
"""Run a synchronous extraction function on a worker thread with timeout.
Sync extractors (PDF, OCR, MS Office) cannot be cancelled cleanly from
asyncio; `wait_for` only protects the awaiter. The underlying thread may
keep running until the process exits but at least the walker proceeds
to the next item instead of freezing forever.
"""
logger.info("walker.extract.start %s timeout=%ds", label, timeoutSeconds)
try:
result = await asyncio.wait_for(
asyncio.to_thread(syncFn, *args),
timeout=timeoutSeconds,
)
logger.debug("walker.extract.done %s", label)
return result
except asyncio.TimeoutError as ex:
logger.warning("walker.extract.timeout %s after %ds", label, timeoutSeconds)
raise WalkerTimeout(f"extract timeout after {timeoutSeconds}s: {label}") from ex
async def ingestWithTimeout(
awaitable: Awaitable[Any],
*,
label: str,
timeoutSeconds: int = INGEST_TIMEOUT_S,
) -> Any:
"""Run an ingestion request with a hard timeout."""
logger.debug("walker.ingest.start %s timeout=%ds", label, timeoutSeconds)
try:
result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds)
logger.debug("walker.ingest.done %s", label)
return result
except asyncio.TimeoutError as ex:
logger.warning("walker.ingest.timeout %s after %ds", label, timeoutSeconds)
raise WalkerTimeout(f"ingest timeout after {timeoutSeconds}s: {label}") from ex
def logItemStart(service: str, label: str, *, sizeBytes: Optional[int] = None, mime: Optional[str] = None) -> None:
"""Log that processing of one item is about to begin.
When the worker hangs, the LAST `walker.item.start` line in the log
points to the exact item that caused the freeze. This is the single
most valuable diagnostic for stuck-job triage.
"""
parts = [f"walker.item.start service={service} path={label}"]
if sizeBytes is not None:
parts.append(f"size={sizeBytes}")
if mime:
parts.append(f"mime={mime}")
logger.info(" ".join(parts))

View file

@ -85,6 +85,11 @@ class AiAuditLogger:
try:
from modules.datamodels.datamodelAiAudit import AiAuditLogEntry
if contentInput:
contentInput = contentInput.replace("\x00", "")
if contentOutput:
contentOutput = contentOutput.replace("\x00", "")
inputPreview = (contentInput or "")[:_PREVIEW_LENGTH] or None
outputPreview = (contentOutput or "")[:_PREVIEW_LENGTH] or None
inputHash = hashlib.sha256(contentInput.encode("utf-8")).hexdigest() if contentInput else None

View file

@ -330,6 +330,16 @@ NAVIGATION_SECTIONS = [
"adminOnly": True,
"sysAdminOnly": True,
},
{
"id": "admin-stt-benchmark",
"objectKey": "ui.admin.sttBenchmark",
"label": t("STT Benchmark"),
"icon": "FaMicrophone",
"path": "/admin/stt-benchmark",
"order": 92,
"adminOnly": True,
"sysAdminOnly": True,
},
{
"id": "admin-languages",
"objectKey": "ui.admin.languages",

3
tests/eval/__init__.py Normal file
View file

@ -0,0 +1,3 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Eval harness for the Feature Data Sub-Agent (Phase 1.5)."""

View file

@ -0,0 +1,246 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""In-memory drop-in for FeatureDataProvider used by the eval harness.
Implements the same three public methods (browseTable / queryTable /
aggregateTable) plus the small surface the Sub-Agent reads (getActualColumns),
but runs all filters/aggregations in Python over the BenchmarkFixture rows.
This keeps the eval hermetic: no DB connection, no fixtures to insert/clean,
no flakiness from shared test schemas. Only the LLM call is real.
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional
_ALLOWED_AGGREGATES = {"SUM", "COUNT", "AVG", "MIN", "MAX"}
class FakeFeatureDataProvider:
"""In-memory provider compatible with :class:`FeatureDataProvider`."""
def __init__(
self,
rowsByTable: Dict[str, List[Dict[str, Any]]],
availableTables: Optional[List[Dict[str, Any]]] = None,
) -> None:
self._rowsByTable = {name: list(rows) for name, rows in rowsByTable.items()}
self._availableTables = list(availableTables or [])
self.callLog: List[Dict[str, Any]] = []
def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]: # noqa: ARG002
return list(self._availableTables)
def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]: # noqa: ARG002
for obj in self._availableTables:
if obj.get("meta", {}).get("table") == tableName:
return obj
return None
def getActualColumns(self, tableName: str) -> List[str]:
rows = self._rowsByTable.get(tableName, [])
if not rows:
return []
seen: List[str] = []
seenSet: set = set()
for row in rows:
for key in row.keys():
if key not in seenSet:
seen.append(key)
seenSet.add(key)
return seen
def browseTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
fields: List[str] = None,
limit: int = 50,
offset: int = 0,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
self.callLog.append({"method": "browseTable", "table": tableName, "fields": fields, "limit": limit})
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
rows = _applyFilters(rows, extraFilters)
total = len(rows)
rows = rows[offset : offset + limit]
if fields:
rows = [{k: v for k, v in row.items() if k in fields} for row in rows]
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
def queryTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
filters: List[Dict[str, Any]] = None,
fields: List[str] = None,
orderBy: str = None,
limit: int = 50,
offset: int = 0,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
self.callLog.append({
"method": "queryTable", "table": tableName, "filters": filters,
"fields": fields, "orderBy": orderBy, "limit": limit,
})
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
combined = list(filters or []) + list(extraFilters or [])
rows = _applyFilters(rows, combined)
if orderBy:
try:
rows = sorted(rows, key=lambda r: (r.get(orderBy) is None, r.get(orderBy)))
except TypeError:
rows = sorted(rows, key=lambda r: str(r.get(orderBy)))
total = len(rows)
rows = rows[offset : offset + limit]
if fields:
rows = [{k: v for k, v in row.items() if k in fields} for row in rows]
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
def aggregateTable(
self,
tableName: str,
featureInstanceId: str,
mandateId: str,
aggregate: str,
field: str,
groupBy: str = None,
extraFilters: Optional[List[Dict[str, Any]]] = None,
) -> Dict[str, Any]:
self.callLog.append({
"method": "aggregateTable", "table": tableName,
"aggregate": aggregate, "field": field, "groupBy": groupBy,
})
aggregate = aggregate.upper()
if aggregate not in _ALLOWED_AGGREGATES:
return {"rows": [], "error": f"Unsupported aggregate: {aggregate}"}
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
rows = _applyFilters(rows, extraFilters)
if groupBy:
groups: Dict[Any, List[Dict[str, Any]]] = {}
for row in rows:
groups.setdefault(row.get(groupBy), []).append(row)
outRows = [
{"groupValue": key, "result": _aggregate(aggregate, [r.get(field) for r in grp])}
for key, grp in groups.items()
]
outRows.sort(key=lambda r: (r["result"] is None, -(r["result"] or 0)))
else:
outRows = [{"result": _aggregate(aggregate, [r.get(field) for r in rows])}]
return {
"rows": outRows,
"aggregate": aggregate,
"field": field,
"groupBy": groupBy,
}
def _scopeRows(self, tableName: str, featureInstanceId: str, mandateId: str) -> List[Dict[str, Any]]:
rows = self._rowsByTable.get(tableName, [])
return [
row for row in rows
if (row.get("featureInstanceId") in (None, featureInstanceId))
and (row.get("mandateId") in (None, mandateId))
]
def _applyFilters(rows: List[Dict[str, Any]], filters: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
if not filters:
return rows
out = rows
for f in filters:
field = f.get("field")
op = (f.get("op") or "=").upper()
value = f.get("value")
out = [r for r in out if _matchesFilter(r.get(field), op, value)]
return out
def _matchesFilter(rowValue: Any, op: str, filterValue: Any) -> bool:
if op in ("IS NULL",):
return rowValue is None
if op in ("IS NOT NULL",):
return rowValue is not None
if rowValue is None:
return False
if op == "=":
return _coerceEqual(rowValue, filterValue)
if op == "!=":
return not _coerceEqual(rowValue, filterValue)
if op == ">":
return _coerceFloat(rowValue) > _coerceFloat(filterValue)
if op == "<":
return _coerceFloat(rowValue) < _coerceFloat(filterValue)
if op == ">=":
return _coerceFloat(rowValue) >= _coerceFloat(filterValue)
if op == "<=":
return _coerceFloat(rowValue) <= _coerceFloat(filterValue)
if op in ("LIKE", "ILIKE"):
pattern = str(filterValue or "")
target = str(rowValue)
if op == "ILIKE":
pattern = pattern.lower()
target = target.lower()
return _sqlLike(target, pattern)
if op == "IN":
if isinstance(filterValue, (list, tuple, set)):
return any(_coerceEqual(rowValue, v) for v in filterValue)
return _coerceEqual(rowValue, filterValue)
return False
def _coerceEqual(a: Any, b: Any) -> bool:
if a == b:
return True
try:
return str(a) == str(b)
except Exception:
return False
def _coerceFloat(value: Any) -> float:
if value is None:
return 0.0
try:
return float(value)
except (TypeError, ValueError):
return 0.0
def _sqlLike(value: str, pattern: str) -> bool:
"""Approximate SQL LIKE -- only % and _ wildcards."""
import re
regex = ""
i = 0
while i < len(pattern):
ch = pattern[i]
if ch == "%":
regex += ".*"
elif ch == "_":
regex += "."
else:
regex += re.escape(ch)
i += 1
return re.fullmatch(regex, value or "") is not None
def _aggregate(op: str, values: List[Any]) -> Any:
if op == "COUNT":
return sum(1 for v in values if v is not None)
nums = [_coerceFloat(v) for v in values if v is not None]
if not nums:
return 0 if op == "SUM" else None
if op == "SUM":
return round(sum(nums), 4)
if op == "AVG":
return round(sum(nums) / len(nums), 4)
if op == "MIN":
return min(nums)
if op == "MAX":
return max(nums)
return None

View file

@ -0,0 +1,735 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Trustee Sub-Agent Eval Harness (Phase 1.5).
Standalone runner that fires real AI calls against the Feature Data
Sub-Agent in three configurations:
* ``baseline`` -- production code without the pre-execute validator
(Repair-Loop disabled, Trustee domain hints active).
* ``phase1`` -- pre-execute validator on (Repair-Loop active),
domain hints active, no ontology yet.
* ``phase2`` -- validator on, ontology-driven schema context +
constraints (replaces hand-written domain hints).
For each mode we run all 19 gold-standard questions against an
in-memory :class:`FakeFeatureDataProvider`, capture the agent's tool
calls and final answer, score them against the gold standard, and
write a Markdown report to ``local/notes/`` for analysis.
Usage::
cd gateway
python -m tests.eval.runTrusteeBenchmark # all 3 modes
python -m tests.eval.runTrusteeBenchmark phase1 # one mode only
python -m tests.eval.runTrusteeBenchmark baseline phase1
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import re
import sys
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# ---------------------------------------------------------------------------
# Path setup so `python -m tests.eval.runTrusteeBenchmark` works from gateway/
# ---------------------------------------------------------------------------
_GATEWAY_DIR = Path(__file__).resolve().parents[2]
if str(_GATEWAY_DIR) not in sys.path:
sys.path.insert(0, str(_GATEWAY_DIR))
import yaml # noqa: E402
from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( # noqa: E402
AgentConfig,
AgentEventTypeEnum,
)
from modules.datamodels.datamodelAi import ( # noqa: E402
AiCallRequest,
AiCallResponse,
OperationTypeEnum,
)
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop # noqa: E402
from modules.serviceCenter.services.serviceAgent.featureDataAgent import ( # noqa: E402
_buildSubAgentTools,
_buildSchemaContext,
)
from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( # noqa: E402
QueryValidationError,
)
from modules.serviceCenter.services.serviceAgent.queryValidator import ( # noqa: E402
QueryValidator,
)
from tests.eval.fakeFeatureDataProvider import ( # noqa: E402
FakeFeatureDataProvider,
)
from tests.fixtures.trusteeBenchmark.loadTrusteeBenchmarkFixture import ( # noqa: E402
buildTrusteeBenchmarkFixture,
BenchmarkFixture,
)
logger = logging.getLogger("trusteeBenchmark")
# ---------------------------------------------------------------------------
# NoOpValidator -- baseline mode (Repair-Loop OFF)
# ---------------------------------------------------------------------------
class _NoOpValidator(QueryValidator):
"""Validator that never rejects anything (used for baseline measurement)."""
def validateBrowseQuery(self, tableName, args): # noqa: ARG002
return None
def validateQueryTable(self, tableName, args): # noqa: ARG002
return None
def validateAggregateQuery(self, tableName, args): # noqa: ARG002
return None
# ---------------------------------------------------------------------------
# Mode-specific tool/prompt building
# ---------------------------------------------------------------------------
@dataclass
class _ModeConfig:
name: str
label: str
useValidator: bool
useOntology: bool
_MODES: Dict[str, _ModeConfig] = {
"baseline": _ModeConfig(name="baseline", label="Baseline (no validator)", useValidator=False, useOntology=False),
"phase1": _ModeConfig(name="phase1", label="Phase 1 (validator on)", useValidator=True, useOntology=False),
"phase2": _ModeConfig(name="phase2", label="Phase 2 (validator + ontology)", useValidator=True, useOntology=True),
}
def _buildValidator(mode: _ModeConfig) -> QueryValidator:
"""Construct the per-mode validator.
* baseline: no-op (Repair-Loop disabled, used to measure raw LLM
accuracy against today's prompt path).
* phase1: convention-based QueryValidator (NEVER_AGGREGATE on
``*Balance``/``*Total`` suffixes; no ontology).
* phase2: ontology-driven QueryValidator (constraints from the
trustee ontology override the convention defaults).
"""
if not mode.useValidator:
return _NoOpValidator()
if mode.useOntology:
try:
from modules.features.trustee.trusteeOntology import getTrusteeOntology
return QueryValidator(ontology=getTrusteeOntology())
except Exception as e:
logger.warning("Could not load trustee ontology, falling back: %s", e)
return QueryValidator()
def _applyEnvForMode(mode: _ModeConfig) -> None:
"""Set the ontology toggle for the production prompt builder.
The Phase 2 path uses ``featureDataAgent._buildSchemaContext`` to pull
the prompt block from ``getAgentOntology()`` automatically. For
baseline/phase1 we set ``POWERON_DISABLE_FEATURE_ONTOLOGY=1`` so the
builder falls back to the legacy ``getAgentDomainHints()`` block --
measuring exactly the production prompt that ships today.
"""
if mode.useOntology:
os.environ.pop("POWERON_DISABLE_FEATURE_ONTOLOGY", None)
else:
os.environ["POWERON_DISABLE_FEATURE_ONTOLOGY"] = "1"
def _buildSystemPrompt(featureCode: str, instanceLabel: str, selectedTables: List[Dict[str, Any]]) -> str:
"""Build the sub-agent system prompt via the production path.
Mode-specific behaviour (legacy hints vs ontology block) is controlled
by the ``POWERON_DISABLE_FEATURE_ONTOLOGY`` env flag set per mode in
:func:`_applyEnvForMode`. Keeping the builder call identical for all
three modes means the benchmark measures the EXACT prompt the agent
would see in production -- no eval-only forks.
"""
return _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang="de")
# ---------------------------------------------------------------------------
# Question loading + per-question evaluation
# ---------------------------------------------------------------------------
@dataclass
class _Question:
id: str
question: str
intent: str
expectedTools: List[str]
expectedTable: Optional[str]
expectedAggregate: Optional[str]
expectedAggregateField: Optional[str]
requiredFilters: Dict[str, Any]
forbiddenTools: List[str]
expectedNumbers: List[float]
expectedAnswerContains: List[str]
numericTolerance: float
def _loadQuestions(yamlPath: Path) -> List[_Question]:
with open(yamlPath, "r", encoding="utf-8") as f:
rawList = yaml.safe_load(f)
questions: List[_Question] = []
for raw in rawList:
questions.append(_Question(
id=raw["id"],
question=raw["question"],
intent=raw.get("intent", ""),
expectedTools=list(raw.get("expectedTools") or []),
expectedTable=raw.get("expectedTable"),
expectedAggregate=raw.get("expectedAggregate"),
expectedAggregateField=raw.get("expectedAggregateField"),
requiredFilters=dict(raw.get("requiredFilters") or {}),
forbiddenTools=list(raw.get("forbiddenTools") or []),
expectedNumbers=[float(x) for x in (raw.get("expectedNumbers") or [])],
expectedAnswerContains=[str(x) for x in (raw.get("expectedAnswerContains") or [])],
numericTolerance=float(raw.get("numericTolerance") or 0.005),
))
return questions
@dataclass
class _RunResult:
questionId: str
finalText: str
toolCalls: List[Dict[str, Any]] = field(default_factory=list)
toolResults: List[Dict[str, Any]] = field(default_factory=list)
summary: Dict[str, Any] = field(default_factory=dict)
durationS: float = 0.0
error: Optional[str] = None
@property
def costCHF(self) -> float:
return float(self.summary.get("costCHF") or 0.0)
@property
def rounds(self) -> int:
return int(self.summary.get("rounds") or 0)
@property
def validationFailures(self) -> int:
return int(self.summary.get("validationFailures") or 0)
@property
def repairAttempts(self) -> int:
return int(self.summary.get("repairAttempts") or 0)
@property
def successAfterRepair(self) -> int:
return int(self.summary.get("successAfterRepair") or 0)
@dataclass
class _Score:
patternOk: bool = False
forbidOk: bool = False
numericOk: bool = False
accuracyOk: bool = False
notes: List[str] = field(default_factory=list)
def _scoreRun(question: _Question, run: _RunResult) -> _Score:
score = _Score()
if run.error:
score.notes.append(f"Sub-agent error: {run.error}")
return score
score.patternOk = _checkPattern(question, run)
score.forbidOk = _checkForbid(question, run)
score.numericOk = _checkNumeric(question, run)
score.accuracyOk = score.patternOk and score.forbidOk and score.numericOk
return score
def _checkPattern(question: _Question, run: _RunResult) -> bool:
"""Did the agent call one of the expected tools on the expected table with required filters?"""
if not question.expectedTools:
return True
matchingCalls = [
c for c in run.toolCalls
if c.get("toolName") in question.expectedTools
and (not question.expectedTable or c.get("args", {}).get("tableName") == question.expectedTable)
]
if not matchingCalls:
return False
if question.expectedAggregate:
wantAgg = question.expectedAggregate.upper()
wantField = question.expectedAggregateField
for c in matchingCalls:
args = c.get("args", {})
if c.get("toolName") != "aggregateTable":
continue
if (args.get("aggregate") or "").upper() != wantAgg:
continue
if wantField and args.get("field") != wantField:
continue
if not _filtersSatisfied(question.requiredFilters, args.get("extraFilters") or args.get("filters") or []):
continue
return True
return False
if question.requiredFilters:
for c in matchingCalls:
args = c.get("args", {})
filters = args.get("filters") or args.get("extraFilters") or []
if _filtersSatisfied(question.requiredFilters, filters):
return True
return False
return True
def _filtersSatisfied(required: Dict[str, Any], actualFilters: List[Dict[str, Any]]) -> bool:
if not required:
return True
for reqField, reqValue in required.items():
if reqField.endswith("Like"):
field = reqField[:-4]
wanted = str(reqValue)
ok = any(
(f.get("field") == field) and (f.get("op", "").upper() in ("LIKE", "ILIKE"))
and str(f.get("value")) == wanted
for f in actualFilters
)
if not ok:
return False
else:
ok = any(
f.get("field") == reqField and _filterValueEqual(f.get("value"), reqValue)
for f in actualFilters
)
if not ok:
return False
return True
def _filterValueEqual(a: Any, b: Any) -> bool:
if a == b:
return True
try:
return str(a).strip() == str(b).strip()
except Exception:
return False
def _checkForbid(question: _Question, run: _RunResult) -> bool:
"""Did the agent AVOID forbidden tool/op combinations?
Forbidden hits only count if the call actually went through to the
provider (success=True). Validator-rejected calls don't count -- the
Repair-Loop is doing its job and steering the agent away.
"""
if not question.forbiddenTools:
return True
forbiddenSet = set(question.forbiddenTools)
for r in run.toolResults:
if not r.get("success"):
continue
if r.get("toolName") in forbiddenSet:
return False
return True
def _checkNumeric(question: _Question, run: _RunResult) -> bool:
text = (run.finalText or "")
if question.expectedNumbers:
textNumbers = _extractNumbers(text)
for expected in question.expectedNumbers:
tol = max(abs(expected) * question.numericTolerance, 0.5)
if not any(abs(n - expected) <= tol for n in textNumbers):
return False
if question.expectedAnswerContains:
lowered = text.lower()
for needle in question.expectedAnswerContains:
if needle.lower() not in lowered:
return False
return True
def _extractNumbers(text: str) -> List[float]:
"""Pick out all numbers from a free-text answer.
Handles Swiss thousand separators (apostrophe and U+2019), German
decimals (comma), plain integers/floats, and JSON numbers. Trailing
punctuation (``,``, ``;``, ``.`` from end-of-sentence) is stripped
before parsing so ``"180500.0,"`` parses cleanly to 180500.0.
"""
cleaned = text.replace("\u2019", "'")
tokens = re.findall(r"-?\d[\d'.,]*", cleaned)
out: List[float] = []
for tok in tokens:
tok = tok.rstrip(",;")
if tok.endswith(".") and tok.count(".") == 1:
tok = tok[:-1]
norm = tok.replace("'", "")
if norm.count(",") == 1 and norm.count(".") == 0:
norm = norm.replace(",", ".")
elif norm.count(",") >= 1 and norm.count(".") >= 1:
if norm.rfind(",") > norm.rfind("."):
norm = norm.replace(".", "").replace(",", ".")
else:
norm = norm.replace(",", "")
else:
norm = norm.replace(",", "")
try:
out.append(float(norm))
except ValueError:
continue
return out
# ---------------------------------------------------------------------------
# AI call wiring
# ---------------------------------------------------------------------------
def _bootstrapServices() -> Tuple[Any, str, str]:
"""Spin up a minimal service hub bound to the root user + initial mandate.
Returns the ServiceHub, the user id, and the mandate id used for billing.
"""
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelUam import Mandate
from modules.serviceHub import getInterface as getServices
rootInterface = getRootInterface()
user = rootInterface.currentUser
mandateId = rootInterface.getInitialId(Mandate)
if not mandateId:
raise RuntimeError("No initial mandate available -- run bootstrap loader first.")
services = getServices(user, workflow=None, mandateId=mandateId, featureInstanceId=None)
return services, user.id, mandateId
async def _runOneQuestion(
*,
services: Any,
userId: str,
mandateId: str,
fixture: BenchmarkFixture,
question: _Question,
mode: _ModeConfig,
) -> _RunResult:
"""Execute a single sub-agent run for one question under one mode."""
provider = FakeFeatureDataProvider(
rowsByTable=fixture.rowsByTable,
availableTables=fixture.selectedTables,
)
validator = _buildValidator(mode)
registry = _buildSubAgentTools(
provider=provider,
featureInstanceId=fixture.featureInstanceId,
mandateId=fixture.mandateId,
tableFilters={},
validator=validator,
)
systemPrompt = _buildSystemPrompt(
featureCode="trustee",
instanceLabel="Demo AG",
selectedTables=fixture.selectedTables,
)
cost = 0.0
async def _aiCallFn(req: AiCallRequest) -> AiCallResponse:
nonlocal cost
resp = await services.ai.callAi(req)
cost += float(getattr(resp, "priceCHF", 0.0) or 0.0)
return resp
async def _getCost() -> float:
return cost
config = AgentConfig(
maxRounds=6,
maxCostCHF=0.50,
operationType=OperationTypeEnum.DATA_QUERY,
)
run = _RunResult(questionId=question.id, finalText="")
t0 = time.time()
try:
async for event in runAgentLoop(
prompt=question.question,
toolRegistry=registry,
config=config,
aiCallFn=_aiCallFn,
getWorkflowCostFn=_getCost,
workflowId=f"eval-{mode.name}-{question.id}-{uuid.uuid4().hex[:6]}",
userId=userId,
featureInstanceId=fixture.featureInstanceId,
mandateId=mandateId,
systemPromptOverride=systemPrompt,
):
if event.type == AgentEventTypeEnum.FINAL:
run.finalText = event.content or run.finalText
elif event.type == AgentEventTypeEnum.MESSAGE and event.content:
run.finalText += event.content
elif event.type == AgentEventTypeEnum.TOOL_CALL:
run.toolCalls.append(dict(event.data or {}))
elif event.type == AgentEventTypeEnum.TOOL_RESULT:
run.toolResults.append(dict(event.data or {}))
elif event.type == AgentEventTypeEnum.AGENT_SUMMARY:
run.summary = dict(event.data or {})
elif event.type == AgentEventTypeEnum.ERROR:
run.error = (run.error or "") + (event.content or "")
except Exception as e:
run.error = f"{type(e).__name__}: {e}"
logger.exception("Sub-agent run failed for %s/%s", mode.name, question.id)
run.durationS = time.time() - t0
return run
# ---------------------------------------------------------------------------
# Report
# ---------------------------------------------------------------------------
@dataclass
class _ModeReport:
mode: _ModeConfig
perQuestion: List[Tuple[_Question, _RunResult, _Score]] = field(default_factory=list)
@property
def total(self) -> int:
return len(self.perQuestion)
def _count(self, attr: str) -> int:
return sum(1 for _, _, s in self.perQuestion if getattr(s, attr))
@property
def accuracy(self) -> float:
return self._count("accuracyOk") / max(self.total, 1)
@property
def patternCompliance(self) -> float:
return self._count("patternOk") / max(self.total, 1)
@property
def repairConversionRate(self) -> float:
attempts = sum(r.repairAttempts for _, r, _ in self.perQuestion)
succeeded = sum(r.successAfterRepair for _, r, _ in self.perQuestion)
if attempts == 0:
return 0.0
return succeeded / attempts
@property
def totalCostCHF(self) -> float:
return sum(r.costCHF for _, r, _ in self.perQuestion)
@property
def totalRounds(self) -> int:
return sum(r.rounds for _, r, _ in self.perQuestion)
@property
def totalValidationFailures(self) -> int:
return sum(r.validationFailures for _, r, _ in self.perQuestion)
def _writeReport(reports: List[_ModeReport], outputPath: Path) -> None:
lines: List[str] = []
lines.append("# Trustee Sub-Agent Benchmark Report")
lines.append("")
lines.append(f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}")
lines.append("")
lines.append("## Summary")
lines.append("")
lines.append("| Mode | Questions | Accuracy | Pattern compliance | Repair conversion | Validator rejects | Rounds | Cost (CHF) |")
lines.append("|---|---|---|---|---|---|---|---|")
for rep in reports:
lines.append(
f"| {rep.mode.label} | {rep.total} | {rep.accuracy:.1%} | {rep.patternCompliance:.1%} | "
f"{rep.repairConversionRate:.1%} | {rep.totalValidationFailures} | {rep.totalRounds} | "
f"{rep.totalCostCHF:.4f} |"
)
lines.append("")
lines.append("## Per-question detail")
for rep in reports:
lines.append("")
lines.append(f"### {rep.mode.label}")
lines.append("")
lines.append("| id | acc | pattern | forbid | numeric | rounds | val-fail | repairs | cost CHF | duration | tools |")
lines.append("|---|---|---|---|---|---|---|---|---|---|---|")
for q, r, s in rep.perQuestion:
toolList = ",".join(
f"{c.get('toolName')}({c.get('args',{}).get('tableName','?')})"
for c in r.toolCalls
)
lines.append(
f"| {q.id} | {_yn(s.accuracyOk)} | {_yn(s.patternOk)} | {_yn(s.forbidOk)} | {_yn(s.numericOk)} | "
f"{r.rounds} | {r.validationFailures} | {r.repairAttempts}/{r.successAfterRepair} | "
f"{r.costCHF:.4f} | {r.durationS:.1f}s | {toolList} |"
)
lines.append("")
lines.append("#### Notes & failures")
for q, r, s in rep.perQuestion:
if s.accuracyOk:
continue
lines.append(f"- **{q.id}** ({q.intent}): pattern={s.patternOk} forbid={s.forbidOk} numeric={s.numericOk}")
if r.error:
lines.append(f" - error: `{r.error}`")
lines.append(f" - answer: `{(r.finalText or '').strip().replace('|', '/').splitlines()[0][:240]}`")
for note in s.notes:
lines.append(f" - note: {note}")
outputPath.parent.mkdir(parents=True, exist_ok=True)
outputPath.write_text("\n".join(lines), encoding="utf-8")
def _yn(b: bool) -> str:
return "OK" if b else "FAIL"
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
async def _runMain(modesToRun: List[str], onlyQuestionId: Optional[str] = None) -> None:
logging.basicConfig(
level=logging.WARNING,
format="%(asctime)s %(levelname)s %(name)s -- %(message)s",
)
logger.setLevel(logging.INFO)
fixture = buildTrusteeBenchmarkFixture()
questionsPath = _GATEWAY_DIR / "tests" / "fixtures" / "trusteeBenchmark" / "questions.yaml"
allQuestions = _loadQuestions(questionsPath)
if onlyQuestionId:
allQuestions = [q for q in allQuestions if q.id == onlyQuestionId]
if not allQuestions:
print(f"No question matches id={onlyQuestionId!r}")
return
print(f"Loaded {len(allQuestions)} questions, {len(modesToRun)} modes -> {len(allQuestions) * len(modesToRun)} sub-agent runs.")
services, userId, mandateId = _bootstrapServices()
print(f"Bootstrap OK: user={userId}, mandate={mandateId}")
reports: List[_ModeReport] = []
for modeName in modesToRun:
mode = _MODES[modeName]
_applyEnvForMode(mode)
rep = _ModeReport(mode=mode)
print(f"\n=== Mode: {mode.label} ===")
for idx, question in enumerate(allQuestions, start=1):
print(f" [{idx:>2}/{len(allQuestions)}] {question.id}: {question.question[:80]} ...", flush=True)
run = await _runOneQuestion(
services=services,
userId=userId,
mandateId=mandateId,
fixture=fixture,
question=question,
mode=mode,
)
score = _scoreRun(question, run)
rep.perQuestion.append((question, run, score))
print(
f" -> acc={_yn(score.accuracyOk)} "
f"pattern={_yn(score.patternOk)} forbid={_yn(score.forbidOk)} "
f"numeric={_yn(score.numericOk)} rounds={run.rounds} cost={run.costCHF:.4f} "
f"val-fail={run.validationFailures} repairs={run.repairAttempts}/{run.successAfterRepair}",
flush=True,
)
reports.append(rep)
timestamp = time.strftime("%Y%m%d-%H%M%S")
outDir = _GATEWAY_DIR.parent / "local" / "notes"
reportPath = outDir / f"trustee-benchmark-{timestamp}.md"
_writeReport(reports, reportPath)
rawJsonPath = outDir / f"trustee-benchmark-{timestamp}.json"
rawJsonPath.write_text(
json.dumps(
[
{
"mode": rep.mode.name,
"accuracy": rep.accuracy,
"patternCompliance": rep.patternCompliance,
"repairConversionRate": rep.repairConversionRate,
"totalCostCHF": rep.totalCostCHF,
"totalRounds": rep.totalRounds,
"totalValidationFailures": rep.totalValidationFailures,
"items": [
{
"questionId": q.id,
"intent": q.intent,
"accuracyOk": s.accuracyOk,
"patternOk": s.patternOk,
"forbidOk": s.forbidOk,
"numericOk": s.numericOk,
"rounds": r.rounds,
"validationFailures": r.validationFailures,
"repairAttempts": r.repairAttempts,
"successAfterRepair": r.successAfterRepair,
"costCHF": r.costCHF,
"durationS": r.durationS,
"finalText": (r.finalText or "")[:600],
"toolCalls": r.toolCalls,
"error": r.error,
}
for q, r, s in rep.perQuestion
],
}
for rep in reports
],
indent=2,
ensure_ascii=False,
),
encoding="utf-8",
)
print(f"\nReport written: {reportPath}")
print(f"Raw JSON: {rawJsonPath}")
for rep in reports:
print(f" {rep.mode.label}: acc={rep.accuracy:.1%} pattern={rep.patternCompliance:.1%} cost={rep.totalCostCHF:.4f}")
def _parseArgs(argv: List[str]) -> Tuple[List[str], Optional[str]]:
modes: List[str] = []
only: Optional[str] = None
for arg in argv:
if arg.startswith("--only="):
only = arg.split("=", 1)[1]
elif arg in _MODES:
modes.append(arg)
else:
print(f"Unknown argument: {arg!r}. Allowed modes: {list(_MODES)}")
sys.exit(2)
if not modes:
modes = ["baseline", "phase1", "phase2"]
return modes, only
def main() -> None:
modes, only = _parseArgs(sys.argv[1:])
asyncio.run(_runMain(modes, onlyQuestionId=only))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,16 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Trustee benchmark fixture: synthetic but realistic Swiss KMU accounting data.
Used by the Feature Data Sub-Agent eval harness (Phase 1.5) to measure
hallucination rates against a fixed gold standard. Data is built in-memory
via Pydantic models -- no SQL, no DB connection -- so the harness stays
hermetic and reproducible.
"""
from tests.fixtures.trusteeBenchmark.loadTrusteeBenchmarkFixture import (
buildTrusteeBenchmarkFixture,
BenchmarkFixture,
)
__all__ = ["buildTrusteeBenchmarkFixture", "BenchmarkFixture"]

View file

@ -0,0 +1,275 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Synthetic Trustee benchmark fixture for the Feature Data Sub-Agent eval.
Builds an in-memory snapshot of one fictional Swiss KMU mandate
("Demo AG") with:
* 3 fiscal years (2023, 2024, 2025) of `TrusteeDataAccountBalance` rows
-- both annual totals (periodMonth=0) and monthly snapshots.
* 8 representative accounts spanning all major chart-of-accounts blocks
(cash, banks, receivables, payables, revenue, materials, personnel,
operating expenses).
* Per-month `TrusteeDataJournalEntry` + multiple `TrusteeDataJournalLine`
rows so debit/credit/COUNT aggregations have meaningful answers.
The data is deterministic (no RNG) so a question's gold-standard answer
is stable across runs.
This module deliberately stays decoupled from the production DB pipeline
-- the harness uses :class:`FakeFeatureDataProvider` (see
``gateway/tests/eval/fakeFeatureDataProvider.py``) to serve queries
against this in-memory snapshot, mirroring the public methods of
``FeatureDataProvider`` (browseTable / queryTable / aggregateTable).
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List
_MANDATE_ID = "m-demo-ag"
_FEATURE_INSTANCE_ID = "fi-demo-ag-trustee"
# ---------------------------------------------------------------------------
# Account master data
# ---------------------------------------------------------------------------
_ACCOUNT_MASTER: List[Dict[str, Any]] = [
{"accountNumber": "1000", "label": "Hauptkasse", "accountType": "asset", "currency": "CHF"},
{"accountNumber": "1020", "label": "ZKB Geschaeftskonto", "accountType": "asset", "currency": "CHF"},
{"accountNumber": "1021", "label": "PostFinance", "accountType": "asset", "currency": "CHF"},
{"accountNumber": "1100", "label": "Forderungen aus Lieferungen und Leistungen", "accountType": "asset", "currency": "CHF"},
{"accountNumber": "2000", "label": "Verbindlichkeiten aus Lieferungen", "accountType": "liability", "currency": "CHF"},
{"accountNumber": "3000", "label": "Ertrag aus Beratung", "accountType": "revenue", "currency": "CHF"},
{"accountNumber": "5400", "label": "Materialaufwand", "accountType": "expense", "currency": "CHF"},
{"accountNumber": "6000", "label": "Mietaufwand", "accountType": "expense", "currency": "CHF"},
]
# Annual closing balances per (year, accountNumber) -- the canonical reference.
# Asset/expense balances are positive, liability/revenue balances are stored
# as positive numbers (sign by accountType, like most accounting systems).
_ANNUAL_CLOSING: Dict[int, Dict[str, float]] = {
2023: {
"1000": 4_800.00,
"1020": 132_500.00,
"1021": 22_400.00,
"1100": 58_200.00,
"2000": 41_300.00,
"3000": 410_000.00,
"5400": 92_000.00,
"6000": 36_000.00,
},
2024: {
"1000": 5_200.00,
"1020": 148_900.00,
"1021": 26_750.00,
"1100": 61_400.00,
"2000": 44_100.00,
"3000": 462_500.00,
"5400": 104_300.00,
"6000": 39_000.00,
},
2025: {
"1000": 5_900.00,
"1020": 152_400.00,
"1021": 28_100.00,
"1100": 66_800.00,
"2000": 47_900.00,
"3000": 488_700.00,
"5400": 112_100.00,
"6000": 42_000.00,
},
}
def _openingFromPriorYear(year: int, accountNumber: str) -> float:
"""Opening balance of year N = closing balance of year N-1 (0 if N-1 is unknown)."""
prior = year - 1
return float(_ANNUAL_CLOSING.get(prior, {}).get(accountNumber, 0.0))
def _monthlyProgression(opening: float, closing: float, month: int) -> float:
"""Linear interpolation between opening and closing for monthly snapshots.
Not realistic in detail but deterministic and monotonic per account, so
questions about "Stand per Ende März" produce stable answers.
"""
if month <= 0:
return float(closing)
frac = month / 12.0
return round(float(opening) + (float(closing) - float(opening)) * frac, 2)
# ---------------------------------------------------------------------------
# Journal entries / lines -- minimal but realistic
# ---------------------------------------------------------------------------
_JOURNAL_ENTRIES_2025: List[Dict[str, Any]] = [
{"month": 3, "day": 15, "reference": "RG-2025-0042", "description": "Beratung Kunde ACME AG", "amount": 18_500.00, "debit": "1100", "credit": "3000"},
{"month": 3, "day": 22, "reference": "EK-2025-0017", "description": "Materialeinkauf Buehler AG", "amount": 9_200.00, "debit": "5400", "credit": "2000"},
{"month": 3, "day": 28, "reference": "MIETE-2025-03", "description": "Mietzins Buero Maerz", "amount": 3_000.00, "debit": "6000", "credit": "1020"},
{"month": 4, "day": 5, "reference": "RG-2025-0051", "description": "Beratung Kunde Bell AG", "amount": 24_300.00, "debit": "1100", "credit": "3000"},
{"month": 4, "day": 18, "reference": "EK-2025-0024", "description": "Materialeinkauf Industriebedarf", "amount": 7_800.00, "debit": "5400", "credit": "2000"},
{"month": 6, "day": 12, "reference": "RG-2025-0079", "description": "Beratung Kunde Bell AG", "amount": 32_100.00, "debit": "1100", "credit": "3000"},
{"month": 6, "day": 30, "reference": "MIETE-2025-Q2", "description": "Mietzins Buero Q2-Abrechnung", "amount": 3_500.00, "debit": "6000", "credit": "1020"},
{"month": 9, "day": 4, "reference": "RG-2025-0114", "description": "Beratung Kunde Migros", "amount": 41_500.00, "debit": "1100", "credit": "3000"},
{"month": 9, "day": 25, "reference": "EK-2025-0061", "description": "Materialeinkauf Buehler AG", "amount": 12_400.00, "debit": "5400", "credit": "2000"},
{"month": 11, "day": 14, "reference": "RG-2025-0188", "description": "Beratung Kunde ACME AG", "amount": 28_700.00, "debit": "1100", "credit": "3000"},
]
# ---------------------------------------------------------------------------
# Snapshot containers
# ---------------------------------------------------------------------------
@dataclass
class BenchmarkFixture:
"""In-memory rows that mimic feature DB tables.
Each ``rowsByTable[tableName]`` is a list of column dicts compatible
with the Pydantic feature data models (TrusteeDataAccountBalance, etc.).
"""
mandateId: str
featureInstanceId: str
rowsByTable: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict)
selectedTables: List[Dict[str, Any]] = field(default_factory=list)
def _buildSelectedTables() -> List[Dict[str, Any]]:
"""Return the DATA_OBJECT-shaped descriptors the sub-agent expects.
Mirrors what the catalog would return for the trustee feature; the
real `getDataObjects("trustee")` call would yield the same shape but
we hard-code the three tables we actually populate.
"""
return [
{
"objectKey": "data.feature.trustee.TrusteeDataAccount",
"label": {"de": "Kontenplan", "en": "Chart of accounts"},
"meta": {
"table": "TrusteeDataAccount",
"fields": ["id", "accountNumber", "label", "accountType", "currency", "isActive"],
},
},
{
"objectKey": "data.feature.trustee.TrusteeDataAccountBalance",
"label": {"de": "Kontosalden", "en": "Account balances"},
"meta": {
"table": "TrusteeDataAccountBalance",
"fields": [
"id", "accountNumber", "periodYear", "periodMonth",
"openingBalance", "debitTotal", "creditTotal",
"closingBalance", "currency",
],
},
},
{
"objectKey": "data.feature.trustee.TrusteeDataJournalLine",
"label": {"de": "Buchungszeilen", "en": "Journal lines"},
"meta": {
"table": "TrusteeDataJournalLine",
"fields": [
"id", "journalEntryId", "accountNumber",
"debitAmount", "creditAmount", "currency", "description",
],
},
},
]
def buildTrusteeBenchmarkFixture() -> BenchmarkFixture:
"""Materialize the full in-memory benchmark snapshot.
All rows include ``mandateId`` and ``featureInstanceId`` columns so the
fake provider can scope them the same way the real one does.
"""
accountRows: List[Dict[str, Any]] = []
for i, acc in enumerate(_ACCOUNT_MASTER):
accountRows.append({
"id": f"acc-{i:03d}",
"accountNumber": acc["accountNumber"],
"label": acc["label"],
"accountType": acc["accountType"],
"currency": acc["currency"],
"isActive": True,
"mandateId": _MANDATE_ID,
"featureInstanceId": _FEATURE_INSTANCE_ID,
})
balanceRows: List[Dict[str, Any]] = []
rowIdx = 0
for year, closings in _ANNUAL_CLOSING.items():
for accountNumber, closing in closings.items():
opening = _openingFromPriorYear(year, accountNumber)
balanceRows.append({
"id": f"bal-{rowIdx:04d}",
"accountNumber": accountNumber,
"periodYear": year,
"periodMonth": 0,
"openingBalance": opening,
"debitTotal": round(max(closing - opening, 0.0) * 1.2, 2),
"creditTotal": round(max(closing - opening, 0.0) * 0.2, 2),
"closingBalance": float(closing),
"currency": "CHF",
"mandateId": _MANDATE_ID,
"featureInstanceId": _FEATURE_INSTANCE_ID,
})
rowIdx += 1
for month in range(1, 13):
monthly = _monthlyProgression(opening, closing, month)
balanceRows.append({
"id": f"bal-{rowIdx:04d}",
"accountNumber": accountNumber,
"periodYear": year,
"periodMonth": month,
"openingBalance": opening,
"debitTotal": round((monthly - opening) * 1.2, 2) if monthly > opening else 0.0,
"creditTotal": round((monthly - opening) * 0.2, 2) if monthly > opening else 0.0,
"closingBalance": monthly,
"currency": "CHF",
"mandateId": _MANDATE_ID,
"featureInstanceId": _FEATURE_INSTANCE_ID,
})
rowIdx += 1
lineRows: List[Dict[str, Any]] = []
for j, entry in enumerate(_JOURNAL_ENTRIES_2025):
entryId = f"je-2025-{j:03d}"
lineRows.append({
"id": f"jl-{j*2:04d}",
"journalEntryId": entryId,
"accountNumber": entry["debit"],
"debitAmount": float(entry["amount"]),
"creditAmount": 0.0,
"currency": "CHF",
"description": entry["description"],
"mandateId": _MANDATE_ID,
"featureInstanceId": _FEATURE_INSTANCE_ID,
})
lineRows.append({
"id": f"jl-{j*2+1:04d}",
"journalEntryId": entryId,
"accountNumber": entry["credit"],
"debitAmount": 0.0,
"creditAmount": float(entry["amount"]),
"currency": "CHF",
"description": entry["description"],
"mandateId": _MANDATE_ID,
"featureInstanceId": _FEATURE_INSTANCE_ID,
})
fixture = BenchmarkFixture(
mandateId=_MANDATE_ID,
featureInstanceId=_FEATURE_INSTANCE_ID,
rowsByTable={
"TrusteeDataAccount": accountRows,
"TrusteeDataAccountBalance": balanceRows,
"TrusteeDataJournalLine": lineRows,
},
selectedTables=_buildSelectedTables(),
)
return fixture

View file

@ -0,0 +1,226 @@
# Trustee Sub-Agent Benchmark -- 19 questions analog Hein 2025
#
# Each question covers ONE expected hallucination class so we can attribute
# accuracy gains to specific phases (validator / ontology).
#
# Scoring per question (all binary unless noted):
# patternOk -- did the agent call the right tool(s) with the right filters?
# forbidOk -- did it AVOID the forbidden tool/op (e.g. SUM closingBalance)?
# numericOk -- does the final answer contain the expected number(s)?
# accuracyOk -- patternOk AND forbidOk AND numericOk
#
# tolerance: relative tolerance for numeric comparison (default 0.005 = 0.5 %).
- id: q01
question: "Was ist der Banksaldo per 31.12.2025 fuer das ZKB-Konto 1020?"
intent: BANK_BALANCE_AT_DATE
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
accountNumber: "1020"
periodYear: 2025
periodMonth: 0
forbiddenTools: [aggregateTable]
expectedNumbers: [152400.0]
- id: q02
question: "Wie hoch ist die Hauptkasse (Konto 1000) per Ende 2024?"
intent: CASH_BALANCE_AT_DATE
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
accountNumber: "1000"
periodYear: 2024
periodMonth: 0
forbiddenTools: [aggregateTable]
expectedNumbers: [5200.0]
- id: q03
question: "Summiere alle Bankkonten (102x) per 31.12.2025."
intent: BANK_GROUP_TOTAL_AT_DATE
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
periodYear: 2025
periodMonth: 0
accountNumberLike: "102%"
forbiddenTools: [aggregateTable]
expectedNumbers: [180500.0]
numericTolerance: 0.01
- id: q04
question: "Wie hat sich der Schlusssaldo des ZKB-Kontos 1020 ueber die Jahre 2023 bis 2025 entwickelt?"
intent: BALANCE_HISTORY_PER_YEAR
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
accountNumber: "1020"
periodMonth: 0
forbiddenTools: [aggregateTable]
expectedNumbers: [132500.0, 148900.0, 152400.0]
- id: q05
question: "Welches Konto hatte 2025 den hoechsten Schlusssaldo bei den Aktiven (1xxx)?"
intent: TOP_ASSET_AT_DATE
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
periodYear: 2025
periodMonth: 0
accountNumberLike: "1%"
forbiddenTools: [aggregateTable]
expectedAnswerContains: ["1020"]
expectedNumbers: [152400.0]
- id: q06
question: "Welche Konten gehoeren zu den Bankkonten (102x)?"
intent: ACCOUNT_LIST_FILTER
expectedTools: [queryTable]
expectedTable: TrusteeDataAccount
requiredFilters:
accountNumberLike: "102%"
forbiddenTools: [aggregateTable]
expectedAnswerContains: ["1020", "1021"]
- id: q07
question: "Wie hoch war der Materialaufwand (Konto 5400) im Jahr 2025?"
intent: EXPENSE_AT_YEAR
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
accountNumber: "5400"
periodYear: 2025
periodMonth: 0
forbiddenTools: [aggregateTable]
expectedNumbers: [112100.0]
- id: q08
question: "Wie viele Buchungszeilen gibt es insgesamt im System?"
intent: COUNT_ROWS
expectedTools: [aggregateTable]
expectedTable: TrusteeDataJournalLine
expectedAggregate: COUNT
forbiddenTools: []
expectedNumbers: [20]
- id: q09
question: "Wie hoch ist der gesamte Beratungsertrag (Konto 3000) im Jahr 2025?"
intent: REVENUE_AT_YEAR
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
accountNumber: "3000"
periodYear: 2025
periodMonth: 0
forbiddenTools: [aggregateTable]
expectedNumbers: [488700.0]
- id: q10
question: "Wie viel wurde 2025 auf das Materialaufwand-Konto 5400 gebucht (Soll-Summe ueber Buchungszeilen)?"
intent: JOURNAL_SUM_AT_ACCOUNT
expectedTools: [aggregateTable]
expectedTable: TrusteeDataJournalLine
expectedAggregate: SUM
expectedAggregateField: debitAmount
requiredFilters:
accountNumber: "5400"
forbiddenTools: []
expectedNumbers: [29400.0]
numericTolerance: 0.01
- id: q11
question: "Welche Buchungen im 1. Quartal 2025 (Januar bis Maerz) wurden auf Konto 3000 gebucht?"
intent: JOURNAL_LINES_BY_ACCOUNT
expectedTools: [queryTable]
expectedTable: TrusteeDataJournalLine
requiredFilters:
accountNumber: "3000"
forbiddenTools: [aggregateTable]
expectedAnswerContains: ["18500", "ACME"]
- id: q12
question: "Wie hoch war die Hauptkasse (Konto 1000) jeweils per Ende Maerz 2025 und per Ende Juni 2025?"
intent: MULTI_MONTH_SNAPSHOT
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
accountNumber: "1000"
periodYear: 2025
forbiddenTools: [aggregateTable]
expectedNumbers: [5375.0, 5550.0]
numericTolerance: 0.01
- id: q13
question: "Wie hoch ist die Summe aller Aufwandskonten (5xxx und 6xxx) per Ende 2025?"
intent: EXPENSE_GROUP_TOTAL
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
periodYear: 2025
periodMonth: 0
forbiddenTools: [aggregateTable]
expectedNumbers: [154100.0]
numericTolerance: 0.01
- id: q14
question: "Welches Konto hat den hoechsten openingBalance fuer 2025?"
intent: TOP_OPENING_BALANCE
# Both routes are legitimate: queryTable+orderBy+limit=1, or
# aggregateTable(MAX) followed by queryTable lookup. We only insist that
# the final answer names the right account and (optionally) the value.
expectedTools: [queryTable, aggregateTable]
expectedTable: TrusteeDataAccountBalance
forbiddenTools: []
expectedAnswerContains: ["3000"]
expectedNumbers: [462500.0]
- id: q15
question: "Liste alle Konten vom Typ asset auf."
intent: ACCOUNTS_BY_TYPE
expectedTools: [queryTable]
expectedTable: TrusteeDataAccount
requiredFilters:
accountType: "asset"
forbiddenTools: [aggregateTable]
expectedAnswerContains: ["1000", "1020", "1021", "1100"]
- id: q16
question: "Wie hoch ist der Schlusssaldo der Forderungen aus Lieferungen und Leistungen (Konto 1100) per Ende 2025?"
intent: BALANCE_BY_NAME_LOOKUP
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
accountNumber: "1100"
periodYear: 2025
periodMonth: 0
forbiddenTools: [aggregateTable]
expectedNumbers: [66800.0]
- id: q17
question: "Wie hoch waren die Verbindlichkeiten (Konto 2000) jeweils per Ende 2023, 2024 und 2025?"
intent: LIABILITY_HISTORY
expectedTools: [queryTable]
expectedTable: TrusteeDataAccountBalance
requiredFilters:
accountNumber: "2000"
periodMonth: 0
forbiddenTools: [aggregateTable]
expectedNumbers: [41300.0, 44100.0, 47900.0]
- id: q18
question: "Wie viele Bankkonten gibt es im Kontenplan (102x)?"
intent: ACCOUNT_COUNT_BY_PREFIX
expectedTools: [queryTable, aggregateTable]
expectedTable: TrusteeDataAccount
requiredFilters:
accountNumberLike: "102%"
forbiddenTools: []
expectedNumbers: [2]
- id: q19
question: "Gib mir alle Buchungszeilen mit einem Sollbetrag groesser als 20'000 CHF."
intent: JOURNAL_LINES_BY_AMOUNT
expectedTools: [queryTable]
expectedTable: TrusteeDataJournalLine
forbiddenTools: [aggregateTable]
expectedAnswerContains: ["24300", "32100", "41500", "28700"]

View file

@ -0,0 +1,112 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Unit tests for the repair-loop telemetry aggregation in agentLoop.
These counters (`validationFailures`, `repairAttempts`, `successAfterRepair`)
land on `AgentTrace` and are surfaced via the `AGENT_SUMMARY` event. The
Eval-Harness (Phase 1.5) reads them to compute the repair conversion rate.
"""
from __future__ import annotations
from modules.serviceCenter.services.serviceAgent.agentLoop import _computeRepairCounters
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
AgentRoundLog, ToolCallLog,
)
def _round(*toolCalls: ToolCallLog) -> AgentRoundLog:
return AgentRoundLog(roundNumber=1, toolCalls=list(toolCalls))
def _failed(toolName: str, code: str) -> ToolCallLog:
return ToolCallLog(
toolName=toolName,
success=False,
validationFailureCode=code,
error=f"{code}: ...",
)
def _ok(toolName: str) -> ToolCallLog:
return ToolCallLog(toolName=toolName, success=True)
def test_computeRepairCounters_emptyTrace():
fails, attempts, succeeded = _computeRepairCounters([])
assert (fails, attempts, succeeded) == (0, 0, 0)
def test_computeRepairCounters_allCleanRunsHaveZeroCounters():
rounds = [
_round(_ok("queryTable"), _ok("browseTable")),
_round(_ok("aggregateTable")),
]
fails, attempts, succeeded = _computeRepairCounters(rounds)
assert (fails, attempts, succeeded) == (0, 0, 0)
def test_computeRepairCounters_singleFailureCountsButNoRepairYet():
"""One failure in round 1, no follow-up call -- counts the failure but
nothing else. Repair only counts when the LLM tries again."""
rounds = [_round(_failed("queryTable", "FIELD_NOT_FOUND"))]
fails, attempts, succeeded = _computeRepairCounters(rounds)
assert (fails, attempts, succeeded) == (1, 0, 0)
def test_computeRepairCounters_repairThatSucceeds():
"""Round 1 fails, round 2 retries same tool successfully."""
rounds = [
_round(_failed("queryTable", "FIELD_NOT_FOUND")),
_round(_ok("queryTable")),
]
fails, attempts, succeeded = _computeRepairCounters(rounds)
assert (fails, attempts, succeeded) == (1, 1, 1)
def test_computeRepairCounters_repairThatFailsAgain():
"""Round 1 fails, round 2 retries same tool but fails validation again."""
rounds = [
_round(_failed("queryTable", "FIELD_NOT_FOUND")),
_round(_failed("queryTable", "FIELD_NOT_FOUND")),
]
fails, attempts, succeeded = _computeRepairCounters(rounds)
assert (fails, attempts, succeeded) == (2, 1, 0)
def test_computeRepairCounters_siblingCallsInSameRoundAreNotRepairs():
"""When the LLM emits two queryTable calls in the same round, the
second is NOT a repair attempt -- it had no way to see the first
one's rejection yet (parallel dispatch within a round)."""
rounds = [
_round(
_failed("queryTable", "FIELD_NOT_FOUND"),
_failed("queryTable", "FIELD_NOT_FOUND"),
),
]
fails, attempts, succeeded = _computeRepairCounters(rounds)
assert (fails, attempts, succeeded) == (2, 0, 0)
def test_computeRepairCounters_differentToolNamesAreIndependent():
"""A queryTable failure does not flag a later browseTable as a repair."""
rounds = [
_round(_failed("queryTable", "FIELD_NOT_FOUND")),
_round(_ok("browseTable")),
]
fails, attempts, succeeded = _computeRepairCounters(rounds)
assert (fails, attempts, succeeded) == (1, 0, 0)
def test_computeRepairCounters_multiToolMix():
"""Trustee-like sequence: SUM(closingBalance) rejected, LLM switches to
queryTable with a typo (rejected), then fixes the typo (success)."""
rounds = [
_round(_failed("aggregateTable", "INVALID_AGGREGATE_TARGET")),
_round(_failed("queryTable", "FIELD_NOT_FOUND")),
_round(_ok("queryTable")),
]
fails, attempts, succeeded = _computeRepairCounters(rounds)
# 2 validation failures total, 1 prior-rejected queryTable retry that
# succeeded; aggregateTable was never retried so no attempt counted for it.
assert (fails, attempts, succeeded) == (2, 1, 1)

View file

@ -19,11 +19,18 @@ asked for the closing balance per period).
from __future__ import annotations
import asyncio
from unittest.mock import MagicMock
import pytest
from modules.shared import fkRegistry
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
ToolCallRequest, ToolResult,
)
from modules.serviceCenter.services.serviceAgent.featureDataAgent import (
_buildSchemaContext,
_buildSubAgentTools,
_buildTableSchemaBlock,
_formatFieldLine,
_summarizePythonType,
@ -152,10 +159,29 @@ def test_buildSchemaContext_forbidsSummingAggregateFields():
assert "closingBalance" in prompt
def test_buildSchemaContext_appendsTrusteeDomainHints():
"""When the feature module exposes getAgentDomainHints(), the schema prompt
must include those hints so the sub-agent knows e.g. that 102x are bank
accounts and periodMonth=0 is the annual total."""
def test_buildSchemaContext_appendsTrusteeOntologyBlock(monkeypatch):
"""When the feature exposes getAgentOntology(), the schema prompt must
include the compiled ontology block (Phase 2 path)."""
monkeypatch.delenv("POWERON_DISABLE_FEATURE_ONTOLOGY", raising=False)
selected = [_trusteeAccountBalanceObj()]
prompt = _buildSchemaContext(
featureCode="trustee",
instanceLabel="Demo AG",
selectedTables=selected,
requestLang="de",
)
assert "DOMAIN ONTOLOGY (trustee):" in prompt
assert "BankAccount" in prompt
assert "NEVER_AGGREGATE on TrusteeDataAccountBalance.closingBalance" in prompt.replace("never aggregate", "NEVER_AGGREGATE")
assert "BANK_BALANCE_AT_DATE" in prompt
def test_buildSchemaContext_fallsBackToLegacyHints_whenOntologyDisabled(monkeypatch):
"""With POWERON_DISABLE_FEATURE_ONTOLOGY=1 the builder must fall back to
the legacy `getAgentDomainHints()` block. This is the path used by the
eval harness to measure `baseline` and `phase1` accuracy without the
ontology-driven prompt."""
monkeypatch.setenv("POWERON_DISABLE_FEATURE_ONTOLOGY", "1")
selected = [_trusteeAccountBalanceObj()]
prompt = _buildSchemaContext(
featureCode="trustee",
@ -164,16 +190,14 @@ def test_buildSchemaContext_appendsTrusteeDomainHints():
requestLang="de",
)
assert "TRUSTEE DOMAIN HINTS" in prompt
assert "DOMAIN ONTOLOGY" not in prompt
assert "102x Bank / Post" in prompt
assert "periodMonth = 0" in prompt
assert "ANTI-PATTERNS" in prompt
assert 'LIKE \'102%\'' in prompt or "LIKE '102%'" in prompt
def test_buildSchemaContext_skipsHintsForFeaturesWithoutHook():
"""Features that don't export getAgentDomainHints() should produce a prompt
without the trailing hints block. Verified by using a feature code that
cannot resolve to a main module (registry returns None)."""
def test_buildSchemaContext_skipsHintsForFeaturesWithoutHook(monkeypatch):
"""Features that don't export getAgentDomainHints()/getAgentOntology()
should produce a prompt without any trailing hints block."""
monkeypatch.delenv("POWERON_DISABLE_FEATURE_ONTOLOGY", raising=False)
selected = [_trusteeAccountBalanceObj()]
prompt = _buildSchemaContext(
featureCode="nosuchfeature",
@ -182,4 +206,90 @@ def test_buildSchemaContext_skipsHintsForFeaturesWithoutHook():
requestLang="de",
)
assert "TRUSTEE DOMAIN HINTS" not in prompt
assert "DOMAIN ONTOLOGY" not in prompt
assert "Keep your answer SHORT" in prompt
# ------------------------------------------------------------------
# Validator integration (Phase 1: Repair-Loop)
#
# These tests guard that pre-execute validation fires BEFORE the provider
# is touched, and that the structured error payload reaches the LLM via
# `ToolResult.errorDetails` -- the contract the LLM relies on for repair.
# ------------------------------------------------------------------
def _buildRegistryWithMockProvider():
"""Build a sub-agent ToolRegistry where the provider is a MagicMock.
The mock records calls so we can assert the validator short-circuits
before the DB layer is reached."""
provider = MagicMock()
provider.browseTable.return_value = {"rows": [], "total": 0, "limit": 50, "offset": 0}
provider.queryTable.return_value = {"rows": [], "total": 0, "limit": 50, "offset": 0}
provider.aggregateTable.return_value = {"rows": [], "aggregate": "SUM", "field": "x"}
registry = _buildSubAgentTools(
provider=provider,
featureInstanceId="fi-test",
mandateId="m-test",
tableFilters=None,
validator=None,
)
return registry, provider
def _dispatchSync(registry, toolName, args):
"""Synchronously dispatch a tool call through the registry."""
call = ToolCallRequest(name=toolName, args=args)
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(registry.dispatch(call, context={}))
finally:
loop.close()
def test_subAgentTools_invalidFieldShortCircuitsBeforeProvider():
"""A queryTable call with an unknown field must NOT reach the provider."""
registry, provider = _buildRegistryWithMockProvider()
result = _dispatchSync(registry, "queryTable", {
"tableName": "TrusteeDataAccountBalance",
"filters": [{"field": "klosingBalance", "op": "=", "value": 1}],
})
assert isinstance(result, ToolResult)
assert result.success is False
assert result.errorDetails is not None
assert result.errorDetails["code"] == "FIELD_NOT_FOUND"
assert result.errorDetails["suggestion"] == "closingBalance"
assert result.error and result.error.startswith("FIELD_NOT_FOUND:")
provider.queryTable.assert_not_called()
def test_subAgentTools_sumClosingBalanceShortCircuits():
"""The flagship hallucination -- SUM(closingBalance) -- must be blocked
by the pre-execute validator before the DB is touched."""
registry, provider = _buildRegistryWithMockProvider()
result = _dispatchSync(registry, "aggregateTable", {
"tableName": "TrusteeDataAccountBalance",
"aggregate": "SUM",
"field": "closingBalance",
})
assert result.success is False
assert result.errorDetails["code"] == "INVALID_AGGREGATE_TARGET"
assert result.errorDetails["field"] == "closingBalance"
provider.aggregateTable.assert_not_called()
def test_subAgentTools_validCallReachesProvider():
"""Sanity: a valid call passes the validator and hits the provider."""
registry, provider = _buildRegistryWithMockProvider()
result = _dispatchSync(registry, "queryTable", {
"tableName": "TrusteeDataAccountBalance",
"filters": [
{"field": "periodYear", "op": "=", "value": 2025},
{"field": "periodMonth", "op": "=", "value": 0},
],
"fields": ["accountNumber", "closingBalance"],
})
assert result.success is True
assert result.errorDetails is None
provider.queryTable.assert_called_once()

View file

@ -0,0 +1,295 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Unit tests for the Feature Data Sub-Agent QueryValidator.
Each constraint is exercised with both a Happy and a Sad path so a future
refactor that silently drops a check is caught immediately.
Test fixture is the real ``TrusteeDataAccountBalance`` / ``TrusteeDataJournalLine``
Pydantic models -- both are perfectly suited because they cover all four
constraint classes in production-realistic shape (string fields, numeric
fields, fields named ``closingBalance`` / ``debitTotal``).
"""
from __future__ import annotations
import pytest
from modules.shared import fkRegistry
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
Constraint,
ConstraintRule,
OntologyDescriptor,
ValidationErrorCode,
)
from modules.serviceCenter.services.serviceAgent.queryValidator import QueryValidator
@pytest.fixture(scope="module", autouse=True)
def _ensureModels():
fkRegistry._ensureModelsLoaded()
@pytest.fixture()
def validator() -> QueryValidator:
return QueryValidator()
# ------------------------------------------------------------------
# FieldExists -- browseTable / queryTable / aggregateTable
# ------------------------------------------------------------------
def test_browseQuery_happyPath_returnsNone(validator):
err = validator.validateBrowseQuery(
"TrusteeDataAccountBalance",
{"fields": ["accountNumber", "closingBalance"]},
)
assert err is None
def test_browseQuery_invalidField_returnsFieldNotFound(validator):
err = validator.validateBrowseQuery(
"TrusteeDataAccountBalance",
{"fields": ["closingBlance"]}, # typo
)
assert err is not None
assert err.code == ValidationErrorCode.FIELD_NOT_FOUND
assert err.field == "closingBlance"
assert err.suggestion == "closingBalance"
def test_queryTable_filterOnInvalidField_returnsFieldNotFound(validator):
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"filters": [{"field": "klosingBalance", "op": "=", "value": 100}]},
)
assert err is not None
assert err.code == ValidationErrorCode.FIELD_NOT_FOUND
assert err.suggestion == "closingBalance"
def test_queryTable_unknownTable_isLenient(validator):
"""When the table isn't in MODEL_REGISTRY we skip validation -- relying on
the SQL layer to surface schema errors. Prevents false positives for
pure UDB tables not exposed via Pydantic."""
err = validator.validateQueryTable(
"NoSuchTable123",
{"filters": [{"field": "anything", "op": "=", "value": 1}]},
)
assert err is None
# ------------------------------------------------------------------
# OperatorCompatible
# ------------------------------------------------------------------
def test_queryTable_likeOnStringField_isOk(validator):
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"filters": [{"field": "accountNumber", "op": "LIKE", "value": "102%"}]},
)
assert err is None
def test_queryTable_likeOnNumericField_isOperatorIncompatible(validator):
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"filters": [{"field": "closingBalance", "op": "LIKE", "value": "100%"}]},
)
assert err is not None
assert err.code == ValidationErrorCode.OPERATOR_INCOMPATIBLE
assert err.field == "closingBalance"
def test_queryTable_gteOnNumericField_isOk(validator):
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"filters": [{"field": "closingBalance", "op": ">=", "value": 100}]},
)
assert err is None
def test_queryTable_gteOnStringField_isOperatorIncompatible(validator):
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"filters": [{"field": "currency", "op": ">=", "value": "CHF"}]},
)
assert err is not None
assert err.code == ValidationErrorCode.OPERATOR_INCOMPATIBLE
def test_queryTable_equalsOnAnyField_isOk(validator):
"""`=` and `!=` work on any field type."""
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"filters": [{"field": "currency", "op": "=", "value": "CHF"}]},
)
assert err is None
def test_queryTable_isNullOnAnyField_isOk(validator):
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"filters": [{"field": "mandateId", "op": "IS NULL", "value": None}]},
)
assert err is None
# ------------------------------------------------------------------
# AggregateTarget -- the highest-impact rule
# ------------------------------------------------------------------
def test_aggregate_sumDebitAmount_isOk(validator):
err = validator.validateAggregateQuery(
"TrusteeDataJournalLine",
{"aggregate": "SUM", "field": "debitAmount"},
)
assert err is None
def test_aggregate_sumClosingBalance_isInvalidAggregateTarget(validator):
"""The flagship bug: SUM(closingBalance) across periods. Must be blocked."""
err = validator.validateAggregateQuery(
"TrusteeDataAccountBalance",
{"aggregate": "SUM", "field": "closingBalance"},
)
assert err is not None
assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET
assert err.field == "closingBalance"
assert "already aggregated" in err.hint
def test_aggregate_avgDebitTotal_isInvalidAggregateTarget(validator):
"""`*Total` columns are turnovers per period -- AVG across periods is nonsense."""
err = validator.validateAggregateQuery(
"TrusteeDataAccountBalance",
{"aggregate": "AVG", "field": "debitTotal"},
)
assert err is not None
assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET
def test_aggregate_countClosingBalance_isOk(validator):
"""COUNT on a balance column is meaningful (how many balance rows exist)."""
err = validator.validateAggregateQuery(
"TrusteeDataAccountBalance",
{"aggregate": "COUNT", "field": "closingBalance"},
)
assert err is None
def test_aggregate_sumOnStringField_isTypeMismatch(validator):
err = validator.validateAggregateQuery(
"TrusteeDataAccountBalance",
{"aggregate": "SUM", "field": "currency"},
)
assert err is not None
assert err.code == ValidationErrorCode.TYPE_MISMATCH
def test_aggregate_invalidField_returnsFieldNotFound(validator):
err = validator.validateAggregateQuery(
"TrusteeDataAccountBalance",
{"aggregate": "SUM", "field": "nonExistent"},
)
assert err is not None
assert err.code == ValidationErrorCode.FIELD_NOT_FOUND
def test_aggregate_invalidGroupBy_returnsFieldNotFound(validator):
err = validator.validateAggregateQuery(
"TrusteeDataJournalLine",
{"aggregate": "SUM", "field": "debitAmount", "groupBy": "ghostColumn"},
)
assert err is not None
assert err.code == ValidationErrorCode.FIELD_NOT_FOUND
# ------------------------------------------------------------------
# OrderByValid
# ------------------------------------------------------------------
def test_queryTable_orderByValid_isOk(validator):
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"orderBy": "periodYear"},
)
assert err is None
def test_queryTable_orderByInvalid_returnsOrderByInvalid(validator):
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"orderBy": "periodYr"},
)
assert err is not None
assert err.code == ValidationErrorCode.ORDER_BY_INVALID
assert err.suggestion == "periodYear"
def test_queryTable_orderByLiteralStringNone_isOk(validator):
"""LLMs sometimes pass the literal string 'None'."""
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"orderBy": "None"},
)
assert err is None
# ------------------------------------------------------------------
# Ontology-driven override (Phase 2 readiness check)
# ------------------------------------------------------------------
def test_ontologyOverride_blocksAggregateForOntologyField():
"""When the ontology marks a field NEVER_AGGREGATE, SUM/AVG is blocked
even if the field name doesn't match the convention suffixes."""
ontology = OntologyDescriptor(
featureCode="trustee",
constraints=[
Constraint(
appliesTo="TrusteeDataJournalLine.debitAmount",
rule=ConstraintRule.NEVER_AGGREGATE,
message="Synthetic test rule.",
)
],
)
validatorWithOntology = QueryValidator(ontology=ontology)
err = validatorWithOntology.validateAggregateQuery(
"TrusteeDataJournalLine",
{"aggregate": "SUM", "field": "debitAmount"},
)
assert err is not None
assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET
# ------------------------------------------------------------------
# QueryValidationError serialization (consumed by featureDataAgent)
# ------------------------------------------------------------------
def test_validationError_toShortErrorIncludesCodeAndField(validator):
err = validator.validateAggregateQuery(
"TrusteeDataAccountBalance",
{"aggregate": "SUM", "field": "closingBalance"},
)
assert err is not None
short = err.toShortError()
assert short.startswith("INVALID_AGGREGATE_TARGET:")
assert "closingBalance" in short
def test_validationError_toErrorDetailsHasFourKeys(validator):
err = validator.validateQueryTable(
"TrusteeDataAccountBalance",
{"filters": [{"field": "klosingBalance", "op": "=", "value": 0}]},
)
assert err is not None
details = err.toErrorDetails()
assert set(details.keys()) == {"code", "field", "suggestion", "hint"}
assert details["code"] == "FIELD_NOT_FOUND"
assert details["suggestion"] == "closingBalance"

View file

@ -0,0 +1,199 @@
# Copyright (c) 2026 Patrick Motsch
# All rights reserved.
"""Unit tests for the trustee ontology and the ontology-to-prompt compiler.
Verifies:
* the descriptor passes Pydantic validation
* `constraintsForTable` correctly scopes by table/field prefix
* the compiler emits a stable header + every entity name + every
constraint message
* the QueryValidator picks up ontology constraints (NEVER_AGGREGATE on
closingBalance) over the convention-based defaults
* the `getAgentOntology()` hook on `mainTrustee` returns the descriptor
* `_buildValidatorForFeature("trustee")` wires the validator with the
ontology
"""
from __future__ import annotations
import pytest
from modules.features.trustee.mainTrustee import getAgentOntology
from modules.features.trustee.trusteeOntology import getTrusteeOntology
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
ConstraintRule,
OntologyDescriptor,
SemanticType,
ValidationErrorCode,
)
from modules.serviceCenter.services.serviceAgent.featureDataAgent import (
_buildValidatorForFeature,
_loadFeatureOntologyBlock,
)
from modules.serviceCenter.services.serviceAgent.ontologyToPromptCompiler import (
compileOntologyToPrompt,
)
from modules.serviceCenter.services.serviceAgent.queryValidator import QueryValidator
from modules.shared import fkRegistry
@pytest.fixture(scope="module", autouse=True)
def _ensureModels():
fkRegistry._ensureModelsLoaded()
# ---------------------------------------------------------------------------
# OntologyDescriptor structure
# ---------------------------------------------------------------------------
def test_trusteeOntology_returnsValidDescriptor():
ont = getTrusteeOntology()
assert isinstance(ont, OntologyDescriptor)
assert ont.featureCode == "trustee"
assert ont.entities and ont.relations and ont.constraints and ont.canonicalPatterns
def test_trusteeOntology_hasBankAccountSpecialization():
ont = getTrusteeOntology()
bank = next((e for e in ont.entities if e.name == "BankAccount"), None)
assert bank is not None
assert bank.parentEntity == "Account"
assert bank.semanticType == SemanticType.ACCOUNT
def test_trusteeOntology_closingBalanceIsNeverAggregate():
ont = getTrusteeOntology()
constraints = ont.constraintsForTable("TrusteeDataAccountBalance")
matching = [
c for c in constraints
if c.rule == ConstraintRule.NEVER_AGGREGATE
and c.appliesTo == "TrusteeDataAccountBalance.closingBalance"
]
assert matching, "Expected NEVER_AGGREGATE constraint on closingBalance"
def test_trusteeOntology_requiresPeriodFilterOnBalanceTable():
ont = getTrusteeOntology()
constraints = ont.constraintsForTable("TrusteeDataAccountBalance")
table_level = [c for c in constraints if c.rule == ConstraintRule.REQUIRES_FILTER_ON]
assert table_level, "Expected at least one REQUIRES_FILTER_ON constraint"
required = table_level[0].params.get("requiredFields") or []
assert "periodYear" in required
assert "periodMonth" in required
def test_constraintsForTable_filtersScopeCorrectly():
ont = getTrusteeOntology()
bal = ont.constraintsForTable("TrusteeDataAccountBalance")
journal = ont.constraintsForTable("TrusteeDataJournalLine")
for c in bal:
assert c.appliesTo.startswith("TrusteeDataAccountBalance")
for c in journal:
assert c.appliesTo.startswith("TrusteeDataJournalLine")
# ---------------------------------------------------------------------------
# Prompt compiler
# ---------------------------------------------------------------------------
def test_compiler_emitsExpectedHeader():
block = compileOntologyToPrompt(getTrusteeOntology())
assert block.startswith("DOMAIN ONTOLOGY (trustee):"), block.splitlines()[0]
def test_compiler_includesAllEntityNames():
ont = getTrusteeOntology()
block = compileOntologyToPrompt(ont)
for e in ont.entities:
assert e.name in block, f"Entity {e.name} missing from compiled prompt"
def test_compiler_includesAllConstraintMessages():
ont = getTrusteeOntology()
block = compileOntologyToPrompt(ont)
for c in ont.constraints:
assert c.message.split(".")[0] in block, f"Constraint message missing: {c.message[:40]}"
def test_compiler_includesCanonicalPatternTools():
ont = getTrusteeOntology()
block = compileOntologyToPrompt(ont)
for p in ont.canonicalPatterns:
assert p.intent in block
assert p.pattern["tool"] in block
def test_compiler_deterministic():
block1 = compileOntologyToPrompt(getTrusteeOntology())
block2 = compileOntologyToPrompt(getTrusteeOntology())
assert block1 == block2
# ---------------------------------------------------------------------------
# QueryValidator x ontology integration
# ---------------------------------------------------------------------------
def test_validator_picksUpOntologyNeverAggregate():
validator = QueryValidator(ontology=getTrusteeOntology())
err = validator.validateAggregateQuery(
"TrusteeDataAccountBalance",
{"aggregate": "SUM", "field": "closingBalance"},
)
assert err is not None
assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET
assert err.field == "closingBalance"
def test_validator_ontologyConstraintFiresOnDebitTotal():
validator = QueryValidator(ontology=getTrusteeOntology())
err = validator.validateAggregateQuery(
"TrusteeDataAccountBalance",
{"aggregate": "SUM", "field": "debitTotal"},
)
assert err is not None
assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET
def test_validator_allowsLegitimateAggregateOnJournalLine():
validator = QueryValidator(ontology=getTrusteeOntology())
err = validator.validateAggregateQuery(
"TrusteeDataJournalLine",
{"aggregate": "SUM", "field": "debitAmount"},
)
assert err is None
# ---------------------------------------------------------------------------
# featureDataAgent integration hooks
# ---------------------------------------------------------------------------
def test_mainTrustee_getAgentOntology_returnsDescriptor():
ont = getAgentOntology()
assert isinstance(ont, OntologyDescriptor)
assert ont.featureCode == "trustee"
def test_loadFeatureOntologyBlock_returnsCompiledBlock():
block = _loadFeatureOntologyBlock("trustee")
assert block.startswith("DOMAIN ONTOLOGY (trustee):")
assert "BankAccount" in block
def test_loadFeatureOntologyBlock_unknownFeatureReturnsEmpty():
assert _loadFeatureOntologyBlock("doesNotExist") == ""
def test_buildValidatorForFeature_trustee_hasOntology():
validator = _buildValidatorForFeature("trustee")
assert validator._ontology is not None
assert validator._ontology.featureCode == "trustee"
def test_buildValidatorForFeature_unknownFeature_noOntology():
validator = _buildValidatorForFeature("doesNotExist")
assert validator._ontology is None