# Copyright (c) 2025 Patrick Motsch # All rights reserved. """Feature Data Sub-Agent. Specialized mini-agent that queries feature-instance data tables. Receives schema context (fields, descriptions) for the selected tables and has its tools: browseTable, queryTable, aggregateTable. Runs its own agent loop and returns structured results back to the main agent. Round/cost budgets are inherited from the parent agent (workspace user setting `maxAgentRounds` -> `AgentConfig.maxRounds`) and propagated through the tool-call context. Defaults below are only used when the sub-agent is invoked outside an agent loop (e.g. in tests). """ import json import logging from typing import Any, Callable, Awaitable, Dict, List, Optional from modules.datamodels.datamodelAi import ( AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum, ) from modules.datamodels.datamodelBase import MODEL_REGISTRY from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult, ) from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider from modules.shared.i18nRegistry import resolveText from modules.shared.timeUtils import getRequestNow, getRequestTimezone logger = logging.getLogger(__name__) _DEFAULT_MAX_ROUNDS = 8 # Per-round CHF cap. Multiplied by `maxRounds` so the cost guard scales # with the configured round budget instead of cutting the loop short. # 0.15 / 8 ≈ 0.019 — round up to 0.02 for some headroom. _MAX_COST_CHF_PER_ROUND = 0.02 async def runFeatureDataAgent( question: str, featureInstanceId: str, featureCode: str, selectedTables: List[Dict[str, Any]], mandateId: str, userId: str, aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]], dbConnector, instanceLabel: str = "", tableFilters: Optional[Dict[str, Dict[str, str]]] = None, requestLang: Optional[str] = None, neutralizeFields: Optional[Dict[str, List[str]]] = None, maxRounds: Optional[int] = None, maxCostCHF: Optional[float] = None, ) -> str: """Run the feature data sub-agent and return the textual result. Args: question: The user/main-agent question to answer using feature data. featureInstanceId: Feature instance to scope queries. featureCode: Feature code (trustee, commcoach, ...). selectedTables: List of DATA_OBJECT dicts the user selected. mandateId: Mandate scope. userId: Calling user ID. aiCallFn: AI call function (with billing). dbConnector: DatabaseConnector for queries. instanceLabel: Human-readable instance name for context. tableFilters: Per-table record filters from FeatureDataSource.recordFilter. requestLang: ISO 639-1 code for resolving multilingual table labels in the schema prompt. neutralizeFields: Per-table list of field names to mask with placeholders before returning data to the AI. maxRounds: Inherited from the parent agent's configured `maxRounds` (workspace user setting `maxAgentRounds` -> `AgentConfig.maxRounds`). Falls back to the legacy 8-round default when not provided so direct callers / tests still work. maxCostCHF: Hard cost cap for the sub-agent run. When omitted, scales with `maxRounds` to keep per-round budget constant. Returns: Plain-text answer produced by the sub-agent. """ provider = FeatureDataProvider(dbConnector, neutralizeFields=neutralizeFields) registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {}) for tbl in selectedTables: meta = tbl.get("meta", {}) tableName = meta.get("table", "") if tableName: realCols = provider.getActualColumns(tableName) if realCols: meta["fields"] = realCols systemPrompt = _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang) effectiveMaxRounds = int(maxRounds) if maxRounds and maxRounds > 0 else _DEFAULT_MAX_ROUNDS effectiveMaxCost = ( float(maxCostCHF) if maxCostCHF is not None and maxCostCHF > 0 else effectiveMaxRounds * _MAX_COST_CHF_PER_ROUND ) config = AgentConfig( maxRounds=effectiveMaxRounds, maxCostCHF=effectiveMaxCost, operationType=OperationTypeEnum.DATA_QUERY, ) logger.info( "Feature data sub-agent starting: featureInstanceId=%s, maxRounds=%d, maxCostCHF=%.4f", featureInstanceId, effectiveMaxRounds, effectiveMaxCost, ) costAccumulator = 0.0 async def _trackingAiCallFn(req): nonlocal costAccumulator resp = await aiCallFn(req) costAccumulator += resp.priceCHF return resp async def _getWorkflowCost() -> float: return costAccumulator result = "" async for event in runAgentLoop( prompt=question, toolRegistry=registry, config=config, aiCallFn=_trackingAiCallFn, getWorkflowCostFn=_getWorkflowCost, workflowId=f"fda-{featureInstanceId[:8]}", userId=userId, featureInstanceId=featureInstanceId, mandateId=mandateId, systemPromptOverride=systemPrompt, ): if event.type == AgentEventTypeEnum.FINAL and event.content: result = event.content elif event.type == AgentEventTypeEnum.MESSAGE and event.content: result += event.content return result or "(no data returned by feature agent)" # ------------------------------------------------------------------ # tool registration # ------------------------------------------------------------------ def _buildSubAgentTools( provider: FeatureDataProvider, featureInstanceId: str, mandateId: str, tableFilters: Dict[str, Dict[str, str]] = None, ) -> ToolRegistry: """Register browseTable and queryTable as sub-agent tools.""" registry = ToolRegistry() _tableFilters = tableFilters or {} def _recordFilterToList(tableName: str) -> Optional[List[Dict[str, Any]]]: """Convert a recordFilter dict to a list of {field, op, value} filter dicts.""" rf = _tableFilters.get(tableName) if not rf: return None return [{"field": k, "op": "=", "value": v} for k, v in rf.items()] async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]): tableName = args.get("tableName", "") limit = args.get("limit", 50) offset = args.get("offset", 0) fields = args.get("fields") if not tableName: return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required") result = provider.browseTable( tableName=tableName, featureInstanceId=featureInstanceId, mandateId=mandateId, fields=fields, limit=min(limit, 200), offset=offset, extraFilters=_recordFilterToList(tableName), ) return ToolResult( toolCallId="", toolName="browseTable", success="error" not in result, data=json.dumps(result, default=str, ensure_ascii=False)[:30000], error=result.get("error"), ) async def _queryTable(args: Dict[str, Any], context: Dict[str, Any]): tableName = args.get("tableName", "") filters = args.get("filters", []) fields = args.get("fields") orderBy = args.get("orderBy") limit = args.get("limit", 50) offset = args.get("offset", 0) if not tableName: return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required") result = provider.queryTable( tableName=tableName, featureInstanceId=featureInstanceId, mandateId=mandateId, filters=filters, fields=fields, orderBy=orderBy, limit=min(limit, 200), offset=offset, extraFilters=_recordFilterToList(tableName), ) return ToolResult( toolCallId="", toolName="queryTable", success="error" not in result, data=json.dumps(result, default=str, ensure_ascii=False)[:30000], error=result.get("error"), ) async def _aggregateTable(args: Dict[str, Any], context: Dict[str, Any]): tableName = args.get("tableName", "") aggregate = args.get("aggregate", "") field = args.get("field", "") groupBy = args.get("groupBy") if not tableName: return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="tableName required") if not aggregate: return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="aggregate required (SUM, COUNT, AVG, MIN, MAX)") if not field: return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="field required") result = provider.aggregateTable( tableName=tableName, featureInstanceId=featureInstanceId, mandateId=mandateId, aggregate=aggregate, field=field, groupBy=groupBy, extraFilters=_recordFilterToList(tableName), ) return ToolResult( toolCallId="", toolName="aggregateTable", success="error" not in result, data=json.dumps(result, default=str, ensure_ascii=False)[:30000], error=result.get("error"), ) registry.register( "aggregateTable", _aggregateTable, description=( "Run an aggregate query on a feature data table. " "Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY. " "Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', field='debitAmount', groupBy='costCenter')" ), parameters={ "type": "object", "properties": { "tableName": {"type": "string", "description": "Name of the table to aggregate"}, "aggregate": {"type": "string", "enum": ["SUM", "COUNT", "AVG", "MIN", "MAX"], "description": "Aggregate function"}, "field": {"type": "string", "description": "Field to aggregate (e.g. debitAmount, creditAmount)"}, "groupBy": {"type": "string", "description": "Optional field to group by (e.g. costCenter, accountNumber)"}, }, "required": ["tableName", "aggregate", "field"], }, readOnly=True, ) registry.register( "browseTable", _browseTable, description="List rows from a feature data table with pagination.", parameters={ "type": "object", "properties": { "tableName": {"type": "string", "description": "Name of the table to browse"}, "fields": { "type": "array", "items": {"type": "string"}, "description": "Optional list of fields to return (default: all)", }, "limit": {"type": "integer", "description": "Max rows to return (default 50, max 200)"}, "offset": {"type": "integer", "description": "Row offset for pagination"}, }, "required": ["tableName"], }, readOnly=True, ) registry.register( "queryTable", _queryTable, description=( "Query a feature data table with filters, field selection, and ordering. " "Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. " "Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL." ), parameters={ "type": "object", "properties": { "tableName": {"type": "string", "description": "Name of the table to query"}, "filters": { "type": "array", "items": { "type": "object", "properties": { "field": {"type": "string"}, "op": {"type": "string"}, "value": {}, }, }, "description": "Filter conditions", }, "fields": { "type": "array", "items": {"type": "string"}, "description": "Optional list of fields to return", }, "orderBy": {"type": "string", "description": "Field name to order by"}, "limit": {"type": "integer", "description": "Max rows (default 50, max 200)"}, "offset": {"type": "integer", "description": "Row offset"}, }, "required": ["tableName"], }, readOnly=True, ) return registry # ------------------------------------------------------------------ # context building # ------------------------------------------------------------------ def _buildSchemaContext( featureCode: str, instanceLabel: str, selectedTables: List[Dict[str, Any]], requestLang: Optional[str] = None, ) -> str: """Build a system prompt describing available tables and query strategy. Per table the prompt now lists every selected column with its Python type, German label, description and FK target (when available via the registered Pydantic model). This gives the sub-agent enough context to: * pick the right table (e.g. period-bucketed *AccountBalance over raw JournalLine for "Saldo per "), * format date filters as UNIX timestamps when the type is float, * recognise FK relations even though the tools cannot JOIN. On top of the generic schema block we append feature-specific domain hints (e.g. Swiss KMU chart-of-accounts conventions for trustee) when the feature module exports a ``getAgentDomainHints()`` function. This lets each feature teach the sub-agent its own jargon without polluting the generic agent code. """ tableNames: List[str] = [] tableBlocks: List[str] = [] for obj in selectedTables: meta = obj.get("meta", {}) tbl = meta.get("table", "?") fields = list(meta.get("fields") or []) labelStr = resolveText(obj.get("label"), requestLang) tableNames.append(tbl) tableBlocks.append(_buildTableSchemaBlock(tbl, labelStr, fields)) header = f"You are a data query assistant for the '{featureCode}' feature" if instanceLabel: header += f' (instance: "{instanceLabel}")' header += "." tz = getRequestTimezone() now = getRequestNow() temporalLines = [ "CURRENT DATE & TIME (use this for relative time references in filters):", f" Today: {now.strftime('%Y-%m-%d (%A)')}", f" Now: {now.strftime('%H:%M')} ({tz})", " Resolve phrases like 'today', 'last month', 'Q1', 'this year' against THIS date.", " Do NOT use your training cutoff for date filters.", ] parts = [ header, "", *temporalLines, "", "AVAILABLE TABLES (use EXACTLY these names as tableName parameter):", *tableBlocks, "", f"Valid tableName values: {tableNames}", "Field names are plain column names (e.g. 'accountNumber', 'periodYear').", "", "QUERY STRATEGY:", "1. If unsure about columns, call browseTable(tableName) first to inspect the schema.", "2. Use queryTable with filters for targeted lookups.", "3. Use aggregateTable for SUM/COUNT/AVG/MIN/MAX with optional GROUP BY.", "4. Combine what you need into as few tool calls as possible.", "", "RULES:", "- Do NOT invent table or field names. Do NOT prefix fields with UUIDs or dots.", "- Float fields whose description mentions 'unix timestamp' (e.g. bookingDate, lastSyncAt) " "store seconds since epoch. Convert dates to a unix-seconds float before filtering " "(e.g. '2025-12-31' -> 1735603200.0); never compare such fields against ISO strings.", "- The query tools operate on ONE table at a time and CANNOT JOIN. To combine related " "tables (FK target shown in [FK -> Table.field]), query each separately and reason " "about the link in your answer.", "- When a table has period-bucketed aggregates (opening/closing balances or totals per " "period), prefer it over recomputing the same aggregate from raw transactional rows.", "- NEVER apply SUM/AVG to columns that already represent a balance, closing/opening " "total or aggregate (e.g. closingBalance, openingBalance, debitTotal, creditTotal, " "*Balance, *Total). These are already aggregated per period — summing them across " "periods produces meaningless numbers. Use queryTable with explicit period filters " "instead, then pick the single matching row.", "- CRITICAL: Return data as compact JSON, NOT as markdown tables or prose.", "- Do NOT reformat, rewrite, or narrate the tool results. Return the raw data directly.", "- If the question asks for rows, return them as a JSON array. Do NOT generate a markdown table.", "- Keep your answer SHORT. The caller is a machine, not a human.", ] domainHints = _loadFeatureDomainHints(featureCode) if domainHints: parts.extend(["", domainHints.strip()]) return "\n".join(parts) def _loadFeatureDomainHints(featureCode: str) -> str: """Pull optional domain-specific hints from the feature's main module. Each feature can expose ``getAgentDomainHints() -> str`` in its ``mainXxx.py``. The returned text is appended verbatim to the sub-agent's system prompt, so features can teach the agent their own domain jargon (chart-of-accounts conventions, period semantics, query patterns) without coupling the generic agent code to any one feature. Failures (missing module, missing hook, exception inside the hook) are swallowed and just yield an empty hints block — domain hints are a best-effort enhancement, not a hard dependency. """ if not featureCode: return "" try: from modules.system.registry import loadFeatureMainModules except Exception: return "" try: mainModules = loadFeatureMainModules() or {} except Exception as exc: logger.debug("Domain-hints lookup: cannot load main modules (%s)", exc) return "" module = mainModules.get(featureCode) or mainModules.get(featureCode.lower()) if module is None: return "" hook = getattr(module, "getAgentDomainHints", None) if not callable(hook): return "" try: hints = hook() except Exception as exc: logger.warning("Feature '%s' getAgentDomainHints() raised: %s", featureCode, exc) return "" if not isinstance(hints, str): return "" return hints def _buildTableSchemaBlock(tableName: str, tableLabel: str, fields: List[str]) -> str: """Render a single table's schema block, enriched from its Pydantic model. Falls back to a flat field list when the model isn't registered (e.g. pure UDB tables, or in early-startup contexts before datamodels are imported). """ headerLine = f' Table: {tableName} "{tableLabel}"' modelClass = MODEL_REGISTRY.get(tableName) docLine = "" if modelClass is not None: rawDoc = (modelClass.__doc__ or "").strip() if rawDoc: docLine = " Description: " + " ".join(rawDoc.split()) if not fields: return headerLine + (("\n" + docLine) if docLine else "") if modelClass is None: return headerLine + f"\n Fields: {', '.join(fields)}" fieldSet = set(fields) fieldLines: List[str] = [] for fieldName, fieldInfo in modelClass.model_fields.items(): if fieldName not in fieldSet: continue fieldLines.append(" - " + _formatFieldLine(fieldName, fieldInfo)) extras = sorted(fieldSet.difference(modelClass.model_fields.keys())) for extra in extras: fieldLines.append(f" - {extra} (unknown)") parts = [headerLine] if docLine: parts.append(docLine) parts.append(" Fields:") parts.extend(fieldLines) return "\n".join(parts) def _formatFieldLine(fieldName: str, fieldInfo: Any) -> str: """Format one field as: ' () "