343 lines
13 KiB
Python
343 lines
13 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""Feature Data Sub-Agent.
|
|
|
|
Specialized mini-agent that queries feature-instance data tables. Receives
|
|
schema context (fields, descriptions) for the selected tables and has two
|
|
tools: browseTable and queryTable. Runs its own agent loop (max 5 rounds,
|
|
low budget) and returns structured results back to the main agent.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any, Callable, Awaitable, Dict, List, Optional
|
|
|
|
from modules.datamodels.datamodelAi import (
|
|
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum,
|
|
)
|
|
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop
|
|
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
|
|
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
|
|
)
|
|
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
|
|
from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
|
|
from modules.shared.i18nRegistry import resolveText
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_MAX_ROUNDS = 8
|
|
_MAX_COST_CHF = 0.15
|
|
|
|
|
|
async def runFeatureDataAgent(
|
|
question: str,
|
|
featureInstanceId: str,
|
|
featureCode: str,
|
|
selectedTables: List[Dict[str, Any]],
|
|
mandateId: str,
|
|
userId: str,
|
|
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
|
|
dbConnector,
|
|
instanceLabel: str = "",
|
|
tableFilters: Optional[Dict[str, Dict[str, str]]] = None,
|
|
requestLang: Optional[str] = None,
|
|
) -> str:
|
|
"""Run the feature data sub-agent and return the textual result.
|
|
|
|
Args:
|
|
question: The user/main-agent question to answer using feature data.
|
|
featureInstanceId: Feature instance to scope queries.
|
|
featureCode: Feature code (trustee, commcoach, ...).
|
|
selectedTables: List of DATA_OBJECT dicts the user selected.
|
|
mandateId: Mandate scope.
|
|
userId: Calling user ID.
|
|
aiCallFn: AI call function (with billing).
|
|
dbConnector: DatabaseConnector for queries.
|
|
instanceLabel: Human-readable instance name for context.
|
|
tableFilters: Per-table record filters from FeatureDataSource.recordFilter.
|
|
requestLang: ISO 639-1 code for resolving multilingual table labels in the schema prompt.
|
|
|
|
Returns:
|
|
Plain-text answer produced by the sub-agent.
|
|
"""
|
|
|
|
provider = FeatureDataProvider(dbConnector)
|
|
registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {})
|
|
|
|
for tbl in selectedTables:
|
|
meta = tbl.get("meta", {})
|
|
tableName = meta.get("table", "")
|
|
if tableName:
|
|
realCols = provider.getActualColumns(tableName)
|
|
if realCols:
|
|
meta["fields"] = realCols
|
|
|
|
systemPrompt = _buildSchemaContext(featureCode, instanceLabel, selectedTables, requestLang)
|
|
|
|
config = AgentConfig(
|
|
maxRounds=_MAX_ROUNDS,
|
|
maxCostCHF=_MAX_COST_CHF,
|
|
operationType=OperationTypeEnum.DATA_QUERY,
|
|
)
|
|
|
|
costAccumulator = 0.0
|
|
|
|
async def _trackingAiCallFn(req):
|
|
nonlocal costAccumulator
|
|
resp = await aiCallFn(req)
|
|
costAccumulator += resp.priceCHF
|
|
return resp
|
|
|
|
async def _getWorkflowCost() -> float:
|
|
return costAccumulator
|
|
|
|
result = ""
|
|
async for event in runAgentLoop(
|
|
prompt=question,
|
|
toolRegistry=registry,
|
|
config=config,
|
|
aiCallFn=_trackingAiCallFn,
|
|
getWorkflowCostFn=_getWorkflowCost,
|
|
workflowId=f"fda-{featureInstanceId[:8]}",
|
|
userId=userId,
|
|
featureInstanceId=featureInstanceId,
|
|
mandateId=mandateId,
|
|
systemPromptOverride=systemPrompt,
|
|
):
|
|
if event.type == AgentEventTypeEnum.FINAL and event.content:
|
|
result = event.content
|
|
elif event.type == AgentEventTypeEnum.MESSAGE and event.content:
|
|
result += event.content
|
|
|
|
return result or "(no data returned by feature agent)"
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# tool registration
|
|
# ------------------------------------------------------------------
|
|
|
|
def _buildSubAgentTools(
|
|
provider: FeatureDataProvider,
|
|
featureInstanceId: str,
|
|
mandateId: str,
|
|
tableFilters: Dict[str, Dict[str, str]] = None,
|
|
) -> ToolRegistry:
|
|
"""Register browseTable and queryTable as sub-agent tools."""
|
|
registry = ToolRegistry()
|
|
_tableFilters = tableFilters or {}
|
|
|
|
def _recordFilterToList(tableName: str) -> Optional[List[Dict[str, Any]]]:
|
|
"""Convert a recordFilter dict to a list of {field, op, value} filter dicts."""
|
|
rf = _tableFilters.get(tableName)
|
|
if not rf:
|
|
return None
|
|
return [{"field": k, "op": "=", "value": v} for k, v in rf.items()]
|
|
|
|
async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]):
|
|
tableName = args.get("tableName", "")
|
|
limit = args.get("limit", 50)
|
|
offset = args.get("offset", 0)
|
|
fields = args.get("fields")
|
|
if not tableName:
|
|
return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required")
|
|
result = provider.browseTable(
|
|
tableName=tableName,
|
|
featureInstanceId=featureInstanceId,
|
|
mandateId=mandateId,
|
|
fields=fields,
|
|
limit=min(limit, 200),
|
|
offset=offset,
|
|
extraFilters=_recordFilterToList(tableName),
|
|
)
|
|
return ToolResult(
|
|
toolCallId="", toolName="browseTable",
|
|
success="error" not in result,
|
|
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
|
|
error=result.get("error"),
|
|
)
|
|
|
|
async def _queryTable(args: Dict[str, Any], context: Dict[str, Any]):
|
|
tableName = args.get("tableName", "")
|
|
filters = args.get("filters", [])
|
|
fields = args.get("fields")
|
|
orderBy = args.get("orderBy")
|
|
limit = args.get("limit", 50)
|
|
offset = args.get("offset", 0)
|
|
if not tableName:
|
|
return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required")
|
|
result = provider.queryTable(
|
|
tableName=tableName,
|
|
featureInstanceId=featureInstanceId,
|
|
mandateId=mandateId,
|
|
filters=filters,
|
|
fields=fields,
|
|
orderBy=orderBy,
|
|
limit=min(limit, 200),
|
|
offset=offset,
|
|
extraFilters=_recordFilterToList(tableName),
|
|
)
|
|
return ToolResult(
|
|
toolCallId="", toolName="queryTable",
|
|
success="error" not in result,
|
|
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
|
|
error=result.get("error"),
|
|
)
|
|
|
|
async def _aggregateTable(args: Dict[str, Any], context: Dict[str, Any]):
|
|
tableName = args.get("tableName", "")
|
|
aggregate = args.get("aggregate", "")
|
|
field = args.get("field", "")
|
|
groupBy = args.get("groupBy")
|
|
if not tableName:
|
|
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="tableName required")
|
|
if not aggregate:
|
|
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="aggregate required (SUM, COUNT, AVG, MIN, MAX)")
|
|
if not field:
|
|
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="field required")
|
|
result = provider.aggregateTable(
|
|
tableName=tableName,
|
|
featureInstanceId=featureInstanceId,
|
|
mandateId=mandateId,
|
|
aggregate=aggregate,
|
|
field=field,
|
|
groupBy=groupBy,
|
|
extraFilters=_recordFilterToList(tableName),
|
|
)
|
|
return ToolResult(
|
|
toolCallId="", toolName="aggregateTable",
|
|
success="error" not in result,
|
|
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
|
|
error=result.get("error"),
|
|
)
|
|
|
|
registry.register(
|
|
"aggregateTable", _aggregateTable,
|
|
description=(
|
|
"Run an aggregate query on a feature data table. "
|
|
"Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY. "
|
|
"Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', field='debitAmount', groupBy='costCenter')"
|
|
),
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"tableName": {"type": "string", "description": "Name of the table to aggregate"},
|
|
"aggregate": {"type": "string", "enum": ["SUM", "COUNT", "AVG", "MIN", "MAX"], "description": "Aggregate function"},
|
|
"field": {"type": "string", "description": "Field to aggregate (e.g. debitAmount, creditAmount)"},
|
|
"groupBy": {"type": "string", "description": "Optional field to group by (e.g. costCenter, accountNumber)"},
|
|
},
|
|
"required": ["tableName", "aggregate", "field"],
|
|
},
|
|
readOnly=True,
|
|
)
|
|
|
|
registry.register(
|
|
"browseTable", _browseTable,
|
|
description="List rows from a feature data table with pagination.",
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"tableName": {"type": "string", "description": "Name of the table to browse"},
|
|
"fields": {
|
|
"type": "array", "items": {"type": "string"},
|
|
"description": "Optional list of fields to return (default: all)",
|
|
},
|
|
"limit": {"type": "integer", "description": "Max rows to return (default 50, max 200)"},
|
|
"offset": {"type": "integer", "description": "Row offset for pagination"},
|
|
},
|
|
"required": ["tableName"],
|
|
},
|
|
readOnly=True,
|
|
)
|
|
|
|
registry.register(
|
|
"queryTable", _queryTable,
|
|
description=(
|
|
"Query a feature data table with filters, field selection, and ordering. "
|
|
"Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. "
|
|
"Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL."
|
|
),
|
|
parameters={
|
|
"type": "object",
|
|
"properties": {
|
|
"tableName": {"type": "string", "description": "Name of the table to query"},
|
|
"filters": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"field": {"type": "string"},
|
|
"op": {"type": "string"},
|
|
"value": {},
|
|
},
|
|
},
|
|
"description": "Filter conditions",
|
|
},
|
|
"fields": {
|
|
"type": "array", "items": {"type": "string"},
|
|
"description": "Optional list of fields to return",
|
|
},
|
|
"orderBy": {"type": "string", "description": "Field name to order by"},
|
|
"limit": {"type": "integer", "description": "Max rows (default 50, max 200)"},
|
|
"offset": {"type": "integer", "description": "Row offset"},
|
|
},
|
|
"required": ["tableName"],
|
|
},
|
|
readOnly=True,
|
|
)
|
|
|
|
return registry
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# context building
|
|
# ------------------------------------------------------------------
|
|
|
|
def _buildSchemaContext(
|
|
featureCode: str,
|
|
instanceLabel: str,
|
|
selectedTables: List[Dict[str, Any]],
|
|
requestLang: Optional[str] = None,
|
|
) -> str:
|
|
"""Build a system prompt describing available tables and query strategy."""
|
|
tableNames = []
|
|
tableBlocks = []
|
|
|
|
for obj in selectedTables:
|
|
meta = obj.get("meta", {})
|
|
tbl = meta.get("table", "?")
|
|
fields = meta.get("fields", [])
|
|
labelStr = resolveText(obj.get("label"), requestLang)
|
|
tableNames.append(tbl)
|
|
block = f" Table: {tbl} ({labelStr})"
|
|
if fields:
|
|
block += f"\n Fields: {', '.join(fields)}"
|
|
tableBlocks.append(block)
|
|
|
|
header = f"You are a data query assistant for the '{featureCode}' feature"
|
|
if instanceLabel:
|
|
header += f' (instance: "{instanceLabel}")'
|
|
header += "."
|
|
|
|
parts = [
|
|
header,
|
|
"",
|
|
"AVAILABLE TABLES (use EXACTLY these names as tableName parameter):",
|
|
*tableBlocks,
|
|
"",
|
|
f"Valid tableName values: {tableNames}",
|
|
"Field names are plain column names (e.g. 'accountNumber', 'periodYear').",
|
|
"",
|
|
"QUERY STRATEGY:",
|
|
"1. If unsure about columns, call browseTable(tableName) first to inspect the schema.",
|
|
"2. Use queryTable with filters for targeted lookups.",
|
|
"3. Use aggregateTable for SUM/COUNT/AVG/MIN/MAX with optional GROUP BY.",
|
|
"4. Combine what you need into as few tool calls as possible.",
|
|
"",
|
|
"RULES:",
|
|
"- Do NOT invent table or field names. Do NOT prefix fields with UUIDs or dots.",
|
|
"- CRITICAL: Return data as compact JSON, NOT as markdown tables or prose.",
|
|
"- Do NOT reformat, rewrite, or narrate the tool results. Return the raw data directly.",
|
|
"- If the question asks for rows, return them as a JSON array. Do NOT generate a markdown table.",
|
|
"- Keep your answer SHORT. The caller is a machine, not a human.",
|
|
]
|
|
return "\n".join(parts)
|