gateway/modules/serviceCenter/services/serviceAgent/featureDataAgent.py
2026-04-27 08:07:37 +02:00

541 lines
22 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Feature Data Sub-Agent.
Specialized mini-agent that queries feature-instance data tables. Receives
schema context (fields, descriptions) for the selected tables and has its
tools: browseTable, queryTable, aggregateTable. Runs its own agent loop
and returns structured results back to the main agent.
Round/cost budgets are inherited from the parent agent (workspace user
setting `maxAgentRounds` -> `AgentConfig.maxRounds`) and propagated through
the tool-call context. Defaults below are only used when the sub-agent is
invoked outside an agent loop (e.g. in tests).
"""
import json
import logging
from typing import Any, Callable, Awaitable, Dict, List, Optional
from modules.datamodels.datamodelAi import (
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum,
)
from modules.datamodels.datamodelBase import MODEL_REGISTRY
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
)
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
from modules.shared.i18nRegistry import resolveText
from modules.shared.timeUtils import getRequestNow, getRequestTimezone
logger = logging.getLogger(__name__)
_DEFAULT_MAX_ROUNDS = 8
# Per-round CHF cap. Multiplied by `maxRounds` so the cost guard scales
# with the configured round budget instead of cutting the loop short.
# 0.15 / 8 ≈ 0.019 — round up to 0.02 for some headroom.
_MAX_COST_CHF_PER_ROUND = 0.02
async def runFeatureDataAgent(
question: str,
featureInstanceId: str,
featureCode: str,
selectedTables: List[Dict[str, Any]],
mandateId: str,
userId: str,
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
dbConnector,
instanceLabel: str = "",
tableFilters: Optional[Dict[str, Dict[str, str]]] = None,
requestLang: Optional[str] = None,
neutralizeFields: Optional[Dict[str, List[str]]] = None,
maxRounds: Optional[int] = None,
maxCostCHF: Optional[float] = None,
) -> str:
"""Run the feature data sub-agent and return the textual result.
Args:
question: The user/main-agent question to answer using feature data.
featureInstanceId: Feature instance to scope queries.
featureCode: Feature code (trustee, commcoach, ...).
selectedTables: List of DATA_OBJECT dicts the user selected.
mandateId: Mandate scope.
userId: Calling user ID.
aiCallFn: AI call function (with billing).
dbConnector: DatabaseConnector for queries.
instanceLabel: Human-readable instance name for context.
tableFilters: Per-table record filters from FeatureDataSource.recordFilter.
requestLang: ISO 639-1 code for resolving multilingual table labels in the schema prompt.
neutralizeFields: Per-table list of field names to mask with placeholders
before returning data to the AI.
maxRounds: Inherited from the parent agent's configured `maxRounds`
(workspace user setting `maxAgentRounds` -> `AgentConfig.maxRounds`).
Falls back to the legacy 8-round default when not provided so direct
callers / tests still work.
maxCostCHF: Hard cost cap for the sub-agent run. When omitted, scales
with `maxRounds` to keep per-round budget constant.
Returns:
Plain-text answer produced by the sub-agent.
"""
provider = FeatureDataProvider(dbConnector, neutralizeFields=neutralizeFields)
registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {})
for tbl in selectedTables:
meta = tbl.get("meta", {})
tableName = meta.get("table", "")
if tableName:
realCols = provider.getActualColumns(tableName)
if realCols:
meta["fields"] = realCols
systemPrompt = _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang)
effectiveMaxRounds = int(maxRounds) if maxRounds and maxRounds > 0 else _DEFAULT_MAX_ROUNDS
effectiveMaxCost = (
float(maxCostCHF)
if maxCostCHF is not None and maxCostCHF > 0
else effectiveMaxRounds * _MAX_COST_CHF_PER_ROUND
)
config = AgentConfig(
maxRounds=effectiveMaxRounds,
maxCostCHF=effectiveMaxCost,
operationType=OperationTypeEnum.DATA_QUERY,
)
logger.info(
"Feature data sub-agent starting: featureInstanceId=%s, maxRounds=%d, maxCostCHF=%.4f",
featureInstanceId, effectiveMaxRounds, effectiveMaxCost,
)
costAccumulator = 0.0
async def _trackingAiCallFn(req):
nonlocal costAccumulator
resp = await aiCallFn(req)
costAccumulator += resp.priceCHF
return resp
async def _getWorkflowCost() -> float:
return costAccumulator
result = ""
async for event in runAgentLoop(
prompt=question,
toolRegistry=registry,
config=config,
aiCallFn=_trackingAiCallFn,
getWorkflowCostFn=_getWorkflowCost,
workflowId=f"fda-{featureInstanceId[:8]}",
userId=userId,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
systemPromptOverride=systemPrompt,
):
if event.type == AgentEventTypeEnum.FINAL and event.content:
result = event.content
elif event.type == AgentEventTypeEnum.MESSAGE and event.content:
result += event.content
return result or "(no data returned by feature agent)"
# ------------------------------------------------------------------
# tool registration
# ------------------------------------------------------------------
def _buildSubAgentTools(
provider: FeatureDataProvider,
featureInstanceId: str,
mandateId: str,
tableFilters: Dict[str, Dict[str, str]] = None,
) -> ToolRegistry:
"""Register browseTable and queryTable as sub-agent tools."""
registry = ToolRegistry()
_tableFilters = tableFilters or {}
def _recordFilterToList(tableName: str) -> Optional[List[Dict[str, Any]]]:
"""Convert a recordFilter dict to a list of {field, op, value} filter dicts."""
rf = _tableFilters.get(tableName)
if not rf:
return None
return [{"field": k, "op": "=", "value": v} for k, v in rf.items()]
async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]):
tableName = args.get("tableName", "")
limit = args.get("limit", 50)
offset = args.get("offset", 0)
fields = args.get("fields")
if not tableName:
return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required")
result = provider.browseTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
fields=fields,
limit=min(limit, 200),
offset=offset,
extraFilters=_recordFilterToList(tableName),
)
return ToolResult(
toolCallId="", toolName="browseTable",
success="error" not in result,
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
error=result.get("error"),
)
async def _queryTable(args: Dict[str, Any], context: Dict[str, Any]):
tableName = args.get("tableName", "")
filters = args.get("filters", [])
fields = args.get("fields")
orderBy = args.get("orderBy")
limit = args.get("limit", 50)
offset = args.get("offset", 0)
if not tableName:
return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required")
result = provider.queryTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
filters=filters,
fields=fields,
orderBy=orderBy,
limit=min(limit, 200),
offset=offset,
extraFilters=_recordFilterToList(tableName),
)
return ToolResult(
toolCallId="", toolName="queryTable",
success="error" not in result,
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
error=result.get("error"),
)
async def _aggregateTable(args: Dict[str, Any], context: Dict[str, Any]):
tableName = args.get("tableName", "")
aggregate = args.get("aggregate", "")
field = args.get("field", "")
groupBy = args.get("groupBy")
if not tableName:
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="tableName required")
if not aggregate:
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="aggregate required (SUM, COUNT, AVG, MIN, MAX)")
if not field:
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="field required")
result = provider.aggregateTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
aggregate=aggregate,
field=field,
groupBy=groupBy,
extraFilters=_recordFilterToList(tableName),
)
return ToolResult(
toolCallId="", toolName="aggregateTable",
success="error" not in result,
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
error=result.get("error"),
)
registry.register(
"aggregateTable", _aggregateTable,
description=(
"Run an aggregate query on a feature data table. "
"Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY. "
"Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', field='debitAmount', groupBy='costCenter')"
),
parameters={
"type": "object",
"properties": {
"tableName": {"type": "string", "description": "Name of the table to aggregate"},
"aggregate": {"type": "string", "enum": ["SUM", "COUNT", "AVG", "MIN", "MAX"], "description": "Aggregate function"},
"field": {"type": "string", "description": "Field to aggregate (e.g. debitAmount, creditAmount)"},
"groupBy": {"type": "string", "description": "Optional field to group by (e.g. costCenter, accountNumber)"},
},
"required": ["tableName", "aggregate", "field"],
},
readOnly=True,
)
registry.register(
"browseTable", _browseTable,
description="List rows from a feature data table with pagination.",
parameters={
"type": "object",
"properties": {
"tableName": {"type": "string", "description": "Name of the table to browse"},
"fields": {
"type": "array", "items": {"type": "string"},
"description": "Optional list of fields to return (default: all)",
},
"limit": {"type": "integer", "description": "Max rows to return (default 50, max 200)"},
"offset": {"type": "integer", "description": "Row offset for pagination"},
},
"required": ["tableName"],
},
readOnly=True,
)
registry.register(
"queryTable", _queryTable,
description=(
"Query a feature data table with filters, field selection, and ordering. "
"Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. "
"Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL."
),
parameters={
"type": "object",
"properties": {
"tableName": {"type": "string", "description": "Name of the table to query"},
"filters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"field": {"type": "string"},
"op": {"type": "string"},
"value": {},
},
},
"description": "Filter conditions",
},
"fields": {
"type": "array", "items": {"type": "string"},
"description": "Optional list of fields to return",
},
"orderBy": {"type": "string", "description": "Field name to order by"},
"limit": {"type": "integer", "description": "Max rows (default 50, max 200)"},
"offset": {"type": "integer", "description": "Row offset"},
},
"required": ["tableName"],
},
readOnly=True,
)
return registry
# ------------------------------------------------------------------
# context building
# ------------------------------------------------------------------
def _buildSchemaContext(
featureCode: str,
instanceLabel: str,
selectedTables: List[Dict[str, Any]],
requestLang: Optional[str] = None,
) -> str:
"""Build a system prompt describing available tables and query strategy.
Per table the prompt now lists every selected column with its Python type,
German label, description and FK target (when available via the registered
Pydantic model). This gives the sub-agent enough context to:
* pick the right table (e.g. period-bucketed *AccountBalance over raw
JournalLine for "Saldo per <date>"),
* format date filters as UNIX timestamps when the type is float,
* recognise FK relations even though the tools cannot JOIN.
On top of the generic schema block we append feature-specific domain
hints (e.g. Swiss KMU chart-of-accounts conventions for trustee) when
the feature module exports a ``getAgentDomainHints()`` function. This
lets each feature teach the sub-agent its own jargon without polluting
the generic agent code.
"""
tableNames: List[str] = []
tableBlocks: List[str] = []
for obj in selectedTables:
meta = obj.get("meta", {})
tbl = meta.get("table", "?")
fields = list(meta.get("fields") or [])
labelStr = resolveText(obj.get("label"), requestLang)
tableNames.append(tbl)
tableBlocks.append(_buildTableSchemaBlock(tbl, labelStr, fields))
header = f"You are a data query assistant for the '{featureCode}' feature"
if instanceLabel:
header += f' (instance: "{instanceLabel}")'
header += "."
tz = getRequestTimezone()
now = getRequestNow()
temporalLines = [
"CURRENT DATE & TIME (use this for relative time references in filters):",
f" Today: {now.strftime('%Y-%m-%d (%A)')}",
f" Now: {now.strftime('%H:%M')} ({tz})",
" Resolve phrases like 'today', 'last month', 'Q1', 'this year' against THIS date.",
" Do NOT use your training cutoff for date filters.",
]
parts = [
header,
"",
*temporalLines,
"",
"AVAILABLE TABLES (use EXACTLY these names as tableName parameter):",
*tableBlocks,
"",
f"Valid tableName values: {tableNames}",
"Field names are plain column names (e.g. 'accountNumber', 'periodYear').",
"",
"QUERY STRATEGY:",
"1. If unsure about columns, call browseTable(tableName) first to inspect the schema.",
"2. Use queryTable with filters for targeted lookups.",
"3. Use aggregateTable for SUM/COUNT/AVG/MIN/MAX with optional GROUP BY.",
"4. Combine what you need into as few tool calls as possible.",
"",
"RULES:",
"- Do NOT invent table or field names. Do NOT prefix fields with UUIDs or dots.",
"- Float fields whose description mentions 'unix timestamp' (e.g. bookingDate, lastSyncAt) "
"store seconds since epoch. Convert dates to a unix-seconds float before filtering "
"(e.g. '2025-12-31' -> 1735603200.0); never compare such fields against ISO strings.",
"- The query tools operate on ONE table at a time and CANNOT JOIN. To combine related "
"tables (FK target shown in [FK -> Table.field]), query each separately and reason "
"about the link in your answer.",
"- When a table has period-bucketed aggregates (opening/closing balances or totals per "
"period), prefer it over recomputing the same aggregate from raw transactional rows.",
"- NEVER apply SUM/AVG to columns that already represent a balance, closing/opening "
"total or aggregate (e.g. closingBalance, openingBalance, debitTotal, creditTotal, "
"*Balance, *Total). These are already aggregated per period — summing them across "
"periods produces meaningless numbers. Use queryTable with explicit period filters "
"instead, then pick the single matching row.",
"- CRITICAL: Return data as compact JSON, NOT as markdown tables or prose.",
"- Do NOT reformat, rewrite, or narrate the tool results. Return the raw data directly.",
"- If the question asks for rows, return them as a JSON array. Do NOT generate a markdown table.",
"- Keep your answer SHORT. The caller is a machine, not a human.",
]
domainHints = _loadFeatureDomainHints(featureCode)
if domainHints:
parts.extend(["", domainHints.strip()])
return "\n".join(parts)
def _loadFeatureDomainHints(featureCode: str) -> str:
"""Pull optional domain-specific hints from the feature's main module.
Each feature can expose ``getAgentDomainHints() -> str`` in its
``mainXxx.py``. The returned text is appended verbatim to the
sub-agent's system prompt, so features can teach the agent their own
domain jargon (chart-of-accounts conventions, period semantics, query
patterns) without coupling the generic agent code to any one feature.
Failures (missing module, missing hook, exception inside the hook) are
swallowed and just yield an empty hints block — domain hints are a
best-effort enhancement, not a hard dependency.
"""
if not featureCode:
return ""
try:
from modules.system.registry import loadFeatureMainModules
except Exception:
return ""
try:
mainModules = loadFeatureMainModules() or {}
except Exception as exc:
logger.debug("Domain-hints lookup: cannot load main modules (%s)", exc)
return ""
module = mainModules.get(featureCode) or mainModules.get(featureCode.lower())
if module is None:
return ""
hook = getattr(module, "getAgentDomainHints", None)
if not callable(hook):
return ""
try:
hints = hook()
except Exception as exc:
logger.warning("Feature '%s' getAgentDomainHints() raised: %s", featureCode, exc)
return ""
if not isinstance(hints, str):
return ""
return hints
def _buildTableSchemaBlock(tableName: str, tableLabel: str, fields: List[str]) -> str:
"""Render a single table's schema block, enriched from its Pydantic model.
Falls back to a flat field list when the model isn't registered (e.g. pure
UDB tables, or in early-startup contexts before datamodels are imported).
"""
headerLine = f' Table: {tableName} "{tableLabel}"'
modelClass = MODEL_REGISTRY.get(tableName)
docLine = ""
if modelClass is not None:
rawDoc = (modelClass.__doc__ or "").strip()
if rawDoc:
docLine = " Description: " + " ".join(rawDoc.split())
if not fields:
return headerLine + (("\n" + docLine) if docLine else "")
if modelClass is None:
return headerLine + f"\n Fields: {', '.join(fields)}"
fieldSet = set(fields)
fieldLines: List[str] = []
for fieldName, fieldInfo in modelClass.model_fields.items():
if fieldName not in fieldSet:
continue
fieldLines.append(" - " + _formatFieldLine(fieldName, fieldInfo))
extras = sorted(fieldSet.difference(modelClass.model_fields.keys()))
for extra in extras:
fieldLines.append(f" - {extra} (unknown)")
parts = [headerLine]
if docLine:
parts.append(docLine)
parts.append(" Fields:")
parts.extend(fieldLines)
return "\n".join(parts)
def _formatFieldLine(fieldName: str, fieldInfo: Any) -> str:
"""Format one field as: '<name> (<type>) "<label>": <description> [FK -> Table.field]'."""
pyType = _summarizePythonType(getattr(fieldInfo, "annotation", None))
extra = getattr(fieldInfo, "json_schema_extra", None)
if not isinstance(extra, dict):
extra = {}
rawLabel = extra.get("label")
label = rawLabel if isinstance(rawLabel, str) else None
rawDesc = getattr(fieldInfo, "description", None)
desc = rawDesc.strip() if isinstance(rawDesc, str) else ""
line = f"{fieldName} ({pyType})"
if label and label != fieldName:
line += f' "{label}"'
if desc:
line += f": {desc}"
fkTarget = extra.get("fk_target")
if isinstance(fkTarget, dict) and fkTarget.get("table"):
targetField = fkTarget.get("targetField") or "id"
line += f" [FK -> {fkTarget['table']}.{targetField}]"
return line
def _summarizePythonType(annotation: Any) -> str:
"""Compact stringification of a Pydantic field annotation for AI prompts."""
if annotation is None:
return "any"
raw = str(annotation)
raw = raw.replace("typing.", "")
if raw.startswith("<class '") and raw.endswith("'>"):
raw = raw[len("<class '"):-len("'>")]
return raw