gateway/modules/serviceCenter/services/serviceAgent/featureDataAgent.py

314 lines
12 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""Feature Data Sub-Agent.
Specialized mini-agent that queries feature-instance data tables. Receives
schema context (fields, descriptions) for the selected tables and has two
tools: browseTable and queryTable. Runs its own agent loop (max 5 rounds,
low budget) and returns structured results back to the main agent.
"""
import json
import logging
from typing import Any, Callable, Awaitable, Dict, List, Optional
from modules.datamodels.datamodelAi import (
AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum,
)
from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop
from modules.serviceCenter.services.serviceAgent.datamodelAgent import (
AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult,
)
from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry
from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider
logger = logging.getLogger(__name__)
_MAX_ROUNDS = 5
_MAX_COST_CHF = 0.10
async def runFeatureDataAgent(
question: str,
featureInstanceId: str,
featureCode: str,
selectedTables: List[Dict[str, Any]],
mandateId: str,
userId: str,
aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]],
dbConnector,
instanceLabel: str = "",
tableFilters: Optional[Dict[str, Dict[str, str]]] = None,
) -> str:
"""Run the feature data sub-agent and return the textual result.
Args:
question: The user/main-agent question to answer using feature data.
featureInstanceId: Feature instance to scope queries.
featureCode: Feature code (trustee, commcoach, ...).
selectedTables: List of DATA_OBJECT dicts the user selected.
mandateId: Mandate scope.
userId: Calling user ID.
aiCallFn: AI call function (with billing).
dbConnector: DatabaseConnector for queries.
instanceLabel: Human-readable instance name for context.
tableFilters: Per-table record filters from FeatureDataSource.recordFilter.
Returns:
Plain-text answer produced by the sub-agent.
"""
provider = FeatureDataProvider(dbConnector)
registry = _buildSubAgentTools(provider, featureInstanceId, mandateId, tableFilters or {})
for tbl in selectedTables:
meta = tbl.get("meta", {})
tableName = meta.get("table", "")
if tableName:
realCols = provider.getActualColumns(tableName)
if realCols:
meta["fields"] = realCols
schemaContext = _buildSchemaContext(featureCode, instanceLabel, selectedTables)
prompt = f"{schemaContext}\n\nUser question:\n{question}"
config = AgentConfig(maxRounds=_MAX_ROUNDS, maxCostCHF=_MAX_COST_CHF)
async def _getWorkflowCost() -> float:
return 0.0
result = ""
async for event in runAgentLoop(
prompt=prompt,
toolRegistry=registry,
config=config,
aiCallFn=aiCallFn,
getWorkflowCostFn=_getWorkflowCost,
workflowId=f"fda-{featureInstanceId[:8]}",
userId=userId,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
):
if event.type == AgentEventTypeEnum.FINAL and event.content:
result = event.content
elif event.type == AgentEventTypeEnum.MESSAGE and event.content:
result += event.content
return result or "(no data returned by feature agent)"
# ------------------------------------------------------------------
# tool registration
# ------------------------------------------------------------------
def _buildSubAgentTools(
provider: FeatureDataProvider,
featureInstanceId: str,
mandateId: str,
tableFilters: Dict[str, Dict[str, str]] = None,
) -> ToolRegistry:
"""Register browseTable and queryTable as sub-agent tools."""
registry = ToolRegistry()
_tableFilters = tableFilters or {}
def _recordFilterToList(tableName: str) -> Optional[List[Dict[str, Any]]]:
"""Convert a recordFilter dict to a list of {field, op, value} filter dicts."""
rf = _tableFilters.get(tableName)
if not rf:
return None
return [{"field": k, "op": "=", "value": v} for k, v in rf.items()]
async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]):
tableName = args.get("tableName", "")
limit = args.get("limit", 50)
offset = args.get("offset", 0)
fields = args.get("fields")
if not tableName:
return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required")
result = provider.browseTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
fields=fields,
limit=min(limit, 200),
offset=offset,
extraFilters=_recordFilterToList(tableName),
)
return ToolResult(
toolCallId="", toolName="browseTable",
success="error" not in result,
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
error=result.get("error"),
)
async def _queryTable(args: Dict[str, Any], context: Dict[str, Any]):
tableName = args.get("tableName", "")
filters = args.get("filters", [])
fields = args.get("fields")
orderBy = args.get("orderBy")
limit = args.get("limit", 50)
offset = args.get("offset", 0)
if not tableName:
return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required")
result = provider.queryTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
filters=filters,
fields=fields,
orderBy=orderBy,
limit=min(limit, 200),
offset=offset,
extraFilters=_recordFilterToList(tableName),
)
return ToolResult(
toolCallId="", toolName="queryTable",
success="error" not in result,
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
error=result.get("error"),
)
async def _aggregateTable(args: Dict[str, Any], context: Dict[str, Any]):
tableName = args.get("tableName", "")
aggregate = args.get("aggregate", "")
field = args.get("field", "")
groupBy = args.get("groupBy")
if not tableName:
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="tableName required")
if not aggregate:
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="aggregate required (SUM, COUNT, AVG, MIN, MAX)")
if not field:
return ToolResult(toolCallId="", toolName="aggregateTable", success=False, error="field required")
result = provider.aggregateTable(
tableName=tableName,
featureInstanceId=featureInstanceId,
mandateId=mandateId,
aggregate=aggregate,
field=field,
groupBy=groupBy,
extraFilters=_recordFilterToList(tableName),
)
return ToolResult(
toolCallId="", toolName="aggregateTable",
success="error" not in result,
data=json.dumps(result, default=str, ensure_ascii=False)[:30000],
error=result.get("error"),
)
registry.register(
"aggregateTable", _aggregateTable,
description=(
"Run an aggregate query on a feature data table. "
"Supports SUM, COUNT, AVG, MIN, MAX with optional GROUP BY. "
"Example: aggregateTable(tableName='TrusteeDataJournalLine', aggregate='SUM', field='debitAmount', groupBy='costCenter')"
),
parameters={
"type": "object",
"properties": {
"tableName": {"type": "string", "description": "Name of the table to aggregate"},
"aggregate": {"type": "string", "enum": ["SUM", "COUNT", "AVG", "MIN", "MAX"], "description": "Aggregate function"},
"field": {"type": "string", "description": "Field to aggregate (e.g. debitAmount, creditAmount)"},
"groupBy": {"type": "string", "description": "Optional field to group by (e.g. costCenter, accountNumber)"},
},
"required": ["tableName", "aggregate", "field"],
},
readOnly=True,
)
registry.register(
"browseTable", _browseTable,
description="List rows from a feature data table with pagination.",
parameters={
"type": "object",
"properties": {
"tableName": {"type": "string", "description": "Name of the table to browse"},
"fields": {
"type": "array", "items": {"type": "string"},
"description": "Optional list of fields to return (default: all)",
},
"limit": {"type": "integer", "description": "Max rows to return (default 50, max 200)"},
"offset": {"type": "integer", "description": "Row offset for pagination"},
},
"required": ["tableName"],
},
readOnly=True,
)
registry.register(
"queryTable", _queryTable,
description=(
"Query a feature data table with filters, field selection, and ordering. "
"Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. "
"Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL."
),
parameters={
"type": "object",
"properties": {
"tableName": {"type": "string", "description": "Name of the table to query"},
"filters": {
"type": "array",
"items": {
"type": "object",
"properties": {
"field": {"type": "string"},
"op": {"type": "string"},
"value": {},
},
},
"description": "Filter conditions",
},
"fields": {
"type": "array", "items": {"type": "string"},
"description": "Optional list of fields to return",
},
"orderBy": {"type": "string", "description": "Field name to order by"},
"limit": {"type": "integer", "description": "Max rows (default 50, max 200)"},
"offset": {"type": "integer", "description": "Row offset"},
},
"required": ["tableName"],
},
readOnly=True,
)
return registry
# ------------------------------------------------------------------
# context building
# ------------------------------------------------------------------
def _buildSchemaContext(
featureCode: str,
instanceLabel: str,
selectedTables: List[Dict[str, Any]],
) -> str:
"""Build a system-level context block describing available tables."""
parts = [
f"You are a data query assistant for the '{featureCode}' feature",
]
if instanceLabel:
parts[0] += f' (instance: "{instanceLabel}")'
parts[0] += "."
parts.append(
"You have access to the following data tables. "
"Use browseTable to list rows, queryTable to filter/search, "
"and aggregateTable for SUM/COUNT/AVG/MIN/MAX with optional GROUP BY."
)
parts.append("")
for obj in selectedTables:
meta = obj.get("meta", {})
tbl = meta.get("table", "?")
fields = meta.get("fields", [])
label = obj.get("label", {})
labelStr = label.get("en") or label.get("de") or tbl
parts.append(f"Table: {tbl} ({labelStr})")
if fields:
parts.append(f" Fields: {', '.join(fields)}")
parts.append("")
parts.append(
"Answer the user's question using the data from these tables. "
"Be precise, cite row counts, and format data clearly."
)
return "\n".join(parts)