# Copyright (c) 2025 Patrick Motsch # All rights reserved. """Feature Data Sub-Agent. Specialized mini-agent that queries feature-instance data tables. Receives schema context (fields, descriptions) for the selected tables and has its tools: browseTable, queryTable, aggregateTable. Runs its own agent loop and returns structured results back to the main agent. Round/cost budgets are inherited from the parent agent (workspace user setting `maxAgentRounds` -> `AgentConfig.maxRounds`) and propagated through the tool-call context. Defaults below are only used when the sub-agent is invoked outside an agent loop (e.g. in tests). """ import json import logging import os from typing import Any, Callable, Awaitable, Dict, List, Optional from modules.datamodels.datamodelAi import ( AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum, ) from modules.datamodels.datamodelBase import MODEL_REGISTRY from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult, ) from modules.serviceCenter.services.serviceAgent.datamodelOntology import ( QueryValidationError, ) from modules.serviceCenter.services.serviceAgent.queryValidator import QueryValidator from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider from modules.shared.i18nRegistry import resolveText from modules.shared.timeUtils import getRequestNow, getRequestTimezone logger = logging.getLogger(__name__) _DEFAULT_MAX_ROUNDS = 8 # Per-round CHF cap. Multiplied by `maxRounds` so the cost guard scales # with the configured round budget instead of cutting the loop short. # 0.15 / 8 ≈ 0.019 — round up to 0.02 for some headroom. _MAX_COST_CHF_PER_ROUND = 0.02 async def runFeatureDataAgent( question: str, featureInstanceId: str, featureCode: str, selectedTables: List[Dict[str, Any]], mandateId: str, userId: str, aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]], dbConnector, instanceLabel: str = "", tableFilters: Optional[Dict[str, Dict[str, str]]] = None, requestLang: Optional[str] = None, neutralizeFields: Optional[Dict[str, List[str]]] = None, neutralizePolicy: Optional[Dict[str, Dict[str, Any]]] = None, neutralizationService: Optional[Any] = None, maxRounds: Optional[int] = None, maxCostCHF: Optional[float] = None, ) -> str: """Run the feature data sub-agent and return the textual result. Args: question: The user/main-agent question to answer using feature data. featureInstanceId: Feature instance to scope queries. featureCode: Feature code (trustee, commcoach, ...). selectedTables: List of DATA_OBJECT dicts the user selected. mandateId: Mandate scope. userId: Calling user ID. aiCallFn: AI call function (with billing). dbConnector: DatabaseConnector for queries. instanceLabel: Human-readable instance name for context. tableFilters: Per-table record filters from FeatureDataSource.recordFilter. requestLang: ISO 639-1 code for resolving multilingual table labels in the schema prompt. neutralizeFields: LEGACY per-table list of field names for whole-value masking. neutralizePolicy: Per-table type/inheritance-aware neutralization policy ({"tableActive": bool, "explicitFields": set}) applied via the provider's finalizeRowsAsync (A2 rules: strings substring-neutralized when effective, binary dropped, other scalars only when explicit). neutralizationService: Mandate/instance-scoped NeutralizationService used for substring neutralization of string cells. maxRounds: Inherited from the parent agent's configured `maxRounds` (workspace user setting `maxAgentRounds` -> `AgentConfig.maxRounds`). Falls back to the legacy 8-round default when not provided so direct callers / tests still work. maxCostCHF: Hard cost cap for the sub-agent run. When omitted, scales with `maxRounds` to keep per-round budget constant. Returns: Plain-text answer produced by the sub-agent. """ provider = FeatureDataProvider( dbConnector, neutralizeFields=neutralizeFields, neutralizePolicy=neutralizePolicy, neutralizationService=neutralizationService, ) validator = _buildValidatorForFeature(featureCode) registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {}, validator=validator) for tbl in selectedTables: meta = tbl.get("meta", {}) tableName = meta.get("table", "") if tableName: realCols = provider.getActualColumns(tableName) if realCols: meta["fields"] = realCols systemPrompt = _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang) effectiveMaxRounds = int(maxRounds) if maxRounds and maxRounds > 0 else _DEFAULT_MAX_ROUNDS effectiveMaxCost = ( float(maxCostCHF) if maxCostCHF is not None and maxCostCHF > 0 else effectiveMaxRounds * _MAX_COST_CHF_PER_ROUND ) config = AgentConfig( maxRounds=effectiveMaxRounds, maxCostCHF=effectiveMaxCost, operationType=OperationTypeEnum.DATA_QUERY, ) logger.info( "Feature data sub-agent starting: featureInstanceId=%s, maxRounds=%d, maxCostCHF=%.4f", featureInstanceId, effectiveMaxRounds, effectiveMaxCost, ) costAccumulator = 0.0 async def _trackingAiCallFn(req): nonlocal costAccumulator resp = await aiCallFn(req) costAccumulator += resp.priceCHF return resp async def _getWorkflowCost() -> float: return costAccumulator result = "" async for event in runAgentLoop( prompt=question, toolRegistry=registry, config=config, aiCallFn=_trackingAiCallFn, getWorkflowCostFn=_getWorkflowCost, workflowId=f"fda-{featureInstanceId[:8]}", userId=userId, featureInstanceId=featureInstanceId, mandateId=mandateId, systemPromptOverride=systemPrompt, ): if event.type == AgentEventTypeEnum.FINAL and event.content: result = event.content elif event.type == AgentEventTypeEnum.MESSAGE and event.content: result += event.content return result or "(no data returned by feature agent)" # ------------------------------------------------------------------ # tool registration # ------------------------------------------------------------------ def _buildSubAgentTools( provider: FeatureDataProvider, featureInstanceId: str, mandateId: str, tableFilters: Dict[str, Dict[str, str]] = None, validator: Optional[QueryValidator] = None, ) -> ToolRegistry: """Register browseTable and queryTable as sub-agent tools. The optional ``validator`` runs **before** the provider on every call. When it returns a structured error, the tool result carries ``errorDetails`` (machine-readable repair hint for the LLM) plus the short ``error`` string for logs/audit. No provider call happens in that case, so the database is never reached with a known-bad query. """ registry = ToolRegistry() _tableFilters = tableFilters or {} _validator = validator or QueryValidator() def _recordFilterToList(tableName: str) -> Optional[List[Dict[str, Any]]]: """Convert a recordFilter dict to a list of {field, op, value} filter dicts.""" rf = _tableFilters.get(tableName) if not rf: return None return [{"field": k, "op": "=", "value": v} for k, v in rf.items()] def _validationToolResult(toolName: str, err: QueryValidationError) -> ToolResult: return ToolResult( toolCallId="", toolName=toolName, success=False, error=err.toShortError(), errorDetails=err.toErrorDetails(), ) async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]): tableName = args.get("tableName", "") limit = args.get("limit", 50) offset = args.get("offset", 0) fields = args.get("fields") if not tableName: return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required") validationErr = _validator.validateBrowseQuery(tableName, args) if validationErr is not None: return _validationToolResult("browseTable", validationErr) result = provider.browseTable( tableName=tableName, featureInstanceId=featureInstanceId, mandateId=mandateId, fields=fields, limit=min(limit, 200), offset=offset, extraFilters=_recordFilterToList(tableName), ) if hasattr(provider, "finalizeRowsAsync") and "rows" in result: result["rows"] = await provider.finalizeRowsAsync(tableName, result["rows"]) return ToolResult( toolCallId="", toolName="browseTable", success="error" not in result, data=json.dumps(result, default=str, ensure_ascii=False)[:30000], error=result.get("error"), ) async def _queryTable(args: Dict[str, Any], context: Dict[str, Any]): tableName = args.get("tableName", "") filters = args.get("filters", []) fields = args.get("fields") orderBy = args.get("orderBy") limit = args.get("limit", 50) offset = args.get("offset", 0) if not tableName: return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required") validationErr = _validator.validateQueryTable(tableName, args) if validationErr is not None: return _validationToolResult("queryTable", validationErr) result = provider.queryTable( tableName=tableName, featureInstanceId=featureInstanceId, mandateId=mandateId, filters=filters, fields=fields, orderBy=orderBy, limit=min(limit, 200), offset=offset, extraFilters=_recordFilterToList(tableName), ) if hasattr(provider, "finalizeRowsAsync") and "rows" in result: result["rows"] = await provider.finalizeRowsAsync(tableName, result["rows"]) return ToolResult( toolCallId="", toolName="queryTable", success="error" not in result, data=json.dumps(result, default=str, ensure_ascii=False)[:30000], error=result.get("error"), ) async def _aggregateTable(args: Dict[str, Any], context: Dict[str, Any]): tableName = args.get("tableName", "") aggregate = args.get("aggregate", "") field = args.get("field", "") groupBy = args.get("groupBy") filters = args.get("filters") or [] if not tableName: return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="tableName required") if not aggregate: return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="aggregate required (SUM, COUNT, AVG, MIN, MAX)") if not field: return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="field required") validationErr = _validator.validateAggregateQuery(tableName, args) if validationErr is not None: return _validationToolResult("aggregateTable", validationErr) combinedFilters = list(filters) recordFilters = _recordFilterToList(tableName) or [] combinedFilters.extend(recordFilters) result = provider.aggregateTable( tableName=tableName, featureInstanceId=featureInstanceId, mandateId=mandateId, aggregate=aggregate, field=field, groupBy=groupBy, extraFilters=combinedFilters or None, ) if hasattr(provider, "finalizeRowsAsync") and "rows" in result: result["rows"] = await provider.finalizeRowsAsync(tableName, result["rows"]) return ToolResult( toolCallId="", toolName="aggregateTable", success="error" not in result, data=json.dumps(result, default=str, ensure_ascii=False)[:30000], error=result.get("error"), ) registry.register( "aggregateTable", _aggregateTable, description=( "Run an aggregate query on a feature data table. " "Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY and filters. " "Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', " "field='debitAmount', filters=[{'field':'accountNumber','op':'=','value':'5400'}]). " "On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- " "read errorDetails and correct the next call (e.g. drop the SUM, switch to queryTable with period filters, " "or use the suggested field name)." ), parameters={ "type": "object", "properties": { "tableName": {"type": "string", "description": "Name of the table to aggregate"}, "aggregate": {"type": "string", "enum": ["SUM", "COUNT", "AVG", "MIN", "MAX"], "description": "Aggregate function"}, "field": {"type": "string", "description": "Field to aggregate (e.g. debitAmount, creditAmount)"}, "groupBy": {"type": "string", "description": "Optional field to group by (e.g. costCenter, accountNumber)"}, "filters": { "type": "array", "items": { "type": "object", "properties": { "field": {"type": "string"}, "op": {"type": "string"}, "value": {}, }, }, "description": ( "Optional filter conditions applied before the aggregate. Same shape as queryTable's " "filters. Required whenever you want to aggregate only a subset (e.g. SUM debits on " "ONE account, COUNT rows in ONE year)." ), }, }, "required": ["tableName", "aggregate", "field"], }, readOnly=True, ) registry.register( "browseTable", _browseTable, description=( "List rows from a feature data table with pagination. " "On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- " "use errorDetails to correct the next call." ), parameters={ "type": "object", "properties": { "tableName": {"type": "string", "description": "Name of the table to browse"}, "fields": { "type": "array", "items": {"type": "string"}, "description": "Optional list of fields to return (default: all)", }, "limit": {"type": "integer", "description": "Max rows to return (default 50, max 200)"}, "offset": {"type": "integer", "description": "Row offset for pagination"}, }, "required": ["tableName"], }, readOnly=True, ) registry.register( "queryTable", _queryTable, description=( "Query a feature data table with filters, field selection, and ordering. " "Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. " "Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL. " "On validation failure the tool returns success=False with errorDetails={code, field, suggestion, hint} -- " "common codes: FIELD_NOT_FOUND (use the suggestion or call browseTable), OPERATOR_INCOMPATIBLE " "(switch to a compatible operator for that field type), ORDER_BY_INVALID." ), parameters={ "type": "object", "properties": { "tableName": {"type": "string", "description": "Name of the table to query"}, "filters": { "type": "array", "items": { "type": "object", "properties": { "field": {"type": "string"}, "op": {"type": "string"}, "value": {}, }, }, "description": "Filter conditions", }, "fields": { "type": "array", "items": {"type": "string"}, "description": "Optional list of fields to return", }, "orderBy": {"type": "string", "description": "Field name to order by"}, "limit": {"type": "integer", "description": "Max rows (default 50, max 200)"}, "offset": {"type": "integer", "description": "Row offset"}, }, "required": ["tableName"], }, readOnly=True, ) return registry # ------------------------------------------------------------------ # context building # ------------------------------------------------------------------ def _buildSchemaContext( featureCode: str, instanceLabel: str, selectedTables: List[Dict[str, Any]], requestLang: Optional[str] = None, ) -> str: """Build a system prompt describing available tables and query strategy. Per table the prompt now lists every selected column with its Python type, German label, description and FK target (when available via the registered Pydantic model). This gives the sub-agent enough context to: * pick the right table (e.g. period-bucketed *AccountBalance over raw JournalLine for "Saldo per "), * format date filters as UNIX timestamps when the type is float, * recognise FK relations even though the tools cannot JOIN. On top of the generic schema block we append feature-specific domain hints (e.g. Swiss KMU chart-of-accounts conventions for trustee) when the feature module exports a ``getAgentDomainHints()`` function. This lets each feature teach the sub-agent its own jargon without polluting the generic agent code. """ tableNames: List[str] = [] tableBlocks: List[str] = [] for obj in selectedTables: meta = obj.get("meta", {}) tbl = meta.get("table", "?") fields = list(meta.get("fields") or []) labelStr = resolveText(obj.get("label"), requestLang) tableNames.append(tbl) tableBlocks.append(_buildTableSchemaBlock(tbl, labelStr, fields)) header = f"You are a data query assistant for the '{featureCode}' feature" if instanceLabel: header += f' (instance: "{instanceLabel}")' header += "." tz = getRequestTimezone() now = getRequestNow() temporalLines = [ "CURRENT DATE & TIME (use this for relative time references in filters):", f" Today: {now.strftime('%Y-%m-%d (%A)')}", f" Now: {now.strftime('%H:%M')} ({tz})", " Resolve phrases like 'today', 'last month', 'Q1', 'this year' against THIS date.", " Do NOT use your training cutoff for date filters.", ] parts = [ header, "", *temporalLines, "", "AVAILABLE TABLES (use EXACTLY these names as tableName parameter):", *tableBlocks, "", f"Valid tableName values: {tableNames}", "Field names are plain column names (e.g. 'accountNumber', 'periodYear').", "", "QUERY STRATEGY:", "1. If unsure about columns, call browseTable(tableName) first to inspect the schema.", "2. Use queryTable with filters for targeted lookups.", "3. Use aggregateTable for SUM/COUNT/AVG/MIN/MAX with optional GROUP BY.", "4. Combine what you need into as few tool calls as possible.", "", "RULES:", "- Do NOT invent table or field names. Do NOT prefix fields with UUIDs or dots.", "- Float fields whose description mentions 'unix timestamp' (e.g. bookingDate, lastSyncAt) " "store seconds since epoch. Convert dates to a unix-seconds float before filtering " "(e.g. '2025-12-31' -> 1735603200.0); never compare such fields against ISO strings.", "- The query tools operate on ONE table at a time and CANNOT JOIN. To combine related " "tables (FK target shown in [FK -> Table.field]), query each separately and reason " "about the link in your answer.", "- When a table has period-bucketed aggregates (opening/closing balances or totals per " "period), prefer it over recomputing the same aggregate from raw transactional rows.", "- NEVER apply SUM/AVG to columns that already represent a balance, closing/opening " "total or aggregate (e.g. closingBalance, openingBalance, debitTotal, creditTotal, " "*Balance, *Total). These are already aggregated per period — summing them across " "periods produces meaningless numbers. Use queryTable with explicit period filters " "instead, then pick the single matching row.", "- CRITICAL: Return data as compact JSON, NOT as markdown tables or prose.", "- Do NOT reformat, rewrite, or narrate the tool results. Return the raw data directly.", "- If the question asks for rows, return them as a JSON array. Do NOT generate a markdown table.", "- Keep your answer SHORT. The caller is a machine, not a human.", ] domainBlock = "" if not _isOntologyDisabled(): domainBlock = _loadFeatureOntologyBlock(featureCode) if not domainBlock: domainBlock = _loadFeatureDomainHints(featureCode) if domainBlock: parts.extend(["", domainBlock.strip()]) return "\n".join(parts) def _isOntologyDisabled() -> bool: """Eval-only escape hatch. Set ``POWERON_DISABLE_FEATURE_ONTOLOGY=1`` in the environment to force ``_buildSchemaContext`` back onto the legacy ``getAgentDomainHints()`` path. Used by the Phase 1.5 benchmark to measure ``baseline`` and ``phase1`` accuracy WITHOUT the ontology-driven prompt block. Never set this flag in production. """ return os.environ.get("POWERON_DISABLE_FEATURE_ONTOLOGY", "").strip() in ("1", "true", "TRUE", "yes") def _buildValidatorForFeature(featureCode: str) -> QueryValidator: """Construct a QueryValidator wired with the feature ontology (when present). Without an ontology the validator falls back to its convention-based constraints (``*Balance`` / ``*Total`` are NEVER_AGGREGATE). With an ontology the descriptor's constraints take precedence -- the validator and the prompt block then share the same source of truth. """ ontology = _loadFeatureOntology(featureCode) return QueryValidator(ontology=ontology) def _loadFeatureOntology(featureCode: str): """Return the feature's OntologyDescriptor or None when no hook is exposed.""" if not featureCode: return None try: from modules.system.registry import loadFeatureMainModules except Exception: return None try: mainModules = loadFeatureMainModules() or {} except Exception as exc: logger.debug("Ontology lookup: cannot load main modules (%s)", exc) return None module = mainModules.get(featureCode) or mainModules.get(featureCode.lower()) if module is None: return None hook = getattr(module, "getAgentOntology", None) if not callable(hook): return None try: return hook() except Exception as exc: logger.warning("Feature '%s' getAgentOntology() raised: %s", featureCode, exc) return None def _loadFeatureOntologyBlock(featureCode: str) -> str: """Return the ontology-derived prompt block when the feature exposes one. Each feature can expose ``getAgentOntology() -> OntologyDescriptor`` in its ``mainXxx.py``. When present, the descriptor is compiled via :func:`ontologyToPromptCompiler.compileOntologyToPrompt` and the result replaces the legacy ``getAgentDomainHints()`` text block. This keeps one single source of truth for the validator AND the prompt. Failures are swallowed (missing hook, exceptions in compilation) so the caller can fall back to the legacy domain-hints path. """ ontology = _loadFeatureOntology(featureCode) if ontology is None: return "" try: from modules.serviceCenter.services.serviceAgent.ontologyToPromptCompiler import ( compileOntologyToPrompt, ) return compileOntologyToPrompt(ontology) except Exception as exc: logger.warning("Ontology compile failed for '%s': %s", featureCode, exc) return "" def _loadFeatureDomainHints(featureCode: str) -> str: """Pull optional domain-specific hints from the feature's main module. Each feature can expose ``getAgentDomainHints() -> str`` in its ``mainXxx.py``. The returned text is appended verbatim to the sub-agent's system prompt, so features can teach the agent their own domain jargon (chart-of-accounts conventions, period semantics, query patterns) without coupling the generic agent code to any one feature. Failures (missing module, missing hook, exception inside the hook) are swallowed and just yield an empty hints block — domain hints are a best-effort enhancement, not a hard dependency. """ if not featureCode: return "" try: from modules.system.registry import loadFeatureMainModules except Exception: return "" try: mainModules = loadFeatureMainModules() or {} except Exception as exc: logger.debug("Domain-hints lookup: cannot load main modules (%s)", exc) return "" module = mainModules.get(featureCode) or mainModules.get(featureCode.lower()) if module is None: return "" hook = getattr(module, "getAgentDomainHints", None) if not callable(hook): return "" try: hints = hook() except Exception as exc: logger.warning("Feature '%s' getAgentDomainHints() raised: %s", featureCode, exc) return "" if not isinstance(hints, str): return "" return hints def _buildTableSchemaBlock(tableName: str, tableLabel: str, fields: List[str]) -> str: """Render a single table's schema block, enriched from its Pydantic model. Falls back to a flat field list when the model isn't registered (e.g. pure UDB tables, or in early-startup contexts before datamodels are imported). """ headerLine = f' Table: {tableName} "{tableLabel}"' modelClass = MODEL_REGISTRY.get(tableName) docLine = "" if modelClass is not None: rawDoc = (modelClass.__doc__ or "").strip() if rawDoc: docLine = " Description: " + " ".join(rawDoc.split()) if not fields: return headerLine + (("\n" + docLine) if docLine else "") if modelClass is None: return headerLine + f"\n Fields: {', '.join(fields)}" fieldSet = set(fields) fieldLines: List[str] = [] for fieldName, fieldInfo in modelClass.model_fields.items(): if fieldName not in fieldSet: continue fieldLines.append(" - " + _formatFieldLine(fieldName, fieldInfo)) extras = sorted(fieldSet.difference(modelClass.model_fields.keys())) for extra in extras: fieldLines.append(f" - {extra} (unknown)") parts = [headerLine] if docLine: parts.append(docLine) parts.append(" Fields:") parts.extend(fieldLines) return "\n".join(parts) def _formatFieldLine(fieldName: str, fieldInfo: Any) -> str: """Format one field as: ' () "