480 lines
19 KiB
Python
480 lines
19 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Feature Data Sub-Agent.
|
|
|
|
Specialized mini-agent that queries feature-instance data tables. Receives
|
|
schema context (fields, descriptions) for the selected tables and has its
|
|
tools: browseTable, queryTable, aggregateTable. Runs its own agent loop
|
|
and returns structured results back to the main agent.
|
|
|
|
Round/cost budgets are inherited from the parent agent (workspace user
|
|
setting `maxAgentRounds` -> `AgentConfig.maxRounds`) and propagated through
|
|
the tool-call context. Defaults below are only used when the sub-agent is
|
|
invoked outside an agent loop (e.g. in tests).
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any, Callable, Awaitable, Dict, List, Optional
|
|
|
|
from modules.datamodels.datamodelAi import (
|
|
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum,
|
|
)
|
|
from modules.datamodels.datamodelBase import MODEL_REGISTRY
|
|
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop
|
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
|
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
|
|
)
|
|
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
|
from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
|
|
from modules.shared.i18nRegistry import resolveText
|
|
from modules.shared.timeUtils import getRequestNow, getRequestTimezone
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEFAULT_MAX_ROUNDS = 8
|
|
# Per-round CHF cap. Multiplied by `maxRounds` so the cost guard scales
|
|
# with the configured round budget instead of cutting the loop short.
|
|
# 0.15 / 8 ≈ 0.019 — round up to 0.02 for some headroom.
|
|
_MAX_COST_CHF_PER_ROUND = 0.02
|
|
|
|
|
|
async def runFeatureDataAgent(
|
|
question: str,
|
|
featureInstanceId: str,
|
|
featureCode: str,
|
|
selectedTables: List[Dict[str, Any]],
|
|
mandateId: str,
|
|
userId: str,
|
|
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
|
|
dbConnector,
|
|
instanceLabel: str = "",
|
|
tableFilters: Optional[Dict[str, Dict[str, str]]] = None,
|
|
requestLang: Optional[str] = None,
|
|
neutralizeFields: Optional[Dict[str, List[str]]] = None,
|
|
maxRounds: Optional[int] = None,
|
|
maxCostCHF: Optional[float] = None,
|
|
) -> str:
|
|
"""Run the feature data sub-agent and return the textual result.
|
|
|
|
Args:
|
|
question: The user/main-agent question to answer using feature data.
|
|
featureInstanceId: Feature instance to scope queries.
|
|
featureCode: Feature code (trustee, commcoach, ...).
|
|
selectedTables: List of DATA_OBJECT dicts the user selected.
|
|
mandateId: Mandate scope.
|
|
userId: Calling user ID.
|
|
aiCallFn: AI call function (with billing).
|
|
dbConnector: DatabaseConnector for queries.
|
|
instanceLabel: Human-readable instance name for context.
|
|
tableFilters: Per-table record filters from FeatureDataSource.recordFilter.
|
|
requestLang: ISO 639-1 code for resolving multilingual table labels in the schema prompt.
|
|
neutralizeFields: Per-table list of field names to mask with placeholders
|
|
before returning data to the AI.
|
|
maxRounds: Inherited from the parent agent's configured `maxRounds`
|
|
(workspace user setting `maxAgentRounds` -> `AgentConfig.maxRounds`).
|
|
Falls back to the legacy 8-round default when not provided so direct
|
|
callers / tests still work.
|
|
maxCostCHF: Hard cost cap for the sub-agent run. When omitted, scales
|
|
with `maxRounds` to keep per-round budget constant.
|
|
|
|
Returns:
|
|
Plain-text answer produced by the sub-agent.
|
|
"""
|
|
|
|
provider = FeatureDataProvider(dbConnector, neutralizeFields=neutralizeFields)
|
|
registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {})
|
|
|
|
for tbl in selectedTables:
|
|
meta = tbl.get("meta", {})
|
|
tableName = meta.get("table", "")
|
|
if tableName:
|
|
realCols = provider.getActualColumns(tableName)
|
|
if realCols:
|
|
meta["fields"] = realCols
|
|
|
|
systemPrompt = _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang)
|
|
|
|
effectiveMaxRounds = int(maxRounds) if maxRounds and maxRounds > 0 else _DEFAULT_MAX_ROUNDS
|
|
effectiveMaxCost = (
|
|
float(maxCostCHF)
|
|
if maxCostCHF is not None and maxCostCHF > 0
|
|
else effectiveMaxRounds * _MAX_COST_CHF_PER_ROUND
|
|
)
|
|
|
|
config = AgentConfig(
|
|
maxRounds=effectiveMaxRounds,
|
|
maxCostCHF=effectiveMaxCost,
|
|
operationType=OperationTypeEnum.DATA_QUERY,
|
|
)
|
|
logger.info(
|
|
"Feature data sub-agent starting: featureInstanceId=%s, maxRounds=%d, maxCostCHF=%.4f",
|
|
featureInstanceId, effectiveMaxRounds, effectiveMaxCost,
|
|
)
|
|
|
|
costAccumulator = 0.0
|
|
|
|
async def _trackingAiCallFn(req):
|
|
nonlocal costAccumulator
|
|
resp = await aiCallFn(req)
|
|
costAccumulator += resp.priceCHF
|
|
return resp
|
|
|
|
async def _getWorkflowCost() -> float:
|
|
return costAccumulator
|
|
|
|
result = ""
|
|
async for event in runAgentLoop(
|
|
prompt=question,
|
|
toolRegistry=registry,
|
|
config=config,
|
|
aiCallFn=_trackingAiCallFn,
|
|
getWorkflowCostFn=_getWorkflowCost,
|
|
workflowId=f"fda-{featureInstanceId[:8]}",
|
|
userId=userId,
|
|
featureInstanceId=featureInstanceId,
|
|
mandateId=mandateId,
|
|
systemPromptOverride=systemPrompt,
|
|
):
|
|
if event.type == AgentEventTypeEnum.FINAL and event.content:
|
|
result = event.content
|
|
elif event.type == AgentEventTypeEnum.MESSAGE and event.content:
|
|
result += event.content
|
|
|
|
return result or "(no data returned by feature agent)"
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# tool registration
|
|
# ------------------------------------------------------------------
|
|
|
|
def _buildSubAgentTools(
|
|
provider: FeatureDataProvider,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
tableFilters: Dict[str, Dict[str, str]] = None,
|
|
) -> ToolRegistry:
|
|
"""Register browseTable and queryTable as sub-agent tools."""
|
|
registry = ToolRegistry()
|
|
_tableFilters = tableFilters or {}
|
|
|
|
def _recordFilterToList(tableName: str) -> Optional[List[Dict[str, Any]]]:
|
|
"""Convert a recordFilter dict to a list of {field, op, value} filter dicts."""
|
|
rf = _tableFilters.get(tableName)
|
|
if not rf:
|
|
return None
|
|
return [{"field": k, "op": "=", "value": v} for k, v in rf.items()]
|
|
|
|
async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]):
|
|
tableName = args.get("tableName", "")
|
|
limit = args.get("limit", 50)
|
|
offset = args.get("offset", 0)
|
|
fields = args.get("fields")
|
|
if not tableName:
|
|
return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required")
|
|
result = provider.browseTable(
|
|
tableName=tableName,
|
|
featureInstanceId=featureInstanceId,
|
|
mandateId=mandateId,
|
|
fields=fields,
|
|
limit=min(limit, 200),
|
|
offset=offset,
|
|
extraFilters=_recordFilterToList(tableName),
|
|
)
|
|
return ToolResult(
|
|
toolCallId="", toolName="browseTable",
|
|
success="error" not in result,
|
|
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
|
|
error=result.get("error"),
|
|
)
|
|
|
|
async def _queryTable(args: Dict[str, Any], context: Dict[str, Any]):
|
|
tableName = args.get("tableName", "")
|
|
filters = args.get("filters", [])
|
|
fields = args.get("fields")
|
|
orderBy = args.get("orderBy")
|
|
limit = args.get("limit", 50)
|
|
offset = args.get("offset", 0)
|
|
if not tableName:
|
|
return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required")
|
|
result = provider.queryTable(
|
|
tableName=tableName,
|
|
featureInstanceId=featureInstanceId,
|
|
mandateId=mandateId,
|
|
filters=filters,
|
|
fields=fields,
|
|
orderBy=orderBy,
|
|
limit=min(limit, 200),
|
|
offset=offset,
|
|
extraFilters=_recordFilterToList(tableName),
|
|
)
|
|
return ToolResult(
|
|
toolCallId="", toolName="queryTable",
|
|
success="error" not in result,
|
|
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
|
|
error=result.get("error"),
|
|
)
|
|
|
|
async def _aggregateTable(args: Dict[str, Any], context: Dict[str, Any]):
|
|
tableName = args.get("tableName", "")
|
|
aggregate = args.get("aggregate", "")
|
|
field = args.get("field", "")
|
|
groupBy = args.get("groupBy")
|
|
if not tableName:
|
|
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="tableName required")
|
|
if not aggregate:
|
|
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="aggregate required (SUM, COUNT, AVG, MIN, MAX)")
|
|
if not field:
|
|
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="field required")
|
|
result = provider.aggregateTable(
|
|
tableName=tableName,
|
|
featureInstanceId=featureInstanceId,
|
|
mandateId=mandateId,
|
|
aggregate=aggregate,
|
|
field=field,
|
|
groupBy=groupBy,
|
|
extraFilters=_recordFilterToList(tableName),
|
|
)
|
|
return ToolResult(
|
|
toolCallId="", toolName="aggregateTable",
|
|
success="error" not in result,
|
|
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
|
|
error=result.get("error"),
|
|
)
|
|
|
|
registry.register(
|
|
"aggregateTable", _aggregateTable,
|
|
description=(
|
|
"Run an aggregate query on a feature data table. "
|
|
"Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY. "
|
|
"Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', field='debitAmount', groupBy='costCenter')"
|
|
),
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"tableName": {"type": "string", "description": "Name of the table to aggregate"},
|
|
"aggregate": {"type": "string", "enum": ["SUM", "COUNT", "AVG", "MIN", "MAX"], "description": "Aggregate function"},
|
|
"field": {"type": "string", "description": "Field to aggregate (e.g. debitAmount, creditAmount)"},
|
|
"groupBy": {"type": "string", "description": "Optional field to group by (e.g. costCenter, accountNumber)"},
|
|
},
|
|
"required": ["tableName", "aggregate", "field"],
|
|
},
|
|
readOnly=True,
|
|
)
|
|
|
|
registry.register(
|
|
"browseTable", _browseTable,
|
|
description="List rows from a feature data table with pagination.",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"tableName": {"type": "string", "description": "Name of the table to browse"},
|
|
"fields": {
|
|
"type": "array", "items": {"type": "string"},
|
|
"description": "Optional list of fields to return (default: all)",
|
|
},
|
|
"limit": {"type": "integer", "description": "Max rows to return (default 50, max 200)"},
|
|
"offset": {"type": "integer", "description": "Row offset for pagination"},
|
|
},
|
|
"required": ["tableName"],
|
|
},
|
|
readOnly=True,
|
|
)
|
|
|
|
registry.register(
|
|
"queryTable", _queryTable,
|
|
description=(
|
|
"Query a feature data table with filters, field selection, and ordering. "
|
|
"Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. "
|
|
"Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL."
|
|
),
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"tableName": {"type": "string", "description": "Name of the table to query"},
|
|
"filters": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"field": {"type": "string"},
|
|
"op": {"type": "string"},
|
|
"value": {},
|
|
},
|
|
},
|
|
"description": "Filter conditions",
|
|
},
|
|
"fields": {
|
|
"type": "array", "items": {"type": "string"},
|
|
"description": "Optional list of fields to return",
|
|
},
|
|
"orderBy": {"type": "string", "description": "Field name to order by"},
|
|
"limit": {"type": "integer", "description": "Max rows (default 50, max 200)"},
|
|
"offset": {"type": "integer", "description": "Row offset"},
|
|
},
|
|
"required": ["tableName"],
|
|
},
|
|
readOnly=True,
|
|
)
|
|
|
|
return registry
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# context building
|
|
# ------------------------------------------------------------------
|
|
|
|
def _buildSchemaContext(
|
|
featureCode: str,
|
|
instanceLabel: str,
|
|
selectedTables: List[Dict[str, Any]],
|
|
requestLang: Optional[str] = None,
|
|
) -> str:
|
|
"""Build a system prompt describing available tables and query strategy.
|
|
|
|
Per table the prompt now lists every selected column with its Python type,
|
|
German label, description and FK target (when available via the registered
|
|
Pydantic model). This gives the sub-agent enough context to:
|
|
* pick the right table (e.g. period-bucketed *AccountBalance over raw
|
|
JournalLine for "Saldo per <date>"),
|
|
* format date filters as UNIX timestamps when the type is float,
|
|
* recognise FK relations even though the tools cannot JOIN.
|
|
"""
|
|
tableNames: List[str] = []
|
|
tableBlocks: List[str] = []
|
|
|
|
for obj in selectedTables:
|
|
meta = obj.get("meta", {})
|
|
tbl = meta.get("table", "?")
|
|
fields = list(meta.get("fields") or [])
|
|
labelStr = resolveText(obj.get("label"), requestLang)
|
|
tableNames.append(tbl)
|
|
tableBlocks.append(_buildTableSchemaBlock(tbl, labelStr, fields))
|
|
|
|
header = f"You are a data query assistant for the '{featureCode}' feature"
|
|
if instanceLabel:
|
|
header += f' (instance: "{instanceLabel}")'
|
|
header += "."
|
|
|
|
tz = getRequestTimezone()
|
|
now = getRequestNow()
|
|
temporalLines = [
|
|
"CURRENT DATE & TIME (use this for relative time references in filters):",
|
|
f" Today: {now.strftime('%Y-%m-%d (%A)')}",
|
|
f" Now: {now.strftime('%H:%M')} ({tz})",
|
|
" Resolve phrases like 'today', 'last month', 'Q1', 'this year' against THIS date.",
|
|
" Do NOT use your training cutoff for date filters.",
|
|
]
|
|
|
|
parts = [
|
|
header,
|
|
"",
|
|
*temporalLines,
|
|
"",
|
|
"AVAILABLE TABLES (use EXACTLY these names as tableName parameter):",
|
|
*tableBlocks,
|
|
"",
|
|
f"Valid tableName values: {tableNames}",
|
|
"Field names are plain column names (e.g. 'accountNumber', 'periodYear').",
|
|
"",
|
|
"QUERY STRATEGY:",
|
|
"1. If unsure about columns, call browseTable(tableName) first to inspect the schema.",
|
|
"2. Use queryTable with filters for targeted lookups.",
|
|
"3. Use aggregateTable for SUM/COUNT/AVG/MIN/MAX with optional GROUP BY.",
|
|
"4. Combine what you need into as few tool calls as possible.",
|
|
"",
|
|
"RULES:",
|
|
"- Do NOT invent table or field names. Do NOT prefix fields with UUIDs or dots.",
|
|
"- Float fields whose description mentions 'unix timestamp' (e.g. bookingDate, lastSyncAt) "
|
|
"store seconds since epoch. Convert dates to a unix-seconds float before filtering "
|
|
"(e.g. '2025-12-31' -> 1735603200.0); never compare such fields against ISO strings.",
|
|
"- The query tools operate on ONE table at a time and CANNOT JOIN. To combine related "
|
|
"tables (FK target shown in [FK -> Table.field]), query each separately and reason "
|
|
"about the link in your answer.",
|
|
"- When a table has period-bucketed aggregates (opening/closing balances or totals per "
|
|
"period), prefer it over recomputing the same aggregate from raw transactional rows.",
|
|
"- CRITICAL: Return data as compact JSON, NOT as markdown tables or prose.",
|
|
"- Do NOT reformat, rewrite, or narrate the tool results. Return the raw data directly.",
|
|
"- If the question asks for rows, return them as a JSON array. Do NOT generate a markdown table.",
|
|
"- Keep your answer SHORT. The caller is a machine, not a human.",
|
|
]
|
|
return "\n".join(parts)
|
|
|
|
|
|
def _buildTableSchemaBlock(tableName: str, tableLabel: str, fields: List[str]) -> str:
|
|
"""Render a single table's schema block, enriched from its Pydantic model.
|
|
|
|
Falls back to a flat field list when the model isn't registered (e.g. pure
|
|
UDB tables, or in early-startup contexts before datamodels are imported).
|
|
"""
|
|
headerLine = f' Table: {tableName} "{tableLabel}"'
|
|
|
|
modelClass = MODEL_REGISTRY.get(tableName)
|
|
docLine = ""
|
|
if modelClass is not None:
|
|
rawDoc = (modelClass.__doc__ or "").strip()
|
|
if rawDoc:
|
|
docLine = " Description: " + " ".join(rawDoc.split())
|
|
|
|
if not fields:
|
|
return headerLine + (("\n" + docLine) if docLine else "")
|
|
|
|
if modelClass is None:
|
|
return headerLine + f"\n Fields: {', '.join(fields)}"
|
|
|
|
fieldSet = set(fields)
|
|
fieldLines: List[str] = []
|
|
for fieldName, fieldInfo in modelClass.model_fields.items():
|
|
if fieldName not in fieldSet:
|
|
continue
|
|
fieldLines.append(" - " + _formatFieldLine(fieldName, fieldInfo))
|
|
|
|
extras = sorted(fieldSet.difference(modelClass.model_fields.keys()))
|
|
for extra in extras:
|
|
fieldLines.append(f" - {extra} (unknown)")
|
|
|
|
parts = [headerLine]
|
|
if docLine:
|
|
parts.append(docLine)
|
|
parts.append(" Fields:")
|
|
parts.extend(fieldLines)
|
|
return "\n".join(parts)
|
|
|
|
|
|
def _formatFieldLine(fieldName: str, fieldInfo: Any) -> str:
|
|
"""Format one field as: '<name> (<type>) "<label>": <description> [FK -> Table.field]'."""
|
|
pyType = _summarizePythonType(getattr(fieldInfo, "annotation", None))
|
|
|
|
extra = getattr(fieldInfo, "json_schema_extra", None)
|
|
if not isinstance(extra, dict):
|
|
extra = {}
|
|
|
|
rawLabel = extra.get("label")
|
|
label = rawLabel if isinstance(rawLabel, str) else None
|
|
|
|
rawDesc = getattr(fieldInfo, "description", None)
|
|
desc = rawDesc.strip() if isinstance(rawDesc, str) else ""
|
|
|
|
line = f"{fieldName} ({pyType})"
|
|
if label and label != fieldName:
|
|
line += f' "{label}"'
|
|
if desc:
|
|
line += f": {desc}"
|
|
|
|
fkTarget = extra.get("fk_target")
|
|
if isinstance(fkTarget, dict) and fkTarget.get("table"):
|
|
targetField = fkTarget.get("targetField") or "id"
|
|
line += f" [FK -> {fkTarget['table']}.{targetField}]"
|
|
|
|
return line
|
|
|
|
|
|
def _summarizePythonType(annotation: Any) -> str:
|
|
"""Compact stringification of a Pydantic field annotation for AI prompts."""
|
|
if annotation is None:
|
|
return "any"
|
|
raw = str(annotation)
|
|
raw = raw.replace("typing.", "")
|
|
if raw.startswith("<class '") and raw.endswith("'>"):
|
|
raw = raw[len("<class '"):-len("'>")]
|
|
return raw
|