Merge pull request #164 from valueonag/feat/demo-system-readieness
rag enhancements
This commit is contained in:
commit
f5aba4bf99
41 changed files with 4809 additions and 398 deletions
5
app.py
5
app.py
|
|
@ -404,8 +404,10 @@ async def lifespan(app: FastAPI):
|
||||||
try:
|
try:
|
||||||
from modules.serviceCenter.services.serviceBackgroundJobs.mainBackgroundJobService import (
|
from modules.serviceCenter.services.serviceBackgroundJobs.mainBackgroundJobService import (
|
||||||
recoverInterruptedJobs,
|
recoverInterruptedJobs,
|
||||||
|
registerZombieKillerScheduler,
|
||||||
)
|
)
|
||||||
recoverInterruptedJobs()
|
recoverInterruptedJobs()
|
||||||
|
registerZombieKillerScheduler(intervalMinutes=5)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"BackgroundJob recovery failed (non-critical): {e}")
|
logger.warning(f"BackgroundJob recovery failed (non-critical): {e}")
|
||||||
|
|
||||||
|
|
@ -607,6 +609,9 @@ app.include_router(connectionsRouter)
|
||||||
from modules.routes.routeRagInventory import router as ragInventoryRouter
|
from modules.routes.routeRagInventory import router as ragInventoryRouter
|
||||||
app.include_router(ragInventoryRouter)
|
app.include_router(ragInventoryRouter)
|
||||||
|
|
||||||
|
from modules.routes.routeAdminSttBenchmark import router as sttBenchmarkRouter
|
||||||
|
app.include_router(sttBenchmarkRouter)
|
||||||
|
|
||||||
from modules.routes.routeTableViews import router as tableViewsRouter
|
from modules.routes.routeTableViews import router as tableViewsRouter
|
||||||
app.include_router(tableViewsRouter)
|
app.include_router(tableViewsRouter)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -319,25 +319,24 @@ class AiOpenai(BaseConnectorAi):
|
||||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00013
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.00013
|
||||||
),
|
),
|
||||||
AiModel(
|
AiModel(
|
||||||
name="dall-e-3",
|
name="gpt-image-1",
|
||||||
displayName="OpenAI DALL-E 3",
|
displayName="OpenAI GPT Image",
|
||||||
connectorType="openai",
|
connectorType="openai",
|
||||||
apiUrl="https://api.openai.com/v1/images/generations",
|
apiUrl="https://api.openai.com/v1/images/generations",
|
||||||
temperature=0.0, # Image generation doesn't use temperature
|
temperature=0.0,
|
||||||
maxTokens=0, # Image generation doesn't use tokens
|
maxTokens=0,
|
||||||
contextLength=0,
|
contextLength=0,
|
||||||
costPer1kTokensInput=0.04,
|
costPer1kTokensInput=0.04,
|
||||||
costPer1kTokensOutput=0.0,
|
costPer1kTokensOutput=0.0,
|
||||||
speedRating=5, # Slow for image generation
|
speedRating=5,
|
||||||
qualityRating=9, # High quality art generation
|
qualityRating=9,
|
||||||
# capabilities removed (not used in business logic)
|
|
||||||
functionCall=self.generateImage,
|
functionCall=self.generateImage,
|
||||||
priority=PriorityEnum.QUALITY,
|
priority=PriorityEnum.QUALITY,
|
||||||
processingMode=ProcessingModeEnum.DETAILED,
|
processingMode=ProcessingModeEnum.DETAILED,
|
||||||
operationTypes=createOperationTypeRatings(
|
operationTypes=createOperationTypeRatings(
|
||||||
(OperationTypeEnum.IMAGE_GENERATE, 10)
|
(OperationTypeEnum.IMAGE_GENERATE, 10)
|
||||||
),
|
),
|
||||||
version="dall-e-3",
|
version="gpt-image-1",
|
||||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.04
|
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: (bytesSent / 4 / 1000) * 0.04
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
@ -653,105 +652,82 @@ class AiOpenai(BaseConnectorAi):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def generateImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
async def generateImage(self, modelCall: AiModelCall) -> AiModelResponse:
|
||||||
"""
|
"""Generate an image using GPT Image model (gpt-image-1)."""
|
||||||
Generate an image using DALL-E 3 using standardized pattern.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
modelCall: AiModelCall with messages and generation options
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AiModelResponse with generated image data
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# Extract parameters from modelCall
|
|
||||||
messages = modelCall.messages
|
|
||||||
model = modelCall.model
|
|
||||||
options = modelCall.options
|
|
||||||
|
|
||||||
# Get prompt from messages
|
|
||||||
promptContent = messages[0]["content"] if messages else ""
|
|
||||||
|
|
||||||
# Parse prompt using AiCallPromptImage model
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
messages = modelCall.messages
|
||||||
|
options = modelCall.options
|
||||||
|
promptContent = messages[0]["content"] if messages else ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Try to parse as JSON
|
|
||||||
promptData = json.loads(promptContent)
|
promptData = json.loads(promptContent)
|
||||||
promptModel = AiCallPromptImage(**promptData)
|
promptModel = AiCallPromptImage(**promptData)
|
||||||
except:
|
except Exception:
|
||||||
# If not JSON, use plain text prompt
|
|
||||||
promptModel = AiCallPromptImage(
|
promptModel = AiCallPromptImage(
|
||||||
prompt=promptContent,
|
prompt=promptContent,
|
||||||
size=options.size if options and hasattr(options, 'size') else "1024x1024",
|
size=options.size if options and hasattr(options, "size") else "1024x1024",
|
||||||
quality=options.quality if options and hasattr(options, 'quality') else "standard",
|
quality=options.quality if options and hasattr(options, "quality") else "auto",
|
||||||
style=options.style if options and hasattr(options, 'style') else "vivid"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract parameters from Pydantic model
|
|
||||||
prompt = promptModel.prompt
|
prompt = promptModel.prompt
|
||||||
size = promptModel.size or "1024x1024"
|
size = promptModel.size or "1024x1024"
|
||||||
quality = promptModel.quality or "standard"
|
rawQuality = promptModel.quality or "auto"
|
||||||
style = promptModel.style or "vivid"
|
quality = {"standard": "auto", "hd": "high"}.get(rawQuality, rawQuality)
|
||||||
|
|
||||||
logger.debug(f"Starting image generation with prompt: '{prompt[:100]}...'")
|
logger.debug(f"Starting image generation with prompt: '{prompt[:100]}...'")
|
||||||
|
|
||||||
# DALL-E 3 API endpoint
|
|
||||||
dalle_url = "https://api.openai.com/v1/images/generations"
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": "dall-e-3",
|
"model": "gpt-image-1",
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"size": size,
|
"size": size,
|
||||||
"quality": quality,
|
"quality": quality,
|
||||||
"style": style,
|
|
||||||
"n": 1,
|
"n": 1,
|
||||||
"response_format": "b64_json" # Get base64 data directly instead of URLs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Use existing httpClient to benefit from connection pooling
|
|
||||||
# This avoids TLS connection issues that can occur with fresh clients
|
|
||||||
response = await self.httpClient.post(
|
response = await self.httpClient.post(
|
||||||
dalle_url,
|
"https://api.openai.com/v1/images/generations",
|
||||||
json=payload
|
json=payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.error(f"DALL-E API error: {response.status_code} - {response.text}")
|
logger.error(f"Image generation API error: {response.status_code} - {response.text}")
|
||||||
return AiModelResponse(
|
return AiModelResponse(
|
||||||
content="",
|
content="",
|
||||||
success=False,
|
success=False,
|
||||||
error=f"DALL-E API error: {response.status_code} - {response.text}"
|
error=f"Image generation API error: {response.status_code} - {response.text}",
|
||||||
)
|
)
|
||||||
|
|
||||||
responseJson = response.json()
|
responseJson = response.json()
|
||||||
|
|
||||||
if "data" in responseJson and len(responseJson["data"]) > 0:
|
if "data" in responseJson and len(responseJson["data"]) > 0:
|
||||||
image_data = responseJson["data"][0]["b64_json"]
|
imageData = responseJson["data"][0].get("b64_json", "")
|
||||||
|
if not imageData:
|
||||||
logger.info(f"Successfully generated image: {len(image_data)} characters")
|
imageData = responseJson["data"][0].get("url", "")
|
||||||
|
|
||||||
|
logger.info(f"Successfully generated image: {len(imageData)} characters")
|
||||||
return AiModelResponse(
|
return AiModelResponse(
|
||||||
content=image_data,
|
content=imageData,
|
||||||
success=True,
|
success=True,
|
||||||
modelId="dall-e-3",
|
modelId="gpt-image-1",
|
||||||
metadata={
|
metadata={
|
||||||
"size": size,
|
"size": size,
|
||||||
"quality": quality,
|
"quality": quality,
|
||||||
"style": style,
|
"response_id": responseJson.get("id", ""),
|
||||||
"response_id": responseJson.get("id", "")
|
},
|
||||||
}
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.error("No image data in DALL-E response")
|
logger.error("No image data in generation response")
|
||||||
return AiModelResponse(
|
return AiModelResponse(
|
||||||
content="",
|
content="",
|
||||||
success=False,
|
success=False,
|
||||||
error="No image data in DALL-E response"
|
error="No image data in generation response",
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during image generation: {str(e)}", exc_info=True)
|
logger.error(f"Error during image generation: {str(e)}", exc_info=True)
|
||||||
return AiModelResponse(
|
return AiModelResponse(
|
||||||
content="",
|
content="",
|
||||||
success=False,
|
success=False,
|
||||||
error=f"Error during image generation: {str(e)}"
|
error=f"Error during image generation: {str(e)}",
|
||||||
)
|
)
|
||||||
|
|
@ -311,7 +311,10 @@ class DatabaseConnector:
|
||||||
# Establish connection to the database
|
# Establish connection to the database
|
||||||
self._connect()
|
self._connect()
|
||||||
|
|
||||||
logger.info("PostgreSQL database system initialized successfully")
|
logger.debug(
|
||||||
|
"PostgreSQL database system initialized (db=%s, host=%s, port=%s)",
|
||||||
|
self.dbDatabase, self.dbHost, self.dbPort,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
|
logger.error(f"FATAL ERROR: Database system initialization failed: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -245,11 +245,10 @@ class AiCallPromptWebCrawl(BaseModel):
|
||||||
|
|
||||||
class AiCallPromptImage(BaseModel):
|
class AiCallPromptImage(BaseModel):
|
||||||
"""Structured prompt format for image generation."""
|
"""Structured prompt format for image generation."""
|
||||||
|
|
||||||
prompt: str = Field(description="Text description of the image to generate")
|
prompt: str = Field(description="Text description of the image to generate")
|
||||||
size: Optional[str] = Field(default="1024x1024", description="Image size (1024x1024, 1792x1024, 1024x1792)")
|
size: Optional[str] = Field(default="1024x1024", description="Image size (1024x1024, 1536x1024, 1024x1536)")
|
||||||
quality: Optional[str] = Field(default="standard", description="Image quality (standard, hd)")
|
quality: Optional[str] = Field(default="auto", description="Image quality (auto, high, medium, low)")
|
||||||
style: Optional[str] = Field(default="vivid", description="Image style (vivid, natural)")
|
|
||||||
|
|
||||||
|
|
||||||
class AiProcessParameters(BaseModel):
|
class AiProcessParameters(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -754,14 +754,35 @@ ANTI-PATTERNS (do NOT do this):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Parked for one release as a fallback while the ontology-based path rolls
|
||||||
|
# out (see `trusteeOntology.getTrusteeOntology()`). Remove together with the
|
||||||
|
# legacy ``_loadFeatureDomainHints`` path once Phase 2 is the only supplier
|
||||||
|
# of the trustee prompt block.
|
||||||
|
_AGENT_DOMAIN_HINTS_LEGACY = _AGENT_DOMAIN_HINTS
|
||||||
|
|
||||||
|
|
||||||
def getAgentDomainHints() -> str:
|
def getAgentDomainHints() -> str:
|
||||||
"""Return Trustee-specific guidance for the Feature Data Sub-Agent.
|
"""Return Trustee-specific guidance for the Feature Data Sub-Agent.
|
||||||
|
|
||||||
The text is appended verbatim to the sub-agent's system prompt by
|
Deprecated as of Phase 2 (2026-05). Prefer ``getAgentOntology()`` ->
|
||||||
``featureDataAgent._buildSchemaContext``. Keep it concise and
|
``ontologyToPromptCompiler.compileOntologyToPrompt(...)``. The legacy
|
||||||
pattern-driven — every line costs tokens on every sub-agent call.
|
text remains available so callers that still go through
|
||||||
|
``_buildSchemaContext()`` keep working during the migration window.
|
||||||
"""
|
"""
|
||||||
return _AGENT_DOMAIN_HINTS
|
return _AGENT_DOMAIN_HINTS_LEGACY
|
||||||
|
|
||||||
|
|
||||||
|
def getAgentOntology():
|
||||||
|
"""Return the structured ontology used by the Feature Data Sub-Agent.
|
||||||
|
|
||||||
|
Discovered by ``featureDataAgent._buildSchemaContext`` (Phase 2 path):
|
||||||
|
when this hook is present, the agent compiles its domain block from
|
||||||
|
the ontology instead of using the legacy free-text hints. The same
|
||||||
|
descriptor feeds the validator's NEVER_AGGREGATE constraints, so
|
||||||
|
prompt and validator stay in sync.
|
||||||
|
"""
|
||||||
|
from modules.features.trustee.trusteeOntology import getTrusteeOntology
|
||||||
|
return getTrusteeOntology()
|
||||||
|
|
||||||
|
|
||||||
def registerFeature(catalogService) -> bool:
|
def registerFeature(catalogService) -> bool:
|
||||||
|
|
|
||||||
295
modules/features/trustee/trusteeOntology.py
Normal file
295
modules/features/trustee/trusteeOntology.py
Normal file
|
|
@ -0,0 +1,295 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Trustee feature ontology (Phase 2 pilot).
|
||||||
|
|
||||||
|
Replaces the hand-written ``_AGENT_DOMAIN_HINTS`` block with a structured
|
||||||
|
ontology so the Feature Data Sub-Agent's QueryValidator AND the prompt
|
||||||
|
compiler share the same source of truth: account-group conventions,
|
||||||
|
period-bucket semantics, the NEVER_AGGREGATE constraints on already-
|
||||||
|
aggregated columns, and canonical tool-call templates for the most
|
||||||
|
frequent user intents.
|
||||||
|
|
||||||
|
Both the validator (deterministic enforcement) and the prompt compiler
|
||||||
|
(LLM steering) read from this descriptor, so an LLM that follows the
|
||||||
|
prompt patterns will never trigger a validator failure -- and one that
|
||||||
|
ignores them gets a structured repair hint pointing back at the same
|
||||||
|
constraint.
|
||||||
|
|
||||||
|
The legacy ``_AGENT_DOMAIN_HINTS_LEGACY`` block stays parked in
|
||||||
|
``mainTrustee.py`` for one release as a fallback during rollout.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
|
||||||
|
CanonicalQueryPattern,
|
||||||
|
Cardinality,
|
||||||
|
Constraint,
|
||||||
|
ConstraintRule,
|
||||||
|
Entity,
|
||||||
|
Invariant,
|
||||||
|
OntologyDescriptor,
|
||||||
|
Relation,
|
||||||
|
SemanticType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Entities
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_ENTITIES = [
|
||||||
|
Entity(
|
||||||
|
name="Account",
|
||||||
|
pythonClass="TrusteeDataAccount",
|
||||||
|
semanticType=SemanticType.ACCOUNT,
|
||||||
|
description=(
|
||||||
|
"Chart-of-accounts row (Konto). One row per accountNumber per "
|
||||||
|
"mandate. Identifies the account, never holds balances."
|
||||||
|
),
|
||||||
|
invariants=[
|
||||||
|
Invariant(description="accountNumber is a stable string identifier (e.g. '1020', '5400')."),
|
||||||
|
Invariant(description="accountType is one of: asset / liability / revenue / expense."),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Entity(
|
||||||
|
name="BankAccount",
|
||||||
|
pythonClass="TrusteeDataAccount",
|
||||||
|
semanticType=SemanticType.ACCOUNT,
|
||||||
|
parentEntity="Account",
|
||||||
|
description="Account subgroup with accountNumber LIKE '102%' (ZKB, PostFinance, UBS, ...).",
|
||||||
|
),
|
||||||
|
Entity(
|
||||||
|
name="CashAccount",
|
||||||
|
pythonClass="TrusteeDataAccount",
|
||||||
|
semanticType=SemanticType.ACCOUNT,
|
||||||
|
parentEntity="Account",
|
||||||
|
description="Account subgroup with accountNumber LIKE '100%' (Hauptkasse, Nebenkassen).",
|
||||||
|
),
|
||||||
|
Entity(
|
||||||
|
name="AccountBalance",
|
||||||
|
pythonClass="TrusteeDataAccountBalance",
|
||||||
|
semanticType=SemanticType.BALANCE_SNAPSHOT,
|
||||||
|
description=(
|
||||||
|
"Period-bucketed snapshot: one row per (account, year, month). "
|
||||||
|
"closingBalance is THE balance at end of period -- already aggregated."
|
||||||
|
),
|
||||||
|
invariants=[
|
||||||
|
Invariant(description="periodMonth=0 means annual total of periodYear (use for 'per 31.12.YYYY')."),
|
||||||
|
Invariant(description="periodMonth in 1..12 means month-end snapshot."),
|
||||||
|
Invariant(description="closingBalance is the balance at period end; openingBalance at period start."),
|
||||||
|
Invariant(description="debitTotal/creditTotal are turnovers for the period, NOT balances."),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Entity(
|
||||||
|
name="JournalEntry",
|
||||||
|
pythonClass="TrusteeDataJournalEntry",
|
||||||
|
semanticType=SemanticType.TRANSACTION,
|
||||||
|
description="One booking header (Beleg). Has a bookingDate (unix seconds float) and totalAmount.",
|
||||||
|
invariants=[
|
||||||
|
Invariant(description="bookingDate is a UTC unix-seconds float; never compare against ISO strings."),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
Entity(
|
||||||
|
name="JournalLine",
|
||||||
|
pythonClass="TrusteeDataJournalLine",
|
||||||
|
semanticType=SemanticType.TRANSACTION,
|
||||||
|
description="One booking line of a JournalEntry. Each line debits or credits exactly one account.",
|
||||||
|
invariants=[
|
||||||
|
Invariant(description="Per line either debitAmount > 0 (Soll) or creditAmount > 0 (Haben), not both."),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Relations
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_RELATIONS = [
|
||||||
|
Relation(fromEntity="AccountBalance", toEntity="Account", cardinality=Cardinality.MANY_TO_ONE, via="accountNumber"),
|
||||||
|
Relation(fromEntity="JournalLine", toEntity="JournalEntry", cardinality=Cardinality.MANY_TO_ONE, via="journalEntryId"),
|
||||||
|
Relation(fromEntity="JournalLine", toEntity="Account", cardinality=Cardinality.MANY_TO_ONE, via="accountNumber"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Constraints (validator-enforced)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_CONSTRAINTS = [
|
||||||
|
# closingBalance is the single biggest hallucination magnet -- it's a
|
||||||
|
# balance per period, summing it across periods or accounts is meaningless.
|
||||||
|
Constraint(
|
||||||
|
appliesTo="TrusteeDataAccountBalance.closingBalance",
|
||||||
|
rule=ConstraintRule.NEVER_AGGREGATE,
|
||||||
|
message=(
|
||||||
|
"closingBalance is per-period already; query with periodYear+periodMonth, never SUM/AVG it."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Constraint(
|
||||||
|
appliesTo="TrusteeDataAccountBalance.openingBalance",
|
||||||
|
rule=ConstraintRule.NEVER_AGGREGATE,
|
||||||
|
message="openingBalance is already a balance per period; do not SUM/AVG it across rows.",
|
||||||
|
),
|
||||||
|
Constraint(
|
||||||
|
appliesTo="TrusteeDataAccountBalance.debitTotal",
|
||||||
|
rule=ConstraintRule.NEVER_AGGREGATE,
|
||||||
|
message=(
|
||||||
|
"debitTotal is the period's debit TURNOVER; do not SUM it without an explicit period filter."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Constraint(
|
||||||
|
appliesTo="TrusteeDataAccountBalance.creditTotal",
|
||||||
|
rule=ConstraintRule.NEVER_AGGREGATE,
|
||||||
|
message="creditTotal is a per-period turnover; do not SUM it across periods without an explicit period filter.",
|
||||||
|
),
|
||||||
|
# AccountBalance queries without a period filter are almost always wrong --
|
||||||
|
# they conflate annual and monthly snapshots. Phase 2 (REQUIRES_FILTER_ON)
|
||||||
|
# is wired through to the validator in a later iteration; for now this
|
||||||
|
# rule is rendered into the prompt compiler so the LLM sees it explicitly.
|
||||||
|
Constraint(
|
||||||
|
appliesTo="TrusteeDataAccountBalance",
|
||||||
|
rule=ConstraintRule.REQUIRES_FILTER_ON,
|
||||||
|
message=(
|
||||||
|
"Always filter on periodYear AND periodMonth (use periodMonth=0 for end-of-year)."
|
||||||
|
),
|
||||||
|
params={"requiredFields": ["periodYear", "periodMonth"]},
|
||||||
|
),
|
||||||
|
Constraint(
|
||||||
|
appliesTo="TrusteeDataAccountBalance",
|
||||||
|
rule=ConstraintRule.PREFERRED_TABLE_FOR_INTENT,
|
||||||
|
message="For 'Saldo per <date>' and 'Stand <year>' questions, prefer AccountBalance over JournalLine.",
|
||||||
|
params={"intents": ["BANK_BALANCE_AT_DATE", "BALANCE_AT_YEAR_END"]},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Canonical query patterns (worked examples for the LLM)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_CANONICAL_PATTERNS = [
|
||||||
|
CanonicalQueryPattern(
|
||||||
|
intent="BANK_BALANCE_AT_DATE",
|
||||||
|
description="Saldo eines Bankkontos per Jahresende.",
|
||||||
|
pattern={
|
||||||
|
"tool": "queryTable",
|
||||||
|
"tableName": "TrusteeDataAccountBalance",
|
||||||
|
"filters": [
|
||||||
|
{"field": "accountNumber", "op": "=", "value": "<accountNumber>"},
|
||||||
|
{"field": "periodYear", "op": "=", "value": "<year>"},
|
||||||
|
{"field": "periodMonth", "op": "=", "value": 0},
|
||||||
|
],
|
||||||
|
"fields": ["closingBalance", "currency"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
CanonicalQueryPattern(
|
||||||
|
intent="BANK_GROUP_TOTAL_AT_DATE",
|
||||||
|
description="Summe einer Kontogruppe (z. B. alle Bankkonten 102%) per Jahresende.",
|
||||||
|
pattern={
|
||||||
|
"tool": "queryTable",
|
||||||
|
"tableName": "TrusteeDataAccountBalance",
|
||||||
|
"filters": [
|
||||||
|
{"field": "accountNumber", "op": "LIKE", "value": "<prefix>%"},
|
||||||
|
{"field": "periodYear", "op": "=", "value": "<year>"},
|
||||||
|
{"field": "periodMonth", "op": "=", "value": 0},
|
||||||
|
],
|
||||||
|
"fields": ["accountNumber", "closingBalance", "currency"],
|
||||||
|
"_postProcessing": "Sum closingBalance values in your final answer; do NOT SUM via aggregateTable.",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
CanonicalQueryPattern(
|
||||||
|
intent="BALANCE_HISTORY_PER_YEAR",
|
||||||
|
description="Saldo-Verlauf eines Kontos ueber mehrere Jahre.",
|
||||||
|
pattern={
|
||||||
|
"tool": "queryTable",
|
||||||
|
"tableName": "TrusteeDataAccountBalance",
|
||||||
|
"filters": [
|
||||||
|
{"field": "accountNumber", "op": "=", "value": "<accountNumber>"},
|
||||||
|
{"field": "periodMonth", "op": "=", "value": 0},
|
||||||
|
],
|
||||||
|
"fields": ["periodYear", "closingBalance", "currency"],
|
||||||
|
"orderBy": "periodYear",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
CanonicalQueryPattern(
|
||||||
|
intent="MONTHLY_BALANCE_SNAPSHOT",
|
||||||
|
description="Saldo per Ende eines bestimmten Monats.",
|
||||||
|
pattern={
|
||||||
|
"tool": "queryTable",
|
||||||
|
"tableName": "TrusteeDataAccountBalance",
|
||||||
|
"filters": [
|
||||||
|
{"field": "accountNumber", "op": "=", "value": "<accountNumber>"},
|
||||||
|
{"field": "periodYear", "op": "=", "value": "<year>"},
|
||||||
|
{"field": "periodMonth", "op": "=", "value": "<month 1..12>"},
|
||||||
|
],
|
||||||
|
"fields": ["closingBalance", "currency"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
CanonicalQueryPattern(
|
||||||
|
intent="ACCOUNT_LIST_BY_TYPE_OR_PREFIX",
|
||||||
|
description="Welche Konten gehoeren zu einer Gruppe (Typ oder Nummern-Prefix)?",
|
||||||
|
pattern={
|
||||||
|
"tool": "queryTable",
|
||||||
|
"tableName": "TrusteeDataAccount",
|
||||||
|
"filters": [
|
||||||
|
{"field": "accountNumber", "op": "LIKE", "value": "<prefix>%"},
|
||||||
|
],
|
||||||
|
"fields": ["accountNumber", "label", "accountType"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
CanonicalQueryPattern(
|
||||||
|
intent="JOURNAL_SUM_AT_ACCOUNT",
|
||||||
|
description="Summe der Soll- oder Haben-Buchungen auf einem Konto.",
|
||||||
|
pattern={
|
||||||
|
"tool": "aggregateTable",
|
||||||
|
"tableName": "TrusteeDataJournalLine",
|
||||||
|
"aggregate": "SUM",
|
||||||
|
"field": "debitAmount",
|
||||||
|
"filters": [
|
||||||
|
{"field": "accountNumber", "op": "=", "value": "<accountNumber>"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
CanonicalQueryPattern(
|
||||||
|
intent="COUNT_ROWS",
|
||||||
|
description="Anzahl Buchungen / Buchungszeilen / Konten.",
|
||||||
|
pattern={
|
||||||
|
"tool": "aggregateTable",
|
||||||
|
"tableName": "<table>",
|
||||||
|
"aggregate": "COUNT",
|
||||||
|
"field": "id",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
CanonicalQueryPattern(
|
||||||
|
intent="JOURNAL_LINES_BY_AMOUNT",
|
||||||
|
description="Buchungszeilen mit einem Betrag groesser/kleiner als einer Schwelle.",
|
||||||
|
pattern={
|
||||||
|
"tool": "queryTable",
|
||||||
|
"tableName": "TrusteeDataJournalLine",
|
||||||
|
"filters": [
|
||||||
|
{"field": "debitAmount", "op": ">", "value": "<amount>"},
|
||||||
|
],
|
||||||
|
"fields": ["accountNumber", "debitAmount", "description"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_TRUSTEE_ONTOLOGY = OntologyDescriptor(
|
||||||
|
featureCode="trustee",
|
||||||
|
entities=_ENTITIES,
|
||||||
|
relations=_RELATIONS,
|
||||||
|
constraints=_CONSTRAINTS,
|
||||||
|
canonicalPatterns=_CANONICAL_PATTERNS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def getTrusteeOntology() -> OntologyDescriptor:
|
||||||
|
"""Public accessor for the trustee ontology.
|
||||||
|
|
||||||
|
Cached as a module-level singleton -- the descriptor is immutable and
|
||||||
|
has no per-call state.
|
||||||
|
"""
|
||||||
|
return _TRUSTEE_ONTOLOGY
|
||||||
217
modules/routes/routeAdminSttBenchmark.py
Normal file
217
modules/routes/routeAdminSttBenchmark.py
Normal file
|
|
@ -0,0 +1,217 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""STT Benchmark route — compare Speech-to-Text v1 (latest_long) vs v2 (Chirp 2).
|
||||||
|
|
||||||
|
Sysadmin-only page for evaluating STT model quality and latency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Depends, Request, UploadFile, File, Form
|
||||||
|
from modules.auth import limiter, getCurrentUser
|
||||||
|
from modules.datamodels.datamodelUam import User
|
||||||
|
from modules.shared.configuration import APP_CONFIG
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/api/admin/stt-benchmark",
|
||||||
|
tags=["Admin STT Benchmark"],
|
||||||
|
responses={401: {"description": "Unauthorized"}, 403: {"description": "Forbidden"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _requireSysAdmin(currentUser: User = Depends(getCurrentUser)) -> User:
|
||||||
|
if not getattr(currentUser, "isSysAdmin", False) and not getattr(currentUser, "isPlatformAdmin", False):
|
||||||
|
raise HTTPException(status_code=403, detail="SysAdmin required")
|
||||||
|
return currentUser
|
||||||
|
|
||||||
|
|
||||||
|
def _getCredentials():
|
||||||
|
apiKey = APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET")
|
||||||
|
if not apiKey or apiKey.startswith("YOUR_"):
|
||||||
|
raise HTTPException(status_code=500, detail="Google Speech API key not configured")
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
return service_account.Credentials.from_service_account_info(json.loads(apiKey))
|
||||||
|
|
||||||
|
|
||||||
|
def _runV1(audioBytes: bytes, language: str, model: str) -> Dict[str, Any]:
|
||||||
|
"""Run Speech-to-Text v1 recognition."""
|
||||||
|
from google.cloud import speech
|
||||||
|
credentials = _getCredentials()
|
||||||
|
client = speech.SpeechClient(credentials=credentials)
|
||||||
|
|
||||||
|
config = speech.RecognitionConfig(
|
||||||
|
encoding=speech.RecognitionConfig.AudioEncoding.ENCODING_UNSPECIFIED,
|
||||||
|
language_code=language,
|
||||||
|
model=model,
|
||||||
|
enable_automatic_punctuation=True,
|
||||||
|
enable_word_time_offsets=True,
|
||||||
|
enable_word_confidence=True,
|
||||||
|
max_alternatives=3,
|
||||||
|
use_enhanced=True,
|
||||||
|
)
|
||||||
|
audio = speech.RecognitionAudio(content=audioBytes)
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
response = client.recognize(config=config, audio=audio)
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for r in response.results:
|
||||||
|
for alt in r.alternatives:
|
||||||
|
results.append({
|
||||||
|
"transcript": alt.transcript,
|
||||||
|
"confidence": round(alt.confidence, 4),
|
||||||
|
"words": len(alt.words) if alt.words else 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"api": "v1",
|
||||||
|
"model": model,
|
||||||
|
"latencyMs": round(elapsed * 1000, 1),
|
||||||
|
"results": results,
|
||||||
|
"resultCount": len(response.results),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _runV2(audioBytes: bytes, language: str, model: str, location: str) -> Dict[str, Any]:
|
||||||
|
"""Run Speech-to-Text v2 recognition (Chirp 2)."""
|
||||||
|
from google.cloud.speech_v2 import SpeechClient
|
||||||
|
from google.cloud.speech_v2.types import cloud_speech
|
||||||
|
|
||||||
|
credentials = _getCredentials()
|
||||||
|
credInfo = json.loads(APP_CONFIG.get("Connector_GoogleSpeech_API_KEY_SECRET"))
|
||||||
|
projectId = credInfo.get("project_id", "")
|
||||||
|
|
||||||
|
client = SpeechClient(
|
||||||
|
credentials=credentials,
|
||||||
|
client_options={"api_endpoint": f"{location}-speech.googleapis.com"},
|
||||||
|
)
|
||||||
|
|
||||||
|
config = cloud_speech.RecognitionConfig(
|
||||||
|
auto_decoding_config=cloud_speech.AutoDetectDecodingConfig(),
|
||||||
|
language_codes=[language],
|
||||||
|
model=model,
|
||||||
|
features=cloud_speech.RecognitionFeatures(
|
||||||
|
enable_automatic_punctuation=True,
|
||||||
|
enable_word_time_offsets=True,
|
||||||
|
enable_word_confidence=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
recognizer = f"projects/{projectId}/locations/{location}/recognizers/_"
|
||||||
|
|
||||||
|
request = cloud_speech.RecognizeRequest(
|
||||||
|
recognizer=recognizer,
|
||||||
|
config=config,
|
||||||
|
content=audioBytes,
|
||||||
|
)
|
||||||
|
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
response = client.recognize(request=request)
|
||||||
|
elapsed = time.perf_counter() - t0
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for r in response.results:
|
||||||
|
for alt in r.alternatives:
|
||||||
|
results.append({
|
||||||
|
"transcript": alt.transcript,
|
||||||
|
"confidence": round(alt.confidence, 4),
|
||||||
|
"words": len(alt.words) if alt.words else 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"api": "v2",
|
||||||
|
"model": model,
|
||||||
|
"location": location,
|
||||||
|
"latencyMs": round(elapsed * 1000, 1),
|
||||||
|
"results": results,
|
||||||
|
"resultCount": len(getattr(response, "results", [])),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/run")
|
||||||
|
@limiter.limit("10/minute")
|
||||||
|
async def runBenchmark(
|
||||||
|
request: Request,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
language: str = Form(default="de-DE"),
|
||||||
|
v1Model: str = Form(default="latest_long"),
|
||||||
|
v2Model: str = Form(default="chirp_2"),
|
||||||
|
v2Location: str = Form(default="europe-west4"),
|
||||||
|
currentUser: User = Depends(_requireSysAdmin),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Upload audio and compare v1 vs v2 STT results."""
|
||||||
|
audioBytes = await file.read()
|
||||||
|
if len(audioBytes) > 10 * 1024 * 1024:
|
||||||
|
raise HTTPException(status_code=400, detail="Audio file too large (max 10 MB)")
|
||||||
|
if len(audioBytes) < 100:
|
||||||
|
raise HTTPException(status_code=400, detail="Audio file too small")
|
||||||
|
|
||||||
|
logger.info("STT benchmark: %s, %d bytes, language=%s, v1=%s, v2=%s@%s",
|
||||||
|
file.filename, len(audioBytes), language, v1Model, v2Model, v2Location)
|
||||||
|
|
||||||
|
v1Result = None
|
||||||
|
v1Error = None
|
||||||
|
try:
|
||||||
|
v1Result = _runV1(audioBytes, language, v1Model)
|
||||||
|
except Exception as e:
|
||||||
|
v1Error = str(e)
|
||||||
|
logger.warning("STT v1 benchmark failed: %s", e)
|
||||||
|
|
||||||
|
v2Result = None
|
||||||
|
v2Error = None
|
||||||
|
try:
|
||||||
|
v2Result = _runV2(audioBytes, language, v2Model, v2Location)
|
||||||
|
except Exception as e:
|
||||||
|
v2Error = str(e)
|
||||||
|
logger.warning("STT v2 benchmark failed: %s", e)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"filename": file.filename,
|
||||||
|
"fileSizeBytes": len(audioBytes),
|
||||||
|
"language": language,
|
||||||
|
"v1": v1Result or {"error": v1Error},
|
||||||
|
"v2": v2Result or {"error": v2Error},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
@limiter.limit("30/minute")
|
||||||
|
async def getAvailableModels(
|
||||||
|
request: Request,
|
||||||
|
currentUser: User = Depends(_requireSysAdmin),
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Return available STT models for the benchmark UI."""
|
||||||
|
return {
|
||||||
|
"v1Models": [
|
||||||
|
{"value": "latest_long", "label": "latest_long (default)"},
|
||||||
|
{"value": "latest_short", "label": "latest_short"},
|
||||||
|
{"value": "phone_call", "label": "phone_call"},
|
||||||
|
{"value": "video", "label": "video"},
|
||||||
|
{"value": "command_and_search", "label": "command_and_search"},
|
||||||
|
],
|
||||||
|
"v2Models": [
|
||||||
|
{"value": "chirp_2", "label": "Chirp 2 (recommended)"},
|
||||||
|
{"value": "chirp", "label": "Chirp (original)"},
|
||||||
|
{"value": "long", "label": "long"},
|
||||||
|
{"value": "short", "label": "short"},
|
||||||
|
],
|
||||||
|
"locations": [
|
||||||
|
{"value": "europe-west4", "label": "Europe West (NL)"},
|
||||||
|
{"value": "us-central1", "label": "US Central"},
|
||||||
|
{"value": "asia-southeast1", "label": "Asia Southeast"},
|
||||||
|
],
|
||||||
|
"languages": [
|
||||||
|
{"value": "de-DE", "label": "Deutsch (DE)"},
|
||||||
|
{"value": "de-CH", "label": "Deutsch (CH)"},
|
||||||
|
{"value": "en-US", "label": "English (US)"},
|
||||||
|
{"value": "en-GB", "label": "English (GB)"},
|
||||||
|
{"value": "fr-FR", "label": "Francais (FR)"},
|
||||||
|
{"value": "it-IT", "label": "Italiano (IT)"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
@ -745,7 +745,7 @@ def _findOwnConnection(interface, userId: str, connectionId: str):
|
||||||
|
|
||||||
@router.patch("/{connectionId}/knowledge-consent")
|
@router.patch("/{connectionId}/knowledge-consent")
|
||||||
@limiter.limit("10/minute")
|
@limiter.limit("10/minute")
|
||||||
def _updateKnowledgeConsent(
|
async def _updateKnowledgeConsent(
|
||||||
request: Request,
|
request: Request,
|
||||||
connectionId: str = Path(..., description="Connection ID"),
|
connectionId: str = Path(..., description="Connection ID"),
|
||||||
enabled: bool = Body(..., embed=True),
|
enabled: bool = Body(..., embed=True),
|
||||||
|
|
@ -780,24 +780,13 @@ def _updateKnowledgeConsent(
|
||||||
from modules.datamodels.datamodelDataSource import DataSource
|
from modules.datamodels.datamodelDataSource import DataSource
|
||||||
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId, "ragIndexEnabled": True})
|
dataSources = rootIf.db.getRecordset(DataSource, recordFilter={"connectionId": connectionId, "ragIndexEnabled": True})
|
||||||
if dataSources:
|
if dataSources:
|
||||||
import asyncio
|
|
||||||
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
|
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
|
||||||
authority = connection.authority.value if hasattr(connection.authority, "value") else str(connection.authority or "")
|
authority = connection.authority.value if hasattr(connection.authority, "value") else str(connection.authority or "")
|
||||||
|
await startJob(
|
||||||
async def _enqueue():
|
"connection.bootstrap",
|
||||||
await startJob(
|
{"connectionId": connectionId, "authority": authority.lower()},
|
||||||
"connection.bootstrap",
|
triggeredBy=str(currentUser.id),
|
||||||
{"connectionId": connectionId, "authority": authority.lower()},
|
)
|
||||||
triggeredBy=str(currentUser.id),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if loop.is_running():
|
|
||||||
loop.create_task(_enqueue())
|
|
||||||
else:
|
|
||||||
loop.run_until_complete(_enqueue())
|
|
||||||
except RuntimeError:
|
|
||||||
asyncio.run(_enqueue())
|
|
||||||
bootstrapEnqueued = True
|
bootstrapEnqueued = True
|
||||||
|
|
||||||
import json as _json
|
import json as _json
|
||||||
|
|
|
||||||
|
|
@ -129,7 +129,7 @@ def _updateNeutralizeFields(
|
||||||
|
|
||||||
@router.patch("/{sourceId}/rag-index")
|
@router.patch("/{sourceId}/rag-index")
|
||||||
@limiter.limit("30/minute")
|
@limiter.limit("30/minute")
|
||||||
def _updateDataSourceRagIndex(
|
async def _updateDataSourceRagIndex(
|
||||||
request: Request,
|
request: Request,
|
||||||
sourceId: str = Path(..., description="ID of the DataSource"),
|
sourceId: str = Path(..., description="ID of the DataSource"),
|
||||||
ragIndexEnabled: bool = Body(..., embed=True),
|
ragIndexEnabled: bool = Body(..., embed=True),
|
||||||
|
|
@ -139,6 +139,10 @@ def _updateDataSourceRagIndex(
|
||||||
|
|
||||||
true: sets flag + enqueues mini-bootstrap for this DataSource only.
|
true: sets flag + enqueues mini-bootstrap for this DataSource only.
|
||||||
false: sets flag + synchronously purges all chunks from this DataSource.
|
false: sets flag + synchronously purges all chunks from this DataSource.
|
||||||
|
|
||||||
|
Must be `async def` so `await startJob(...)` registers `_runJob` in the
|
||||||
|
main event loop. Sync route → worker thread → temporary loop closes
|
||||||
|
before the task runs → job stays stuck forever.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||||
|
|
@ -152,7 +156,6 @@ def _updateDataSourceRagIndex(
|
||||||
|
|
||||||
if ragIndexEnabled:
|
if ragIndexEnabled:
|
||||||
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
|
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
|
||||||
import asyncio
|
|
||||||
|
|
||||||
connectionId = rec.get("connectionId") or rec.get("connection_id") or ""
|
connectionId = rec.get("connectionId") or rec.get("connection_id") or ""
|
||||||
conn = rootIf.getUserConnectionById(connectionId) if connectionId else None
|
conn = rootIf.getUserConnectionById(connectionId) if connectionId else None
|
||||||
|
|
@ -160,20 +163,11 @@ def _updateDataSourceRagIndex(
|
||||||
if conn:
|
if conn:
|
||||||
authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
|
authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
|
||||||
|
|
||||||
async def _enqueue():
|
await startJob(
|
||||||
await startJob(
|
"connection.bootstrap",
|
||||||
"connection.bootstrap",
|
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]},
|
||||||
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": [sourceId]},
|
triggeredBy=str(context.user.id),
|
||||||
triggeredBy=str(context.user.id),
|
)
|
||||||
)
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if loop.is_running():
|
|
||||||
loop.create_task(_enqueue())
|
|
||||||
else:
|
|
||||||
loop.run_until_complete(_enqueue())
|
|
||||||
except RuntimeError:
|
|
||||||
asyncio.run(_enqueue())
|
|
||||||
else:
|
else:
|
||||||
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
|
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
|
||||||
purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId)
|
purgeResult = getKnowledgeInterface(None).deleteFileContentIndexByDataSource(sourceId)
|
||||||
|
|
|
||||||
|
|
@ -39,20 +39,27 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
||||||
chunksByDs: Dict[str, int] = {}
|
chunksByDs: Dict[str, int] = {}
|
||||||
unassigned = 0
|
unassigned = 0
|
||||||
for idx in connIndexRows:
|
for idx in connIndexRows:
|
||||||
prov = (idx.get("provenance") if isinstance(idx, dict) else getattr(idx, "provenance", None)) or {}
|
struct = (idx.get("structure") if isinstance(idx, dict) else getattr(idx, "structure", None)) or {}
|
||||||
|
ingestion = struct.get("_ingestion") or {} if isinstance(struct, dict) else {}
|
||||||
|
prov = ingestion.get("provenance") or {} if isinstance(ingestion, dict) else {}
|
||||||
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
|
dsIdRef = prov.get("dataSourceId", "") if isinstance(prov, dict) else ""
|
||||||
if dsIdRef:
|
if dsIdRef:
|
||||||
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
|
chunksByDs[dsIdRef] = chunksByDs.get(dsIdRef, 0) + 1
|
||||||
else:
|
else:
|
||||||
unassigned += 1
|
unassigned += 1
|
||||||
|
|
||||||
|
seen: Dict[str, bool] = {}
|
||||||
dsItems = []
|
dsItems = []
|
||||||
for ds in dataSources:
|
for ds in dataSources:
|
||||||
dsId = ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")
|
dsId = ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")
|
||||||
|
dsPath = ds.get("path") if isinstance(ds, dict) else getattr(ds, "path", "")
|
||||||
|
if dsPath in seen:
|
||||||
|
continue
|
||||||
|
seen[dsPath] = True
|
||||||
dsItems.append({
|
dsItems.append({
|
||||||
"id": dsId,
|
"id": dsId,
|
||||||
"label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""),
|
"label": ds.get("label") if isinstance(ds, dict) else getattr(ds, "label", ""),
|
||||||
"path": ds.get("path") if isinstance(ds, dict) else getattr(ds, "path", ""),
|
"path": dsPath,
|
||||||
"sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""),
|
"sourceType": ds.get("sourceType") if isinstance(ds, dict) else getattr(ds, "sourceType", ""),
|
||||||
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
|
"ragIndexEnabled": ds.get("ragIndexEnabled") if isinstance(ds, dict) else getattr(ds, "ragIndexEnabled", False),
|
||||||
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
|
"neutralize": ds.get("neutralize") if isinstance(ds, dict) else getattr(ds, "neutralize", False),
|
||||||
|
|
@ -60,20 +67,43 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
||||||
"chunkCount": chunksByDs.get(dsId, 0),
|
"chunkCount": chunksByDs.get(dsId, 0),
|
||||||
})
|
})
|
||||||
|
|
||||||
if unassigned > 0 and len(dsItems) == 1:
|
if unassigned > 0 and len(dsItems) > 0:
|
||||||
dsItems[0]["chunkCount"] += unassigned
|
perDs = unassigned // len(dsItems)
|
||||||
|
remainder = unassigned % len(dsItems)
|
||||||
|
for i, item in enumerate(dsItems):
|
||||||
|
item["chunkCount"] += perDs + (1 if i < remainder else 0)
|
||||||
|
|
||||||
jobs = jobService.listJobs(jobType="connection.bootstrap", limit=5)
|
# Pull a wider window than the previous 5 so the "last successful
|
||||||
|
# sync" is found even if a connection has many recent jobs queued.
|
||||||
|
jobs = jobService.listJobs(jobType="connection.bootstrap", limit=50)
|
||||||
connJobs = [j for j in jobs if (j.get("payload") or {}).get("connectionId") == connectionId]
|
connJobs = [j for j in jobs if (j.get("payload") or {}).get("connectionId") == connectionId]
|
||||||
runningJobs = [
|
runningJobs = [
|
||||||
{"jobId": j["id"], "progress": j.get("progress", 0), "progressMessage": j.get("progressMessage", "")}
|
{"jobId": j["id"], "progress": j.get("progress", 0), "progressMessage": j.get("progressMessage", "")}
|
||||||
for j in connJobs
|
for j in connJobs
|
||||||
if j.get("status") in ("PENDING", "RUNNING")
|
if j.get("status") in ("PENDING", "RUNNING")
|
||||||
]
|
]
|
||||||
lastError = None
|
lastError: Optional[Dict[str, Any]] = None
|
||||||
|
lastSuccess: Optional[Dict[str, Any]] = None
|
||||||
for j in connJobs:
|
for j in connJobs:
|
||||||
if j.get("status") == "ERROR":
|
status = j.get("status")
|
||||||
lastError = {"jobId": j["id"], "errorMessage": j.get("errorMessage", "")}
|
if status == "ERROR" and lastError is None:
|
||||||
|
lastError = {
|
||||||
|
"jobId": j["id"],
|
||||||
|
"errorMessage": j.get("errorMessage", ""),
|
||||||
|
"finishedAt": j.get("finishedAt"),
|
||||||
|
}
|
||||||
|
elif status == "SUCCESS" and lastSuccess is None:
|
||||||
|
result = j.get("result") or {}
|
||||||
|
lastSuccess = {
|
||||||
|
"jobId": j["id"],
|
||||||
|
"finishedAt": j.get("finishedAt"),
|
||||||
|
"indexed": result.get("indexed", 0),
|
||||||
|
"skippedDuplicate": result.get("skippedDuplicate", 0),
|
||||||
|
"skippedPolicy": result.get("skippedPolicy", 0),
|
||||||
|
"failed": result.get("failed", 0),
|
||||||
|
"durationMs": result.get("durationMs", 0),
|
||||||
|
}
|
||||||
|
if lastError and lastSuccess:
|
||||||
break
|
break
|
||||||
|
|
||||||
out.append({
|
out.append({
|
||||||
|
|
@ -86,6 +116,7 @@ def _buildConnectionInventory(connections, rootIf, knowledgeIf, jobService) -> L
|
||||||
"totalChunks": connChunkTotal,
|
"totalChunks": connChunkTotal,
|
||||||
"runningJobs": runningJobs,
|
"runningJobs": runningJobs,
|
||||||
"lastError": lastError,
|
"lastError": lastError,
|
||||||
|
"lastSuccess": lastSuccess,
|
||||||
})
|
})
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
@ -182,7 +213,7 @@ def _getInventoryPlatform(
|
||||||
|
|
||||||
@router.post("/reindex/{connectionId}")
|
@router.post("/reindex/{connectionId}")
|
||||||
@limiter.limit("10/minute")
|
@limiter.limit("10/minute")
|
||||||
def _reindexConnection(
|
async def _reindexConnection(
|
||||||
request: Request,
|
request: Request,
|
||||||
connectionId: str,
|
connectionId: str,
|
||||||
currentUser: User = Depends(getCurrentUser),
|
currentUser: User = Depends(getCurrentUser),
|
||||||
|
|
@ -190,12 +221,16 @@ def _reindexConnection(
|
||||||
"""Re-trigger bootstrap for a connection (re-index all ragIndexEnabled DataSources).
|
"""Re-trigger bootstrap for a connection (re-index all ragIndexEnabled DataSources).
|
||||||
|
|
||||||
Submits a new connection.bootstrap job, regardless of previous failures.
|
Submits a new connection.bootstrap job, regardless of previous failures.
|
||||||
|
|
||||||
|
Must be `async def` so `await startJob(...)` registers the `_runJob` task
|
||||||
|
in FastAPI's main event loop. A sync route would land in the worker
|
||||||
|
threadpool and `asyncio.run` would tear down the temporary loop right
|
||||||
|
after `create_task`, leaving the job stuck in PENDING forever.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||||
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
|
from modules.serviceCenter.services.serviceBackgroundJobs import startJob
|
||||||
from modules.datamodels.datamodelDataSource import DataSource
|
from modules.datamodels.datamodelDataSource import DataSource
|
||||||
import asyncio
|
|
||||||
|
|
||||||
rootIf = getRootInterface()
|
rootIf = getRootInterface()
|
||||||
conn = rootIf.getUserConnectionById(connectionId)
|
conn = rootIf.getUserConnectionById(connectionId)
|
||||||
|
|
@ -213,23 +248,13 @@ def _reindexConnection(
|
||||||
authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
|
authority = conn.authority.value if hasattr(conn.authority, "value") else str(conn.authority or "")
|
||||||
dsIds = [(ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")) for ds in ragDs]
|
dsIds = [(ds.get("id") if isinstance(ds, dict) else getattr(ds, "id", "")) for ds in ragDs]
|
||||||
|
|
||||||
async def _enqueue():
|
jobId = await startJob(
|
||||||
return await startJob(
|
"connection.bootstrap",
|
||||||
"connection.bootstrap",
|
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": dsIds},
|
||||||
{"connectionId": connectionId, "authority": authority.lower(), "dataSourceIds": dsIds},
|
triggeredBy=str(currentUser.id),
|
||||||
triggeredBy=str(currentUser.id),
|
)
|
||||||
)
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
if loop.is_running():
|
|
||||||
future = asyncio.ensure_future(_enqueue())
|
|
||||||
jobId = None
|
|
||||||
else:
|
|
||||||
jobId = loop.run_until_complete(_enqueue())
|
|
||||||
except RuntimeError:
|
|
||||||
jobId = asyncio.run(_enqueue())
|
|
||||||
|
|
||||||
logger.info("Reindex triggered for connection %s (%d DataSources)", connectionId, len(dsIds))
|
logger.info("Reindex triggered for connection %s (%d DataSources, jobId=%s)", connectionId, len(dsIds), jobId)
|
||||||
return {"status": "queued", "connectionId": connectionId, "dataSourceCount": len(dsIds), "jobId": jobId}
|
return {"status": "queued", "connectionId": connectionId, "dataSourceCount": len(dsIds), "jobId": jobId}
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable
|
from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Awaitable, Tuple
|
||||||
|
|
||||||
from modules.datamodels.datamodelAi import (
|
from modules.datamodels.datamodelAi import (
|
||||||
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum
|
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum
|
||||||
|
|
@ -360,12 +360,18 @@ async def runAgentLoop(
|
||||||
state.totalToolCalls += len(results)
|
state.totalToolCalls += len(results)
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
|
validationCode = None
|
||||||
|
if isinstance(result.errorDetails, dict):
|
||||||
|
code = result.errorDetails.get("code")
|
||||||
|
if isinstance(code, str):
|
||||||
|
validationCode = code
|
||||||
roundLog.toolCalls.append(ToolCallLog(
|
roundLog.toolCalls.append(ToolCallLog(
|
||||||
toolName=result.toolName,
|
toolName=result.toolName,
|
||||||
args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}),
|
args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}),
|
||||||
success=result.success,
|
success=result.success,
|
||||||
durationMs=result.durationMs,
|
durationMs=result.durationMs,
|
||||||
error=result.error,
|
error=result.error,
|
||||||
|
validationFailureCode=validationCode,
|
||||||
resultData=result.data[:300] if result.data else "",
|
resultData=result.data[:300] if result.data else "",
|
||||||
))
|
))
|
||||||
if not result.success:
|
if not result.success:
|
||||||
|
|
@ -443,6 +449,11 @@ async def runAgentLoop(
|
||||||
trace.totalCostCHF = state.totalCostCHF
|
trace.totalCostCHF = state.totalCostCHF
|
||||||
trace.abortReason = state.abortReason
|
trace.abortReason = state.abortReason
|
||||||
|
|
||||||
|
validationFailures, repairAttempts, successAfterRepair = _computeRepairCounters(trace.rounds)
|
||||||
|
trace.validationFailures = validationFailures
|
||||||
|
trace.repairAttempts = repairAttempts
|
||||||
|
trace.successAfterRepair = successAfterRepair
|
||||||
|
|
||||||
artifactSummary = _buildArtifactSummary(trace.rounds)
|
artifactSummary = _buildArtifactSummary(trace.rounds)
|
||||||
|
|
||||||
yield AgentEvent(
|
yield AgentEvent(
|
||||||
|
|
@ -456,6 +467,9 @@ async def runAgentLoop(
|
||||||
"status": state.status.value,
|
"status": state.status.value,
|
||||||
"abortReason": state.abortReason,
|
"abortReason": state.abortReason,
|
||||||
"artifacts": artifactSummary,
|
"artifacts": artifactSummary,
|
||||||
|
"validationFailures": validationFailures,
|
||||||
|
"repairAttempts": repairAttempts,
|
||||||
|
"successAfterRepair": successAfterRepair,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -720,6 +734,41 @@ def classifyToolResult(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _computeRepairCounters(rounds: List[AgentRoundLog]) -> Tuple[int, int, int]:
|
||||||
|
"""Aggregate repair-loop telemetry across all rounds.
|
||||||
|
|
||||||
|
Returns ``(validationFailures, repairAttempts, successAfterRepair)``.
|
||||||
|
|
||||||
|
* `validationFailures` -- total tool calls rejected by a pre-execute
|
||||||
|
validator (any round, counts every occurrence).
|
||||||
|
* `repairAttempts` -- tool calls in **later** rounds whose `toolName`
|
||||||
|
had been rejected in some **earlier** round. Multiple retries of the
|
||||||
|
same tool count multiple times. We intentionally do not count
|
||||||
|
sibling calls within the same round, since the LLM has not yet seen
|
||||||
|
the first one's result when emitting the second.
|
||||||
|
* `successAfterRepair` -- the subset of `repairAttempts` that passed
|
||||||
|
the validator (``validationFailureCode is None``).
|
||||||
|
"""
|
||||||
|
validationFailures = 0
|
||||||
|
repairAttempts = 0
|
||||||
|
successAfterRepair = 0
|
||||||
|
rejectedTools: set = set()
|
||||||
|
|
||||||
|
for roundLog in rounds:
|
||||||
|
rejectedFromPriorRounds = set(rejectedTools)
|
||||||
|
for tc in roundLog.toolCalls:
|
||||||
|
wasRejectedBefore = tc.toolName in rejectedFromPriorRounds
|
||||||
|
if tc.validationFailureCode is not None:
|
||||||
|
validationFailures += 1
|
||||||
|
if wasRejectedBefore:
|
||||||
|
repairAttempts += 1
|
||||||
|
rejectedTools.add(tc.toolName)
|
||||||
|
elif wasRejectedBefore:
|
||||||
|
repairAttempts += 1
|
||||||
|
successAfterRepair += 1
|
||||||
|
return validationFailures, repairAttempts, successAfterRepair
|
||||||
|
|
||||||
|
|
||||||
_ARTIFACT_TOOLS = {"writeFile", "replaceInFile", "deleteFile", "renameFile", "copyFile",
|
_ARTIFACT_TOOLS = {"writeFile", "replaceInFile", "deleteFile", "renameFile", "copyFile",
|
||||||
"createFolder", "deleteFolder", "renderDocument", "generateImage"}
|
"createFolder", "deleteFolder", "renderDocument", "generateImage"}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,20 @@ from modules.serviceCenter.services.serviceAgent.coreTools._helpers import (
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_STALE_EXTRACTION_PATTERNS = (
|
||||||
|
"requires the extract-msg package",
|
||||||
|
"extraction requires the",
|
||||||
|
"will be treated as binary",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _isStaleExtractionResult(text: str) -> bool:
|
||||||
|
"""Detect cached extraction results that are just error/warning placeholders."""
|
||||||
|
if len(text) > 500:
|
||||||
|
return False
|
||||||
|
textLower = text.lower()
|
||||||
|
return any(p in textLower for p in _STALE_EXTRACTION_PATTERNS)
|
||||||
|
|
||||||
|
|
||||||
import uuid as _uuid
|
import uuid as _uuid
|
||||||
|
|
||||||
|
|
@ -62,15 +76,16 @@ def _registerWorkspaceTools(registry: ToolRegistry, services):
|
||||||
]
|
]
|
||||||
if textChunks:
|
if textChunks:
|
||||||
assembled = "\n\n".join(c["data"] for c in textChunks)
|
assembled = "\n\n".join(c["data"] for c in textChunks)
|
||||||
chunked = _applyOffsetLimit(assembled, offset, limit)
|
if not _isStaleExtractionResult(assembled):
|
||||||
if chunked is not None:
|
chunked = _applyOffsetLimit(assembled, offset, limit)
|
||||||
return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked)
|
if chunked is not None:
|
||||||
if len(assembled) > _MAX_TOOL_RESULT_CHARS:
|
return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked)
|
||||||
assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}. Use offset/limit to read specific sections.]"
|
if len(assembled) > _MAX_TOOL_RESULT_CHARS:
|
||||||
return ToolResult(
|
assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}. Use offset/limit to read specific sections.]"
|
||||||
toolCallId="", toolName="readFile", success=True,
|
return ToolResult(
|
||||||
data=assembled,
|
toolCallId="", toolName="readFile", success=True,
|
||||||
)
|
data=assembled,
|
||||||
|
)
|
||||||
elif fileStatus in ("processing", "embedding", "extracted"):
|
elif fileStatus in ("processing", "embedding", "extracted"):
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
toolCallId="", toolName="readFile", success=True,
|
toolCallId="", toolName="readFile", success=True,
|
||||||
|
|
@ -101,12 +116,31 @@ def _registerWorkspaceTools(registry: ToolRegistry, services):
|
||||||
isBinary = _looksLikeBinary(rawBytes)
|
isBinary = _looksLikeBinary(rawBytes)
|
||||||
|
|
||||||
if isBinary:
|
if isBinary:
|
||||||
|
extractionService = services.getService("extraction") if hasattr(services, "getService") else None
|
||||||
|
if extractionService:
|
||||||
|
try:
|
||||||
|
extracted = extractionService.extractContentFromBytes(
|
||||||
|
rawBytes, fileName, mimeType, documentId=fileId,
|
||||||
|
)
|
||||||
|
textParts = [
|
||||||
|
p.data for p in (extracted.parts or [])
|
||||||
|
if getattr(p, "contentType", "") != "image" and getattr(p, "data", None)
|
||||||
|
]
|
||||||
|
if textParts:
|
||||||
|
assembled = "\n\n".join(textParts)
|
||||||
|
chunked = _applyOffsetLimit(assembled, offset, limit)
|
||||||
|
if chunked is not None:
|
||||||
|
return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked)
|
||||||
|
if len(assembled) > _MAX_TOOL_RESULT_CHARS:
|
||||||
|
assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}. Use offset/limit to read specific sections.]"
|
||||||
|
return ToolResult(toolCallId="", toolName="readFile", success=True, data=assembled)
|
||||||
|
except Exception as extractErr:
|
||||||
|
logger.warning("readFile: inline extraction failed for %s: %s", fileId, extractErr)
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
toolCallId="", toolName="readFile", success=True,
|
toolCallId="", toolName="readFile", success=True,
|
||||||
data=(
|
data=(
|
||||||
f"[File '{fileName}' ({mimeType}) is not yet indexed "
|
f"[File '{fileName}' ({mimeType}) is binary and could not be extracted "
|
||||||
f"(status: {fileStatus or 'unknown'}). Indexing runs automatically "
|
f"(status: {fileStatus or 'unknown'}). "
|
||||||
f"on upload. Please wait a few seconds and retry, or re-upload the file. "
|
|
||||||
f"For visual content use describeImage(fileId='{fileId}').]"
|
f"For visual content use describeImage(fileId='{fileId}').]"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,14 @@ class ToolResult(BaseModel):
|
||||||
success: bool = True
|
success: bool = True
|
||||||
data: str = ""
|
data: str = ""
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
errorDetails: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Structured, machine-readable error payload for the LLM (e.g. validation "
|
||||||
|
"repair hints with code/field/suggestion/hint). `error` remains the short "
|
||||||
|
"human-readable text for logs and audit."
|
||||||
|
),
|
||||||
|
)
|
||||||
durationMs: int = 0
|
durationMs: int = 0
|
||||||
sideEvents: Optional[List[Dict[str, Any]]] = None
|
sideEvents: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
|
@ -141,6 +149,14 @@ class ToolCallLog(BaseModel):
|
||||||
success: bool = True
|
success: bool = True
|
||||||
durationMs: int = 0
|
durationMs: int = 0
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
validationFailureCode: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If the tool call was rejected by a pre-execute validator (e.g. "
|
||||||
|
"QueryValidator), the structured error code (e.g. FIELD_NOT_FOUND). "
|
||||||
|
"None when the call ran cleanly or failed for other reasons."
|
||||||
|
),
|
||||||
|
)
|
||||||
resultData: str = Field(default="", description="Short result summary for artifact tracking")
|
resultData: str = Field(default="", description="Short result summary for artifact tracking")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -167,6 +183,24 @@ class AgentTrace(BaseModel):
|
||||||
totalToolCalls: int = 0
|
totalToolCalls: int = 0
|
||||||
totalCostCHF: float = 0.0
|
totalCostCHF: float = 0.0
|
||||||
abortReason: Optional[str] = None
|
abortReason: Optional[str] = None
|
||||||
|
validationFailures: int = Field(
|
||||||
|
default=0,
|
||||||
|
description="Total tool calls rejected by a pre-execute validator across the run.",
|
||||||
|
)
|
||||||
|
repairAttempts: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"Number of times the LLM retried a previously rejected tool (same toolName) "
|
||||||
|
"in a later round. Counted by `agentLoop` from per-round ToolCallLog entries."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
successAfterRepair: int = Field(
|
||||||
|
default=0,
|
||||||
|
description=(
|
||||||
|
"Number of repair attempts that produced a clean (validationFailureCode=None) "
|
||||||
|
"result. Combined with `repairAttempts` this gives the repair conversion rate."
|
||||||
|
),
|
||||||
|
)
|
||||||
rounds: List[AgentRoundLog] = Field(default_factory=list)
|
rounds: List[AgentRoundLog] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
203
modules/serviceCenter/services/serviceAgent/datamodelOntology.py
Normal file
203
modules/serviceCenter/services/serviceAgent/datamodelOntology.py
Normal file
|
|
@ -0,0 +1,203 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Ontology data model for feature data sub-agents.
|
||||||
|
|
||||||
|
This module defines the data structures that describe a feature's data
|
||||||
|
ontology -- entities, relations, constraints, canonical query patterns --
|
||||||
|
plus the validation error payload used by the QueryValidator.
|
||||||
|
|
||||||
|
Phase 1 (Repair-Loop) only needs `QueryValidationError`, `Constraint`,
|
||||||
|
`ConstraintRule` and `ValidationErrorCode`; the richer `Entity`/`Relation`/
|
||||||
|
`OntologyDescriptor` types are defined here so Phase 2 (Trustee ontology
|
||||||
|
pilot) can plug in without a second data-model change.
|
||||||
|
|
||||||
|
See `wiki/c-work/2-build/2026-05-feature-data-agent-ontology-and-repair.md`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationErrorCode(str, Enum):
|
||||||
|
"""Stable codes for validator failures.
|
||||||
|
|
||||||
|
The LLM sees these codes verbatim in `ToolResult.errorDetails["code"]`
|
||||||
|
and is expected to react to them deterministically (e.g. inspect the
|
||||||
|
schema via browseTable when FIELD_NOT_FOUND, drop the SUM when
|
||||||
|
INVALID_AGGREGATE_TARGET, add a period filter when MISSING_REQUIRED_FILTER).
|
||||||
|
"""
|
||||||
|
FIELD_NOT_FOUND = "FIELD_NOT_FOUND"
|
||||||
|
INVALID_AGGREGATE_TARGET = "INVALID_AGGREGATE_TARGET"
|
||||||
|
WRONG_TABLE_FOR_PURPOSE = "WRONG_TABLE_FOR_PURPOSE"
|
||||||
|
TYPE_MISMATCH = "TYPE_MISMATCH"
|
||||||
|
OPERATOR_INCOMPATIBLE = "OPERATOR_INCOMPATIBLE"
|
||||||
|
MISSING_REQUIRED_FILTER = "MISSING_REQUIRED_FILTER"
|
||||||
|
ORDER_BY_INVALID = "ORDER_BY_INVALID"
|
||||||
|
|
||||||
|
|
||||||
|
class QueryValidationError(BaseModel):
|
||||||
|
"""Structured pre-execute validation error.
|
||||||
|
|
||||||
|
Serialized into `ToolResult.errorDetails` (machine-readable) and
|
||||||
|
summarized into `ToolResult.error` (short human-readable string).
|
||||||
|
"""
|
||||||
|
code: ValidationErrorCode
|
||||||
|
field: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The offending field name (when applicable).",
|
||||||
|
)
|
||||||
|
suggestion: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Best-effort suggestion (e.g. fuzzy-matched valid field name). "
|
||||||
|
"None when no useful suggestion exists."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
hint: str = Field(
|
||||||
|
description="Short corrective hint, max ~80 chars. Surfaced to the LLM verbatim.",
|
||||||
|
max_length=160,
|
||||||
|
)
|
||||||
|
|
||||||
|
def toShortError(self) -> str:
|
||||||
|
"""Build the short `error` string for logs/audit.
|
||||||
|
|
||||||
|
Format: `<CODE>: <hint>` (or with field when present).
|
||||||
|
"""
|
||||||
|
if self.field:
|
||||||
|
return f"{self.code.value}: {self.field}: {self.hint}"
|
||||||
|
return f"{self.code.value}: {self.hint}"
|
||||||
|
|
||||||
|
def toErrorDetails(self) -> Dict[str, Any]:
|
||||||
|
"""Build the dict for `ToolResult.errorDetails`."""
|
||||||
|
return {
|
||||||
|
"code": self.code.value,
|
||||||
|
"field": self.field,
|
||||||
|
"suggestion": self.suggestion,
|
||||||
|
"hint": self.hint,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ConstraintRule(str, Enum):
|
||||||
|
"""High-level rule kinds that can be attached to a field or table."""
|
||||||
|
NEVER_AGGREGATE = "NEVER_AGGREGATE"
|
||||||
|
REQUIRES_FILTER_ON = "REQUIRES_FILTER_ON"
|
||||||
|
TYPE_MISMATCH_GUARD = "TYPE_MISMATCH_GUARD"
|
||||||
|
PREFERRED_TABLE_FOR_INTENT = "PREFERRED_TABLE_FOR_INTENT"
|
||||||
|
|
||||||
|
|
||||||
|
class Constraint(BaseModel):
|
||||||
|
"""A single rule the validator and the prompt compiler both consume.
|
||||||
|
|
||||||
|
Phase 1 uses constraints declared inline by the validator (defaults
|
||||||
|
derived from naming conventions like ``*Balance`` / ``*Total``).
|
||||||
|
Phase 2 sources them from feature ontologies, replacing the
|
||||||
|
convention-based defaults.
|
||||||
|
"""
|
||||||
|
appliesTo: str = Field(
|
||||||
|
description=(
|
||||||
|
"Target identifier, format depends on rule: `<Table>.<field>` for "
|
||||||
|
"field-level constraints, `<Table>` for table-level."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
rule: ConstraintRule
|
||||||
|
message: str = Field(
|
||||||
|
description="Short hint forwarded to the LLM if the constraint fires.",
|
||||||
|
max_length=160,
|
||||||
|
)
|
||||||
|
params: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description=(
|
||||||
|
"Rule-specific extras, e.g. {'requiredFields': ['periodYear', 'periodMonth']} "
|
||||||
|
"for REQUIRES_FILTER_ON."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SemanticType(str, Enum):
|
||||||
|
"""High-level semantic category an entity belongs to.
|
||||||
|
|
||||||
|
Coarser than the underlying Pydantic type -- used so the prompt compiler
|
||||||
|
can group entities ("here are your ACCOUNT-like tables") without the LLM
|
||||||
|
having to read the full schema.
|
||||||
|
"""
|
||||||
|
ACCOUNT = "ACCOUNT"
|
||||||
|
BALANCE_SNAPSHOT = "BALANCE_SNAPSHOT"
|
||||||
|
TRANSACTION = "TRANSACTION"
|
||||||
|
DOCUMENT = "DOCUMENT"
|
||||||
|
PARTY = "PARTY"
|
||||||
|
PERIOD = "PERIOD"
|
||||||
|
OTHER = "OTHER"
|
||||||
|
|
||||||
|
|
||||||
|
class Cardinality(str, Enum):
|
||||||
|
ONE_TO_ONE = "ONE_TO_ONE"
|
||||||
|
ONE_TO_MANY = "ONE_TO_MANY"
|
||||||
|
MANY_TO_ONE = "MANY_TO_ONE"
|
||||||
|
MANY_TO_MANY = "MANY_TO_MANY"
|
||||||
|
|
||||||
|
|
||||||
|
class Invariant(BaseModel):
|
||||||
|
"""Free-form invariant attached to an entity.
|
||||||
|
|
||||||
|
Phase 1 leaves these as opaque text consumed by the prompt compiler.
|
||||||
|
Future phases may add a structured rule kind.
|
||||||
|
"""
|
||||||
|
description: str = Field(max_length=200)
|
||||||
|
|
||||||
|
|
||||||
|
class Entity(BaseModel):
|
||||||
|
"""One semantic entity in the ontology (often backed by a Pydantic table)."""
|
||||||
|
name: str
|
||||||
|
pythonClass: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="MODEL_REGISTRY key when the entity is DB-backed (e.g. 'TrusteeDataAccountBalance').",
|
||||||
|
)
|
||||||
|
semanticType: SemanticType = SemanticType.OTHER
|
||||||
|
parentEntity: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Name of a broader entity this one specializes (e.g. 'BankAccount' parentEntity 'Account').",
|
||||||
|
)
|
||||||
|
description: str = ""
|
||||||
|
invariants: List[Invariant] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class Relation(BaseModel):
|
||||||
|
fromEntity: str
|
||||||
|
toEntity: str
|
||||||
|
cardinality: Cardinality
|
||||||
|
via: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="FK-Feldname auf der fromEntity-Seite (z. B. 'journalEntryId').",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CanonicalQueryPattern(BaseModel):
|
||||||
|
"""Tool-call skeleton for a recurring user intent.
|
||||||
|
|
||||||
|
The prompt compiler renders these as worked examples so the LLM has a
|
||||||
|
template to mimic instead of inventing a query shape.
|
||||||
|
"""
|
||||||
|
intent: str = Field(description="Short label, e.g. 'BANK_BALANCE_AT_DATE'.")
|
||||||
|
description: str = Field(default="", description="Human-readable when to use this pattern.")
|
||||||
|
pattern: Dict[str, Any] = Field(
|
||||||
|
description="Tool-call shape with placeholders, e.g. {'tool': 'queryTable', 'tableName': '...', 'filters': [...]}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyDescriptor(BaseModel):
|
||||||
|
"""Top-level container exported by `getAgentOntology()` per feature."""
|
||||||
|
featureCode: str
|
||||||
|
entities: List[Entity] = Field(default_factory=list)
|
||||||
|
relations: List[Relation] = Field(default_factory=list)
|
||||||
|
constraints: List[Constraint] = Field(default_factory=list)
|
||||||
|
canonicalPatterns: List[CanonicalQueryPattern] = Field(default_factory=list)
|
||||||
|
|
||||||
|
def constraintsForTable(self, tableName: str) -> List[Constraint]:
|
||||||
|
"""Return constraints whose ``appliesTo`` targets the given table or one of its fields."""
|
||||||
|
prefix = f"{tableName}."
|
||||||
|
return [
|
||||||
|
c for c in self.constraints
|
||||||
|
if c.appliesTo == tableName or c.appliesTo.startswith(prefix)
|
||||||
|
]
|
||||||
|
|
@ -15,6 +15,7 @@ invoked outside an agent loop (e.g. in tests).
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import Any, Callable, Awaitable, Dict, List, Optional
|
from typing import Any, Callable, Awaitable, Dict, List, Optional
|
||||||
|
|
||||||
from modules.datamodels.datamodelAi import (
|
from modules.datamodels.datamodelAi import (
|
||||||
|
|
@ -25,6 +26,10 @@ from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop
|
||||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||||
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
|
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
|
||||||
)
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
|
||||||
|
QueryValidationError,
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.queryValidator import QueryValidator
|
||||||
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
||||||
from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
|
from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
|
||||||
from modules.shared.i18nRegistry import resolveText
|
from modules.shared.i18nRegistry import resolveText
|
||||||
|
|
@ -83,7 +88,8 @@ async def runFeatureDataAgent(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider = FeatureDataProvider(dbConnector, neutralizeFields=neutralizeFields)
|
provider = FeatureDataProvider(dbConnector, neutralizeFields=neutralizeFields)
|
||||||
registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {})
|
validator = _buildValidatorForFeature(featureCode)
|
||||||
|
registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {}, validator=validator)
|
||||||
|
|
||||||
for tbl in selectedTables:
|
for tbl in selectedTables:
|
||||||
meta = tbl.get("meta", {})
|
meta = tbl.get("meta", {})
|
||||||
|
|
@ -153,10 +159,19 @@ def _buildSubAgentTools(
|
||||||
featureInstanceId: str,
|
featureInstanceId: str,
|
||||||
mandateId: str,
|
mandateId: str,
|
||||||
tableFilters: Dict[str, Dict[str, str]] = None,
|
tableFilters: Dict[str, Dict[str, str]] = None,
|
||||||
|
validator: Optional[QueryValidator] = None,
|
||||||
) -> ToolRegistry:
|
) -> ToolRegistry:
|
||||||
"""Register browseTable and queryTable as sub-agent tools."""
|
"""Register browseTable and queryTable as sub-agent tools.
|
||||||
|
|
||||||
|
The optional ``validator`` runs **before** the provider on every call.
|
||||||
|
When it returns a structured error, the tool result carries
|
||||||
|
``errorDetails`` (machine-readable repair hint for the LLM) plus the
|
||||||
|
short ``error`` string for logs/audit. No provider call happens in that
|
||||||
|
case, so the database is never reached with a known-bad query.
|
||||||
|
"""
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
_tableFilters = tableFilters or {}
|
_tableFilters = tableFilters or {}
|
||||||
|
_validator = validator or QueryValidator()
|
||||||
|
|
||||||
def _recordFilterToList(tableName: str) -> Optional[List[Dict[str, Any]]]:
|
def _recordFilterToList(tableName: str) -> Optional[List[Dict[str, Any]]]:
|
||||||
"""Convert a recordFilter dict to a list of {field, op, value} filter dicts."""
|
"""Convert a recordFilter dict to a list of {field, op, value} filter dicts."""
|
||||||
|
|
@ -165,6 +180,14 @@ def _buildSubAgentTools(
|
||||||
return None
|
return None
|
||||||
return [{"field": k, "op": "=", "value": v} for k, v in rf.items()]
|
return [{"field": k, "op": "=", "value": v} for k, v in rf.items()]
|
||||||
|
|
||||||
|
def _validationToolResult(toolName: str, err: QueryValidationError) -> ToolResult:
|
||||||
|
return ToolResult(
|
||||||
|
toolCallId="", toolName=toolName,
|
||||||
|
success=False,
|
||||||
|
error=err.toShortError(),
|
||||||
|
errorDetails=err.toErrorDetails(),
|
||||||
|
)
|
||||||
|
|
||||||
async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]):
|
async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]):
|
||||||
tableName = args.get("tableName", "")
|
tableName = args.get("tableName", "")
|
||||||
limit = args.get("limit", 50)
|
limit = args.get("limit", 50)
|
||||||
|
|
@ -172,6 +195,9 @@ def _buildSubAgentTools(
|
||||||
fields = args.get("fields")
|
fields = args.get("fields")
|
||||||
if not tableName:
|
if not tableName:
|
||||||
return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required")
|
return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required")
|
||||||
|
validationErr = _validator.validateBrowseQuery(tableName, args)
|
||||||
|
if validationErr is not None:
|
||||||
|
return _validationToolResult("browseTable", validationErr)
|
||||||
result = provider.browseTable(
|
result = provider.browseTable(
|
||||||
tableName=tableName,
|
tableName=tableName,
|
||||||
featureInstanceId=featureInstanceId,
|
featureInstanceId=featureInstanceId,
|
||||||
|
|
@ -197,6 +223,9 @@ def _buildSubAgentTools(
|
||||||
offset = args.get("offset", 0)
|
offset = args.get("offset", 0)
|
||||||
if not tableName:
|
if not tableName:
|
||||||
return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required")
|
return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required")
|
||||||
|
validationErr = _validator.validateQueryTable(tableName, args)
|
||||||
|
if validationErr is not None:
|
||||||
|
return _validationToolResult("queryTable", validationErr)
|
||||||
result = provider.queryTable(
|
result = provider.queryTable(
|
||||||
tableName=tableName,
|
tableName=tableName,
|
||||||
featureInstanceId=featureInstanceId,
|
featureInstanceId=featureInstanceId,
|
||||||
|
|
@ -220,12 +249,19 @@ def _buildSubAgentTools(
|
||||||
aggregate = args.get("aggregate", "")
|
aggregate = args.get("aggregate", "")
|
||||||
field = args.get("field", "")
|
field = args.get("field", "")
|
||||||
groupBy = args.get("groupBy")
|
groupBy = args.get("groupBy")
|
||||||
|
filters = args.get("filters") or []
|
||||||
if not tableName:
|
if not tableName:
|
||||||
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="tableName required")
|
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="tableName required")
|
||||||
if not aggregate:
|
if not aggregate:
|
||||||
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="aggregate required (SUM, COUNT, AVG, MIN, MAX)")
|
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="aggregate required (SUM, COUNT, AVG, MIN, MAX)")
|
||||||
if not field:
|
if not field:
|
||||||
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="field required")
|
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="field required")
|
||||||
|
validationErr = _validator.validateAggregateQuery(tableName, args)
|
||||||
|
if validationErr is not None:
|
||||||
|
return _validationToolResult("aggregateTable", validationErr)
|
||||||
|
combinedFilters = list(filters)
|
||||||
|
recordFilters = _recordFilterToList(tableName) or []
|
||||||
|
combinedFilters.extend(recordFilters)
|
||||||
result = provider.aggregateTable(
|
result = provider.aggregateTable(
|
||||||
tableName=tableName,
|
tableName=tableName,
|
||||||
featureInstanceId=featureInstanceId,
|
featureInstanceId=featureInstanceId,
|
||||||
|
|
@ -233,7 +269,7 @@ def _buildSubAgentTools(
|
||||||
aggregate=aggregate,
|
aggregate=aggregate,
|
||||||
field=field,
|
field=field,
|
||||||
groupBy=groupBy,
|
groupBy=groupBy,
|
||||||
extraFilters=_recordFilterToList(tableName),
|
extraFilters=combinedFilters or None,
|
||||||
)
|
)
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
toolCallId="", toolName="aggregateTable",
|
toolCallId="", toolName="aggregateTable",
|
||||||
|
|
@ -246,8 +282,12 @@ def _buildSubAgentTools(
|
||||||
"aggregateTable", _aggregateTable,
|
"aggregateTable", _aggregateTable,
|
||||||
description=(
|
description=(
|
||||||
"Run an aggregate query on a feature data table. "
|
"Run an aggregate query on a feature data table. "
|
||||||
"Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY. "
|
"Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY and filters. "
|
||||||
"Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', field='debitAmount', groupBy='costCenter')"
|
"Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', "
|
||||||
|
"field='debitAmount', filters=[{'field':'accountNumber','op':'=','value':'5400'}]). "
|
||||||
|
"On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- "
|
||||||
|
"read errorDetails and correct the next call (e.g. drop the SUM, switch to queryTable with period filters, "
|
||||||
|
"or use the suggested field name)."
|
||||||
),
|
),
|
||||||
parameters={
|
parameters={
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|
@ -256,6 +296,22 @@ def _buildSubAgentTools(
|
||||||
"aggregate": {"type": "string", "enum": ["SUM", "COUNT", "AVG", "MIN", "MAX"], "description": "Aggregate function"},
|
"aggregate": {"type": "string", "enum": ["SUM", "COUNT", "AVG", "MIN", "MAX"], "description": "Aggregate function"},
|
||||||
"field": {"type": "string", "description": "Field to aggregate (e.g. debitAmount, creditAmount)"},
|
"field": {"type": "string", "description": "Field to aggregate (e.g. debitAmount, creditAmount)"},
|
||||||
"groupBy": {"type": "string", "description": "Optional field to group by (e.g. costCenter, accountNumber)"},
|
"groupBy": {"type": "string", "description": "Optional field to group by (e.g. costCenter, accountNumber)"},
|
||||||
|
"filters": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"field": {"type": "string"},
|
||||||
|
"op": {"type": "string"},
|
||||||
|
"value": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"description": (
|
||||||
|
"Optional filter conditions applied before the aggregate. Same shape as queryTable's "
|
||||||
|
"filters. Required whenever you want to aggregate only a subset (e.g. SUM debits on "
|
||||||
|
"ONE account, COUNT rows in ONE year)."
|
||||||
|
),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": ["tableName", "aggregate", "field"],
|
"required": ["tableName", "aggregate", "field"],
|
||||||
},
|
},
|
||||||
|
|
@ -264,7 +320,11 @@ def _buildSubAgentTools(
|
||||||
|
|
||||||
registry.register(
|
registry.register(
|
||||||
"browseTable", _browseTable,
|
"browseTable", _browseTable,
|
||||||
description="List rows from a feature data table with pagination.",
|
description=(
|
||||||
|
"List rows from a feature data table with pagination. "
|
||||||
|
"On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- "
|
||||||
|
"use errorDetails to correct the next call."
|
||||||
|
),
|
||||||
parameters={
|
parameters={
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
@ -286,7 +346,10 @@ def _buildSubAgentTools(
|
||||||
description=(
|
description=(
|
||||||
"Query a feature data table with filters, field selection, and ordering. "
|
"Query a feature data table with filters, field selection, and ordering. "
|
||||||
"Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. "
|
"Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. "
|
||||||
"Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL."
|
"Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL. "
|
||||||
|
"On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- "
|
||||||
|
"common codes: FIELD_NOT_FOUND (use the suggestion or call browseTable), OPERATOR_INCOMPATIBLE "
|
||||||
|
"(switch to a compatible operator for that field type), ORDER_BY_INVALID."
|
||||||
),
|
),
|
||||||
parameters={
|
parameters={
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|
@ -410,13 +473,94 @@ def _buildSchemaContext(
|
||||||
"- Keep your answer SHORT. The caller is a machine, not a human.",
|
"- Keep your answer SHORT. The caller is a machine, not a human.",
|
||||||
]
|
]
|
||||||
|
|
||||||
domainHints = _loadFeatureDomainHints(featureCode)
|
domainBlock = ""
|
||||||
if domainHints:
|
if not _isOntologyDisabled():
|
||||||
parts.extend(["", domainHints.strip()])
|
domainBlock = _loadFeatureOntologyBlock(featureCode)
|
||||||
|
if not domainBlock:
|
||||||
|
domainBlock = _loadFeatureDomainHints(featureCode)
|
||||||
|
if domainBlock:
|
||||||
|
parts.extend(["", domainBlock.strip()])
|
||||||
|
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _isOntologyDisabled() -> bool:
|
||||||
|
"""Eval-only escape hatch.
|
||||||
|
|
||||||
|
Set ``POWERON_DISABLE_FEATURE_ONTOLOGY=1`` in the environment to force
|
||||||
|
``_buildSchemaContext`` back onto the legacy ``getAgentDomainHints()``
|
||||||
|
path. Used by the Phase 1.5 benchmark to measure ``baseline`` and
|
||||||
|
``phase1`` accuracy WITHOUT the ontology-driven prompt block. Never
|
||||||
|
set this flag in production.
|
||||||
|
"""
|
||||||
|
return os.environ.get("POWERON_DISABLE_FEATURE_ONTOLOGY", "").strip() in ("1", "true", "TRUE", "yes")
|
||||||
|
|
||||||
|
|
||||||
|
def _buildValidatorForFeature(featureCode: str) -> QueryValidator:
|
||||||
|
"""Construct a QueryValidator wired with the feature ontology (when present).
|
||||||
|
|
||||||
|
Without an ontology the validator falls back to its convention-based
|
||||||
|
constraints (``*Balance`` / ``*Total`` are NEVER_AGGREGATE). With an
|
||||||
|
ontology the descriptor's constraints take precedence -- the validator
|
||||||
|
and the prompt block then share the same source of truth.
|
||||||
|
"""
|
||||||
|
ontology = _loadFeatureOntology(featureCode)
|
||||||
|
return QueryValidator(ontology=ontology)
|
||||||
|
|
||||||
|
|
||||||
|
def _loadFeatureOntology(featureCode: str):
|
||||||
|
"""Return the feature's OntologyDescriptor or None when no hook is exposed."""
|
||||||
|
if not featureCode:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
from modules.system.registry import loadFeatureMainModules
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
mainModules = loadFeatureMainModules() or {}
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Ontology lookup: cannot load main modules (%s)", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
module = mainModules.get(featureCode) or mainModules.get(featureCode.lower())
|
||||||
|
if module is None:
|
||||||
|
return None
|
||||||
|
hook = getattr(module, "getAgentOntology", None)
|
||||||
|
if not callable(hook):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return hook()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Feature '%s' getAgentOntology() raised: %s", featureCode, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _loadFeatureOntologyBlock(featureCode: str) -> str:
|
||||||
|
"""Return the ontology-derived prompt block when the feature exposes one.
|
||||||
|
|
||||||
|
Each feature can expose ``getAgentOntology() -> OntologyDescriptor`` in
|
||||||
|
its ``mainXxx.py``. When present, the descriptor is compiled via
|
||||||
|
:func:`ontologyToPromptCompiler.compileOntologyToPrompt` and the result
|
||||||
|
replaces the legacy ``getAgentDomainHints()`` text block. This keeps
|
||||||
|
one single source of truth for the validator AND the prompt.
|
||||||
|
|
||||||
|
Failures are swallowed (missing hook, exceptions in compilation) so the
|
||||||
|
caller can fall back to the legacy domain-hints path.
|
||||||
|
"""
|
||||||
|
ontology = _loadFeatureOntology(featureCode)
|
||||||
|
if ontology is None:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
from modules.serviceCenter.services.serviceAgent.ontologyToPromptCompiler import (
|
||||||
|
compileOntologyToPrompt,
|
||||||
|
)
|
||||||
|
return compileOntologyToPrompt(ontology)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Ontology compile failed for '%s': %s", featureCode, exc)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def _loadFeatureDomainHints(featureCode: str) -> str:
|
def _loadFeatureDomainHints(featureCode: str) -> str:
|
||||||
"""Pull optional domain-specific hints from the feature's main module.
|
"""Pull optional domain-specific hints from the feature's main module.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Deterministic compiler: OntologyDescriptor -> sub-agent prompt block.
|
||||||
|
|
||||||
|
Phase 2 replaces a feature's hand-written ``_AGENT_DOMAIN_HINTS`` text
|
||||||
|
with a structured :class:`OntologyDescriptor`. This compiler renders the
|
||||||
|
descriptor into a stable, terse Markdown-ish block that the sub-agent
|
||||||
|
appends to its system prompt -- the same source of truth the
|
||||||
|
:class:`QueryValidator` consults.
|
||||||
|
|
||||||
|
The output is intentionally:
|
||||||
|
* short (every token costs every call)
|
||||||
|
* deterministic (no f-string ordering bugs, no Python dict iteration)
|
||||||
|
* free of internal jargon ('canonicalQueryPattern' is rendered as
|
||||||
|
'CANONICAL PATTERN' for the LLM)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Iterable, List
|
||||||
|
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
|
||||||
|
CanonicalQueryPattern,
|
||||||
|
Constraint,
|
||||||
|
ConstraintRule,
|
||||||
|
Entity,
|
||||||
|
OntologyDescriptor,
|
||||||
|
Relation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compileOntologyToPrompt(ontology: OntologyDescriptor) -> str:
|
||||||
|
"""Render *ontology* into a sub-agent prompt block.
|
||||||
|
|
||||||
|
The output starts with a stable marker line (``DOMAIN ONTOLOGY (...)``)
|
||||||
|
so downstream tooling can find/replace it deterministically.
|
||||||
|
"""
|
||||||
|
lines: List[str] = []
|
||||||
|
lines.append(f"DOMAIN ONTOLOGY ({ontology.featureCode}):")
|
||||||
|
lines.append("")
|
||||||
|
lines.extend(_renderEntities(ontology.entities))
|
||||||
|
relationLines = _renderRelations(ontology.relations)
|
||||||
|
if relationLines:
|
||||||
|
lines.append("")
|
||||||
|
lines.extend(relationLines)
|
||||||
|
constraintLines = _renderConstraints(ontology.constraints)
|
||||||
|
if constraintLines:
|
||||||
|
lines.append("")
|
||||||
|
lines.extend(constraintLines)
|
||||||
|
patternLines = _renderPatterns(ontology.canonicalPatterns)
|
||||||
|
if patternLines:
|
||||||
|
lines.append("")
|
||||||
|
lines.extend(patternLines)
|
||||||
|
return "\n".join(lines).rstrip() + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _renderEntities(entities: Iterable[Entity]) -> List[str]:
|
||||||
|
out: List[str] = ["ENTITIES:"]
|
||||||
|
for e in entities:
|
||||||
|
head = f"- {e.name}"
|
||||||
|
if e.parentEntity:
|
||||||
|
head += f" (specializes {e.parentEntity})"
|
||||||
|
if e.pythonClass:
|
||||||
|
head += f" [table: {e.pythonClass}]"
|
||||||
|
out.append(head)
|
||||||
|
if e.description:
|
||||||
|
out.append(f" {e.description}")
|
||||||
|
for inv in e.invariants:
|
||||||
|
out.append(f" * {inv.description}")
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _renderRelations(relations: Iterable[Relation]) -> List[str]:
|
||||||
|
rels = list(relations)
|
||||||
|
if not rels:
|
||||||
|
return []
|
||||||
|
out: List[str] = ["RELATIONS:"]
|
||||||
|
for r in rels:
|
||||||
|
line = f"- {r.fromEntity} -> {r.toEntity} ({r.cardinality.value}"
|
||||||
|
if r.via:
|
||||||
|
line += f" via {r.via}"
|
||||||
|
line += ")"
|
||||||
|
out.append(line)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _renderConstraints(constraints: Iterable[Constraint]) -> List[str]:
|
||||||
|
cons = list(constraints)
|
||||||
|
if not cons:
|
||||||
|
return []
|
||||||
|
out: List[str] = ["CONSTRAINTS (validator-enforced):"]
|
||||||
|
for c in cons:
|
||||||
|
rule = _ruleLabel(c.rule)
|
||||||
|
line = f"- {rule} on {c.appliesTo}: {c.message}"
|
||||||
|
params = c.params or {}
|
||||||
|
required = params.get("requiredFields")
|
||||||
|
if isinstance(required, list) and required:
|
||||||
|
line += f" (required filters: {', '.join(required)})"
|
||||||
|
intents = params.get("intents")
|
||||||
|
if isinstance(intents, list) and intents:
|
||||||
|
line += f" (intents: {', '.join(intents)})"
|
||||||
|
out.append(line)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _ruleLabel(rule: ConstraintRule) -> str:
|
||||||
|
return rule.value.replace("_", " ").lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _renderPatterns(patterns: Iterable[CanonicalQueryPattern]) -> List[str]:
|
||||||
|
pats = list(patterns)
|
||||||
|
if not pats:
|
||||||
|
return []
|
||||||
|
out: List[str] = ["CANONICAL QUERY PATTERNS (mimic these tool calls):"]
|
||||||
|
for i, p in enumerate(pats, start=1):
|
||||||
|
out.append(f"{i}) intent={p.intent}: {p.description}")
|
||||||
|
out.append(f" call: {_renderPatternCall(p.pattern)}")
|
||||||
|
extra = p.pattern.get("_postProcessing") if isinstance(p.pattern, dict) else None
|
||||||
|
if isinstance(extra, str):
|
||||||
|
out.append(f" note: {extra}")
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _renderPatternCall(pattern: dict) -> str:
|
||||||
|
"""Render the pattern as a compact one-line tool call signature."""
|
||||||
|
tool = pattern.get("tool", "?")
|
||||||
|
parts: List[str] = []
|
||||||
|
for key in ("tableName", "aggregate", "field", "groupBy", "orderBy"):
|
||||||
|
if key in pattern and pattern[key] is not None and not str(key).startswith("_"):
|
||||||
|
parts.append(f"{key}={pattern[key]!r}")
|
||||||
|
if "fields" in pattern and pattern["fields"]:
|
||||||
|
parts.append(f"fields={pattern['fields']}")
|
||||||
|
if "filters" in pattern and pattern["filters"]:
|
||||||
|
compact = ", ".join(
|
||||||
|
f"{f.get('field')}{f.get('op','=')}{f.get('value')!r}"
|
||||||
|
for f in pattern["filters"]
|
||||||
|
if isinstance(f, dict)
|
||||||
|
)
|
||||||
|
parts.append(f"filters=[{compact}]")
|
||||||
|
return f"{tool}({', '.join(parts)})"
|
||||||
311
modules/serviceCenter/services/serviceAgent/queryValidator.py
Normal file
311
modules/serviceCenter/services/serviceAgent/queryValidator.py
Normal file
|
|
@ -0,0 +1,311 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Pre-execute query validator for the Feature Data Sub-Agent.
|
||||||
|
|
||||||
|
Sits between the LLM tool call and `FeatureDataProvider`. Catches the four
|
||||||
|
high-impact hallucination classes deterministically so the LLM gets an
|
||||||
|
actionable repair hint instead of a raw SQL exception:
|
||||||
|
|
||||||
|
* invented field names -> FIELD_NOT_FOUND (+ fuzzy suggestion)
|
||||||
|
* operator/type mismatches -> OPERATOR_INCOMPATIBLE
|
||||||
|
* SUM/AVG on already-aggregated -> INVALID_AGGREGATE_TARGET
|
||||||
|
balance/total columns
|
||||||
|
* orderBy on invented fields -> ORDER_BY_INVALID
|
||||||
|
|
||||||
|
The validator reads the canonical schema from
|
||||||
|
`modules.datamodels.datamodelBase.MODEL_REGISTRY`. When an
|
||||||
|
`OntologyDescriptor` is provided (Phase 2), its constraints override the
|
||||||
|
convention-based defaults (e.g. NEVER_AGGREGATE on closingBalance).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import difflib
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import typing
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from modules.datamodels.datamodelBase import MODEL_REGISTRY
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
|
||||||
|
Constraint,
|
||||||
|
ConstraintRule,
|
||||||
|
OntologyDescriptor,
|
||||||
|
QueryValidationError,
|
||||||
|
ValidationErrorCode,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_STRING_ONLY_OPERATORS = {"LIKE", "ILIKE"}
|
||||||
|
_COMPARISON_OPERATORS = {">", "<", ">=", "<="}
|
||||||
|
_VALUELESS_OPERATORS = {"IS NULL", "IS NOT NULL"}
|
||||||
|
_AGGREGATES_THAT_SUM = {"SUM", "AVG"}
|
||||||
|
_AGGREGATE_BLACKLIST_SUFFIXES_DEFAULT: Tuple[str, ...] = ("Balance", "Total")
|
||||||
|
|
||||||
|
|
||||||
|
class QueryValidator:
|
||||||
|
"""Validate sub-agent tool arguments against the schema (+ optional ontology).
|
||||||
|
|
||||||
|
Stateless per call -- holding only the optional ontology. Each
|
||||||
|
`validateXxx` method returns ``None`` on success or a
|
||||||
|
:class:`QueryValidationError` to be surfaced to the LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ontology: Optional[OntologyDescriptor] = None):
|
||||||
|
self._ontology = ontology
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# public API: one method per sub-agent tool
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def validateBrowseQuery(
|
||||||
|
self, tableName: str, args: Dict[str, Any]
|
||||||
|
) -> Optional[QueryValidationError]:
|
||||||
|
"""Validate browseTable arguments.
|
||||||
|
|
||||||
|
Phase 1 scope: only `fields` (whitelist) is LLM-driven; `limit`/`offset`
|
||||||
|
are sanitized by the tool wrapper.
|
||||||
|
"""
|
||||||
|
modelFields = _getModelFields(tableName)
|
||||||
|
if modelFields is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fieldsErr = self._validateFieldList(args.get("fields"), modelFields)
|
||||||
|
if fieldsErr is not None:
|
||||||
|
return fieldsErr
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validateQueryTable(
|
||||||
|
self, tableName: str, args: Dict[str, Any]
|
||||||
|
) -> Optional[QueryValidationError]:
|
||||||
|
"""Validate queryTable arguments (filters + fields + orderBy)."""
|
||||||
|
modelFields = _getModelFields(tableName)
|
||||||
|
if modelFields is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
fieldsErr = self._validateFieldList(args.get("fields"), modelFields)
|
||||||
|
if fieldsErr is not None:
|
||||||
|
return fieldsErr
|
||||||
|
|
||||||
|
for f in args.get("filters") or []:
|
||||||
|
filterErr = self._validateFilter(f, modelFields)
|
||||||
|
if filterErr is not None:
|
||||||
|
return filterErr
|
||||||
|
|
||||||
|
orderBy = args.get("orderBy")
|
||||||
|
if orderBy is not None and not _isPlainNone(orderBy):
|
||||||
|
if orderBy not in modelFields:
|
||||||
|
return QueryValidationError(
|
||||||
|
code=ValidationErrorCode.ORDER_BY_INVALID,
|
||||||
|
field=orderBy,
|
||||||
|
suggestion=_suggestFieldName(orderBy, modelFields),
|
||||||
|
hint="orderBy must be a real field of this table.",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validateAggregateQuery(
|
||||||
|
self, tableName: str, args: Dict[str, Any]
|
||||||
|
) -> Optional[QueryValidationError]:
|
||||||
|
"""Validate aggregateTable arguments.
|
||||||
|
|
||||||
|
Catches the highest-impact hallucination in the codebase:
|
||||||
|
``SUM(closingBalance)`` (and friends) across periods -- closing
|
||||||
|
balances are already per-period, summing them produces nonsense.
|
||||||
|
"""
|
||||||
|
modelFields = _getModelFields(tableName)
|
||||||
|
if modelFields is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
field = args.get("field")
|
||||||
|
aggregate = (args.get("aggregate") or "").upper()
|
||||||
|
|
||||||
|
if not field:
|
||||||
|
return None # tool wrapper rejects empty field already
|
||||||
|
|
||||||
|
if field not in modelFields:
|
||||||
|
return QueryValidationError(
|
||||||
|
code=ValidationErrorCode.FIELD_NOT_FOUND,
|
||||||
|
field=field,
|
||||||
|
suggestion=_suggestFieldName(field, modelFields),
|
||||||
|
hint="Use browseTable to inspect this table's columns.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if aggregate in _AGGREGATES_THAT_SUM and self._isAggregateBlacklisted(tableName, field):
|
||||||
|
return QueryValidationError(
|
||||||
|
code=ValidationErrorCode.INVALID_AGGREGATE_TARGET,
|
||||||
|
field=field,
|
||||||
|
suggestion=None,
|
||||||
|
hint=(
|
||||||
|
f"{field} is already aggregated per period; do not {aggregate} it "
|
||||||
|
"across rows. Use queryTable with period filters instead."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if aggregate in _AGGREGATES_THAT_SUM and not _isNumericAnnotation(modelFields[field]):
|
||||||
|
return QueryValidationError(
|
||||||
|
code=ValidationErrorCode.TYPE_MISMATCH,
|
||||||
|
field=field,
|
||||||
|
suggestion=None,
|
||||||
|
hint=f"{aggregate} requires a numeric field; {field} is not numeric.",
|
||||||
|
)
|
||||||
|
|
||||||
|
groupBy = args.get("groupBy")
|
||||||
|
if groupBy is not None and not _isPlainNone(groupBy):
|
||||||
|
if groupBy not in modelFields:
|
||||||
|
return QueryValidationError(
|
||||||
|
code=ValidationErrorCode.FIELD_NOT_FOUND,
|
||||||
|
field=groupBy,
|
||||||
|
suggestion=_suggestFieldName(groupBy, modelFields),
|
||||||
|
hint="groupBy must be a real field of this table.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# filters validation matches queryTable so the LLM gets consistent
|
||||||
|
# repair hints regardless of which tool it picked.
|
||||||
|
for f in args.get("filters") or []:
|
||||||
|
filterErr = self._validateFilter(f, modelFields)
|
||||||
|
if filterErr is not None:
|
||||||
|
return filterErr
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# internals
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _validateFieldList(
|
||||||
|
self, fields: Optional[List[str]], modelFields: Dict[str, Any]
|
||||||
|
) -> Optional[QueryValidationError]:
|
||||||
|
if not fields:
|
||||||
|
return None
|
||||||
|
for f in fields:
|
||||||
|
if not isinstance(f, str):
|
||||||
|
continue
|
||||||
|
if f not in modelFields:
|
||||||
|
return QueryValidationError(
|
||||||
|
code=ValidationErrorCode.FIELD_NOT_FOUND,
|
||||||
|
field=f,
|
||||||
|
suggestion=_suggestFieldName(f, modelFields),
|
||||||
|
hint="Use browseTable to inspect this table's columns.",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _validateFilter(
|
||||||
|
self, filterEntry: Any, modelFields: Dict[str, Any]
|
||||||
|
) -> Optional[QueryValidationError]:
|
||||||
|
if not isinstance(filterEntry, dict):
|
||||||
|
return None
|
||||||
|
field = filterEntry.get("field")
|
||||||
|
op = (filterEntry.get("op") or "=").upper()
|
||||||
|
|
||||||
|
if not isinstance(field, str) or not field:
|
||||||
|
return None # tool wrapper passes these straight through
|
||||||
|
|
||||||
|
if field not in modelFields:
|
||||||
|
return QueryValidationError(
|
||||||
|
code=ValidationErrorCode.FIELD_NOT_FOUND,
|
||||||
|
field=field,
|
||||||
|
suggestion=_suggestFieldName(field, modelFields),
|
||||||
|
hint="Use browseTable to inspect this table's columns.",
|
||||||
|
)
|
||||||
|
|
||||||
|
annotation = modelFields[field]
|
||||||
|
|
||||||
|
if op in _STRING_ONLY_OPERATORS and not _isStringAnnotation(annotation):
|
||||||
|
return QueryValidationError(
|
||||||
|
code=ValidationErrorCode.OPERATOR_INCOMPATIBLE,
|
||||||
|
field=field,
|
||||||
|
suggestion=None,
|
||||||
|
hint=f"{op} only works on string fields; {field} is not a string.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if op in _COMPARISON_OPERATORS and not _isComparableAnnotation(annotation):
|
||||||
|
return QueryValidationError(
|
||||||
|
code=ValidationErrorCode.OPERATOR_INCOMPATIBLE,
|
||||||
|
field=field,
|
||||||
|
suggestion=None,
|
||||||
|
hint=f"{op} requires a numeric or date field; {field} is not comparable.",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _isAggregateBlacklisted(self, tableName: str, fieldName: str) -> bool:
|
||||||
|
"""Check whether a field is marked NEVER_AGGREGATE.
|
||||||
|
|
||||||
|
Phase 2 (ontology present): consult the descriptor.
|
||||||
|
Phase 1 fallback: naming convention (``*Balance`` / ``*Total``).
|
||||||
|
"""
|
||||||
|
if self._ontology is not None:
|
||||||
|
target = f"{tableName}.{fieldName}"
|
||||||
|
for c in self._ontology.constraintsForTable(tableName):
|
||||||
|
if c.rule == ConstraintRule.NEVER_AGGREGATE and c.appliesTo == target:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for suffix in _AGGREGATE_BLACKLIST_SUFFIXES_DEFAULT:
|
||||||
|
if fieldName.endswith(suffix):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _getModelFields(tableName: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Return ``{fieldName: annotation}`` for a registered Pydantic table model.
|
||||||
|
|
||||||
|
None when the table is not in MODEL_REGISTRY (e.g. pure UDB tables in
|
||||||
|
early-startup contexts). The validator is a best-effort layer -- when
|
||||||
|
the schema is unknown we let the request through and rely on the
|
||||||
|
downstream SQL layer for safety.
|
||||||
|
"""
|
||||||
|
modelClass = MODEL_REGISTRY.get(tableName)
|
||||||
|
if modelClass is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
name: info.annotation for name, info in modelClass.model_fields.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _suggestFieldName(badName: str, modelFields: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""Return the closest valid field name, or None if nothing reasonable."""
|
||||||
|
if not badName or not modelFields:
|
||||||
|
return None
|
||||||
|
matches = difflib.get_close_matches(badName, list(modelFields.keys()), n=1, cutoff=0.6)
|
||||||
|
return matches[0] if matches else None
|
||||||
|
|
||||||
|
|
||||||
|
def _isPlainNone(value: Any) -> bool:
|
||||||
|
"""LLMs sometimes pass the literal string 'None' -- treat both as None."""
|
||||||
|
return value is None or (isinstance(value, str) and value.strip().lower() == "none")
|
||||||
|
|
||||||
|
|
||||||
|
def _unwrapAnnotation(annotation: Any) -> Tuple[Any, ...]:
|
||||||
|
"""Flatten Optional/Union annotations into their constituent types."""
|
||||||
|
origin = typing.get_origin(annotation)
|
||||||
|
if origin is None:
|
||||||
|
return (annotation,)
|
||||||
|
return tuple(a for a in typing.get_args(annotation) if a is not type(None))
|
||||||
|
|
||||||
|
|
||||||
|
def _isStringAnnotation(annotation: Any) -> bool:
|
||||||
|
return any(a is str for a in _unwrapAnnotation(annotation))
|
||||||
|
|
||||||
|
|
||||||
|
def _isNumericAnnotation(annotation: Any) -> bool:
|
||||||
|
numericTypes = (int, float)
|
||||||
|
return any(a in numericTypes for a in _unwrapAnnotation(annotation))
|
||||||
|
|
||||||
|
|
||||||
|
def _isComparableAnnotation(annotation: Any) -> bool:
|
||||||
|
"""Numeric types are the comparable shape we see in feature tables.
|
||||||
|
|
||||||
|
Booleans count as int in Python's type hierarchy but the comparison
|
||||||
|
operators ``>``/``<`` on bool columns are almost never meaningful, so we
|
||||||
|
treat bool as non-comparable for validator purposes.
|
||||||
|
"""
|
||||||
|
for a in _unwrapAnnotation(annotation):
|
||||||
|
if a is bool:
|
||||||
|
continue
|
||||||
|
if a in (int, float):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
@ -98,14 +98,17 @@ class _VirtualFS:
|
||||||
|
|
||||||
def _makeReadFile(services):
|
def _makeReadFile(services):
|
||||||
"""Create a readFile(fileId) closure bound to the current services context."""
|
"""Create a readFile(fileId) closure bound to the current services context."""
|
||||||
def readFile(fileId: str) -> str:
|
def readFile(fileId: str, encoding: str = "utf-8") -> str:
|
||||||
mgmt = getattr(services, 'interfaceDbComponent', None) if services else None
|
mgmt = getattr(services, 'interfaceDbComponent', None) if services else None
|
||||||
if not mgmt:
|
if not mgmt:
|
||||||
raise RuntimeError("readFile: no file store available in this session")
|
raise RuntimeError("readFile: no file store available in this session")
|
||||||
data = mgmt.getFileData(str(fileId))
|
data = mgmt.getFileData(str(fileId))
|
||||||
if data is None:
|
if data is None:
|
||||||
raise FileNotFoundError(f"File '{fileId}' not found in workspace")
|
raise FileNotFoundError(f"File '{fileId}' not found in workspace")
|
||||||
return data.decode("utf-8")
|
try:
|
||||||
|
return data.decode(encoding)
|
||||||
|
except (UnicodeDecodeError, LookupError):
|
||||||
|
return data.decode("utf-8", errors="replace")
|
||||||
return readFile
|
return readFile
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,6 +60,7 @@ from modules.shared.jsonContinuation import getContexts
|
||||||
from modules.shared.jsonUtils import buildContinuationContext, tryParseJson
|
from modules.shared.jsonUtils import buildContinuationContext, tryParseJson
|
||||||
from modules.shared.jsonUtils import closeJsonStructures
|
from modules.shared.jsonUtils import closeJsonStructures
|
||||||
from modules.shared.jsonUtils import stripCodeFences, normalizeJsonText
|
from modules.shared.jsonUtils import stripCodeFences, normalizeJsonText
|
||||||
|
from modules.shared.jsonUtils import extractJsonString, repairBrokenJson
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -447,7 +448,6 @@ class AiCallLooper:
|
||||||
extracted = extractJsonString(contexts.completePart)
|
extracted = extractJsonString(contexts.completePart)
|
||||||
parsed, parseErr, _ = tryParseJson(extracted)
|
parsed, parseErr, _ = tryParseJson(extracted)
|
||||||
if parseErr is not None:
|
if parseErr is not None:
|
||||||
from modules.shared.jsonUtils import repairBrokenJson
|
|
||||||
repaired = repairBrokenJson(extracted)
|
repaired = repairBrokenJson(extracted)
|
||||||
if repaired:
|
if repaired:
|
||||||
parsed = repaired
|
parsed = repaired
|
||||||
|
|
@ -470,9 +470,10 @@ class AiCallLooper:
|
||||||
return useCase.finalResultHandler(
|
return useCase.finalResultHandler(
|
||||||
result, normalized, extracted, debugPrefix, self.services
|
result, normalized, extracted, debugPrefix, self.services
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Iteration {iteration}: completePart not serializable after getContexts success: {e}"
|
f"Iteration {iteration}: completePart not serializable after getContexts success: "
|
||||||
|
f"{type(e).__name__}: {e}"
|
||||||
)
|
)
|
||||||
mergeFailCount += 1
|
mergeFailCount += 1
|
||||||
if mergeFailCount >= MAX_MERGE_FAILS:
|
if mergeFailCount >= MAX_MERGE_FAILS:
|
||||||
|
|
@ -491,6 +492,15 @@ class AiCallLooper:
|
||||||
)
|
)
|
||||||
self.services.chat.progressLogFinish(iterationOperationId, True)
|
self.services.chat.progressLogFinish(iterationOperationId, True)
|
||||||
continue
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Iteration {iteration}: unexpected error during completePart processing "
|
||||||
|
f"(re-raising, NOT a pipeline-mismatch retry): {type(e).__name__}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
if iterationOperationId:
|
||||||
|
self.services.chat.progressLogFinish(iterationOperationId, False)
|
||||||
|
raise
|
||||||
|
|
||||||
elif contexts.jsonParsingSuccess and contexts.overlapContext != "":
|
elif contexts.jsonParsingSuccess and contexts.overlapContext != "":
|
||||||
# JSON parseable but has cut point - CONTINUE to next iteration
|
# JSON parseable but has cut point - CONTINUE to next iteration
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
from modules.connectors.connectorDbPostgre import DatabaseConnector, getCachedConnector
|
||||||
from modules.shared.configuration import APP_CONFIG
|
from modules.shared.configuration import APP_CONFIG
|
||||||
from modules.shared.dbRegistry import registerDatabase
|
from modules.shared.dbRegistry import registerDatabase
|
||||||
from modules.datamodels.datamodelBackgroundJob import (
|
from modules.datamodels.datamodelBackgroundJob import (
|
||||||
|
|
@ -104,7 +104,13 @@ def registerJobHandler(jobType: str, handler: JobHandler) -> None:
|
||||||
|
|
||||||
|
|
||||||
def _getDb() -> DatabaseConnector:
|
def _getDb() -> DatabaseConnector:
|
||||||
return DatabaseConnector(
|
"""Return the shared cached connector for the jobs DB.
|
||||||
|
|
||||||
|
Reuses the same connector across all job CRUD calls instead of opening a
|
||||||
|
fresh psycopg2 connection (and re-running `_create_database_if_not_exists`
|
||||||
|
+ `_create_tables` + `_initializeSystemTable`) on every operation.
|
||||||
|
"""
|
||||||
|
return getCachedConnector(
|
||||||
dbDatabase=JOBS_DATABASE,
|
dbDatabase=JOBS_DATABASE,
|
||||||
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
||||||
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
|
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
|
||||||
|
|
@ -290,12 +296,12 @@ def cancelJobsByConnection(connectionId: str, *, jobType: str = "connection.boot
|
||||||
|
|
||||||
|
|
||||||
def recoverInterruptedJobs() -> int:
|
def recoverInterruptedJobs() -> int:
|
||||||
"""Flip any RUNNING jobs to ERROR and re-queue bootstrap jobs (called at worker boot).
|
"""Flip any RUNNING jobs to ERROR (called at worker boot).
|
||||||
|
|
||||||
A RUNNING job in the DB after process restart means the previous worker
|
A RUNNING job in the DB after process restart means the previous worker
|
||||||
died mid-execution; the asyncio task is gone and the job will never
|
died mid-execution; the asyncio task is gone and the job will never
|
||||||
finish on its own. For connection.bootstrap jobs, a fresh job is
|
finish on its own. The daily scheduler or manual "Neu indexieren"
|
||||||
automatically re-queued so the user doesn't have to manually retry.
|
button handles retry — no automatic re-queue to avoid infinite loops.
|
||||||
"""
|
"""
|
||||||
db = _getDb()
|
db = _getDb()
|
||||||
try:
|
try:
|
||||||
|
|
@ -304,34 +310,70 @@ def recoverInterruptedJobs() -> int:
|
||||||
logger.warning("recoverInterruptedJobs: failed to scan RUNNING jobs: %s", ex)
|
logger.warning("recoverInterruptedJobs: failed to scan RUNNING jobs: %s", ex)
|
||||||
return 0
|
return 0
|
||||||
count = 0
|
count = 0
|
||||||
requeued = 0
|
|
||||||
for row in rows:
|
for row in rows:
|
||||||
try:
|
try:
|
||||||
_markError(row["id"], "Interrupted by worker restart")
|
_markError(row["id"], "Interrupted by worker restart")
|
||||||
count += 1
|
count += 1
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logger.warning("recoverInterruptedJobs: could not mark %s as ERROR: %s", row.get("id"), ex)
|
logger.warning("recoverInterruptedJobs: could not mark %s as ERROR: %s", row.get("id"), ex)
|
||||||
continue
|
|
||||||
|
|
||||||
if row.get("jobType") == "connection.bootstrap":
|
|
||||||
payload = row.get("payload") or {}
|
|
||||||
if payload.get("connectionId"):
|
|
||||||
try:
|
|
||||||
newJob = BackgroundJob(
|
|
||||||
jobType="connection.bootstrap",
|
|
||||||
payload=payload,
|
|
||||||
triggeredBy="recovery.requeue",
|
|
||||||
)
|
|
||||||
record = db.recordCreate(BackgroundJob, _serialiseDatetimes(newJob.model_dump()))
|
|
||||||
asyncio.create_task(_runJob(record["id"]))
|
|
||||||
requeued += 1
|
|
||||||
logger.info(
|
|
||||||
"recoverInterruptedJobs: re-queued bootstrap for connectionId=%s (new jobId=%s)",
|
|
||||||
payload["connectionId"], record["id"],
|
|
||||||
)
|
|
||||||
except Exception as reqEx:
|
|
||||||
logger.warning("recoverInterruptedJobs: re-queue failed for %s: %s", row.get("id"), reqEx)
|
|
||||||
|
|
||||||
if count:
|
if count:
|
||||||
logger.warning("Recovered %d interrupted background job(s) after restart (re-queued %d)", count, requeued)
|
logger.warning("Recovered %d interrupted background job(s) after restart", count)
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
_ZOMBIE_MAX_AGE_SECONDS = 30 * 60
|
||||||
|
|
||||||
|
|
||||||
|
def killZombieJobs(maxAgeSeconds: int = _ZOMBIE_MAX_AGE_SECONDS) -> int:
|
||||||
|
"""Kill RUNNING jobs that have not been updated within `maxAgeSeconds`.
|
||||||
|
|
||||||
|
Detects walkers that are stuck in a sync call without progress updates.
|
||||||
|
A live job updates progress at least every few seconds via JobProgressCallback.
|
||||||
|
Anything older than maxAgeSeconds without finishing is considered hung.
|
||||||
|
"""
|
||||||
|
db = _getDb()
|
||||||
|
try:
|
||||||
|
rows = db.getRecordset(BackgroundJob, recordFilter={"status": BackgroundJobStatusEnum.RUNNING.value})
|
||||||
|
except Exception as ex:
|
||||||
|
logger.warning("killZombieJobs: failed to scan RUNNING jobs: %s", ex)
|
||||||
|
return 0
|
||||||
|
now = time.time()
|
||||||
|
threshold = now - maxAgeSeconds
|
||||||
|
count = 0
|
||||||
|
for row in rows:
|
||||||
|
started = row.get("startedAt") or row.get("createdAt")
|
||||||
|
if not started or started > threshold:
|
||||||
|
continue
|
||||||
|
ageMin = (now - started) / 60
|
||||||
|
try:
|
||||||
|
_markError(row["id"], f"Zombie killed (stuck >{maxAgeSeconds // 60}min, no progress)")
|
||||||
|
count += 1
|
||||||
|
payload = row.get("payload") or {}
|
||||||
|
logger.warning(
|
||||||
|
"killZombieJobs: killed %s (type=%s connId=%s ageMin=%.1f)",
|
||||||
|
row["id"], row.get("jobType"), payload.get("connectionId", "")[:12], ageMin,
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
logger.warning("killZombieJobs: could not kill %s: %s", row.get("id"), ex)
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def registerZombieKillerScheduler(*, intervalMinutes: int = 5) -> None:
|
||||||
|
"""Register a recurring cron job that kills stuck RUNNING jobs.
|
||||||
|
|
||||||
|
Idempotent. Runs every `intervalMinutes` minutes.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from modules.shared.eventManagement import eventManager
|
||||||
|
|
||||||
|
async def _runKiller():
|
||||||
|
killZombieJobs()
|
||||||
|
|
||||||
|
eventManager.registerCron(
|
||||||
|
jobId="background_jobs.zombie_killer",
|
||||||
|
func=_runKiller,
|
||||||
|
cronKwargs={"minute": f"*/{intervalMinutes}"},
|
||||||
|
)
|
||||||
|
logger.info("Zombie-killer scheduler registered (every %d min)", intervalMinutes)
|
||||||
|
except Exception as ex:
|
||||||
|
logger.warning("Zombie-killer scheduler registration failed (non-critical): %s", ex)
|
||||||
|
|
|
||||||
|
|
@ -532,8 +532,16 @@ class ChatService:
|
||||||
self, connectionId: str, sourceType: str, path: str, label: str,
|
self, connectionId: str, sourceType: str, path: str, label: str,
|
||||||
featureInstanceId: str = None, displayPath: str = None,
|
featureInstanceId: str = None, displayPath: str = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Create a new external data source reference."""
|
"""Create a new external data source reference.
|
||||||
|
|
||||||
|
Returns existing record if connectionId + path already exists (upsert semantics).
|
||||||
|
"""
|
||||||
from modules.datamodels.datamodelDataSource import DataSource
|
from modules.datamodels.datamodelDataSource import DataSource
|
||||||
|
existing = self.interfaceDbApp.db.getRecordset(
|
||||||
|
DataSource, recordFilter={"connectionId": connectionId, "path": path}
|
||||||
|
)
|
||||||
|
if existing:
|
||||||
|
return existing[0] if isinstance(existing[0], dict) else existing[0].model_dump()
|
||||||
ds = DataSource(
|
ds = DataSource(
|
||||||
connectionId=connectionId,
|
connectionId=connectionId,
|
||||||
sourceType=sourceType,
|
sourceType=sourceType,
|
||||||
|
|
|
||||||
|
|
@ -132,10 +132,10 @@ _SOURCE_TYPE_MAP = {
|
||||||
"gmail": ("gmailFolder",),
|
"gmail": ("gmailFolder",),
|
||||||
},
|
},
|
||||||
"clickup": {
|
"clickup": {
|
||||||
"clickup": ("clickupList",),
|
"clickup": ("clickupList", "clickup"),
|
||||||
},
|
},
|
||||||
"infomaniak": {
|
"infomaniak": {
|
||||||
"kdrive": ("kdriveFolder",),
|
"kdrive": ("kdriveFolder", "infomaniak"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -225,7 +225,7 @@ async def _bootstrapJobHandler(
|
||||||
bootstrapOutlook,
|
bootstrapOutlook,
|
||||||
)
|
)
|
||||||
|
|
||||||
progressCb(10, "sharepoint + outlook")
|
progressCb(0, "Synchronisierung läuft...")
|
||||||
spDs = _filterDs("sharepoint")
|
spDs = _filterDs("sharepoint")
|
||||||
olDs = _filterDs("outlook")
|
olDs = _filterDs("outlook")
|
||||||
async def _noopResult():
|
async def _noopResult():
|
||||||
|
|
@ -251,7 +251,7 @@ async def _bootstrapJobHandler(
|
||||||
bootstrapGmail,
|
bootstrapGmail,
|
||||||
)
|
)
|
||||||
|
|
||||||
progressCb(10, "drive + gmail")
|
progressCb(0, "Synchronisierung läuft...")
|
||||||
gdDs = _filterDs("drive")
|
gdDs = _filterDs("drive")
|
||||||
gmDs = _filterDs("gmail")
|
gmDs = _filterDs("gmail")
|
||||||
async def _noopResult():
|
async def _noopResult():
|
||||||
|
|
@ -274,7 +274,7 @@ async def _bootstrapJobHandler(
|
||||||
bootstrapClickup,
|
bootstrapClickup,
|
||||||
)
|
)
|
||||||
|
|
||||||
progressCb(10, "clickup tasks")
|
progressCb(0, "Synchronisierung läuft...")
|
||||||
cuDs = _filterDs("clickup")
|
cuDs = _filterDs("clickup")
|
||||||
cuResult = await bootstrapClickup(connectionId=connectionId, progressCb=progressCb, dataSources=cuDs) if cuDs else {"skipped": True, "reason": "no_datasources"}
|
cuResult = await bootstrapClickup(connectionId=connectionId, progressCb=progressCb, dataSources=cuDs) if cuDs else {"skipped": True, "reason": "no_datasources"}
|
||||||
return {
|
return {
|
||||||
|
|
@ -283,6 +283,20 @@ async def _bootstrapJobHandler(
|
||||||
"clickup": _normalize(cuResult, "clickup"),
|
"clickup": _normalize(cuResult, "clickup"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if authority == "infomaniak":
|
||||||
|
from modules.serviceCenter.services.serviceKnowledge.subConnectorSyncKdrive import (
|
||||||
|
bootstrapKdrive,
|
||||||
|
)
|
||||||
|
|
||||||
|
progressCb(0, "Synchronisierung läuft...")
|
||||||
|
kdDs = _filterDs("kdrive")
|
||||||
|
kdResult = await bootstrapKdrive(connectionId=connectionId, progressCb=progressCb, dataSources=kdDs) if kdDs else {"skipped": True, "reason": "no_datasources"}
|
||||||
|
return {
|
||||||
|
"connectionId": connectionId,
|
||||||
|
"authority": authority,
|
||||||
|
"kdrive": _normalize(kdResult, "kdrive"),
|
||||||
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"ingestion.connection.bootstrap.skipped reason=unsupported_authority authority=%s connectionId=%s",
|
"ingestion.connection.bootstrap.skipped reason=unsupported_authority authority=%s connectionId=%s",
|
||||||
authority, connectionId,
|
authority, connectionId,
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,12 @@ from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
|
||||||
|
WalkerTimeout,
|
||||||
|
ingestWithTimeout,
|
||||||
|
logItemStart,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MAX_TASKS_DEFAULT = 500
|
MAX_TASKS_DEFAULT = 500
|
||||||
|
|
@ -449,36 +455,44 @@ async def _ingestTask(
|
||||||
name = task.get("name") or f"Task {taskId}"
|
name = task.get("name") or f"Task {taskId}"
|
||||||
syntheticId = _syntheticTaskId(connectionId, taskId)
|
syntheticId = _syntheticTaskId(connectionId, taskId)
|
||||||
fileName = f"{name[:80].strip() or taskId}.task.json"
|
fileName = f"{name[:80].strip() or taskId}.task.json"
|
||||||
|
logItemStart("clickup", f"{teamId}/{taskId}")
|
||||||
|
|
||||||
contentObjects = _buildContentObjects(task, limits)
|
contentObjects = _buildContentObjects(task, limits)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
handle = await knowledgeService.requestIngestion(
|
handle = await ingestWithTimeout(
|
||||||
IngestionJob(
|
knowledgeService.requestIngestion(
|
||||||
sourceKind="clickup_task",
|
IngestionJob(
|
||||||
sourceId=syntheticId,
|
sourceKind="clickup_task",
|
||||||
fileName=fileName,
|
sourceId=syntheticId,
|
||||||
mimeType="application/vnd.clickup.task+json",
|
fileName=fileName,
|
||||||
userId=userId,
|
mimeType="application/vnd.clickup.task+json",
|
||||||
mandateId=mandateId,
|
userId=userId,
|
||||||
contentObjects=contentObjects,
|
mandateId=mandateId,
|
||||||
contentVersion=revision or None,
|
contentObjects=contentObjects,
|
||||||
neutralize=limits.neutralize,
|
contentVersion=revision or None,
|
||||||
provenance={
|
neutralize=limits.neutralize,
|
||||||
"connectionId": connectionId,
|
provenance={
|
||||||
"dataSourceId": dataSourceId,
|
"connectionId": connectionId,
|
||||||
"authority": "clickup",
|
"dataSourceId": dataSourceId,
|
||||||
"service": "clickup",
|
"authority": "clickup",
|
||||||
"externalItemId": taskId,
|
"service": "clickup",
|
||||||
"teamId": teamId,
|
"externalItemId": taskId,
|
||||||
"listId": ((task.get("list") or {}).get("id")),
|
"teamId": teamId,
|
||||||
"spaceId": ((task.get("space") or {}).get("id")),
|
"listId": ((task.get("list") or {}).get("id")),
|
||||||
"url": task.get("url"),
|
"spaceId": ((task.get("space") or {}).get("id")),
|
||||||
"status": ((task.get("status") or {}).get("status")),
|
"url": task.get("url"),
|
||||||
"tier": limits.clickupScope,
|
"status": ((task.get("status") or {}).get("status")),
|
||||||
},
|
"tier": limits.clickupScope,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
label=taskId,
|
||||||
)
|
)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("clickup ingestion %s failed: %s", taskId, exc, exc_info=True)
|
logger.error("clickup ingestion %s failed: %s", taskId, exc, exc_info=True)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -493,18 +507,16 @@ async def _ingestTask(
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
||||||
processed = result.indexed + result.skippedDuplicate
|
processed = result.indexed + result.skippedDuplicate
|
||||||
if progressCb is not None and processed % 50 == 0:
|
if progressCb is not None and processed % 5 == 0:
|
||||||
if hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
|
if hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
progressCb(
|
progressCb(0, f"{processed} Tasks verarbeitet, {result.indexed} indexiert")
|
||||||
min(90, 10 + int(80 * processed / max(1, limits.maxTasks))),
|
|
||||||
f"clickup processed={processed}",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logger.info(
|
if processed % 50 == 0:
|
||||||
"ingestion.connection.bootstrap.progress part=clickup processed=%d skippedDup=%d failed=%d",
|
logger.info(
|
||||||
|
"ingestion.connection.bootstrap.progress part=clickup processed=%d skippedDup=%d failed=%d",
|
||||||
processed, result.skippedDuplicate, result.failed,
|
processed, result.skippedDuplicate, result.failed,
|
||||||
extra={
|
extra={
|
||||||
"event": "ingestion.connection.bootstrap.progress",
|
"event": "ingestion.connection.bootstrap.progress",
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,13 @@ from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from modules.datamodels.datamodelExtraction import ExtractionOptions
|
from modules.datamodels.datamodelExtraction import ExtractionOptions
|
||||||
|
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
|
||||||
|
WalkerTimeout,
|
||||||
|
downloadWithTimeout,
|
||||||
|
extractWithTimeout,
|
||||||
|
ingestWithTimeout,
|
||||||
|
logItemStart,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -342,9 +349,15 @@ async def _ingestOne(
|
||||||
|
|
||||||
syntheticFileId = _syntheticFileId(connectionId, externalItemId)
|
syntheticFileId = _syntheticFileId(connectionId, externalItemId)
|
||||||
fileName = getattr(entry, "name", "") or externalItemId
|
fileName = getattr(entry, "name", "") or externalItemId
|
||||||
|
declaredSize = int(getattr(entry, "size", 0) or 0) or None
|
||||||
|
logItemStart("gdrive", entryPath, sizeBytes=declaredSize, mime=mimeType)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
downloaded = await adapter.download(entryPath)
|
downloaded = await downloadWithTimeout(adapter.download(entryPath), label=entryPath)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("gdrive download %s failed: %s", entryPath, exc)
|
logger.warning("gdrive download %s failed: %s", entryPath, exc)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -368,10 +381,16 @@ async def _ingestOne(
|
||||||
result.bytesProcessed += len(fileBytes)
|
result.bytesProcessed += len(fileBytes)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
extracted = runExtractionFn(
|
extracted = await extractWithTimeout(
|
||||||
|
runExtractionFn,
|
||||||
fileBytes, fileName, mimeType,
|
fileBytes, fileName, mimeType,
|
||||||
ExtractionOptions(mergeStrategy=None),
|
ExtractionOptions(mergeStrategy=None),
|
||||||
|
label=entryPath,
|
||||||
)
|
)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("gdrive extraction %s failed: %s", entryPath, exc)
|
logger.warning("gdrive extraction %s failed: %s", entryPath, exc)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -393,20 +412,27 @@ async def _ingestOne(
|
||||||
"tier": "body",
|
"tier": "body",
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
handle = await knowledgeService.requestIngestion(
|
handle = await ingestWithTimeout(
|
||||||
IngestionJob(
|
knowledgeService.requestIngestion(
|
||||||
sourceKind="gdrive_item",
|
IngestionJob(
|
||||||
sourceId=syntheticFileId,
|
sourceKind="gdrive_item",
|
||||||
fileName=fileName,
|
sourceId=syntheticFileId,
|
||||||
mimeType=mimeType,
|
fileName=fileName,
|
||||||
userId=userId,
|
mimeType=mimeType,
|
||||||
mandateId=mandateId,
|
userId=userId,
|
||||||
contentObjects=contentObjects,
|
mandateId=mandateId,
|
||||||
contentVersion=revision,
|
contentObjects=contentObjects,
|
||||||
neutralize=limits.neutralize,
|
contentVersion=revision,
|
||||||
provenance=provenance,
|
neutralize=limits.neutralize,
|
||||||
)
|
provenance=provenance,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
label=entryPath,
|
||||||
)
|
)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("gdrive ingestion %s failed: %s", entryPath, exc, exc_info=True)
|
logger.error("gdrive ingestion %s failed: %s", entryPath, exc, exc_info=True)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -422,13 +448,10 @@ async def _ingestOne(
|
||||||
if handle.error:
|
if handle.error:
|
||||||
result.errors.append(f"ingest({entryPath}): {handle.error}")
|
result.errors.append(f"ingest({entryPath}): {handle.error}")
|
||||||
|
|
||||||
if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0:
|
processed = result.indexed + result.skippedDuplicate
|
||||||
processed = result.indexed + result.skippedDuplicate
|
if progressCb is not None and processed % 5 == 0:
|
||||||
try:
|
try:
|
||||||
progressCb(
|
progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert")
|
||||||
min(90, 10 + int(80 * processed / max(1, limits.maxItems))),
|
|
||||||
f"gdrive processed={processed}",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,11 @@ from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from modules.serviceCenter.services.serviceKnowledge.subTextClean import cleanEmailBody
|
from modules.serviceCenter.services.serviceKnowledge.subTextClean import cleanEmailBody
|
||||||
|
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
|
||||||
|
WalkerTimeout,
|
||||||
|
ingestWithTimeout,
|
||||||
|
logItemStart,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -399,34 +404,42 @@ async def _ingestMessage(
|
||||||
subject = headers.get("subject") or "(no subject)"
|
subject = headers.get("subject") or "(no subject)"
|
||||||
syntheticId = _syntheticMessageId(connectionId, messageId)
|
syntheticId = _syntheticMessageId(connectionId, messageId)
|
||||||
fileName = f"{subject[:80].strip()}.eml" if subject else f"{messageId}.eml"
|
fileName = f"{subject[:80].strip()}.eml" if subject else f"{messageId}.eml"
|
||||||
|
logItemStart("gmail", f"{labelId}/{messageId}", mime="message/rfc822")
|
||||||
|
|
||||||
contentObjects = _buildContentObjects(
|
contentObjects = _buildContentObjects(
|
||||||
message, limits.maxBodyChars, mailContentDepth=limits.mailContentDepth
|
message, limits.maxBodyChars, mailContentDepth=limits.mailContentDepth
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
handle = await knowledgeService.requestIngestion(
|
handle = await ingestWithTimeout(
|
||||||
IngestionJob(
|
knowledgeService.requestIngestion(
|
||||||
sourceKind="gmail_message",
|
IngestionJob(
|
||||||
sourceId=syntheticId,
|
sourceKind="gmail_message",
|
||||||
fileName=fileName,
|
sourceId=syntheticId,
|
||||||
mimeType="message/rfc822",
|
fileName=fileName,
|
||||||
userId=userId,
|
mimeType="message/rfc822",
|
||||||
mandateId=mandateId,
|
userId=userId,
|
||||||
contentObjects=contentObjects,
|
mandateId=mandateId,
|
||||||
contentVersion=str(revision) if revision else None,
|
contentObjects=contentObjects,
|
||||||
neutralize=limits.neutralize,
|
contentVersion=str(revision) if revision else None,
|
||||||
provenance={
|
neutralize=limits.neutralize,
|
||||||
"connectionId": connectionId,
|
provenance={
|
||||||
"dataSourceId": dataSourceId,
|
"connectionId": connectionId,
|
||||||
"authority": "google",
|
"dataSourceId": dataSourceId,
|
||||||
"service": "gmail",
|
"authority": "google",
|
||||||
"externalItemId": messageId,
|
"service": "gmail",
|
||||||
"label": labelId,
|
"externalItemId": messageId,
|
||||||
"threadId": message.get("threadId"),
|
"label": labelId,
|
||||||
"tier": limits.mailContentDepth,
|
"threadId": message.get("threadId"),
|
||||||
},
|
"tier": limits.mailContentDepth,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
label=messageId,
|
||||||
)
|
)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("gmail ingestion %s failed: %s", messageId, exc, exc_info=True)
|
logger.error("gmail ingestion %s failed: %s", messageId, exc, exc_info=True)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -458,18 +471,16 @@ async def _ingestMessage(
|
||||||
logger.warning("gmail attachments %s failed: %s", messageId, exc)
|
logger.warning("gmail attachments %s failed: %s", messageId, exc)
|
||||||
result.errors.append(f"attachments({messageId}): {exc}")
|
result.errors.append(f"attachments({messageId}): {exc}")
|
||||||
|
|
||||||
if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0:
|
processed = result.indexed + result.skippedDuplicate
|
||||||
processed = result.indexed + result.skippedDuplicate
|
if progressCb is not None and processed % 5 == 0:
|
||||||
try:
|
try:
|
||||||
progressCb(
|
progressCb(0, f"{processed} Mails verarbeitet, {result.indexed} indexiert")
|
||||||
min(90, 10 + int(80 * processed / max(1, limits.maxMessages))),
|
|
||||||
f"gmail processed={processed}",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logger.info(
|
if processed % 50 == 0:
|
||||||
"ingestion.connection.bootstrap.progress part=gmail processed=%d skippedDup=%d failed=%d",
|
logger.info(
|
||||||
processed, result.skippedDuplicate, result.failed,
|
"ingestion.connection.bootstrap.progress part=gmail processed=%d skippedDup=%d failed=%d",
|
||||||
|
processed, result.skippedDuplicate, result.failed,
|
||||||
extra={
|
extra={
|
||||||
"event": "ingestion.connection.bootstrap.progress",
|
"event": "ingestion.connection.bootstrap.progress",
|
||||||
"part": "gmail",
|
"part": "gmail",
|
||||||
|
|
@ -546,13 +557,26 @@ async def _ingestAttachments(
|
||||||
fileName = stub["filename"]
|
fileName = stub["filename"]
|
||||||
mimeType = stub["mimeType"]
|
mimeType = stub["mimeType"]
|
||||||
syntheticId = _syntheticAttachmentId(connectionId, messageId, stub["attachmentId"])
|
syntheticId = _syntheticAttachmentId(connectionId, messageId, stub["attachmentId"])
|
||||||
|
attLabel = f"{messageId}/att:{stub['attachmentId']}/{fileName}"
|
||||||
|
logItemStart("gmail-attachment", attLabel, sizeBytes=stub.get("size") or None, mime=mimeType)
|
||||||
|
|
||||||
try:
|
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
|
||||||
extracted = runExtraction(
|
extractWithTimeout as _extractWithTimeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _runAttExtraction():
|
||||||
|
return runExtraction(
|
||||||
extractorRegistry, chunkerRegistry,
|
extractorRegistry, chunkerRegistry,
|
||||||
rawBytes, fileName, mimeType,
|
rawBytes, fileName, mimeType,
|
||||||
ExtractionOptions(mergeStrategy=None),
|
ExtractionOptions(mergeStrategy=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
extracted = await _extractWithTimeout(_runAttExtraction, label=attLabel)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
continue
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("gmail attachment extract %s failed: %s", stub["attachmentId"], exc)
|
logger.warning("gmail attachment extract %s failed: %s", stub["attachmentId"], exc)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -584,27 +608,33 @@ async def _ingestAttachments(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await knowledgeService.requestIngestion(
|
await ingestWithTimeout(
|
||||||
IngestionJob(
|
knowledgeService.requestIngestion(
|
||||||
sourceKind="gmail_attachment",
|
IngestionJob(
|
||||||
sourceId=syntheticId,
|
sourceKind="gmail_attachment",
|
||||||
fileName=fileName,
|
sourceId=syntheticId,
|
||||||
mimeType=mimeType,
|
fileName=fileName,
|
||||||
userId=userId,
|
mimeType=mimeType,
|
||||||
mandateId=mandateId,
|
userId=userId,
|
||||||
contentObjects=contentObjects,
|
mandateId=mandateId,
|
||||||
provenance={
|
contentObjects=contentObjects,
|
||||||
"connectionId": connectionId,
|
provenance={
|
||||||
"dataSourceId": dataSourceId,
|
"connectionId": connectionId,
|
||||||
"authority": "google",
|
"dataSourceId": dataSourceId,
|
||||||
"service": "gmail",
|
"authority": "google",
|
||||||
"parentId": parentSyntheticId,
|
"service": "gmail",
|
||||||
"externalItemId": stub["attachmentId"],
|
"parentId": parentSyntheticId,
|
||||||
"parentMessageId": messageId,
|
"externalItemId": stub["attachmentId"],
|
||||||
},
|
"parentMessageId": messageId,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
label=attLabel,
|
||||||
)
|
)
|
||||||
result.attachmentsIndexed += 1
|
result.attachmentsIndexed += 1
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("gmail attachment ingest %s failed: %s", stub["attachmentId"], exc)
|
logger.warning("gmail attachment ingest %s failed: %s", stub["attachmentId"], exc)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,439 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""kDrive bootstrap for the unified knowledge ingestion lane.
|
||||||
|
|
||||||
|
Walks every ragIndexEnabled kDrive DataSource, downloads file items and
|
||||||
|
hands them to KnowledgeService.requestIngestion. Idempotency is provided
|
||||||
|
by the ingestion facade (content-hash dedup).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from modules.datamodels.datamodelExtraction import ExtractionOptions
|
||||||
|
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
|
||||||
|
WalkerTimeout,
|
||||||
|
downloadWithTimeout,
|
||||||
|
extractWithTimeout,
|
||||||
|
ingestWithTimeout,
|
||||||
|
logItemStart,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAX_ITEMS_DEFAULT = 500
|
||||||
|
MAX_BYTES_DEFAULT = 200 * 1024 * 1024
|
||||||
|
MAX_FILE_SIZE_DEFAULT = 25 * 1024 * 1024
|
||||||
|
SKIP_MIME_PREFIXES_DEFAULT = ("video/", "audio/")
|
||||||
|
MAX_DEPTH_DEFAULT = 4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KdriveBootstrapLimits:
|
||||||
|
maxItems: int = MAX_ITEMS_DEFAULT
|
||||||
|
maxBytes: int = MAX_BYTES_DEFAULT
|
||||||
|
maxFileSize: int = MAX_FILE_SIZE_DEFAULT
|
||||||
|
skipMimePrefixes: tuple = SKIP_MIME_PREFIXES_DEFAULT
|
||||||
|
maxDepth: int = MAX_DEPTH_DEFAULT
|
||||||
|
neutralize: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KdriveBootstrapResult:
|
||||||
|
connectionId: str
|
||||||
|
indexed: int = 0
|
||||||
|
skippedDuplicate: int = 0
|
||||||
|
skippedPolicy: int = 0
|
||||||
|
failed: int = 0
|
||||||
|
bytesProcessed: int = 0
|
||||||
|
errors: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def _syntheticFileId(connectionId: str, externalItemId: str) -> str:
|
||||||
|
token = hashlib.sha256(f"{connectionId}:{externalItemId}".encode("utf-8")).hexdigest()[:16]
|
||||||
|
return f"kd:{connectionId[:8]}:{token}"
|
||||||
|
|
||||||
|
|
||||||
|
def _toContentObjects(extracted, fileName: str) -> List[Dict[str, Any]]:
|
||||||
|
parts = getattr(extracted, "parts", None) or []
|
||||||
|
out: List[Dict[str, Any]] = []
|
||||||
|
for part in parts:
|
||||||
|
data = getattr(part, "data", None) or ""
|
||||||
|
if not data or not str(data).strip():
|
||||||
|
continue
|
||||||
|
typeGroup = getattr(part, "typeGroup", "text") or "text"
|
||||||
|
contentType = "text"
|
||||||
|
if typeGroup == "image":
|
||||||
|
contentType = "image"
|
||||||
|
elif typeGroup in ("binary", "container"):
|
||||||
|
contentType = "other"
|
||||||
|
out.append({
|
||||||
|
"contentObjectId": getattr(part, "id", ""),
|
||||||
|
"contentType": contentType,
|
||||||
|
"data": data,
|
||||||
|
"contextRef": {
|
||||||
|
"containerPath": fileName,
|
||||||
|
"location": getattr(part, "label", None) or "file",
|
||||||
|
**(getattr(part, "metadata", None) or {}),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
async def bootstrapKdrive(
|
||||||
|
connectionId: str,
|
||||||
|
*,
|
||||||
|
dataSources: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
progressCb: Optional[Any] = None,
|
||||||
|
adapter: Any = None,
|
||||||
|
connection: Any = None,
|
||||||
|
knowledgeService: Any = None,
|
||||||
|
limits: Optional[KdriveBootstrapLimits] = None,
|
||||||
|
runExtractionFn: Optional[Callable[..., Any]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Enumerate kDrive folders and ingest files via the facade."""
|
||||||
|
if not dataSources:
|
||||||
|
return {"connectionId": connectionId, "skipped": True, "reason": "no_datasources"}
|
||||||
|
|
||||||
|
if not limits:
|
||||||
|
limits = KdriveBootstrapLimits()
|
||||||
|
|
||||||
|
startMs = time.time()
|
||||||
|
result = KdriveBootstrapResult(connectionId=connectionId)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"ingestion.connection.bootstrap.started part=kdrive connectionId=%s dataSources=%d",
|
||||||
|
connectionId, len(dataSources),
|
||||||
|
extra={"event": "ingestion.connection.bootstrap.started", "part": "kdrive",
|
||||||
|
"connectionId": connectionId, "dataSourceCount": len(dataSources)},
|
||||||
|
)
|
||||||
|
|
||||||
|
if adapter is None or knowledgeService is None or connection is None:
|
||||||
|
adapter, connection, knowledgeService = await _resolveDependencies(connectionId)
|
||||||
|
if runExtractionFn is None:
|
||||||
|
from modules.serviceCenter.services.serviceExtraction.subPipeline import runExtraction
|
||||||
|
from modules.serviceCenter.services.serviceExtraction.subRegistry import (
|
||||||
|
ExtractorRegistry, ChunkerRegistry,
|
||||||
|
)
|
||||||
|
extractorRegistry = ExtractorRegistry()
|
||||||
|
chunkerRegistry = ChunkerRegistry()
|
||||||
|
|
||||||
|
def runExtractionFn(bytesData, name, mime, options):
|
||||||
|
return runExtraction(extractorRegistry, chunkerRegistry, bytesData, name, mime, options)
|
||||||
|
|
||||||
|
mandateId = str(getattr(connection, "mandateId", "") or "") if connection is not None else ""
|
||||||
|
userId = str(getattr(connection, "userId", "") or "") if connection is not None else ""
|
||||||
|
|
||||||
|
cancelled = False
|
||||||
|
for ds in dataSources:
|
||||||
|
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
||||||
|
break
|
||||||
|
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
|
||||||
|
cancelled = True
|
||||||
|
break
|
||||||
|
|
||||||
|
dsPath = ds.get("path", "")
|
||||||
|
dsId = ds.get("id", "")
|
||||||
|
dsNeutralize = ds.get("neutralize", False)
|
||||||
|
dsLimits = KdriveBootstrapLimits(
|
||||||
|
maxItems=limits.maxItems,
|
||||||
|
maxBytes=limits.maxBytes,
|
||||||
|
maxFileSize=limits.maxFileSize,
|
||||||
|
skipMimePrefixes=limits.skipMimePrefixes,
|
||||||
|
maxDepth=limits.maxDepth,
|
||||||
|
neutralize=dsNeutralize,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _walkFolder(
|
||||||
|
adapter=adapter,
|
||||||
|
knowledgeService=knowledgeService,
|
||||||
|
runExtractionFn=runExtractionFn,
|
||||||
|
connectionId=connectionId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
userId=userId,
|
||||||
|
folderPath=dsPath,
|
||||||
|
depth=0,
|
||||||
|
limits=dsLimits,
|
||||||
|
result=result,
|
||||||
|
progressCb=progressCb,
|
||||||
|
dataSourceId=dsId,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("kdrive walk failed for ds %s path %s: %s", dsId, dsPath, exc, exc_info=True)
|
||||||
|
result.errors.append(f"walk({dsPath}): {exc}")
|
||||||
|
|
||||||
|
finalResult = _finalizeResult(connectionId, result, startMs)
|
||||||
|
if cancelled:
|
||||||
|
finalResult["cancelled"] = True
|
||||||
|
return finalResult
|
||||||
|
|
||||||
|
|
||||||
|
async def _resolveDependencies(connectionId: str):
|
||||||
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||||
|
from modules.auth import TokenManager
|
||||||
|
from modules.connectors.providerInfomaniak.connectorInfomaniak import InfomaniakConnector
|
||||||
|
from modules.serviceCenter import getService
|
||||||
|
from modules.serviceCenter.context import ServiceCenterContext
|
||||||
|
from modules.security.rootAccess import getRootUser
|
||||||
|
|
||||||
|
rootInterface = getRootInterface()
|
||||||
|
connection = rootInterface.getUserConnectionById(connectionId)
|
||||||
|
if connection is None:
|
||||||
|
raise ValueError(f"UserConnection not found: {connectionId}")
|
||||||
|
|
||||||
|
token = TokenManager().getFreshToken(connectionId)
|
||||||
|
if not token or not token.tokenAccess:
|
||||||
|
raise ValueError(f"No valid token for connection {connectionId}")
|
||||||
|
|
||||||
|
provider = InfomaniakConnector(connection, token.tokenAccess)
|
||||||
|
adapter = provider.getServiceAdapter("kdrive")
|
||||||
|
|
||||||
|
rootUser = getRootUser()
|
||||||
|
ctx = ServiceCenterContext(
|
||||||
|
user=rootUser,
|
||||||
|
mandate_id=str(getattr(connection, "mandateId", "") or ""),
|
||||||
|
)
|
||||||
|
knowledgeService = getService("knowledge", ctx)
|
||||||
|
return adapter, connection, knowledgeService
|
||||||
|
|
||||||
|
|
||||||
|
async def _walkFolder(
|
||||||
|
*,
|
||||||
|
adapter,
|
||||||
|
knowledgeService,
|
||||||
|
runExtractionFn,
|
||||||
|
connectionId: str,
|
||||||
|
mandateId: str,
|
||||||
|
userId: str,
|
||||||
|
folderPath: str,
|
||||||
|
depth: int,
|
||||||
|
limits: KdriveBootstrapLimits,
|
||||||
|
result: KdriveBootstrapResult,
|
||||||
|
progressCb: Optional[Any],
|
||||||
|
dataSourceId: str = "",
|
||||||
|
) -> None:
|
||||||
|
if depth > limits.maxDepth:
|
||||||
|
return
|
||||||
|
if progressCb and hasattr(progressCb, "isCancelled") and progressCb.isCancelled():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
entries = await adapter.browse(folderPath)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("kdrive browse %s failed: %s", folderPath, exc)
|
||||||
|
result.errors.append(f"browse({folderPath}): {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
if result.indexed + result.skippedDuplicate >= limits.maxItems:
|
||||||
|
return
|
||||||
|
if result.bytesProcessed >= limits.maxBytes:
|
||||||
|
return
|
||||||
|
if progressCb and hasattr(progressCb, "isCancelled") and (result.indexed + result.skippedDuplicate) % 50 == 0 and progressCb.isCancelled():
|
||||||
|
return
|
||||||
|
|
||||||
|
entryPath = getattr(entry, "path", "") or ""
|
||||||
|
if getattr(entry, "isFolder", False):
|
||||||
|
await _walkFolder(
|
||||||
|
adapter=adapter,
|
||||||
|
knowledgeService=knowledgeService,
|
||||||
|
runExtractionFn=runExtractionFn,
|
||||||
|
connectionId=connectionId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
userId=userId,
|
||||||
|
folderPath=entryPath,
|
||||||
|
depth=depth + 1,
|
||||||
|
limits=limits,
|
||||||
|
result=result,
|
||||||
|
progressCb=progressCb,
|
||||||
|
dataSourceId=dataSourceId,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
mimeType = getattr(entry, "mimeType", None) or "application/octet-stream"
|
||||||
|
if any(mimeType.startswith(prefix) for prefix in limits.skipMimePrefixes):
|
||||||
|
result.skippedPolicy += 1
|
||||||
|
continue
|
||||||
|
size = int(getattr(entry, "size", 0) or 0)
|
||||||
|
if size and size > limits.maxFileSize:
|
||||||
|
result.skippedPolicy += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
metadata = getattr(entry, "metadata", {}) or {}
|
||||||
|
externalItemId = metadata.get("id") or entryPath
|
||||||
|
revision = metadata.get("revision") or metadata.get("lastModified")
|
||||||
|
|
||||||
|
await _ingestOne(
|
||||||
|
adapter=adapter,
|
||||||
|
knowledgeService=knowledgeService,
|
||||||
|
runExtractionFn=runExtractionFn,
|
||||||
|
connectionId=connectionId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
userId=userId,
|
||||||
|
entry=entry,
|
||||||
|
entryPath=entryPath,
|
||||||
|
mimeType=mimeType,
|
||||||
|
externalItemId=externalItemId,
|
||||||
|
revision=revision,
|
||||||
|
limits=limits,
|
||||||
|
result=result,
|
||||||
|
progressCb=progressCb,
|
||||||
|
dataSourceId=dataSourceId,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _ingestOne(
|
||||||
|
*,
|
||||||
|
adapter,
|
||||||
|
knowledgeService,
|
||||||
|
runExtractionFn,
|
||||||
|
connectionId: str,
|
||||||
|
mandateId: str,
|
||||||
|
userId: str,
|
||||||
|
entry,
|
||||||
|
entryPath: str,
|
||||||
|
mimeType: str,
|
||||||
|
externalItemId: str,
|
||||||
|
revision: Optional[str],
|
||||||
|
limits: KdriveBootstrapLimits,
|
||||||
|
result: KdriveBootstrapResult,
|
||||||
|
progressCb: Optional[Any],
|
||||||
|
dataSourceId: str = "",
|
||||||
|
) -> None:
|
||||||
|
from modules.serviceCenter.services.serviceKnowledge.mainServiceKnowledge import IngestionJob
|
||||||
|
|
||||||
|
syntheticFileId = _syntheticFileId(connectionId, externalItemId)
|
||||||
|
fileName = getattr(entry, "name", "") or externalItemId
|
||||||
|
declaredSize = int(getattr(entry, "size", 0) or 0) or None
|
||||||
|
logItemStart("kdrive", entryPath, sizeBytes=declaredSize, mime=mimeType)
|
||||||
|
|
||||||
|
try:
|
||||||
|
downloadResult = await downloadWithTimeout(adapter.download(entryPath), label=entryPath)
|
||||||
|
fileBytes = getattr(downloadResult, "data", None)
|
||||||
|
dlFileName = getattr(downloadResult, "fileName", None)
|
||||||
|
dlMimeType = getattr(downloadResult, "mimeType", None)
|
||||||
|
if dlFileName:
|
||||||
|
fileName = dlFileName
|
||||||
|
if dlMimeType:
|
||||||
|
mimeType = dlMimeType
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("kdrive download %s failed: %s", entryPath, exc)
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(f"download({entryPath}): {exc}")
|
||||||
|
return
|
||||||
|
if not fileBytes:
|
||||||
|
result.failed += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
result.bytesProcessed += len(fileBytes)
|
||||||
|
|
||||||
|
try:
|
||||||
|
extracted = await extractWithTimeout(
|
||||||
|
runExtractionFn,
|
||||||
|
fileBytes, fileName, mimeType,
|
||||||
|
ExtractionOptions(mergeStrategy=None),
|
||||||
|
label=entryPath,
|
||||||
|
)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("kdrive extraction %s failed: %s", entryPath, exc)
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(f"extract({entryPath}): {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
contentObjects = _toContentObjects(extracted, fileName)
|
||||||
|
if not contentObjects:
|
||||||
|
result.skippedPolicy += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
provenance: Dict[str, Any] = {
|
||||||
|
"connectionId": connectionId,
|
||||||
|
"dataSourceId": dataSourceId,
|
||||||
|
"authority": "infomaniak",
|
||||||
|
"service": "kdrive",
|
||||||
|
"externalItemId": externalItemId,
|
||||||
|
"externalPath": entryPath,
|
||||||
|
"revision": revision,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
handle = await ingestWithTimeout(
|
||||||
|
knowledgeService.requestIngestion(
|
||||||
|
IngestionJob(
|
||||||
|
sourceKind="kdrive_item",
|
||||||
|
sourceId=syntheticFileId,
|
||||||
|
fileName=fileName,
|
||||||
|
mimeType=mimeType,
|
||||||
|
userId=userId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
contentObjects=contentObjects,
|
||||||
|
contentVersion=revision,
|
||||||
|
neutralize=limits.neutralize,
|
||||||
|
provenance=provenance,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
label=entryPath,
|
||||||
|
)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("kdrive ingestion %s failed: %s", entryPath, exc, exc_info=True)
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(f"ingest({entryPath}): {exc}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if handle.status == "duplicate":
|
||||||
|
result.skippedDuplicate += 1
|
||||||
|
elif handle.status == "indexed":
|
||||||
|
result.indexed += 1
|
||||||
|
else:
|
||||||
|
result.failed += 1
|
||||||
|
if handle.error:
|
||||||
|
result.errors.append(f"ingest({entryPath}): {handle.error}")
|
||||||
|
|
||||||
|
processed = result.indexed + result.skippedDuplicate
|
||||||
|
if progressCb is not None and processed % 5 == 0:
|
||||||
|
try:
|
||||||
|
progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _finalizeResult(connectionId: str, result: KdriveBootstrapResult, startMs: float) -> Dict[str, Any]:
|
||||||
|
durationMs = int((time.time() - startMs) * 1000)
|
||||||
|
logger.info(
|
||||||
|
"ingestion.connection.bootstrap.done part=kdrive connectionId=%s indexed=%d skippedDup=%d skippedPolicy=%d failed=%d durationMs=%d",
|
||||||
|
connectionId,
|
||||||
|
result.indexed, result.skippedDuplicate, result.skippedPolicy, result.failed,
|
||||||
|
durationMs,
|
||||||
|
extra={"event": "ingestion.connection.bootstrap.done", "part": "kdrive",
|
||||||
|
"connectionId": connectionId, "indexed": result.indexed,
|
||||||
|
"skippedDup": result.skippedDuplicate, "skippedPolicy": result.skippedPolicy,
|
||||||
|
"failed": result.failed, "durationMs": durationMs},
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"connectionId": result.connectionId,
|
||||||
|
"indexed": result.indexed,
|
||||||
|
"skippedDuplicate": result.skippedDuplicate,
|
||||||
|
"skippedPolicy": result.skippedPolicy,
|
||||||
|
"failed": result.failed,
|
||||||
|
"bytesProcessed": result.bytesProcessed,
|
||||||
|
"durationMs": durationMs,
|
||||||
|
"errors": result.errors[:20],
|
||||||
|
}
|
||||||
|
|
@ -21,6 +21,12 @@ from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from modules.serviceCenter.services.serviceKnowledge.subTextClean import cleanEmailBody
|
from modules.serviceCenter.services.serviceKnowledge.subTextClean import cleanEmailBody
|
||||||
|
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
|
||||||
|
WalkerTimeout,
|
||||||
|
extractWithTimeout,
|
||||||
|
ingestWithTimeout,
|
||||||
|
logItemStart,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -384,34 +390,42 @@ async def _ingestMessage(
|
||||||
subject = message.get("subject") or "(no subject)"
|
subject = message.get("subject") or "(no subject)"
|
||||||
syntheticId = _syntheticMessageId(connectionId, messageId)
|
syntheticId = _syntheticMessageId(connectionId, messageId)
|
||||||
fileName = f"{subject[:80].strip()}.eml" if subject else f"{messageId}.eml"
|
fileName = f"{subject[:80].strip()}.eml" if subject else f"{messageId}.eml"
|
||||||
|
logItemStart("outlook", messageId, mime="message/rfc822")
|
||||||
|
|
||||||
contentObjects = _buildContentObjects(
|
contentObjects = _buildContentObjects(
|
||||||
message, limits.maxBodyChars, mailContentDepth=limits.mailContentDepth
|
message, limits.maxBodyChars, mailContentDepth=limits.mailContentDepth
|
||||||
)
|
)
|
||||||
# Always at least the header is emitted, so `contentObjects` is non-empty.
|
# Always at least the header is emitted, so `contentObjects` is non-empty.
|
||||||
try:
|
try:
|
||||||
handle = await knowledgeService.requestIngestion(
|
handle = await ingestWithTimeout(
|
||||||
IngestionJob(
|
knowledgeService.requestIngestion(
|
||||||
sourceKind="outlook_message",
|
IngestionJob(
|
||||||
sourceId=syntheticId,
|
sourceKind="outlook_message",
|
||||||
fileName=fileName,
|
sourceId=syntheticId,
|
||||||
mimeType="message/rfc822",
|
fileName=fileName,
|
||||||
userId=userId,
|
mimeType="message/rfc822",
|
||||||
mandateId=mandateId,
|
userId=userId,
|
||||||
contentObjects=contentObjects,
|
mandateId=mandateId,
|
||||||
contentVersion=revision,
|
contentObjects=contentObjects,
|
||||||
neutralize=limits.neutralize,
|
contentVersion=revision,
|
||||||
provenance={
|
neutralize=limits.neutralize,
|
||||||
"connectionId": connectionId,
|
provenance={
|
||||||
"dataSourceId": dataSourceId,
|
"connectionId": connectionId,
|
||||||
"authority": "msft",
|
"dataSourceId": dataSourceId,
|
||||||
"service": "outlook",
|
"authority": "msft",
|
||||||
"externalItemId": messageId,
|
"service": "outlook",
|
||||||
"internetMessageId": message.get("internetMessageId"),
|
"externalItemId": messageId,
|
||||||
"tier": limits.mailContentDepth,
|
"internetMessageId": message.get("internetMessageId"),
|
||||||
},
|
"tier": limits.mailContentDepth,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
label=messageId,
|
||||||
)
|
)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("outlook ingestion %s failed: %s", messageId, exc, exc_info=True)
|
logger.error("outlook ingestion %s failed: %s", messageId, exc, exc_info=True)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -443,18 +457,16 @@ async def _ingestMessage(
|
||||||
logger.warning("outlook attachments %s failed: %s", messageId, exc)
|
logger.warning("outlook attachments %s failed: %s", messageId, exc)
|
||||||
result.errors.append(f"attachments({messageId}): {exc}")
|
result.errors.append(f"attachments({messageId}): {exc}")
|
||||||
|
|
||||||
if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0:
|
processed = result.indexed + result.skippedDuplicate
|
||||||
processed = result.indexed + result.skippedDuplicate
|
if progressCb is not None and processed % 5 == 0:
|
||||||
try:
|
try:
|
||||||
progressCb(
|
progressCb(0, f"{processed} Mails verarbeitet, {result.indexed} indexiert")
|
||||||
min(90, 10 + int(80 * processed / max(1, limits.maxMessages))),
|
|
||||||
f"outlook processed={processed}",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logger.info(
|
if processed % 50 == 0:
|
||||||
"ingestion.connection.bootstrap.progress part=outlook processed=%d skippedDup=%d failed=%d",
|
logger.info(
|
||||||
processed, result.skippedDuplicate, result.failed,
|
"ingestion.connection.bootstrap.progress part=outlook processed=%d skippedDup=%d failed=%d",
|
||||||
|
processed, result.skippedDuplicate, result.failed,
|
||||||
extra={
|
extra={
|
||||||
"event": "ingestion.connection.bootstrap.progress",
|
"event": "ingestion.connection.bootstrap.progress",
|
||||||
"part": "outlook",
|
"part": "outlook",
|
||||||
|
|
@ -518,13 +530,22 @@ async def _ingestAttachments(
|
||||||
mimeType = attachment.get("contentType") or "application/octet-stream"
|
mimeType = attachment.get("contentType") or "application/octet-stream"
|
||||||
attachmentId = attachment.get("id") or fileName
|
attachmentId = attachment.get("id") or fileName
|
||||||
syntheticId = _syntheticAttachmentId(connectionId, messageId, attachmentId)
|
syntheticId = _syntheticAttachmentId(connectionId, messageId, attachmentId)
|
||||||
|
attLabel = f"{messageId}/att:{attachmentId}/{fileName}"
|
||||||
|
logItemStart("outlook-attachment", attLabel, sizeBytes=size or None, mime=mimeType)
|
||||||
|
|
||||||
try:
|
def _runAttExtraction():
|
||||||
extracted = runExtraction(
|
return runExtraction(
|
||||||
extractorRegistry, chunkerRegistry,
|
extractorRegistry, chunkerRegistry,
|
||||||
rawBytes, fileName, mimeType,
|
rawBytes, fileName, mimeType,
|
||||||
ExtractionOptions(mergeStrategy=None),
|
ExtractionOptions(mergeStrategy=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
extracted = await extractWithTimeout(_runAttExtraction, label=attLabel)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
continue
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("outlook attachment extract %s failed: %s", attachmentId, exc)
|
logger.warning("outlook attachment extract %s failed: %s", attachmentId, exc)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -556,28 +577,34 @@ async def _ingestAttachments(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await knowledgeService.requestIngestion(
|
await ingestWithTimeout(
|
||||||
IngestionJob(
|
knowledgeService.requestIngestion(
|
||||||
sourceKind="outlook_attachment",
|
IngestionJob(
|
||||||
sourceId=syntheticId,
|
sourceKind="outlook_attachment",
|
||||||
fileName=fileName,
|
sourceId=syntheticId,
|
||||||
mimeType=mimeType,
|
fileName=fileName,
|
||||||
userId=userId,
|
mimeType=mimeType,
|
||||||
mandateId=mandateId,
|
userId=userId,
|
||||||
contentObjects=contentObjects,
|
mandateId=mandateId,
|
||||||
neutralize=limits.neutralize,
|
contentObjects=contentObjects,
|
||||||
provenance={
|
neutralize=limits.neutralize,
|
||||||
"connectionId": connectionId,
|
provenance={
|
||||||
"dataSourceId": dataSourceId,
|
"connectionId": connectionId,
|
||||||
"authority": "msft",
|
"dataSourceId": dataSourceId,
|
||||||
"service": "outlook",
|
"authority": "msft",
|
||||||
"parentId": parentSyntheticId,
|
"service": "outlook",
|
||||||
"externalItemId": attachmentId,
|
"parentId": parentSyntheticId,
|
||||||
"parentMessageId": messageId,
|
"externalItemId": attachmentId,
|
||||||
},
|
"parentMessageId": messageId,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
label=attLabel,
|
||||||
)
|
)
|
||||||
result.attachmentsIndexed += 1
|
result.attachmentsIndexed += 1
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("outlook attachment ingest %s failed: %s", attachmentId, exc)
|
logger.warning("outlook attachment ingest %s failed: %s", attachmentId, exc)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,13 @@ from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from modules.datamodels.datamodelExtraction import ExtractionOptions
|
from modules.datamodels.datamodelExtraction import ExtractionOptions
|
||||||
|
from modules.serviceCenter.services.serviceKnowledge.subWalkerHelpers import (
|
||||||
|
WalkerTimeout,
|
||||||
|
downloadWithTimeout,
|
||||||
|
extractWithTimeout,
|
||||||
|
ingestWithTimeout,
|
||||||
|
logItemStart,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -330,9 +337,15 @@ async def _ingestOne(
|
||||||
|
|
||||||
syntheticFileId = _syntheticFileId(connectionId, externalItemId)
|
syntheticFileId = _syntheticFileId(connectionId, externalItemId)
|
||||||
fileName = getattr(entry, "name", "") or externalItemId
|
fileName = getattr(entry, "name", "") or externalItemId
|
||||||
|
declaredSize = int(getattr(entry, "size", 0) or 0) or None
|
||||||
|
logItemStart("sharepoint", entryPath, sizeBytes=declaredSize, mime=mimeType)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fileBytes = await adapter.download(entryPath)
|
fileBytes = await downloadWithTimeout(adapter.download(entryPath), label=entryPath)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("sharepoint download %s failed: %s", entryPath, exc)
|
logger.warning("sharepoint download %s failed: %s", entryPath, exc)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -345,10 +358,16 @@ async def _ingestOne(
|
||||||
result.bytesProcessed += len(fileBytes)
|
result.bytesProcessed += len(fileBytes)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
extracted = runExtractionFn(
|
extracted = await extractWithTimeout(
|
||||||
|
runExtractionFn,
|
||||||
fileBytes, fileName, mimeType,
|
fileBytes, fileName, mimeType,
|
||||||
ExtractionOptions(mergeStrategy=None),
|
ExtractionOptions(mergeStrategy=None),
|
||||||
|
label=entryPath,
|
||||||
)
|
)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("sharepoint extraction %s failed: %s", entryPath, exc)
|
logger.warning("sharepoint extraction %s failed: %s", entryPath, exc)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -370,20 +389,27 @@ async def _ingestOne(
|
||||||
"revision": revision,
|
"revision": revision,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
handle = await knowledgeService.requestIngestion(
|
handle = await ingestWithTimeout(
|
||||||
IngestionJob(
|
knowledgeService.requestIngestion(
|
||||||
sourceKind="sharepoint_item",
|
IngestionJob(
|
||||||
sourceId=syntheticFileId,
|
sourceKind="sharepoint_item",
|
||||||
fileName=fileName,
|
sourceId=syntheticFileId,
|
||||||
mimeType=mimeType,
|
fileName=fileName,
|
||||||
userId=userId,
|
mimeType=mimeType,
|
||||||
mandateId=mandateId,
|
userId=userId,
|
||||||
contentObjects=contentObjects,
|
mandateId=mandateId,
|
||||||
contentVersion=revision,
|
contentObjects=contentObjects,
|
||||||
neutralize=limits.neutralize,
|
contentVersion=revision,
|
||||||
provenance=provenance,
|
neutralize=limits.neutralize,
|
||||||
)
|
provenance=provenance,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
label=entryPath,
|
||||||
)
|
)
|
||||||
|
except WalkerTimeout as exc:
|
||||||
|
result.failed += 1
|
||||||
|
result.errors.append(str(exc))
|
||||||
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("sharepoint ingestion %s failed: %s", entryPath, exc, exc_info=True)
|
logger.error("sharepoint ingestion %s failed: %s", entryPath, exc, exc_info=True)
|
||||||
result.failed += 1
|
result.failed += 1
|
||||||
|
|
@ -399,27 +425,17 @@ async def _ingestOne(
|
||||||
if handle.error:
|
if handle.error:
|
||||||
result.errors.append(f"ingest({entryPath}): {handle.error}")
|
result.errors.append(f"ingest({entryPath}): {handle.error}")
|
||||||
|
|
||||||
if progressCb is not None and (result.indexed + result.skippedDuplicate) % 50 == 0:
|
processed = result.indexed + result.skippedDuplicate
|
||||||
processed = result.indexed + result.skippedDuplicate
|
if progressCb is not None and processed % 5 == 0:
|
||||||
try:
|
try:
|
||||||
progressCb(
|
progressCb(0, f"{processed} Dateien verarbeitet, {result.indexed} indexiert")
|
||||||
min(90, 10 + int(80 * processed / max(1, limits.maxItems))),
|
|
||||||
f"sharepoint processed={processed}",
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
logger.info(
|
if processed % 50 == 0:
|
||||||
"ingestion.connection.bootstrap.progress part=sharepoint processed=%d skippedDup=%d failed=%d",
|
logger.info(
|
||||||
processed, result.skippedDuplicate, result.failed,
|
"ingestion.connection.bootstrap.progress part=sharepoint processed=%d indexed=%d failed=%d",
|
||||||
extra={
|
processed, result.indexed, result.failed,
|
||||||
"event": "ingestion.connection.bootstrap.progress",
|
)
|
||||||
"part": "sharepoint",
|
|
||||||
"connectionId": connectionId,
|
|
||||||
"processed": processed,
|
|
||||||
"skippedDup": result.skippedDuplicate,
|
|
||||||
"failed": result.failed,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Yield so the event loop can interleave other tasks (download/extract are
|
# Yield so the event loop can interleave other tasks (download/extract are
|
||||||
# CPU-ish and extraction uses sync libs; cooperative scheduling prevents
|
# CPU-ish and extraction uses sync libs; cooperative scheduling prevents
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright (c) 2025 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Shared helpers for ingestion walkers (timeouts, per-item logging).
|
||||||
|
|
||||||
|
Walkers (sharepoint, gdrive, gmail, outlook, clickup, kdrive) all face the
|
||||||
|
same risks:
|
||||||
|
|
||||||
|
- A single `adapter.download()` call can hang on the network for hours.
|
||||||
|
- A single `runExtraction()` call can hang on a corrupt PDF/Office doc inside
|
||||||
|
a sync extractor library, blocking the asyncio loop.
|
||||||
|
- A single `requestIngestion()` call can stall on the embedding API.
|
||||||
|
|
||||||
|
Without timeouts, one bad item freezes the whole bootstrap job and we end
|
||||||
|
up with "Job stuck at 10% for 10h" zombies.
|
||||||
|
|
||||||
|
These helpers wrap each phase in `asyncio.wait_for`. Sync extraction runs
|
||||||
|
on a worker thread so the loop stays responsive. Every wrapped call also
|
||||||
|
emits a short start/done log line, so when something hangs we know the
|
||||||
|
exact item that caused it (path, size, mime).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Awaitable, Callable, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DOWNLOAD_TIMEOUT_S = 60
|
||||||
|
EXTRACTION_TIMEOUT_S = 90
|
||||||
|
INGEST_TIMEOUT_S = 60
|
||||||
|
|
||||||
|
|
||||||
|
class WalkerTimeout(Exception):
|
||||||
|
"""Raised when a walker phase exceeds its timeout budget."""
|
||||||
|
|
||||||
|
|
||||||
|
async def downloadWithTimeout(
|
||||||
|
awaitable: Awaitable[Any],
|
||||||
|
*,
|
||||||
|
label: str,
|
||||||
|
timeoutSeconds: int = DOWNLOAD_TIMEOUT_S,
|
||||||
|
) -> Any:
|
||||||
|
"""Run a download awaitable with a hard timeout.
|
||||||
|
|
||||||
|
`label` is a short human-readable identifier (typically the external path)
|
||||||
|
used in log messages so we can pinpoint the offending item in case of a
|
||||||
|
hang or timeout.
|
||||||
|
"""
|
||||||
|
logger.info("walker.download.start %s timeout=%ds", label, timeoutSeconds)
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds)
|
||||||
|
logger.debug("walker.download.done %s", label)
|
||||||
|
return result
|
||||||
|
except asyncio.TimeoutError as ex:
|
||||||
|
logger.warning("walker.download.timeout %s after %ds", label, timeoutSeconds)
|
||||||
|
raise WalkerTimeout(f"download timeout after {timeoutSeconds}s: {label}") from ex
|
||||||
|
|
||||||
|
|
||||||
|
async def extractWithTimeout(
|
||||||
|
syncFn: Callable[..., Any],
|
||||||
|
*args: Any,
|
||||||
|
label: str,
|
||||||
|
timeoutSeconds: int = EXTRACTION_TIMEOUT_S,
|
||||||
|
) -> Any:
|
||||||
|
"""Run a synchronous extraction function on a worker thread with timeout.
|
||||||
|
|
||||||
|
Sync extractors (PDF, OCR, MS Office) cannot be cancelled cleanly from
|
||||||
|
asyncio; `wait_for` only protects the awaiter. The underlying thread may
|
||||||
|
keep running until the process exits — but at least the walker proceeds
|
||||||
|
to the next item instead of freezing forever.
|
||||||
|
"""
|
||||||
|
logger.info("walker.extract.start %s timeout=%ds", label, timeoutSeconds)
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(syncFn, *args),
|
||||||
|
timeout=timeoutSeconds,
|
||||||
|
)
|
||||||
|
logger.debug("walker.extract.done %s", label)
|
||||||
|
return result
|
||||||
|
except asyncio.TimeoutError as ex:
|
||||||
|
logger.warning("walker.extract.timeout %s after %ds", label, timeoutSeconds)
|
||||||
|
raise WalkerTimeout(f"extract timeout after {timeoutSeconds}s: {label}") from ex
|
||||||
|
|
||||||
|
|
||||||
|
async def ingestWithTimeout(
|
||||||
|
awaitable: Awaitable[Any],
|
||||||
|
*,
|
||||||
|
label: str,
|
||||||
|
timeoutSeconds: int = INGEST_TIMEOUT_S,
|
||||||
|
) -> Any:
|
||||||
|
"""Run an ingestion request with a hard timeout."""
|
||||||
|
logger.debug("walker.ingest.start %s timeout=%ds", label, timeoutSeconds)
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(awaitable, timeout=timeoutSeconds)
|
||||||
|
logger.debug("walker.ingest.done %s", label)
|
||||||
|
return result
|
||||||
|
except asyncio.TimeoutError as ex:
|
||||||
|
logger.warning("walker.ingest.timeout %s after %ds", label, timeoutSeconds)
|
||||||
|
raise WalkerTimeout(f"ingest timeout after {timeoutSeconds}s: {label}") from ex
|
||||||
|
|
||||||
|
|
||||||
|
def logItemStart(service: str, label: str, *, sizeBytes: Optional[int] = None, mime: Optional[str] = None) -> None:
|
||||||
|
"""Log that processing of one item is about to begin.
|
||||||
|
|
||||||
|
When the worker hangs, the LAST `walker.item.start` line in the log
|
||||||
|
points to the exact item that caused the freeze. This is the single
|
||||||
|
most valuable diagnostic for stuck-job triage.
|
||||||
|
"""
|
||||||
|
parts = [f"walker.item.start service={service} path={label}"]
|
||||||
|
if sizeBytes is not None:
|
||||||
|
parts.append(f"size={sizeBytes}")
|
||||||
|
if mime:
|
||||||
|
parts.append(f"mime={mime}")
|
||||||
|
logger.info(" ".join(parts))
|
||||||
|
|
@ -85,6 +85,11 @@ class AiAuditLogger:
|
||||||
try:
|
try:
|
||||||
from modules.datamodels.datamodelAiAudit import AiAuditLogEntry
|
from modules.datamodels.datamodelAiAudit import AiAuditLogEntry
|
||||||
|
|
||||||
|
if contentInput:
|
||||||
|
contentInput = contentInput.replace("\x00", "")
|
||||||
|
if contentOutput:
|
||||||
|
contentOutput = contentOutput.replace("\x00", "")
|
||||||
|
|
||||||
inputPreview = (contentInput or "")[:_PREVIEW_LENGTH] or None
|
inputPreview = (contentInput or "")[:_PREVIEW_LENGTH] or None
|
||||||
outputPreview = (contentOutput or "")[:_PREVIEW_LENGTH] or None
|
outputPreview = (contentOutput or "")[:_PREVIEW_LENGTH] or None
|
||||||
inputHash = hashlib.sha256(contentInput.encode("utf-8")).hexdigest() if contentInput else None
|
inputHash = hashlib.sha256(contentInput.encode("utf-8")).hexdigest() if contentInput else None
|
||||||
|
|
|
||||||
|
|
@ -330,6 +330,16 @@ NAVIGATION_SECTIONS = [
|
||||||
"adminOnly": True,
|
"adminOnly": True,
|
||||||
"sysAdminOnly": True,
|
"sysAdminOnly": True,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"id": "admin-stt-benchmark",
|
||||||
|
"objectKey": "ui.admin.sttBenchmark",
|
||||||
|
"label": t("STT Benchmark"),
|
||||||
|
"icon": "FaMicrophone",
|
||||||
|
"path": "/admin/stt-benchmark",
|
||||||
|
"order": 92,
|
||||||
|
"adminOnly": True,
|
||||||
|
"sysAdminOnly": True,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"id": "admin-languages",
|
"id": "admin-languages",
|
||||||
"objectKey": "ui.admin.languages",
|
"objectKey": "ui.admin.languages",
|
||||||
|
|
|
||||||
3
tests/eval/__init__.py
Normal file
3
tests/eval/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Eval harness for the Feature Data Sub-Agent (Phase 1.5)."""
|
||||||
246
tests/eval/fakeFeatureDataProvider.py
Normal file
246
tests/eval/fakeFeatureDataProvider.py
Normal file
|
|
@ -0,0 +1,246 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""In-memory drop-in for FeatureDataProvider used by the eval harness.
|
||||||
|
|
||||||
|
Implements the same three public methods (browseTable / queryTable /
|
||||||
|
aggregateTable) plus the small surface the Sub-Agent reads (getActualColumns),
|
||||||
|
but runs all filters/aggregations in Python over the BenchmarkFixture rows.
|
||||||
|
|
||||||
|
This keeps the eval hermetic: no DB connection, no fixtures to insert/clean,
|
||||||
|
no flakiness from shared test schemas. Only the LLM call is real.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
_ALLOWED_AGGREGATES = {"SUM", "COUNT", "AVG", "MIN", "MAX"}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeFeatureDataProvider:
|
||||||
|
"""In-memory provider compatible with :class:`FeatureDataProvider`."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rowsByTable: Dict[str, List[Dict[str, Any]]],
|
||||||
|
availableTables: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
) -> None:
|
||||||
|
self._rowsByTable = {name: list(rows) for name, rows in rowsByTable.items()}
|
||||||
|
self._availableTables = list(availableTables or [])
|
||||||
|
self.callLog: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]: # noqa: ARG002
|
||||||
|
return list(self._availableTables)
|
||||||
|
|
||||||
|
def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]: # noqa: ARG002
|
||||||
|
for obj in self._availableTables:
|
||||||
|
if obj.get("meta", {}).get("table") == tableName:
|
||||||
|
return obj
|
||||||
|
return None
|
||||||
|
|
||||||
|
def getActualColumns(self, tableName: str) -> List[str]:
|
||||||
|
rows = self._rowsByTable.get(tableName, [])
|
||||||
|
if not rows:
|
||||||
|
return []
|
||||||
|
seen: List[str] = []
|
||||||
|
seenSet: set = set()
|
||||||
|
for row in rows:
|
||||||
|
for key in row.keys():
|
||||||
|
if key not in seenSet:
|
||||||
|
seen.append(key)
|
||||||
|
seenSet.add(key)
|
||||||
|
return seen
|
||||||
|
|
||||||
|
def browseTable(
|
||||||
|
self,
|
||||||
|
tableName: str,
|
||||||
|
featureInstanceId: str,
|
||||||
|
mandateId: str,
|
||||||
|
fields: List[str] = None,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
self.callLog.append({"method": "browseTable", "table": tableName, "fields": fields, "limit": limit})
|
||||||
|
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
|
||||||
|
rows = _applyFilters(rows, extraFilters)
|
||||||
|
total = len(rows)
|
||||||
|
rows = rows[offset : offset + limit]
|
||||||
|
if fields:
|
||||||
|
rows = [{k: v for k, v in row.items() if k in fields} for row in rows]
|
||||||
|
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
|
||||||
|
|
||||||
|
def queryTable(
|
||||||
|
self,
|
||||||
|
tableName: str,
|
||||||
|
featureInstanceId: str,
|
||||||
|
mandateId: str,
|
||||||
|
filters: List[Dict[str, Any]] = None,
|
||||||
|
fields: List[str] = None,
|
||||||
|
orderBy: str = None,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
self.callLog.append({
|
||||||
|
"method": "queryTable", "table": tableName, "filters": filters,
|
||||||
|
"fields": fields, "orderBy": orderBy, "limit": limit,
|
||||||
|
})
|
||||||
|
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
|
||||||
|
combined = list(filters or []) + list(extraFilters or [])
|
||||||
|
rows = _applyFilters(rows, combined)
|
||||||
|
if orderBy:
|
||||||
|
try:
|
||||||
|
rows = sorted(rows, key=lambda r: (r.get(orderBy) is None, r.get(orderBy)))
|
||||||
|
except TypeError:
|
||||||
|
rows = sorted(rows, key=lambda r: str(r.get(orderBy)))
|
||||||
|
total = len(rows)
|
||||||
|
rows = rows[offset : offset + limit]
|
||||||
|
if fields:
|
||||||
|
rows = [{k: v for k, v in row.items() if k in fields} for row in rows]
|
||||||
|
return {"rows": rows, "total": total, "limit": limit, "offset": offset}
|
||||||
|
|
||||||
|
def aggregateTable(
|
||||||
|
self,
|
||||||
|
tableName: str,
|
||||||
|
featureInstanceId: str,
|
||||||
|
mandateId: str,
|
||||||
|
aggregate: str,
|
||||||
|
field: str,
|
||||||
|
groupBy: str = None,
|
||||||
|
extraFilters: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
self.callLog.append({
|
||||||
|
"method": "aggregateTable", "table": tableName,
|
||||||
|
"aggregate": aggregate, "field": field, "groupBy": groupBy,
|
||||||
|
})
|
||||||
|
aggregate = aggregate.upper()
|
||||||
|
if aggregate not in _ALLOWED_AGGREGATES:
|
||||||
|
return {"rows": [], "error": f"Unsupported aggregate: {aggregate}"}
|
||||||
|
rows = self._scopeRows(tableName, featureInstanceId, mandateId)
|
||||||
|
rows = _applyFilters(rows, extraFilters)
|
||||||
|
|
||||||
|
if groupBy:
|
||||||
|
groups: Dict[Any, List[Dict[str, Any]]] = {}
|
||||||
|
for row in rows:
|
||||||
|
groups.setdefault(row.get(groupBy), []).append(row)
|
||||||
|
outRows = [
|
||||||
|
{"groupValue": key, "result": _aggregate(aggregate, [r.get(field) for r in grp])}
|
||||||
|
for key, grp in groups.items()
|
||||||
|
]
|
||||||
|
outRows.sort(key=lambda r: (r["result"] is None, -(r["result"] or 0)))
|
||||||
|
else:
|
||||||
|
outRows = [{"result": _aggregate(aggregate, [r.get(field) for r in rows])}]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rows": outRows,
|
||||||
|
"aggregate": aggregate,
|
||||||
|
"field": field,
|
||||||
|
"groupBy": groupBy,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _scopeRows(self, tableName: str, featureInstanceId: str, mandateId: str) -> List[Dict[str, Any]]:
|
||||||
|
rows = self._rowsByTable.get(tableName, [])
|
||||||
|
return [
|
||||||
|
row for row in rows
|
||||||
|
if (row.get("featureInstanceId") in (None, featureInstanceId))
|
||||||
|
and (row.get("mandateId") in (None, mandateId))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _applyFilters(rows: List[Dict[str, Any]], filters: Optional[List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||||
|
if not filters:
|
||||||
|
return rows
|
||||||
|
out = rows
|
||||||
|
for f in filters:
|
||||||
|
field = f.get("field")
|
||||||
|
op = (f.get("op") or "=").upper()
|
||||||
|
value = f.get("value")
|
||||||
|
out = [r for r in out if _matchesFilter(r.get(field), op, value)]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _matchesFilter(rowValue: Any, op: str, filterValue: Any) -> bool:
|
||||||
|
if op in ("IS NULL",):
|
||||||
|
return rowValue is None
|
||||||
|
if op in ("IS NOT NULL",):
|
||||||
|
return rowValue is not None
|
||||||
|
if rowValue is None:
|
||||||
|
return False
|
||||||
|
if op == "=":
|
||||||
|
return _coerceEqual(rowValue, filterValue)
|
||||||
|
if op == "!=":
|
||||||
|
return not _coerceEqual(rowValue, filterValue)
|
||||||
|
if op == ">":
|
||||||
|
return _coerceFloat(rowValue) > _coerceFloat(filterValue)
|
||||||
|
if op == "<":
|
||||||
|
return _coerceFloat(rowValue) < _coerceFloat(filterValue)
|
||||||
|
if op == ">=":
|
||||||
|
return _coerceFloat(rowValue) >= _coerceFloat(filterValue)
|
||||||
|
if op == "<=":
|
||||||
|
return _coerceFloat(rowValue) <= _coerceFloat(filterValue)
|
||||||
|
if op in ("LIKE", "ILIKE"):
|
||||||
|
pattern = str(filterValue or "")
|
||||||
|
target = str(rowValue)
|
||||||
|
if op == "ILIKE":
|
||||||
|
pattern = pattern.lower()
|
||||||
|
target = target.lower()
|
||||||
|
return _sqlLike(target, pattern)
|
||||||
|
if op == "IN":
|
||||||
|
if isinstance(filterValue, (list, tuple, set)):
|
||||||
|
return any(_coerceEqual(rowValue, v) for v in filterValue)
|
||||||
|
return _coerceEqual(rowValue, filterValue)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _coerceEqual(a: Any, b: Any) -> bool:
|
||||||
|
if a == b:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
return str(a) == str(b)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _coerceFloat(value: Any) -> float:
|
||||||
|
if value is None:
|
||||||
|
return 0.0
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _sqlLike(value: str, pattern: str) -> bool:
|
||||||
|
"""Approximate SQL LIKE -- only % and _ wildcards."""
|
||||||
|
import re
|
||||||
|
regex = ""
|
||||||
|
i = 0
|
||||||
|
while i < len(pattern):
|
||||||
|
ch = pattern[i]
|
||||||
|
if ch == "%":
|
||||||
|
regex += ".*"
|
||||||
|
elif ch == "_":
|
||||||
|
regex += "."
|
||||||
|
else:
|
||||||
|
regex += re.escape(ch)
|
||||||
|
i += 1
|
||||||
|
return re.fullmatch(regex, value or "") is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _aggregate(op: str, values: List[Any]) -> Any:
|
||||||
|
if op == "COUNT":
|
||||||
|
return sum(1 for v in values if v is not None)
|
||||||
|
nums = [_coerceFloat(v) for v in values if v is not None]
|
||||||
|
if not nums:
|
||||||
|
return 0 if op == "SUM" else None
|
||||||
|
if op == "SUM":
|
||||||
|
return round(sum(nums), 4)
|
||||||
|
if op == "AVG":
|
||||||
|
return round(sum(nums) / len(nums), 4)
|
||||||
|
if op == "MIN":
|
||||||
|
return min(nums)
|
||||||
|
if op == "MAX":
|
||||||
|
return max(nums)
|
||||||
|
return None
|
||||||
735
tests/eval/runTrusteeBenchmark.py
Normal file
735
tests/eval/runTrusteeBenchmark.py
Normal file
|
|
@ -0,0 +1,735 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Trustee Sub-Agent Eval Harness (Phase 1.5).
|
||||||
|
|
||||||
|
Standalone runner that fires real AI calls against the Feature Data
|
||||||
|
Sub-Agent in three configurations:
|
||||||
|
|
||||||
|
* ``baseline`` -- production code without the pre-execute validator
|
||||||
|
(Repair-Loop disabled, Trustee domain hints active).
|
||||||
|
* ``phase1`` -- pre-execute validator on (Repair-Loop active),
|
||||||
|
domain hints active, no ontology yet.
|
||||||
|
* ``phase2`` -- validator on, ontology-driven schema context +
|
||||||
|
constraints (replaces hand-written domain hints).
|
||||||
|
|
||||||
|
For each mode we run all 19 gold-standard questions against an
|
||||||
|
in-memory :class:`FakeFeatureDataProvider`, capture the agent's tool
|
||||||
|
calls and final answer, score them against the gold standard, and
|
||||||
|
write a Markdown report to ``local/notes/`` for analysis.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
cd gateway
|
||||||
|
python -m tests.eval.runTrusteeBenchmark # all 3 modes
|
||||||
|
python -m tests.eval.runTrusteeBenchmark phase1 # one mode only
|
||||||
|
python -m tests.eval.runTrusteeBenchmark baseline phase1
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Path setup so `python -m tests.eval.runTrusteeBenchmark` works from gateway/
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
_GATEWAY_DIR = Path(__file__).resolve().parents[2]
|
||||||
|
if str(_GATEWAY_DIR) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_GATEWAY_DIR))
|
||||||
|
|
||||||
|
import yaml # noqa: E402
|
||||||
|
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( # noqa: E402
|
||||||
|
AgentConfig,
|
||||||
|
AgentEventTypeEnum,
|
||||||
|
)
|
||||||
|
from modules.datamodels.datamodelAi import ( # noqa: E402
|
||||||
|
AiCallRequest,
|
||||||
|
AiCallResponse,
|
||||||
|
OperationTypeEnum,
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop # noqa: E402
|
||||||
|
from modules.serviceCenter.services.serviceAgent.featureDataAgent import ( # noqa: E402
|
||||||
|
_buildSubAgentTools,
|
||||||
|
_buildSchemaContext,
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( # noqa: E402
|
||||||
|
QueryValidationError,
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.queryValidator import ( # noqa: E402
|
||||||
|
QueryValidator,
|
||||||
|
)
|
||||||
|
|
||||||
|
from tests.eval.fakeFeatureDataProvider import ( # noqa: E402
|
||||||
|
FakeFeatureDataProvider,
|
||||||
|
)
|
||||||
|
from tests.fixtures.trusteeBenchmark.loadTrusteeBenchmarkFixture import ( # noqa: E402
|
||||||
|
buildTrusteeBenchmarkFixture,
|
||||||
|
BenchmarkFixture,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger("trusteeBenchmark")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# NoOpValidator -- baseline mode (Repair-Loop OFF)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _NoOpValidator(QueryValidator):
|
||||||
|
"""Validator that never rejects anything (used for baseline measurement)."""
|
||||||
|
|
||||||
|
def validateBrowseQuery(self, tableName, args): # noqa: ARG002
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validateQueryTable(self, tableName, args): # noqa: ARG002
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validateAggregateQuery(self, tableName, args): # noqa: ARG002
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Mode-specific tool/prompt building
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ModeConfig:
|
||||||
|
name: str
|
||||||
|
label: str
|
||||||
|
useValidator: bool
|
||||||
|
useOntology: bool
|
||||||
|
|
||||||
|
|
||||||
|
_MODES: Dict[str, _ModeConfig] = {
|
||||||
|
"baseline": _ModeConfig(name="baseline", label="Baseline (no validator)", useValidator=False, useOntology=False),
|
||||||
|
"phase1": _ModeConfig(name="phase1", label="Phase 1 (validator on)", useValidator=True, useOntology=False),
|
||||||
|
"phase2": _ModeConfig(name="phase2", label="Phase 2 (validator + ontology)", useValidator=True, useOntology=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _buildValidator(mode: _ModeConfig) -> QueryValidator:
|
||||||
|
"""Construct the per-mode validator.
|
||||||
|
|
||||||
|
* baseline: no-op (Repair-Loop disabled, used to measure raw LLM
|
||||||
|
accuracy against today's prompt path).
|
||||||
|
* phase1: convention-based QueryValidator (NEVER_AGGREGATE on
|
||||||
|
``*Balance``/``*Total`` suffixes; no ontology).
|
||||||
|
* phase2: ontology-driven QueryValidator (constraints from the
|
||||||
|
trustee ontology override the convention defaults).
|
||||||
|
"""
|
||||||
|
if not mode.useValidator:
|
||||||
|
return _NoOpValidator()
|
||||||
|
if mode.useOntology:
|
||||||
|
try:
|
||||||
|
from modules.features.trustee.trusteeOntology import getTrusteeOntology
|
||||||
|
return QueryValidator(ontology=getTrusteeOntology())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Could not load trustee ontology, falling back: %s", e)
|
||||||
|
return QueryValidator()
|
||||||
|
|
||||||
|
|
||||||
|
def _applyEnvForMode(mode: _ModeConfig) -> None:
|
||||||
|
"""Set the ontology toggle for the production prompt builder.
|
||||||
|
|
||||||
|
The Phase 2 path uses ``featureDataAgent._buildSchemaContext`` to pull
|
||||||
|
the prompt block from ``getAgentOntology()`` automatically. For
|
||||||
|
baseline/phase1 we set ``POWERON_DISABLE_FEATURE_ONTOLOGY=1`` so the
|
||||||
|
builder falls back to the legacy ``getAgentDomainHints()`` block --
|
||||||
|
measuring exactly the production prompt that ships today.
|
||||||
|
"""
|
||||||
|
if mode.useOntology:
|
||||||
|
os.environ.pop("POWERON_DISABLE_FEATURE_ONTOLOGY", None)
|
||||||
|
else:
|
||||||
|
os.environ["POWERON_DISABLE_FEATURE_ONTOLOGY"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def _buildSystemPrompt(featureCode: str, instanceLabel: str, selectedTables: List[Dict[str, Any]]) -> str:
|
||||||
|
"""Build the sub-agent system prompt via the production path.
|
||||||
|
|
||||||
|
Mode-specific behaviour (legacy hints vs ontology block) is controlled
|
||||||
|
by the ``POWERON_DISABLE_FEATURE_ONTOLOGY`` env flag set per mode in
|
||||||
|
:func:`_applyEnvForMode`. Keeping the builder call identical for all
|
||||||
|
three modes means the benchmark measures the EXACT prompt the agent
|
||||||
|
would see in production -- no eval-only forks.
|
||||||
|
"""
|
||||||
|
return _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang="de")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Question loading + per-question evaluation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Question:
|
||||||
|
id: str
|
||||||
|
question: str
|
||||||
|
intent: str
|
||||||
|
expectedTools: List[str]
|
||||||
|
expectedTable: Optional[str]
|
||||||
|
expectedAggregate: Optional[str]
|
||||||
|
expectedAggregateField: Optional[str]
|
||||||
|
requiredFilters: Dict[str, Any]
|
||||||
|
forbiddenTools: List[str]
|
||||||
|
expectedNumbers: List[float]
|
||||||
|
expectedAnswerContains: List[str]
|
||||||
|
numericTolerance: float
|
||||||
|
|
||||||
|
|
||||||
|
def _loadQuestions(yamlPath: Path) -> List[_Question]:
|
||||||
|
with open(yamlPath, "r", encoding="utf-8") as f:
|
||||||
|
rawList = yaml.safe_load(f)
|
||||||
|
questions: List[_Question] = []
|
||||||
|
for raw in rawList:
|
||||||
|
questions.append(_Question(
|
||||||
|
id=raw["id"],
|
||||||
|
question=raw["question"],
|
||||||
|
intent=raw.get("intent", ""),
|
||||||
|
expectedTools=list(raw.get("expectedTools") or []),
|
||||||
|
expectedTable=raw.get("expectedTable"),
|
||||||
|
expectedAggregate=raw.get("expectedAggregate"),
|
||||||
|
expectedAggregateField=raw.get("expectedAggregateField"),
|
||||||
|
requiredFilters=dict(raw.get("requiredFilters") or {}),
|
||||||
|
forbiddenTools=list(raw.get("forbiddenTools") or []),
|
||||||
|
expectedNumbers=[float(x) for x in (raw.get("expectedNumbers") or [])],
|
||||||
|
expectedAnswerContains=[str(x) for x in (raw.get("expectedAnswerContains") or [])],
|
||||||
|
numericTolerance=float(raw.get("numericTolerance") or 0.005),
|
||||||
|
))
|
||||||
|
return questions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _RunResult:
|
||||||
|
questionId: str
|
||||||
|
finalText: str
|
||||||
|
toolCalls: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
toolResults: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
summary: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
durationS: float = 0.0
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def costCHF(self) -> float:
|
||||||
|
return float(self.summary.get("costCHF") or 0.0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rounds(self) -> int:
|
||||||
|
return int(self.summary.get("rounds") or 0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def validationFailures(self) -> int:
|
||||||
|
return int(self.summary.get("validationFailures") or 0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repairAttempts(self) -> int:
|
||||||
|
return int(self.summary.get("repairAttempts") or 0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def successAfterRepair(self) -> int:
|
||||||
|
return int(self.summary.get("successAfterRepair") or 0)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _Score:
|
||||||
|
patternOk: bool = False
|
||||||
|
forbidOk: bool = False
|
||||||
|
numericOk: bool = False
|
||||||
|
accuracyOk: bool = False
|
||||||
|
notes: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def _scoreRun(question: _Question, run: _RunResult) -> _Score:
|
||||||
|
score = _Score()
|
||||||
|
if run.error:
|
||||||
|
score.notes.append(f"Sub-agent error: {run.error}")
|
||||||
|
return score
|
||||||
|
|
||||||
|
score.patternOk = _checkPattern(question, run)
|
||||||
|
score.forbidOk = _checkForbid(question, run)
|
||||||
|
score.numericOk = _checkNumeric(question, run)
|
||||||
|
score.accuracyOk = score.patternOk and score.forbidOk and score.numericOk
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
def _checkPattern(question: _Question, run: _RunResult) -> bool:
|
||||||
|
"""Did the agent call one of the expected tools on the expected table with required filters?"""
|
||||||
|
if not question.expectedTools:
|
||||||
|
return True
|
||||||
|
matchingCalls = [
|
||||||
|
c for c in run.toolCalls
|
||||||
|
if c.get("toolName") in question.expectedTools
|
||||||
|
and (not question.expectedTable or c.get("args", {}).get("tableName") == question.expectedTable)
|
||||||
|
]
|
||||||
|
if not matchingCalls:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if question.expectedAggregate:
|
||||||
|
wantAgg = question.expectedAggregate.upper()
|
||||||
|
wantField = question.expectedAggregateField
|
||||||
|
for c in matchingCalls:
|
||||||
|
args = c.get("args", {})
|
||||||
|
if c.get("toolName") != "aggregateTable":
|
||||||
|
continue
|
||||||
|
if (args.get("aggregate") or "").upper() != wantAgg:
|
||||||
|
continue
|
||||||
|
if wantField and args.get("field") != wantField:
|
||||||
|
continue
|
||||||
|
if not _filtersSatisfied(question.requiredFilters, args.get("extraFilters") or args.get("filters") or []):
|
||||||
|
continue
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
if question.requiredFilters:
|
||||||
|
for c in matchingCalls:
|
||||||
|
args = c.get("args", {})
|
||||||
|
filters = args.get("filters") or args.get("extraFilters") or []
|
||||||
|
if _filtersSatisfied(question.requiredFilters, filters):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _filtersSatisfied(required: Dict[str, Any], actualFilters: List[Dict[str, Any]]) -> bool:
|
||||||
|
if not required:
|
||||||
|
return True
|
||||||
|
for reqField, reqValue in required.items():
|
||||||
|
if reqField.endswith("Like"):
|
||||||
|
field = reqField[:-4]
|
||||||
|
wanted = str(reqValue)
|
||||||
|
ok = any(
|
||||||
|
(f.get("field") == field) and (f.get("op", "").upper() in ("LIKE", "ILIKE"))
|
||||||
|
and str(f.get("value")) == wanted
|
||||||
|
for f in actualFilters
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
ok = any(
|
||||||
|
f.get("field") == reqField and _filterValueEqual(f.get("value"), reqValue)
|
||||||
|
for f in actualFilters
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _filterValueEqual(a: Any, b: Any) -> bool:
|
||||||
|
if a == b:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
return str(a).strip() == str(b).strip()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _checkForbid(question: _Question, run: _RunResult) -> bool:
|
||||||
|
"""Did the agent AVOID forbidden tool/op combinations?
|
||||||
|
|
||||||
|
Forbidden hits only count if the call actually went through to the
|
||||||
|
provider (success=True). Validator-rejected calls don't count -- the
|
||||||
|
Repair-Loop is doing its job and steering the agent away.
|
||||||
|
"""
|
||||||
|
if not question.forbiddenTools:
|
||||||
|
return True
|
||||||
|
forbiddenSet = set(question.forbiddenTools)
|
||||||
|
for r in run.toolResults:
|
||||||
|
if not r.get("success"):
|
||||||
|
continue
|
||||||
|
if r.get("toolName") in forbiddenSet:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _checkNumeric(question: _Question, run: _RunResult) -> bool:
|
||||||
|
text = (run.finalText or "")
|
||||||
|
if question.expectedNumbers:
|
||||||
|
textNumbers = _extractNumbers(text)
|
||||||
|
for expected in question.expectedNumbers:
|
||||||
|
tol = max(abs(expected) * question.numericTolerance, 0.5)
|
||||||
|
if not any(abs(n - expected) <= tol for n in textNumbers):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if question.expectedAnswerContains:
|
||||||
|
lowered = text.lower()
|
||||||
|
for needle in question.expectedAnswerContains:
|
||||||
|
if needle.lower() not in lowered:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _extractNumbers(text: str) -> List[float]:
|
||||||
|
"""Pick out all numbers from a free-text answer.
|
||||||
|
|
||||||
|
Handles Swiss thousand separators (apostrophe and U+2019), German
|
||||||
|
decimals (comma), plain integers/floats, and JSON numbers. Trailing
|
||||||
|
punctuation (``,``, ``;``, ``.`` from end-of-sentence) is stripped
|
||||||
|
before parsing so ``"180500.0,"`` parses cleanly to 180500.0.
|
||||||
|
"""
|
||||||
|
cleaned = text.replace("\u2019", "'")
|
||||||
|
tokens = re.findall(r"-?\d[\d'.,]*", cleaned)
|
||||||
|
out: List[float] = []
|
||||||
|
for tok in tokens:
|
||||||
|
tok = tok.rstrip(",;")
|
||||||
|
if tok.endswith(".") and tok.count(".") == 1:
|
||||||
|
tok = tok[:-1]
|
||||||
|
norm = tok.replace("'", "")
|
||||||
|
if norm.count(",") == 1 and norm.count(".") == 0:
|
||||||
|
norm = norm.replace(",", ".")
|
||||||
|
elif norm.count(",") >= 1 and norm.count(".") >= 1:
|
||||||
|
if norm.rfind(",") > norm.rfind("."):
|
||||||
|
norm = norm.replace(".", "").replace(",", ".")
|
||||||
|
else:
|
||||||
|
norm = norm.replace(",", "")
|
||||||
|
else:
|
||||||
|
norm = norm.replace(",", "")
|
||||||
|
try:
|
||||||
|
out.append(float(norm))
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# AI call wiring
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _bootstrapServices() -> Tuple[Any, str, str]:
|
||||||
|
"""Spin up a minimal service hub bound to the root user + initial mandate.
|
||||||
|
|
||||||
|
Returns the ServiceHub, the user id, and the mandate id used for billing.
|
||||||
|
"""
|
||||||
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||||
|
from modules.datamodels.datamodelUam import Mandate
|
||||||
|
from modules.serviceHub import getInterface as getServices
|
||||||
|
|
||||||
|
rootInterface = getRootInterface()
|
||||||
|
user = rootInterface.currentUser
|
||||||
|
mandateId = rootInterface.getInitialId(Mandate)
|
||||||
|
if not mandateId:
|
||||||
|
raise RuntimeError("No initial mandate available -- run bootstrap loader first.")
|
||||||
|
services = getServices(user, workflow=None, mandateId=mandateId, featureInstanceId=None)
|
||||||
|
return services, user.id, mandateId
|
||||||
|
|
||||||
|
|
||||||
|
async def _runOneQuestion(
|
||||||
|
*,
|
||||||
|
services: Any,
|
||||||
|
userId: str,
|
||||||
|
mandateId: str,
|
||||||
|
fixture: BenchmarkFixture,
|
||||||
|
question: _Question,
|
||||||
|
mode: _ModeConfig,
|
||||||
|
) -> _RunResult:
|
||||||
|
"""Execute a single sub-agent run for one question under one mode."""
|
||||||
|
provider = FakeFeatureDataProvider(
|
||||||
|
rowsByTable=fixture.rowsByTable,
|
||||||
|
availableTables=fixture.selectedTables,
|
||||||
|
)
|
||||||
|
validator = _buildValidator(mode)
|
||||||
|
registry = _buildSubAgentTools(
|
||||||
|
provider=provider,
|
||||||
|
featureInstanceId=fixture.featureInstanceId,
|
||||||
|
mandateId=fixture.mandateId,
|
||||||
|
tableFilters={},
|
||||||
|
validator=validator,
|
||||||
|
)
|
||||||
|
|
||||||
|
systemPrompt = _buildSystemPrompt(
|
||||||
|
featureCode="trustee",
|
||||||
|
instanceLabel="Demo AG",
|
||||||
|
selectedTables=fixture.selectedTables,
|
||||||
|
)
|
||||||
|
|
||||||
|
cost = 0.0
|
||||||
|
|
||||||
|
async def _aiCallFn(req: AiCallRequest) -> AiCallResponse:
|
||||||
|
nonlocal cost
|
||||||
|
resp = await services.ai.callAi(req)
|
||||||
|
cost += float(getattr(resp, "priceCHF", 0.0) or 0.0)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
async def _getCost() -> float:
|
||||||
|
return cost
|
||||||
|
|
||||||
|
config = AgentConfig(
|
||||||
|
maxRounds=6,
|
||||||
|
maxCostCHF=0.50,
|
||||||
|
operationType=OperationTypeEnum.DATA_QUERY,
|
||||||
|
)
|
||||||
|
|
||||||
|
run = _RunResult(questionId=question.id, finalText="")
|
||||||
|
t0 = time.time()
|
||||||
|
try:
|
||||||
|
async for event in runAgentLoop(
|
||||||
|
prompt=question.question,
|
||||||
|
toolRegistry=registry,
|
||||||
|
config=config,
|
||||||
|
aiCallFn=_aiCallFn,
|
||||||
|
getWorkflowCostFn=_getCost,
|
||||||
|
workflowId=f"eval-{mode.name}-{question.id}-{uuid.uuid4().hex[:6]}",
|
||||||
|
userId=userId,
|
||||||
|
featureInstanceId=fixture.featureInstanceId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
systemPromptOverride=systemPrompt,
|
||||||
|
):
|
||||||
|
if event.type == AgentEventTypeEnum.FINAL:
|
||||||
|
run.finalText = event.content or run.finalText
|
||||||
|
elif event.type == AgentEventTypeEnum.MESSAGE and event.content:
|
||||||
|
run.finalText += event.content
|
||||||
|
elif event.type == AgentEventTypeEnum.TOOL_CALL:
|
||||||
|
run.toolCalls.append(dict(event.data or {}))
|
||||||
|
elif event.type == AgentEventTypeEnum.TOOL_RESULT:
|
||||||
|
run.toolResults.append(dict(event.data or {}))
|
||||||
|
elif event.type == AgentEventTypeEnum.AGENT_SUMMARY:
|
||||||
|
run.summary = dict(event.data or {})
|
||||||
|
elif event.type == AgentEventTypeEnum.ERROR:
|
||||||
|
run.error = (run.error or "") + (event.content or "")
|
||||||
|
except Exception as e:
|
||||||
|
run.error = f"{type(e).__name__}: {e}"
|
||||||
|
logger.exception("Sub-agent run failed for %s/%s", mode.name, question.id)
|
||||||
|
run.durationS = time.time() - t0
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Report
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ModeReport:
|
||||||
|
mode: _ModeConfig
|
||||||
|
perQuestion: List[Tuple[_Question, _RunResult, _Score]] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> int:
|
||||||
|
return len(self.perQuestion)
|
||||||
|
|
||||||
|
def _count(self, attr: str) -> int:
|
||||||
|
return sum(1 for _, _, s in self.perQuestion if getattr(s, attr))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def accuracy(self) -> float:
|
||||||
|
return self._count("accuracyOk") / max(self.total, 1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def patternCompliance(self) -> float:
|
||||||
|
return self._count("patternOk") / max(self.total, 1)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repairConversionRate(self) -> float:
|
||||||
|
attempts = sum(r.repairAttempts for _, r, _ in self.perQuestion)
|
||||||
|
succeeded = sum(r.successAfterRepair for _, r, _ in self.perQuestion)
|
||||||
|
if attempts == 0:
|
||||||
|
return 0.0
|
||||||
|
return succeeded / attempts
|
||||||
|
|
||||||
|
@property
|
||||||
|
def totalCostCHF(self) -> float:
|
||||||
|
return sum(r.costCHF for _, r, _ in self.perQuestion)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def totalRounds(self) -> int:
|
||||||
|
return sum(r.rounds for _, r, _ in self.perQuestion)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def totalValidationFailures(self) -> int:
|
||||||
|
return sum(r.validationFailures for _, r, _ in self.perQuestion)
|
||||||
|
|
||||||
|
|
||||||
|
def _writeReport(reports: List[_ModeReport], outputPath: Path) -> None:
|
||||||
|
lines: List[str] = []
|
||||||
|
lines.append("# Trustee Sub-Agent Benchmark Report")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("## Summary")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("| Mode | Questions | Accuracy | Pattern compliance | Repair conversion | Validator rejects | Rounds | Cost (CHF) |")
|
||||||
|
lines.append("|---|---|---|---|---|---|---|---|")
|
||||||
|
for rep in reports:
|
||||||
|
lines.append(
|
||||||
|
f"| {rep.mode.label} | {rep.total} | {rep.accuracy:.1%} | {rep.patternCompliance:.1%} | "
|
||||||
|
f"{rep.repairConversionRate:.1%} | {rep.totalValidationFailures} | {rep.totalRounds} | "
|
||||||
|
f"{rep.totalCostCHF:.4f} |"
|
||||||
|
)
|
||||||
|
lines.append("")
|
||||||
|
lines.append("## Per-question detail")
|
||||||
|
for rep in reports:
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"### {rep.mode.label}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("| id | acc | pattern | forbid | numeric | rounds | val-fail | repairs | cost CHF | duration | tools |")
|
||||||
|
lines.append("|---|---|---|---|---|---|---|---|---|---|---|")
|
||||||
|
for q, r, s in rep.perQuestion:
|
||||||
|
toolList = ",".join(
|
||||||
|
f"{c.get('toolName')}({c.get('args',{}).get('tableName','?')})"
|
||||||
|
for c in r.toolCalls
|
||||||
|
)
|
||||||
|
lines.append(
|
||||||
|
f"| {q.id} | {_yn(s.accuracyOk)} | {_yn(s.patternOk)} | {_yn(s.forbidOk)} | {_yn(s.numericOk)} | "
|
||||||
|
f"{r.rounds} | {r.validationFailures} | {r.repairAttempts}/{r.successAfterRepair} | "
|
||||||
|
f"{r.costCHF:.4f} | {r.durationS:.1f}s | {toolList} |"
|
||||||
|
)
|
||||||
|
lines.append("")
|
||||||
|
lines.append("#### Notes & failures")
|
||||||
|
for q, r, s in rep.perQuestion:
|
||||||
|
if s.accuracyOk:
|
||||||
|
continue
|
||||||
|
lines.append(f"- **{q.id}** ({q.intent}): pattern={s.patternOk} forbid={s.forbidOk} numeric={s.numericOk}")
|
||||||
|
if r.error:
|
||||||
|
lines.append(f" - error: `{r.error}`")
|
||||||
|
lines.append(f" - answer: `{(r.finalText or '').strip().replace('|', '/').splitlines()[0][:240]}`")
|
||||||
|
for note in s.notes:
|
||||||
|
lines.append(f" - note: {note}")
|
||||||
|
outputPath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
outputPath.write_text("\n".join(lines), encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _yn(b: bool) -> str:
|
||||||
|
return "OK" if b else "FAIL"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main entry point
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def _runMain(modesToRun: List[str], onlyQuestionId: Optional[str] = None) -> None:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.WARNING,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s -- %(message)s",
|
||||||
|
)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
fixture = buildTrusteeBenchmarkFixture()
|
||||||
|
questionsPath = _GATEWAY_DIR / "tests" / "fixtures" / "trusteeBenchmark" / "questions.yaml"
|
||||||
|
allQuestions = _loadQuestions(questionsPath)
|
||||||
|
if onlyQuestionId:
|
||||||
|
allQuestions = [q for q in allQuestions if q.id == onlyQuestionId]
|
||||||
|
if not allQuestions:
|
||||||
|
print(f"No question matches id={onlyQuestionId!r}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Loaded {len(allQuestions)} questions, {len(modesToRun)} modes -> {len(allQuestions) * len(modesToRun)} sub-agent runs.")
|
||||||
|
|
||||||
|
services, userId, mandateId = _bootstrapServices()
|
||||||
|
print(f"Bootstrap OK: user={userId}, mandate={mandateId}")
|
||||||
|
|
||||||
|
reports: List[_ModeReport] = []
|
||||||
|
for modeName in modesToRun:
|
||||||
|
mode = _MODES[modeName]
|
||||||
|
_applyEnvForMode(mode)
|
||||||
|
rep = _ModeReport(mode=mode)
|
||||||
|
print(f"\n=== Mode: {mode.label} ===")
|
||||||
|
for idx, question in enumerate(allQuestions, start=1):
|
||||||
|
print(f" [{idx:>2}/{len(allQuestions)}] {question.id}: {question.question[:80]} ...", flush=True)
|
||||||
|
run = await _runOneQuestion(
|
||||||
|
services=services,
|
||||||
|
userId=userId,
|
||||||
|
mandateId=mandateId,
|
||||||
|
fixture=fixture,
|
||||||
|
question=question,
|
||||||
|
mode=mode,
|
||||||
|
)
|
||||||
|
score = _scoreRun(question, run)
|
||||||
|
rep.perQuestion.append((question, run, score))
|
||||||
|
print(
|
||||||
|
f" -> acc={_yn(score.accuracyOk)} "
|
||||||
|
f"pattern={_yn(score.patternOk)} forbid={_yn(score.forbidOk)} "
|
||||||
|
f"numeric={_yn(score.numericOk)} rounds={run.rounds} cost={run.costCHF:.4f} "
|
||||||
|
f"val-fail={run.validationFailures} repairs={run.repairAttempts}/{run.successAfterRepair}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
reports.append(rep)
|
||||||
|
|
||||||
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
outDir = _GATEWAY_DIR.parent / "local" / "notes"
|
||||||
|
reportPath = outDir / f"trustee-benchmark-{timestamp}.md"
|
||||||
|
_writeReport(reports, reportPath)
|
||||||
|
|
||||||
|
rawJsonPath = outDir / f"trustee-benchmark-{timestamp}.json"
|
||||||
|
rawJsonPath.write_text(
|
||||||
|
json.dumps(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"mode": rep.mode.name,
|
||||||
|
"accuracy": rep.accuracy,
|
||||||
|
"patternCompliance": rep.patternCompliance,
|
||||||
|
"repairConversionRate": rep.repairConversionRate,
|
||||||
|
"totalCostCHF": rep.totalCostCHF,
|
||||||
|
"totalRounds": rep.totalRounds,
|
||||||
|
"totalValidationFailures": rep.totalValidationFailures,
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"questionId": q.id,
|
||||||
|
"intent": q.intent,
|
||||||
|
"accuracyOk": s.accuracyOk,
|
||||||
|
"patternOk": s.patternOk,
|
||||||
|
"forbidOk": s.forbidOk,
|
||||||
|
"numericOk": s.numericOk,
|
||||||
|
"rounds": r.rounds,
|
||||||
|
"validationFailures": r.validationFailures,
|
||||||
|
"repairAttempts": r.repairAttempts,
|
||||||
|
"successAfterRepair": r.successAfterRepair,
|
||||||
|
"costCHF": r.costCHF,
|
||||||
|
"durationS": r.durationS,
|
||||||
|
"finalText": (r.finalText or "")[:600],
|
||||||
|
"toolCalls": r.toolCalls,
|
||||||
|
"error": r.error,
|
||||||
|
}
|
||||||
|
for q, r, s in rep.perQuestion
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for rep in reports
|
||||||
|
],
|
||||||
|
indent=2,
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\nReport written: {reportPath}")
|
||||||
|
print(f"Raw JSON: {rawJsonPath}")
|
||||||
|
for rep in reports:
|
||||||
|
print(f" {rep.mode.label}: acc={rep.accuracy:.1%} pattern={rep.patternCompliance:.1%} cost={rep.totalCostCHF:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
def _parseArgs(argv: List[str]) -> Tuple[List[str], Optional[str]]:
|
||||||
|
modes: List[str] = []
|
||||||
|
only: Optional[str] = None
|
||||||
|
for arg in argv:
|
||||||
|
if arg.startswith("--only="):
|
||||||
|
only = arg.split("=", 1)[1]
|
||||||
|
elif arg in _MODES:
|
||||||
|
modes.append(arg)
|
||||||
|
else:
|
||||||
|
print(f"Unknown argument: {arg!r}. Allowed modes: {list(_MODES)}")
|
||||||
|
sys.exit(2)
|
||||||
|
if not modes:
|
||||||
|
modes = ["baseline", "phase1", "phase2"]
|
||||||
|
return modes, only
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
modes, only = _parseArgs(sys.argv[1:])
|
||||||
|
asyncio.run(_runMain(modes, onlyQuestionId=only))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
16
tests/fixtures/trusteeBenchmark/__init__.py
vendored
Normal file
16
tests/fixtures/trusteeBenchmark/__init__.py
vendored
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Trustee benchmark fixture: synthetic but realistic Swiss KMU accounting data.
|
||||||
|
|
||||||
|
Used by the Feature Data Sub-Agent eval harness (Phase 1.5) to measure
|
||||||
|
hallucination rates against a fixed gold standard. Data is built in-memory
|
||||||
|
via Pydantic models -- no SQL, no DB connection -- so the harness stays
|
||||||
|
hermetic and reproducible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from tests.fixtures.trusteeBenchmark.loadTrusteeBenchmarkFixture import (
|
||||||
|
buildTrusteeBenchmarkFixture,
|
||||||
|
BenchmarkFixture,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["buildTrusteeBenchmarkFixture", "BenchmarkFixture"]
|
||||||
275
tests/fixtures/trusteeBenchmark/loadTrusteeBenchmarkFixture.py
vendored
Normal file
275
tests/fixtures/trusteeBenchmark/loadTrusteeBenchmarkFixture.py
vendored
Normal file
|
|
@ -0,0 +1,275 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Synthetic Trustee benchmark fixture for the Feature Data Sub-Agent eval.
|
||||||
|
|
||||||
|
Builds an in-memory snapshot of one fictional Swiss KMU mandate
|
||||||
|
("Demo AG") with:
|
||||||
|
|
||||||
|
* 3 fiscal years (2023, 2024, 2025) of `TrusteeDataAccountBalance` rows
|
||||||
|
-- both annual totals (periodMonth=0) and monthly snapshots.
|
||||||
|
* 8 representative accounts spanning all major chart-of-accounts blocks
|
||||||
|
(cash, banks, receivables, payables, revenue, materials, personnel,
|
||||||
|
operating expenses).
|
||||||
|
* Per-month `TrusteeDataJournalEntry` + multiple `TrusteeDataJournalLine`
|
||||||
|
rows so debit/credit/COUNT aggregations have meaningful answers.
|
||||||
|
|
||||||
|
The data is deterministic (no RNG) so a question's gold-standard answer
|
||||||
|
is stable across runs.
|
||||||
|
|
||||||
|
This module deliberately stays decoupled from the production DB pipeline
|
||||||
|
-- the harness uses :class:`FakeFeatureDataProvider` (see
|
||||||
|
``gateway/tests/eval/fakeFeatureDataProvider.py``) to serve queries
|
||||||
|
against this in-memory snapshot, mirroring the public methods of
|
||||||
|
``FeatureDataProvider`` (browseTable / queryTable / aggregateTable).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
_MANDATE_ID = "m-demo-ag"
|
||||||
|
_FEATURE_INSTANCE_ID = "fi-demo-ag-trustee"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Account master data
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_ACCOUNT_MASTER: List[Dict[str, Any]] = [
|
||||||
|
{"accountNumber": "1000", "label": "Hauptkasse", "accountType": "asset", "currency": "CHF"},
|
||||||
|
{"accountNumber": "1020", "label": "ZKB Geschaeftskonto", "accountType": "asset", "currency": "CHF"},
|
||||||
|
{"accountNumber": "1021", "label": "PostFinance", "accountType": "asset", "currency": "CHF"},
|
||||||
|
{"accountNumber": "1100", "label": "Forderungen aus Lieferungen und Leistungen", "accountType": "asset", "currency": "CHF"},
|
||||||
|
{"accountNumber": "2000", "label": "Verbindlichkeiten aus Lieferungen", "accountType": "liability", "currency": "CHF"},
|
||||||
|
{"accountNumber": "3000", "label": "Ertrag aus Beratung", "accountType": "revenue", "currency": "CHF"},
|
||||||
|
{"accountNumber": "5400", "label": "Materialaufwand", "accountType": "expense", "currency": "CHF"},
|
||||||
|
{"accountNumber": "6000", "label": "Mietaufwand", "accountType": "expense", "currency": "CHF"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Annual closing balances per (year, accountNumber) -- the canonical reference.
|
||||||
|
# Asset/expense balances are positive, liability/revenue balances are stored
|
||||||
|
# as positive numbers (sign by accountType, like most accounting systems).
|
||||||
|
_ANNUAL_CLOSING: Dict[int, Dict[str, float]] = {
|
||||||
|
2023: {
|
||||||
|
"1000": 4_800.00,
|
||||||
|
"1020": 132_500.00,
|
||||||
|
"1021": 22_400.00,
|
||||||
|
"1100": 58_200.00,
|
||||||
|
"2000": 41_300.00,
|
||||||
|
"3000": 410_000.00,
|
||||||
|
"5400": 92_000.00,
|
||||||
|
"6000": 36_000.00,
|
||||||
|
},
|
||||||
|
2024: {
|
||||||
|
"1000": 5_200.00,
|
||||||
|
"1020": 148_900.00,
|
||||||
|
"1021": 26_750.00,
|
||||||
|
"1100": 61_400.00,
|
||||||
|
"2000": 44_100.00,
|
||||||
|
"3000": 462_500.00,
|
||||||
|
"5400": 104_300.00,
|
||||||
|
"6000": 39_000.00,
|
||||||
|
},
|
||||||
|
2025: {
|
||||||
|
"1000": 5_900.00,
|
||||||
|
"1020": 152_400.00,
|
||||||
|
"1021": 28_100.00,
|
||||||
|
"1100": 66_800.00,
|
||||||
|
"2000": 47_900.00,
|
||||||
|
"3000": 488_700.00,
|
||||||
|
"5400": 112_100.00,
|
||||||
|
"6000": 42_000.00,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _openingFromPriorYear(year: int, accountNumber: str) -> float:
|
||||||
|
"""Opening balance of year N = closing balance of year N-1 (0 if N-1 is unknown)."""
|
||||||
|
prior = year - 1
|
||||||
|
return float(_ANNUAL_CLOSING.get(prior, {}).get(accountNumber, 0.0))
|
||||||
|
|
||||||
|
|
||||||
|
def _monthlyProgression(opening: float, closing: float, month: int) -> float:
|
||||||
|
"""Linear interpolation between opening and closing for monthly snapshots.
|
||||||
|
|
||||||
|
Not realistic in detail but deterministic and monotonic per account, so
|
||||||
|
questions about "Stand per Ende März" produce stable answers.
|
||||||
|
"""
|
||||||
|
if month <= 0:
|
||||||
|
return float(closing)
|
||||||
|
frac = month / 12.0
|
||||||
|
return round(float(opening) + (float(closing) - float(opening)) * frac, 2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Journal entries / lines -- minimal but realistic
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_JOURNAL_ENTRIES_2025: List[Dict[str, Any]] = [
|
||||||
|
{"month": 3, "day": 15, "reference": "RG-2025-0042", "description": "Beratung Kunde ACME AG", "amount": 18_500.00, "debit": "1100", "credit": "3000"},
|
||||||
|
{"month": 3, "day": 22, "reference": "EK-2025-0017", "description": "Materialeinkauf Buehler AG", "amount": 9_200.00, "debit": "5400", "credit": "2000"},
|
||||||
|
{"month": 3, "day": 28, "reference": "MIETE-2025-03", "description": "Mietzins Buero Maerz", "amount": 3_000.00, "debit": "6000", "credit": "1020"},
|
||||||
|
{"month": 4, "day": 5, "reference": "RG-2025-0051", "description": "Beratung Kunde Bell AG", "amount": 24_300.00, "debit": "1100", "credit": "3000"},
|
||||||
|
{"month": 4, "day": 18, "reference": "EK-2025-0024", "description": "Materialeinkauf Industriebedarf", "amount": 7_800.00, "debit": "5400", "credit": "2000"},
|
||||||
|
{"month": 6, "day": 12, "reference": "RG-2025-0079", "description": "Beratung Kunde Bell AG", "amount": 32_100.00, "debit": "1100", "credit": "3000"},
|
||||||
|
{"month": 6, "day": 30, "reference": "MIETE-2025-Q2", "description": "Mietzins Buero Q2-Abrechnung", "amount": 3_500.00, "debit": "6000", "credit": "1020"},
|
||||||
|
{"month": 9, "day": 4, "reference": "RG-2025-0114", "description": "Beratung Kunde Migros", "amount": 41_500.00, "debit": "1100", "credit": "3000"},
|
||||||
|
{"month": 9, "day": 25, "reference": "EK-2025-0061", "description": "Materialeinkauf Buehler AG", "amount": 12_400.00, "debit": "5400", "credit": "2000"},
|
||||||
|
{"month": 11, "day": 14, "reference": "RG-2025-0188", "description": "Beratung Kunde ACME AG", "amount": 28_700.00, "debit": "1100", "credit": "3000"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Snapshot containers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkFixture:
|
||||||
|
"""In-memory rows that mimic feature DB tables.
|
||||||
|
|
||||||
|
Each ``rowsByTable[tableName]`` is a list of column dicts compatible
|
||||||
|
with the Pydantic feature data models (TrusteeDataAccountBalance, etc.).
|
||||||
|
"""
|
||||||
|
mandateId: str
|
||||||
|
featureInstanceId: str
|
||||||
|
rowsByTable: Dict[str, List[Dict[str, Any]]] = field(default_factory=dict)
|
||||||
|
selectedTables: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def _buildSelectedTables() -> List[Dict[str, Any]]:
|
||||||
|
"""Return the DATA_OBJECT-shaped descriptors the sub-agent expects.
|
||||||
|
|
||||||
|
Mirrors what the catalog would return for the trustee feature; the
|
||||||
|
real `getDataObjects("trustee")` call would yield the same shape but
|
||||||
|
we hard-code the three tables we actually populate.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"objectKey": "data.feature.trustee.TrusteeDataAccount",
|
||||||
|
"label": {"de": "Kontenplan", "en": "Chart of accounts"},
|
||||||
|
"meta": {
|
||||||
|
"table": "TrusteeDataAccount",
|
||||||
|
"fields": ["id", "accountNumber", "label", "accountType", "currency", "isActive"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"objectKey": "data.feature.trustee.TrusteeDataAccountBalance",
|
||||||
|
"label": {"de": "Kontosalden", "en": "Account balances"},
|
||||||
|
"meta": {
|
||||||
|
"table": "TrusteeDataAccountBalance",
|
||||||
|
"fields": [
|
||||||
|
"id", "accountNumber", "periodYear", "periodMonth",
|
||||||
|
"openingBalance", "debitTotal", "creditTotal",
|
||||||
|
"closingBalance", "currency",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"objectKey": "data.feature.trustee.TrusteeDataJournalLine",
|
||||||
|
"label": {"de": "Buchungszeilen", "en": "Journal lines"},
|
||||||
|
"meta": {
|
||||||
|
"table": "TrusteeDataJournalLine",
|
||||||
|
"fields": [
|
||||||
|
"id", "journalEntryId", "accountNumber",
|
||||||
|
"debitAmount", "creditAmount", "currency", "description",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def buildTrusteeBenchmarkFixture() -> BenchmarkFixture:
|
||||||
|
"""Materialize the full in-memory benchmark snapshot.
|
||||||
|
|
||||||
|
All rows include ``mandateId`` and ``featureInstanceId`` columns so the
|
||||||
|
fake provider can scope them the same way the real one does.
|
||||||
|
"""
|
||||||
|
accountRows: List[Dict[str, Any]] = []
|
||||||
|
for i, acc in enumerate(_ACCOUNT_MASTER):
|
||||||
|
accountRows.append({
|
||||||
|
"id": f"acc-{i:03d}",
|
||||||
|
"accountNumber": acc["accountNumber"],
|
||||||
|
"label": acc["label"],
|
||||||
|
"accountType": acc["accountType"],
|
||||||
|
"currency": acc["currency"],
|
||||||
|
"isActive": True,
|
||||||
|
"mandateId": _MANDATE_ID,
|
||||||
|
"featureInstanceId": _FEATURE_INSTANCE_ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
balanceRows: List[Dict[str, Any]] = []
|
||||||
|
rowIdx = 0
|
||||||
|
for year, closings in _ANNUAL_CLOSING.items():
|
||||||
|
for accountNumber, closing in closings.items():
|
||||||
|
opening = _openingFromPriorYear(year, accountNumber)
|
||||||
|
balanceRows.append({
|
||||||
|
"id": f"bal-{rowIdx:04d}",
|
||||||
|
"accountNumber": accountNumber,
|
||||||
|
"periodYear": year,
|
||||||
|
"periodMonth": 0,
|
||||||
|
"openingBalance": opening,
|
||||||
|
"debitTotal": round(max(closing - opening, 0.0) * 1.2, 2),
|
||||||
|
"creditTotal": round(max(closing - opening, 0.0) * 0.2, 2),
|
||||||
|
"closingBalance": float(closing),
|
||||||
|
"currency": "CHF",
|
||||||
|
"mandateId": _MANDATE_ID,
|
||||||
|
"featureInstanceId": _FEATURE_INSTANCE_ID,
|
||||||
|
})
|
||||||
|
rowIdx += 1
|
||||||
|
for month in range(1, 13):
|
||||||
|
monthly = _monthlyProgression(opening, closing, month)
|
||||||
|
balanceRows.append({
|
||||||
|
"id": f"bal-{rowIdx:04d}",
|
||||||
|
"accountNumber": accountNumber,
|
||||||
|
"periodYear": year,
|
||||||
|
"periodMonth": month,
|
||||||
|
"openingBalance": opening,
|
||||||
|
"debitTotal": round((monthly - opening) * 1.2, 2) if monthly > opening else 0.0,
|
||||||
|
"creditTotal": round((monthly - opening) * 0.2, 2) if monthly > opening else 0.0,
|
||||||
|
"closingBalance": monthly,
|
||||||
|
"currency": "CHF",
|
||||||
|
"mandateId": _MANDATE_ID,
|
||||||
|
"featureInstanceId": _FEATURE_INSTANCE_ID,
|
||||||
|
})
|
||||||
|
rowIdx += 1
|
||||||
|
|
||||||
|
lineRows: List[Dict[str, Any]] = []
|
||||||
|
for j, entry in enumerate(_JOURNAL_ENTRIES_2025):
|
||||||
|
entryId = f"je-2025-{j:03d}"
|
||||||
|
lineRows.append({
|
||||||
|
"id": f"jl-{j*2:04d}",
|
||||||
|
"journalEntryId": entryId,
|
||||||
|
"accountNumber": entry["debit"],
|
||||||
|
"debitAmount": float(entry["amount"]),
|
||||||
|
"creditAmount": 0.0,
|
||||||
|
"currency": "CHF",
|
||||||
|
"description": entry["description"],
|
||||||
|
"mandateId": _MANDATE_ID,
|
||||||
|
"featureInstanceId": _FEATURE_INSTANCE_ID,
|
||||||
|
})
|
||||||
|
lineRows.append({
|
||||||
|
"id": f"jl-{j*2+1:04d}",
|
||||||
|
"journalEntryId": entryId,
|
||||||
|
"accountNumber": entry["credit"],
|
||||||
|
"debitAmount": 0.0,
|
||||||
|
"creditAmount": float(entry["amount"]),
|
||||||
|
"currency": "CHF",
|
||||||
|
"description": entry["description"],
|
||||||
|
"mandateId": _MANDATE_ID,
|
||||||
|
"featureInstanceId": _FEATURE_INSTANCE_ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
fixture = BenchmarkFixture(
|
||||||
|
mandateId=_MANDATE_ID,
|
||||||
|
featureInstanceId=_FEATURE_INSTANCE_ID,
|
||||||
|
rowsByTable={
|
||||||
|
"TrusteeDataAccount": accountRows,
|
||||||
|
"TrusteeDataAccountBalance": balanceRows,
|
||||||
|
"TrusteeDataJournalLine": lineRows,
|
||||||
|
},
|
||||||
|
selectedTables=_buildSelectedTables(),
|
||||||
|
)
|
||||||
|
return fixture
|
||||||
226
tests/fixtures/trusteeBenchmark/questions.yaml
vendored
Normal file
226
tests/fixtures/trusteeBenchmark/questions.yaml
vendored
Normal file
|
|
@ -0,0 +1,226 @@
|
||||||
|
# Trustee Sub-Agent Benchmark -- 19 questions analog Hein 2025
|
||||||
|
#
|
||||||
|
# Each question covers ONE expected hallucination class so we can attribute
|
||||||
|
# accuracy gains to specific phases (validator / ontology).
|
||||||
|
#
|
||||||
|
# Scoring per question (all binary unless noted):
|
||||||
|
# patternOk -- did the agent call the right tool(s) with the right filters?
|
||||||
|
# forbidOk -- did it AVOID the forbidden tool/op (e.g. SUM closingBalance)?
|
||||||
|
# numericOk -- does the final answer contain the expected number(s)?
|
||||||
|
# accuracyOk -- patternOk AND forbidOk AND numericOk
|
||||||
|
#
|
||||||
|
# tolerance: relative tolerance for numeric comparison (default 0.005 = 0.5 %).
|
||||||
|
|
||||||
|
- id: q01
|
||||||
|
question: "Was ist der Banksaldo per 31.12.2025 fuer das ZKB-Konto 1020?"
|
||||||
|
intent: BANK_BALANCE_AT_DATE
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
accountNumber: "1020"
|
||||||
|
periodYear: 2025
|
||||||
|
periodMonth: 0
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedNumbers: [152400.0]
|
||||||
|
|
||||||
|
- id: q02
|
||||||
|
question: "Wie hoch ist die Hauptkasse (Konto 1000) per Ende 2024?"
|
||||||
|
intent: CASH_BALANCE_AT_DATE
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
accountNumber: "1000"
|
||||||
|
periodYear: 2024
|
||||||
|
periodMonth: 0
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedNumbers: [5200.0]
|
||||||
|
|
||||||
|
- id: q03
|
||||||
|
question: "Summiere alle Bankkonten (102x) per 31.12.2025."
|
||||||
|
intent: BANK_GROUP_TOTAL_AT_DATE
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
periodYear: 2025
|
||||||
|
periodMonth: 0
|
||||||
|
accountNumberLike: "102%"
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedNumbers: [180500.0]
|
||||||
|
numericTolerance: 0.01
|
||||||
|
|
||||||
|
- id: q04
|
||||||
|
question: "Wie hat sich der Schlusssaldo des ZKB-Kontos 1020 ueber die Jahre 2023 bis 2025 entwickelt?"
|
||||||
|
intent: BALANCE_HISTORY_PER_YEAR
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
accountNumber: "1020"
|
||||||
|
periodMonth: 0
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedNumbers: [132500.0, 148900.0, 152400.0]
|
||||||
|
|
||||||
|
- id: q05
|
||||||
|
question: "Welches Konto hatte 2025 den hoechsten Schlusssaldo bei den Aktiven (1xxx)?"
|
||||||
|
intent: TOP_ASSET_AT_DATE
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
periodYear: 2025
|
||||||
|
periodMonth: 0
|
||||||
|
accountNumberLike: "1%"
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedAnswerContains: ["1020"]
|
||||||
|
expectedNumbers: [152400.0]
|
||||||
|
|
||||||
|
- id: q06
|
||||||
|
question: "Welche Konten gehoeren zu den Bankkonten (102x)?"
|
||||||
|
intent: ACCOUNT_LIST_FILTER
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccount
|
||||||
|
requiredFilters:
|
||||||
|
accountNumberLike: "102%"
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedAnswerContains: ["1020", "1021"]
|
||||||
|
|
||||||
|
- id: q07
|
||||||
|
question: "Wie hoch war der Materialaufwand (Konto 5400) im Jahr 2025?"
|
||||||
|
intent: EXPENSE_AT_YEAR
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
accountNumber: "5400"
|
||||||
|
periodYear: 2025
|
||||||
|
periodMonth: 0
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedNumbers: [112100.0]
|
||||||
|
|
||||||
|
- id: q08
|
||||||
|
question: "Wie viele Buchungszeilen gibt es insgesamt im System?"
|
||||||
|
intent: COUNT_ROWS
|
||||||
|
expectedTools: [aggregateTable]
|
||||||
|
expectedTable: TrusteeDataJournalLine
|
||||||
|
expectedAggregate: COUNT
|
||||||
|
forbiddenTools: []
|
||||||
|
expectedNumbers: [20]
|
||||||
|
|
||||||
|
- id: q09
|
||||||
|
question: "Wie hoch ist der gesamte Beratungsertrag (Konto 3000) im Jahr 2025?"
|
||||||
|
intent: REVENUE_AT_YEAR
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
accountNumber: "3000"
|
||||||
|
periodYear: 2025
|
||||||
|
periodMonth: 0
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedNumbers: [488700.0]
|
||||||
|
|
||||||
|
- id: q10
|
||||||
|
question: "Wie viel wurde 2025 auf das Materialaufwand-Konto 5400 gebucht (Soll-Summe ueber Buchungszeilen)?"
|
||||||
|
intent: JOURNAL_SUM_AT_ACCOUNT
|
||||||
|
expectedTools: [aggregateTable]
|
||||||
|
expectedTable: TrusteeDataJournalLine
|
||||||
|
expectedAggregate: SUM
|
||||||
|
expectedAggregateField: debitAmount
|
||||||
|
requiredFilters:
|
||||||
|
accountNumber: "5400"
|
||||||
|
forbiddenTools: []
|
||||||
|
expectedNumbers: [29400.0]
|
||||||
|
numericTolerance: 0.01
|
||||||
|
|
||||||
|
- id: q11
|
||||||
|
question: "Welche Buchungen im 1. Quartal 2025 (Januar bis Maerz) wurden auf Konto 3000 gebucht?"
|
||||||
|
intent: JOURNAL_LINES_BY_ACCOUNT
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataJournalLine
|
||||||
|
requiredFilters:
|
||||||
|
accountNumber: "3000"
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedAnswerContains: ["18500", "ACME"]
|
||||||
|
|
||||||
|
- id: q12
|
||||||
|
question: "Wie hoch war die Hauptkasse (Konto 1000) jeweils per Ende Maerz 2025 und per Ende Juni 2025?"
|
||||||
|
intent: MULTI_MONTH_SNAPSHOT
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
accountNumber: "1000"
|
||||||
|
periodYear: 2025
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedNumbers: [5375.0, 5550.0]
|
||||||
|
numericTolerance: 0.01
|
||||||
|
|
||||||
|
- id: q13
|
||||||
|
question: "Wie hoch ist die Summe aller Aufwandskonten (5xxx und 6xxx) per Ende 2025?"
|
||||||
|
intent: EXPENSE_GROUP_TOTAL
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
periodYear: 2025
|
||||||
|
periodMonth: 0
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedNumbers: [154100.0]
|
||||||
|
numericTolerance: 0.01
|
||||||
|
|
||||||
|
- id: q14
|
||||||
|
question: "Welches Konto hat den hoechsten openingBalance fuer 2025?"
|
||||||
|
intent: TOP_OPENING_BALANCE
|
||||||
|
# Both routes are legitimate: queryTable+orderBy+limit=1, or
|
||||||
|
# aggregateTable(MAX) followed by queryTable lookup. We only insist that
|
||||||
|
# the final answer names the right account and (optionally) the value.
|
||||||
|
expectedTools: [queryTable, aggregateTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
forbiddenTools: []
|
||||||
|
expectedAnswerContains: ["3000"]
|
||||||
|
expectedNumbers: [462500.0]
|
||||||
|
|
||||||
|
- id: q15
|
||||||
|
question: "Liste alle Konten vom Typ asset auf."
|
||||||
|
intent: ACCOUNTS_BY_TYPE
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccount
|
||||||
|
requiredFilters:
|
||||||
|
accountType: "asset"
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedAnswerContains: ["1000", "1020", "1021", "1100"]
|
||||||
|
|
||||||
|
- id: q16
|
||||||
|
question: "Wie hoch ist der Schlusssaldo der Forderungen aus Lieferungen und Leistungen (Konto 1100) per Ende 2025?"
|
||||||
|
intent: BALANCE_BY_NAME_LOOKUP
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
accountNumber: "1100"
|
||||||
|
periodYear: 2025
|
||||||
|
periodMonth: 0
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedNumbers: [66800.0]
|
||||||
|
|
||||||
|
- id: q17
|
||||||
|
question: "Wie hoch waren die Verbindlichkeiten (Konto 2000) jeweils per Ende 2023, 2024 und 2025?"
|
||||||
|
intent: LIABILITY_HISTORY
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataAccountBalance
|
||||||
|
requiredFilters:
|
||||||
|
accountNumber: "2000"
|
||||||
|
periodMonth: 0
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedNumbers: [41300.0, 44100.0, 47900.0]
|
||||||
|
|
||||||
|
- id: q18
|
||||||
|
question: "Wie viele Bankkonten gibt es im Kontenplan (102x)?"
|
||||||
|
intent: ACCOUNT_COUNT_BY_PREFIX
|
||||||
|
expectedTools: [queryTable, aggregateTable]
|
||||||
|
expectedTable: TrusteeDataAccount
|
||||||
|
requiredFilters:
|
||||||
|
accountNumberLike: "102%"
|
||||||
|
forbiddenTools: []
|
||||||
|
expectedNumbers: [2]
|
||||||
|
|
||||||
|
- id: q19
|
||||||
|
question: "Gib mir alle Buchungszeilen mit einem Sollbetrag groesser als 20'000 CHF."
|
||||||
|
intent: JOURNAL_LINES_BY_AMOUNT
|
||||||
|
expectedTools: [queryTable]
|
||||||
|
expectedTable: TrusteeDataJournalLine
|
||||||
|
forbiddenTools: [aggregateTable]
|
||||||
|
expectedAnswerContains: ["24300", "32100", "41500", "28700"]
|
||||||
112
tests/unit/serviceAgent/test_agentTrace_repairCounters.py
Normal file
112
tests/unit/serviceAgent/test_agentTrace_repairCounters.py
Normal file
|
|
@ -0,0 +1,112 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Unit tests for the repair-loop telemetry aggregation in agentLoop.
|
||||||
|
|
||||||
|
These counters (`validationFailures`, `repairAttempts`, `successAfterRepair`)
|
||||||
|
land on `AgentTrace` and are surfaced via the `AGENT_SUMMARY` event. The
|
||||||
|
Eval-Harness (Phase 1.5) reads them to compute the repair conversion rate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from modules.serviceCenter.services.serviceAgent.agentLoop import _computeRepairCounters
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||||
|
AgentRoundLog, ToolCallLog,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _round(*toolCalls: ToolCallLog) -> AgentRoundLog:
|
||||||
|
return AgentRoundLog(roundNumber=1, toolCalls=list(toolCalls))
|
||||||
|
|
||||||
|
|
||||||
|
def _failed(toolName: str, code: str) -> ToolCallLog:
|
||||||
|
return ToolCallLog(
|
||||||
|
toolName=toolName,
|
||||||
|
success=False,
|
||||||
|
validationFailureCode=code,
|
||||||
|
error=f"{code}: ...",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ok(toolName: str) -> ToolCallLog:
|
||||||
|
return ToolCallLog(toolName=toolName, success=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_computeRepairCounters_emptyTrace():
|
||||||
|
fails, attempts, succeeded = _computeRepairCounters([])
|
||||||
|
assert (fails, attempts, succeeded) == (0, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_computeRepairCounters_allCleanRunsHaveZeroCounters():
|
||||||
|
rounds = [
|
||||||
|
_round(_ok("queryTable"), _ok("browseTable")),
|
||||||
|
_round(_ok("aggregateTable")),
|
||||||
|
]
|
||||||
|
fails, attempts, succeeded = _computeRepairCounters(rounds)
|
||||||
|
assert (fails, attempts, succeeded) == (0, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_computeRepairCounters_singleFailureCountsButNoRepairYet():
|
||||||
|
"""One failure in round 1, no follow-up call -- counts the failure but
|
||||||
|
nothing else. Repair only counts when the LLM tries again."""
|
||||||
|
rounds = [_round(_failed("queryTable", "FIELD_NOT_FOUND"))]
|
||||||
|
fails, attempts, succeeded = _computeRepairCounters(rounds)
|
||||||
|
assert (fails, attempts, succeeded) == (1, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_computeRepairCounters_repairThatSucceeds():
|
||||||
|
"""Round 1 fails, round 2 retries same tool successfully."""
|
||||||
|
rounds = [
|
||||||
|
_round(_failed("queryTable", "FIELD_NOT_FOUND")),
|
||||||
|
_round(_ok("queryTable")),
|
||||||
|
]
|
||||||
|
fails, attempts, succeeded = _computeRepairCounters(rounds)
|
||||||
|
assert (fails, attempts, succeeded) == (1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_computeRepairCounters_repairThatFailsAgain():
|
||||||
|
"""Round 1 fails, round 2 retries same tool but fails validation again."""
|
||||||
|
rounds = [
|
||||||
|
_round(_failed("queryTable", "FIELD_NOT_FOUND")),
|
||||||
|
_round(_failed("queryTable", "FIELD_NOT_FOUND")),
|
||||||
|
]
|
||||||
|
fails, attempts, succeeded = _computeRepairCounters(rounds)
|
||||||
|
assert (fails, attempts, succeeded) == (2, 1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_computeRepairCounters_siblingCallsInSameRoundAreNotRepairs():
|
||||||
|
"""When the LLM emits two queryTable calls in the same round, the
|
||||||
|
second is NOT a repair attempt -- it had no way to see the first
|
||||||
|
one's rejection yet (parallel dispatch within a round)."""
|
||||||
|
rounds = [
|
||||||
|
_round(
|
||||||
|
_failed("queryTable", "FIELD_NOT_FOUND"),
|
||||||
|
_failed("queryTable", "FIELD_NOT_FOUND"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
fails, attempts, succeeded = _computeRepairCounters(rounds)
|
||||||
|
assert (fails, attempts, succeeded) == (2, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_computeRepairCounters_differentToolNamesAreIndependent():
|
||||||
|
"""A queryTable failure does not flag a later browseTable as a repair."""
|
||||||
|
rounds = [
|
||||||
|
_round(_failed("queryTable", "FIELD_NOT_FOUND")),
|
||||||
|
_round(_ok("browseTable")),
|
||||||
|
]
|
||||||
|
fails, attempts, succeeded = _computeRepairCounters(rounds)
|
||||||
|
assert (fails, attempts, succeeded) == (1, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_computeRepairCounters_multiToolMix():
|
||||||
|
"""Trustee-like sequence: SUM(closingBalance) rejected, LLM switches to
|
||||||
|
queryTable with a typo (rejected), then fixes the typo (success)."""
|
||||||
|
rounds = [
|
||||||
|
_round(_failed("aggregateTable", "INVALID_AGGREGATE_TARGET")),
|
||||||
|
_round(_failed("queryTable", "FIELD_NOT_FOUND")),
|
||||||
|
_round(_ok("queryTable")),
|
||||||
|
]
|
||||||
|
fails, attempts, succeeded = _computeRepairCounters(rounds)
|
||||||
|
# 2 validation failures total, 1 prior-rejected queryTable retry that
|
||||||
|
# succeeded; aggregateTable was never retried so no attempt counted for it.
|
||||||
|
assert (fails, attempts, succeeded) == (2, 1, 1)
|
||||||
|
|
@ -19,11 +19,18 @@ asked for the closing balance per period).
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from modules.shared import fkRegistry
|
from modules.shared import fkRegistry
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
||||||
|
ToolCallRequest, ToolResult,
|
||||||
|
)
|
||||||
from modules.serviceCenter.services.serviceAgent.featureDataAgent import (
|
from modules.serviceCenter.services.serviceAgent.featureDataAgent import (
|
||||||
_buildSchemaContext,
|
_buildSchemaContext,
|
||||||
|
_buildSubAgentTools,
|
||||||
_buildTableSchemaBlock,
|
_buildTableSchemaBlock,
|
||||||
_formatFieldLine,
|
_formatFieldLine,
|
||||||
_summarizePythonType,
|
_summarizePythonType,
|
||||||
|
|
@ -152,10 +159,29 @@ def test_buildSchemaContext_forbidsSummingAggregateFields():
|
||||||
assert "closingBalance" in prompt
|
assert "closingBalance" in prompt
|
||||||
|
|
||||||
|
|
||||||
def test_buildSchemaContext_appendsTrusteeDomainHints():
|
def test_buildSchemaContext_appendsTrusteeOntologyBlock(monkeypatch):
|
||||||
"""When the feature module exposes getAgentDomainHints(), the schema prompt
|
"""When the feature exposes getAgentOntology(), the schema prompt must
|
||||||
must include those hints so the sub-agent knows e.g. that 102x are bank
|
include the compiled ontology block (Phase 2 path)."""
|
||||||
accounts and periodMonth=0 is the annual total."""
|
monkeypatch.delenv("POWERON_DISABLE_FEATURE_ONTOLOGY", raising=False)
|
||||||
|
selected = [_trusteeAccountBalanceObj()]
|
||||||
|
prompt = _buildSchemaContext(
|
||||||
|
featureCode="trustee",
|
||||||
|
instanceLabel="Demo AG",
|
||||||
|
selectedTables=selected,
|
||||||
|
requestLang="de",
|
||||||
|
)
|
||||||
|
assert "DOMAIN ONTOLOGY (trustee):" in prompt
|
||||||
|
assert "BankAccount" in prompt
|
||||||
|
assert "NEVER_AGGREGATE on TrusteeDataAccountBalance.closingBalance" in prompt.replace("never aggregate", "NEVER_AGGREGATE")
|
||||||
|
assert "BANK_BALANCE_AT_DATE" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_buildSchemaContext_fallsBackToLegacyHints_whenOntologyDisabled(monkeypatch):
|
||||||
|
"""With POWERON_DISABLE_FEATURE_ONTOLOGY=1 the builder must fall back to
|
||||||
|
the legacy `getAgentDomainHints()` block. This is the path used by the
|
||||||
|
eval harness to measure `baseline` and `phase1` accuracy without the
|
||||||
|
ontology-driven prompt."""
|
||||||
|
monkeypatch.setenv("POWERON_DISABLE_FEATURE_ONTOLOGY", "1")
|
||||||
selected = [_trusteeAccountBalanceObj()]
|
selected = [_trusteeAccountBalanceObj()]
|
||||||
prompt = _buildSchemaContext(
|
prompt = _buildSchemaContext(
|
||||||
featureCode="trustee",
|
featureCode="trustee",
|
||||||
|
|
@ -164,16 +190,14 @@ def test_buildSchemaContext_appendsTrusteeDomainHints():
|
||||||
requestLang="de",
|
requestLang="de",
|
||||||
)
|
)
|
||||||
assert "TRUSTEE DOMAIN HINTS" in prompt
|
assert "TRUSTEE DOMAIN HINTS" in prompt
|
||||||
|
assert "DOMAIN ONTOLOGY" not in prompt
|
||||||
assert "102x Bank / Post" in prompt
|
assert "102x Bank / Post" in prompt
|
||||||
assert "periodMonth = 0" in prompt
|
|
||||||
assert "ANTI-PATTERNS" in prompt
|
|
||||||
assert 'LIKE \'102%\'' in prompt or "LIKE '102%'" in prompt
|
|
||||||
|
|
||||||
|
|
||||||
def test_buildSchemaContext_skipsHintsForFeaturesWithoutHook():
|
def test_buildSchemaContext_skipsHintsForFeaturesWithoutHook(monkeypatch):
|
||||||
"""Features that don't export getAgentDomainHints() should produce a prompt
|
"""Features that don't export getAgentDomainHints()/getAgentOntology()
|
||||||
without the trailing hints block. Verified by using a feature code that
|
should produce a prompt without any trailing hints block."""
|
||||||
cannot resolve to a main module (registry returns None)."""
|
monkeypatch.delenv("POWERON_DISABLE_FEATURE_ONTOLOGY", raising=False)
|
||||||
selected = [_trusteeAccountBalanceObj()]
|
selected = [_trusteeAccountBalanceObj()]
|
||||||
prompt = _buildSchemaContext(
|
prompt = _buildSchemaContext(
|
||||||
featureCode="nosuchfeature",
|
featureCode="nosuchfeature",
|
||||||
|
|
@ -182,4 +206,90 @@ def test_buildSchemaContext_skipsHintsForFeaturesWithoutHook():
|
||||||
requestLang="de",
|
requestLang="de",
|
||||||
)
|
)
|
||||||
assert "TRUSTEE DOMAIN HINTS" not in prompt
|
assert "TRUSTEE DOMAIN HINTS" not in prompt
|
||||||
|
assert "DOMAIN ONTOLOGY" not in prompt
|
||||||
assert "Keep your answer SHORT" in prompt
|
assert "Keep your answer SHORT" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Validator integration (Phase 1: Repair-Loop)
|
||||||
|
#
|
||||||
|
# These tests guard that pre-execute validation fires BEFORE the provider
|
||||||
|
# is touched, and that the structured error payload reaches the LLM via
|
||||||
|
# `ToolResult.errorDetails` -- the contract the LLM relies on for repair.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _buildRegistryWithMockProvider():
|
||||||
|
"""Build a sub-agent ToolRegistry where the provider is a MagicMock.
|
||||||
|
|
||||||
|
The mock records calls so we can assert the validator short-circuits
|
||||||
|
before the DB layer is reached."""
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.browseTable.return_value = {"rows": [], "total": 0, "limit": 50, "offset": 0}
|
||||||
|
provider.queryTable.return_value = {"rows": [], "total": 0, "limit": 50, "offset": 0}
|
||||||
|
provider.aggregateTable.return_value = {"rows": [], "aggregate": "SUM", "field": "x"}
|
||||||
|
registry = _buildSubAgentTools(
|
||||||
|
provider=provider,
|
||||||
|
featureInstanceId="fi-test",
|
||||||
|
mandateId="m-test",
|
||||||
|
tableFilters=None,
|
||||||
|
validator=None,
|
||||||
|
)
|
||||||
|
return registry, provider
|
||||||
|
|
||||||
|
|
||||||
|
def _dispatchSync(registry, toolName, args):
|
||||||
|
"""Synchronously dispatch a tool call through the registry."""
|
||||||
|
call = ToolCallRequest(name=toolName, args=args)
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
return loop.run_until_complete(registry.dispatch(call, context={}))
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_subAgentTools_invalidFieldShortCircuitsBeforeProvider():
|
||||||
|
"""A queryTable call with an unknown field must NOT reach the provider."""
|
||||||
|
registry, provider = _buildRegistryWithMockProvider()
|
||||||
|
result = _dispatchSync(registry, "queryTable", {
|
||||||
|
"tableName": "TrusteeDataAccountBalance",
|
||||||
|
"filters": [{"field": "klosingBalance", "op": "=", "value": 1}],
|
||||||
|
})
|
||||||
|
assert isinstance(result, ToolResult)
|
||||||
|
assert result.success is False
|
||||||
|
assert result.errorDetails is not None
|
||||||
|
assert result.errorDetails["code"] == "FIELD_NOT_FOUND"
|
||||||
|
assert result.errorDetails["suggestion"] == "closingBalance"
|
||||||
|
assert result.error and result.error.startswith("FIELD_NOT_FOUND:")
|
||||||
|
provider.queryTable.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_subAgentTools_sumClosingBalanceShortCircuits():
|
||||||
|
"""The flagship hallucination -- SUM(closingBalance) -- must be blocked
|
||||||
|
by the pre-execute validator before the DB is touched."""
|
||||||
|
registry, provider = _buildRegistryWithMockProvider()
|
||||||
|
result = _dispatchSync(registry, "aggregateTable", {
|
||||||
|
"tableName": "TrusteeDataAccountBalance",
|
||||||
|
"aggregate": "SUM",
|
||||||
|
"field": "closingBalance",
|
||||||
|
})
|
||||||
|
assert result.success is False
|
||||||
|
assert result.errorDetails["code"] == "INVALID_AGGREGATE_TARGET"
|
||||||
|
assert result.errorDetails["field"] == "closingBalance"
|
||||||
|
provider.aggregateTable.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_subAgentTools_validCallReachesProvider():
|
||||||
|
"""Sanity: a valid call passes the validator and hits the provider."""
|
||||||
|
registry, provider = _buildRegistryWithMockProvider()
|
||||||
|
result = _dispatchSync(registry, "queryTable", {
|
||||||
|
"tableName": "TrusteeDataAccountBalance",
|
||||||
|
"filters": [
|
||||||
|
{"field": "periodYear", "op": "=", "value": 2025},
|
||||||
|
{"field": "periodMonth", "op": "=", "value": 0},
|
||||||
|
],
|
||||||
|
"fields": ["accountNumber", "closingBalance"],
|
||||||
|
})
|
||||||
|
assert result.success is True
|
||||||
|
assert result.errorDetails is None
|
||||||
|
provider.queryTable.assert_called_once()
|
||||||
|
|
|
||||||
295
tests/unit/services/test_queryValidator.py
Normal file
295
tests/unit/services/test_queryValidator.py
Normal file
|
|
@ -0,0 +1,295 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Unit tests for the Feature Data Sub-Agent QueryValidator.
|
||||||
|
|
||||||
|
Each constraint is exercised with both a Happy and a Sad path so a future
|
||||||
|
refactor that silently drops a check is caught immediately.
|
||||||
|
|
||||||
|
Test fixture is the real ``TrusteeDataAccountBalance`` / ``TrusteeDataJournalLine``
|
||||||
|
Pydantic models -- both are perfectly suited because they cover all four
|
||||||
|
constraint classes in production-realistic shape (string fields, numeric
|
||||||
|
fields, fields named ``closingBalance`` / ``debitTotal``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from modules.shared import fkRegistry
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
|
||||||
|
Constraint,
|
||||||
|
ConstraintRule,
|
||||||
|
OntologyDescriptor,
|
||||||
|
ValidationErrorCode,
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.queryValidator import QueryValidator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def _ensureModels():
|
||||||
|
fkRegistry._ensureModelsLoaded()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def validator() -> QueryValidator:
|
||||||
|
return QueryValidator()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# FieldExists -- browseTable / queryTable / aggregateTable
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_browseQuery_happyPath_returnsNone(validator):
|
||||||
|
err = validator.validateBrowseQuery(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"fields": ["accountNumber", "closingBalance"]},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_browseQuery_invalidField_returnsFieldNotFound(validator):
|
||||||
|
err = validator.validateBrowseQuery(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"fields": ["closingBlance"]}, # typo
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.FIELD_NOT_FOUND
|
||||||
|
assert err.field == "closingBlance"
|
||||||
|
assert err.suggestion == "closingBalance"
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_filterOnInvalidField_returnsFieldNotFound(validator):
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"filters": [{"field": "klosingBalance", "op": "=", "value": 100}]},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.FIELD_NOT_FOUND
|
||||||
|
assert err.suggestion == "closingBalance"
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_unknownTable_isLenient(validator):
|
||||||
|
"""When the table isn't in MODEL_REGISTRY we skip validation -- relying on
|
||||||
|
the SQL layer to surface schema errors. Prevents false positives for
|
||||||
|
pure UDB tables not exposed via Pydantic."""
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"NoSuchTable123",
|
||||||
|
{"filters": [{"field": "anything", "op": "=", "value": 1}]},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# OperatorCompatible
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_likeOnStringField_isOk(validator):
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"filters": [{"field": "accountNumber", "op": "LIKE", "value": "102%"}]},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_likeOnNumericField_isOperatorIncompatible(validator):
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"filters": [{"field": "closingBalance", "op": "LIKE", "value": "100%"}]},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.OPERATOR_INCOMPATIBLE
|
||||||
|
assert err.field == "closingBalance"
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_gteOnNumericField_isOk(validator):
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"filters": [{"field": "closingBalance", "op": ">=", "value": 100}]},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_gteOnStringField_isOperatorIncompatible(validator):
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"filters": [{"field": "currency", "op": ">=", "value": "CHF"}]},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.OPERATOR_INCOMPATIBLE
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_equalsOnAnyField_isOk(validator):
|
||||||
|
"""`=` and `!=` work on any field type."""
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"filters": [{"field": "currency", "op": "=", "value": "CHF"}]},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_isNullOnAnyField_isOk(validator):
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"filters": [{"field": "mandateId", "op": "IS NULL", "value": None}]},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# AggregateTarget -- the highest-impact rule
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_sumDebitAmount_isOk(validator):
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataJournalLine",
|
||||||
|
{"aggregate": "SUM", "field": "debitAmount"},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_sumClosingBalance_isInvalidAggregateTarget(validator):
|
||||||
|
"""The flagship bug: SUM(closingBalance) across periods. Must be blocked."""
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"aggregate": "SUM", "field": "closingBalance"},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET
|
||||||
|
assert err.field == "closingBalance"
|
||||||
|
assert "already aggregated" in err.hint
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_avgDebitTotal_isInvalidAggregateTarget(validator):
|
||||||
|
"""`*Total` columns are turnovers per period -- AVG across periods is nonsense."""
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"aggregate": "AVG", "field": "debitTotal"},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_countClosingBalance_isOk(validator):
|
||||||
|
"""COUNT on a balance column is meaningful (how many balance rows exist)."""
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"aggregate": "COUNT", "field": "closingBalance"},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_sumOnStringField_isTypeMismatch(validator):
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"aggregate": "SUM", "field": "currency"},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.TYPE_MISMATCH
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_invalidField_returnsFieldNotFound(validator):
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"aggregate": "SUM", "field": "nonExistent"},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.FIELD_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_invalidGroupBy_returnsFieldNotFound(validator):
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataJournalLine",
|
||||||
|
{"aggregate": "SUM", "field": "debitAmount", "groupBy": "ghostColumn"},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.FIELD_NOT_FOUND
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# OrderByValid
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_orderByValid_isOk(validator):
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"orderBy": "periodYear"},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_orderByInvalid_returnsOrderByInvalid(validator):
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"orderBy": "periodYr"},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.ORDER_BY_INVALID
|
||||||
|
assert err.suggestion == "periodYear"
|
||||||
|
|
||||||
|
|
||||||
|
def test_queryTable_orderByLiteralStringNone_isOk(validator):
|
||||||
|
"""LLMs sometimes pass the literal string 'None'."""
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"orderBy": "None"},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Ontology-driven override (Phase 2 readiness check)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_ontologyOverride_blocksAggregateForOntologyField():
|
||||||
|
"""When the ontology marks a field NEVER_AGGREGATE, SUM/AVG is blocked
|
||||||
|
even if the field name doesn't match the convention suffixes."""
|
||||||
|
ontology = OntologyDescriptor(
|
||||||
|
featureCode="trustee",
|
||||||
|
constraints=[
|
||||||
|
Constraint(
|
||||||
|
appliesTo="TrusteeDataJournalLine.debitAmount",
|
||||||
|
rule=ConstraintRule.NEVER_AGGREGATE,
|
||||||
|
message="Synthetic test rule.",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
validatorWithOntology = QueryValidator(ontology=ontology)
|
||||||
|
err = validatorWithOntology.validateAggregateQuery(
|
||||||
|
"TrusteeDataJournalLine",
|
||||||
|
{"aggregate": "SUM", "field": "debitAmount"},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# QueryValidationError serialization (consumed by featureDataAgent)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_validationError_toShortErrorIncludesCodeAndField(validator):
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"aggregate": "SUM", "field": "closingBalance"},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
short = err.toShortError()
|
||||||
|
assert short.startswith("INVALID_AGGREGATE_TARGET:")
|
||||||
|
assert "closingBalance" in short
|
||||||
|
|
||||||
|
|
||||||
|
def test_validationError_toErrorDetailsHasFourKeys(validator):
|
||||||
|
err = validator.validateQueryTable(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"filters": [{"field": "klosingBalance", "op": "=", "value": 0}]},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
details = err.toErrorDetails()
|
||||||
|
assert set(details.keys()) == {"code", "field", "suggestion", "hint"}
|
||||||
|
assert details["code"] == "FIELD_NOT_FOUND"
|
||||||
|
assert details["suggestion"] == "closingBalance"
|
||||||
199
tests/unit/services/test_trusteeOntology.py
Normal file
199
tests/unit/services/test_trusteeOntology.py
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
# Copyright (c) 2026 Patrick Motsch
|
||||||
|
# All rights reserved.
|
||||||
|
"""Unit tests for the trustee ontology and the ontology-to-prompt compiler.
|
||||||
|
|
||||||
|
Verifies:
|
||||||
|
|
||||||
|
* the descriptor passes Pydantic validation
|
||||||
|
* `constraintsForTable` correctly scopes by table/field prefix
|
||||||
|
* the compiler emits a stable header + every entity name + every
|
||||||
|
constraint message
|
||||||
|
* the QueryValidator picks up ontology constraints (NEVER_AGGREGATE on
|
||||||
|
closingBalance) over the convention-based defaults
|
||||||
|
* the `getAgentOntology()` hook on `mainTrustee` returns the descriptor
|
||||||
|
* `_buildValidatorForFeature("trustee")` wires the validator with the
|
||||||
|
ontology
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from modules.features.trustee.mainTrustee import getAgentOntology
|
||||||
|
from modules.features.trustee.trusteeOntology import getTrusteeOntology
|
||||||
|
from modules.serviceCenter.services.serviceAgent.datamodelOntology import (
|
||||||
|
ConstraintRule,
|
||||||
|
OntologyDescriptor,
|
||||||
|
SemanticType,
|
||||||
|
ValidationErrorCode,
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.featureDataAgent import (
|
||||||
|
_buildValidatorForFeature,
|
||||||
|
_loadFeatureOntologyBlock,
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.ontologyToPromptCompiler import (
|
||||||
|
compileOntologyToPrompt,
|
||||||
|
)
|
||||||
|
from modules.serviceCenter.services.serviceAgent.queryValidator import QueryValidator
|
||||||
|
from modules.shared import fkRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", autouse=True)
|
||||||
|
def _ensureModels():
|
||||||
|
fkRegistry._ensureModelsLoaded()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OntologyDescriptor structure
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_trusteeOntology_returnsValidDescriptor():
|
||||||
|
ont = getTrusteeOntology()
|
||||||
|
assert isinstance(ont, OntologyDescriptor)
|
||||||
|
assert ont.featureCode == "trustee"
|
||||||
|
assert ont.entities and ont.relations and ont.constraints and ont.canonicalPatterns
|
||||||
|
|
||||||
|
|
||||||
|
def test_trusteeOntology_hasBankAccountSpecialization():
|
||||||
|
ont = getTrusteeOntology()
|
||||||
|
bank = next((e for e in ont.entities if e.name == "BankAccount"), None)
|
||||||
|
assert bank is not None
|
||||||
|
assert bank.parentEntity == "Account"
|
||||||
|
assert bank.semanticType == SemanticType.ACCOUNT
|
||||||
|
|
||||||
|
|
||||||
|
def test_trusteeOntology_closingBalanceIsNeverAggregate():
|
||||||
|
ont = getTrusteeOntology()
|
||||||
|
constraints = ont.constraintsForTable("TrusteeDataAccountBalance")
|
||||||
|
matching = [
|
||||||
|
c for c in constraints
|
||||||
|
if c.rule == ConstraintRule.NEVER_AGGREGATE
|
||||||
|
and c.appliesTo == "TrusteeDataAccountBalance.closingBalance"
|
||||||
|
]
|
||||||
|
assert matching, "Expected NEVER_AGGREGATE constraint on closingBalance"
|
||||||
|
|
||||||
|
|
||||||
|
def test_trusteeOntology_requiresPeriodFilterOnBalanceTable():
|
||||||
|
ont = getTrusteeOntology()
|
||||||
|
constraints = ont.constraintsForTable("TrusteeDataAccountBalance")
|
||||||
|
table_level = [c for c in constraints if c.rule == ConstraintRule.REQUIRES_FILTER_ON]
|
||||||
|
assert table_level, "Expected at least one REQUIRES_FILTER_ON constraint"
|
||||||
|
required = table_level[0].params.get("requiredFields") or []
|
||||||
|
assert "periodYear" in required
|
||||||
|
assert "periodMonth" in required
|
||||||
|
|
||||||
|
|
||||||
|
def test_constraintsForTable_filtersScopeCorrectly():
|
||||||
|
ont = getTrusteeOntology()
|
||||||
|
bal = ont.constraintsForTable("TrusteeDataAccountBalance")
|
||||||
|
journal = ont.constraintsForTable("TrusteeDataJournalLine")
|
||||||
|
for c in bal:
|
||||||
|
assert c.appliesTo.startswith("TrusteeDataAccountBalance")
|
||||||
|
for c in journal:
|
||||||
|
assert c.appliesTo.startswith("TrusteeDataJournalLine")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Prompt compiler
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_compiler_emitsExpectedHeader():
|
||||||
|
block = compileOntologyToPrompt(getTrusteeOntology())
|
||||||
|
assert block.startswith("DOMAIN ONTOLOGY (trustee):"), block.splitlines()[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_compiler_includesAllEntityNames():
|
||||||
|
ont = getTrusteeOntology()
|
||||||
|
block = compileOntologyToPrompt(ont)
|
||||||
|
for e in ont.entities:
|
||||||
|
assert e.name in block, f"Entity {e.name} missing from compiled prompt"
|
||||||
|
|
||||||
|
|
||||||
|
def test_compiler_includesAllConstraintMessages():
|
||||||
|
ont = getTrusteeOntology()
|
||||||
|
block = compileOntologyToPrompt(ont)
|
||||||
|
for c in ont.constraints:
|
||||||
|
assert c.message.split(".")[0] in block, f"Constraint message missing: {c.message[:40]}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_compiler_includesCanonicalPatternTools():
|
||||||
|
ont = getTrusteeOntology()
|
||||||
|
block = compileOntologyToPrompt(ont)
|
||||||
|
for p in ont.canonicalPatterns:
|
||||||
|
assert p.intent in block
|
||||||
|
assert p.pattern["tool"] in block
|
||||||
|
|
||||||
|
|
||||||
|
def test_compiler_deterministic():
|
||||||
|
block1 = compileOntologyToPrompt(getTrusteeOntology())
|
||||||
|
block2 = compileOntologyToPrompt(getTrusteeOntology())
|
||||||
|
assert block1 == block2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# QueryValidator x ontology integration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_validator_picksUpOntologyNeverAggregate():
|
||||||
|
validator = QueryValidator(ontology=getTrusteeOntology())
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"aggregate": "SUM", "field": "closingBalance"},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET
|
||||||
|
assert err.field == "closingBalance"
|
||||||
|
|
||||||
|
|
||||||
|
def test_validator_ontologyConstraintFiresOnDebitTotal():
|
||||||
|
validator = QueryValidator(ontology=getTrusteeOntology())
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataAccountBalance",
|
||||||
|
{"aggregate": "SUM", "field": "debitTotal"},
|
||||||
|
)
|
||||||
|
assert err is not None
|
||||||
|
assert err.code == ValidationErrorCode.INVALID_AGGREGATE_TARGET
|
||||||
|
|
||||||
|
|
||||||
|
def test_validator_allowsLegitimateAggregateOnJournalLine():
|
||||||
|
validator = QueryValidator(ontology=getTrusteeOntology())
|
||||||
|
err = validator.validateAggregateQuery(
|
||||||
|
"TrusteeDataJournalLine",
|
||||||
|
{"aggregate": "SUM", "field": "debitAmount"},
|
||||||
|
)
|
||||||
|
assert err is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# featureDataAgent integration hooks
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_mainTrustee_getAgentOntology_returnsDescriptor():
|
||||||
|
ont = getAgentOntology()
|
||||||
|
assert isinstance(ont, OntologyDescriptor)
|
||||||
|
assert ont.featureCode == "trustee"
|
||||||
|
|
||||||
|
|
||||||
|
def test_loadFeatureOntologyBlock_returnsCompiledBlock():
|
||||||
|
block = _loadFeatureOntologyBlock("trustee")
|
||||||
|
assert block.startswith("DOMAIN ONTOLOGY (trustee):")
|
||||||
|
assert "BankAccount" in block
|
||||||
|
|
||||||
|
|
||||||
|
def test_loadFeatureOntologyBlock_unknownFeatureReturnsEmpty():
|
||||||
|
assert _loadFeatureOntologyBlock("doesNotExist") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_buildValidatorForFeature_trustee_hasOntology():
|
||||||
|
validator = _buildValidatorForFeature("trustee")
|
||||||
|
assert validator._ontology is not None
|
||||||
|
assert validator._ontology.featureCode == "trustee"
|
||||||
|
|
||||||
|
|
||||||
|
def test_buildValidatorForFeature_unknownFeature_noOntology():
|
||||||
|
validator = _buildValidatorForFeature("doesNotExist")
|
||||||
|
assert validator._ontology is None
|
||||||
Loading…
Reference in a new issue