From 7c4c5e079a748c022462b1636940124b0202ba34 Mon Sep 17 00:00:00 2001 From: ValueOn AG Date: Sat, 16 May 2026 22:55:43 +0200 Subject: [PATCH] rag enhancements --- app.py | 5 + modules/aicore/aicorePluginOpenai.py | 110 +-- modules/connectors/connectorDbPostgre.py | 5 +- modules/datamodels/datamodelAi.py | 7 +- modules/features/trustee/mainTrustee.py | 29 +- modules/features/trustee/trusteeOntology.py | 295 +++++++ modules/routes/routeAdminSttBenchmark.py | 217 ++++++ modules/routes/routeDataConnections.py | 23 +- modules/routes/routeDataSources.py | 26 +- modules/routes/routeRagInventory.py | 77 +- .../services/serviceAgent/agentLoop.py | 51 +- .../serviceAgent/coreTools/_workspaceTools.py | 58 +- .../services/serviceAgent/datamodelAgent.py | 34 + .../serviceAgent/datamodelOntology.py | 203 +++++ .../services/serviceAgent/featureDataAgent.py | 164 +++- .../serviceAgent/ontologyToPromptCompiler.py | 140 ++++ .../services/serviceAgent/queryValidator.py | 311 ++++++++ .../services/serviceAgent/sandboxExecutor.py | 7 +- .../services/serviceAi/subAiCallLooping.py | 16 +- .../mainBackgroundJobService.py | 98 ++- .../services/serviceChat/mainServiceChat.py | 10 +- .../subConnectorIngestConsumer.py | 24 +- .../subConnectorSyncClickup.py | 76 +- .../subConnectorSyncGdrive.py | 65 +- .../serviceKnowledge/subConnectorSyncGmail.py | 134 ++-- .../subConnectorSyncKdrive.py | 439 +++++++++++ .../subConnectorSyncOutlook.py | 131 ++-- .../subConnectorSyncSharepoint.py | 82 +- .../serviceKnowledge/subWalkerHelpers.py | 116 +++ modules/shared/aiAuditLogger.py | 5 + modules/system/mainSystem.py | 10 + tests/eval/__init__.py | 3 + tests/eval/fakeFeatureDataProvider.py | 246 ++++++ tests/eval/runTrusteeBenchmark.py | 735 ++++++++++++++++++ tests/fixtures/trusteeBenchmark/__init__.py | 16 + .../loadTrusteeBenchmarkFixture.py | 275 +++++++ .../fixtures/trusteeBenchmark/questions.yaml | 226 ++++++ .../test_agentTrace_repairCounters.py | 112 +++ .../services/test_featureDataAgent_schema.py | 132 +++- tests/unit/services/test_queryValidator.py | 295 +++++++ tests/unit/services/test_trusteeOntology.py | 199 +++++ 41 files changed, 4809 insertions(+), 398 deletions(-) create mode 100644 modules/features/trustee/trusteeOntology.py create mode 100644 modules/routes/routeAdminSttBenchmark.py create mode 100644 modules/serviceCenter/services/serviceAgent/datamodelOntology.py create mode 100644 modules/serviceCenter/services/serviceAgent/ontologyToPromptCompiler.py create mode 100644 modules/serviceCenter/services/serviceAgent/queryValidator.py create mode 100644 modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py create mode 100644 modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py create mode 100644 tests/eval/__init__.py create mode 100644 tests/eval/fakeFeatureDataProvider.py create mode 100644 tests/eval/runTrusteeBenchmark.py create mode 100644 tests/fixtures/trusteeBenchmark/__init__.py create mode 100644 tests/fixtures/trusteeBenchmark/loadTrusteeBenchmarkFixture.py create mode 100644 tests/fixtures/trusteeBenchmark/questions.yaml create mode 100644 tests/unit/serviceAgent/test_agentTrace_repairCounters.py create mode 100644 tests/unit/services/test_queryValidator.py create mode 100644 tests/unit/services/test_trusteeOntology.py diff --git a/app.py b/app.py index 73a64064..7a4ed4d4 100644 --- a/app.py +++ b/app.py @@ -404,8 +404,10 @@ async def lifespan(app: FastAPI): try: from modules.serviceCenter.services.serviceBackgroundJobs.mainBackgroundJobService import ( recoverInterruptedJobs, + registerZombieKillerScheduler, ) recoverInterruptedJobs() + registerZombieKillerScheduler(intervalMinutes=5) except Exception as e: logger.warning(f"BackgroundJob recovery failed (non-critical): {e}") @@ -607,6 +609,9 @@ app.include_router(connectionsRouter) from modules.routes.routeRagInventory import router as ragInventoryRouter app.include_router(ragInventoryRouter) +from modules.routes.routeAdminSttBenchmark import router as sttBenchmarkRouter +app.include_router(sttBenchmarkRouter) + from modules.routes.routeTableViews import router as tableViewsRouter app.include_router(tableViewsRouter) diff --git a/modules/aicore/aicorePluginOpenai.py b/modules/aicore/aicorePluginOpenai.py index 259ca117..bfea82f7 100644 --- a/modules/aicore/aicorePluginOpenai.py +++ b/modules/aicore/aicorePluginOpenai.py @@ -319,25 +319,24 @@ class AiOpenai(BaseConnectorAi): calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00013 ), AiModel( - name="dall-e-3", - displayName="OpenAI DALL-E 3", + name="gpt-image-1", + displayName="OpenAI GPT Image", connectorType="openai", apiUrl="https://api.openai.com/v1/images/generations", - temperature=0.0, # Image generation doesn't use temperature - maxTokens=0, # Image generation doesn't use tokens + temperature=0.0, + maxTokens=0, contextLength=0, costPer1kTokensInput=0.04, costPer1kTokensOutput=0.0, - speedRating=5, # Slow for image generation - qualityRating=9, # High quality art generation - # capabilities removed (not used in business logic) + speedRating=5, + qualityRating=9, functionCall=self.generateImage, priority=PriorityEnum.QUALITY, processingMode=ProcessingModeEnum.DETAILED, operationTypes=createOperationTypeRatings( (OperationTypeEnum.IMAGE_GENERATE, 10) ), - version="dall-e-3", + version="gpt-image-1", calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.04 ) ] @@ -653,105 +652,82 @@ class AiOpenai(BaseConnectorAi): ) async def generateImage(self, modelCall: AiModelCall) -> AiModelResponse: - """ - Generate an image using DALL-E 3 using standardized pattern. - - Args: - modelCall: AiModelCall with messages and generation options - - Returns: - AiModelResponse with generated image data - """ + """Generate an image using GPT Image model (gpt-image-1).""" try: - # Extract parameters from modelCall - messages = modelCall.messages - model = modelCall.model - options = modelCall.options - - # Get prompt from messages - promptContent = messages[0]["content"] if messages else "" - - # Parse prompt using AiCallPromptImage model import json - + + messages = modelCall.messages + options = modelCall.options + promptContent = messages[0]["content"] if messages else "" + try: - # Try to parse as JSON promptData = json.loads(promptContent) promptModel = AiCallPromptImage(**promptData) - except: - # If not JSON, use plain text prompt + except Exception: promptModel = AiCallPromptImage( prompt=promptContent, - size=options.size if options and hasattr(options, 'size') else "1024x1024", - quality=options.quality if options and hasattr(options, 'quality') else "standard", - style=options.style if options and hasattr(options, 'style') else "vivid" + size=options.size if options and hasattr(options, "size") else "1024x1024", + quality=options.quality if options and hasattr(options, "quality") else "auto", ) - - # Extract parameters from Pydantic model + prompt = promptModel.prompt size = promptModel.size or "1024x1024" - quality = promptModel.quality or "standard" - style = promptModel.style or "vivid" - + rawQuality = promptModel.quality or "auto" + quality = {"standard": "auto", "hd": "high"}.get(rawQuality, rawQuality) + logger.debug(f"Starting image generation with prompt: '{prompt[:100]}...'") - - # DALL-E 3 API endpoint - dalle_url = "https://api.openai.com/v1/images/generations" - + payload = { - "model": "dall-e-3", + "model": "gpt-image-1", "prompt": prompt, "size": size, "quality": quality, - "style": style, "n": 1, - "response_format": "b64_json" # Get base64 data directly instead of URLs } - - # Use existing httpClient to benefit from connection pooling - # This avoids TLS connection issues that can occur with fresh clients + response = await self.httpClient.post( - dalle_url, - json=payload + "https://api.openai.com/v1/images/generations", + json=payload, ) - + if response.status_code != 200: - logger.error(f"DALL-E API error: {response.status_code} - {response.text}") + logger.error(f"Image generation API error: {response.status_code} - {response.text}") return AiModelResponse( content="", success=False, - error=f"DALL-E API error: {response.status_code} - {response.text}" + error=f"Image generation API error: {response.status_code} - {response.text}", ) - + responseJson = response.json() - + if "data" in responseJson and len(responseJson["data"]) > 0: - image_data = responseJson["data"][0]["b64_json"] - - logger.info(f"Successfully generated image: {len(image_data)} characters") + imageData = responseJson["data"][0].get("b64_json", "") + if not imageData: + imageData = responseJson["data"][0].get("url", "") + + logger.info(f"Successfully generated image: {len(imageData)} characters") return AiModelResponse( - content=image_data, + content=imageData, success=True, - modelId="dall-e-3", + modelId="gpt-image-1", metadata={ "size": size, "quality": quality, - "style": style, - "response_id": responseJson.get("id", "") - } + "response_id": responseJson.get("id", ""), + }, ) else: - logger.error("No image data in DALL-E response") + logger.error("No image data in generation response") return AiModelResponse( content="", success=False, - error="No image data in DALL-E response" + error="No image data in generation response", ) - + except Exception as e: logger.error(f"Error during image generation: {str(e)}", exc_info=True) return AiModelResponse( content="", success=False, - error=f"Error during image generation: {str(e)}" + error=f"Error during image generation: {str(e)}", ) \ No newline at end of file diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py index 9f16b1f4..a6893396 100644 --- a/modules/connectors/connectorDbPostgre.py +++ b/modules/connectors/connectorDbPostgre.py @@ -311,7 +311,10 @@ class DatabaseConnector: # Establish connection to the database self._connect() - logger.info("PostgreSQL database system initialized successfully") + logger.debug( + "PostgreSQL database system initialized (db=%s, host=%s, port=%s)", + self.dbDatabase, self.dbHost, self.dbPort, + ) except Exception as e: logger.error(f"FATAL ERROR: Database system initialization failed: {e}") raise diff --git a/modules/datamodels/datamodelAi.py b/modules/datamodels/datamodelAi.py index 786eea7d..cd481c9a 100644 --- a/modules/datamodels/datamodelAi.py +++ b/modules/datamodels/datamodelAi.py @@ -245,11 +245,10 @@ class AiCallPromptWebCrawl(BaseModel): class AiCallPromptImage(BaseModel): """Structured prompt format for image generation.""" - + prompt: str = Field(description="Text description of the image to generate") - size: Optional[str] = Field(default="1024x1024", description="Image size (1024x1024, 1792x1024, 1024x1792)") - quality: Optional[str] = Field(default="standard", description="Image quality (standard, hd)") - style: Optional[str] = Field(default="vivid", description="Image style (vivid, natural)") + size: Optional[str] = Field(default="1024x1024", description="Image size (1024x1024, 1536x1024, 1024x1536)") + quality: Optional[str] = Field(default="auto", description="Image quality (auto, high, medium, low)") class AiProcessParameters(BaseModel): diff --git a/modules/features/trustee/mainTrustee.py b/modules/features/trustee/mainTrustee.py index b8ab853d..8f725d2f 100644 --- a/modules/features/trustee/mainTrustee.py +++ b/modules/features/trustee/mainTrustee.py @@ -754,14 +754,35 @@ ANTI-PATTERNS (do NOT do this): """ +# Parked for one release as a fallback while the ontology-based path rolls +# out (see `trusteeOntology.getTrusteeOntology()`). Remove together with the +# legacy ``_loadFeatureDomainHints`` path once Phase 2 is the only supplier +# of the trustee prompt block. +_AGENT_DOMAIN_HINTS_LEGACY = _AGENT_DOMAIN_HINTS + + def getAgentDomainHints() -> str: """Return Trustee-specific guidance for the Feature Data Sub-Agent. - The text is appended verbatim to the sub-agent's system prompt by - ``featureDataAgent._buildSchemaContext``. Keep it concise and - pattern-driven — every line costs tokens on every sub-agent call. + Deprecated as of Phase 2 (2026-05). Prefer ``getAgentOntology()`` -> + ``ontologyToPromptCompiler.compileOntologyToPrompt(...)``. The legacy + text remains available so callers that still go through + ``_buildSchemaContext()`` keep working during the migration window. """ - return _AGENT_DOMAIN_HINTS + return _AGENT_DOMAIN_HINTS_LEGACY + + +def getAgentOntology(): + """Return the structured ontology used by the Feature Data Sub-Agent. + + Discovered by ``featureDataAgent._buildSchemaContext`` (Phase 2 path): + when this hook is present, the agent compiles its domain block from + the ontology instead of using the legacy free-text hints. The same + descriptor feeds the validator's NEVER_AGGREGATE constraints, so + prompt and validator stay in sync. + """ + from modules.features.trustee.trusteeOntology import getTrusteeOntology + return getTrusteeOntology() def registerFeature(catalogService) -> bool: diff --git a/modules/features/trustee/trusteeOntology.py b/modules/features/trustee/trusteeOntology.py new file mode 100644 index 00000000..c5b117d7 --- /dev/null +++ b/modules/features/trustee/trusteeOntology.py @@ -0,0 +1,295 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Trustee feature ontology (Phase 2 pilot). + +Replaces the hand-written ``_AGENT_DOMAIN_HINTS`` block with a structured +ontology so the Feature Data Sub-Agent's QueryValidator AND the prompt +compiler share the same source of truth: account-group conventions, +period-bucket semantics, the NEVER_AGGREGATE constraints on already- +aggregated columns, and canonical tool-call templates for the most +frequent user intents. + +Both the validator (deterministic enforcement) and the prompt compiler +(LLM steering) read from this descriptor, so an LLM that follows the +prompt patterns will never trigger a validator failure -- and one that +ignores them gets a structured repair hint pointing back at the same +constraint. + +The legacy ``_AGENT_DOMAIN_HINTS_LEGACY`` block stays parked in +``mainTrustee.py`` for one release as a fallback during rollout. +""" + +from __future__ import annotations + +from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( + CanonicalQueryPattern, + Cardinality, + Constraint, + ConstraintRule, + Entity, + Invariant, + OntologyDescriptor, + Relation, + SemanticType, +) + + +# --------------------------------------------------------------------------- +# Entities +# --------------------------------------------------------------------------- + +_ENTITIES = [ + Entity( + name="Account", + pythonClass="TrusteeDataAccount", + semanticType=SemanticType.ACCOUNT, + description=( + "Chart-of-accounts row (Konto). One row per accountNumber per " + "mandate. Identifies the account, never holds balances." + ), + invariants=[ + Invariant(description="accountNumber is a stable string identifier (e.g. '1020', '5400')."), + Invariant(description="accountType is one of: asset / liability / revenue / expense."), + ], + ), + Entity( + name="BankAccount", + pythonClass="TrusteeDataAccount", + semanticType=SemanticType.ACCOUNT, + parentEntity="Account", + description="Account subgroup with accountNumber LIKE '102%' (ZKB, PostFinance, UBS, ...).", + ), + Entity( + name="CashAccount", + pythonClass="TrusteeDataAccount", + semanticType=SemanticType.ACCOUNT, + parentEntity="Account", + description="Account subgroup with accountNumber LIKE '100%' (Hauptkasse, Nebenkassen).", + ), + Entity( + name="AccountBalance", + pythonClass="TrusteeDataAccountBalance", + semanticType=SemanticType.BALANCE_SNAPSHOT, + description=( + "Period-bucketed snapshot: one row per (account, year, month). " + "closingBalance is THE balance at end of period -- already aggregated." + ), + invariants=[ + Invariant(description="periodMonth=0 means annual total of periodYear (use for 'per 31.12.YYYY')."), + Invariant(description="periodMonth in 1..12 means month-end snapshot."), + Invariant(description="closingBalance is the balance at period end; openingBalance at period start."), + Invariant(description="debitTotal/creditTotal are turnovers for the period, NOT balances."), + ], + ), + Entity( + name="JournalEntry", + pythonClass="TrusteeDataJournalEntry", + semanticType=SemanticType.TRANSACTION, + description="One booking header (Beleg). Has a bookingDate (unix seconds float) and totalAmount.", + invariants=[ + Invariant(description="bookingDate is a UTC unix-seconds float; never compare against ISO strings."), + ], + ), + Entity( + name="JournalLine", + pythonClass="TrusteeDataJournalLine", + semanticType=SemanticType.TRANSACTION, + description="One booking line of a JournalEntry. Each line debits or credits exactly one account.", + invariants=[ + Invariant(description="Per line either debitAmount > 0 (Soll) or creditAmount > 0 (Haben), not both."), + ], + ), +] + + +# --------------------------------------------------------------------------- +# Relations +# --------------------------------------------------------------------------- + +_RELATIONS = [ + Relation(fromEntity="AccountBalance", toEntity="Account", cardinality=Cardinality.MANY_TO_ONE, via="accountNumber"), + Relation(fromEntity="JournalLine", toEntity="JournalEntry", cardinality=Cardinality.MANY_TO_ONE, via="journalEntryId"), + Relation(fromEntity="JournalLine", toEntity="Account", cardinality=Cardinality.MANY_TO_ONE, via="accountNumber"), +] + + +# --------------------------------------------------------------------------- +# Constraints (validator-enforced) +# --------------------------------------------------------------------------- + +_CONSTRAINTS = [ + # closingBalance is the single biggest hallucination magnet -- it's a + # balance per period, summing it across periods or accounts is meaningless. + Constraint( + appliesTo="TrusteeDataAccountBalance.closingBalance", + rule=ConstraintRule.NEVER_AGGREGATE, + message=( + "closingBalance is per-period already; query with periodYear+periodMonth, never SUM/AVG it." + ), + ), + Constraint( + appliesTo="TrusteeDataAccountBalance.openingBalance", + rule=ConstraintRule.NEVER_AGGREGATE, + message="openingBalance is already a balance per period; do not SUM/AVG it across rows.", + ), + Constraint( + appliesTo="TrusteeDataAccountBalance.debitTotal", + rule=ConstraintRule.NEVER_AGGREGATE, + message=( + "debitTotal is the period's debit TURNOVER; do not SUM it without an explicit period filter." + ), + ), + Constraint( + appliesTo="TrusteeDataAccountBalance.creditTotal", + rule=ConstraintRule.NEVER_AGGREGATE, + message="creditTotal is a per-period turnover; do not SUM it across periods without an explicit period filter.", + ), + # AccountBalance queries without a period filter are almost always wrong -- + # they conflate annual and monthly snapshots. Phase 2 (REQUIRES_FILTER_ON) + # is wired through to the validator in a later iteration; for now this + # rule is rendered into the prompt compiler so the LLM sees it explicitly. + Constraint( + appliesTo="TrusteeDataAccountBalance", + rule=ConstraintRule.REQUIRES_FILTER_ON, + message=( + "Always filter on periodYear AND periodMonth (use periodMonth=0 for end-of-year)." + ), + params={"requiredFields": ["periodYear", "periodMonth"]}, + ), + Constraint( + appliesTo="TrusteeDataAccountBalance", + rule=ConstraintRule.PREFERRED_TABLE_FOR_INTENT, + message="For 'Saldo per ' and 'Stand ' questions, prefer AccountBalance over JournalLine.", + params={"intents": ["BANK_BALANCE_AT_DATE", "BALANCE_AT_YEAR_END"]}, + ), +] + + +# --------------------------------------------------------------------------- +# Canonical query patterns (worked examples for the LLM) +# --------------------------------------------------------------------------- + +_CANONICAL_PATTERNS = [ + CanonicalQueryPattern( + intent="BANK_BALANCE_AT_DATE", + description="Saldo eines Bankkontos per Jahresende.", + pattern={ + "tool": "queryTable", + "tableName": "TrusteeDataAccountBalance", + "filters": [ + {"field": "accountNumber", "op": "=", "value": ""}, + {"field": "periodYear", "op": "=", "value": ""}, + {"field": "periodMonth", "op": "=", "value": 0}, + ], + "fields": ["closingBalance", "currency"], + }, + ), + CanonicalQueryPattern( + intent="BANK_GROUP_TOTAL_AT_DATE", + description="Summe einer Kontogruppe (z. B. alle Bankkonten 102%) per Jahresende.", + pattern={ + "tool": "queryTable", + "tableName": "TrusteeDataAccountBalance", + "filters": [ + {"field": "accountNumber", "op": "LIKE", "value": "%"}, + {"field": "periodYear", "op": "=", "value": ""}, + {"field": "periodMonth", "op": "=", "value": 0}, + ], + "fields": ["accountNumber", "closingBalance", "currency"], + "_postProcessing": "Sum closingBalance values in your final answer; do NOT SUM via aggregateTable.", + }, + ), + CanonicalQueryPattern( + intent="BALANCE_HISTORY_PER_YEAR", + description="Saldo-Verlauf eines Kontos ueber mehrere Jahre.", + pattern={ + "tool": "queryTable", + "tableName": "TrusteeDataAccountBalance", + "filters": [ + {"field": "accountNumber", "op": "=", "value": ""}, + {"field": "periodMonth", "op": "=", "value": 0}, + ], + "fields": ["periodYear", "closingBalance", "currency"], + "orderBy": "periodYear", + }, + ), + CanonicalQueryPattern( + intent="MONTHLY_BALANCE_SNAPSHOT", + description="Saldo per Ende eines bestimmten Monats.", + pattern={ + "tool": "queryTable", + "tableName": "TrusteeDataAccountBalance", + "filters": [ + {"field": "accountNumber", "op": "=", "value": ""}, + {"field": "periodYear", "op": "=", "value": ""}, + {"field": "periodMonth", "op": "=", "value": ""}, + ], + "fields": ["closingBalance", "currency"], + }, + ), + CanonicalQueryPattern( + intent="ACCOUNT_LIST_BY_TYPE_OR_PREFIX", + description="Welche Konten gehoeren zu einer Gruppe (Typ oder Nummern-Prefix)?", + pattern={ + "tool": "queryTable", + "tableName": "TrusteeDataAccount", + "filters": [ + {"field": "accountNumber", "op": "LIKE", "value": "%"}, + ], + "fields": ["accountNumber", "label", "accountType"], + }, + ), + CanonicalQueryPattern( + intent="JOURNAL_SUM_AT_ACCOUNT", + description="Summe der Soll- oder Haben-Buchungen auf einem Konto.", + pattern={ + "tool": "aggregateTable", + "tableName": "TrusteeDataJournalLine", + "aggregate": "SUM", + "field": "debitAmount", + "filters": [ + {"field": "accountNumber", "op": "=", "value": ""}, + ], + }, + ), + CanonicalQueryPattern( + intent="COUNT_ROWS", + description="Anzahl Buchungen / Buchungszeilen / Konten.", + pattern={ + "tool": "aggregateTable", + "tableName": "", + "aggregate": "COUNT", + "field": "id", + }, + ), + CanonicalQueryPattern( + intent="JOURNAL_LINES_BY_AMOUNT", + description="Buchungszeilen mit einem Betrag groesser/kleiner als einer Schwelle.", + pattern={ + "tool": "queryTable", + "tableName": "TrusteeDataJournalLine", + "filters": [ + {"field": "debitAmount", "op": ">", "value": ""}, + ], + "fields": ["accountNumber", "debitAmount", "description"], + }, + ), +] + + +_TRUSTEE_ONTOLOGY = OntologyDescriptor( + featureCode="trustee", + entities=_ENTITIES, + relations=_RELATIONS, + constraints=_CONSTRAINTS, + canonicalPatterns=_CANONICAL_PATTERNS, +) + + +def getTrusteeOntology() -> OntologyDescriptor: + """Public accessor for the trustee ontology. + + Cached as a module-level singleton -- the descriptor is immutable and + has no per-call state. + """ + return _TRUSTEE_ONTOLOGY diff --git a/modules/routes/routeAdminSttBenchmark.py b/modules/routes/routeAdminSttBenchmark.py new file mode 100644 index 00000000..ae24e792 --- /dev/null +++ b/modules/routes/routeAdminSttBenchmark.py @@ -0,0 +1,217 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""STT Benchmark route — compare Speech-to-Text v1 (latest_long) vs v2 (Chirp 2). + +Sysadmin-only page for evaluating STT model quality and latency. +""" + +import json +import time +import logging +from typing import Any, Dict + +from fastapi import APIRouter, HTTPException, Depends, Request, UploadFile, File, Form +from modules.auth import limiter, getCurrentUser +from modules.datamodels.datamodelUam import User +from modules.shared.configuration import APP_CONFIG + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/api/admin/stt-benchmark", + tags=["Admin STT Benchmark"], + responses={401: {"description": "Unauthorized"}, 403: {"description": "Forbidden"}}, +) + + +def _requireSysAdmin(currentUser: User = Depends(getCurrentUser)) -> User: + if not getattr(currentUser, "isSysAdmin", False) and not getattr(currentUser, "isPlatformAdmin", False): + raise HTTPException(status_code=403, detail="SysAdmin required") + return currentUser + + +def _getCredentials(): + apiKey = APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET") + if not apiKey or apiKey.startswith("YOUR_"): + raise HTTPException(status_code=500, detail="Google Speech API key not configured") + from google.oauth2 import service_account + return service_account.Credentials.from_service_account_info(json.loads(apiKey)) + + +def _runV1(audioBytes: bytes, language: str, model: str) -> Dict[str, Any]: + """Run Speech-to-Text v1 recognition.""" + from google.cloud import speech + credentials = _getCredentials() + client = speech.SpeechClient(credentials=credentials) + + config = speech.RecognitionConfig( + encoding=speech.RecognitionConfig.AudioEncoding.ENCODING_UNSPECIFIED, + language_code=language, + model=model, + enable_automatic_punctuation=True, + enable_word_time_offsets=True, + enable_word_confidence=True, + max_alternatives=3, + use_enhanced=True, + ) + audio = speech.RecognitionAudio(content=audioBytes) + + t0 = time.perf_counter() + response = client.recognize(config=config, audio=audio) + elapsed = time.perf_counter() - t0 + + results = [] + for r in response.results: + for alt in r.alternatives: + results.append({ + "transcript": alt.transcript, + "confidence": round(alt.confidence, 4), + "words": len(alt.words) if alt.words else 0, + }) + + return { + "api": "v1", + "model": model, + "latencyMs": round(elapsed * 1000, 1), + "results": results, + "resultCount": len(response.results), + } + + +def _runV2(audioBytes: bytes, language: str, model: str, location: str) -> Dict[str, Any]: + """Run Speech-to-Text v2 recognition (Chirp 2).""" + from google.cloud.speech_v2 import SpeechClient + from google.cloud.speech_v2.types import cloud_speech + + credentials = _getCredentials() + credInfo = json.loads(APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET")) + projectId = credInfo.get("project_id", "") + + client = SpeechClient( + credentials=credentials, + client_options={"api_endpoint": f"{location}-speech.googleapis.com"}, + ) + + config = cloud_speech.RecognitionConfig( + auto_decoding_config=cloud_speech.AutoDetectDecodingConfig(), + language_codes=[language], + model=model, + features=cloud_speech.RecognitionFeatures( + enable_automatic_punctuation=True, + enable_word_time_offsets=True, + enable_word_confidence=True, + ), + ) + + recognizer = f"projects/{projectId}/locations/{location}/recognizers/_" + + request = cloud_speech.RecognizeRequest( + recognizer=recognizer, + config=config, + content=audioBytes, + ) + + t0 = time.perf_counter() + response = client.recognize(request=request) + elapsed = time.perf_counter() - t0 + + results = [] + for r in response.results: + for alt in r.alternatives: + results.append({ + "transcript": alt.transcript, + "confidence": round(alt.confidence, 4), + "words": len(alt.words) if alt.words else 0, + }) + + return { + "api": "v2", + "model": model, + "location": location, + "latencyMs": round(elapsed * 1000, 1), + "results": results, + "resultCount": len(getattr(response, "results", [])), + } + + +@router.post("/run") +@limiter.limit("10/minute") +async def runBenchmark( + request: Request, + file: UploadFile = File(...), + language: str = Form(default="de-DE"), + v1Model: str = Form(default="latest_long"), + v2Model: str = Form(default="chirp_2"), + v2Location: str = Form(default="europe-west4"), + currentUser: User = Depends(_requireSysAdmin), +) -> Dict[str, Any]: + """Upload audio and compare v1 vs v2 STT results.""" + audioBytes = await file.read() + if len(audioBytes) > 10 * 1024 * 1024: + raise HTTPException(status_code=400, detail="Audio file too large (max 10 MB)") + if len(audioBytes) < 100: + raise HTTPException(status_code=400, detail="Audio file too small") + + logger.info("STT benchmark: %s, %d bytes, language=%s, v1=%s, v2=%s@%s", + file.filename, len(audioBytes), language, v1Model, v2Model, v2Location) + + v1Result = None + v1Error = None + try: + v1Result = _runV1(audioBytes, language, v1Model) + except Exception as e: + v1Error = str(e) + logger.warning("STT v1 benchmark failed: %s", e) + + v2Result = None + v2Error = None + try: + v2Result = _runV2(audioBytes, language, v2Model, v2Location) + except Exception as e: + v2Error = str(e) + logger.warning("STT v2 benchmark failed: %s", e) + + return { + "filename": file.filename, + "fileSizeBytes": len(audioBytes), + "language": language, + "v1": v1Result or {"error": v1Error}, + "v2": v2Result or {"error": v2Error}, + } + + +@router.get("/models") +@limiter.limit("30/minute") +async def getAvailableModels( + request: Request, + currentUser: User = Depends(_requireSysAdmin), +) -> Dict[str, Any]: + """Return available STT models for the benchmark UI.""" + return { + "v1Models": [ + {"value": "latest_long", "label": "latest_long (default)"}, + {"value": "latest_short", "label": "latest_short"}, + {"value": "phone_call", "label": "phone_call"}, + {"value": "video", "label": "video"}, + {"value": "command_and_search", "label": "command_and_search"}, + ], + "v2Models": [ + {"value": "chirp_2", "label": "Chirp 2 (recommended)"}, + {"value": "chirp", "label": "Chirp (original)"}, + {"value": "long", "label": "long"}, + {"value": "short", "label": "short"}, + ], + "locations": [ + {"value": "europe-west4", "label": "Europe West (NL)"}, + {"value": "us-central1", "label": "US Central"}, + {"value": "asia-southeast1", "label": "Asia Southeast"}, + ], + "languages": [ + {"value": "de-DE", "label": "Deutsch (DE)"}, + {"value": "de-CH", "label": "Deutsch (CH)"}, + {"value": "en-US", "label": "English (US)"}, + {"value": "en-GB", "label": "English (GB)"}, + {"value": "fr-FR", "label": "Francais (FR)"}, + {"value": "it-IT", "label": "Italiano (IT)"}, + ], + } diff --git a/modules/routes/routeDataConnections.py b/modules/routes/routeDataConnections.py index 04f652fb..e2b08461 100644 --- a/modules/routes/routeDataConnections.py +++ b/modules/routes/routeDataConnections.py @@ -745,7 +745,7 @@ def _findOwnConnection(interface, userId: str, connectionId: str): @router.patch("/{connectionId}/knowledge-consent") @limiter.limit("10/minute") -def _updateKnowledgeConsent( +async def _updateKnowledgeConsent( request: Request, connectionId: str = Path(..., description="Connection ID"), enabled: bool = Body(..., embed=True), @@ -780,24 +780,13 @@ def _updateKnowledgeConsent( from modules.datamodels.datamodelDataSource import DataSource dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId, "ragIndexEnabled": True}) if dataSources: - import asyncio from modules.serviceCenter.services.serviceBackgroundJobs import startJob authority = connection.authority.value if hasattr(connection.authority, "value") else str(connection.authority or "") - - async def _enqueue(): - await startJob( - "connection.bootstrap", - {"connectionId": connectionId, "authority": authority.lower()}, - triggeredBy=str(currentUser.id), - ) - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.create_task(_enqueue()) - else: - loop.run_until_complete(_enqueue()) - except RuntimeError: - asyncio.run(_enqueue()) + await startJob( + "connection.bootstrap", + {"connectionId": connectionId, "authority": authority.lower()}, + triggeredBy=str(currentUser.id), + ) bootstrapEnqueued = True import json as _json diff --git a/modules/routes/routeDataSources.py b/modules/routes/routeDataSources.py index f7e5425d..ba398008 100644 --- a/modules/routes/routeDataSources.py +++ b/modules/routes/routeDataSources.py @@ -129,7 +129,7 @@ def _updateNeutralizeFields( @router.patch("/{sourceId}/rag-index") @limiter.limit("30/minute") -def _updateDataSourceRagIndex( +async def _updateDataSourceRagIndex( request: Request, sourceId: str = Path(..., description="ID of the DataSource"), ragIndexEnabled: bool = Body(..., embed=True), @@ -139,6 +139,10 @@ def _updateDataSourceRagIndex( true: sets flag + enqueues mini-bootstrap for this DataSource only. false: sets flag + synchronously purges all chunks from this DataSource. + + Must be `async def` so `await startJob(...)` registers `_runJob` in the + main event loop. Sync route → worker thread → temporary loop closes + before the task runs → job stays stuck forever. """ try: from modules.interfaces.interfaceDbApp import getRootInterface @@ -152,7 +156,6 @@ def _updateDataSourceRagIndex( if ragIndexEnabled: from modules.serviceCenter.services.serviceBackgroundJobs import startJob - import asyncio connectionId = rec.get("connectionId") or rec.get("connection_id") or "" conn = rootIf.getUserConnectionById(connectionId) if connectionId else None @@ -160,20 +163,11 @@ def _updateDataSourceRagIndex( if conn: authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "") - async def _enqueue(): - await startJob( - "connection.bootstrap", - {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]}, - triggeredBy=str(context.user.id), - ) - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - loop.create_task(_enqueue()) - else: - loop.run_until_complete(_enqueue()) - except RuntimeError: - asyncio.run(_enqueue()) + await startJob( + "connection.bootstrap", + {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]}, + triggeredBy=str(context.user.id), + ) else: from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId) diff --git a/modules/routes/routeRagInventory.py b/modules/routes/routeRagInventory.py index 37fb330b..074b5b85 100644 --- a/modules/routes/routeRagInventory.py +++ b/modules/routes/routeRagInventory.py @@ -39,20 +39,27 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L chunksByDs: Dict[str, int] = {} unassigned = 0 for idx in connIndexRows: - prov = (idx.get("provenance") if isinstance(idx, dict) else getattr(idx, "provenance", None)) or {} + struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {} + ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {} + prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {} dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else "" if dsIdRef: chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1 else: unassigned += 1 + seen: Dict[str, bool] = {} dsItems = [] for ds in dataSources: dsId = ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "") + dsPath = ds.get("path") if isinstance(ds, dict) else getattr(ds, "path", "") + if dsPath in seen: + continue + seen[dsPath] = True dsItems.append({ "id": dsId, "label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""), - "path": ds.get("path") if isinstance(ds, dict) else getattr(ds, "path", ""), + "path": dsPath, "sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""), "ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False), "neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False), @@ -60,20 +67,43 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "chunkCount": chunksByDs.get(dsId, 0), }) - if unassigned > 0 and len(dsItems) == 1: - dsItems[0]["chunkCount"] += unassigned + if unassigned > 0 and len(dsItems) > 0: + perDs = unassigned // len(dsItems) + remainder = unassigned % len(dsItems) + for i, item in enumerate(dsItems): + item["chunkCount"] += perDs + (1 if i < remainder else 0) - jobs = jobService.listJobs(jobType="connection.bootstrap", limit=5) + # Pull a wider window than the previous 5 so the "last successful + # sync" is found even if a connection has many recent jobs queued. + jobs = jobService.listJobs(jobType="connection.bootstrap", limit=50) connJobs = [j for j in jobs if (j.get("payload") or {}).get("connectionId") == connectionId] runningJobs = [ {"jobId": j["id"], "progress": j.get("progress", 0), "progressMessage": j.get("progressMessage", "")} for j in connJobs if j.get("status") in ("PENDING", "RUNNING") ] - lastError = None + lastError: Optional[Dict[str, Any]] = None + lastSuccess: Optional[Dict[str, Any]] = None for j in connJobs: - if j.get("status") == "ERROR": - lastError = {"jobId": j["id"], "errorMessage": j.get("errorMessage", "")} + status = j.get("status") + if status == "ERROR" and lastError is None: + lastError = { + "jobId": j["id"], + "errorMessage": j.get("errorMessage", ""), + "finishedAt": j.get("finishedAt"), + } + elif status == "SUCCESS" and lastSuccess is None: + result = j.get("result") or {} + lastSuccess = { + "jobId": j["id"], + "finishedAt": j.get("finishedAt"), + "indexed": result.get("indexed", 0), + "skippedDuplicate": result.get("skippedDuplicate", 0), + "skippedPolicy": result.get("skippedPolicy", 0), + "failed": result.get("failed", 0), + "durationMs": result.get("durationMs", 0), + } + if lastError and lastSuccess: break out.append({ @@ -86,6 +116,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L "totalChunks": connChunkTotal, "runningJobs": runningJobs, "lastError": lastError, + "lastSuccess": lastSuccess, }) return out @@ -182,7 +213,7 @@ def _getInventoryPlatform( @router.post("/reindex/{connectionId}") @limiter.limit("10/minute") -def _reindexConnection( +async def _reindexConnection( request: Request, connectionId: str, currentUser: User = Depends(getCurrentUser), @@ -190,12 +221,16 @@ def _reindexConnection( """Re-trigger bootstrap for a connection (re-index all ragIndexEnabled DataSources). Submits a new connection.bootstrap job, regardless of previous failures. + + Must be `async def` so `await startJob(...)` registers the `_runJob` task + in FastAPI's main event loop. A sync route would land in the worker + threadpool and `asyncio.run` would tear down the temporary loop right + after `create_task`, leaving the job stuck in PENDING forever. """ try: from modules.interfaces.interfaceDbApp import getRootInterface from modules.serviceCenter.services.serviceBackgroundJobs import startJob from modules.datamodels.datamodelDataSource import DataSource - import asyncio rootIf = getRootInterface() conn = rootIf.getUserConnectionById(connectionId) @@ -213,23 +248,13 @@ def _reindexConnection( authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "") dsIds = [(ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")) for ds in ragDs] - async def _enqueue(): - return await startJob( - "connection.bootstrap", - {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": dsIds}, - triggeredBy=str(currentUser.id), - ) - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - future = asyncio.ensure_future(_enqueue()) - jobId = None - else: - jobId = loop.run_until_complete(_enqueue()) - except RuntimeError: - jobId = asyncio.run(_enqueue()) + jobId = await startJob( + "connection.bootstrap", + {"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": dsIds}, + triggeredBy=str(currentUser.id), + ) - logger.info("Reindex triggered for connection %s (%d DataSources)", connectionId, len(dsIds)) + logger.info("Reindex triggered for connection %s (%d DataSources, jobId=%s)", connectionId, len(dsIds), jobId) return {"status": "queued", "connectionId": connectionId, "dataSourceCount": len(dsIds), "jobId": jobId} except HTTPException: raise diff --git a/modules/serviceCenter/services/serviceAgent/agentLoop.py b/modules/serviceCenter/services/serviceAgent/agentLoop.py index c1571994..99f4dbd7 100644 --- a/modules/serviceCenter/services/serviceAgent/agentLoop.py +++ b/modules/serviceCenter/services/serviceAgent/agentLoop.py @@ -7,7 +7,7 @@ import logging import time import json import re -from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable +from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable, Tuple from modules.datamodels.datamodelAi import ( AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum @@ -360,12 +360,18 @@ async def runAgentLoop( state.totalToolCalls += len(results) for result in results: + validationCode = None + if isinstance(result.errorDetails, dict): + code = result.errorDetails.get("code") + if isinstance(code, str): + validationCode = code roundLog.toolCalls.append(ToolCallLog( toolName=result.toolName, args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}), success=result.success, durationMs=result.durationMs, error=result.error, + validationFailureCode=validationCode, resultData=result.data[:300] if result.data else "", )) if not result.success: @@ -443,6 +449,11 @@ async def runAgentLoop( trace.totalCostCHF = state.totalCostCHF trace.abortReason = state.abortReason + validationFailures, repairAttempts, successAfterRepair = _computeRepairCounters(trace.rounds) + trace.validationFailures = validationFailures + trace.repairAttempts = repairAttempts + trace.successAfterRepair = successAfterRepair + artifactSummary = _buildArtifactSummary(trace.rounds) yield AgentEvent( @@ -456,6 +467,9 @@ async def runAgentLoop( "status": state.status.value, "abortReason": state.abortReason, "artifacts": artifactSummary, + "validationFailures": validationFailures, + "repairAttempts": repairAttempts, + "successAfterRepair": successAfterRepair, } ) @@ -720,6 +734,41 @@ def classifyToolResult( return None +def _computeRepairCounters(rounds: List[AgentRoundLog]) -> Tuple[int, int, int]: + """Aggregate repair-loop telemetry across all rounds. + + Returns ``(validationFailures, repairAttempts, successAfterRepair)``. + + * `validationFailures` -- total tool calls rejected by a pre-execute + validator (any round, counts every occurrence). + * `repairAttempts` -- tool calls in **later** rounds whose `toolName` + had been rejected in some **earlier** round. Multiple retries of the + same tool count multiple times. We intentionally do not count + sibling calls within the same round, since the LLM has not yet seen + the first one's result when emitting the second. + * `successAfterRepair` -- the subset of `repairAttempts` that passed + the validator (``validationFailureCode is None``). + """ + validationFailures = 0 + repairAttempts = 0 + successAfterRepair = 0 + rejectedTools: set = set() + + for roundLog in rounds: + rejectedFromPriorRounds = set(rejectedTools) + for tc in roundLog.toolCalls: + wasRejectedBefore = tc.toolName in rejectedFromPriorRounds + if tc.validationFailureCode is not None: + validationFailures += 1 + if wasRejectedBefore: + repairAttempts += 1 + rejectedTools.add(tc.toolName) + elif wasRejectedBefore: + repairAttempts += 1 + successAfterRepair += 1 + return validationFailures, repairAttempts, successAfterRepair + + _ARTIFACT_TOOLS = {"writeFile", "replaceInFile", "deleteFile", "renameFile", "copyFile", "createFolder", "deleteFolder", "renderDocument", "generateImage"} diff --git a/modules/serviceCenter/services/serviceAgent/coreTools/_workspaceTools.py b/modules/serviceCenter/services/serviceAgent/coreTools/_workspaceTools.py index ed30538a..8aa83732 100644 --- a/modules/serviceCenter/services/serviceAgent/coreTools/_workspaceTools.py +++ b/modules/serviceCenter/services/serviceAgent/coreTools/_workspaceTools.py @@ -19,6 +19,20 @@ from modules.serviceCenter.services.serviceAgent.coreTools._helpers import ( logger = logging.getLogger(__name__) +_STALE_EXTRACTION_PATTERNS = ( + "requires the extract-msg package", + "extraction requires the", + "will be treated as binary", +) + + +def _isStaleExtractionResult(text: str) -> bool: + """Detect cached extraction results that are just error/warning placeholders.""" + if len(text) > 500: + return False + textLower = text.lower() + return any(p in textLower for p in _STALE_EXTRACTION_PATTERNS) + import uuid as _uuid @@ -62,15 +76,16 @@ def _registerWorkspaceTools(registry: ToolRegistry, services): ] if textChunks: assembled = "\n\n".join(c["data"] for c in textChunks) - chunked = _applyOffsetLimit(assembled, offset, limit) - if chunked is not None: - return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked) - if len(assembled) > _MAX_TOOL_RESULT_CHARS: - assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}. Use offset/limit to read specific sections.]" - return ToolResult( - toolCallId="", toolName="readFile", success=True, - data=assembled, - ) + if not _isStaleExtractionResult(assembled): + chunked = _applyOffsetLimit(assembled, offset, limit) + if chunked is not None: + return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked) + if len(assembled) > _MAX_TOOL_RESULT_CHARS: + assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}. Use offset/limit to read specific sections.]" + return ToolResult( + toolCallId="", toolName="readFile", success=True, + data=assembled, + ) elif fileStatus in ("processing", "embedding", "extracted"): return ToolResult( toolCallId="", toolName="readFile", success=True, @@ -101,12 +116,31 @@ def _registerWorkspaceTools(registry: ToolRegistry, services): isBinary = _looksLikeBinary(rawBytes) if isBinary: + extractionService = services.getService("extraction") if hasattr(services, "getService") else None + if extractionService: + try: + extracted = extractionService.extractContentFromBytes( + rawBytes, fileName, mimeType, documentId=fileId, + ) + textParts = [ + p.data for p in (extracted.parts or []) + if getattr(p, "contentType", "") != "image" and getattr(p, "data", None) + ] + if textParts: + assembled = "\n\n".join(textParts) + chunked = _applyOffsetLimit(assembled, offset, limit) + if chunked is not None: + return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked) + if len(assembled) > _MAX_TOOL_RESULT_CHARS: + assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}. Use offset/limit to read specific sections.]" + return ToolResult(toolCallId="", toolName="readFile", success=True, data=assembled) + except Exception as extractErr: + logger.warning("readFile: inline extraction failed for %s: %s", fileId, extractErr) return ToolResult( toolCallId="", toolName="readFile", success=True, data=( - f"[File '{fileName}' ({mimeType}) is not yet indexed " - f"(status: {fileStatus or 'unknown'}). Indexing runs automatically " - f"on upload. Please wait a few seconds and retry, or re-upload the file. " + f"[File '{fileName}' ({mimeType}) is binary and could not be extracted " + f"(status: {fileStatus or 'unknown'}). " f"For visual content use describeImage(fileId='{fileId}').]" ), ) diff --git a/modules/serviceCenter/services/serviceAgent/datamodelAgent.py b/modules/serviceCenter/services/serviceAgent/datamodelAgent.py index 889f31e8..c96265e4 100644 --- a/modules/serviceCenter/services/serviceAgent/datamodelAgent.py +++ b/modules/serviceCenter/services/serviceAgent/datamodelAgent.py @@ -79,6 +79,14 @@ class ToolResult(BaseModel): success: bool = True data: str = "" error: Optional[str] = None + errorDetails: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Structured, machine-readable error payload for the LLM (e.g. validation " + "repair hints with code/field/suggestion/hint). `error` remains the short " + "human-readable text for logs and audit." + ), + ) durationMs: int = 0 sideEvents: Optional[List[Dict[str, Any]]] = None @@ -141,6 +149,14 @@ class ToolCallLog(BaseModel): success: bool = True durationMs: int = 0 error: Optional[str] = None + validationFailureCode: Optional[str] = Field( + default=None, + description=( + "If the tool call was rejected by a pre-execute validator (e.g. " + "QueryValidator), the structured error code (e.g. FIELD_NOT_FOUND). " + "None when the call ran cleanly or failed for other reasons." + ), + ) resultData: str = Field(default="", description="Short result summary for artifact tracking") @@ -167,6 +183,24 @@ class AgentTrace(BaseModel): totalToolCalls: int = 0 totalCostCHF: float = 0.0 abortReason: Optional[str] = None + validationFailures: int = Field( + default=0, + description="Total tool calls rejected by a pre-execute validator across the run.", + ) + repairAttempts: int = Field( + default=0, + description=( + "Number of times the LLM retried a previously rejected tool (same toolName) " + "in a later round. Counted by `agentLoop` from per-round ToolCallLog entries." + ), + ) + successAfterRepair: int = Field( + default=0, + description=( + "Number of repair attempts that produced a clean (validationFailureCode=None) " + "result. Combined with `repairAttempts` this gives the repair conversion rate." + ), + ) rounds: List[AgentRoundLog] = Field(default_factory=list) diff --git a/modules/serviceCenter/services/serviceAgent/datamodelOntology.py b/modules/serviceCenter/services/serviceAgent/datamodelOntology.py new file mode 100644 index 00000000..30e5b023 --- /dev/null +++ b/modules/serviceCenter/services/serviceAgent/datamodelOntology.py @@ -0,0 +1,203 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Ontology data model for feature data sub-agents. + +This module defines the data structures that describe a feature's data +ontology -- entities, relations, constraints, canonical query patterns -- +plus the validation error payload used by the QueryValidator. + +Phase 1 (Repair-Loop) only needs `QueryValidationError`, `Constraint`, +`ConstraintRule` and `ValidationErrorCode`; the richer `Entity`/`Relation`/ +`OntologyDescriptor` types are defined here so Phase 2 (Trustee ontology +pilot) can plug in without a second data-model change. + +See `wiki/c-work/2-build/2026-05-feature-data-agent-ontology-and-repair.md`. +""" + +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class ValidationErrorCode(str, Enum): + """Stable codes for validator failures. + + The LLM sees these codes verbatim in `ToolResult.errorDetails["code"]` + and is expected to react to them deterministically (e.g. inspect the + schema via browseTable when FIELD_NOT_FOUND, drop the SUM when + INVALID_AGGREGATE_TARGET, add a period filter when MISSING_REQUIRED_FILTER). + """ + FIELD_NOT_FOUND = "FIELD_NOT_FOUND" + INVALID_AGGREGATE_TARGET = "INVALID_AGGREGATE_TARGET" + WRONG_TABLE_FOR_PURPOSE = "WRONG_TABLE_FOR_PURPOSE" + TYPE_MISMATCH = "TYPE_MISMATCH" + OPERATOR_INCOMPATIBLE = "OPERATOR_INCOMPATIBLE" + MISSING_REQUIRED_FILTER = "MISSING_REQUIRED_FILTER" + ORDER_BY_INVALID = "ORDER_BY_INVALID" + + +class QueryValidationError(BaseModel): + """Structured pre-execute validation error. + + Serialized into `ToolResult.errorDetails` (machine-readable) and + summarized into `ToolResult.error` (short human-readable string). + """ + code: ValidationErrorCode + field: Optional[str] = Field( + default=None, + description="The offending field name (when applicable).", + ) + suggestion: Optional[str] = Field( + default=None, + description=( + "Best-effort suggestion (e.g. fuzzy-matched valid field name). " + "None when no useful suggestion exists." + ), + ) + hint: str = Field( + description="Short corrective hint, max ~80 chars. Surfaced to the LLM verbatim.", + max_length=160, + ) + + def toShortError(self) -> str: + """Build the short `error` string for logs/audit. + + Format: `: ` (or with field when present). + """ + if self.field: + return f"{self.code.value}: {self.field}: {self.hint}" + return f"{self.code.value}: {self.hint}" + + def toErrorDetails(self) -> Dict[str, Any]: + """Build the dict for `ToolResult.errorDetails`.""" + return { + "code": self.code.value, + "field": self.field, + "suggestion": self.suggestion, + "hint": self.hint, + } + + +class ConstraintRule(str, Enum): + """High-level rule kinds that can be attached to a field or table.""" + NEVER_AGGREGATE = "NEVER_AGGREGATE" + REQUIRES_FILTER_ON = "REQUIRES_FILTER_ON" + TYPE_MISMATCH_GUARD = "TYPE_MISMATCH_GUARD" + PREFERRED_TABLE_FOR_INTENT = "PREFERRED_TABLE_FOR_INTENT" + + +class Constraint(BaseModel): + """A single rule the validator and the prompt compiler both consume. + + Phase 1 uses constraints declared inline by the validator (defaults + derived from naming conventions like ``*Balance`` / ``*Total``). + Phase 2 sources them from feature ontologies, replacing the + convention-based defaults. + """ + appliesTo: str = Field( + description=( + "Target identifier, format depends on rule: `
.` for " + "field-level constraints, `
` for table-level." + ), + ) + rule: ConstraintRule + message: str = Field( + description="Short hint forwarded to the LLM if the constraint fires.", + max_length=160, + ) + params: Dict[str, Any] = Field( + default_factory=dict, + description=( + "Rule-specific extras, e.g. {'requiredFields': ['periodYear', 'periodMonth']} " + "for REQUIRES_FILTER_ON." + ), + ) + + +class SemanticType(str, Enum): + """High-level semantic category an entity belongs to. + + Coarser than the underlying Pydantic type -- used so the prompt compiler + can group entities ("here are your ACCOUNT-like tables") without the LLM + having to read the full schema. + """ + ACCOUNT = "ACCOUNT" + BALANCE_SNAPSHOT = "BALANCE_SNAPSHOT" + TRANSACTION = "TRANSACTION" + DOCUMENT = "DOCUMENT" + PARTY = "PARTY" + PERIOD = "PERIOD" + OTHER = "OTHER" + + +class Cardinality(str, Enum): + ONE_TO_ONE = "ONE_TO_ONE" + ONE_TO_MANY = "ONE_TO_MANY" + MANY_TO_ONE = "MANY_TO_ONE" + MANY_TO_MANY = "MANY_TO_MANY" + + +class Invariant(BaseModel): + """Free-form invariant attached to an entity. + + Phase 1 leaves these as opaque text consumed by the prompt compiler. + Future phases may add a structured rule kind. + """ + description: str = Field(max_length=200) + + +class Entity(BaseModel): + """One semantic entity in the ontology (often backed by a Pydantic table).""" + name: str + pythonClass: Optional[str] = Field( + default=None, + description="MODEL_REGISTRY key when the entity is DB-backed (e.g. 'TrusteeDataAccountBalance').", + ) + semanticType: SemanticType = SemanticType.OTHER + parentEntity: Optional[str] = Field( + default=None, + description="Name of a broader entity this one specializes (e.g. 'BankAccount' parentEntity 'Account').", + ) + description: str = "" + invariants: List[Invariant] = Field(default_factory=list) + + +class Relation(BaseModel): + fromEntity: str + toEntity: str + cardinality: Cardinality + via: Optional[str] = Field( + default=None, + description="FK-Feldname auf der fromEntity-Seite (z. B. 'journalEntryId').", + ) + + +class CanonicalQueryPattern(BaseModel): + """Tool-call skeleton for a recurring user intent. + + The prompt compiler renders these as worked examples so the LLM has a + template to mimic instead of inventing a query shape. + """ + intent: str = Field(description="Short label, e.g. 'BANK_BALANCE_AT_DATE'.") + description: str = Field(default="", description="Human-readable when to use this pattern.") + pattern: Dict[str, Any] = Field( + description="Tool-call shape with placeholders, e.g. {'tool': 'queryTable', 'tableName': '...', 'filters': [...]}", + ) + + +class OntologyDescriptor(BaseModel): + """Top-level container exported by `getAgentOntology()` per feature.""" + featureCode: str + entities: List[Entity] = Field(default_factory=list) + relations: List[Relation] = Field(default_factory=list) + constraints: List[Constraint] = Field(default_factory=list) + canonicalPatterns: List[CanonicalQueryPattern] = Field(default_factory=list) + + def constraintsForTable(self, tableName: str) -> List[Constraint]: + """Return constraints whose ``appliesTo`` targets the given table or one of its fields.""" + prefix = f"{tableName}." + return [ + c for c in self.constraints + if c.appliesTo == tableName or c.appliesTo.startswith(prefix) + ] diff --git a/modules/serviceCenter/services/serviceAgent/featureDataAgent.py b/modules/serviceCenter/services/serviceAgent/featureDataAgent.py index aa2d332d..51840575 100644 --- a/modules/serviceCenter/services/serviceAgent/featureDataAgent.py +++ b/modules/serviceCenter/services/serviceAgent/featureDataAgent.py @@ -15,6 +15,7 @@ invoked outside an agent loop (e.g. in tests). import json import logging +import os from typing import Any, Callable, Awaitable, Dict, List, Optional from modules.datamodels.datamodelAi import ( @@ -25,6 +26,10 @@ from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult, ) +from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( + QueryValidationError, +) +from modules.serviceCenter.services.serviceAgent.queryValidator import QueryValidator from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider from modules.shared.i18nRegistry import resolveText @@ -83,7 +88,8 @@ async def runFeatureDataAgent( """ provider = FeatureDataProvider(dbConnector, neutralizeFields=neutralizeFields) - registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {}) + validator = _buildValidatorForFeature(featureCode) + registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {}, validator=validator) for tbl in selectedTables: meta = tbl.get("meta", {}) @@ -153,10 +159,19 @@ def _buildSubAgentTools( featureInstanceId: str, mandateId: str, tableFilters: Dict[str, Dict[str, str]] = None, + validator: Optional[QueryValidator] = None, ) -> ToolRegistry: - """Register browseTable and queryTable as sub-agent tools.""" + """Register browseTable and queryTable as sub-agent tools. + + The optional ``validator`` runs **before** the provider on every call. + When it returns a structured error, the tool result carries + ``errorDetails`` (machine-readable repair hint for the LLM) plus the + short ``error`` string for logs/audit. No provider call happens in that + case, so the database is never reached with a known-bad query. + """ registry = ToolRegistry() _tableFilters = tableFilters or {} + _validator = validator or QueryValidator() def _recordFilterToList(tableName: str) -> Optional[List[Dict[str, Any]]]: """Convert a recordFilter dict to a list of {field, op, value} filter dicts.""" @@ -165,6 +180,14 @@ def _buildSubAgentTools( return None return [{"field": k, "op": "=", "value": v} for k, v in rf.items()] + def _validationToolResult(toolName: str, err: QueryValidationError) -> ToolResult: + return ToolResult( + toolCallId="", toolName=toolName, + success=False, + error=err.toShortError(), + errorDetails=err.toErrorDetails(), + ) + async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]): tableName = args.get("tableName", "") limit = args.get("limit", 50) @@ -172,6 +195,9 @@ def _buildSubAgentTools( fields = args.get("fields") if not tableName: return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required") + validationErr = _validator.validateBrowseQuery(tableName, args) + if validationErr is not None: + return _validationToolResult("browseTable", validationErr) result = provider.browseTable( tableName=tableName, featureInstanceId=featureInstanceId, @@ -197,6 +223,9 @@ def _buildSubAgentTools( offset = args.get("offset", 0) if not tableName: return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required") + validationErr = _validator.validateQueryTable(tableName, args) + if validationErr is not None: + return _validationToolResult("queryTable", validationErr) result = provider.queryTable( tableName=tableName, featureInstanceId=featureInstanceId, @@ -220,12 +249,19 @@ def _buildSubAgentTools( aggregate = args.get("aggregate", "") field = args.get("field", "") groupBy = args.get("groupBy") + filters = args.get("filters") or [] if not tableName: return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="tableName required") if not aggregate: return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="aggregate required (SUM, COUNT, AVG, MIN, MAX)") if not field: return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="field required") + validationErr = _validator.validateAggregateQuery(tableName, args) + if validationErr is not None: + return _validationToolResult("aggregateTable", validationErr) + combinedFilters = list(filters) + recordFilters = _recordFilterToList(tableName) or [] + combinedFilters.extend(recordFilters) result = provider.aggregateTable( tableName=tableName, featureInstanceId=featureInstanceId, @@ -233,7 +269,7 @@ def _buildSubAgentTools( aggregate=aggregate, field=field, groupBy=groupBy, - extraFilters=_recordFilterToList(tableName), + extraFilters=combinedFilters or None, ) return ToolResult( toolCallId="", toolName="aggregateTable", @@ -246,8 +282,12 @@ def _buildSubAgentTools( "aggregateTable", _aggregateTable, description=( "Run an aggregate query on a feature data table. " - "Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY. " - "Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', field='debitAmount', groupBy='costCenter')" + "Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY and filters. " + "Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', " + "field='debitAmount', filters=[{'field':'accountNumber','op':'=','value':'5400'}]). " + "On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- " + "read errorDetails and correct the next call (e.g. drop the SUM, switch to queryTable with period filters, " + "or use the suggested field name)." ), parameters={ "type": "object", @@ -256,6 +296,22 @@ def _buildSubAgentTools( "aggregate": {"type": "string", "enum": ["SUM", "COUNT", "AVG", "MIN", "MAX"], "description": "Aggregate function"}, "field": {"type": "string", "description": "Field to aggregate (e.g. debitAmount, creditAmount)"}, "groupBy": {"type": "string", "description": "Optional field to group by (e.g. costCenter, accountNumber)"}, + "filters": { + "type": "array", + "items": { + "type": "object", + "properties": { + "field": {"type": "string"}, + "op": {"type": "string"}, + "value": {}, + }, + }, + "description": ( + "Optional filter conditions applied before the aggregate. Same shape as queryTable's " + "filters. Required whenever you want to aggregate only a subset (e.g. SUM debits on " + "ONE account, COUNT rows in ONE year)." + ), + }, }, "required": ["tableName", "aggregate", "field"], }, @@ -264,7 +320,11 @@ def _buildSubAgentTools( registry.register( "browseTable", _browseTable, - description="List rows from a feature data table with pagination.", + description=( + "List rows from a feature data table with pagination. " + "On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- " + "use errorDetails to correct the next call." + ), parameters={ "type": "object", "properties": { @@ -286,7 +346,10 @@ def _buildSubAgentTools( description=( "Query a feature data table with filters, field selection, and ordering. " "Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. " - "Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL." + "Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL. " + "On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- " + "common codes: FIELD_NOT_FOUND (use the suggestion or call browseTable), OPERATOR_INCOMPATIBLE " + "(switch to a compatible operator for that field type), ORDER_BY_INVALID." ), parameters={ "type": "object", @@ -410,13 +473,94 @@ def _buildSchemaContext( "- Keep your answer SHORT. The caller is a machine, not a human.", ] - domainHints = _loadFeatureDomainHints(featureCode) - if domainHints: - parts.extend(["", domainHints.strip()]) + domainBlock = "" + if not _isOntologyDisabled(): + domainBlock = _loadFeatureOntologyBlock(featureCode) + if not domainBlock: + domainBlock = _loadFeatureDomainHints(featureCode) + if domainBlock: + parts.extend(["", domainBlock.strip()]) return "\n".join(parts) +def _isOntologyDisabled() -> bool: + """Eval-only escape hatch. + + Set ``POWERON_DISABLE_FEATURE_ONTOLOGY=1`` in the environment to force + ``_buildSchemaContext`` back onto the legacy ``getAgentDomainHints()`` + path. Used by the Phase 1.5 benchmark to measure ``baseline`` and + ``phase1`` accuracy WITHOUT the ontology-driven prompt block. Never + set this flag in production. + """ + return os.environ.get("POWERON_DISABLE_FEATURE_ONTOLOGY", "").strip() in ("1", "true", "TRUE", "yes") + + +def _buildValidatorForFeature(featureCode: str) -> QueryValidator: + """Construct a QueryValidator wired with the feature ontology (when present). + + Without an ontology the validator falls back to its convention-based + constraints (``*Balance`` / ``*Total`` are NEVER_AGGREGATE). With an + ontology the descriptor's constraints take precedence -- the validator + and the prompt block then share the same source of truth. + """ + ontology = _loadFeatureOntology(featureCode) + return QueryValidator(ontology=ontology) + + +def _loadFeatureOntology(featureCode: str): + """Return the feature's OntologyDescriptor or None when no hook is exposed.""" + if not featureCode: + return None + try: + from modules.system.registry import loadFeatureMainModules + except Exception: + return None + + try: + mainModules = loadFeatureMainModules() or {} + except Exception as exc: + logger.debug("Ontology lookup: cannot load main modules (%s)", exc) + return None + + module = mainModules.get(featureCode) or mainModules.get(featureCode.lower()) + if module is None: + return None + hook = getattr(module, "getAgentOntology", None) + if not callable(hook): + return None + try: + return hook() + except Exception as exc: + logger.warning("Feature '%s' getAgentOntology() raised: %s", featureCode, exc) + return None + + +def _loadFeatureOntologyBlock(featureCode: str) -> str: + """Return the ontology-derived prompt block when the feature exposes one. + + Each feature can expose ``getAgentOntology() -> OntologyDescriptor`` in + its ``mainXxx.py``. When present, the descriptor is compiled via + :func:`ontologyToPromptCompiler.compileOntologyToPrompt` and the result + replaces the legacy ``getAgentDomainHints()`` text block. This keeps + one single source of truth for the validator AND the prompt. + + Failures are swallowed (missing hook, exceptions in compilation) so the + caller can fall back to the legacy domain-hints path. + """ + ontology = _loadFeatureOntology(featureCode) + if ontology is None: + return "" + try: + from modules.serviceCenter.services.serviceAgent.ontologyToPromptCompiler import ( + compileOntologyToPrompt, + ) + return compileOntologyToPrompt(ontology) + except Exception as exc: + logger.warning("Ontology compile failed for '%s': %s", featureCode, exc) + return "" + + def _loadFeatureDomainHints(featureCode: str) -> str: """Pull optional domain-specific hints from the feature's main module. diff --git a/modules/serviceCenter/services/serviceAgent/ontologyToPromptCompiler.py b/modules/serviceCenter/services/serviceAgent/ontologyToPromptCompiler.py new file mode 100644 index 00000000..5b162ed3 --- /dev/null +++ b/modules/serviceCenter/services/serviceAgent/ontologyToPromptCompiler.py @@ -0,0 +1,140 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Deterministic compiler: OntologyDescriptor -> sub-agent prompt block. + +Phase 2 replaces a feature's hand-written ``_AGENT_DOMAIN_HINTS`` text +with a structured :class:`OntologyDescriptor`. This compiler renders the +descriptor into a stable, terse Markdown-ish block that the sub-agent +appends to its system prompt -- the same source of truth the +:class:`QueryValidator` consults. + +The output is intentionally: +* short (every token costs every call) +* deterministic (no f-string ordering bugs, no Python dict iteration) +* free of internal jargon ('canonicalQueryPattern' is rendered as + 'CANONICAL PATTERN' for the LLM) +""" + +from __future__ import annotations + +from typing import Iterable, List + +from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( + CanonicalQueryPattern, + Constraint, + ConstraintRule, + Entity, + OntologyDescriptor, + Relation, +) + + +def compileOntologyToPrompt(ontology: OntologyDescriptor) -> str: + """Render *ontology* into a sub-agent prompt block. + + The output starts with a stable marker line (``DOMAIN ONTOLOGY (...)``) + so downstream tooling can find/replace it deterministically. + """ + lines: List[str] = [] + lines.append(f"DOMAIN ONTOLOGY ({ontology.featureCode}):") + lines.append("") + lines.extend(_renderEntities(ontology.entities)) + relationLines = _renderRelations(ontology.relations) + if relationLines: + lines.append("") + lines.extend(relationLines) + constraintLines = _renderConstraints(ontology.constraints) + if constraintLines: + lines.append("") + lines.extend(constraintLines) + patternLines = _renderPatterns(ontology.canonicalPatterns) + if patternLines: + lines.append("") + lines.extend(patternLines) + return "\n".join(lines).rstrip() + "\n" + + +def _renderEntities(entities: Iterable[Entity]) -> List[str]: + out: List[str] = ["ENTITIES:"] + for e in entities: + head = f"- {e.name}" + if e.parentEntity: + head += f" (specializes {e.parentEntity})" + if e.pythonClass: + head += f" [table: {e.pythonClass}]" + out.append(head) + if e.description: + out.append(f" {e.description}") + for inv in e.invariants: + out.append(f" * {inv.description}") + return out + + +def _renderRelations(relations: Iterable[Relation]) -> List[str]: + rels = list(relations) + if not rels: + return [] + out: List[str] = ["RELATIONS:"] + for r in rels: + line = f"- {r.fromEntity} -> {r.toEntity} ({r.cardinality.value}" + if r.via: + line += f" via {r.via}" + line += ")" + out.append(line) + return out + + +def _renderConstraints(constraints: Iterable[Constraint]) -> List[str]: + cons = list(constraints) + if not cons: + return [] + out: List[str] = ["CONSTRAINTS (validator-enforced):"] + for c in cons: + rule = _ruleLabel(c.rule) + line = f"- {rule} on {c.appliesTo}: {c.message}" + params = c.params or {} + required = params.get("requiredFields") + if isinstance(required, list) and required: + line += f" (required filters: {', '.join(required)})" + intents = params.get("intents") + if isinstance(intents, list) and intents: + line += f" (intents: {', '.join(intents)})" + out.append(line) + return out + + +def _ruleLabel(rule: ConstraintRule) -> str: + return rule.value.replace("_", " ").lower() + + +def _renderPatterns(patterns: Iterable[CanonicalQueryPattern]) -> List[str]: + pats = list(patterns) + if not pats: + return [] + out: List[str] = ["CANONICAL QUERY PATTERNS (mimic these tool calls):"] + for i, p in enumerate(pats, start=1): + out.append(f"{i}) intent={p.intent}: {p.description}") + out.append(f" call: {_renderPatternCall(p.pattern)}") + extra = p.pattern.get("_postProcessing") if isinstance(p.pattern, dict) else None + if isinstance(extra, str): + out.append(f" note: {extra}") + return out + + +def _renderPatternCall(pattern: dict) -> str: + """Render the pattern as a compact one-line tool call signature.""" + tool = pattern.get("tool", "?") + parts: List[str] = [] + for key in ("tableName", "aggregate", "field", "groupBy", "orderBy"): + if key in pattern and pattern[key] is not None and not str(key).startswith("_"): + parts.append(f"{key}={pattern[key]!r}") + if "fields" in pattern and pattern["fields"]: + parts.append(f"fields={pattern['fields']}") + if "filters" in pattern and pattern["filters"]: + compact = ", ".join( + f"{f.get('field')}{f.get('op','=')}{f.get('value')!r}" + for f in pattern["filters"] + if isinstance(f, dict) + ) + parts.append(f"filters=[{compact}]") + return f"{tool}({', '.join(parts)})" diff --git a/modules/serviceCenter/services/serviceAgent/queryValidator.py b/modules/serviceCenter/services/serviceAgent/queryValidator.py new file mode 100644 index 00000000..2dbbd57e --- /dev/null +++ b/modules/serviceCenter/services/serviceAgent/queryValidator.py @@ -0,0 +1,311 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Pre-execute query validator for the Feature Data Sub-Agent. + +Sits between the LLM tool call and `FeatureDataProvider`. Catches the four +high-impact hallucination classes deterministically so the LLM gets an +actionable repair hint instead of a raw SQL exception: + +* invented field names -> FIELD_NOT_FOUND (+ fuzzy suggestion) +* operator/type mismatches -> OPERATOR_INCOMPATIBLE +* SUM/AVG on already-aggregated -> INVALID_AGGREGATE_TARGET + balance/total columns +* orderBy on invented fields -> ORDER_BY_INVALID + +The validator reads the canonical schema from +`modules.datamodels.datamodelBase.MODEL_REGISTRY`. When an +`OntologyDescriptor` is provided (Phase 2), its constraints override the +convention-based defaults (e.g. NEVER_AGGREGATE on closingBalance). +""" + +from __future__ import annotations + +import difflib +import logging +import re +import typing +from typing import Any, Dict, List, Optional, Tuple + +from modules.datamodels.datamodelBase import MODEL_REGISTRY +from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( + Constraint, + ConstraintRule, + OntologyDescriptor, + QueryValidationError, + ValidationErrorCode, +) + +logger = logging.getLogger(__name__) + + +_STRING_ONLY_OPERATORS = {"LIKE", "ILIKE"} +_COMPARISON_OPERATORS = {">", "<", ">=", "<="} +_VALUELESS_OPERATORS = {"IS NULL", "IS NOT NULL"} +_AGGREGATES_THAT_SUM = {"SUM", "AVG"} +_AGGREGATE_BLACKLIST_SUFFIXES_DEFAULT: Tuple[str, ...] = ("Balance", "Total") + + +class QueryValidator: + """Validate sub-agent tool arguments against the schema (+ optional ontology). + + Stateless per call -- holding only the optional ontology. Each + `validateXxx` method returns ``None`` on success or a + :class:`QueryValidationError` to be surfaced to the LLM. + """ + + def __init__(self, ontology: Optional[OntologyDescriptor] = None): + self._ontology = ontology + + # ------------------------------------------------------------------ + # public API: one method per sub-agent tool + # ------------------------------------------------------------------ + + def validateBrowseQuery( + self, tableName: str, args: Dict[str, Any] + ) -> Optional[QueryValidationError]: + """Validate browseTable arguments. + + Phase 1 scope: only `fields` (whitelist) is LLM-driven; `limit`/`offset` + are sanitized by the tool wrapper. + """ + modelFields = _getModelFields(tableName) + if modelFields is None: + return None + + fieldsErr = self._validateFieldList(args.get("fields"), modelFields) + if fieldsErr is not None: + return fieldsErr + return None + + def validateQueryTable( + self, tableName: str, args: Dict[str, Any] + ) -> Optional[QueryValidationError]: + """Validate queryTable arguments (filters + fields + orderBy).""" + modelFields = _getModelFields(tableName) + if modelFields is None: + return None + + fieldsErr = self._validateFieldList(args.get("fields"), modelFields) + if fieldsErr is not None: + return fieldsErr + + for f in args.get("filters") or []: + filterErr = self._validateFilter(f, modelFields) + if filterErr is not None: + return filterErr + + orderBy = args.get("orderBy") + if orderBy is not None and not _isPlainNone(orderBy): + if orderBy not in modelFields: + return QueryValidationError( + code=ValidationErrorCode.ORDER_BY_INVALID, + field=orderBy, + suggestion=_suggestFieldName(orderBy, modelFields), + hint="orderBy must be a real field of this table.", + ) + return None + + def validateAggregateQuery( + self, tableName: str, args: Dict[str, Any] + ) -> Optional[QueryValidationError]: + """Validate aggregateTable arguments. + + Catches the highest-impact hallucination in the codebase: + ``SUM(closingBalance)`` (and friends) across periods -- closing + balances are already per-period, summing them produces nonsense. + """ + modelFields = _getModelFields(tableName) + if modelFields is None: + return None + + field = args.get("field") + aggregate = (args.get("aggregate") or "").upper() + + if not field: + return None # tool wrapper rejects empty field already + + if field not in modelFields: + return QueryValidationError( + code=ValidationErrorCode.FIELD_NOT_FOUND, + field=field, + suggestion=_suggestFieldName(field, modelFields), + hint="Use browseTable to inspect this table's columns.", + ) + + if aggregate in _AGGREGATES_THAT_SUM and self._isAggregateBlacklisted(tableName, field): + return QueryValidationError( + code=ValidationErrorCode.INVALID_AGGREGATE_TARGET, + field=field, + suggestion=None, + hint=( + f"{field} is already aggregated per period; do not {aggregate} it " + "across rows. Use queryTable with period filters instead." + ), + ) + + if aggregate in _AGGREGATES_THAT_SUM and not _isNumericAnnotation(modelFields[field]): + return QueryValidationError( + code=ValidationErrorCode.TYPE_MISMATCH, + field=field, + suggestion=None, + hint=f"{aggregate} requires a numeric field; {field} is not numeric.", + ) + + groupBy = args.get("groupBy") + if groupBy is not None and not _isPlainNone(groupBy): + if groupBy not in modelFields: + return QueryValidationError( + code=ValidationErrorCode.FIELD_NOT_FOUND, + field=groupBy, + suggestion=_suggestFieldName(groupBy, modelFields), + hint="groupBy must be a real field of this table.", + ) + + # filters validation matches queryTable so the LLM gets consistent + # repair hints regardless of which tool it picked. + for f in args.get("filters") or []: + filterErr = self._validateFilter(f, modelFields) + if filterErr is not None: + return filterErr + return None + + # ------------------------------------------------------------------ + # internals + # ------------------------------------------------------------------ + + def _validateFieldList( + self, fields: Optional[List[str]], modelFields: Dict[str, Any] + ) -> Optional[QueryValidationError]: + if not fields: + return None + for f in fields: + if not isinstance(f, str): + continue + if f not in modelFields: + return QueryValidationError( + code=ValidationErrorCode.FIELD_NOT_FOUND, + field=f, + suggestion=_suggestFieldName(f, modelFields), + hint="Use browseTable to inspect this table's columns.", + ) + return None + + def _validateFilter( + self, filterEntry: Any, modelFields: Dict[str, Any] + ) -> Optional[QueryValidationError]: + if not isinstance(filterEntry, dict): + return None + field = filterEntry.get("field") + op = (filterEntry.get("op") or "=").upper() + + if not isinstance(field, str) or not field: + return None # tool wrapper passes these straight through + + if field not in modelFields: + return QueryValidationError( + code=ValidationErrorCode.FIELD_NOT_FOUND, + field=field, + suggestion=_suggestFieldName(field, modelFields), + hint="Use browseTable to inspect this table's columns.", + ) + + annotation = modelFields[field] + + if op in _STRING_ONLY_OPERATORS and not _isStringAnnotation(annotation): + return QueryValidationError( + code=ValidationErrorCode.OPERATOR_INCOMPATIBLE, + field=field, + suggestion=None, + hint=f"{op} only works on string fields; {field} is not a string.", + ) + + if op in _COMPARISON_OPERATORS and not _isComparableAnnotation(annotation): + return QueryValidationError( + code=ValidationErrorCode.OPERATOR_INCOMPATIBLE, + field=field, + suggestion=None, + hint=f"{op} requires a numeric or date field; {field} is not comparable.", + ) + return None + + def _isAggregateBlacklisted(self, tableName: str, fieldName: str) -> bool: + """Check whether a field is marked NEVER_AGGREGATE. + + Phase 2 (ontology present): consult the descriptor. + Phase 1 fallback: naming convention (``*Balance`` / ``*Total``). + """ + if self._ontology is not None: + target = f"{tableName}.{fieldName}" + for c in self._ontology.constraintsForTable(tableName): + if c.rule == ConstraintRule.NEVER_AGGREGATE and c.appliesTo == target: + return True + + for suffix in _AGGREGATE_BLACKLIST_SUFFIXES_DEFAULT: + if fieldName.endswith(suffix): + return True + return False + + +# ------------------------------------------------------------------ +# helpers +# ------------------------------------------------------------------ + +def _getModelFields(tableName: str) -> Optional[Dict[str, Any]]: + """Return ``{fieldName: annotation}`` for a registered Pydantic table model. + + None when the table is not in MODEL_REGISTRY (e.g. pure UDB tables in + early-startup contexts). The validator is a best-effort layer -- when + the schema is unknown we let the request through and rely on the + downstream SQL layer for safety. + """ + modelClass = MODEL_REGISTRY.get(tableName) + if modelClass is None: + return None + return { + name: info.annotation for name, info in modelClass.model_fields.items() + } + + +def _suggestFieldName(badName: str, modelFields: Dict[str, Any]) -> Optional[str]: + """Return the closest valid field name, or None if nothing reasonable.""" + if not badName or not modelFields: + return None + matches = difflib.get_close_matches(badName, list(modelFields.keys()), n=1, cutoff=0.6) + return matches[0] if matches else None + + +def _isPlainNone(value: Any) -> bool: + """LLMs sometimes pass the literal string 'None' -- treat both as None.""" + return value is None or (isinstance(value, str) and value.strip().lower() == "none") + + +def _unwrapAnnotation(annotation: Any) -> Tuple[Any, ...]: + """Flatten Optional/Union annotations into their constituent types.""" + origin = typing.get_origin(annotation) + if origin is None: + return (annotation,) + return tuple(a for a in typing.get_args(annotation) if a is not type(None)) + + +def _isStringAnnotation(annotation: Any) -> bool: + return any(a is str for a in _unwrapAnnotation(annotation)) + + +def _isNumericAnnotation(annotation: Any) -> bool: + numericTypes = (int, float) + return any(a in numericTypes for a in _unwrapAnnotation(annotation)) + + +def _isComparableAnnotation(annotation: Any) -> bool: + """Numeric types are the comparable shape we see in feature tables. + + Booleans count as int in Python's type hierarchy but the comparison + operators ``>``/``<`` on bool columns are almost never meaningful, so we + treat bool as non-comparable for validator purposes. + """ + for a in _unwrapAnnotation(annotation): + if a is bool: + continue + if a in (int, float): + return True + return False diff --git a/modules/serviceCenter/services/serviceAgent/sandboxExecutor.py b/modules/serviceCenter/services/serviceAgent/sandboxExecutor.py index c2e16506..2fbe9c34 100644 --- a/modules/serviceCenter/services/serviceAgent/sandboxExecutor.py +++ b/modules/serviceCenter/services/serviceAgent/sandboxExecutor.py @@ -98,14 +98,17 @@ class _VirtualFS: def _makeReadFile(services): """Create a readFile(fileId) closure bound to the current services context.""" - def readFile(fileId: str) -> str: + def readFile(fileId: str, encoding: str = "utf-8") -> str: mgmt = getattr(services, 'interfaceDbComponent', None) if services else None if not mgmt: raise RuntimeError("readFile: no file store available in this session") data = mgmt.getFileData(str(fileId)) if data is None: raise FileNotFoundError(f"File '{fileId}' not found in workspace") - return data.decode("utf-8") + try: + return data.decode(encoding) + except (UnicodeDecodeError, LookupError): + return data.decode("utf-8", errors="replace") return readFile diff --git a/modules/serviceCenter/services/serviceAi/subAiCallLooping.py b/modules/serviceCenter/services/serviceAi/subAiCallLooping.py index 4285de51..3ef22535 100644 --- a/modules/serviceCenter/services/serviceAi/subAiCallLooping.py +++ b/modules/serviceCenter/services/serviceAi/subAiCallLooping.py @@ -60,6 +60,7 @@ from modules.shared.jsonContinuation import getContexts from modules.shared.jsonUtils import buildContinuationContext, tryParseJson from modules.shared.jsonUtils import closeJsonStructures from modules.shared.jsonUtils import stripCodeFences, normalizeJsonText +from modules.shared.jsonUtils import extractJsonString, repairBrokenJson logger = logging.getLogger(__name__) @@ -447,7 +448,6 @@ class AiCallLooper: extracted = extractJsonString(contexts.completePart) parsed, parseErr, _ = tryParseJson(extracted) if parseErr is not None: - from modules.shared.jsonUtils import repairBrokenJson repaired = repairBrokenJson(extracted) if repaired: parsed = repaired @@ -470,9 +470,10 @@ class AiCallLooper: return useCase.finalResultHandler( result, normalized, extracted, debugPrefix, self.services ) - except Exception as e: + except (json.JSONDecodeError, KeyError, TypeError) as e: logger.warning( - f"Iteration {iteration}: completePart not serializable after getContexts success: {e}" + f"Iteration {iteration}: completePart not serializable after getContexts success: " + f"{type(e).__name__}: {e}" ) mergeFailCount += 1 if mergeFailCount >= MAX_MERGE_FAILS: @@ -491,6 +492,15 @@ class AiCallLooper: ) self.services.chat.progressLogFinish(iterationOperationId, True) continue + except Exception as e: + logger.error( + f"Iteration {iteration}: unexpected error during completePart processing " + f"(re-raising, NOT a pipeline-mismatch retry): {type(e).__name__}: {e}", + exc_info=True, + ) + if iterationOperationId: + self.services.chat.progressLogFinish(iterationOperationId, False) + raise elif contexts.jsonParsingSuccess and contexts.overlapContext != "": # JSON parseable but has cut point - CONTINUE to next iteration diff --git a/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py b/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py index 66ca4708..e27dae58 100644 --- a/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py +++ b/modules/serviceCenter/services/serviceBackgroundJobs/mainBackgroundJobService.py @@ -34,7 +34,7 @@ import time from datetime import datetime, timezone from typing import Any, Awaitable, Callable, Dict, List, Optional -from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.connectors.connectorDbPostgre import DatabaseConnector, getCachedConnector from modules.shared.configuration import APP_CONFIG from modules.shared.dbRegistry import registerDatabase from modules.datamodels.datamodelBackgroundJob import ( @@ -104,7 +104,13 @@ def registerJobHandler(jobType: str, handler: JobHandler) -> None: def _getDb() -> DatabaseConnector: - return DatabaseConnector( + """Return the shared cached connector for the jobs DB. + + Reuses the same connector across all job CRUD calls instead of opening a + fresh psycopg2 connection (and re-running `_create_database_if_not_exists` + + `_create_tables` + `_initializeSystemTable`) on every operation. + """ + return getCachedConnector( dbDatabase=JOBS_DATABASE, dbHost=APP_CONFIG.get("DB_HOST", "localhost"), dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), @@ -290,12 +296,12 @@ def cancelJobsByConnection(connectionId: str, *, jobType: str = "connection.boot def recoverInterruptedJobs() -> int: - """Flip any RUNNING jobs to ERROR and re-queue bootstrap jobs (called at worker boot). + """Flip any RUNNING jobs to ERROR (called at worker boot). A RUNNING job in the DB after process restart means the previous worker died mid-execution; the asyncio task is gone and the job will never - finish on its own. For connection.bootstrap jobs, a fresh job is - automatically re-queued so the user doesn't have to manually retry. + finish on its own. The daily scheduler or manual "Neu indexieren" + button handles retry — no automatic re-queue to avoid infinite loops. """ db = _getDb() try: @@ -304,34 +310,70 @@ def recoverInterruptedJobs() -> int: logger.warning("recoverInterruptedJobs: failed to scan RUNNING jobs: %s", ex) return 0 count = 0 - requeued = 0 for row in rows: try: _markError(row["id"], "Interrupted by worker restart") count += 1 except Exception as ex: logger.warning("recoverInterruptedJobs: could not mark %s as ERROR: %s", row.get("id"), ex) - continue - - if row.get("jobType") == "connection.bootstrap": - payload = row.get("payload") or {} - if payload.get("connectionId"): - try: - newJob = BackgroundJob( - jobType="connection.bootstrap", - payload=payload, - triggeredBy="recovery.requeue", - ) - record = db.recordCreate(BackgroundJob, _serialiseDatetimes(newJob.model_dump())) - asyncio.create_task(_runJob(record["id"])) - requeued += 1 - logger.info( - "recoverInterruptedJobs: re-queued bootstrap for connectionId=%s (new jobId=%s)", - payload["connectionId"], record["id"], - ) - except Exception as reqEx: - logger.warning("recoverInterruptedJobs: re-queue failed for %s: %s", row.get("id"), reqEx) - if count: - logger.warning("Recovered %d interrupted background job(s) after restart (re-queued %d)", count, requeued) + logger.warning("Recovered %d interrupted background job(s) after restart", count) return count + + +_ZOMBIE_MAX_AGE_SECONDS = 30 * 60 + + +def killZombieJobs(maxAgeSeconds: int = _ZOMBIE_MAX_AGE_SECONDS) -> int: + """Kill RUNNING jobs that have not been updated within `maxAgeSeconds`. + + Detects walkers that are stuck in a sync call without progress updates. + A live job updates progress at least every few seconds via JobProgressCallback. + Anything older than maxAgeSeconds without finishing is considered hung. + """ + db = _getDb() + try: + rows = db.getRecordset(BackgroundJob, recordFilter={"status": BackgroundJobStatusEnum.RUNNING.value}) + except Exception as ex: + logger.warning("killZombieJobs: failed to scan RUNNING jobs: %s", ex) + return 0 + now = time.time() + threshold = now - maxAgeSeconds + count = 0 + for row in rows: + started = row.get("startedAt") or row.get("createdAt") + if not started or started > threshold: + continue + ageMin = (now - started) / 60 + try: + _markError(row["id"], f"Zombie killed (stuck >{maxAgeSeconds // 60}min, no progress)") + count += 1 + payload = row.get("payload") or {} + logger.warning( + "killZombieJobs: killed %s (type=%s connId=%s ageMin=%.1f)", + row["id"], row.get("jobType"), payload.get("connectionId", "")[:12], ageMin, + ) + except Exception as ex: + logger.warning("killZombieJobs: could not kill %s: %s", row.get("id"), ex) + return count + + +def registerZombieKillerScheduler(*, intervalMinutes: int = 5) -> None: + """Register a recurring cron job that kills stuck RUNNING jobs. + + Idempotent. Runs every `intervalMinutes` minutes. + """ + try: + from modules.shared.eventManagement import eventManager + + async def _runKiller(): + killZombieJobs() + + eventManager.registerCron( + jobId="background_jobs.zombie_killer", + func=_runKiller, + cronKwargs={"minute": f"*/{intervalMinutes}"}, + ) + logger.info("Zombie-killer scheduler registered (every %d min)", intervalMinutes) + except Exception as ex: + logger.warning("Zombie-killer scheduler registration failed (non-critical): %s", ex) diff --git a/modules/serviceCenter/services/serviceChat/mainServiceChat.py b/modules/serviceCenter/services/serviceChat/mainServiceChat.py index 7852360c..2ca61d7e 100644 --- a/modules/serviceCenter/services/serviceChat/mainServiceChat.py +++ b/modules/serviceCenter/services/serviceChat/mainServiceChat.py @@ -532,8 +532,16 @@ class ChatService: self, connectionId: str, sourceType: str, path: str, label: str, featureInstanceId: str = None, displayPath: str = None, ) -> Dict[str, Any]: - """Create a new external data source reference.""" + """Create a new external data source reference. + + Returns existing record if connectionId + path already exists (upsert semantics). + """ from modules.datamodels.datamodelDataSource import DataSource + existing = self.interfaceDbApp.db.getRecordset( + DataSource, recordFilter={"connectionId": connectionId, "path": path} + ) + if existing: + return existing[0] if isinstance(existing[0], dict) else existing[0].model_dump() ds = DataSource( connectionId=connectionId, sourceType=sourceType, diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py index 0e2d251f..c86aed86 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorIngestConsumer.py @@ -132,10 +132,10 @@ _SOURCE_TYPE_MAP = { "gmail": ("gmailFolder",), }, "clickup": { - "clickup": ("clickupList",), + "clickup": ("clickupList", "clickup"), }, "infomaniak": { - "kdrive": ("kdriveFolder",), + "kdrive": ("kdriveFolder", "infomaniak"), }, } @@ -225,7 +225,7 @@ async def _bootstrapJobHandler( bootstrapOutlook, ) - progressCb(10, "sharepoint + outlook") + progressCb(0, "Synchronisierung läuft...") spDs = _filterDs("sharepoint") olDs = _filterDs("outlook") async def _noopResult(): @@ -251,7 +251,7 @@ async def _bootstrapJobHandler( bootstrapGmail, ) - progressCb(10, "drive + gmail") + progressCb(0, "Synchronisierung läuft...") gdDs = _filterDs("drive") gmDs = _filterDs("gmail") async def _noopResult(): @@ -274,7 +274,7 @@ async def _bootstrapJobHandler( bootstrapClickup, ) - progressCb(10, "clickup tasks") + progressCb(0, "Synchronisierung läuft...") cuDs = _filterDs("clickup") cuResult = await bootstrapClickup(connectionId=connectionId, progressCb=progressCb, dataSources=cuDs) if cuDs else {"skipped": True, "reason": "no_datasources"} return { @@ -283,6 +283,20 @@ async def _bootstrapJobHandler( "clickup": _normalize(cuResult, "clickup"), } + if authority == "infomaniak": + from modules.serviceCenter.services.serviceKnowledge.subConnectorSyncKdrive import ( + bootstrapKdrive, + ) + + progressCb(0, "Synchronisierung läuft...") + kdDs = _filterDs("kdrive") + kdResult = await bootstrapKdrive(connectionId=connectionId, progressCb=progressCb, dataSources=kdDs) if kdDs else {"skipped": True, "reason": "no_datasources"} + return { + "connectionId": connectionId, + "authority": authority, + "kdrive": _normalize(kdResult, "kdrive"), + } + logger.info( "ingestion.connection.bootstrap.skipped reason=unsupported_authority authority=%s connectionId=%s", authority, connectionId, diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py index 7acbaa19..8bfa2628 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncClickup.py @@ -25,6 +25,12 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Optional +from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( + WalkerTimeout, + ingestWithTimeout, + logItemStart, +) + logger = logging.getLogger(__name__) MAX_TASKS_DEFAULT = 500 @@ -449,36 +455,44 @@ async def _ingestTask( name = task.get("name") or f"Task {taskId}" syntheticId = _syntheticTaskId(connectionId, taskId) fileName = f"{name[:80].strip() or taskId}.task.json" + logItemStart("clickup", f"{teamId}/{taskId}") contentObjects = _buildContentObjects(task, limits) try: - handle = await knowledgeService.requestIngestion( - IngestionJob( - sourceKind="clickup_task", - sourceId=syntheticId, - fileName=fileName, - mimeType="application/vnd.clickup.task+json", - userId=userId, - mandateId=mandateId, - contentObjects=contentObjects, - contentVersion=revision or None, - neutralize=limits.neutralize, - provenance={ - "connectionId": connectionId, - "dataSourceId": dataSourceId, - "authority": "clickup", - "service": "clickup", - "externalItemId": taskId, - "teamId": teamId, - "listId": ((task.get("list") or {}).get("id")), - "spaceId": ((task.get("space") or {}).get("id")), - "url": task.get("url"), - "status": ((task.get("status") or {}).get("status")), - "tier": limits.clickupScope, - }, - ) + handle = await ingestWithTimeout( + knowledgeService.requestIngestion( + IngestionJob( + sourceKind="clickup_task", + sourceId=syntheticId, + fileName=fileName, + mimeType="application/vnd.clickup.task+json", + userId=userId, + mandateId=mandateId, + contentObjects=contentObjects, + contentVersion=revision or None, + neutralize=limits.neutralize, + provenance={ + "connectionId": connectionId, + "dataSourceId": dataSourceId, + "authority": "clickup", + "service": "clickup", + "externalItemId": taskId, + "teamId": teamId, + "listId": ((task.get("list") or {}).get("id")), + "spaceId": ((task.get("space") or {}).get("id")), + "url": task.get("url"), + "status": ((task.get("status") or {}).get("status")), + "tier": limits.clickupScope, + }, + ) + ), + label=taskId, ) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return except Exception as exc: logger.error("clickup ingestion %s failed: %s", taskId, exc, exc_info=True) result.failed += 1 @@ -493,18 +507,16 @@ async def _ingestTask( result.failed += 1 processed = result.indexed + result.skippedDuplicate - if progressCb is not None and processed % 50 == 0: + if progressCb is not None and processed % 5 == 0: if hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): return try: - progressCb( - min(90, 10 + int(80 * processed / max(1, limits.maxTasks))), - f"clickup processed={processed}", - ) + progressCb(0, f"{processed} Tasks verarbeitet, {result.indexed} indexiert") except Exception: pass - logger.info( - "ingestion.connection.bootstrap.progress part=clickup processed=%d skippedDup=%d failed=%d", + if processed % 50 == 0: + logger.info( + "ingestion.connection.bootstrap.progress part=clickup processed=%d skippedDup=%d failed=%d", processed, result.skippedDuplicate, result.failed, extra={ "event": "ingestion.connection.bootstrap.progress", diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py index 398b9af9..5dd1bd8b 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGdrive.py @@ -21,6 +21,13 @@ from datetime import datetime, timedelta, timezone from typing import Any, Callable, Dict, List, Optional from modules.datamodels.datamodelExtraction import ExtractionOptions +from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( + WalkerTimeout, + downloadWithTimeout, + extractWithTimeout, + ingestWithTimeout, + logItemStart, +) logger = logging.getLogger(__name__) @@ -342,9 +349,15 @@ async def _ingestOne( syntheticFileId = _syntheticFileId(connectionId, externalItemId) fileName = getattr(entry, "name", "") or externalItemId + declaredSize = int(getattr(entry, "size", 0) or 0) or None + logItemStart("gdrive", entryPath, sizeBytes=declaredSize, mime=mimeType) try: - downloaded = await adapter.download(entryPath) + downloaded = await downloadWithTimeout(adapter.download(entryPath), label=entryPath) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return except Exception as exc: logger.warning("gdrive download %s failed: %s", entryPath, exc) result.failed += 1 @@ -368,10 +381,16 @@ async def _ingestOne( result.bytesProcessed += len(fileBytes) try: - extracted = runExtractionFn( + extracted = await extractWithTimeout( + runExtractionFn, fileBytes, fileName, mimeType, ExtractionOptions(mergeStrategy=None), + label=entryPath, ) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return except Exception as exc: logger.warning("gdrive extraction %s failed: %s", entryPath, exc) result.failed += 1 @@ -393,20 +412,27 @@ async def _ingestOne( "tier": "body", } try: - handle = await knowledgeService.requestIngestion( - IngestionJob( - sourceKind="gdrive_item", - sourceId=syntheticFileId, - fileName=fileName, - mimeType=mimeType, - userId=userId, - mandateId=mandateId, - contentObjects=contentObjects, - contentVersion=revision, - neutralize=limits.neutralize, - provenance=provenance, - ) + handle = await ingestWithTimeout( + knowledgeService.requestIngestion( + IngestionJob( + sourceKind="gdrive_item", + sourceId=syntheticFileId, + fileName=fileName, + mimeType=mimeType, + userId=userId, + mandateId=mandateId, + contentObjects=contentObjects, + contentVersion=revision, + neutralize=limits.neutralize, + provenance=provenance, + ) + ), + label=entryPath, ) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return except Exception as exc: logger.error("gdrive ingestion %s failed: %s", entryPath, exc, exc_info=True) result.failed += 1 @@ -422,13 +448,10 @@ async def _ingestOne( if handle.error: result.errors.append(f"ingest({entryPath}): {handle.error}") - if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0: - processed = result.indexed + result.skippedDuplicate + processed = result.indexed + result.skippedDuplicate + if progressCb is not None and processed % 5 == 0: try: - progressCb( - min(90, 10 + int(80 * processed / max(1, limits.maxItems))), - f"gdrive processed={processed}", - ) + progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert") except Exception: pass logger.info( diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py index f5c345c6..3130e942 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncGmail.py @@ -24,6 +24,11 @@ from datetime import datetime, timedelta, timezone from typing import Any, Callable, Dict, List, Optional from modules.serviceCenter.services.serviceKnowledge.subTextClean import cleanEmailBody +from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( + WalkerTimeout, + ingestWithTimeout, + logItemStart, +) logger = logging.getLogger(__name__) @@ -399,34 +404,42 @@ async def _ingestMessage( subject = headers.get("subject") or "(no subject)" syntheticId = _syntheticMessageId(connectionId, messageId) fileName = f"{subject[:80].strip()}.eml" if subject else f"{messageId}.eml" + logItemStart("gmail", f"{labelId}/{messageId}", mime="message/rfc822") contentObjects = _buildContentObjects( message, limits.maxBodyChars, mailContentDepth=limits.mailContentDepth ) try: - handle = await knowledgeService.requestIngestion( - IngestionJob( - sourceKind="gmail_message", - sourceId=syntheticId, - fileName=fileName, - mimeType="message/rfc822", - userId=userId, - mandateId=mandateId, - contentObjects=contentObjects, - contentVersion=str(revision) if revision else None, - neutralize=limits.neutralize, - provenance={ - "connectionId": connectionId, - "dataSourceId": dataSourceId, - "authority": "google", - "service": "gmail", - "externalItemId": messageId, - "label": labelId, - "threadId": message.get("threadId"), - "tier": limits.mailContentDepth, - }, - ) + handle = await ingestWithTimeout( + knowledgeService.requestIngestion( + IngestionJob( + sourceKind="gmail_message", + sourceId=syntheticId, + fileName=fileName, + mimeType="message/rfc822", + userId=userId, + mandateId=mandateId, + contentObjects=contentObjects, + contentVersion=str(revision) if revision else None, + neutralize=limits.neutralize, + provenance={ + "connectionId": connectionId, + "dataSourceId": dataSourceId, + "authority": "google", + "service": "gmail", + "externalItemId": messageId, + "label": labelId, + "threadId": message.get("threadId"), + "tier": limits.mailContentDepth, + }, + ) + ), + label=messageId, ) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return except Exception as exc: logger.error("gmail ingestion %s failed: %s", messageId, exc, exc_info=True) result.failed += 1 @@ -458,18 +471,16 @@ async def _ingestMessage( logger.warning("gmail attachments %s failed: %s", messageId, exc) result.errors.append(f"attachments({messageId}): {exc}") - if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0: - processed = result.indexed + result.skippedDuplicate + processed = result.indexed + result.skippedDuplicate + if progressCb is not None and processed % 5 == 0: try: - progressCb( - min(90, 10 + int(80 * processed / max(1, limits.maxMessages))), - f"gmail processed={processed}", - ) + progressCb(0, f"{processed} Mails verarbeitet, {result.indexed} indexiert") except Exception: pass - logger.info( - "ingestion.connection.bootstrap.progress part=gmail processed=%d skippedDup=%d failed=%d", - processed, result.skippedDuplicate, result.failed, + if processed % 50 == 0: + logger.info( + "ingestion.connection.bootstrap.progress part=gmail processed=%d skippedDup=%d failed=%d", + processed, result.skippedDuplicate, result.failed, extra={ "event": "ingestion.connection.bootstrap.progress", "part": "gmail", @@ -546,13 +557,26 @@ async def _ingestAttachments( fileName = stub["filename"] mimeType = stub["mimeType"] syntheticId = _syntheticAttachmentId(connectionId, messageId, stub["attachmentId"]) + attLabel = f"{messageId}/att:{stub['attachmentId']}/{fileName}" + logItemStart("gmail-attachment", attLabel, sizeBytes=stub.get("size") or None, mime=mimeType) - try: - extracted = runExtraction( + from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( + extractWithTimeout as _extractWithTimeout, + ) + + def _runAttExtraction(): + return runExtraction( extractorRegistry, chunkerRegistry, rawBytes, fileName, mimeType, ExtractionOptions(mergeStrategy=None), ) + + try: + extracted = await _extractWithTimeout(_runAttExtraction, label=attLabel) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + continue except Exception as exc: logger.warning("gmail attachment extract %s failed: %s", stub["attachmentId"], exc) result.failed += 1 @@ -584,27 +608,33 @@ async def _ingestAttachments( continue try: - await knowledgeService.requestIngestion( - IngestionJob( - sourceKind="gmail_attachment", - sourceId=syntheticId, - fileName=fileName, - mimeType=mimeType, - userId=userId, - mandateId=mandateId, - contentObjects=contentObjects, - provenance={ - "connectionId": connectionId, - "dataSourceId": dataSourceId, - "authority": "google", - "service": "gmail", - "parentId": parentSyntheticId, - "externalItemId": stub["attachmentId"], - "parentMessageId": messageId, - }, - ) + await ingestWithTimeout( + knowledgeService.requestIngestion( + IngestionJob( + sourceKind="gmail_attachment", + sourceId=syntheticId, + fileName=fileName, + mimeType=mimeType, + userId=userId, + mandateId=mandateId, + contentObjects=contentObjects, + provenance={ + "connectionId": connectionId, + "dataSourceId": dataSourceId, + "authority": "google", + "service": "gmail", + "parentId": parentSyntheticId, + "externalItemId": stub["attachmentId"], + "parentMessageId": messageId, + }, + ) + ), + label=attLabel, ) result.attachmentsIndexed += 1 + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) except Exception as exc: logger.warning("gmail attachment ingest %s failed: %s", stub["attachmentId"], exc) result.failed += 1 diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py new file mode 100644 index 00000000..e656abe8 --- /dev/null +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncKdrive.py @@ -0,0 +1,439 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""kDrive bootstrap for the unified knowledge ingestion lane. + +Walks every ragIndexEnabled kDrive DataSource, downloads file items and +hands them to KnowledgeService.requestIngestion. Idempotency is provided +by the ingestion facade (content-hash dedup). +""" + +from __future__ import annotations + +import asyncio +import hashlib +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +from modules.datamodels.datamodelExtraction import ExtractionOptions +from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( + WalkerTimeout, + downloadWithTimeout, + extractWithTimeout, + ingestWithTimeout, + logItemStart, +) + +logger = logging.getLogger(__name__) + +MAX_ITEMS_DEFAULT = 500 +MAX_BYTES_DEFAULT = 200 * 1024 * 1024 +MAX_FILE_SIZE_DEFAULT = 25 * 1024 * 1024 +SKIP_MIME_PREFIXES_DEFAULT = ("video/", "audio/") +MAX_DEPTH_DEFAULT = 4 + + +@dataclass +class KdriveBootstrapLimits: + maxItems: int = MAX_ITEMS_DEFAULT + maxBytes: int = MAX_BYTES_DEFAULT + maxFileSize: int = MAX_FILE_SIZE_DEFAULT + skipMimePrefixes: tuple = SKIP_MIME_PREFIXES_DEFAULT + maxDepth: int = MAX_DEPTH_DEFAULT + neutralize: bool = False + + +@dataclass +class KdriveBootstrapResult: + connectionId: str + indexed: int = 0 + skippedDuplicate: int = 0 + skippedPolicy: int = 0 + failed: int = 0 + bytesProcessed: int = 0 + errors: List[str] = field(default_factory=list) + + +def _syntheticFileId(connectionId: str, externalItemId: str) -> str: + token = hashlib.sha256(f"{connectionId}:{externalItemId}".encode("utf-8")).hexdigest()[:16] + return f"kd:{connectionId[:8]}:{token}" + + +def _toContentObjects(extracted, fileName: str) -> List[Dict[str, Any]]: + parts = getattr(extracted, "parts", None) or [] + out: List[Dict[str, Any]] = [] + for part in parts: + data = getattr(part, "data", None) or "" + if not data or not str(data).strip(): + continue + typeGroup = getattr(part, "typeGroup", "text") or "text" + contentType = "text" + if typeGroup == "image": + contentType = "image" + elif typeGroup in ("binary", "container"): + contentType = "other" + out.append({ + "contentObjectId": getattr(part, "id", ""), + "contentType": contentType, + "data": data, + "contextRef": { + "containerPath": fileName, + "location": getattr(part, "label", None) or "file", + **(getattr(part, "metadata", None) or {}), + }, + }) + return out + + +async def bootstrapKdrive( + connectionId: str, + *, + dataSources: Optional[List[Dict[str, Any]]] = None, + progressCb: Optional[Any] = None, + adapter: Any = None, + connection: Any = None, + knowledgeService: Any = None, + limits: Optional[KdriveBootstrapLimits] = None, + runExtractionFn: Optional[Callable[..., Any]] = None, +) -> Dict[str, Any]: + """Enumerate kDrive folders and ingest files via the facade.""" + if not dataSources: + return {"connectionId": connectionId, "skipped": True, "reason": "no_datasources"} + + if not limits: + limits = KdriveBootstrapLimits() + + startMs = time.time() + result = KdriveBootstrapResult(connectionId=connectionId) + + logger.info( + "ingestion.connection.bootstrap.started part=kdrive connectionId=%s dataSources=%d", + connectionId, len(dataSources), + extra={"event": "ingestion.connection.bootstrap.started", "part": "kdrive", + "connectionId": connectionId, "dataSourceCount": len(dataSources)}, + ) + + if adapter is None or knowledgeService is None or connection is None: + adapter, connection, knowledgeService = await _resolveDependencies(connectionId) + if runExtractionFn is None: + from modules.serviceCenter.services.serviceExtraction.subPipeline import runExtraction + from modules.serviceCenter.services.serviceExtraction.subRegistry import ( + ExtractorRegistry, ChunkerRegistry, + ) + extractorRegistry = ExtractorRegistry() + chunkerRegistry = ChunkerRegistry() + + def runExtractionFn(bytesData, name, mime, options): + return runExtraction(extractorRegistry, chunkerRegistry, bytesData, name, mime, options) + + mandateId = str(getattr(connection, "mandateId", "") or "") if connection is not None else "" + userId = str(getattr(connection, "userId", "") or "") if connection is not None else "" + + cancelled = False + for ds in dataSources: + if result.indexed + result.skippedDuplicate >= limits.maxItems: + break + if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): + cancelled = True + break + + dsPath = ds.get("path", "") + dsId = ds.get("id", "") + dsNeutralize = ds.get("neutralize", False) + dsLimits = KdriveBootstrapLimits( + maxItems=limits.maxItems, + maxBytes=limits.maxBytes, + maxFileSize=limits.maxFileSize, + skipMimePrefixes=limits.skipMimePrefixes, + maxDepth=limits.maxDepth, + neutralize=dsNeutralize, + ) + + try: + await _walkFolder( + adapter=adapter, + knowledgeService=knowledgeService, + runExtractionFn=runExtractionFn, + connectionId=connectionId, + mandateId=mandateId, + userId=userId, + folderPath=dsPath, + depth=0, + limits=dsLimits, + result=result, + progressCb=progressCb, + dataSourceId=dsId, + ) + except Exception as exc: + logger.error("kdrive walk failed for ds %s path %s: %s", dsId, dsPath, exc, exc_info=True) + result.errors.append(f"walk({dsPath}): {exc}") + + finalResult = _finalizeResult(connectionId, result, startMs) + if cancelled: + finalResult["cancelled"] = True + return finalResult + + +async def _resolveDependencies(connectionId: str): + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.auth import TokenManager + from modules.connectors.providerInfomaniak.connectorInfomaniak import InfomaniakConnector + from modules.serviceCenter import getService + from modules.serviceCenter.context import ServiceCenterContext + from modules.security.rootAccess import getRootUser + + rootInterface = getRootInterface() + connection = rootInterface.getUserConnectionById(connectionId) + if connection is None: + raise ValueError(f"UserConnection not found: {connectionId}") + + token = TokenManager().getFreshToken(connectionId) + if not token or not token.tokenAccess: + raise ValueError(f"No valid token for connection {connectionId}") + + provider = InfomaniakConnector(connection, token.tokenAccess) + adapter = provider.getServiceAdapter("kdrive") + + rootUser = getRootUser() + ctx = ServiceCenterContext( + user=rootUser, + mandate_id=str(getattr(connection, "mandateId", "") or ""), + ) + knowledgeService = getService("knowledge", ctx) + return adapter, connection, knowledgeService + + +async def _walkFolder( + *, + adapter, + knowledgeService, + runExtractionFn, + connectionId: str, + mandateId: str, + userId: str, + folderPath: str, + depth: int, + limits: KdriveBootstrapLimits, + result: KdriveBootstrapResult, + progressCb: Optional[Any], + dataSourceId: str = "", +) -> None: + if depth > limits.maxDepth: + return + if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled(): + return + try: + entries = await adapter.browse(folderPath) + except Exception as exc: + logger.warning("kdrive browse %s failed: %s", folderPath, exc) + result.errors.append(f"browse({folderPath}): {exc}") + return + + for entry in entries: + if result.indexed + result.skippedDuplicate >= limits.maxItems: + return + if result.bytesProcessed >= limits.maxBytes: + return + if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled(): + return + + entryPath = getattr(entry, "path", "") or "" + if getattr(entry, "isFolder", False): + await _walkFolder( + adapter=adapter, + knowledgeService=knowledgeService, + runExtractionFn=runExtractionFn, + connectionId=connectionId, + mandateId=mandateId, + userId=userId, + folderPath=entryPath, + depth=depth + 1, + limits=limits, + result=result, + progressCb=progressCb, + dataSourceId=dataSourceId, + ) + continue + + mimeType = getattr(entry, "mimeType", None) or "application/octet-stream" + if any(mimeType.startswith(prefix) for prefix in limits.skipMimePrefixes): + result.skippedPolicy += 1 + continue + size = int(getattr(entry, "size", 0) or 0) + if size and size > limits.maxFileSize: + result.skippedPolicy += 1 + continue + + metadata = getattr(entry, "metadata", {}) or {} + externalItemId = metadata.get("id") or entryPath + revision = metadata.get("revision") or metadata.get("lastModified") + + await _ingestOne( + adapter=adapter, + knowledgeService=knowledgeService, + runExtractionFn=runExtractionFn, + connectionId=connectionId, + mandateId=mandateId, + userId=userId, + entry=entry, + entryPath=entryPath, + mimeType=mimeType, + externalItemId=externalItemId, + revision=revision, + limits=limits, + result=result, + progressCb=progressCb, + dataSourceId=dataSourceId, + ) + + +async def _ingestOne( + *, + adapter, + knowledgeService, + runExtractionFn, + connectionId: str, + mandateId: str, + userId: str, + entry, + entryPath: str, + mimeType: str, + externalItemId: str, + revision: Optional[str], + limits: KdriveBootstrapLimits, + result: KdriveBootstrapResult, + progressCb: Optional[Any], + dataSourceId: str = "", +) -> None: + from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob + + syntheticFileId = _syntheticFileId(connectionId, externalItemId) + fileName = getattr(entry, "name", "") or externalItemId + declaredSize = int(getattr(entry, "size", 0) or 0) or None + logItemStart("kdrive", entryPath, sizeBytes=declaredSize, mime=mimeType) + + try: + downloadResult = await downloadWithTimeout(adapter.download(entryPath), label=entryPath) + fileBytes = getattr(downloadResult, "data", None) + dlFileName = getattr(downloadResult, "fileName", None) + dlMimeType = getattr(downloadResult, "mimeType", None) + if dlFileName: + fileName = dlFileName + if dlMimeType: + mimeType = dlMimeType + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return + except Exception as exc: + logger.warning("kdrive download %s failed: %s", entryPath, exc) + result.failed += 1 + result.errors.append(f"download({entryPath}): {exc}") + return + if not fileBytes: + result.failed += 1 + return + + result.bytesProcessed += len(fileBytes) + + try: + extracted = await extractWithTimeout( + runExtractionFn, + fileBytes, fileName, mimeType, + ExtractionOptions(mergeStrategy=None), + label=entryPath, + ) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return + except Exception as exc: + logger.warning("kdrive extraction %s failed: %s", entryPath, exc) + result.failed += 1 + result.errors.append(f"extract({entryPath}): {exc}") + return + + contentObjects = _toContentObjects(extracted, fileName) + if not contentObjects: + result.skippedPolicy += 1 + return + + provenance: Dict[str, Any] = { + "connectionId": connectionId, + "dataSourceId": dataSourceId, + "authority": "infomaniak", + "service": "kdrive", + "externalItemId": externalItemId, + "externalPath": entryPath, + "revision": revision, + } + try: + handle = await ingestWithTimeout( + knowledgeService.requestIngestion( + IngestionJob( + sourceKind="kdrive_item", + sourceId=syntheticFileId, + fileName=fileName, + mimeType=mimeType, + userId=userId, + mandateId=mandateId, + contentObjects=contentObjects, + contentVersion=revision, + neutralize=limits.neutralize, + provenance=provenance, + ) + ), + label=entryPath, + ) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return + except Exception as exc: + logger.error("kdrive ingestion %s failed: %s", entryPath, exc, exc_info=True) + result.failed += 1 + result.errors.append(f"ingest({entryPath}): {exc}") + return + + if handle.status == "duplicate": + result.skippedDuplicate += 1 + elif handle.status == "indexed": + result.indexed += 1 + else: + result.failed += 1 + if handle.error: + result.errors.append(f"ingest({entryPath}): {handle.error}") + + processed = result.indexed + result.skippedDuplicate + if progressCb is not None and processed % 5 == 0: + try: + progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert") + except Exception: + pass + + await asyncio.sleep(0) + + +def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]: + durationMs = int((time.time() - startMs) * 1000) + logger.info( + "ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d", + connectionId, + result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed, + durationMs, + extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive", + "connectionId": connectionId, "indexed": result.indexed, + "skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy, + "failed": result.failed, "durationMs": durationMs}, + ) + return { + "connectionId": result.connectionId, + "indexed": result.indexed, + "skippedDuplicate": result.skippedDuplicate, + "skippedPolicy": result.skippedPolicy, + "failed": result.failed, + "bytesProcessed": result.bytesProcessed, + "durationMs": durationMs, + "errors": result.errors[:20], + } diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py index 3f4a8afb..17220d97 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncOutlook.py @@ -21,6 +21,12 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional from modules.serviceCenter.services.serviceKnowledge.subTextClean import cleanEmailBody +from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( + WalkerTimeout, + extractWithTimeout, + ingestWithTimeout, + logItemStart, +) logger = logging.getLogger(__name__) @@ -384,34 +390,42 @@ async def _ingestMessage( subject = message.get("subject") or "(no subject)" syntheticId = _syntheticMessageId(connectionId, messageId) fileName = f"{subject[:80].strip()}.eml" if subject else f"{messageId}.eml" + logItemStart("outlook", messageId, mime="message/rfc822") contentObjects = _buildContentObjects( message, limits.maxBodyChars, mailContentDepth=limits.mailContentDepth ) # Always at least the header is emitted, so `contentObjects` is non-empty. try: - handle = await knowledgeService.requestIngestion( - IngestionJob( - sourceKind="outlook_message", - sourceId=syntheticId, - fileName=fileName, - mimeType="message/rfc822", - userId=userId, - mandateId=mandateId, - contentObjects=contentObjects, - contentVersion=revision, - neutralize=limits.neutralize, - provenance={ - "connectionId": connectionId, - "dataSourceId": dataSourceId, - "authority": "msft", - "service": "outlook", - "externalItemId": messageId, - "internetMessageId": message.get("internetMessageId"), - "tier": limits.mailContentDepth, - }, - ) + handle = await ingestWithTimeout( + knowledgeService.requestIngestion( + IngestionJob( + sourceKind="outlook_message", + sourceId=syntheticId, + fileName=fileName, + mimeType="message/rfc822", + userId=userId, + mandateId=mandateId, + contentObjects=contentObjects, + contentVersion=revision, + neutralize=limits.neutralize, + provenance={ + "connectionId": connectionId, + "dataSourceId": dataSourceId, + "authority": "msft", + "service": "outlook", + "externalItemId": messageId, + "internetMessageId": message.get("internetMessageId"), + "tier": limits.mailContentDepth, + }, + ) + ), + label=messageId, ) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return except Exception as exc: logger.error("outlook ingestion %s failed: %s", messageId, exc, exc_info=True) result.failed += 1 @@ -443,18 +457,16 @@ async def _ingestMessage( logger.warning("outlook attachments %s failed: %s", messageId, exc) result.errors.append(f"attachments({messageId}): {exc}") - if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0: - processed = result.indexed + result.skippedDuplicate + processed = result.indexed + result.skippedDuplicate + if progressCb is not None and processed % 5 == 0: try: - progressCb( - min(90, 10 + int(80 * processed / max(1, limits.maxMessages))), - f"outlook processed={processed}", - ) + progressCb(0, f"{processed} Mails verarbeitet, {result.indexed} indexiert") except Exception: pass - logger.info( - "ingestion.connection.bootstrap.progress part=outlook processed=%d skippedDup=%d failed=%d", - processed, result.skippedDuplicate, result.failed, + if processed % 50 == 0: + logger.info( + "ingestion.connection.bootstrap.progress part=outlook processed=%d skippedDup=%d failed=%d", + processed, result.skippedDuplicate, result.failed, extra={ "event": "ingestion.connection.bootstrap.progress", "part": "outlook", @@ -518,13 +530,22 @@ async def _ingestAttachments( mimeType = attachment.get("contentType") or "application/octet-stream" attachmentId = attachment.get("id") or fileName syntheticId = _syntheticAttachmentId(connectionId, messageId, attachmentId) + attLabel = f"{messageId}/att:{attachmentId}/{fileName}" + logItemStart("outlook-attachment", attLabel, sizeBytes=size or None, mime=mimeType) - try: - extracted = runExtraction( + def _runAttExtraction(): + return runExtraction( extractorRegistry, chunkerRegistry, rawBytes, fileName, mimeType, ExtractionOptions(mergeStrategy=None), ) + + try: + extracted = await extractWithTimeout(_runAttExtraction, label=attLabel) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + continue except Exception as exc: logger.warning("outlook attachment extract %s failed: %s", attachmentId, exc) result.failed += 1 @@ -556,28 +577,34 @@ async def _ingestAttachments( continue try: - await knowledgeService.requestIngestion( - IngestionJob( - sourceKind="outlook_attachment", - sourceId=syntheticId, - fileName=fileName, - mimeType=mimeType, - userId=userId, - mandateId=mandateId, - contentObjects=contentObjects, - neutralize=limits.neutralize, - provenance={ - "connectionId": connectionId, - "dataSourceId": dataSourceId, - "authority": "msft", - "service": "outlook", - "parentId": parentSyntheticId, - "externalItemId": attachmentId, - "parentMessageId": messageId, - }, - ) + await ingestWithTimeout( + knowledgeService.requestIngestion( + IngestionJob( + sourceKind="outlook_attachment", + sourceId=syntheticId, + fileName=fileName, + mimeType=mimeType, + userId=userId, + mandateId=mandateId, + contentObjects=contentObjects, + neutralize=limits.neutralize, + provenance={ + "connectionId": connectionId, + "dataSourceId": dataSourceId, + "authority": "msft", + "service": "outlook", + "parentId": parentSyntheticId, + "externalItemId": attachmentId, + "parentMessageId": messageId, + }, + ) + ), + label=attLabel, ) result.attachmentsIndexed += 1 + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) except Exception as exc: logger.warning("outlook attachment ingest %s failed: %s", attachmentId, exc) result.failed += 1 diff --git a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py index f664f1a8..892e41ba 100644 --- a/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py +++ b/modules/serviceCenter/services/serviceKnowledge/subConnectorSyncSharepoint.py @@ -20,6 +20,13 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional from modules.datamodels.datamodelExtraction import ExtractionOptions +from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import ( + WalkerTimeout, + downloadWithTimeout, + extractWithTimeout, + ingestWithTimeout, + logItemStart, +) logger = logging.getLogger(__name__) @@ -330,9 +337,15 @@ async def _ingestOne( syntheticFileId = _syntheticFileId(connectionId, externalItemId) fileName = getattr(entry, "name", "") or externalItemId + declaredSize = int(getattr(entry, "size", 0) or 0) or None + logItemStart("sharepoint", entryPath, sizeBytes=declaredSize, mime=mimeType) try: - fileBytes = await adapter.download(entryPath) + fileBytes = await downloadWithTimeout(adapter.download(entryPath), label=entryPath) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return except Exception as exc: logger.warning("sharepoint download %s failed: %s", entryPath, exc) result.failed += 1 @@ -345,10 +358,16 @@ async def _ingestOne( result.bytesProcessed += len(fileBytes) try: - extracted = runExtractionFn( + extracted = await extractWithTimeout( + runExtractionFn, fileBytes, fileName, mimeType, ExtractionOptions(mergeStrategy=None), + label=entryPath, ) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return except Exception as exc: logger.warning("sharepoint extraction %s failed: %s", entryPath, exc) result.failed += 1 @@ -370,20 +389,27 @@ async def _ingestOne( "revision": revision, } try: - handle = await knowledgeService.requestIngestion( - IngestionJob( - sourceKind="sharepoint_item", - sourceId=syntheticFileId, - fileName=fileName, - mimeType=mimeType, - userId=userId, - mandateId=mandateId, - contentObjects=contentObjects, - contentVersion=revision, - neutralize=limits.neutralize, - provenance=provenance, - ) + handle = await ingestWithTimeout( + knowledgeService.requestIngestion( + IngestionJob( + sourceKind="sharepoint_item", + sourceId=syntheticFileId, + fileName=fileName, + mimeType=mimeType, + userId=userId, + mandateId=mandateId, + contentObjects=contentObjects, + contentVersion=revision, + neutralize=limits.neutralize, + provenance=provenance, + ) + ), + label=entryPath, ) + except WalkerTimeout as exc: + result.failed += 1 + result.errors.append(str(exc)) + return except Exception as exc: logger.error("sharepoint ingestion %s failed: %s", entryPath, exc, exc_info=True) result.failed += 1 @@ -399,27 +425,17 @@ async def _ingestOne( if handle.error: result.errors.append(f"ingest({entryPath}): {handle.error}") - if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0: - processed = result.indexed + result.skippedDuplicate + processed = result.indexed + result.skippedDuplicate + if progressCb is not None and processed % 5 == 0: try: - progressCb( - min(90, 10 + int(80 * processed / max(1, limits.maxItems))), - f"sharepoint processed={processed}", - ) + progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert") except Exception: pass - logger.info( - "ingestion.connection.bootstrap.progress part=sharepoint processed=%d skippedDup=%d failed=%d", - processed, result.skippedDuplicate, result.failed, - extra={ - "event": "ingestion.connection.bootstrap.progress", - "part": "sharepoint", - "connectionId": connectionId, - "processed": processed, - "skippedDup": result.skippedDuplicate, - "failed": result.failed, - }, - ) + if processed % 50 == 0: + logger.info( + "ingestion.connection.bootstrap.progress part=sharepoint processed=%d indexed=%d failed=%d", + processed, result.indexed, result.failed, + ) # Yield so the event loop can interleave other tasks (download/extract are # CPU-ish and extraction uses sync libs; cooperative scheduling prevents diff --git a/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py b/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py new file mode 100644 index 00000000..8e65fd0f --- /dev/null +++ b/modules/serviceCenter/services/serviceKnowledge/subWalkerHelpers.py @@ -0,0 +1,116 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Shared helpers for ingestion walkers (timeouts, per-item logging). + +Walkers (sharepoint, gdrive, gmail, outlook, clickup, kdrive) all face the +same risks: + +- A single `adapter.download()` call can hang on the network for hours. +- A single `runExtraction()` call can hang on a corrupt PDF/Office doc inside + a sync extractor library, blocking the asyncio loop. +- A single `requestIngestion()` call can stall on the embedding API. + +Without timeouts, one bad item freezes the whole bootstrap job and we end +up with "Job stuck at 10% for 10h" zombies. + +These helpers wrap each phase in `asyncio.wait_for`. Sync extraction runs +on a worker thread so the loop stays responsive. Every wrapped call also +emits a short start/done log line, so when something hangs we know the +exact item that caused it (path, size, mime). +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Awaitable, Callable, Optional + +logger = logging.getLogger(__name__) + +DOWNLOAD_TIMEOUT_S = 60 +EXTRACTION_TIMEOUT_S = 90 +INGEST_TIMEOUT_S = 60 + + +class WalkerTimeout(Exception): + """Raised when a walker phase exceeds its timeout budget.""" + + +async def downloadWithTimeout( + awaitable: Awaitable[Any], + *, + label: str, + timeoutSeconds: int = DOWNLOAD_TIMEOUT_S, +) -> Any: + """Run a download awaitable with a hard timeout. + + `label` is a short human-readable identifier (typically the external path) + used in log messages so we can pinpoint the offending item in case of a + hang or timeout. + """ + logger.info("walker.download.start %s timeout=%ds", label, timeoutSeconds) + try: + result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds) + logger.debug("walker.download.done %s", label) + return result + except asyncio.TimeoutError as ex: + logger.warning("walker.download.timeout %s after %ds", label, timeoutSeconds) + raise WalkerTimeout(f"download timeout after {timeoutSeconds}s: {label}") from ex + + +async def extractWithTimeout( + syncFn: Callable[..., Any], + *args: Any, + label: str, + timeoutSeconds: int = EXTRACTION_TIMEOUT_S, +) -> Any: + """Run a synchronous extraction function on a worker thread with timeout. + + Sync extractors (PDF, OCR, MS Office) cannot be cancelled cleanly from + asyncio; `wait_for` only protects the awaiter. The underlying thread may + keep running until the process exits — but at least the walker proceeds + to the next item instead of freezing forever. + """ + logger.info("walker.extract.start %s timeout=%ds", label, timeoutSeconds) + try: + result = await asyncio.wait_for( + asyncio.to_thread(syncFn, *args), + timeout=timeoutSeconds, + ) + logger.debug("walker.extract.done %s", label) + return result + except asyncio.TimeoutError as ex: + logger.warning("walker.extract.timeout %s after %ds", label, timeoutSeconds) + raise WalkerTimeout(f"extract timeout after {timeoutSeconds}s: {label}") from ex + + +async def ingestWithTimeout( + awaitable: Awaitable[Any], + *, + label: str, + timeoutSeconds: int = INGEST_TIMEOUT_S, +) -> Any: + """Run an ingestion request with a hard timeout.""" + logger.debug("walker.ingest.start %s timeout=%ds", label, timeoutSeconds) + try: + result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds) + logger.debug("walker.ingest.done %s", label) + return result + except asyncio.TimeoutError as ex: + logger.warning("walker.ingest.timeout %s after %ds", label, timeoutSeconds) + raise WalkerTimeout(f"ingest timeout after {timeoutSeconds}s: {label}") from ex + + +def logItemStart(service: str, label: str, *, sizeBytes: Optional[int] = None, mime: Optional[str] = None) -> None: + """Log that processing of one item is about to begin. + + When the worker hangs, the LAST `walker.item.start` line in the log + points to the exact item that caused the freeze. This is the single + most valuable diagnostic for stuck-job triage. + """ + parts = [f"walker.item.start service={service} path={label}"] + if sizeBytes is not None: + parts.append(f"size={sizeBytes}") + if mime: + parts.append(f"mime={mime}") + logger.info(" ".join(parts)) diff --git a/modules/shared/aiAuditLogger.py b/modules/shared/aiAuditLogger.py index 04255ce1..5da105a8 100644 --- a/modules/shared/aiAuditLogger.py +++ b/modules/shared/aiAuditLogger.py @@ -85,6 +85,11 @@ class AiAuditLogger: try: from modules.datamodels.datamodelAiAudit import AiAuditLogEntry + if contentInput: + contentInput = contentInput.replace("\x00", "") + if contentOutput: + contentOutput = contentOutput.replace("\x00", "") + inputPreview = (contentInput or "")[:_PREVIEW_LENGTH] or None outputPreview = (contentOutput or "")[:_PREVIEW_LENGTH] or None inputHash = hashlib.sha256(contentInput.encode("utf-8")).hexdigest() if contentInput else None diff --git a/modules/system/mainSystem.py b/modules/system/mainSystem.py index 21d0cbee..aacc6d3c 100644 --- a/modules/system/mainSystem.py +++ b/modules/system/mainSystem.py @@ -330,6 +330,16 @@ NAVIGATION_SECTIONS = [ "adminOnly": True, "sysAdminOnly": True, }, + { + "id": "admin-stt-benchmark", + "objectKey": "ui.admin.sttBenchmark", + "label": t("STT Benchmark"), + "icon": "FaMicrophone", + "path": "/admin/stt-benchmark", + "order": 92, + "adminOnly": True, + "sysAdminOnly": True, + }, { "id": "admin-languages", "objectKey": "ui.admin.languages", diff --git a/tests/eval/__init__.py b/tests/eval/__init__.py new file mode 100644 index 00000000..fde23b13 --- /dev/null +++ b/tests/eval/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Eval harness for the Feature Data Sub-Agent (Phase 1.5).""" diff --git a/tests/eval/fakeFeatureDataProvider.py b/tests/eval/fakeFeatureDataProvider.py new file mode 100644 index 00000000..55557e7d --- /dev/null +++ b/tests/eval/fakeFeatureDataProvider.py @@ -0,0 +1,246 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""In-memory drop-in for FeatureDataProvider used by the eval harness. + +Implements the same three public methods (browseTable / queryTable / +aggregateTable) plus the small surface the Sub-Agent reads (getActualColumns), +but runs all filters/aggregations in Python over the BenchmarkFixture rows. + +This keeps the eval hermetic: no DB connection, no fixtures to insert/clean, +no flakiness from shared test schemas. Only the LLM call is real. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + + +_ALLOWED_AGGREGATES = {"SUM", "COUNT", "AVG", "MIN", "MAX"} + + +class FakeFeatureDataProvider: + """In-memory provider compatible with :class:`FeatureDataProvider`.""" + + def __init__( + self, + rowsByTable: Dict[str, List[Dict[str, Any]]], + availableTables: Optional[List[Dict[str, Any]]] = None, + ) -> None: + self._rowsByTable = {name: list(rows) for name, rows in rowsByTable.items()} + self._availableTables = list(availableTables or []) + self.callLog: List[Dict[str, Any]] = [] + + def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]: # noqa: ARG002 + return list(self._availableTables) + + def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]: # noqa: ARG002 + for obj in self._availableTables: + if obj.get("meta", {}).get("table") == tableName: + return obj + return None + + def getActualColumns(self, tableName: str) -> List[str]: + rows = self._rowsByTable.get(tableName, []) + if not rows: + return [] + seen: List[str] = [] + seenSet: set = set() + for row in rows: + for key in row.keys(): + if key not in seenSet: + seen.append(key) + seenSet.add(key) + return seen + + def browseTable( + self, + tableName: str, + featureInstanceId: str, + mandateId: str, + fields: List[str] = None, + limit: int = 50, + offset: int = 0, + extraFilters: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + self.callLog.append({"method": "browseTable", "table": tableName, "fields": fields, "limit": limit}) + rows = self._scopeRows(tableName, featureInstanceId, mandateId) + rows = _applyFilters(rows, extraFilters) + total = len(rows) + rows = rows[offset : offset + limit] + if fields: + rows = [{k: v for k, v in row.items() if k in fields} for row in rows] + return {"rows": rows, "total": total, "limit": limit, "offset": offset} + + def queryTable( + self, + tableName: str, + featureInstanceId: str, + mandateId: str, + filters: List[Dict[str, Any]] = None, + fields: List[str] = None, + orderBy: str = None, + limit: int = 50, + offset: int = 0, + extraFilters: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + self.callLog.append({ + "method": "queryTable", "table": tableName, "filters": filters, + "fields": fields, "orderBy": orderBy, "limit": limit, + }) + rows = self._scopeRows(tableName, featureInstanceId, mandateId) + combined = list(filters or []) + list(extraFilters or []) + rows = _applyFilters(rows, combined) + if orderBy: + try: + rows = sorted(rows, key=lambda r: (r.get(orderBy) is None, r.get(orderBy))) + except TypeError: + rows = sorted(rows, key=lambda r: str(r.get(orderBy))) + total = len(rows) + rows = rows[offset : offset + limit] + if fields: + rows = [{k: v for k, v in row.items() if k in fields} for row in rows] + return {"rows": rows, "total": total, "limit": limit, "offset": offset} + + def aggregateTable( + self, + tableName: str, + featureInstanceId: str, + mandateId: str, + aggregate: str, + field: str, + groupBy: str = None, + extraFilters: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + self.callLog.append({ + "method": "aggregateTable", "table": tableName, + "aggregate": aggregate, "field": field, "groupBy": groupBy, + }) + aggregate = aggregate.upper() + if aggregate not in _ALLOWED_AGGREGATES: + return {"rows": [], "error": f"Unsupported aggregate: {aggregate}"} + rows = self._scopeRows(tableName, featureInstanceId, mandateId) + rows = _applyFilters(rows, extraFilters) + + if groupBy: + groups: Dict[Any, List[Dict[str, Any]]] = {} + for row in rows: + groups.setdefault(row.get(groupBy), []).append(row) + outRows = [ + {"groupValue": key, "result": _aggregate(aggregate, [r.get(field) for r in grp])} + for key, grp in groups.items() + ] + outRows.sort(key=lambda r: (r["result"] is None, -(r["result"] or 0))) + else: + outRows = [{"result": _aggregate(aggregate, [r.get(field) for r in rows])}] + + return { + "rows": outRows, + "aggregate": aggregate, + "field": field, + "groupBy": groupBy, + } + + def _scopeRows(self, tableName: str, featureInstanceId: str, mandateId: str) -> List[Dict[str, Any]]: + rows = self._rowsByTable.get(tableName, []) + return [ + row for row in rows + if (row.get("featureInstanceId") in (None, featureInstanceId)) + and (row.get("mandateId") in (None, mandateId)) + ] + + +def _applyFilters(rows: List[Dict[str, Any]], filters: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]: + if not filters: + return rows + out = rows + for f in filters: + field = f.get("field") + op = (f.get("op") or "=").upper() + value = f.get("value") + out = [r for r in out if _matchesFilter(r.get(field), op, value)] + return out + + +def _matchesFilter(rowValue: Any, op: str, filterValue: Any) -> bool: + if op in ("IS NULL",): + return rowValue is None + if op in ("IS NOT NULL",): + return rowValue is not None + if rowValue is None: + return False + if op == "=": + return _coerceEqual(rowValue, filterValue) + if op == "!=": + return not _coerceEqual(rowValue, filterValue) + if op == ">": + return _coerceFloat(rowValue) > _coerceFloat(filterValue) + if op == "<": + return _coerceFloat(rowValue) < _coerceFloat(filterValue) + if op == ">=": + return _coerceFloat(rowValue) >= _coerceFloat(filterValue) + if op == "<=": + return _coerceFloat(rowValue) <= _coerceFloat(filterValue) + if op in ("LIKE", "ILIKE"): + pattern = str(filterValue or "") + target = str(rowValue) + if op == "ILIKE": + pattern = pattern.lower() + target = target.lower() + return _sqlLike(target, pattern) + if op == "IN": + if isinstance(filterValue, (list, tuple, set)): + return any(_coerceEqual(rowValue, v) for v in filterValue) + return _coerceEqual(rowValue, filterValue) + return False + + +def _coerceEqual(a: Any, b: Any) -> bool: + if a == b: + return True + try: + return str(a) == str(b) + except Exception: + return False + + +def _coerceFloat(value: Any) -> float: + if value is None: + return 0.0 + try: + return float(value) + except (TypeError, ValueError): + return 0.0 + + +def _sqlLike(value: str, pattern: str) -> bool: + """Approximate SQL LIKE -- only % and _ wildcards.""" + import re + regex = "" + i = 0 + while i < len(pattern): + ch = pattern[i] + if ch == "%": + regex += ".*" + elif ch == "_": + regex += "." + else: + regex += re.escape(ch) + i += 1 + return re.fullmatch(regex, value or "") is not None + + +def _aggregate(op: str, values: List[Any]) -> Any: + if op == "COUNT": + return sum(1 for v in values if v is not None) + nums = [_coerceFloat(v) for v in values if v is not None] + if not nums: + return 0 if op == "SUM" else None + if op == "SUM": + return round(sum(nums), 4) + if op == "AVG": + return round(sum(nums) / len(nums), 4) + if op == "MIN": + return min(nums) + if op == "MAX": + return max(nums) + return None diff --git a/tests/eval/runTrusteeBenchmark.py b/tests/eval/runTrusteeBenchmark.py new file mode 100644 index 00000000..3f298173 --- /dev/null +++ b/tests/eval/runTrusteeBenchmark.py @@ -0,0 +1,735 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Trustee Sub-Agent Eval Harness (Phase 1.5). + +Standalone runner that fires real AI calls against the Feature Data +Sub-Agent in three configurations: + +* ``baseline`` -- production code without the pre-execute validator + (Repair-Loop disabled, Trustee domain hints active). +* ``phase1`` -- pre-execute validator on (Repair-Loop active), + domain hints active, no ontology yet. +* ``phase2`` -- validator on, ontology-driven schema context + + constraints (replaces hand-written domain hints). + +For each mode we run all 19 gold-standard questions against an +in-memory :class:`FakeFeatureDataProvider`, capture the agent's tool +calls and final answer, score them against the gold standard, and +write a Markdown report to ``local/notes/`` for analysis. + +Usage:: + + cd gateway + python -m tests.eval.runTrusteeBenchmark # all 3 modes + python -m tests.eval.runTrusteeBenchmark phase1 # one mode only + python -m tests.eval.runTrusteeBenchmark baseline phase1 +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import re +import sys +import time +import uuid +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +# --------------------------------------------------------------------------- +# Path setup so `python -m tests.eval.runTrusteeBenchmark` works from gateway/ +# --------------------------------------------------------------------------- +_GATEWAY_DIR = Path(__file__).resolve().parents[2] +if str(_GATEWAY_DIR) not in sys.path: + sys.path.insert(0, str(_GATEWAY_DIR)) + +import yaml # noqa: E402 + +from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( # noqa: E402 + AgentConfig, + AgentEventTypeEnum, +) +from modules.datamodels.datamodelAi import ( # noqa: E402 + AiCallRequest, + AiCallResponse, + OperationTypeEnum, +) +from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop # noqa: E402 +from modules.serviceCenter.services.serviceAgent.featureDataAgent import ( # noqa: E402 + _buildSubAgentTools, + _buildSchemaContext, +) +from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( # noqa: E402 + QueryValidationError, +) +from modules.serviceCenter.services.serviceAgent.queryValidator import ( # noqa: E402 + QueryValidator, +) + +from tests.eval.fakeFeatureDataProvider import ( # noqa: E402 + FakeFeatureDataProvider, +) +from tests.fixtures.trusteeBenchmark.loadTrusteeBenchmarkFixture import ( # noqa: E402 + buildTrusteeBenchmarkFixture, + BenchmarkFixture, +) + + +logger = logging.getLogger("trusteeBenchmark") + + +# --------------------------------------------------------------------------- +# NoOpValidator -- baseline mode (Repair-Loop OFF) +# --------------------------------------------------------------------------- + + +class _NoOpValidator(QueryValidator): + """Validator that never rejects anything (used for baseline measurement).""" + + def validateBrowseQuery(self, tableName, args): # noqa: ARG002 + return None + + def validateQueryTable(self, tableName, args): # noqa: ARG002 + return None + + def validateAggregateQuery(self, tableName, args): # noqa: ARG002 + return None + + +# --------------------------------------------------------------------------- +# Mode-specific tool/prompt building +# --------------------------------------------------------------------------- + + +@dataclass +class _ModeConfig: + name: str + label: str + useValidator: bool + useOntology: bool + + +_MODES: Dict[str, _ModeConfig] = { + "baseline": _ModeConfig(name="baseline", label="Baseline (no validator)", useValidator=False, useOntology=False), + "phase1": _ModeConfig(name="phase1", label="Phase 1 (validator on)", useValidator=True, useOntology=False), + "phase2": _ModeConfig(name="phase2", label="Phase 2 (validator + ontology)", useValidator=True, useOntology=True), +} + + +def _buildValidator(mode: _ModeConfig) -> QueryValidator: + """Construct the per-mode validator. + + * baseline: no-op (Repair-Loop disabled, used to measure raw LLM + accuracy against today's prompt path). + * phase1: convention-based QueryValidator (NEVER_AGGREGATE on + ``*Balance``/``*Total`` suffixes; no ontology). + * phase2: ontology-driven QueryValidator (constraints from the + trustee ontology override the convention defaults). + """ + if not mode.useValidator: + return _NoOpValidator() + if mode.useOntology: + try: + from modules.features.trustee.trusteeOntology import getTrusteeOntology + return QueryValidator(ontology=getTrusteeOntology()) + except Exception as e: + logger.warning("Could not load trustee ontology, falling back: %s", e) + return QueryValidator() + + +def _applyEnvForMode(mode: _ModeConfig) -> None: + """Set the ontology toggle for the production prompt builder. + + The Phase 2 path uses ``featureDataAgent._buildSchemaContext`` to pull + the prompt block from ``getAgentOntology()`` automatically. For + baseline/phase1 we set ``POWERON_DISABLE_FEATURE_ONTOLOGY=1`` so the + builder falls back to the legacy ``getAgentDomainHints()`` block -- + measuring exactly the production prompt that ships today. + """ + if mode.useOntology: + os.environ.pop("POWERON_DISABLE_FEATURE_ONTOLOGY", None) + else: + os.environ["POWERON_DISABLE_FEATURE_ONTOLOGY"] = "1" + + +def _buildSystemPrompt(featureCode: str, instanceLabel: str, selectedTables: List[Dict[str, Any]]) -> str: + """Build the sub-agent system prompt via the production path. + + Mode-specific behaviour (legacy hints vs ontology block) is controlled + by the ``POWERON_DISABLE_FEATURE_ONTOLOGY`` env flag set per mode in + :func:`_applyEnvForMode`. Keeping the builder call identical for all + three modes means the benchmark measures the EXACT prompt the agent + would see in production -- no eval-only forks. + """ + return _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang="de") + + +# --------------------------------------------------------------------------- +# Question loading + per-question evaluation +# --------------------------------------------------------------------------- + + +@dataclass +class _Question: + id: str + question: str + intent: str + expectedTools: List[str] + expectedTable: Optional[str] + expectedAggregate: Optional[str] + expectedAggregateField: Optional[str] + requiredFilters: Dict[str, Any] + forbiddenTools: List[str] + expectedNumbers: List[float] + expectedAnswerContains: List[str] + numericTolerance: float + + +def _loadQuestions(yamlPath: Path) -> List[_Question]: + with open(yamlPath, "r", encoding="utf-8") as f: + rawList = yaml.safe_load(f) + questions: List[_Question] = [] + for raw in rawList: + questions.append(_Question( + id=raw["id"], + question=raw["question"], + intent=raw.get("intent", ""), + expectedTools=list(raw.get("expectedTools") or []), + expectedTable=raw.get("expectedTable"), + expectedAggregate=raw.get("expectedAggregate"), + expectedAggregateField=raw.get("expectedAggregateField"), + requiredFilters=dict(raw.get("requiredFilters") or {}), + forbiddenTools=list(raw.get("forbiddenTools") or []), + expectedNumbers=[float(x) for x in (raw.get("expectedNumbers") or [])], + expectedAnswerContains=[str(x) for x in (raw.get("expectedAnswerContains") or [])], + numericTolerance=float(raw.get("numericTolerance") or 0.005), + )) + return questions + + +@dataclass +class _RunResult: + questionId: str + finalText: str + toolCalls: List[Dict[str, Any]] = field(default_factory=list) + toolResults: List[Dict[str, Any]] = field(default_factory=list) + summary: Dict[str, Any] = field(default_factory=dict) + durationS: float = 0.0 + error: Optional[str] = None + + @property + def costCHF(self) -> float: + return float(self.summary.get("costCHF") or 0.0) + + @property + def rounds(self) -> int: + return int(self.summary.get("rounds") or 0) + + @property + def validationFailures(self) -> int: + return int(self.summary.get("validationFailures") or 0) + + @property + def repairAttempts(self) -> int: + return int(self.summary.get("repairAttempts") or 0) + + @property + def successAfterRepair(self) -> int: + return int(self.summary.get("successAfterRepair") or 0) + + +@dataclass +class _Score: + patternOk: bool = False + forbidOk: bool = False + numericOk: bool = False + accuracyOk: bool = False + notes: List[str] = field(default_factory=list) + + +def _scoreRun(question: _Question, run: _RunResult) -> _Score: + score = _Score() + if run.error: + score.notes.append(f"Sub-agent error: {run.error}") + return score + + score.patternOk = _checkPattern(question, run) + score.forbidOk = _checkForbid(question, run) + score.numericOk = _checkNumeric(question, run) + score.accuracyOk = score.patternOk and score.forbidOk and score.numericOk + return score + + +def _checkPattern(question: _Question, run: _RunResult) -> bool: + """Did the agent call one of the expected tools on the expected table with required filters?""" + if not question.expectedTools: + return True + matchingCalls = [ + c for c in run.toolCalls + if c.get("toolName") in question.expectedTools + and (not question.expectedTable or c.get("args", {}).get("tableName") == question.expectedTable) + ] + if not matchingCalls: + return False + + if question.expectedAggregate: + wantAgg = question.expectedAggregate.upper() + wantField = question.expectedAggregateField + for c in matchingCalls: + args = c.get("args", {}) + if c.get("toolName") != "aggregateTable": + continue + if (args.get("aggregate") or "").upper() != wantAgg: + continue + if wantField and args.get("field") != wantField: + continue + if not _filtersSatisfied(question.requiredFilters, args.get("extraFilters") or args.get("filters") or []): + continue + return True + return False + + if question.requiredFilters: + for c in matchingCalls: + args = c.get("args", {}) + filters = args.get("filters") or args.get("extraFilters") or [] + if _filtersSatisfied(question.requiredFilters, filters): + return True + return False + + return True + + +def _filtersSatisfied(required: Dict[str, Any], actualFilters: List[Dict[str, Any]]) -> bool: + if not required: + return True + for reqField, reqValue in required.items(): + if reqField.endswith("Like"): + field = reqField[:-4] + wanted = str(reqValue) + ok = any( + (f.get("field") == field) and (f.get("op", "").upper() in ("LIKE", "ILIKE")) + and str(f.get("value")) == wanted + for f in actualFilters + ) + if not ok: + return False + else: + ok = any( + f.get("field") == reqField and _filterValueEqual(f.get("value"), reqValue) + for f in actualFilters + ) + if not ok: + return False + return True + + +def _filterValueEqual(a: Any, b: Any) -> bool: + if a == b: + return True + try: + return str(a).strip() == str(b).strip() + except Exception: + return False + + +def _checkForbid(question: _Question, run: _RunResult) -> bool: + """Did the agent AVOID forbidden tool/op combinations? + + Forbidden hits only count if the call actually went through to the + provider (success=True). Validator-rejected calls don't count -- the + Repair-Loop is doing its job and steering the agent away. + """ + if not question.forbiddenTools: + return True + forbiddenSet = set(question.forbiddenTools) + for r in run.toolResults: + if not r.get("success"): + continue + if r.get("toolName") in forbiddenSet: + return False + return True + + +def _checkNumeric(question: _Question, run: _RunResult) -> bool: + text = (run.finalText or "") + if question.expectedNumbers: + textNumbers = _extractNumbers(text) + for expected in question.expectedNumbers: + tol = max(abs(expected) * question.numericTolerance, 0.5) + if not any(abs(n - expected) <= tol for n in textNumbers): + return False + + if question.expectedAnswerContains: + lowered = text.lower() + for needle in question.expectedAnswerContains: + if needle.lower() not in lowered: + return False + + return True + + +def _extractNumbers(text: str) -> List[float]: + """Pick out all numbers from a free-text answer. + + Handles Swiss thousand separators (apostrophe and U+2019), German + decimals (comma), plain integers/floats, and JSON numbers. Trailing + punctuation (``,``, ``;``, ``.`` from end-of-sentence) is stripped + before parsing so ``"180500.0,"`` parses cleanly to 180500.0. + """ + cleaned = text.replace("\u2019", "'") + tokens = re.findall(r"-?\d[\d'.,]*", cleaned) + out: List[float] = [] + for tok in tokens: + tok = tok.rstrip(",;") + if tok.endswith(".") and tok.count(".") == 1: + tok = tok[:-1] + norm = tok.replace("'", "") + if norm.count(",") == 1 and norm.count(".") == 0: + norm = norm.replace(",", ".") + elif norm.count(",") >= 1 and norm.count(".") >= 1: + if norm.rfind(",") > norm.rfind("."): + norm = norm.replace(".", "").replace(",", ".") + else: + norm = norm.replace(",", "") + else: + norm = norm.replace(",", "") + try: + out.append(float(norm)) + except ValueError: + continue + return out + + +# --------------------------------------------------------------------------- +# AI call wiring +# --------------------------------------------------------------------------- + + +def _bootstrapServices() -> Tuple[Any, str, str]: + """Spin up a minimal service hub bound to the root user + initial mandate. + + Returns the ServiceHub, the user id, and the mandate id used for billing. + """ + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelUam import Mandate + from modules.serviceHub import getInterface as getServices + + rootInterface = getRootInterface() + user = rootInterface.currentUser + mandateId = rootInterface.getInitialId(Mandate) + if not mandateId: + raise RuntimeError("No initial mandate available -- run bootstrap loader first.") + services = getServices(user, workflow=None, mandateId=mandateId, featureInstanceId=None) + return services, user.id, mandateId + + +async def _runOneQuestion( + *, + services: Any, + userId: str, + mandateId: str, + fixture: BenchmarkFixture, + question: _Question, + mode: _ModeConfig, +) -> _RunResult: + """Execute a single sub-agent run for one question under one mode.""" + provider = FakeFeatureDataProvider( + rowsByTable=fixture.rowsByTable, + availableTables=fixture.selectedTables, + ) + validator = _buildValidator(mode) + registry = _buildSubAgentTools( + provider=provider, + featureInstanceId=fixture.featureInstanceId, + mandateId=fixture.mandateId, + tableFilters={}, + validator=validator, + ) + + systemPrompt = _buildSystemPrompt( + featureCode="trustee", + instanceLabel="Demo AG", + selectedTables=fixture.selectedTables, + ) + + cost = 0.0 + + async def _aiCallFn(req: AiCallRequest) -> AiCallResponse: + nonlocal cost + resp = await services.ai.callAi(req) + cost += float(getattr(resp, "priceCHF", 0.0) or 0.0) + return resp + + async def _getCost() -> float: + return cost + + config = AgentConfig( + maxRounds=6, + maxCostCHF=0.50, + operationType=OperationTypeEnum.DATA_QUERY, + ) + + run = _RunResult(questionId=question.id, finalText="") + t0 = time.time() + try: + async for event in runAgentLoop( + prompt=question.question, + toolRegistry=registry, + config=config, + aiCallFn=_aiCallFn, + getWorkflowCostFn=_getCost, + workflowId=f"eval-{mode.name}-{question.id}-{uuid.uuid4().hex[:6]}", + userId=userId, + featureInstanceId=fixture.featureInstanceId, + mandateId=mandateId, + systemPromptOverride=systemPrompt, + ): + if event.type == AgentEventTypeEnum.FINAL: + run.finalText = event.content or run.finalText + elif event.type == AgentEventTypeEnum.MESSAGE and event.content: + run.finalText += event.content + elif event.type == AgentEventTypeEnum.TOOL_CALL: + run.toolCalls.append(dict(event.data or {})) + elif event.type == AgentEventTypeEnum.TOOL_RESULT: + run.toolResults.append(dict(event.data or {})) + elif event.type == AgentEventTypeEnum.AGENT_SUMMARY: + run.summary = dict(event.data or {}) + elif event.type == AgentEventTypeEnum.ERROR: + run.error = (run.error or "") + (event.content or "") + except Exception as e: + run.error = f"{type(e).__name__}: {e}" + logger.exception("Sub-agent run failed for %s/%s", mode.name, question.id) + run.durationS = time.time() - t0 + return run + + +# --------------------------------------------------------------------------- +# Report +# --------------------------------------------------------------------------- + + +@dataclass +class _ModeReport: + mode: _ModeConfig + perQuestion: List[Tuple[_Question, _RunResult, _Score]] = field(default_factory=list) + + @property + def total(self) -> int: + return len(self.perQuestion) + + def _count(self, attr: str) -> int: + return sum(1 for _, _, s in self.perQuestion if getattr(s, attr)) + + @property + def accuracy(self) -> float: + return self._count("accuracyOk") / max(self.total, 1) + + @property + def patternCompliance(self) -> float: + return self._count("patternOk") / max(self.total, 1) + + @property + def repairConversionRate(self) -> float: + attempts = sum(r.repairAttempts for _, r, _ in self.perQuestion) + succeeded = sum(r.successAfterRepair for _, r, _ in self.perQuestion) + if attempts == 0: + return 0.0 + return succeeded / attempts + + @property + def totalCostCHF(self) -> float: + return sum(r.costCHF for _, r, _ in self.perQuestion) + + @property + def totalRounds(self) -> int: + return sum(r.rounds for _, r, _ in self.perQuestion) + + @property + def totalValidationFailures(self) -> int: + return sum(r.validationFailures for _, r, _ in self.perQuestion) + + +def _writeReport(reports: List[_ModeReport], outputPath: Path) -> None: + lines: List[str] = [] + lines.append("# Trustee Sub-Agent Benchmark Report") + lines.append("") + lines.append(f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}") + lines.append("") + lines.append("## Summary") + lines.append("") + lines.append("| Mode | Questions | Accuracy | Pattern compliance | Repair conversion | Validator rejects | Rounds | Cost (CHF) |") + lines.append("|---|---|---|---|---|---|---|---|") + for rep in reports: + lines.append( + f"| {rep.mode.label} | {rep.total} | {rep.accuracy:.1%} | {rep.patternCompliance:.1%} | " + f"{rep.repairConversionRate:.1%} | {rep.totalValidationFailures} | {rep.totalRounds} | " + f"{rep.totalCostCHF:.4f} |" + ) + lines.append("") + lines.append("## Per-question detail") + for rep in reports: + lines.append("") + lines.append(f"### {rep.mode.label}") + lines.append("") + lines.append("| id | acc | pattern | forbid | numeric | rounds | val-fail | repairs | cost CHF | duration | tools |") + lines.append("|---|---|---|---|---|---|---|---|---|---|---|") + for q, r, s in rep.perQuestion: + toolList = ",".join( + f"{c.get('toolName')}({c.get('args',{}).get('tableName','?')})" + for c in r.toolCalls + ) + lines.append( + f"| {q.id} | {_yn(s.accuracyOk)} | {_yn(s.patternOk)} | {_yn(s.forbidOk)} | {_yn(s.numericOk)} | " + f"{r.rounds} | {r.validationFailures} | {r.repairAttempts}/{r.successAfterRepair} | " + f"{r.costCHF:.4f} | {r.durationS:.1f}s | {toolList} |" + ) + lines.append("") + lines.append("#### Notes & failures") + for q, r, s in rep.perQuestion: + if s.accuracyOk: + continue + lines.append(f"- **{q.id}** ({q.intent}): pattern={s.patternOk} forbid={s.forbidOk} numeric={s.numericOk}") + if r.error: + lines.append(f" - error: `{r.error}`") + lines.append(f" - answer: `{(r.finalText or '').strip().replace('|', '/').splitlines()[0][:240]}`") + for note in s.notes: + lines.append(f" - note: {note}") + outputPath.parent.mkdir(parents=True, exist_ok=True) + outputPath.write_text("\n".join(lines), encoding="utf-8") + + +def _yn(b: bool) -> str: + return "OK" if b else "FAIL" + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + + +async def _runMain(modesToRun: List[str], onlyQuestionId: Optional[str] = None) -> None: + logging.basicConfig( + level=logging.WARNING, + format="%(asctime)s %(levelname)s %(name)s -- %(message)s", + ) + logger.setLevel(logging.INFO) + + fixture = buildTrusteeBenchmarkFixture() + questionsPath = _GATEWAY_DIR / "tests" / "fixtures" / "trusteeBenchmark" / "questions.yaml" + allQuestions = _loadQuestions(questionsPath) + if onlyQuestionId: + allQuestions = [q for q in allQuestions if q.id == onlyQuestionId] + if not allQuestions: + print(f"No question matches id={onlyQuestionId!r}") + return + + print(f"Loaded {len(allQuestions)} questions, {len(modesToRun)} modes -> {len(allQuestions) * len(modesToRun)} sub-agent runs.") + + services, userId, mandateId = _bootstrapServices() + print(f"Bootstrap OK: user={userId}, mandate={mandateId}") + + reports: List[_ModeReport] = [] + for modeName in modesToRun: + mode = _MODES[modeName] + _applyEnvForMode(mode) + rep = _ModeReport(mode=mode) + print(f"\n=== Mode: {mode.label} ===") + for idx, question in enumerate(allQuestions, start=1): + print(f" [{idx:>2}/{len(allQuestions)}] {question.id}: {question.question[:80]} ...", flush=True) + run = await _runOneQuestion( + services=services, + userId=userId, + mandateId=mandateId, + fixture=fixture, + question=question, + mode=mode, + ) + score = _scoreRun(question, run) + rep.perQuestion.append((question, run, score)) + print( + f" -> acc={_yn(score.accuracyOk)} " + f"pattern={_yn(score.patternOk)} forbid={_yn(score.forbidOk)} " + f"numeric={_yn(score.numericOk)} rounds={run.rounds} cost={run.costCHF:.4f} " + f"val-fail={run.validationFailures} repairs={run.repairAttempts}/{run.successAfterRepair}", + flush=True, + ) + reports.append(rep) + + timestamp = time.strftime("%Y%m%d-%H%M%S") + outDir = _GATEWAY_DIR.parent / "local" / "notes" + reportPath = outDir / f"trustee-benchmark-{timestamp}.md" + _writeReport(reports, reportPath) + + rawJsonPath = outDir / f"trustee-benchmark-{timestamp}.json" + rawJsonPath.write_text( + json.dumps( + [ + { + "mode": rep.mode.name, + "accuracy": rep.accuracy, + "patternCompliance": rep.patternCompliance, + "repairConversionRate": rep.repairConversionRate, + "totalCostCHF": rep.totalCostCHF, + "totalRounds": rep.totalRounds, + "totalValidationFailures": rep.totalValidationFailures, + "items": [ + { + "questionId": q.id, + "intent": q.intent, + "accuracyOk": s.accuracyOk, + "patternOk": s.patternOk, + "forbidOk": s.forbidOk, + "numericOk": s.numericOk, + "rounds": r.rounds, + "validationFailures": r.validationFailures, + "repairAttempts": r.repairAttempts, + "successAfterRepair": r.successAfterRepair, + "costCHF": r.costCHF, + "durationS": r.durationS, + "finalText": (r.finalText or "")[:600], + "toolCalls": r.toolCalls, + "error": r.error, + } + for q, r, s in rep.perQuestion + ], + } + for rep in reports + ], + indent=2, + ensure_ascii=False, + ), + encoding="utf-8", + ) + + print(f"\nReport written: {reportPath}") + print(f"Raw JSON: {rawJsonPath}") + for rep in reports: + print(f" {rep.mode.label}: acc={rep.accuracy:.1%} pattern={rep.patternCompliance:.1%} cost={rep.totalCostCHF:.4f}") + + +def _parseArgs(argv: List[str]) -> Tuple[List[str], Optional[str]]: + modes: List[str] = [] + only: Optional[str] = None + for arg in argv: + if arg.startswith("--only="): + only = arg.split("=", 1)[1] + elif arg in _MODES: + modes.append(arg) + else: + print(f"Unknown argument: {arg!r}. Allowed modes: {list(_MODES)}") + sys.exit(2) + if not modes: + modes = ["baseline", "phase1", "phase2"] + return modes, only + + +def main() -> None: + modes, only = _parseArgs(sys.argv[1:]) + asyncio.run(_runMain(modes, onlyQuestionId=only)) + + +if __name__ == "__main__": + main() diff --git a/tests/fixtures/trusteeBenchmark/__init__.py b/tests/fixtures/trusteeBenchmark/__init__.py new file mode 100644 index 00000000..52f83ff7 --- /dev/null +++ b/tests/fixtures/trusteeBenchmark/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Trustee benchmark fixture: synthetic but realistic Swiss KMU accounting data. + +Used by the Feature Data Sub-Agent eval harness (Phase 1.5) to measure +hallucination rates against a fixed gold standard. Data is built in-memory +via Pydantic models -- no SQL, no DB connection -- so the harness stays +hermetic and reproducible. +""" + +from tests.fixtures.trusteeBenchmark.loadTrusteeBenchmarkFixture import ( + buildTrusteeBenchmarkFixture, + BenchmarkFixture, +) + +__all__ = ["buildTrusteeBenchmarkFixture", "BenchmarkFixture"] diff --git a/tests/fixtures/trusteeBenchmark/loadTrusteeBenchmarkFixture.py b/tests/fixtures/trusteeBenchmark/loadTrusteeBenchmarkFixture.py new file mode 100644 index 00000000..5eb77867 --- /dev/null +++ b/tests/fixtures/trusteeBenchmark/loadTrusteeBenchmarkFixture.py @@ -0,0 +1,275 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Synthetic Trustee benchmark fixture for the Feature Data Sub-Agent eval. + +Builds an in-memory snapshot of one fictional Swiss KMU mandate +("Demo AG") with: + +* 3 fiscal years (2023, 2024, 2025) of `TrusteeDataAccountBalance` rows + -- both annual totals (periodMonth=0) and monthly snapshots. +* 8 representative accounts spanning all major chart-of-accounts blocks + (cash, banks, receivables, payables, revenue, materials, personnel, + operating expenses). +* Per-month `TrusteeDataJournalEntry` + multiple `TrusteeDataJournalLine` + rows so debit/credit/COUNT aggregations have meaningful answers. + +The data is deterministic (no RNG) so a question's gold-standard answer +is stable across runs. + +This module deliberately stays decoupled from the production DB pipeline +-- the harness uses :class:`FakeFeatureDataProvider` (see +``gateway/tests/eval/fakeFeatureDataProvider.py``) to serve queries +against this in-memory snapshot, mirroring the public methods of +``FeatureDataProvider`` (browseTable / queryTable / aggregateTable). +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List + + +_MANDATE_ID = "m-demo-ag" +_FEATURE_INSTANCE_ID = "fi-demo-ag-trustee" + + +# --------------------------------------------------------------------------- +# Account master data +# --------------------------------------------------------------------------- + +_ACCOUNT_MASTER: List[Dict[str, Any]] = [ + {"accountNumber": "1000", "label": "Hauptkasse", "accountType": "asset", "currency": "CHF"}, + {"accountNumber": "1020", "label": "ZKB Geschaeftskonto", "accountType": "asset", "currency": "CHF"}, + {"accountNumber": "1021", "label": "PostFinance", "accountType": "asset", "currency": "CHF"}, + {"accountNumber": "1100", "label": "Forderungen aus Lieferungen und Leistungen", "accountType": "asset", "currency": "CHF"}, + {"accountNumber": "2000", "label": "Verbindlichkeiten aus Lieferungen", "accountType": "liability", "currency": "CHF"}, + {"accountNumber": "3000", "label": "Ertrag aus Beratung", "accountType": "revenue", "currency": "CHF"}, + {"accountNumber": "5400", "label": "Materialaufwand", "accountType": "expense", "currency": "CHF"}, + {"accountNumber": "6000", "label": "Mietaufwand", "accountType": "expense", "currency": "CHF"}, +] + + +# Annual closing balances per (year, accountNumber) -- the canonical reference. +# Asset/expense balances are positive, liability/revenue balances are stored +# as positive numbers (sign by accountType, like most accounting systems). +_ANNUAL_CLOSING: Dict[int, Dict[str, float]] = { + 2023: { + "1000": 4_800.00, + "1020": 132_500.00, + "1021": 22_400.00, + "1100": 58_200.00, + "2000": 41_300.00, + "3000": 410_000.00, + "5400": 92_000.00, + "6000": 36_000.00, + }, + 2024: { + "1000": 5_200.00, + "1020": 148_900.00, + "1021": 26_750.00, + "1100": 61_400.00, + "2000": 44_100.00, + "3000": 462_500.00, + "5400": 104_300.00, + "6000": 39_000.00, + }, + 2025: { + "1000": 5_900.00, + "1020": 152_400.00, + "1021": 28_100.00, + "1100": 66_800.00, + "2000": 47_900.00, + "3000": 488_700.00, + "5400": 112_100.00, + "6000": 42_000.00, + }, +} + + +def _openingFromPriorYear(year: int, accountNumber: str) -> float: + """Opening balance of year N = closing balance of year N-1 (0 if N-1 is unknown).""" + prior = year - 1 + return float(_ANNUAL_CLOSING.get(prior, {}).get(accountNumber, 0.0)) + + +def _monthlyProgression(opening: float, closing: float, month: int) -> float: + """Linear interpolation between opening and closing for monthly snapshots. + + Not realistic in detail but deterministic and monotonic per account, so + questions about "Stand per Ende März" produce stable answers. + """ + if month <= 0: + return float(closing) + frac = month / 12.0 + return round(float(opening) + (float(closing) - float(opening)) * frac, 2) + + +# --------------------------------------------------------------------------- +# Journal entries / lines -- minimal but realistic +# --------------------------------------------------------------------------- + +_JOURNAL_ENTRIES_2025: List[Dict[str, Any]] = [ + {"month": 3, "day": 15, "reference": "RG-2025-0042", "description": "Beratung Kunde ACME AG", "amount": 18_500.00, "debit": "1100", "credit": "3000"}, + {"month": 3, "day": 22, "reference": "EK-2025-0017", "description": "Materialeinkauf Buehler AG", "amount": 9_200.00, "debit": "5400", "credit": "2000"}, + {"month": 3, "day": 28, "reference": "MIETE-2025-03", "description": "Mietzins Buero Maerz", "amount": 3_000.00, "debit": "6000", "credit": "1020"}, + {"month": 4, "day": 5, "reference": "RG-2025-0051", "description": "Beratung Kunde Bell AG", "amount": 24_300.00, "debit": "1100", "credit": "3000"}, + {"month": 4, "day": 18, "reference": "EK-2025-0024", "description": "Materialeinkauf Industriebedarf", "amount": 7_800.00, "debit": "5400", "credit": "2000"}, + {"month": 6, "day": 12, "reference": "RG-2025-0079", "description": "Beratung Kunde Bell AG", "amount": 32_100.00, "debit": "1100", "credit": "3000"}, + {"month": 6, "day": 30, "reference": "MIETE-2025-Q2", "description": "Mietzins Buero Q2-Abrechnung", "amount": 3_500.00, "debit": "6000", "credit": "1020"}, + {"month": 9, "day": 4, "reference": "RG-2025-0114", "description": "Beratung Kunde Migros", "amount": 41_500.00, "debit": "1100", "credit": "3000"}, + {"month": 9, "day": 25, "reference": "EK-2025-0061", "description": "Materialeinkauf Buehler AG", "amount": 12_400.00, "debit": "5400", "credit": "2000"}, + {"month": 11, "day": 14, "reference": "RG-2025-0188", "description": "Beratung Kunde ACME AG", "amount": 28_700.00, "debit": "1100", "credit": "3000"}, +] + + +# --------------------------------------------------------------------------- +# Snapshot containers +# --------------------------------------------------------------------------- + +@dataclass +class BenchmarkFixture: + """In-memory rows that mimic feature DB tables. + + Each ``rowsByTable[tableName]`` is a list of column dicts compatible + with the Pydantic feature data models (TrusteeDataAccountBalance, etc.). + """ + mandateId: str + featureInstanceId: str + rowsByTable: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict) + selectedTables: List[Dict[str, Any]] = field(default_factory=list) + + +def _buildSelectedTables() -> List[Dict[str, Any]]: + """Return the DATA_OBJECT-shaped descriptors the sub-agent expects. + + Mirrors what the catalog would return for the trustee feature; the + real `getDataObjects("trustee")` call would yield the same shape but + we hard-code the three tables we actually populate. + """ + return [ + { + "objectKey": "data.feature.trustee.TrusteeDataAccount", + "label": {"de": "Kontenplan", "en": "Chart of accounts"}, + "meta": { + "table": "TrusteeDataAccount", + "fields": ["id", "accountNumber", "label", "accountType", "currency", "isActive"], + }, + }, + { + "objectKey": "data.feature.trustee.TrusteeDataAccountBalance", + "label": {"de": "Kontosalden", "en": "Account balances"}, + "meta": { + "table": "TrusteeDataAccountBalance", + "fields": [ + "id", "accountNumber", "periodYear", "periodMonth", + "openingBalance", "debitTotal", "creditTotal", + "closingBalance", "currency", + ], + }, + }, + { + "objectKey": "data.feature.trustee.TrusteeDataJournalLine", + "label": {"de": "Buchungszeilen", "en": "Journal lines"}, + "meta": { + "table": "TrusteeDataJournalLine", + "fields": [ + "id", "journalEntryId", "accountNumber", + "debitAmount", "creditAmount", "currency", "description", + ], + }, + }, + ] + + +def buildTrusteeBenchmarkFixture() -> BenchmarkFixture: + """Materialize the full in-memory benchmark snapshot. + + All rows include ``mandateId`` and ``featureInstanceId`` columns so the + fake provider can scope them the same way the real one does. + """ + accountRows: List[Dict[str, Any]] = [] + for i, acc in enumerate(_ACCOUNT_MASTER): + accountRows.append({ + "id": f"acc-{i:03d}", + "accountNumber": acc["accountNumber"], + "label": acc["label"], + "accountType": acc["accountType"], + "currency": acc["currency"], + "isActive": True, + "mandateId": _MANDATE_ID, + "featureInstanceId": _FEATURE_INSTANCE_ID, + }) + + balanceRows: List[Dict[str, Any]] = [] + rowIdx = 0 + for year, closings in _ANNUAL_CLOSING.items(): + for accountNumber, closing in closings.items(): + opening = _openingFromPriorYear(year, accountNumber) + balanceRows.append({ + "id": f"bal-{rowIdx:04d}", + "accountNumber": accountNumber, + "periodYear": year, + "periodMonth": 0, + "openingBalance": opening, + "debitTotal": round(max(closing - opening, 0.0) * 1.2, 2), + "creditTotal": round(max(closing - opening, 0.0) * 0.2, 2), + "closingBalance": float(closing), + "currency": "CHF", + "mandateId": _MANDATE_ID, + "featureInstanceId": _FEATURE_INSTANCE_ID, + }) + rowIdx += 1 + for month in range(1, 13): + monthly = _monthlyProgression(opening, closing, month) + balanceRows.append({ + "id": f"bal-{rowIdx:04d}", + "accountNumber": accountNumber, + "periodYear": year, + "periodMonth": month, + "openingBalance": opening, + "debitTotal": round((monthly - opening) * 1.2, 2) if monthly > opening else 0.0, + "creditTotal": round((monthly - opening) * 0.2, 2) if monthly > opening else 0.0, + "closingBalance": monthly, + "currency": "CHF", + "mandateId": _MANDATE_ID, + "featureInstanceId": _FEATURE_INSTANCE_ID, + }) + rowIdx += 1 + + lineRows: List[Dict[str, Any]] = [] + for j, entry in enumerate(_JOURNAL_ENTRIES_2025): + entryId = f"je-2025-{j:03d}" + lineRows.append({ + "id": f"jl-{j*2:04d}", + "journalEntryId": entryId, + "accountNumber": entry["debit"], + "debitAmount": float(entry["amount"]), + "creditAmount": 0.0, + "currency": "CHF", + "description": entry["description"], + "mandateId": _MANDATE_ID, + "featureInstanceId": _FEATURE_INSTANCE_ID, + }) + lineRows.append({ + "id": f"jl-{j*2+1:04d}", + "journalEntryId": entryId, + "accountNumber": entry["credit"], + "debitAmount": 0.0, + "creditAmount": float(entry["amount"]), + "currency": "CHF", + "description": entry["description"], + "mandateId": _MANDATE_ID, + "featureInstanceId": _FEATURE_INSTANCE_ID, + }) + + fixture = BenchmarkFixture( + mandateId=_MANDATE_ID, + featureInstanceId=_FEATURE_INSTANCE_ID, + rowsByTable={ + "TrusteeDataAccount": accountRows, + "TrusteeDataAccountBalance": balanceRows, + "TrusteeDataJournalLine": lineRows, + }, + selectedTables=_buildSelectedTables(), + ) + return fixture diff --git a/tests/fixtures/trusteeBenchmark/questions.yaml b/tests/fixtures/trusteeBenchmark/questions.yaml new file mode 100644 index 00000000..7d277cae --- /dev/null +++ b/tests/fixtures/trusteeBenchmark/questions.yaml @@ -0,0 +1,226 @@ +# Trustee Sub-Agent Benchmark -- 19 questions analog Hein 2025 +# +# Each question covers ONE expected hallucination class so we can attribute +# accuracy gains to specific phases (validator / ontology). +# +# Scoring per question (all binary unless noted): +# patternOk -- did the agent call the right tool(s) with the right filters? +# forbidOk -- did it AVOID the forbidden tool/op (e.g. SUM closingBalance)? +# numericOk -- does the final answer contain the expected number(s)? +# accuracyOk -- patternOk AND forbidOk AND numericOk +# +# tolerance: relative tolerance for numeric comparison (default 0.005 = 0.5 %). + +- id: q01 + question: "Was ist der Banksaldo per 31.12.2025 fuer das ZKB-Konto 1020?" + intent: BANK_BALANCE_AT_DATE + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + accountNumber: "1020" + periodYear: 2025 + periodMonth: 0 + forbiddenTools: [aggregateTable] + expectedNumbers: [152400.0] + +- id: q02 + question: "Wie hoch ist die Hauptkasse (Konto 1000) per Ende 2024?" + intent: CASH_BALANCE_AT_DATE + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + accountNumber: "1000" + periodYear: 2024 + periodMonth: 0 + forbiddenTools: [aggregateTable] + expectedNumbers: [5200.0] + +- id: q03 + question: "Summiere alle Bankkonten (102x) per 31.12.2025." + intent: BANK_GROUP_TOTAL_AT_DATE + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + periodYear: 2025 + periodMonth: 0 + accountNumberLike: "102%" + forbiddenTools: [aggregateTable] + expectedNumbers: [180500.0] + numericTolerance: 0.01 + +- id: q04 + question: "Wie hat sich der Schlusssaldo des ZKB-Kontos 1020 ueber die Jahre 2023 bis 2025 entwickelt?" + intent: BALANCE_HISTORY_PER_YEAR + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + accountNumber: "1020" + periodMonth: 0 + forbiddenTools: [aggregateTable] + expectedNumbers: [132500.0, 148900.0, 152400.0] + +- id: q05 + question: "Welches Konto hatte 2025 den hoechsten Schlusssaldo bei den Aktiven (1xxx)?" + intent: TOP_ASSET_AT_DATE + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + periodYear: 2025 + periodMonth: 0 + accountNumberLike: "1%" + forbiddenTools: [aggregateTable] + expectedAnswerContains: ["1020"] + expectedNumbers: [152400.0] + +- id: q06 + question: "Welche Konten gehoeren zu den Bankkonten (102x)?" + intent: ACCOUNT_LIST_FILTER + expectedTools: [queryTable] + expectedTable: TrusteeDataAccount + requiredFilters: + accountNumberLike: "102%" + forbiddenTools: [aggregateTable] + expectedAnswerContains: ["1020", "1021"] + +- id: q07 + question: "Wie hoch war der Materialaufwand (Konto 5400) im Jahr 2025?" + intent: EXPENSE_AT_YEAR + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + accountNumber: "5400" + periodYear: 2025 + periodMonth: 0 + forbiddenTools: [aggregateTable] + expectedNumbers: [112100.0] + +- id: q08 + question: "Wie viele Buchungszeilen gibt es insgesamt im System?" + intent: COUNT_ROWS + expectedTools: [aggregateTable] + expectedTable: TrusteeDataJournalLine + expectedAggregate: COUNT + forbiddenTools: [] + expectedNumbers: [20] + +- id: q09 + question: "Wie hoch ist der gesamte Beratungsertrag (Konto 3000) im Jahr 2025?" + intent: REVENUE_AT_YEAR + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + accountNumber: "3000" + periodYear: 2025 + periodMonth: 0 + forbiddenTools: [aggregateTable] + expectedNumbers: [488700.0] + +- id: q10 + question: "Wie viel wurde 2025 auf das Materialaufwand-Konto 5400 gebucht (Soll-Summe ueber Buchungszeilen)?" + intent: JOURNAL_SUM_AT_ACCOUNT + expectedTools: [aggregateTable] + expectedTable: TrusteeDataJournalLine + expectedAggregate: SUM + expectedAggregateField: debitAmount + requiredFilters: + accountNumber: "5400" + forbiddenTools: [] + expectedNumbers: [29400.0] + numericTolerance: 0.01 + +- id: q11 + question: "Welche Buchungen im 1. Quartal 2025 (Januar bis Maerz) wurden auf Konto 3000 gebucht?" + intent: JOURNAL_LINES_BY_ACCOUNT + expectedTools: [queryTable] + expectedTable: TrusteeDataJournalLine + requiredFilters: + accountNumber: "3000" + forbiddenTools: [aggregateTable] + expectedAnswerContains: ["18500", "ACME"] + +- id: q12 + question: "Wie hoch war die Hauptkasse (Konto 1000) jeweils per Ende Maerz 2025 und per Ende Juni 2025?" + intent: MULTI_MONTH_SNAPSHOT + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + accountNumber: "1000" + periodYear: 2025 + forbiddenTools: [aggregateTable] + expectedNumbers: [5375.0, 5550.0] + numericTolerance: 0.01 + +- id: q13 + question: "Wie hoch ist die Summe aller Aufwandskonten (5xxx und 6xxx) per Ende 2025?" + intent: EXPENSE_GROUP_TOTAL + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + periodYear: 2025 + periodMonth: 0 + forbiddenTools: [aggregateTable] + expectedNumbers: [154100.0] + numericTolerance: 0.01 + +- id: q14 + question: "Welches Konto hat den hoechsten openingBalance fuer 2025?" + intent: TOP_OPENING_BALANCE + # Both routes are legitimate: queryTable+orderBy+limit=1, or + # aggregateTable(MAX) followed by queryTable lookup. We only insist that + # the final answer names the right account and (optionally) the value. + expectedTools: [queryTable, aggregateTable] + expectedTable: TrusteeDataAccountBalance + forbiddenTools: [] + expectedAnswerContains: ["3000"] + expectedNumbers: [462500.0] + +- id: q15 + question: "Liste alle Konten vom Typ asset auf." + intent: ACCOUNTS_BY_TYPE + expectedTools: [queryTable] + expectedTable: TrusteeDataAccount + requiredFilters: + accountType: "asset" + forbiddenTools: [aggregateTable] + expectedAnswerContains: ["1000", "1020", "1021", "1100"] + +- id: q16 + question: "Wie hoch ist der Schlusssaldo der Forderungen aus Lieferungen und Leistungen (Konto 1100) per Ende 2025?" + intent: BALANCE_BY_NAME_LOOKUP + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + accountNumber: "1100" + periodYear: 2025 + periodMonth: 0 + forbiddenTools: [aggregateTable] + expectedNumbers: [66800.0] + +- id: q17 + question: "Wie hoch waren die Verbindlichkeiten (Konto 2000) jeweils per Ende 2023, 2024 und 2025?" + intent: LIABILITY_HISTORY + expectedTools: [queryTable] + expectedTable: TrusteeDataAccountBalance + requiredFilters: + accountNumber: "2000" + periodMonth: 0 + forbiddenTools: [aggregateTable] + expectedNumbers: [41300.0, 44100.0, 47900.0] + +- id: q18 + question: "Wie viele Bankkonten gibt es im Kontenplan (102x)?" + intent: ACCOUNT_COUNT_BY_PREFIX + expectedTools: [queryTable, aggregateTable] + expectedTable: TrusteeDataAccount + requiredFilters: + accountNumberLike: "102%" + forbiddenTools: [] + expectedNumbers: [2] + +- id: q19 + question: "Gib mir alle Buchungszeilen mit einem Sollbetrag groesser als 20'000 CHF." + intent: JOURNAL_LINES_BY_AMOUNT + expectedTools: [queryTable] + expectedTable: TrusteeDataJournalLine + forbiddenTools: [aggregateTable] + expectedAnswerContains: ["24300", "32100", "41500", "28700"] diff --git a/tests/unit/serviceAgent/test_agentTrace_repairCounters.py b/tests/unit/serviceAgent/test_agentTrace_repairCounters.py new file mode 100644 index 00000000..4a0909d1 --- /dev/null +++ b/tests/unit/serviceAgent/test_agentTrace_repairCounters.py @@ -0,0 +1,112 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Unit tests for the repair-loop telemetry aggregation in agentLoop. + +These counters (`validationFailures`, `repairAttempts`, `successAfterRepair`) +land on `AgentTrace` and are surfaced via the `AGENT_SUMMARY` event. The +Eval-Harness (Phase 1.5) reads them to compute the repair conversion rate. +""" + +from __future__ import annotations + +from modules.serviceCenter.services.serviceAgent.agentLoop import _computeRepairCounters +from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( + AgentRoundLog, ToolCallLog, +) + + +def _round(*toolCalls: ToolCallLog) -> AgentRoundLog: + return AgentRoundLog(roundNumber=1, toolCalls=list(toolCalls)) + + +def _failed(toolName: str, code: str) -> ToolCallLog: + return ToolCallLog( + toolName=toolName, + success=False, + validationFailureCode=code, + error=f"{code}: ...", + ) + + +def _ok(toolName: str) -> ToolCallLog: + return ToolCallLog(toolName=toolName, success=True) + + +def test_computeRepairCounters_emptyTrace(): + fails, attempts, succeeded = _computeRepairCounters([]) + assert (fails, attempts, succeeded) == (0, 0, 0) + + +def test_computeRepairCounters_allCleanRunsHaveZeroCounters(): + rounds = [ + _round(_ok("queryTable"), _ok("browseTable")), + _round(_ok("aggregateTable")), + ] + fails, attempts, succeeded = _computeRepairCounters(rounds) + assert (fails, attempts, succeeded) == (0, 0, 0) + + +def test_computeRepairCounters_singleFailureCountsButNoRepairYet(): + """One failure in round 1, no follow-up call -- counts the failure but + nothing else. Repair only counts when the LLM tries again.""" + rounds = [_round(_failed("queryTable", "FIELD_NOT_FOUND"))] + fails, attempts, succeeded = _computeRepairCounters(rounds) + assert (fails, attempts, succeeded) == (1, 0, 0) + + +def test_computeRepairCounters_repairThatSucceeds(): + """Round 1 fails, round 2 retries same tool successfully.""" + rounds = [ + _round(_failed("queryTable", "FIELD_NOT_FOUND")), + _round(_ok("queryTable")), + ] + fails, attempts, succeeded = _computeRepairCounters(rounds) + assert (fails, attempts, succeeded) == (1, 1, 1) + + +def test_computeRepairCounters_repairThatFailsAgain(): + """Round 1 fails, round 2 retries same tool but fails validation again.""" + rounds = [ + _round(_failed("queryTable", "FIELD_NOT_FOUND")), + _round(_failed("queryTable", "FIELD_NOT_FOUND")), + ] + fails, attempts, succeeded = _computeRepairCounters(rounds) + assert (fails, attempts, succeeded) == (2, 1, 0) + + +def test_computeRepairCounters_siblingCallsInSameRoundAreNotRepairs(): + """When the LLM emits two queryTable calls in the same round, the + second is NOT a repair attempt -- it had no way to see the first + one's rejection yet (parallel dispatch within a round).""" + rounds = [ + _round( + _failed("queryTable", "FIELD_NOT_FOUND"), + _failed("queryTable", "FIELD_NOT_FOUND"), + ), + ] + fails, attempts, succeeded = _computeRepairCounters(rounds) + assert (fails, attempts, succeeded) == (2, 0, 0) + + +def test_computeRepairCounters_differentToolNamesAreIndependent(): + """A queryTable failure does not flag a later browseTable as a repair.""" + rounds = [ + _round(_failed("queryTable", "FIELD_NOT_FOUND")), + _round(_ok("browseTable")), + ] + fails, attempts, succeeded = _computeRepairCounters(rounds) + assert (fails, attempts, succeeded) == (1, 0, 0) + + +def test_computeRepairCounters_multiToolMix(): + """Trustee-like sequence: SUM(closingBalance) rejected, LLM switches to + queryTable with a typo (rejected), then fixes the typo (success).""" + rounds = [ + _round(_failed("aggregateTable", "INVALID_AGGREGATE_TARGET")), + _round(_failed("queryTable", "FIELD_NOT_FOUND")), + _round(_ok("queryTable")), + ] + fails, attempts, succeeded = _computeRepairCounters(rounds) + # 2 validation failures total, 1 prior-rejected queryTable retry that + # succeeded; aggregateTable was never retried so no attempt counted for it. + assert (fails, attempts, succeeded) == (2, 1, 1) diff --git a/tests/unit/services/test_featureDataAgent_schema.py b/tests/unit/services/test_featureDataAgent_schema.py index ef37753b..616f46cc 100644 --- a/tests/unit/services/test_featureDataAgent_schema.py +++ b/tests/unit/services/test_featureDataAgent_schema.py @@ -19,11 +19,18 @@ asked for the closing balance per period). from __future__ import annotations +import asyncio +from unittest.mock import MagicMock + import pytest from modules.shared import fkRegistry +from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( + ToolCallRequest, ToolResult, +) from modules.serviceCenter.services.serviceAgent.featureDataAgent import ( _buildSchemaContext, + _buildSubAgentTools, _buildTableSchemaBlock, _formatFieldLine, _summarizePythonType, @@ -152,10 +159,29 @@ def test_buildSchemaContext_forbidsSummingAggregateFields(): assert "closingBalance" in prompt -def test_buildSchemaContext_appendsTrusteeDomainHints(): - """When the feature module exposes getAgentDomainHints(), the schema prompt - must include those hints so the sub-agent knows e.g. that 102x are bank - accounts and periodMonth=0 is the annual total.""" +def test_buildSchemaContext_appendsTrusteeOntologyBlock(monkeypatch): + """When the feature exposes getAgentOntology(), the schema prompt must + include the compiled ontology block (Phase 2 path).""" + monkeypatch.delenv("POWERON_DISABLE_FEATURE_ONTOLOGY", raising=False) + selected = [_trusteeAccountBalanceObj()] + prompt = _buildSchemaContext( + featureCode="trustee", + instanceLabel="Demo AG", + selectedTables=selected, + requestLang="de", + ) + assert "DOMAIN ONTOLOGY (trustee):" in prompt + assert "BankAccount" in prompt + assert "NEVER_AGGREGATE on TrusteeDataAccountBalance.closingBalance" in prompt.replace("never aggregate", "NEVER_AGGREGATE") + assert "BANK_BALANCE_AT_DATE" in prompt + + +def test_buildSchemaContext_fallsBackToLegacyHints_whenOntologyDisabled(monkeypatch): + """With POWERON_DISABLE_FEATURE_ONTOLOGY=1 the builder must fall back to + the legacy `getAgentDomainHints()` block. This is the path used by the + eval harness to measure `baseline` and `phase1` accuracy without the + ontology-driven prompt.""" + monkeypatch.setenv("POWERON_DISABLE_FEATURE_ONTOLOGY", "1") selected = [_trusteeAccountBalanceObj()] prompt = _buildSchemaContext( featureCode="trustee", @@ -164,16 +190,14 @@ def test_buildSchemaContext_appendsTrusteeDomainHints(): requestLang="de", ) assert "TRUSTEE DOMAIN HINTS" in prompt + assert "DOMAIN ONTOLOGY" not in prompt assert "102x Bank / Post" in prompt - assert "periodMonth = 0" in prompt - assert "ANTI-PATTERNS" in prompt - assert 'LIKE \'102%\'' in prompt or "LIKE '102%'" in prompt -def test_buildSchemaContext_skipsHintsForFeaturesWithoutHook(): - """Features that don't export getAgentDomainHints() should produce a prompt - without the trailing hints block. Verified by using a feature code that - cannot resolve to a main module (registry returns None).""" +def test_buildSchemaContext_skipsHintsForFeaturesWithoutHook(monkeypatch): + """Features that don't export getAgentDomainHints()/getAgentOntology() + should produce a prompt without any trailing hints block.""" + monkeypatch.delenv("POWERON_DISABLE_FEATURE_ONTOLOGY", raising=False) selected = [_trusteeAccountBalanceObj()] prompt = _buildSchemaContext( featureCode="nosuchfeature", @@ -182,4 +206,90 @@ def test_buildSchemaContext_skipsHintsForFeaturesWithoutHook(): requestLang="de", ) assert "TRUSTEE DOMAIN HINTS" not in prompt + assert "DOMAIN ONTOLOGY" not in prompt assert "Keep your answer SHORT" in prompt + + +# ------------------------------------------------------------------ +# Validator integration (Phase 1: Repair-Loop) +# +# These tests guard that pre-execute validation fires BEFORE the provider +# is touched, and that the structured error payload reaches the LLM via +# `ToolResult.errorDetails` -- the contract the LLM relies on for repair. +# ------------------------------------------------------------------ + + +def _buildRegistryWithMockProvider(): + """Build a sub-agent ToolRegistry where the provider is a MagicMock. + + The mock records calls so we can assert the validator short-circuits + before the DB layer is reached.""" + provider = MagicMock() + provider.browseTable.return_value = {"rows": [], "total": 0, "limit": 50, "offset": 0} + provider.queryTable.return_value = {"rows": [], "total": 0, "limit": 50, "offset": 0} + provider.aggregateTable.return_value = {"rows": [], "aggregate": "SUM", "field": "x"} + registry = _buildSubAgentTools( + provider=provider, + featureInstanceId="fi-test", + mandateId="m-test", + tableFilters=None, + validator=None, + ) + return registry, provider + + +def _dispatchSync(registry, toolName, args): + """Synchronously dispatch a tool call through the registry.""" + call = ToolCallRequest(name=toolName, args=args) + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(registry.dispatch(call, context={})) + finally: + loop.close() + + +def test_subAgentTools_invalidFieldShortCircuitsBeforeProvider(): + """A queryTable call with an unknown field must NOT reach the provider.""" + registry, provider = _buildRegistryWithMockProvider() + result = _dispatchSync(registry, "queryTable", { + "tableName": "TrusteeDataAccountBalance", + "filters": [{"field": "klosingBalance", "op": "=", "value": 1}], + }) + assert isinstance(result, ToolResult) + assert result.success is False + assert result.errorDetails is not None + assert result.errorDetails["code"] == "FIELD_NOT_FOUND" + assert result.errorDetails["suggestion"] == "closingBalance" + assert result.error and result.error.startswith("FIELD_NOT_FOUND:") + provider.queryTable.assert_not_called() + + +def test_subAgentTools_sumClosingBalanceShortCircuits(): + """The flagship hallucination -- SUM(closingBalance) -- must be blocked + by the pre-execute validator before the DB is touched.""" + registry, provider = _buildRegistryWithMockProvider() + result = _dispatchSync(registry, "aggregateTable", { + "tableName": "TrusteeDataAccountBalance", + "aggregate": "SUM", + "field": "closingBalance", + }) + assert result.success is False + assert result.errorDetails["code"] == "INVALID_AGGREGATE_TARGET" + assert result.errorDetails["field"] == "closingBalance" + provider.aggregateTable.assert_not_called() + + +def test_subAgentTools_validCallReachesProvider(): + """Sanity: a valid call passes the validator and hits the provider.""" + registry, provider = _buildRegistryWithMockProvider() + result = _dispatchSync(registry, "queryTable", { + "tableName": "TrusteeDataAccountBalance", + "filters": [ + {"field": "periodYear", "op": "=", "value": 2025}, + {"field": "periodMonth", "op": "=", "value": 0}, + ], + "fields": ["accountNumber", "closingBalance"], + }) + assert result.success is True + assert result.errorDetails is None + provider.queryTable.assert_called_once() diff --git a/tests/unit/services/test_queryValidator.py b/tests/unit/services/test_queryValidator.py new file mode 100644 index 00000000..40c8f444 --- /dev/null +++ b/tests/unit/services/test_queryValidator.py @@ -0,0 +1,295 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Unit tests for the Feature Data Sub-Agent QueryValidator. + +Each constraint is exercised with both a Happy and a Sad path so a future +refactor that silently drops a check is caught immediately. + +Test fixture is the real ``TrusteeDataAccountBalance`` / ``TrusteeDataJournalLine`` +Pydantic models -- both are perfectly suited because they cover all four +constraint classes in production-realistic shape (string fields, numeric +fields, fields named ``closingBalance`` / ``debitTotal``). +""" + +from __future__ import annotations + +import pytest + +from modules.shared import fkRegistry +from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( + Constraint, + ConstraintRule, + OntologyDescriptor, + ValidationErrorCode, +) +from modules.serviceCenter.services.serviceAgent.queryValidator import QueryValidator + + +@pytest.fixture(scope="module", autouse=True) +def _ensureModels(): + fkRegistry._ensureModelsLoaded() + + +@pytest.fixture() +def validator() -> QueryValidator: + return QueryValidator() + + +# ------------------------------------------------------------------ +# FieldExists -- browseTable / queryTable / aggregateTable +# ------------------------------------------------------------------ + + +def test_browseQuery_happyPath_returnsNone(validator): + err = validator.validateBrowseQuery( + "TrusteeDataAccountBalance", + {"fields": ["accountNumber", "closingBalance"]}, + ) + assert err is None + + +def test_browseQuery_invalidField_returnsFieldNotFound(validator): + err = validator.validateBrowseQuery( + "TrusteeDataAccountBalance", + {"fields": ["closingBlance"]}, # typo + ) + assert err is not None + assert err.code == ValidationErrorCode.FIELD_NOT_FOUND + assert err.field == "closingBlance" + assert err.suggestion == "closingBalance" + + +def test_queryTable_filterOnInvalidField_returnsFieldNotFound(validator): + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"filters": [{"field": "klosingBalance", "op": "=", "value": 100}]}, + ) + assert err is not None + assert err.code == ValidationErrorCode.FIELD_NOT_FOUND + assert err.suggestion == "closingBalance" + + +def test_queryTable_unknownTable_isLenient(validator): + """When the table isn't in MODEL_REGISTRY we skip validation -- relying on + the SQL layer to surface schema errors. Prevents false positives for + pure UDB tables not exposed via Pydantic.""" + err = validator.validateQueryTable( + "NoSuchTable123", + {"filters": [{"field": "anything", "op": "=", "value": 1}]}, + ) + assert err is None + + +# ------------------------------------------------------------------ +# OperatorCompatible +# ------------------------------------------------------------------ + + +def test_queryTable_likeOnStringField_isOk(validator): + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"filters": [{"field": "accountNumber", "op": "LIKE", "value": "102%"}]}, + ) + assert err is None + + +def test_queryTable_likeOnNumericField_isOperatorIncompatible(validator): + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"filters": [{"field": "closingBalance", "op": "LIKE", "value": "100%"}]}, + ) + assert err is not None + assert err.code == ValidationErrorCode.OPERATOR_INCOMPATIBLE + assert err.field == "closingBalance" + + +def test_queryTable_gteOnNumericField_isOk(validator): + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"filters": [{"field": "closingBalance", "op": ">=", "value": 100}]}, + ) + assert err is None + + +def test_queryTable_gteOnStringField_isOperatorIncompatible(validator): + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"filters": [{"field": "currency", "op": ">=", "value": "CHF"}]}, + ) + assert err is not None + assert err.code == ValidationErrorCode.OPERATOR_INCOMPATIBLE + + +def test_queryTable_equalsOnAnyField_isOk(validator): + """`=` and `!=` work on any field type.""" + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"filters": [{"field": "currency", "op": "=", "value": "CHF"}]}, + ) + assert err is None + + +def test_queryTable_isNullOnAnyField_isOk(validator): + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"filters": [{"field": "mandateId", "op": "IS NULL", "value": None}]}, + ) + assert err is None + + +# ------------------------------------------------------------------ +# AggregateTarget -- the highest-impact rule +# ------------------------------------------------------------------ + + +def test_aggregate_sumDebitAmount_isOk(validator): + err = validator.validateAggregateQuery( + "TrusteeDataJournalLine", + {"aggregate": "SUM", "field": "debitAmount"}, + ) + assert err is None + + +def test_aggregate_sumClosingBalance_isInvalidAggregateTarget(validator): + """The flagship bug: SUM(closingBalance) across periods. Must be blocked.""" + err = validator.validateAggregateQuery( + "TrusteeDataAccountBalance", + {"aggregate": "SUM", "field": "closingBalance"}, + ) + assert err is not None + assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET + assert err.field == "closingBalance" + assert "already aggregated" in err.hint + + +def test_aggregate_avgDebitTotal_isInvalidAggregateTarget(validator): + """`*Total` columns are turnovers per period -- AVG across periods is nonsense.""" + err = validator.validateAggregateQuery( + "TrusteeDataAccountBalance", + {"aggregate": "AVG", "field": "debitTotal"}, + ) + assert err is not None + assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET + + +def test_aggregate_countClosingBalance_isOk(validator): + """COUNT on a balance column is meaningful (how many balance rows exist).""" + err = validator.validateAggregateQuery( + "TrusteeDataAccountBalance", + {"aggregate": "COUNT", "field": "closingBalance"}, + ) + assert err is None + + +def test_aggregate_sumOnStringField_isTypeMismatch(validator): + err = validator.validateAggregateQuery( + "TrusteeDataAccountBalance", + {"aggregate": "SUM", "field": "currency"}, + ) + assert err is not None + assert err.code == ValidationErrorCode.TYPE_MISMATCH + + +def test_aggregate_invalidField_returnsFieldNotFound(validator): + err = validator.validateAggregateQuery( + "TrusteeDataAccountBalance", + {"aggregate": "SUM", "field": "nonExistent"}, + ) + assert err is not None + assert err.code == ValidationErrorCode.FIELD_NOT_FOUND + + +def test_aggregate_invalidGroupBy_returnsFieldNotFound(validator): + err = validator.validateAggregateQuery( + "TrusteeDataJournalLine", + {"aggregate": "SUM", "field": "debitAmount", "groupBy": "ghostColumn"}, + ) + assert err is not None + assert err.code == ValidationErrorCode.FIELD_NOT_FOUND + + +# ------------------------------------------------------------------ +# OrderByValid +# ------------------------------------------------------------------ + + +def test_queryTable_orderByValid_isOk(validator): + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"orderBy": "periodYear"}, + ) + assert err is None + + +def test_queryTable_orderByInvalid_returnsOrderByInvalid(validator): + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"orderBy": "periodYr"}, + ) + assert err is not None + assert err.code == ValidationErrorCode.ORDER_BY_INVALID + assert err.suggestion == "periodYear" + + +def test_queryTable_orderByLiteralStringNone_isOk(validator): + """LLMs sometimes pass the literal string 'None'.""" + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"orderBy": "None"}, + ) + assert err is None + + +# ------------------------------------------------------------------ +# Ontology-driven override (Phase 2 readiness check) +# ------------------------------------------------------------------ + + +def test_ontologyOverride_blocksAggregateForOntologyField(): + """When the ontology marks a field NEVER_AGGREGATE, SUM/AVG is blocked + even if the field name doesn't match the convention suffixes.""" + ontology = OntologyDescriptor( + featureCode="trustee", + constraints=[ + Constraint( + appliesTo="TrusteeDataJournalLine.debitAmount", + rule=ConstraintRule.NEVER_AGGREGATE, + message="Synthetic test rule.", + ) + ], + ) + validatorWithOntology = QueryValidator(ontology=ontology) + err = validatorWithOntology.validateAggregateQuery( + "TrusteeDataJournalLine", + {"aggregate": "SUM", "field": "debitAmount"}, + ) + assert err is not None + assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET + + +# ------------------------------------------------------------------ +# QueryValidationError serialization (consumed by featureDataAgent) +# ------------------------------------------------------------------ + + +def test_validationError_toShortErrorIncludesCodeAndField(validator): + err = validator.validateAggregateQuery( + "TrusteeDataAccountBalance", + {"aggregate": "SUM", "field": "closingBalance"}, + ) + assert err is not None + short = err.toShortError() + assert short.startswith("INVALID_AGGREGATE_TARGET:") + assert "closingBalance" in short + + +def test_validationError_toErrorDetailsHasFourKeys(validator): + err = validator.validateQueryTable( + "TrusteeDataAccountBalance", + {"filters": [{"field": "klosingBalance", "op": "=", "value": 0}]}, + ) + assert err is not None + details = err.toErrorDetails() + assert set(details.keys()) == {"code", "field", "suggestion", "hint"} + assert details["code"] == "FIELD_NOT_FOUND" + assert details["suggestion"] == "closingBalance" diff --git a/tests/unit/services/test_trusteeOntology.py b/tests/unit/services/test_trusteeOntology.py new file mode 100644 index 00000000..887f69a4 --- /dev/null +++ b/tests/unit/services/test_trusteeOntology.py @@ -0,0 +1,199 @@ +# Copyright (c) 2026 Patrick Motsch +# All rights reserved. +"""Unit tests for the trustee ontology and the ontology-to-prompt compiler. + +Verifies: + +* the descriptor passes Pydantic validation +* `constraintsForTable` correctly scopes by table/field prefix +* the compiler emits a stable header + every entity name + every + constraint message +* the QueryValidator picks up ontology constraints (NEVER_AGGREGATE on + closingBalance) over the convention-based defaults +* the `getAgentOntology()` hook on `mainTrustee` returns the descriptor +* `_buildValidatorForFeature("trustee")` wires the validator with the + ontology +""" + +from __future__ import annotations + +import pytest + +from modules.features.trustee.mainTrustee import getAgentOntology +from modules.features.trustee.trusteeOntology import getTrusteeOntology +from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( + ConstraintRule, + OntologyDescriptor, + SemanticType, + ValidationErrorCode, +) +from modules.serviceCenter.services.serviceAgent.featureDataAgent import ( + _buildValidatorForFeature, + _loadFeatureOntologyBlock, +) +from modules.serviceCenter.services.serviceAgent.ontologyToPromptCompiler import ( + compileOntologyToPrompt, +) +from modules.serviceCenter.services.serviceAgent.queryValidator import QueryValidator +from modules.shared import fkRegistry + + +@pytest.fixture(scope="module", autouse=True) +def _ensureModels(): + fkRegistry._ensureModelsLoaded() + + +# --------------------------------------------------------------------------- +# OntologyDescriptor structure +# --------------------------------------------------------------------------- + + +def test_trusteeOntology_returnsValidDescriptor(): + ont = getTrusteeOntology() + assert isinstance(ont, OntologyDescriptor) + assert ont.featureCode == "trustee" + assert ont.entities and ont.relations and ont.constraints and ont.canonicalPatterns + + +def test_trusteeOntology_hasBankAccountSpecialization(): + ont = getTrusteeOntology() + bank = next((e for e in ont.entities if e.name == "BankAccount"), None) + assert bank is not None + assert bank.parentEntity == "Account" + assert bank.semanticType == SemanticType.ACCOUNT + + +def test_trusteeOntology_closingBalanceIsNeverAggregate(): + ont = getTrusteeOntology() + constraints = ont.constraintsForTable("TrusteeDataAccountBalance") + matching = [ + c for c in constraints + if c.rule == ConstraintRule.NEVER_AGGREGATE + and c.appliesTo == "TrusteeDataAccountBalance.closingBalance" + ] + assert matching, "Expected NEVER_AGGREGATE constraint on closingBalance" + + +def test_trusteeOntology_requiresPeriodFilterOnBalanceTable(): + ont = getTrusteeOntology() + constraints = ont.constraintsForTable("TrusteeDataAccountBalance") + table_level = [c for c in constraints if c.rule == ConstraintRule.REQUIRES_FILTER_ON] + assert table_level, "Expected at least one REQUIRES_FILTER_ON constraint" + required = table_level[0].params.get("requiredFields") or [] + assert "periodYear" in required + assert "periodMonth" in required + + +def test_constraintsForTable_filtersScopeCorrectly(): + ont = getTrusteeOntology() + bal = ont.constraintsForTable("TrusteeDataAccountBalance") + journal = ont.constraintsForTable("TrusteeDataJournalLine") + for c in bal: + assert c.appliesTo.startswith("TrusteeDataAccountBalance") + for c in journal: + assert c.appliesTo.startswith("TrusteeDataJournalLine") + + +# --------------------------------------------------------------------------- +# Prompt compiler +# --------------------------------------------------------------------------- + + +def test_compiler_emitsExpectedHeader(): + block = compileOntologyToPrompt(getTrusteeOntology()) + assert block.startswith("DOMAIN ONTOLOGY (trustee):"), block.splitlines()[0] + + +def test_compiler_includesAllEntityNames(): + ont = getTrusteeOntology() + block = compileOntologyToPrompt(ont) + for e in ont.entities: + assert e.name in block, f"Entity {e.name} missing from compiled prompt" + + +def test_compiler_includesAllConstraintMessages(): + ont = getTrusteeOntology() + block = compileOntologyToPrompt(ont) + for c in ont.constraints: + assert c.message.split(".")[0] in block, f"Constraint message missing: {c.message[:40]}" + + +def test_compiler_includesCanonicalPatternTools(): + ont = getTrusteeOntology() + block = compileOntologyToPrompt(ont) + for p in ont.canonicalPatterns: + assert p.intent in block + assert p.pattern["tool"] in block + + +def test_compiler_deterministic(): + block1 = compileOntologyToPrompt(getTrusteeOntology()) + block2 = compileOntologyToPrompt(getTrusteeOntology()) + assert block1 == block2 + + +# --------------------------------------------------------------------------- +# QueryValidator x ontology integration +# --------------------------------------------------------------------------- + + +def test_validator_picksUpOntologyNeverAggregate(): + validator = QueryValidator(ontology=getTrusteeOntology()) + err = validator.validateAggregateQuery( + "TrusteeDataAccountBalance", + {"aggregate": "SUM", "field": "closingBalance"}, + ) + assert err is not None + assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET + assert err.field == "closingBalance" + + +def test_validator_ontologyConstraintFiresOnDebitTotal(): + validator = QueryValidator(ontology=getTrusteeOntology()) + err = validator.validateAggregateQuery( + "TrusteeDataAccountBalance", + {"aggregate": "SUM", "field": "debitTotal"}, + ) + assert err is not None + assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET + + +def test_validator_allowsLegitimateAggregateOnJournalLine(): + validator = QueryValidator(ontology=getTrusteeOntology()) + err = validator.validateAggregateQuery( + "TrusteeDataJournalLine", + {"aggregate": "SUM", "field": "debitAmount"}, + ) + assert err is None + + +# --------------------------------------------------------------------------- +# featureDataAgent integration hooks +# --------------------------------------------------------------------------- + + +def test_mainTrustee_getAgentOntology_returnsDescriptor(): + ont = getAgentOntology() + assert isinstance(ont, OntologyDescriptor) + assert ont.featureCode == "trustee" + + +def test_loadFeatureOntologyBlock_returnsCompiledBlock(): + block = _loadFeatureOntologyBlock("trustee") + assert block.startswith("DOMAIN ONTOLOGY (trustee):") + assert "BankAccount" in block + + +def test_loadFeatureOntologyBlock_unknownFeatureReturnsEmpty(): + assert _loadFeatureOntologyBlock("doesNotExist") == "" + + +def test_buildValidatorForFeature_trustee_hasOntology(): + validator = _buildValidatorForFeature("trustee") + assert validator._ontology is not None + assert validator._ontology.featureCode == "trustee" + + +def test_buildValidatorForFeature_unknownFeature_noOntology(): + validator = _buildValidatorForFeature("doesNotExist") + assert validator._ontology is None