diff --git a/modules/aicore/aicorePluginAnthropic.py b/modules/aicore/aicorePluginAnthropic.py index 8b6ec197..81a2175e 100644 --- a/modules/aicore/aicorePluginAnthropic.py +++ b/modules/aicore/aicorePluginAnthropic.py @@ -4,7 +4,7 @@ import json import logging import httpx import os -from typing import Dict, Any, List, AsyncGenerator, Union +from typing import Dict, Any, List, AsyncGenerator, Optional, Union from fastapi import HTTPException from modules.shared.configuration import APP_CONFIG from .aicoreBase import BaseConnectorAi, RateLimitExceededException @@ -295,6 +295,7 @@ class AiAnthropic(BaseConnectorAi): fullContent = "" toolUseBlocks: Dict[int, Dict[str, Any]] = {} currentToolIdx = -1 + stopReason: Optional[str] = None async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response: if response.status_code != 200: @@ -316,7 +317,16 @@ class AiAnthropic(BaseConnectorAi): eventType = event.get("type", "") - if eventType == "content_block_start": + if eventType == "error": + errDetail = event.get("error", {}) + errMsg = errDetail.get("message", str(errDetail)) + errType = errDetail.get("type", "unknown") + logger.error(f"Anthropic stream error event: type={errType}, message={errMsg}") + if "overloaded" in errMsg.lower() or "overloaded" in errType.lower(): + raise HTTPException(status_code=500, detail=f"Anthropic API is currently overloaded. Please try again in a few minutes.") + raise HTTPException(status_code=500, detail=f"Anthropic stream error: [{errType}] {errMsg}") + + elif eventType == "content_block_start": block = event.get("content_block", {}) idx = event.get("index", 0) if block.get("type") == "tool_use": @@ -338,10 +348,22 @@ class AiAnthropic(BaseConnectorAi): if idx in toolUseBlocks: toolUseBlocks[idx]["arguments"] += delta.get("partial_json", "") + elif eventType == "message_delta": + delta = event.get("delta", {}) + stopReason = delta.get("stop_reason", stopReason) + elif eventType == "message_stop": break + if not fullContent and not toolUseBlocks: + logger.warning( + f"Anthropic stream returned empty response: model={model.name}, " + f"stopReason={stopReason}" + ) + metadata: Dict[str, Any] = {} + if stopReason: + metadata["stopReason"] = stopReason if toolUseBlocks: metadata["toolCalls"] = [ { diff --git a/modules/aicore/aicorePluginMistral.py b/modules/aicore/aicorePluginMistral.py index 8c4fb6d9..885addcf 100644 --- a/modules/aicore/aicorePluginMistral.py +++ b/modules/aicore/aicorePluginMistral.py @@ -174,7 +174,11 @@ class AiMistral(BaseConnectorAi): "temperature": temperature, "max_tokens": maxTokens } - + + if modelCall.tools: + payload["tools"] = modelCall.tools + payload["tool_choice"] = modelCall.toolChoice or "auto" + response = await self.httpClient.post( model.apiUrl, json=payload @@ -214,15 +218,20 @@ class AiMistral(BaseConnectorAi): raise HTTPException(status_code=500, detail=error_message) responseJson = response.json() - content = responseJson["choices"][0]["message"]["content"] - + choiceMessage = responseJson["choices"][0]["message"] + content = choiceMessage.get("content") or "" + + metadata = {"response_id": responseJson.get("id", "")} + if choiceMessage.get("tool_calls"): + metadata["toolCalls"] = choiceMessage["tool_calls"] + return AiModelResponse( content=content, success=True, modelId=model.name, - metadata={"response_id": responseJson.get("id", "")} + metadata=metadata, ) - + except ContextLengthExceededException: # Re-raise context length exceptions without wrapping raise @@ -250,7 +259,12 @@ class AiMistral(BaseConnectorAi): "stream": True, } + if modelCall.tools: + payload["tools"] = modelCall.tools + payload["tool_choice"] = modelCall.toolChoice or "auto" + fullContent = "" + toolCallsAccum: Dict[int, Dict[str, Any]] = {} async with self.httpClient.stream("POST", model.apiUrl, json=payload) as response: if response.status_code != 200: @@ -280,11 +294,31 @@ class AiMistral(BaseConnectorAi): fullContent += delta["content"] yield delta["content"] + for tcDelta in delta.get("tool_calls", []): + idx = tcDelta.get("index", 0) + if idx not in toolCallsAccum: + toolCallsAccum[idx] = { + "id": tcDelta.get("id", ""), + "type": "function", + "function": {"name": "", "arguments": ""}, + } + if tcDelta.get("id"): + toolCallsAccum[idx]["id"] = tcDelta["id"] + fn = tcDelta.get("function", {}) + if fn.get("name"): + toolCallsAccum[idx]["function"]["name"] = fn["name"] + if fn.get("arguments"): + toolCallsAccum[idx]["function"]["arguments"] += fn["arguments"] + + metadata: Dict[str, Any] = {} + if toolCallsAccum: + metadata["toolCalls"] = [toolCallsAccum[i] for i in sorted(toolCallsAccum)] + yield AiModelResponse( content=fullContent, success=True, modelId=model.name, - metadata={}, + metadata=metadata, ) except (RateLimitExceededException, ContextLengthExceededException, HTTPException): diff --git a/modules/datamodels/datamodelFeatureDataSource.py b/modules/datamodels/datamodelFeatureDataSource.py new file mode 100644 index 00000000..89b8b372 --- /dev/null +++ b/modules/datamodels/datamodelFeatureDataSource.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""FeatureDataSource model for exposing feature instance data to the AI workspace. + +A FeatureDataSource links a FeatureInstance table (DATA_OBJECT) to a workspace +so the agent can query structured feature data (e.g. TrusteePosition rows). +""" + +from typing import Optional +from pydantic import BaseModel, Field +from modules.shared.attributeUtils import registerModelLabels +from modules.shared.timeUtils import getUtcTimestamp +import uuid + + +class FeatureDataSource(BaseModel): + """A feature-instance table attached as data source in the AI workspace.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") + featureInstanceId: str = Field(description="FK to FeatureInstance") + featureCode: str = Field(description="Feature code (e.g. trustee, commcoach)") + tableName: str = Field(description="Table name from DATA_OBJECTS meta (e.g. TrusteePosition)") + objectKey: str = Field(description="RBAC object key (e.g. data.feature.trustee.TrusteePosition)") + label: str = Field(description="User-visible label") + mandateId: str = Field(default="", description="Mandate scope") + userId: str = Field(default="", description="Owner user ID") + workspaceInstanceId: str = Field(description="Workspace instance where this source is used") + createdAt: float = Field(default_factory=getUtcTimestamp, description="Creation timestamp") + + +registerModelLabels( + "FeatureDataSource", + {"en": "Feature Data Source", "de": "Feature-Datenquelle", "fr": "Source de données fonctionnalité"}, + { + "id": {"en": "ID", "de": "ID", "fr": "ID"}, + "featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"}, + "featureCode": {"en": "Feature", "de": "Feature", "fr": "Fonctionnalité"}, + "tableName": {"en": "Table", "de": "Tabelle", "fr": "Table"}, + "objectKey": {"en": "Object Key", "de": "Objekt-Schlüssel", "fr": "Clé objet"}, + "label": {"en": "Label", "de": "Bezeichnung", "fr": "Libellé"}, + "mandateId": {"en": "Mandate", "de": "Mandant", "fr": "Mandat"}, + "userId": {"en": "User", "de": "Benutzer", "fr": "Utilisateur"}, + "workspaceInstanceId": {"en": "Workspace", "de": "Workspace", "fr": "Espace de travail"}, + "createdAt": {"en": "Created At", "de": "Erstellt am", "fr": "Créé le"}, + }, +) diff --git a/modules/features/trustee/accounting/accountingConnectorBase.py b/modules/features/trustee/accounting/accountingConnectorBase.py index 2cfa4a54..44044729 100644 --- a/modules/features/trustee/accounting/accountingConnectorBase.py +++ b/modules/features/trustee/accounting/accountingConnectorBase.py @@ -118,6 +118,13 @@ class BaseAccountingConnector(ABC): """Load the vendor list. Override in connectors that support it.""" return [] + async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]: + """Read journal entries from the external system. Each entry should contain: + - externalId, bookingDate, reference, description, currency, totalAmount + - lines: list of {accountNumber, debitAmount, creditAmount, currency, taxCode, costCenter, description} + accountNumbers: pre-fetched account numbers (avoids redundant API call). Override in connectors that support it.""" + return [] + async def uploadDocument( self, config: Dict[str, Any], diff --git a/modules/features/trustee/accounting/accountingDataSync.py b/modules/features/trustee/accounting/accountingDataSync.py new file mode 100644 index 00000000..d393de76 --- /dev/null +++ b/modules/features/trustee/accounting/accountingDataSync.py @@ -0,0 +1,306 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Orchestrates importing accounting data from external systems into TrusteeData* tables. + +Flow: load config → resolve connector → fetch data → clear old records → write new records → compute balances. +""" + +import logging +import time +from collections import defaultdict +from typing import Dict, Any, Optional + +from .accountingConnectorBase import BaseAccountingConnector +from .accountingRegistry import _getAccountingRegistry + +logger = logging.getLogger(__name__) + + +class AccountingDataSync: + """Imports accounting data (read-only) from an external system into local TrusteeData* tables.""" + + def __init__(self, trusteeInterface): + self._if = trusteeInterface + self._registry = _getAccountingRegistry() + + async def importData( + self, + featureInstanceId: str, + mandateId: str, + dateFrom: Optional[str] = None, + dateTo: Optional[str] = None, + ) -> Dict[str, Any]: + """Run a full data import for a feature instance. + + Returns a summary dict with counts per entity and any errors. + """ + from modules.features.trustee.datamodelFeatureTrustee import ( + TrusteeAccountingConfig, + TrusteeDataAccount, + TrusteeDataJournalEntry, + TrusteeDataJournalLine, + TrusteeDataContact, + TrusteeDataAccountBalance, + ) + from modules.shared.configuration import decryptValue + + summary: Dict[str, Any] = { + "accounts": 0, + "journalEntries": 0, + "journalLines": 0, + "contacts": 0, + "accountBalances": 0, + "errors": [], + "startedAt": time.time(), + } + + cfgRecords = self._if.db.getRecordset( + TrusteeAccountingConfig, + recordFilter={"featureInstanceId": featureInstanceId, "isActive": True}, + ) + if not cfgRecords: + summary["errors"].append("No active accounting configuration found") + return summary + + cfgRecord = cfgRecords[0] + connectorType = cfgRecord.get("connectorType", "") + encryptedConfig = cfgRecord.get("encryptedConfig", "") + + try: + import json + plainJson = decryptValue(encryptedConfig) + connConfig = json.loads(plainJson) if plainJson else {} + except Exception as e: + summary["errors"].append(f"Failed to decrypt config: {e}") + return summary + + connector = self._registry.getConnector(connectorType) + if not connector: + summary["errors"].append(f"Unknown connector type: {connectorType}") + return summary + + scope = {"featureInstanceId": featureInstanceId, "mandateId": mandateId} + logger.info(f"AccountingDataSync starting for {featureInstanceId}, connector={connectorType}, dateFrom={dateFrom}, dateTo={dateTo}") + fetchedAccountNumbers: list = [] + + # 1) Chart of accounts + try: + charts = await connector.getChartOfAccounts(connConfig) + fetchedAccountNumbers = [acc.accountNumber for acc in charts if acc.accountNumber] + self._clearTable(TrusteeDataAccount, featureInstanceId) + for acc in charts: + self._if.db.recordCreate(TrusteeDataAccount, { + "accountNumber": acc.accountNumber, + "label": acc.label, + "accountType": acc.accountType or "", + "currency": "CHF", + "isActive": True, + **scope, + }) + summary["accounts"] = len(charts) + except Exception as e: + logger.error(f"Import accounts failed: {e}", exc_info=True) + summary["errors"].append(f"Accounts: {e}") + + # 2) Journal entries + lines (pass already-fetched chart to avoid redundant API call) + try: + rawEntries = await connector.getJournalEntries(connConfig, dateFrom=dateFrom, dateTo=dateTo, accountNumbers=fetchedAccountNumbers or None) + self._clearTable(TrusteeDataJournalEntry, featureInstanceId) + self._clearTable(TrusteeDataJournalLine, featureInstanceId) + lineCount = 0 + for raw in rawEntries: + import uuid + entryId = str(uuid.uuid4()) + self._if.db.recordCreate(TrusteeDataJournalEntry, { + "id": entryId, + "externalId": raw.get("externalId"), + "bookingDate": raw.get("bookingDate"), + "reference": raw.get("reference"), + "description": raw.get("description", ""), + "currency": raw.get("currency", "CHF"), + "totalAmount": float(raw.get("totalAmount", 0)), + **scope, + }) + for line in (raw.get("lines") or []): + self._if.db.recordCreate(TrusteeDataJournalLine, { + "journalEntryId": entryId, + "accountNumber": line.get("accountNumber", ""), + "debitAmount": float(line.get("debitAmount", 0)), + "creditAmount": float(line.get("creditAmount", 0)), + "currency": line.get("currency", "CHF"), + "taxCode": line.get("taxCode"), + "costCenter": line.get("costCenter"), + "description": line.get("description", ""), + **scope, + }) + lineCount += 1 + summary["journalEntries"] = len(rawEntries) + summary["journalLines"] = lineCount + except Exception as e: + logger.error(f"Import journal entries failed: {e}") + summary["errors"].append(f"Journal entries: {e}") + + # 3) Contacts (customers + vendors) + try: + self._clearTable(TrusteeDataContact, featureInstanceId) + contactCount = 0 + + customers = await connector.getCustomers(connConfig) + for c in customers: + self._if.db.recordCreate(TrusteeDataContact, self._mapContact(c, "customer", scope)) + contactCount += 1 + + vendors = await connector.getVendors(connConfig) + for v in vendors: + self._if.db.recordCreate(TrusteeDataContact, self._mapContact(v, "vendor", scope)) + contactCount += 1 + + summary["contacts"] = contactCount + except Exception as e: + logger.error(f"Import contacts failed: {e}", exc_info=True) + summary["errors"].append(f"Contacts: {e}") + + # 4) Compute account balances from journal lines + try: + self._clearTable(TrusteeDataAccountBalance, featureInstanceId) + balanceCount = self._computeBalances(featureInstanceId, mandateId) + summary["accountBalances"] = balanceCount + except Exception as e: + logger.error(f"Compute balances failed: {e}") + summary["errors"].append(f"Balances: {e}") + + # Update config with last import timestamp + try: + cfgId = cfgRecord.get("id") + if cfgId: + self._if.db.recordModify(TrusteeAccountingConfig, cfgId, { + "lastSyncAt": time.time(), + "lastSyncStatus": "success" if not summary["errors"] else "partial", + "lastSyncErrorMessage": "; ".join(summary["errors"])[:500] if summary["errors"] else None, + }) + except Exception: + pass + + summary["finishedAt"] = time.time() + summary["durationSeconds"] = round(summary["finishedAt"] - summary["startedAt"], 1) + logger.info( + f"AccountingDataSync completed for {featureInstanceId}: " + f"{summary['accounts']} accounts, {summary['journalEntries']} entries, " + f"{summary['journalLines']} lines, {summary['contacts']} contacts, " + f"{summary['accountBalances']} balances, {len(summary['errors'])} errors, " + f"{summary['durationSeconds']}s" + ) + return summary + + @staticmethod + def _safeStr(val: Any) -> str: + """Convert a value to a safe string for DB storage, collapsing nested dicts/lists.""" + if val is None: + return "" + if isinstance(val, (dict, list)): + return "" + return str(val) + + def _mapContact(self, raw: Dict[str, Any], contactType: str, scope: Dict[str, Any]) -> Dict[str, Any]: + """Extract contact fields from a raw API dict, handling varying field names across connectors.""" + s = self._safeStr + return { + "externalId": s(raw.get("id") or raw.get("Id") or raw.get("customer_nr") or raw.get("vendor_nr") or ""), + "contactType": contactType, + "contactNumber": s( + raw.get("customernumber") or raw.get("customer_nr") + or raw.get("vendornumber") or raw.get("vendor_nr") + or raw.get("nr") or raw.get("ContactNumber") + or raw.get("id") or "" + ), + "name": s(raw.get("name") or raw.get("Name") or raw.get("name_1") or ""), + "address": s(raw.get("addr1") or raw.get("address") or raw.get("Address") or ""), + "zip": s(raw.get("zipcode") or raw.get("postcode") or raw.get("Zip") or raw.get("zip") or ""), + "city": s(raw.get("city") or raw.get("City") or ""), + "country": s(raw.get("country") or raw.get("country_id") or raw.get("Country") or ""), + "email": s(raw.get("email") or raw.get("mail") or raw.get("Email") or ""), + "phone": s(raw.get("phone") or raw.get("phone_fixed") or raw.get("Phone") or ""), + "vatNumber": s(raw.get("vat_identifier") or raw.get("vatNumber") or ""), + **scope, + } + + def _clearTable(self, model, featureInstanceId: str): + """Delete all records for this feature instance from a TrusteeData* table.""" + records = self._if.db.getRecordset(model, recordFilter={"featureInstanceId": featureInstanceId}) + for r in (records or []): + rid = r.get("id") if isinstance(r, dict) else getattr(r, "id", None) + if rid: + try: + self._if.db.recordDelete(model, rid) + except Exception: + pass + + def _computeBalances(self, featureInstanceId: str, mandateId: str) -> int: + """Aggregate journal lines into monthly + annual account balances.""" + from modules.features.trustee.datamodelFeatureTrustee import ( + TrusteeDataJournalEntry, + TrusteeDataJournalLine, + TrusteeDataAccountBalance, + ) + + entries = self._if.db.getRecordset( + TrusteeDataJournalEntry, + recordFilter={"featureInstanceId": featureInstanceId}, + ) or [] + entryDates = {} + for e in entries: + eid = e.get("id") if isinstance(e, dict) else getattr(e, "id", None) + bdate = e.get("bookingDate") if isinstance(e, dict) else getattr(e, "bookingDate", None) + if eid and bdate: + entryDates[eid] = bdate + + lines = self._if.db.getRecordset( + TrusteeDataJournalLine, + recordFilter={"featureInstanceId": featureInstanceId}, + ) or [] + + # key: (accountNumber, year, month) + buckets: Dict[tuple, Dict[str, float]] = defaultdict(lambda: {"debit": 0.0, "credit": 0.0}) + for ln in lines: + if isinstance(ln, dict): + jeid = ln.get("journalEntryId", "") + accNo = ln.get("accountNumber", "") + debit = float(ln.get("debitAmount", 0)) + credit = float(ln.get("creditAmount", 0)) + else: + jeid = getattr(ln, "journalEntryId", "") + accNo = getattr(ln, "accountNumber", "") + debit = float(getattr(ln, "debitAmount", 0)) + credit = float(getattr(ln, "creditAmount", 0)) + + bdate = entryDates.get(jeid, "") + if not accNo or not bdate: + continue + parts = bdate.split("-") + if len(parts) < 2: + continue + year = int(parts[0]) + month = int(parts[1]) + + buckets[(accNo, year, month)]["debit"] += debit + buckets[(accNo, year, month)]["credit"] += credit + buckets[(accNo, year, 0)]["debit"] += debit + buckets[(accNo, year, 0)]["credit"] += credit + + count = 0 + scope = {"featureInstanceId": featureInstanceId, "mandateId": mandateId} + for (accNo, year, month), totals in buckets.items(): + closing = totals["debit"] - totals["credit"] + self._if.db.recordCreate(TrusteeDataAccountBalance, { + "accountNumber": accNo, + "periodYear": year, + "periodMonth": month, + "openingBalance": 0.0, + "debitTotal": round(totals["debit"], 2), + "creditTotal": round(totals["credit"], 2), + "closingBalance": round(closing, 2), + "currency": "CHF", + **scope, + }) + count += 1 + return count diff --git a/modules/features/trustee/accounting/connectors/accountingConnectorAbacus.py b/modules/features/trustee/accounting/connectors/accountingConnectorAbacus.py index b8d6a03d..eec3fef0 100644 --- a/modules/features/trustee/accounting/connectors/accountingConnectorAbacus.py +++ b/modules/features/trustee/accounting/connectors/accountingConnectorAbacus.py @@ -255,6 +255,60 @@ class AccountingConnectorAbacus(BaseAccountingConnector): except Exception as e: return SyncResult(success=False, errorMessage=str(e)) + async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]: + """Read GeneralJournalEntries from Abacus (OData V4, paginated).""" + headers = await self._buildAuthHeaders(config) + if not headers: + return [] + + filterParts = [] + if dateFrom: + filterParts.append(f"JournalDate ge {dateFrom}") + if dateTo: + filterParts.append(f"JournalDate le {dateTo}") + queryParams = "" + if filterParts: + queryParams = "?$filter=" + " and ".join(filterParts) + + entries: List[Dict[str, Any]] = [] + url: Optional[str] = self._buildEntityUrl(config, f"GeneralJournalEntries{queryParams}") + try: + async with aiohttp.ClientSession() as session: + while url: + async with session.get(url, headers=headers, timeout=aiohttp.ClientTimeout(total=60)) as resp: + if resp.status != 200: + break + data = await resp.json() + + for item in data.get("value", []): + lines = [] + totalAmt = 0.0 + for line in (item.get("Lines") or []): + debit = float(line.get("DebitAmount", 0)) + credit = float(line.get("CreditAmount", 0)) + lines.append({ + "accountNumber": str(line.get("AccountId", "")), + "debitAmount": debit, + "creditAmount": credit, + "description": line.get("Text", ""), + "taxCode": line.get("TaxCode"), + "costCenter": line.get("CostCenterId"), + }) + totalAmt += max(debit, credit) + entries.append({ + "externalId": str(item.get("Id", "")), + "bookingDate": str(item.get("JournalDate", "")).split("T")[0], + "reference": item.get("Reference", ""), + "description": item.get("Text", ""), + "currency": "CHF", + "totalAmount": totalAmt, + "lines": lines, + }) + url = data.get("@odata.nextLink") + except Exception as e: + logger.error(f"Abacus getJournalEntries error: {e}") + return entries + async def getCustomers(self, config: Dict[str, Any]) -> List[Dict[str, Any]]: headers = await self._buildAuthHeaders(config) if not headers: diff --git a/modules/features/trustee/accounting/connectors/accountingConnectorBexio.py b/modules/features/trustee/accounting/connectors/accountingConnectorBexio.py index a5487a82..a1e588d6 100644 --- a/modules/features/trustee/accounting/connectors/accountingConnectorBexio.py +++ b/modules/features/trustee/accounting/connectors/accountingConnectorBexio.py @@ -193,6 +193,62 @@ class AccountingConnectorBexio(BaseAccountingConnector): except Exception as e: return SyncResult(success=False, errorMessage=str(e)) + async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]: + """Read manual entries from Bexio. API: GET 3.0/accounting/manual-entries""" + try: + accounts = await self._loadRawAccounts(config) + accMap = {acc.get("id"): str(acc.get("account_no", "")) for acc in accounts} + + async with aiohttp.ClientSession() as session: + url = self._buildUrl(config, "3.0/accounting/manual-entries") + params: Dict[str, str] = {} + if dateFrom: + params["date_from"] = dateFrom + if dateTo: + params["date_to"] = dateTo + async with session.get(url, headers=self._buildHeaders(config), params=params, timeout=aiohttp.ClientTimeout(total=60)) as resp: + if resp.status != 200: + logger.error(f"Bexio getJournalEntries failed: HTTP {resp.status}") + return [] + items = await resp.json() + + entries = [] + for item in (items if isinstance(items, list) else []): + lines = [] + totalAmt = 0.0 + for e in (item.get("entries") or []): + amt = float(e.get("amount", 0)) + debitAccId = e.get("debit_account_id") + creditAccId = e.get("credit_account_id") + lines.append({ + "accountNumber": accMap.get(debitAccId, str(debitAccId or "")), + "debitAmount": amt, + "creditAmount": 0.0, + "description": e.get("description", ""), + "taxCode": str(e.get("tax_id", "")) if e.get("tax_id") else None, + }) + if creditAccId and creditAccId != debitAccId: + lines.append({ + "accountNumber": accMap.get(creditAccId, str(creditAccId or "")), + "debitAmount": 0.0, + "creditAmount": amt, + "description": e.get("description", ""), + }) + totalAmt += amt + entries.append({ + "externalId": str(item.get("id", "")), + "bookingDate": item.get("date", ""), + "reference": item.get("reference_nr", ""), + "description": item.get("text", ""), + "currency": "CHF", + "totalAmount": totalAmt, + "lines": lines, + }) + return entries + except Exception as e: + logger.error(f"Bexio getJournalEntries error: {e}") + return [] + async def getCustomers(self, config: Dict[str, Any]) -> List[Dict[str, Any]]: try: async with aiohttp.ClientSession() as session: diff --git a/modules/features/trustee/accounting/connectors/accountingConnectorRma.py b/modules/features/trustee/accounting/connectors/accountingConnectorRma.py index fa93ff40..15aa7ca9 100644 --- a/modules/features/trustee/accounting/connectors/accountingConnectorRma.py +++ b/modules/features/trustee/accounting/connectors/accountingConnectorRma.py @@ -150,11 +150,11 @@ class AccountingConnectorRma(BaseAccountingConnector): charts = [] items = data if isinstance(data, list) else data.get("chart", data.get("row", [])) if not isinstance(items, list): - items = [] + items = [items] if isinstance(items, dict) else [] for item in items: if isinstance(item, dict): - accNo = str(item.get("accno", item.get("account_number", ""))) - label = str(item.get("description", item.get("label", ""))) + accNo = str(item.get("accno") or item.get("account_number") or item.get("number") or item.get("@accno") or "") + label = str(item.get("description") or item.get("label") or item.get("@description") or "") rmaLink = item.get("link") or "" chartType = item.get("charttype") or item.get("category") or "" if not chartType and rmaLink: @@ -338,6 +338,169 @@ class AccountingConnectorRma(BaseAccountingConnector): logger.debug("RMA isBookingSynced error: %s – trust local", e) return SyncResult(success=True) + async def getJournalEntries(self, config: Dict[str, Any], dateFrom: Optional[str] = None, dateTo: Optional[str] = None, accountNumbers: Optional[List[str]] = None) -> List[Dict[str, Any]]: + """Read GL entries from RMA. + + Strategy: first try GET /gl (bulk), then fall back to iterating + account transactions. Uses pre-fetched accountNumbers if provided. + """ + try: + params: Dict[str, str] = {} + if dateFrom: + params["from_date"] = dateFrom + if dateTo: + params["to_date"] = dateTo + + # Try bulk GL endpoint first + bulkEntries = await self._fetchGlBulk(config, params) + if bulkEntries: + return bulkEntries + + # Fallback: iterate accounts and fetch transactions + if accountNumbers: + accNums = accountNumbers + else: + chart = await self.getChartOfAccounts(config) + accNums = [acc.accountNumber for acc in chart if acc.accountNumber] + if not accNums: + return [] + + entriesByRef: Dict[str, Dict[str, Any]] = {} + fetchedCount = 0 + emptyCount = 0 + errorCount = 0 + async with aiohttp.ClientSession() as session: + for accNo in accNums: + url = self._buildUrl(config, f"charts/{accNo}/transactions") + try: + async with session.get(url, headers=self._buildHeaders(config), params=params, timeout=aiohttp.ClientTimeout(total=10)) as resp: + if resp.status != 200: + emptyCount += 1 + continue + body = await resp.text() + if not body.strip(): + emptyCount += 1 + continue + try: + data = json.loads(body) + except Exception: + errorCount += 1 + continue + except (asyncio.TimeoutError, Exception): + errorCount += 1 + continue + fetchedCount += 1 + + if isinstance(data, dict): + transactions = data.get("transaction") or data.get("@transaction") + else: + transactions = data + if isinstance(transactions, dict): + transactions = [transactions] + if not isinstance(transactions, list): + continue + + for t in transactions: + if not isinstance(t, dict): + continue + ref = t.get("reference") or t.get("@reference") or t.get("batch_number") or str(t.get("id") or "") + transDate = str(t.get("transdate") or t.get("@transdate") or "").split("T")[0] + desc = t.get("description") or t.get("memo") or t.get("@description") or "" + + rawAmount = float(t.get("amount") or t.get("@amount") or 0) + debit = rawAmount if rawAmount > 0 else 0.0 + credit = abs(rawAmount) if rawAmount < 0 else 0.0 + + if ref not in entriesByRef: + entriesByRef[ref] = { + "externalId": str(t.get("id") or t.get("@id") or ref), + "bookingDate": transDate, + "reference": ref, + "description": desc, + "currency": "CHF", + "totalAmount": 0.0, + "lines": [], + } + entry = entriesByRef[ref] + entry["lines"].append({ + "accountNumber": accNo, + "debitAmount": debit, + "creditAmount": credit, + "description": desc, + }) + entry["totalAmount"] += max(debit, credit) + + return list(entriesByRef.values()) + except Exception as e: + logger.error(f"RMA getJournalEntries error: {e}", exc_info=True) + return [] + + async def _fetchGlBulk(self, config: Dict[str, Any], params: Dict[str, str]) -> List[Dict[str, Any]]: + """Try GET /gl to fetch journal entries in bulk (not all RMA versions support this).""" + try: + async with aiohttp.ClientSession() as session: + url = self._buildUrl(config, "gl") + async with session.get(url, headers=self._buildHeaders(config), params=params, timeout=aiohttp.ClientTimeout(total=60)) as resp: + if resp.status != 200: + return [] + body = await resp.text() + if not body.strip(): + return [] + try: + data = json.loads(body) + except Exception: + return [] + + items = data if isinstance(data, list) else (data.get("gl_batch") or data.get("gl") or data.get("items") or []) + if isinstance(items, dict): + items = [items] + if not isinstance(items, list): + return [] + + entries = [] + for batch in items: + if not isinstance(batch, dict): + continue + transDate = str(batch.get("transdate") or batch.get("date") or "").split("T")[0] + ref = batch.get("batch_number") or batch.get("reference") or str(batch.get("id", "")) + desc = batch.get("description") or batch.get("notes") or "" + + rawTxns = batch.get("gl_transactions", {}) + txnList = rawTxns.get("gl_transaction") if isinstance(rawTxns, dict) else rawTxns + if isinstance(txnList, dict): + txnList = [txnList] + if not isinstance(txnList, list): + txnList = [] + + lines = [] + totalAmt = 0.0 + for t in txnList: + if not isinstance(t, dict): + continue + debit = float(t.get("debit_amount") or 0) + credit = float(t.get("credit_amount") or 0) + lines.append({ + "accountNumber": str(t.get("accno", "")), + "debitAmount": debit, + "creditAmount": credit, + "description": t.get("memo", ""), + }) + totalAmt += max(debit, credit) + + entries.append({ + "externalId": str(batch.get("id", ref)), + "bookingDate": transDate, + "reference": ref, + "description": desc, + "currency": batch.get("currency", "CHF"), + "totalAmount": totalAmt, + "lines": lines, + }) + return entries + except Exception as e: + logger.debug(f"RMA _fetchGlBulk not available: {e}") + return [] + async def pushInvoice(self, config: Dict[str, Any], invoice: Dict[str, Any]) -> SyncResult: try: async with aiohttp.ClientSession() as session: @@ -357,8 +520,8 @@ class AccountingConnectorRma(BaseAccountingConnector): async with session.get(url, headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp: if resp.status != 200: return [] - data = await resp.json() - return data if isinstance(data, list) else data.get("customer", []) + data = await self._parseJsonOrXmlList(resp, "customer") + return data except Exception as e: logger.error(f"RMA getCustomers error: {e}") return [] @@ -370,12 +533,39 @@ class AccountingConnectorRma(BaseAccountingConnector): async with session.get(url, headers=self._buildHeaders(config), timeout=aiohttp.ClientTimeout(total=30)) as resp: if resp.status != 200: return [] - data = await resp.json() - return data if isinstance(data, list) else data.get("vendor", []) + data = await self._parseJsonOrXmlList(resp, "vendor") + return data except Exception as e: logger.error(f"RMA getVendors error: {e}") return [] + async def _parseJsonOrXmlList(self, resp: aiohttp.ClientResponse, itemKey: str) -> List[Dict[str, Any]]: + """Parse RMA response that may be JSON or XML. Returns list of dicts.""" + body = await resp.text() + if not body or not body.strip(): + return [] + try: + data = json.loads(body) + if isinstance(data, list): + return data + if isinstance(data, dict): + items = data.get(itemKey) or data.get("items") or data.get("row") or [] + if isinstance(items, dict): + return [items] + return items if isinstance(items, list) else [] + return [] + except (json.JSONDecodeError, ValueError): + pass + result: List[Dict[str, Any]] = [] + ids = re.findall(r"([^<]+)", body) + names = re.findall(r"([^<]+)", body) + for i, rid in enumerate(ids): + entry: Dict[str, Any] = {"id": rid.strip()} + if i < len(names): + entry["name"] = names[i].strip() + result.append(entry) + return result + async def _findBelegByFilename(self, config: Dict[str, Any], session: aiohttp.ClientSession, fileName: str) -> Optional[str]: """Try GET /belege (undocumented) to find an existing beleg by filename.""" try: diff --git a/modules/features/trustee/datamodelFeatureTrustee.py b/modules/features/trustee/datamodelFeatureTrustee.py index bbad2102..538414a0 100644 --- a/modules/features/trustee/datamodelFeatureTrustee.py +++ b/modules/features/trustee/datamodelFeatureTrustee.py @@ -736,6 +736,177 @@ registerModelLabels( ) +# ── TrusteeData* tables (synced from external accounting apps for analysis) ── + + +class TrusteeDataAccount(BaseModel): + """Chart of accounts synced from external accounting system.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + accountNumber: str = Field(description="Account number (e.g. '1020')") + label: str = Field(default="", description="Account name") + accountType: Optional[str] = Field(default=None, description="asset / liability / equity / revenue / expense") + accountGroup: Optional[str] = Field(default=None, description="Account group/category") + currency: str = Field(default="CHF", description="Account currency") + isActive: bool = Field(default=True) + mandateId: Optional[str] = Field(default=None) + featureInstanceId: Optional[str] = Field(default=None) + + +registerModelLabels( + "TrusteeDataAccount", + {"en": "Account (Synced)", "de": "Konto (Sync)", "fr": "Compte (Sync)"}, + { + "id": {"en": "ID", "de": "ID", "fr": "ID"}, + "accountNumber": {"en": "Account Number", "de": "Kontonummer", "fr": "Numéro de compte"}, + "label": {"en": "Name", "de": "Bezeichnung", "fr": "Libellé"}, + "accountType": {"en": "Type", "de": "Typ", "fr": "Type"}, + "accountGroup": {"en": "Group", "de": "Gruppe", "fr": "Groupe"}, + "currency": {"en": "Currency", "de": "Währung", "fr": "Devise"}, + "isActive": {"en": "Active", "de": "Aktiv", "fr": "Actif"}, + "mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"}, + "featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"}, + }, +) + + +class TrusteeDataJournalEntry(BaseModel): + """Journal entry header synced from external accounting system.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + externalId: Optional[str] = Field(default=None, description="ID in the source system") + bookingDate: Optional[str] = Field(default=None, description="Booking date (YYYY-MM-DD)") + reference: Optional[str] = Field(default=None, description="Booking reference / voucher number") + description: str = Field(default="", description="Booking text") + currency: str = Field(default="CHF") + totalAmount: float = Field(default=0.0, description="Total amount of entry") + mandateId: Optional[str] = Field(default=None) + featureInstanceId: Optional[str] = Field(default=None) + + +registerModelLabels( + "TrusteeDataJournalEntry", + {"en": "Journal Entry (Synced)", "de": "Buchung (Sync)", "fr": "Écriture (Sync)"}, + { + "id": {"en": "ID", "de": "ID", "fr": "ID"}, + "externalId": {"en": "External ID", "de": "Externe ID", "fr": "ID externe"}, + "bookingDate": {"en": "Date", "de": "Datum", "fr": "Date"}, + "reference": {"en": "Reference", "de": "Referenz", "fr": "Référence"}, + "description": {"en": "Description", "de": "Beschreibung", "fr": "Description"}, + "currency": {"en": "Currency", "de": "Währung", "fr": "Devise"}, + "totalAmount": {"en": "Amount", "de": "Betrag", "fr": "Montant"}, + "mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"}, + "featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"}, + }, +) + + +class TrusteeDataJournalLine(BaseModel): + """Journal entry line (debit/credit) synced from external accounting system.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + journalEntryId: str = Field(description="FK → TrusteeDataJournalEntry.id") + accountNumber: str = Field(description="Account number") + debitAmount: float = Field(default=0.0) + creditAmount: float = Field(default=0.0) + currency: str = Field(default="CHF") + taxCode: Optional[str] = Field(default=None) + costCenter: Optional[str] = Field(default=None) + description: str = Field(default="") + mandateId: Optional[str] = Field(default=None) + featureInstanceId: Optional[str] = Field(default=None) + + +registerModelLabels( + "TrusteeDataJournalLine", + {"en": "Journal Line (Synced)", "de": "Buchungszeile (Sync)", "fr": "Ligne écriture (Sync)"}, + { + "id": {"en": "ID", "de": "ID", "fr": "ID"}, + "journalEntryId": {"en": "Journal Entry", "de": "Buchung", "fr": "Écriture"}, + "accountNumber": {"en": "Account", "de": "Konto", "fr": "Compte"}, + "debitAmount": {"en": "Debit", "de": "Soll", "fr": "Débit"}, + "creditAmount": {"en": "Credit", "de": "Haben", "fr": "Crédit"}, + "currency": {"en": "Currency", "de": "Währung", "fr": "Devise"}, + "taxCode": {"en": "Tax Code", "de": "Steuercode", "fr": "Code TVA"}, + "costCenter": {"en": "Cost Center", "de": "Kostenstelle", "fr": "Centre de coûts"}, + "description": {"en": "Description", "de": "Beschreibung", "fr": "Description"}, + "mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"}, + "featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"}, + }, +) + + +class TrusteeDataContact(BaseModel): + """Customer or vendor synced from external accounting system.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + externalId: Optional[str] = Field(default=None, description="ID in the source system") + contactType: str = Field(default="customer", description="customer / vendor / both") + contactNumber: Optional[str] = Field(default=None, description="Customer/vendor number") + name: str = Field(default="", description="Name / company") + address: Optional[str] = Field(default=None) + zip: Optional[str] = Field(default=None) + city: Optional[str] = Field(default=None) + country: Optional[str] = Field(default=None) + email: Optional[str] = Field(default=None) + phone: Optional[str] = Field(default=None) + vatNumber: Optional[str] = Field(default=None) + mandateId: Optional[str] = Field(default=None) + featureInstanceId: Optional[str] = Field(default=None) + + +registerModelLabels( + "TrusteeDataContact", + {"en": "Contact (Synced)", "de": "Kontakt (Sync)", "fr": "Contact (Sync)"}, + { + "id": {"en": "ID", "de": "ID", "fr": "ID"}, + "externalId": {"en": "External ID", "de": "Externe ID", "fr": "ID externe"}, + "contactType": {"en": "Type", "de": "Typ", "fr": "Type"}, + "contactNumber": {"en": "Number", "de": "Nummer", "fr": "Numéro"}, + "name": {"en": "Name", "de": "Name", "fr": "Nom"}, + "address": {"en": "Address", "de": "Adresse", "fr": "Adresse"}, + "zip": {"en": "ZIP", "de": "PLZ", "fr": "NPA"}, + "city": {"en": "City", "de": "Ort", "fr": "Ville"}, + "country": {"en": "Country", "de": "Land", "fr": "Pays"}, + "email": {"en": "Email", "de": "E-Mail", "fr": "E-mail"}, + "phone": {"en": "Phone", "de": "Telefon", "fr": "Téléphone"}, + "vatNumber": {"en": "VAT Number", "de": "MWST-Nr.", "fr": "N° TVA"}, + "mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"}, + "featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"}, + }, +) + + +class TrusteeDataAccountBalance(BaseModel): + """Account balance per period, derived from journal lines or directly from accounting system.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + accountNumber: str = Field(description="Account number") + periodYear: int = Field(description="Fiscal year") + periodMonth: int = Field(default=0, description="Month (1-12); 0 = annual total") + openingBalance: float = Field(default=0.0) + debitTotal: float = Field(default=0.0) + creditTotal: float = Field(default=0.0) + closingBalance: float = Field(default=0.0) + currency: str = Field(default="CHF") + mandateId: Optional[str] = Field(default=None) + featureInstanceId: Optional[str] = Field(default=None) + + +registerModelLabels( + "TrusteeDataAccountBalance", + {"en": "Account Balance (Synced)", "de": "Kontosaldo (Sync)", "fr": "Solde compte (Sync)"}, + { + "id": {"en": "ID", "de": "ID", "fr": "ID"}, + "accountNumber": {"en": "Account", "de": "Konto", "fr": "Compte"}, + "periodYear": {"en": "Year", "de": "Jahr", "fr": "Année"}, + "periodMonth": {"en": "Month", "de": "Monat", "fr": "Mois"}, + "openingBalance": {"en": "Opening Balance", "de": "Eröffnungssaldo", "fr": "Solde d'ouverture"}, + "debitTotal": {"en": "Debit Total", "de": "Soll-Umsatz", "fr": "Total débit"}, + "creditTotal": {"en": "Credit Total", "de": "Haben-Umsatz", "fr": "Total crédit"}, + "closingBalance": {"en": "Closing Balance", "de": "Schlusssaldo", "fr": "Solde de clôture"}, + "currency": {"en": "Currency", "de": "Währung", "fr": "Devise"}, + "mandateId": {"en": "Mandate", "de": "Mandat", "fr": "Mandat"}, + "featureInstanceId": {"en": "Feature Instance", "de": "Feature-Instanz", "fr": "Instance"}, + }, +) + + class TrusteeAccountingConfig(BaseModel): """Per-instance accounting system configuration with encrypted credentials. diff --git a/modules/features/trustee/mainTrustee.py b/modules/features/trustee/mainTrustee.py index 1fa1948f..606da308 100644 --- a/modules/features/trustee/mainTrustee.py +++ b/modules/features/trustee/mainTrustee.py @@ -78,6 +78,31 @@ DATA_OBJECTS = [ "label": {"en": "Accounting Sync", "de": "Buchhaltungs-Synchronisation", "fr": "Sync. comptable"}, "meta": {"table": "TrusteeAccountingSync", "fields": ["id", "positionId", "syncStatus", "externalId"]} }, + { + "objectKey": "data.feature.trustee.TrusteeDataAccount", + "label": {"en": "Accounts (Synced)", "de": "Kontenplan (Sync)", "fr": "Plan comptable (Sync)"}, + "meta": {"table": "TrusteeDataAccount", "fields": ["id", "accountNumber", "label", "accountType", "accountGroup", "currency", "isActive"]} + }, + { + "objectKey": "data.feature.trustee.TrusteeDataJournalEntry", + "label": {"en": "Journal Entries (Synced)", "de": "Buchungen (Sync)", "fr": "Écritures (Sync)"}, + "meta": {"table": "TrusteeDataJournalEntry", "fields": ["id", "externalId", "bookingDate", "reference", "description", "currency", "totalAmount"]} + }, + { + "objectKey": "data.feature.trustee.TrusteeDataJournalLine", + "label": {"en": "Journal Lines (Synced)", "de": "Buchungszeilen (Sync)", "fr": "Lignes écriture (Sync)"}, + "meta": {"table": "TrusteeDataJournalLine", "fields": ["id", "journalEntryId", "accountNumber", "debitAmount", "creditAmount", "currency", "taxCode", "costCenter", "description"]} + }, + { + "objectKey": "data.feature.trustee.TrusteeDataContact", + "label": {"en": "Contacts (Synced)", "de": "Kontakte (Sync)", "fr": "Contacts (Sync)"}, + "meta": {"table": "TrusteeDataContact", "fields": ["id", "externalId", "contactType", "contactNumber", "name", "address", "zip", "city", "country", "email", "phone", "vatNumber"]} + }, + { + "objectKey": "data.feature.trustee.TrusteeDataAccountBalance", + "label": {"en": "Account Balances (Synced)", "de": "Kontosalden (Sync)", "fr": "Soldes comptes (Sync)"}, + "meta": {"table": "TrusteeDataAccountBalance", "fields": ["id", "accountNumber", "periodYear", "periodMonth", "openingBalance", "debitTotal", "creditTotal", "closingBalance", "currency"]} + }, { "objectKey": "data.feature.trustee.*", "label": {"en": "All Trustee Data", "de": "Alle Treuhand-Daten", "fr": "Toutes les données fiduciaires"}, diff --git a/modules/features/trustee/routeFeatureTrustee.py b/modules/features/trustee/routeFeatureTrustee.py index 9ad41b9d..feb873ae 100644 --- a/modules/features/trustee/routeFeatureTrustee.py +++ b/modules/features/trustee/routeFeatureTrustee.py @@ -1481,6 +1481,63 @@ def get_position_sync_status( return {"items": items} +# ===== Accounting Data Import ===== + +@router.post("/{instanceId}/accounting/import-data") +@limiter.limit("3/minute") +async def import_accounting_data( + request: Request, + instanceId: str = Path(..., description="Feature Instance ID"), + data: Dict[str, Any] = Body(default={}), + context: RequestContext = Depends(getRequestContext) +) -> Dict[str, Any]: + """Import accounting data (chart, journal entries, contacts) from the external system into TrusteeData* tables.""" + mandateId = _validateInstanceAccess(instanceId, context) + interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) + from .accounting.accountingDataSync import AccountingDataSync + sync = AccountingDataSync(interface) + dateFrom = data.get("dateFrom") + dateTo = data.get("dateTo") + result = await sync.importData( + featureInstanceId=instanceId, + mandateId=mandateId, + dateFrom=dateFrom, + dateTo=dateTo, + ) + return result + + +@router.get("/{instanceId}/accounting/import-status") +@limiter.limit("30/minute") +def get_import_status( + request: Request, + instanceId: str = Path(..., description="Feature Instance ID"), + context: RequestContext = Depends(getRequestContext) +) -> Dict[str, Any]: + """Get counts of imported TrusteeData* records for this instance.""" + mandateId = _validateInstanceAccess(instanceId, context) + interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) + from .datamodelFeatureTrustee import ( + TrusteeDataAccount, TrusteeDataJournalEntry, TrusteeDataJournalLine, + TrusteeDataContact, TrusteeDataAccountBalance, TrusteeAccountingConfig, + ) + filt = {"featureInstanceId": instanceId} + counts = { + "accounts": len(interface.db.getRecordset(TrusteeDataAccount, recordFilter=filt) or []), + "journalEntries": len(interface.db.getRecordset(TrusteeDataJournalEntry, recordFilter=filt) or []), + "journalLines": len(interface.db.getRecordset(TrusteeDataJournalLine, recordFilter=filt) or []), + "contacts": len(interface.db.getRecordset(TrusteeDataContact, recordFilter=filt) or []), + "accountBalances": len(interface.db.getRecordset(TrusteeDataAccountBalance, recordFilter=filt) or []), + } + cfgRecords = interface.db.getRecordset(TrusteeAccountingConfig, recordFilter={"featureInstanceId": instanceId, "isActive": True}) + if cfgRecords: + cfg = cfgRecords[0] + counts["lastSyncAt"] = cfg.get("lastSyncAt") + counts["lastSyncStatus"] = cfg.get("lastSyncStatus") + counts["lastSyncErrorMessage"] = cfg.get("lastSyncErrorMessage") + return counts + + # ===== Position-Document Query ===== @router.get("/{instanceId}/positions/document/{documentId}", response_model=List[TrusteePosition]) diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py index 1ea8a93a..7397beb7 100644 --- a/modules/features/workspace/routeFeatureWorkspace.py +++ b/modules/features/workspace/routeFeatureWorkspace.py @@ -72,6 +72,7 @@ class WorkspaceInputRequest(BaseModel): fileIds: List[str] = Field(default_factory=list, description="Referenced file IDs") uploadedFiles: List[str] = Field(default_factory=list, description="Newly uploaded file IDs") dataSourceIds: List[str] = Field(default_factory=list, description="Active DataSource IDs") + featureDataSourceIds: List[str] = Field(default_factory=list, description="Attached FeatureDataSource IDs") voiceMode: bool = Field(default=False, description="Enable voice response") workflowId: Optional[str] = Field(default=None, description="Continue existing workflow") userLanguage: str = Field(default="en", description="User language code") @@ -184,6 +185,63 @@ def _buildDataSourceContext(chatService, dataSourceIds: List[str]) -> str: return "\n".join(parts) if found else "" +def _buildFeatureDataSourceContext(featureDataSourceIds: List[str]) -> str: + """Build a description of attached feature data sources for the agent prompt.""" + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + from modules.security.rbacCatalog import getCatalogService + from modules.interfaces.interfaceDbApp import getRootInterface + + parts = [ + "The user has attached data from the following feature instances.", + "Use queryFeatureInstance(featureInstanceId, question) to query this data.", + "", + ] + found = False + catalog = getCatalogService() + rootIf = getRootInterface() + + instanceCache: Dict[str, Any] = {} + for fdsId in featureDataSourceIds: + try: + records = rootIf.db.getRecordset(FeatureDataSource, recordFilter={"id": fdsId}) + if not records: + logger.warning(f"FeatureDataSource {fdsId} not found") + continue + fds = records[0] + found = True + + fiId = fds.get("featureInstanceId", "") + featureCode = fds.get("featureCode", "") + tableName = fds.get("tableName", "") + label = fds.get("label", tableName) + + if fiId not in instanceCache: + inst = rootIf.getFeatureInstance(fiId) + instanceCache[fiId] = inst + + inst = instanceCache.get(fiId) + instanceLabel = getattr(inst, "label", fiId) if inst else fiId + + dataObj = catalog.getDataObjects(featureCode) + tableFields = [] + for obj in dataObj: + if obj.get("meta", {}).get("table") == tableName: + tableFields = obj.get("meta", {}).get("fields", []) + break + + parts.append( + f"- featureInstanceId: {fiId}\n" + f" feature: {featureCode}\n" + f" instance: \"{instanceLabel}\"\n" + f" table: {tableName} ({label})\n" + f" fields: {', '.join(tableFields) if tableFields else 'all'}" + ) + except Exception as e: + logger.warning(f"Error loading FeatureDataSource {fdsId}: {e}") + + return "\n".join(parts) if found else "" + + def _loadConversationHistory(chatInterface, workflowId: str, currentPrompt: str) -> List[Dict[str, str]]: """Load prior messages from DB for follow-up context, excluding the current prompt.""" try: @@ -248,7 +306,7 @@ async def _deriveWorkflowName(prompt: str, aiService) -> str: # --------------------------------------------------------------------------- @router.post("/{instanceId}/start/stream") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def streamWorkspaceStart( request: Request, instanceId: str = Path(..., description="Feature instance ID"), @@ -264,7 +322,13 @@ async def streamWorkspaceStart( if userInput.workflowId: workflow = chatInterface.getWorkflow(userInput.workflowId) if not workflow: - raise HTTPException(status_code=404, detail=f"Workflow {userInput.workflowId} not found") + logger.warning(f"Workflow {userInput.workflowId} not found, creating new one") + workflow = chatInterface.createWorkflow({ + "featureInstanceId": instanceId, + "status": "active", + "name": "", + "workflowMode": "Dynamic", + }) else: workflow = chatInterface.createWorkflow({ "featureInstanceId": instanceId, @@ -290,6 +354,7 @@ async def streamWorkspaceStart( prompt=userInput.prompt, fileIds=userInput.fileIds, dataSourceIds=userInput.dataSourceIds, + featureDataSourceIds=userInput.featureDataSourceIds, voiceMode=userInput.voiceMode, instanceId=instanceId, user=context.user, @@ -344,13 +409,14 @@ async def _runWorkspaceAgent( prompt: str, fileIds: List[str], dataSourceIds: List[str], - voiceMode: bool, - instanceId: str, - user, - mandateId: str, - aiObjects, - chatInterface, - eventManager, + featureDataSourceIds: List[str] = None, + voiceMode: bool = False, + instanceId: str = "", + user=None, + mandateId: str = "", + aiObjects=None, + chatInterface=None, + eventManager=None, userLanguage: str = "en", instanceConfig: Dict[str, Any] = None, allowedProviders: List[str] = None, @@ -396,6 +462,11 @@ async def _runWorkspaceAgent( if dsInfo: enrichedPrompt = f"{prompt}\n\n[Active Data Sources]\n{dsInfo}" + if featureDataSourceIds: + fdsInfo = _buildFeatureDataSourceContext(featureDataSourceIds) + if fdsInfo: + enrichedPrompt = f"{enrichedPrompt}\n\n[Attached Feature Data Sources]\n{fdsInfo}" + conversationHistory = _loadConversationHistory(chatInterface, workflowId, prompt) accumulatedText = "" @@ -525,7 +596,7 @@ async def _runWorkspaceAgent( # --------------------------------------------------------------------------- @router.post("/{instanceId}/{workflowId}/stop") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def stopWorkspace( request: Request, instanceId: str = Path(...), @@ -549,7 +620,7 @@ async def stopWorkspace( # --------------------------------------------------------------------------- @router.get("/{instanceId}/workflows") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def listWorkspaceWorkflows( request: Request, instanceId: str = Path(...), @@ -585,7 +656,7 @@ class UpdateWorkflowRequest(BaseModel): @router.patch("/{instanceId}/workflows/{workflowId}") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def patchWorkspaceWorkflow( request: Request, instanceId: str = Path(..., description="Feature instance ID"), @@ -620,7 +691,7 @@ async def patchWorkspaceWorkflow( @router.delete("/{instanceId}/workflows/{workflowId}") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def deleteWorkspaceWorkflow( request: Request, instanceId: str = Path(...), @@ -638,7 +709,7 @@ async def deleteWorkspaceWorkflow( @router.post("/{instanceId}/workflows") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def createWorkspaceWorkflow( request: Request, instanceId: str = Path(...), @@ -661,7 +732,7 @@ async def createWorkspaceWorkflow( @router.get("/{instanceId}/workflows/{workflowId}/messages") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def getWorkspaceMessages( request: Request, instanceId: str = Path(...), @@ -691,7 +762,7 @@ async def getWorkspaceMessages( # --------------------------------------------------------------------------- @router.get("/{instanceId}/files") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def listWorkspaceFiles( request: Request, instanceId: str = Path(...), @@ -723,7 +794,7 @@ async def listWorkspaceFiles( @router.get("/{instanceId}/files/{fileId}/content") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def getFileContent( request: Request, instanceId: str = Path(...), @@ -751,7 +822,7 @@ async def getFileContent( @router.get("/{instanceId}/folders") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def listWorkspaceFolders( request: Request, instanceId: str = Path(...), @@ -775,7 +846,7 @@ async def listWorkspaceFolders( @router.get("/{instanceId}/datasources") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def listWorkspaceDataSources( request: Request, instanceId: str = Path(...), @@ -798,7 +869,7 @@ async def listWorkspaceDataSources( @router.get("/{instanceId}/connections") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def listWorkspaceConnections( request: Request, instanceId: str = Path(...), @@ -843,7 +914,7 @@ class CreateDataSourceRequest(BaseModel): @router.post("/{instanceId}/datasources") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def createWorkspaceDataSource( request: Request, instanceId: str = Path(...), @@ -871,7 +942,7 @@ async def createWorkspaceDataSource( @router.delete("/{instanceId}/datasources/{dataSourceId}") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def deleteWorkspaceDataSource( request: Request, instanceId: str = Path(...), @@ -892,8 +963,204 @@ async def deleteWorkspaceDataSource( return JSONResponse({"success": True}) +# ---- Feature Connections & Feature Data Sources ---- + +@router.get("/{instanceId}/feature-connections") +@limiter.limit("120/minute") +async def listFeatureConnections( + request: Request, + instanceId: str = Path(...), + context: RequestContext = Depends(getRequestContext), +): + """List feature instances the user has access to across ALL mandates.""" + _validateInstanceAccess(instanceId, context) + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.security.rbacCatalog import getCatalogService + from modules.datamodels.datamodelUam import Mandate + + rootIf = getRootInterface() + userId = str(context.user.id) + + catalog = getCatalogService() + featureCodesWithData = catalog.getFeaturesWithDataObjects() + + userMandates = rootIf.getUserMandates(userId) + if not userMandates: + return JSONResponse({"featureConnections": []}) + + mandateLabels: dict = {} + for um in userMandates: + try: + rows = rootIf.db.getRecordset(Mandate, recordFilter={"id": um.mandateId}) + if rows: + m = rows[0] + mandateLabels[um.mandateId] = m.get("label") or m.get("name") or um.mandateId + except Exception: + mandateLabels[um.mandateId] = um.mandateId + + items = [] + seenIds: set = set() + for um in userMandates: + allInstances = rootIf.getFeatureInstancesByMandate(um.mandateId) + for inst in allInstances: + if inst.id in seenIds: + continue + seenIds.add(inst.id) + if not inst.enabled: + continue + if inst.featureCode not in featureCodesWithData: + continue + featureAccess = rootIf.getFeatureAccess(userId, inst.id) + if not featureAccess or not featureAccess.enabled: + continue + + featureDef = catalog.getFeatureDefinition(inst.featureCode) or {} + dataObjects = catalog.getDataObjects(inst.featureCode) + mLabel = mandateLabels.get(inst.mandateId, "") + label = inst.label or inst.featureCode + if mLabel: + label = f"{label} ({mLabel})" + items.append({ + "featureInstanceId": inst.id, + "featureCode": inst.featureCode, + "mandateId": inst.mandateId, + "label": label, + "icon": featureDef.get("icon", "mdi-database"), + "tableCount": len(dataObjects), + }) + + return JSONResponse({"featureConnections": items}) + + +@router.get("/{instanceId}/feature-connections/{fiId}/tables") +@limiter.limit("120/minute") +async def listFeatureConnectionTables( + request: Request, + instanceId: str = Path(...), + fiId: str = Path(..., description="Feature instance ID"), + context: RequestContext = Depends(getRequestContext), +): + """List data tables (DATA_OBJECTS) for a feature instance, filtered by RBAC.""" + _validateInstanceAccess(instanceId, context) + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.security.rbacCatalog import getCatalogService + + rootIf = getRootInterface() + inst = rootIf.getFeatureInstance(fiId) + if not inst: + raise HTTPException(status_code=404, detail="Feature instance not found") + + mandateId = str(inst.mandateId) if inst.mandateId else None + catalog = getCatalogService() + + try: + from modules.security.rbac import RbacClass + from modules.security.rootAccess import getRootDbAppConnector + dbApp = getRootDbAppConnector() + rbac = RbacClass(dbApp, dbApp=dbApp) + accessible = catalog.getAccessibleDataObjects( + featureCode=inst.featureCode, + rbacInstance=rbac, + user=context.user, + mandateId=mandateId or "", + featureInstanceId=fiId, + ) + except Exception: + accessible = catalog.getDataObjects(inst.featureCode) + + tables = [] + for obj in accessible: + meta = obj.get("meta", {}) + tables.append({ + "objectKey": obj.get("objectKey", ""), + "tableName": meta.get("table", ""), + "label": obj.get("label", {}), + "fields": meta.get("fields", []), + }) + + return JSONResponse({"tables": tables}) + + +class CreateFeatureDataSourceRequest(BaseModel): + """Request body for adding a feature table as data source.""" + featureInstanceId: str = Field(description="Feature instance ID") + featureCode: str = Field(description="Feature code") + tableName: str = Field(description="Table name from DATA_OBJECTS") + objectKey: str = Field(description="RBAC object key") + label: str = Field(description="User-visible label") + + +@router.post("/{instanceId}/feature-datasources") +@limiter.limit("300/minute") +async def createFeatureDataSource( + request: Request, + instanceId: str = Path(...), + body: CreateFeatureDataSourceRequest = Body(...), + context: RequestContext = Depends(getRequestContext), +): + """Create a FeatureDataSource for this workspace instance.""" + _validateInstanceAccess(instanceId, context) + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + + rootIf = getRootInterface() + inst = rootIf.getFeatureInstance(body.featureInstanceId) + mandateId = str(inst.mandateId) if inst else (str(context.mandateId) if context.mandateId else "") + + fds = FeatureDataSource( + featureInstanceId=body.featureInstanceId, + featureCode=body.featureCode, + tableName=body.tableName, + objectKey=body.objectKey, + label=body.label, + mandateId=mandateId, + userId=str(context.user.id), + workspaceInstanceId=instanceId, + ) + created = rootIf.db.recordCreate(FeatureDataSource, fds.model_dump()) + return JSONResponse(created if isinstance(created, dict) else fds.model_dump()) + + +@router.get("/{instanceId}/feature-datasources") +@limiter.limit("300/minute") +async def listFeatureDataSources( + request: Request, + instanceId: str = Path(...), + context: RequestContext = Depends(getRequestContext), +): + """List active FeatureDataSources for this workspace instance.""" + _validateInstanceAccess(instanceId, context) + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + + rootIf = getRootInterface() + records = rootIf.db.getRecordset( + FeatureDataSource, + recordFilter={"workspaceInstanceId": instanceId}, + ) + return JSONResponse({"featureDataSources": records or []}) + + +@router.delete("/{instanceId}/feature-datasources/{featureDataSourceId}") +@limiter.limit("300/minute") +async def deleteFeatureDataSource( + request: Request, + instanceId: str = Path(...), + featureDataSourceId: str = Path(...), + context: RequestContext = Depends(getRequestContext), +): + """Delete a FeatureDataSource.""" + _validateInstanceAccess(instanceId, context) + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + + rootIf = getRootInterface() + rootIf.db.recordDelete(FeatureDataSource, featureDataSourceId) + return JSONResponse({"success": True}) + + @router.get("/{instanceId}/connections/{connectionId}/services") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def listConnectionServices( request: Request, instanceId: str = Path(...), @@ -950,7 +1217,7 @@ async def listConnectionServices( @router.get("/{instanceId}/connections/{connectionId}/browse") -@limiter.limit("60/minute") +@limiter.limit("300/minute") async def browseConnectionService( request: Request, instanceId: str = Path(...), @@ -997,7 +1264,7 @@ async def browseConnectionService( # --------------------------------------------------------------------------- @router.post("/{instanceId}/voice/transcribe") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def transcribeVoice( request: Request, instanceId: str = Path(...), @@ -1026,7 +1293,7 @@ async def transcribeVoice( @router.post("/{instanceId}/voice/synthesize") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def synthesizeVoice( request: Request, instanceId: str = Path(...), @@ -1046,7 +1313,7 @@ async def synthesizeVoice( # ========================================================================= @router.get("/{instanceId}/settings/voice") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def getVoiceSettings( request: Request, instanceId: str = Path(...), @@ -1071,7 +1338,7 @@ async def getVoiceSettings( @router.put("/{instanceId}/settings/voice") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def updateVoiceSettings( request: Request, instanceId: str = Path(...), @@ -1109,7 +1376,7 @@ async def updateVoiceSettings( @router.get("/{instanceId}/voice/languages") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def getVoiceLanguages( request: Request, instanceId: str = Path(...), @@ -1125,7 +1392,7 @@ async def getVoiceLanguages( @router.get("/{instanceId}/voice/voices") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def getVoiceVoices( request: Request, instanceId: str = Path(...), @@ -1142,7 +1409,7 @@ async def getVoiceVoices( @router.post("/{instanceId}/voice/test") -@limiter.limit("10/minute") +@limiter.limit("30/minute") async def testVoice( request: Request, instanceId: str = Path(...), @@ -1180,7 +1447,7 @@ async def testVoice( @router.get("/{instanceId}/pending-edits") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def getPendingEdits( request: Request, instanceId: str = Path(...), @@ -1193,7 +1460,7 @@ async def getPendingEdits( @router.post("/{instanceId}/edit/{editId}/accept") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def acceptEdit( request: Request, instanceId: str = Path(...), @@ -1230,7 +1497,7 @@ async def acceptEdit( @router.post("/{instanceId}/edit/{editId}/reject") -@limiter.limit("30/minute") +@limiter.limit("120/minute") async def rejectEdit( request: Request, instanceId: str = Path(...), @@ -1256,7 +1523,7 @@ async def rejectEdit( @router.post("/{instanceId}/edit/accept-all") -@limiter.limit("10/minute") +@limiter.limit("30/minute") async def acceptAllEdits( request: Request, instanceId: str = Path(...), @@ -1287,7 +1554,7 @@ async def acceptAllEdits( @router.post("/{instanceId}/edit/reject-all") -@limiter.limit("10/minute") +@limiter.limit("30/minute") async def rejectAllEdits( request: Request, instanceId: str = Path(...), diff --git a/modules/routes/routeDataFiles.py b/modules/routes/routeDataFiles.py index c3138aed..470793bb 100644 --- a/modules/routes/routeDataFiles.py +++ b/modules/routes/routeDataFiles.py @@ -441,6 +441,72 @@ def move_folder( raise HTTPException(status_code=500, detail=str(e)) +@router.get("/folders/{folderId}/download") +@limiter.limit("10/minute") +def download_folder( + request: Request, + folderId: str = Path(..., description="ID of the folder to download as ZIP"), + currentUser: User = Depends(getCurrentUser), + context: RequestContext = Depends(getRequestContext) +) -> Response: + """Download a folder (including subfolders) as a ZIP archive.""" + import io + import zipfile + import urllib.parse + + try: + mgmt = interfaceDbManagement.getInterface( + currentUser, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None, + ) + + folder = mgmt.getFolder(folderId) + if not folder: + raise HTTPException(status_code=404, detail=f"Folder {folderId} not found") + + folderName = folder.get("name", "download") + + def _collectFiles(parentId: str, pathPrefix: str): + """Recursively collect (zipPath, fileId) tuples.""" + entries = [] + for f in mgmt._getFilesByCurrentUser(recordFilter={"folderId": parentId}): + fname = f.get("fileName") or f.get("name") or f.get("id", "file") + entries.append((f"{pathPrefix}{fname}", f["id"])) + for sub in mgmt.listFolders(parentId=parentId): + subName = sub.get("name", sub["id"]) + entries.extend(_collectFiles(sub["id"], f"{pathPrefix}{subName}/")) + return entries + + fileEntries = _collectFiles(folderId, "") + if not fileEntries: + raise HTTPException(status_code=404, detail="Folder is empty") + + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", zipfile.ZIP_DEFLATED) as zf: + for zipPath, fileId in fileEntries: + data = mgmt.getFileData(fileId) + if data: + zf.writestr(zipPath, data) + + buf.seek(0) + zipBytes = buf.getvalue() + encodedName = urllib.parse.quote(f"{folderName}.zip") + + return Response( + content=zipBytes, + media_type="application/zip", + headers={ + "Content-Disposition": f"attachment; filename*=UTF-8''{encodedName}" + } + ) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error downloading folder as ZIP: {e}") + raise HTTPException(status_code=500, detail=f"Error downloading folder: {str(e)}") + + @router.post("/batch-delete") @limiter.limit("10/minute") def batch_delete_items( diff --git a/modules/security/rbacCatalog.py b/modules/security/rbacCatalog.py index a913a095..14b87534 100644 --- a/modules/security/rbacCatalog.py +++ b/modules/security/rbacCatalog.py @@ -119,7 +119,50 @@ class RbacCatalogService: if featureCode: return [obj for obj in self._dataObjects.values() if obj["featureCode"] == featureCode] return list(self._dataObjects.values()) - + + def getAccessibleDataObjects( + self, + featureCode: str, + rbacInstance, + user, + mandateId: str, + featureInstanceId: str, + ) -> List[Dict[str, Any]]: + """Get DATA objects filtered by RBAC read permission for the user. + + Args: + featureCode: Feature code to filter by + rbacInstance: RbacClass instance for permission checks + user: User object + mandateId: Mandate scope + featureInstanceId: Feature instance scope + """ + from modules.datamodels.datamodelRbac import AccessRuleContext + allObjects = self.getDataObjects(featureCode) + accessible = [] + for obj in allObjects: + objectKey = obj.get("objectKey", "") + try: + perms = rbacInstance.getUserPermissions( + user=user, + context=AccessRuleContext.DATA, + item=objectKey, + mandateId=mandateId, + featureInstanceId=featureInstanceId, + ) + if perms.view or perms.read.value != "n": + accessible.append(obj) + except Exception: + pass + return accessible + + def getFeaturesWithDataObjects(self) -> List[str]: + """Get feature codes that have at least one registered DATA object.""" + codes = set() + for obj in self._dataObjects.values(): + codes.add(obj["featureCode"]) + return list(codes) + def getAllObjects(self, featureCode: Optional[str] = None) -> List[Dict[str, Any]]: """Get all RBAC objects (UI + RESOURCE + DATA), optionally filtered by feature.""" return self.getUiObjects(featureCode) + self.getResourceObjects(featureCode) + self.getDataObjects(featureCode) diff --git a/modules/serviceCenter/services/serviceAgent/agentLoop.py b/modules/serviceCenter/services/serviceAgent/agentLoop.py index eaa3bd75..1cf74152 100644 --- a/modules/serviceCenter/services/serviceAgent/agentLoop.py +++ b/modules/serviceCenter/services/serviceAgent/agentLoop.py @@ -21,6 +21,7 @@ from modules.serviceCenter.services.serviceAgent.conversationManager import ( ConversationManager, buildSystemPrompt ) from modules.shared.timeUtils import getUtcTimestamp +from modules.shared.jsonUtils import closeJsonStructures logger = logging.getLogger(__name__) @@ -64,7 +65,12 @@ async def runAgentLoop( tools = toolRegistry.getTools() toolDefinitions = toolRegistry.formatToolsForFunctionCalling() - toolsText = toolRegistry.formatToolsForPrompt() + + # Text-based tool descriptions are ONLY used as fallback when native function + # calling is unavailable. Including both creates conflicting instructions + # (text ```tool_call format vs native tool_use blocks) and can cause the model + # to respond with plain text instead of actual tool calls. + toolsText = "" if toolDefinitions else toolRegistry.formatToolsForPrompt() systemPrompt = buildSystemPrompt(tools, toolsText, userLanguage=userLanguage) conversation = ConversationManager(systemPrompt) @@ -192,6 +198,29 @@ async def runAgentLoop( toolCalls = _parseToolCalls(aiResponse) textContent = _extractTextContent(aiResponse) + logger.debug( + f"Round {state.currentRound} AI response: model={aiResponse.modelName}, " + f"toolCalls={len(toolCalls)}, nativeToolCalls={'yes' if aiResponse.toolCalls else 'no'}, " + f"contentLen={len(aiResponse.content)}, streamedLen={len(streamedText)}" + ) + + # Empty response (no content, no tool calls) = model returned nothing useful. + # Burn the round but let the loop continue so the next iteration can retry + # (the failover mechanism in the AI layer will try alternative models). + if not toolCalls and not textContent and not streamedText: + logger.warning( + f"Round {state.currentRound}: AI returned empty response " + f"(model={aiResponse.modelName}). Retrying next round." + ) + conversation.addUserMessage( + "Your previous response was empty. Please use the available tools " + "to accomplish the task. Start by planning the steps, then call the " + "appropriate tools." + ) + roundLog.durationMs = int((time.time() - roundStartTime) * 1000) + trace.rounds.append(roundLog) + continue + if textContent and not streamedText: yield AgentEvent(type=AgentEventTypeEnum.MESSAGE, content=textContent) @@ -228,7 +257,8 @@ async def runAgentLoop( args=next((tc.args for tc in toolCalls if tc.id == result.toolCallId), {}), success=result.success, durationMs=result.durationMs, - error=result.error + error=result.error, + resultData=result.data[:300] if result.data else "", )) if not result.success: logger.warning(f"Tool '{result.toolName}' failed: {result.error}") @@ -282,6 +312,8 @@ async def runAgentLoop( trace.totalCostCHF = state.totalCostCHF trace.abortReason = state.abortReason + artifactSummary = _buildArtifactSummary(trace.rounds) + yield AgentEvent( type=AgentEventTypeEnum.AGENT_SUMMARY, data={ @@ -291,7 +323,8 @@ async def runAgentLoop( "costCHF": round(state.totalCostCHF, 4), "processingTime": round(state.totalProcessingTime, 2), "status": state.status.value, - "abortReason": state.abortReason + "abortReason": state.abortReason, + "artifacts": artifactSummary, } ) @@ -351,46 +384,19 @@ async def _executeToolCalls(toolCalls: List[ToolCallRequest], def _repairTruncatedJson(raw: str) -> Optional[Dict[str, Any]]: - """Try to repair truncated JSON from LLM output by closing open brackets/braces. + """Repair truncated JSON using the shared jsonUtils toolbox. + Uses closeJsonStructures which handles open strings, brackets, braces, + and trailing commas with stack-based structure tracking. Returns parsed dict on success, None if unrecoverable. """ if not raw or not raw.strip().startswith("{"): return None - - openBraces = raw.count("{") - raw.count("}") - openBrackets = raw.count("[") - raw.count("]") - - inString = False - lastQuoteEscaped = False - quoteCount = 0 - for ch in raw: - if ch == '"' and not lastQuoteEscaped: - quoteCount += 1 - inString = not inString - lastQuoteEscaped = (ch == '\\') - - candidate = raw - if quoteCount % 2 != 0: - candidate += '"' - - candidate += "]" * max(0, openBrackets) - candidate += "}" * max(0, openBraces) - try: - return json.loads(candidate) - except json.JSONDecodeError: - pass - - lastComma = candidate.rfind(",") - if lastComma > 0: - trimmed = candidate[:lastComma] + candidate[lastComma + 1:] - try: - return json.loads(trimmed) - except json.JSONDecodeError: - pass - - return None + closed = closeJsonStructures(raw) + return json.loads(closed) + except (json.JSONDecodeError, Exception): + return None def _parseToolCalls(aiResponse: AiCallResponse) -> List[ToolCallRequest]: @@ -409,7 +415,14 @@ def _parseToolCalls(aiResponse: AiCallResponse) -> List[ToolCallRequest]: parsedArgs = _repairTruncatedJson(rawArgs) if parsedArgs is None: logger.warning(f"Unrecoverable truncated JSON for '{tc['function']['name']}': {rawArgs[:200]}") - parsedArgs = {"_parseError": f"Truncated JSON arguments – model output was cut off. Raw start: {rawArgs[:120]}"} + parsedArgs = {"_parseError": ( + "Your tool call arguments were truncated (output cut off by token limit). " + "The content is too large for a single tool call. Strategies:\n" + "1. For new files: use writeFile(mode='create') with the first part, " + "then writeFile(fileId=..., mode='append') for subsequent parts (~8000 chars each).\n" + "2. For editing existing files: use replaceInFile to change only the specific parts.\n" + "3. For documentation: split into multiple smaller files." + )} else: logger.info(f"Repaired truncated JSON for '{tc['function']['name']}'") else: @@ -471,3 +484,24 @@ def _buildProgressSummary(state: AgentState, reason: str) -> str: f"- Cost: {state.totalCostCHF:.4f} CHF\n" f"- Processing time: {state.totalProcessingTime:.1f}s" ) + + +_ARTIFACT_TOOLS = {"writeFile", "replaceInFile", "deleteFile", "renameFile", "copyFile", + "createFolder", "deleteFolder", "renderDocument", "generateImage"} + +def _buildArtifactSummary(roundLogs: List[AgentRoundLog]) -> str: + """Extract file operations and key results from all agent rounds. + + Produces a concise summary persisted as _workflowArtifacts so + follow-up rounds have immediate context (file IDs, names, actions). + """ + ops = [] + for log in roundLogs: + for tc in log.toolCalls: + if tc.toolName not in _ARTIFACT_TOOLS or not tc.success: + continue + ops.append(f"- {tc.resultData}" if tc.resultData else f"- {tc.toolName}") + + if not ops: + return "" + return "File operations in this run:\n" + "\n".join(ops) diff --git a/modules/serviceCenter/services/serviceAgent/conversationManager.py b/modules/serviceCenter/services/serviceAgent/conversationManager.py index 86d6714c..79570c03 100644 --- a/modules/serviceCenter/services/serviceAgent/conversationManager.py +++ b/modules/serviceCenter/services/serviceAgent/conversationManager.py @@ -296,6 +296,29 @@ def buildSystemPrompt( "Think step by step. Call tools when you need information or need to perform actions. " "When you have enough information to answer, respond directly without calling tools.\n\n" ) + + prompt += ( + "## Working Guidelines\n\n" + "### Workflow Context\n" + "When continuing a workflow (follow-up message), the Relevant Knowledge section contains " + "artifacts from previous rounds (file IDs, operations). Use this context instead of " + "re-searching or re-listing files.\n\n" + "### Efficient File Editing\n" + "- Use readFile with offset/limit to read specific line ranges of large files.\n" + "- Use searchInFileContent to find text before editing.\n" + "- Use replaceInFile for targeted edits (preferred over rewriting entire files).\n" + "- Use writeFile(mode='overwrite') only when the entire content must change.\n\n" + "### Large Content Strategy\n" + "- For content larger than ~8000 characters: use writeFile(mode='create') for the first " + "part, then writeFile(fileId=..., mode='append') for subsequent parts.\n" + "- Split large documentation into multiple focused files rather than one huge document.\n" + "- Structure outputs so files reference each other (e.g. index.md linking to sections).\n\n" + "### Code Generation\n" + "- Prefer modular file structures over monolithic files.\n" + "- When generating applications, create separate files for logical components.\n" + "- Always plan the structure before writing code.\n\n" + ) + if toolsFormatted: prompt += f"Available Tools:\n{toolsFormatted}\n\n" prompt += ( diff --git a/modules/serviceCenter/services/serviceAgent/datamodelAgent.py b/modules/serviceCenter/services/serviceAgent/datamodelAgent.py index c70d8344..f682e705 100644 --- a/modules/serviceCenter/services/serviceAgent/datamodelAgent.py +++ b/modules/serviceCenter/services/serviceAgent/datamodelAgent.py @@ -111,6 +111,7 @@ class ToolCallLog(BaseModel): success: bool = True durationMs: int = 0 error: Optional[str] = None + resultData: str = Field(default="", description="Short result summary for artifact tracking") class AgentRoundLog(BaseModel): diff --git a/modules/serviceCenter/services/serviceAgent/featureDataAgent.py b/modules/serviceCenter/services/serviceAgent/featureDataAgent.py new file mode 100644 index 00000000..e36745df --- /dev/null +++ b/modules/serviceCenter/services/serviceAgent/featureDataAgent.py @@ -0,0 +1,253 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Feature Data Sub-Agent. + +Specialized mini-agent that queries feature-instance data tables. Receives +schema context (fields, descriptions) for the selected tables and has two +tools: browseTable and queryTable. Runs its own agent loop (max 5 rounds, +low budget) and returns structured results back to the main agent. +""" + +import json +import logging +from typing import Any, Callable, Awaitable, Dict, List, Optional + +from modules.datamodels.datamodelAi import ( + AiCallRequest, AiCallOptions, AiCallResponse, OperationTypeEnum, +) +from modules.serviceCenter.services.serviceAgent.agentLoop import runAgentLoop +from modules.serviceCenter.services.serviceAgent.datamodelAgent import ( + AgentConfig, AgentEvent, AgentEventTypeEnum, ToolResult, +) +from modules.serviceCenter.services.serviceAgent.toolRegistry import ToolRegistry +from modules.serviceCenter.services.serviceAgent.featureDataProvider import FeatureDataProvider + +logger = logging.getLogger(__name__) + +_MAX_ROUNDS = 5 +_MAX_COST_CHF = 0.10 + + +async def runFeatureDataAgent( + question: str, + featureInstanceId: str, + featureCode: str, + selectedTables: List[Dict[str, Any]], + mandateId: str, + userId: str, + aiCallFn: Callable[[AiCallRequest], Awaitable[AiCallResponse]], + dbConnector, + instanceLabel: str = "", +) -> str: + """Run the feature data sub-agent and return the textual result. + + Args: + question: The user/main-agent question to answer using feature data. + featureInstanceId: Feature instance to scope queries. + featureCode: Feature code (trustee, commcoach, ...). + selectedTables: List of DATA_OBJECT dicts the user selected. + mandateId: Mandate scope. + userId: Calling user ID. + aiCallFn: AI call function (with billing). + dbConnector: DatabaseConnector for queries. + instanceLabel: Human-readable instance name for context. + + Returns: + Plain-text answer produced by the sub-agent. + """ + + provider = FeatureDataProvider(dbConnector) + registry = _buildSubAgentTools(provider, featureInstanceId, mandateId) + + for tbl in selectedTables: + meta = tbl.get("meta", {}) + tableName = meta.get("table", "") + if tableName: + realCols = provider.getActualColumns(tableName) + if realCols: + meta["fields"] = realCols + + schemaContext = _buildSchemaContext(featureCode, instanceLabel, selectedTables) + prompt = f"{schemaContext}\n\nUser question:\n{question}" + + config = AgentConfig(maxRounds=_MAX_ROUNDS, maxCostCHF=_MAX_COST_CHF) + + async def _getWorkflowCost() -> float: + return 0.0 + + result = "" + async for event in runAgentLoop( + prompt=prompt, + toolRegistry=registry, + config=config, + aiCallFn=aiCallFn, + getWorkflowCostFn=_getWorkflowCost, + workflowId=f"fda-{featureInstanceId[:8]}", + userId=userId, + featureInstanceId=featureInstanceId, + mandateId=mandateId, + ): + if event.type == AgentEventTypeEnum.FINAL and event.content: + result = event.content + elif event.type == AgentEventTypeEnum.MESSAGE and event.content: + result += event.content + + return result or "(no data returned by feature agent)" + + +# ------------------------------------------------------------------ +# tool registration +# ------------------------------------------------------------------ + +def _buildSubAgentTools( + provider: FeatureDataProvider, + featureInstanceId: str, + mandateId: str, +) -> ToolRegistry: + """Register browseTable and queryTable as sub-agent tools.""" + registry = ToolRegistry() + + async def _browseTable(args: Dict[str, Any], context: Dict[str, Any]): + tableName = args.get("tableName", "") + limit = args.get("limit", 50) + offset = args.get("offset", 0) + fields = args.get("fields") + if not tableName: + return ToolResult(toolCallId="", toolName="browseTable", success=False, error="tableName required") + result = provider.browseTable( + tableName=tableName, + featureInstanceId=featureInstanceId, + mandateId=mandateId, + fields=fields, + limit=min(limit, 200), + offset=offset, + ) + return ToolResult( + toolCallId="", toolName="browseTable", + success="error" not in result, + data=json.dumps(result, default=str, ensure_ascii=False)[:30000], + error=result.get("error"), + ) + + async def _queryTable(args: Dict[str, Any], context: Dict[str, Any]): + tableName = args.get("tableName", "") + filters = args.get("filters", []) + fields = args.get("fields") + orderBy = args.get("orderBy") + limit = args.get("limit", 50) + offset = args.get("offset", 0) + if not tableName: + return ToolResult(toolCallId="", toolName="queryTable", success=False, error="tableName required") + result = provider.queryTable( + tableName=tableName, + featureInstanceId=featureInstanceId, + mandateId=mandateId, + filters=filters, + fields=fields, + orderBy=orderBy, + limit=min(limit, 200), + offset=offset, + ) + return ToolResult( + toolCallId="", toolName="queryTable", + success="error" not in result, + data=json.dumps(result, default=str, ensure_ascii=False)[:30000], + error=result.get("error"), + ) + + registry.register( + "browseTable", _browseTable, + description="List rows from a feature data table with pagination.", + parameters={ + "type": "object", + "properties": { + "tableName": {"type": "string", "description": "Name of the table to browse"}, + "fields": { + "type": "array", "items": {"type": "string"}, + "description": "Optional list of fields to return (default: all)", + }, + "limit": {"type": "integer", "description": "Max rows to return (default 50, max 200)"}, + "offset": {"type": "integer", "description": "Row offset for pagination"}, + }, + "required": ["tableName"], + }, + readOnly=True, + ) + + registry.register( + "queryTable", _queryTable, + description=( + "Query a feature data table with filters, field selection, and ordering. " + "Filters: [{\"field\": \"status\", \"op\": \"=\", \"value\": \"active\"}]. " + "Operators: =, !=, >, <, >=, <=, LIKE, ILIKE, IS NULL, IS NOT NULL." + ), + parameters={ + "type": "object", + "properties": { + "tableName": {"type": "string", "description": "Name of the table to query"}, + "filters": { + "type": "array", + "items": { + "type": "object", + "properties": { + "field": {"type": "string"}, + "op": {"type": "string"}, + "value": {}, + }, + }, + "description": "Filter conditions", + }, + "fields": { + "type": "array", "items": {"type": "string"}, + "description": "Optional list of fields to return", + }, + "orderBy": {"type": "string", "description": "Field name to order by"}, + "limit": {"type": "integer", "description": "Max rows (default 50, max 200)"}, + "offset": {"type": "integer", "description": "Row offset"}, + }, + "required": ["tableName"], + }, + readOnly=True, + ) + + return registry + + +# ------------------------------------------------------------------ +# context building +# ------------------------------------------------------------------ + +def _buildSchemaContext( + featureCode: str, + instanceLabel: str, + selectedTables: List[Dict[str, Any]], +) -> str: + """Build a system-level context block describing available tables.""" + parts = [ + f"You are a data query assistant for the '{featureCode}' feature", + ] + if instanceLabel: + parts[0] += f' (instance: "{instanceLabel}")' + parts[0] += "." + parts.append( + "You have access to the following data tables. " + "Use browseTable to list rows and queryTable to filter/search." + ) + parts.append("") + + for obj in selectedTables: + meta = obj.get("meta", {}) + tbl = meta.get("table", "?") + fields = meta.get("fields", []) + label = obj.get("label", {}) + labelStr = label.get("en") or label.get("de") or tbl + parts.append(f"Table: {tbl} ({labelStr})") + if fields: + parts.append(f" Fields: {', '.join(fields)}") + parts.append("") + + parts.append( + "Answer the user's question using the data from these tables. " + "Be precise, cite row counts, and format data clearly." + ) + return "\n".join(parts) diff --git a/modules/serviceCenter/services/serviceAgent/featureDataProvider.py b/modules/serviceCenter/services/serviceAgent/featureDataProvider.py new file mode 100644 index 00000000..40bf0c6b --- /dev/null +++ b/modules/serviceCenter/services/serviceAgent/featureDataProvider.py @@ -0,0 +1,215 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Generic data provider for querying feature-instance tables. + +Uses the RBAC catalog's DATA_OBJECTS metadata (table name, fields) and the +DB connector to execute scoped, read-only queries against any registered +feature table. All queries are automatically filtered by featureInstanceId +and mandateId so data isolation is guaranteed. +""" + +import logging +import json +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + +_ALLOWED_OPERATORS = {"=", "!=", ">", "<", ">=", "<=", "LIKE", "ILIKE", "IS NULL", "IS NOT NULL"} + + +class FeatureDataProvider: + """Reads feature-instance data from the DB using DATA_OBJECTS metadata.""" + + def __init__(self, dbConnector): + """ + Args: + dbConnector: A connectorDbPostgre.DatabaseConnector with an open connection. + """ + self._db = dbConnector + + # ------------------------------------------------------------------ + # public API (called by FeatureDataAgent tools) + # ------------------------------------------------------------------ + + def getAvailableTables(self, featureCode: str) -> List[Dict[str, Any]]: + """Return DATA_OBJECTS registered for *featureCode*.""" + from modules.security.rbacCatalog import getCatalogService + catalog = getCatalogService() + return catalog.getDataObjects(featureCode) + + def getTableSchema(self, featureCode: str, tableName: str) -> Optional[Dict[str, Any]]: + """Return the DATA_OBJECT entry for a specific table.""" + for obj in self.getAvailableTables(featureCode): + if obj.get("meta", {}).get("table") == tableName: + return obj + return None + + def getActualColumns(self, tableName: str) -> List[str]: + """Read real column names from PostgreSQL information_schema.""" + try: + conn = self._db.connection + with conn.cursor() as cur: + cur.execute( + "SELECT column_name FROM information_schema.columns " + "WHERE table_schema = 'public' AND LOWER(table_name) = LOWER(%s) " + "ORDER BY ordinal_position", + [tableName], + ) + cols = [row["column_name"] for row in cur.fetchall()] + return [c for c in cols if not c.startswith("_")] + except Exception as e: + logger.warning(f"getActualColumns({tableName}) failed: {e}") + return [] + + def browseTable( + self, + tableName: str, + featureInstanceId: str, + mandateId: str, + fields: List[str] = None, + limit: int = 50, + offset: int = 0, + ) -> Dict[str, Any]: + """List rows from a feature table with pagination. + + Returns ``{"rows": [...], "total": N, "limit": L, "offset": O}``. + """ + _validateTableName(tableName) + scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId) + + try: + conn = self._db.connection + with conn.cursor() as cur: + countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {scopeFilter["where"]}' + cur.execute(countSql, scopeFilter["params"]) + total = cur.fetchone()["count"] if cur.rowcount else 0 + + selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*" + dataSql = ( + f'SELECT {selectCols} FROM "{tableName}" ' + f'WHERE {scopeFilter["where"]} ' + f'ORDER BY "id" LIMIT %s OFFSET %s' + ) + cur.execute(dataSql, scopeFilter["params"] + [limit, offset]) + rows = [_serializeRow(dict(r)) for r in cur.fetchall()] + + return {"rows": rows, "total": total, "limit": limit, "offset": offset} + except Exception as e: + logger.error(f"browseTable({tableName}) failed: {e}") + return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)} + + def queryTable( + self, + tableName: str, + featureInstanceId: str, + mandateId: str, + filters: List[Dict[str, Any]] = None, + fields: List[str] = None, + orderBy: str = None, + limit: int = 50, + offset: int = 0, + ) -> Dict[str, Any]: + """Query a feature table with optional filters. + + ``filters`` is a list of ``{"field": "x", "op": "=", "value": "y"}``. + """ + _validateTableName(tableName) + scopeFilter = _buildScopeFilter(tableName, featureInstanceId, mandateId) + extraWhere, extraParams = _buildFilterClauses(filters) + + fullWhere = scopeFilter["where"] + allParams = list(scopeFilter["params"]) + if extraWhere: + fullWhere += " AND " + extraWhere + allParams.extend(extraParams) + + try: + conn = self._db.connection + with conn.cursor() as cur: + countSql = f'SELECT COUNT(*) FROM "{tableName}" WHERE {fullWhere}' + cur.execute(countSql, allParams) + total = cur.fetchone()["count"] if cur.rowcount else 0 + + selectCols = ", ".join(f'"{f}"' for f in fields) if fields else "*" + orderClause = f'ORDER BY "{orderBy}"' if orderBy and _isValidIdentifier(orderBy) else 'ORDER BY "id"' + dataSql = ( + f'SELECT {selectCols} FROM "{tableName}" ' + f'WHERE {fullWhere} {orderClause} LIMIT %s OFFSET %s' + ) + cur.execute(dataSql, allParams + [limit, offset]) + rows = [_serializeRow(dict(r)) for r in cur.fetchall()] + + return {"rows": rows, "total": total, "limit": limit, "offset": offset} + except Exception as e: + logger.error(f"queryTable({tableName}) failed: {e}") + return {"rows": [], "total": 0, "limit": limit, "offset": offset, "error": str(e)} + + +# ------------------------------------------------------------------ +# helpers +# ------------------------------------------------------------------ + +def _validateTableName(tableName: str): + if not tableName or not _isValidIdentifier(tableName): + raise ValueError(f"Invalid table name: {tableName}") + + +def _isValidIdentifier(name: str) -> bool: + """Only allow alphanumeric + underscore to prevent SQL injection.""" + return name.isidentifier() + + +def _buildScopeFilter(tableName: str, featureInstanceId: str, mandateId: str) -> Dict[str, Any]: + """Build the mandatory WHERE clause that scopes rows to the feature instance. + + Feature tables usually have either ``featureInstanceId`` or a combination + of ``mandateId`` + an org/context FK. We try ``featureInstanceId`` first, + then fall back to ``mandateId``. + """ + conditions = [] + params = [] + + conditions.append('"featureInstanceId" = %s') + params.append(featureInstanceId) + + if mandateId: + conditions.append('"mandateId" = %s') + params.append(mandateId) + + return {"where": " AND ".join(conditions), "params": params} + + +def _buildFilterClauses(filters: Optional[List[Dict[str, Any]]]) -> tuple: + """Convert agent-provided filter dicts into safe SQL.""" + if not filters: + return "", [] + + parts = [] + params = [] + for f in filters: + field = f.get("field", "") + op = (f.get("op") or "=").upper() + value = f.get("value") + + if not field or not _isValidIdentifier(field): + continue + if op not in _ALLOWED_OPERATORS: + continue + + if op in ("IS NULL", "IS NOT NULL"): + parts.append(f'"{field}" {op}') + else: + parts.append(f'"{field}" {op} %s') + params.append(value) + + return " AND ".join(parts), params + + +def _serializeRow(row: Dict[str, Any]) -> Dict[str, Any]: + """Ensure all values are JSON-serializable.""" + for k, v in row.items(): + if isinstance(v, (bytes, bytearray)): + row[k] = f"" + elif hasattr(v, "isoformat"): + row[k] = v.isoformat() + return row diff --git a/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py b/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py index 2de2ebd4..905621b9 100644 --- a/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py +++ b/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py @@ -280,7 +280,7 @@ class AgentService: return registry async def _persistTrace(self, workflowId: str, summaryData: Dict[str, Any]): - """Persist the agent trace as a workflow memory entry in the knowledge store.""" + """Persist the agent trace and workflow artifacts in the knowledge store.""" try: knowledgeService = self._getService("knowledge") userId = self.services.user.id if self.services.user else "" @@ -297,6 +297,19 @@ class AgentService: value=traceValue, source="agent", ) + + artifacts = summaryData.get("artifacts", "") + if artifacts: + await knowledgeService.storeEntity( + workflowId=workflowId, + userId=userId, + featureInstanceId=featureInstanceId, + key="_workflowArtifacts", + value=artifacts, + source="agent", + ) + logger.info(f"Persisted workflow artifacts for workflow {workflowId}") + logger.info(f"Persisted agent trace for workflow {workflowId}") except Exception as e: logger.warning(f"Could not persist agent trace: {e}") @@ -372,8 +385,23 @@ def _registerCoreTools(registry: ToolRegistry, services): # ---- Read-only tools ---- + def _applyOffsetLimit(text: str, offset: int = None, limit: int = None) -> str: + """Apply line-based offset/limit to text content, returning numbered lines.""" + if offset is None and limit is None: + return None + lines = text.split("\n") + totalLines = len(lines) + startLine = max(0, (offset or 1) - 1) + endLine = min(totalLines, startLine + (limit or 200)) + selected = lines[startLine:endLine] + numbered = "\n".join(f"{i + startLine + 1}|{line}" for i, line in enumerate(selected)) + header = f"[Lines {startLine + 1}-{endLine} of {totalLines} total]\n" + return header + numbered + async def _readFile(args: Dict[str, Any], context: Dict[str, Any]): fileId = args.get("fileId", "") + offset = args.get("offset") + limit = args.get("limit") if not fileId: return ToolResult(toolCallId="", toolName="readFile", success=False, error="fileId is required") try: @@ -390,8 +418,11 @@ def _registerCoreTools(registry: ToolRegistry, services): ] if textChunks: assembled = "\n\n".join(c["data"] for c in textChunks) + chunked = _applyOffsetLimit(assembled, offset, limit) + if chunked is not None: + return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked) if len(assembled) > _MAX_TOOL_RESULT_CHARS: - assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}]" + assembled = assembled[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(assembled)}. Use offset/limit to read specific sections.]" return ToolResult( toolCallId="", toolName="readFile", success=True, data=assembled, @@ -466,8 +497,11 @@ def _registerCoreTools(registry: ToolRegistry, services): textParts = [o["data"] for o in contentObjects if o["contentType"] != "image"] if textParts: joined = "\n\n".join(textParts) + chunked = _applyOffsetLimit(joined, offset, limit) + if chunked is not None: + return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked) if len(joined) > _MAX_TOOL_RESULT_CHARS: - joined = joined[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(joined)}]" + joined = joined[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(joined)}. Use offset/limit to read specific sections.]" return ToolResult( toolCallId="", toolName="readFile", success=True, data=joined, @@ -493,8 +527,11 @@ def _registerCoreTools(registry: ToolRegistry, services): try: text = rawBytes.decode(encoding) if text.strip(): + chunked = _applyOffsetLimit(text, offset, limit) + if chunked is not None: + return ToolResult(toolCallId="", toolName="readFile", success=True, data=chunked) if len(text) > _MAX_TOOL_RESULT_CHARS: - text = text[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(text)}]" + text = text[:_MAX_TOOL_RESULT_CHARS] + f"\n\n[Truncated – showing first {_MAX_TOOL_RESULT_CHARS} chars of {len(text)}. Use offset/limit to read specific sections.]" return ToolResult( toolCallId="", toolName="readFile", success=True, data=text, @@ -527,20 +564,44 @@ def _registerCoreTools(registry: ToolRegistry, services): except Exception as e: return ToolResult(toolCallId="", toolName="listFiles", success=False, error=str(e)) - async def _searchFiles(args: Dict[str, Any], context: Dict[str, Any]): + async def _searchInFileContent(args: Dict[str, Any], context: Dict[str, Any]): + import re as _re + fileId = args.get("fileId", "") query = args.get("query", "") - if not query: - return ToolResult(toolCallId="", toolName="searchFiles", success=False, error="query is required") + contextLines = args.get("contextLines", 2) + if not fileId or not query: + return ToolResult(toolCallId="", toolName="searchInFileContent", success=False, error="fileId and query are required") try: chatService = services.chat - files = chatService.listFiles(search=query, tags=args.get("tags")) - fileList = "\n".join( - f"- {f.get('fileName', 'unknown')} (id: {f.get('id', '?')})" - for f in files - ) if files else "No files matching query." - return ToolResult(toolCallId="", toolName="searchFiles", success=True, data=fileList) + rawBytes = chatService.getFileData(fileId) + if not rawBytes: + return ToolResult(toolCallId="", toolName="searchInFileContent", success=False, error="File data not accessible") + try: + content = rawBytes.decode("utf-8") + except UnicodeDecodeError: + content = rawBytes.decode("latin-1", errors="replace") + + lines = content.split("\n") + pattern = _re.compile(_re.escape(query), _re.IGNORECASE) + matches = [] + for i, line in enumerate(lines): + if pattern.search(line): + start = max(0, i - contextLines) + end = min(len(lines), i + contextLines + 1) + snippet = "\n".join(f"{j + 1}|{lines[j]}" for j in range(start, end)) + matches.append(snippet) + + if not matches: + return ToolResult(toolCallId="", toolName="searchInFileContent", success=True, + data=f"No matches for '{query}' in file.") + + shown = matches[:20] + resultText = f"Found {len(matches)} match(es) for '{query}':\n\n" + "\n---\n".join(shown) + if len(matches) > 20: + resultText += f"\n\n... and {len(matches) - 20} more matches" + return ToolResult(toolCallId="", toolName="searchInFileContent", success=True, data=resultText) except Exception as e: - return ToolResult(toolCallId="", toolName="searchFiles", success=False, error=str(e)) + return ToolResult(toolCallId="", toolName="searchInFileContent", success=False, error=str(e)) async def _listFolders(args: Dict[str, Any], context: Dict[str, Any]): try: @@ -621,22 +682,63 @@ def _registerCoreTools(registry: ToolRegistry, services): return ToolResult(toolCallId="", toolName="createFolder", success=False, error=str(e)) async def _writeFile(args: Dict[str, Any], context: Dict[str, Any]): - name = args.get("name", "") content = args.get("content", "") - if not name: - return ToolResult(toolCallId="", toolName="writeFile", success=False, error="name is required") + mode = args.get("mode", "create") + fileId = args.get("fileId", "") + name = args.get("name", "") + + if not content: + return ToolResult(toolCallId="", toolName="writeFile", success=False, error="content is required") + try: chatService = services.chat - fileItem, _ = chatService.interfaceDbComponent.saveUploadedFile( - content.encode("utf-8"), name - ) + dbMgmt = chatService.interfaceDbComponent + + if mode == "append": + if not fileId: + return ToolResult(toolCallId="", toolName="writeFile", success=False, error="fileId is required for mode=append") + file = dbMgmt.getFile(fileId) + if not file: + return ToolResult(toolCallId="", toolName="writeFile", success=False, error=f"File {fileId} not found") + existingData = dbMgmt.getFileData(fileId) or b"" + try: + existingText = existingData.decode("utf-8") + except UnicodeDecodeError: + existingText = existingData.decode("latin-1", errors="replace") + newContent = existingText + content + dbMgmt.updateFileData(fileId, newContent.encode("utf-8")) + dbMgmt.updateFile(fileId, {"fileSize": len(newContent.encode("utf-8"))}) + return ToolResult( + toolCallId="", toolName="writeFile", success=True, + data=f"Appended {len(content)} chars to '{file.fileName}' (id: {fileId}, total: {len(newContent)} chars)", + sideEvents=[{"type": "fileUpdated", "data": {"fileId": fileId, "fileName": file.fileName}}], + ) + + if mode == "overwrite": + if not fileId: + return ToolResult(toolCallId="", toolName="writeFile", success=False, error="fileId is required for mode=overwrite") + file = dbMgmt.getFile(fileId) + if not file: + return ToolResult(toolCallId="", toolName="writeFile", success=False, error=f"File {fileId} not found") + dbMgmt.updateFileData(fileId, content.encode("utf-8")) + dbMgmt.updateFile(fileId, {"fileSize": len(content.encode("utf-8"))}) + return ToolResult( + toolCallId="", toolName="writeFile", success=True, + data=f"Overwritten '{file.fileName}' (id: {fileId}, {len(content)} chars)", + sideEvents=[{"type": "fileUpdated", "data": {"fileId": fileId, "fileName": file.fileName}}], + ) + + # mode == "create" (default) + if not name: + return ToolResult(toolCallId="", toolName="writeFile", success=False, error="name is required for mode=create") + fileItem, _ = dbMgmt.saveUploadedFile(content.encode("utf-8"), name) fiId = context.get("featureInstanceId") or (services.featureInstanceId if services else "") if fiId: - chatService.interfaceDbComponent.updateFile(fileItem.id, {"featureInstanceId": fiId}) + dbMgmt.updateFile(fileItem.id, {"featureInstanceId": fiId}) if args.get("folderId"): - chatService.interfaceDbComponent.updateFile(fileItem.id, {"folderId": args["folderId"]}) + dbMgmt.updateFile(fileItem.id, {"folderId": args["folderId"]}) if args.get("tags"): - chatService.interfaceDbComponent.updateFile(fileItem.id, {"tags": args["tags"]}) + dbMgmt.updateFile(fileItem.id, {"tags": args["tags"]}) return ToolResult( toolCallId="", toolName="writeFile", success=True, data=f"File '{name}' created (id: {fileItem.id})", @@ -657,10 +759,18 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "readFile", _readFile, - description="Read the content of a file by its fileId.", + description=( + "Read the content of a file. Returns full content by default. " + "For large files, use offset and limit to read specific line ranges. " + "When truncated, the response tells the total line count so you can paginate." + ), parameters={ "type": "object", - "properties": {"fileId": {"type": "string", "description": "The file ID to read"}}, + "properties": { + "fileId": {"type": "string", "description": "The file ID to read"}, + "offset": {"type": "integer", "description": "Start reading from this line number (1-based). Omit for full file."}, + "limit": {"type": "integer", "description": "Max number of lines to return (default: all). Use with offset for chunked reading."}, + }, "required": ["fileId"] }, readOnly=True @@ -668,7 +778,10 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "listFiles", _listFiles, - description="List LOCAL workspace files (uploaded/generated). NOT for external data sources -- use browseDataSource instead.", + description=( + "List files in the local workspace. Filter by folder, tags, or search term. " + "For external data sources, use browseDataSource instead." + ), parameters={ "type": "object", "properties": { @@ -681,22 +794,27 @@ def _registerCoreTools(registry: ToolRegistry, services): ) registry.register( - "searchFiles", _searchFiles, - description="Search LOCAL workspace files by name, description, or tags. NOT for external data sources -- use searchDataSource instead.", + "searchInFileContent", _searchInFileContent, + description=( + "Search for text within a file's content. Returns matching lines with context. " + "Case-insensitive. Use to locate specific text before using replaceInFile, " + "or to find relevant sections in a large file before reading with offset/limit." + ), parameters={ "type": "object", "properties": { - "query": {"type": "string", "description": "Search query"}, - "tags": {"type": "array", "items": {"type": "string"}, "description": "Additional tag filter"}, + "fileId": {"type": "string", "description": "The file ID to search in"}, + "query": {"type": "string", "description": "Text to search for (case-insensitive)"}, + "contextLines": {"type": "integer", "description": "Number of context lines around each match (default: 2)"}, }, - "required": ["query"] + "required": ["fileId", "query"] }, readOnly=True ) registry.register( "listFolders", _listFolders, - description="List LOCAL workspace folders. NOT for external data sources -- use browseDataSource instead.", + description="List folders in the local workspace. For external data sources, use browseDataSource instead.", parameters={ "type": "object", "properties": { @@ -708,7 +826,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "webSearch", _webSearch, - description="Search the web for information.", + description="Search the web for general information. Use readUrl to fetch content from a known URL instead.", parameters={ "type": "object", "properties": {"query": {"type": "string", "description": "Search query"}}, @@ -719,7 +837,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "tagFile", _tagFile, - description="Set tags on a file for categorization.", + description="Set or update tags on a file for categorization and filtering via listFiles.", parameters={ "type": "object", "properties": { @@ -733,7 +851,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "moveFile", _moveFile, - description="Move a file to a different folder.", + description="Move a file to a different folder in the local workspace.", parameters={ "type": "object", "properties": { @@ -747,7 +865,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "createFolder", _createFolder, - description="Create a new file folder.", + description="Create a new folder in the local workspace.", parameters={ "type": "object", "properties": { @@ -761,16 +879,24 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "writeFile", _writeFile, - description="Create a new file with text content.", + description=( + "Create, append, or overwrite a file. Modes:\n" + "- create (default): create a new file (name required).\n" + "- append: append content to an existing file (fileId required). " + "Use for large content that exceeds a single tool call (~8000 chars per call).\n" + "- overwrite: replace entire file content (fileId required)." + ), parameters={ "type": "object", "properties": { - "name": {"type": "string", "description": "File name including extension"}, - "content": {"type": "string", "description": "File content as text"}, - "folderId": {"type": "string", "description": "Target folder ID"}, - "tags": {"type": "array", "items": {"type": "string"}, "description": "Tags"}, + "name": {"type": "string", "description": "File name (required for mode=create)"}, + "content": {"type": "string", "description": "Content to write/append"}, + "mode": {"type": "string", "enum": ["create", "append", "overwrite"], "description": "Write mode (default: create)"}, + "fileId": {"type": "string", "description": "File ID (required for mode=append/overwrite)"}, + "folderId": {"type": "string", "description": "Target folder ID (mode=create only)"}, + "tags": {"type": "array", "items": {"type": "string"}, "description": "Tags (mode=create only)"}, }, - "required": ["name", "content"] + "required": ["content"] }, readOnly=False ) @@ -866,7 +992,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "deleteFile", _deleteFile, - description="Delete a file from the workspace. Use when the user asks to remove or delete a file.", + description="Permanently delete a file from the local workspace.", parameters={ "type": "object", "properties": { @@ -879,7 +1005,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "renameFile", _renameFile, - description="Rename a file in the workspace.", + description="Rename a file in the local workspace. Include the file extension in the new name.", parameters={ "type": "object", "properties": { @@ -1000,34 +1126,51 @@ def _registerCoreTools(registry: ToolRegistry, services): except Exception as e: return ToolResult(toolCallId="", toolName="copyFile", success=False, error=str(e)) - async def _editFile(args: Dict[str, Any], context: Dict[str, Any]): + async def _replaceInFile(args: Dict[str, Any], context: Dict[str, Any]): fileId = args.get("fileId", "") - content = args.get("content", "") - if not fileId or not content: - return ToolResult(toolCallId="", toolName="editFile", success=False, error="fileId and content are required") + oldText = args.get("oldText", "") + newText = args.get("newText", "") + replaceAll = args.get("replaceAll", False) + if not fileId or not oldText: + return ToolResult(toolCallId="", toolName="replaceInFile", success=False, error="fileId and oldText are required") try: chatService = services.chat dbMgmt = chatService.interfaceDbComponent file = dbMgmt.getFile(fileId) if not file: - return ToolResult(toolCallId="", toolName="editFile", success=False, error=f"File {fileId} not found") + return ToolResult(toolCallId="", toolName="replaceInFile", success=False, error=f"File {fileId} not found") if not dbMgmt.isTextMimeType(file.mimeType): return ToolResult( - toolCallId="", toolName="editFile", success=False, + toolCallId="", toolName="replaceInFile", success=False, error=f"Cannot edit binary file ({file.mimeType}). Only text-based files are supported." ) - oldContent = "" - oldData = dbMgmt.getFileData(fileId) - if oldData: - try: - oldContent = oldData.decode("utf-8") - except UnicodeDecodeError: - oldContent = "" + rawData = dbMgmt.getFileData(fileId) + if not rawData: + return ToolResult(toolCallId="", toolName="replaceInFile", success=False, error="File has no content") + try: + oldContent = rawData.decode("utf-8") + except UnicodeDecodeError: + return ToolResult(toolCallId="", toolName="replaceInFile", success=False, error="File content is not valid UTF-8 text") + + count = oldContent.count(oldText) + if count == 0: + return ToolResult( + toolCallId="", toolName="replaceInFile", success=False, + error="oldText not found in file. Use readFile or searchInFileContent to verify the exact text." + ) + if count > 1 and not replaceAll: + return ToolResult( + toolCallId="", toolName="replaceInFile", success=False, + error=f"oldText found {count} times. Set replaceAll=true or provide more surrounding context to make it unique." + ) + + newContent = oldContent.replace(oldText, newText) if replaceAll else oldContent.replace(oldText, newText, 1) editId = str(_uuid.uuid4()) + label = f"all {count} occurrences" if replaceAll else "1 occurrence" return ToolResult( - toolCallId="", toolName="editFile", success=True, - data=f"Edit proposed for '{file.fileName}'. Waiting for user review.", + toolCallId="", toolName="replaceInFile", success=True, + data=f"Edit proposed for '{file.fileName}': replaced {label}. Waiting for user review.", sideEvents=[{ "type": "fileEditProposal", "data": { @@ -1036,16 +1179,16 @@ def _registerCoreTools(registry: ToolRegistry, services): "fileName": file.fileName, "mimeType": file.mimeType, "oldContent": oldContent, - "newContent": content, + "newContent": newContent, }, }], ) except Exception as e: - return ToolResult(toolCallId="", toolName="editFile", success=False, error=str(e)) + return ToolResult(toolCallId="", toolName="replaceInFile", success=False, error=str(e)) registry.register( "deleteFolder", _deleteFolder, - description="Delete a folder. Set recursive=true to delete folder with all contents.", + description="Delete a folder from the local workspace. Set recursive=true to delete all contents.", parameters={ "type": "object", "properties": { @@ -1059,7 +1202,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "renameFolder", _renameFolder, - description="Rename a folder. Folder names must be unique within their parent.", + description="Rename a folder in the local workspace.", parameters={ "type": "object", "properties": { @@ -1073,7 +1216,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "moveFolder", _moveFolder, - description="Move a folder to a different parent folder. Cannot move a folder into its own subtree.", + description="Move a folder to a different parent in the local workspace.", parameters={ "type": "object", "properties": { @@ -1087,7 +1230,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "copyFile", _copyFile, - description="Create a full copy of a file. The copy is independent and can be edited separately.", + description="Create an independent copy of a file in the local workspace.", parameters={ "type": "object", "properties": { @@ -1101,19 +1244,22 @@ def _registerCoreTools(registry: ToolRegistry, services): ) registry.register( - "editFile", _editFile, + "replaceInFile", _replaceInFile, description=( - "Propose an edit to an existing text file. The change is shown to the user " - "for review (accept/reject) before being applied. Only works for text-based " - "files (text/*, application/json, etc.). For binary files, create a new file instead." + "Replace specific text in an existing file. The edit is shown to the user for " + "review (accept/reject) before being applied. Provide enough surrounding context " + "in oldText to make the match unique (at least 2-3 lines). " + "Use readFile or searchInFileContent first to identify the exact text to replace." ), parameters={ "type": "object", "properties": { "fileId": {"type": "string", "description": "The file ID to edit"}, - "content": {"type": "string", "description": "New file content (replaces entire file content)"}, + "oldText": {"type": "string", "description": "Exact text to find and replace (must be unique unless replaceAll=true)"}, + "newText": {"type": "string", "description": "The replacement text"}, + "replaceAll": {"type": "boolean", "description": "Replace all occurrences (default: false)"}, }, - "required": ["fileId", "content"] + "required": ["fileId", "oldText", "newText"] }, readOnly=False ) @@ -1150,79 +1296,13 @@ def _registerCoreTools(registry: ToolRegistry, services): except Exception as e: return ToolResult(toolCallId="", toolName="listConnections", success=False, error=str(e)) - async def _externalBrowse(args: Dict[str, Any], context: Dict[str, Any]): - connectionId = args.get("connectionId", "") - service = args.get("service", "") - path = args.get("path", "/") - if not connectionId or not service: - return ToolResult(toolCallId="", toolName="externalBrowse", success=False, error="connectionId and service are required") - try: - from modules.connectors.connectorResolver import ConnectorResolver - resolver = ConnectorResolver( - services.getService("security"), - _buildResolverDb(), - ) - adapter = await resolver.resolveService(connectionId, service) - entries = await adapter.browse(path, filter=args.get("filter")) - entryLines = "\n".join( - f"- {'[DIR]' if e.isFolder else '[FILE]'} {e.name} ({e.size or '?'} bytes)" - for e in entries - ) if entries else "Empty directory." - return ToolResult(toolCallId="", toolName="externalBrowse", success=True, data=entryLines) - except Exception as e: - return ToolResult(toolCallId="", toolName="externalBrowse", success=False, error=str(e)) - - async def _externalDownload(args: Dict[str, Any], context: Dict[str, Any]): - connectionId = args.get("connectionId", "") - service = args.get("service", "") - path = args.get("path", "") - if not connectionId or not service or not path: - return ToolResult(toolCallId="", toolName="externalDownload", success=False, error="connectionId, service, and path are required") - try: - from modules.connectors.connectorResolver import ConnectorResolver - from modules.connectors.connectorProviderBase import DownloadResult as _DR - resolver = ConnectorResolver( - services.getService("security"), - _buildResolverDb(), - ) - adapter = await resolver.resolveService(connectionId, service) - result = await adapter.download(path) - - if isinstance(result, _DR): - fileBytes = result.data - fileName = result.fileName or path.split("/")[-1] or "downloaded_file" - else: - fileBytes = result - fileName = path.split("/")[-1] or "downloaded_file" - - if not fileBytes: - return ToolResult(toolCallId="", toolName="externalDownload", success=False, error="Download returned empty") - - chatService = services.chat - fileItem, _ = chatService.interfaceDbComponent.saveUploadedFile(fileBytes, fileName) - fid = fileItem.id if hasattr(fileItem, "id") else fileItem.get("id", "?") - fiId = context.get("featureInstanceId") or (services.featureInstanceId if services else "") - if fiId: - chatService.interfaceDbComponent.updateFile(fid, {"featureInstanceId": fiId}) - tempFolderId = _getOrCreateTempFolder(chatService) - if tempFolderId: - chatService.interfaceDbComponent.updateFile(fid, {"folderId": tempFolderId}) - ext = fileName.rsplit(".", 1)[-1].lower() if "." in fileName else "" - hint = "Use readFile to read text content." if ext in ("doc", "docx", "txt", "csv", "json", "xml", "html", "md", "rtf", "odt", "xls", "xlsx", "pptx", "eml", "msg") else "Use readFile to access the content." - return ToolResult( - toolCallId="", toolName="externalDownload", success=True, - data=f"Downloaded '{fileName}' ({len(fileBytes)} bytes) → local file id: {fid}. {hint}" - ) - except Exception as e: - return ToolResult(toolCallId="", toolName="externalDownload", success=False, error=str(e)) - - async def _externalUpload(args: Dict[str, Any], context: Dict[str, Any]): + async def _uploadToExternal(args: Dict[str, Any], context: Dict[str, Any]): connectionId = args.get("connectionId", "") service = args.get("service", "") path = args.get("path", "") fileId = args.get("fileId", "") if not connectionId or not service or not path or not fileId: - return ToolResult(toolCallId="", toolName="externalUpload", success=False, error="connectionId, service, path, and fileId are required") + return ToolResult(toolCallId="", toolName="uploadToExternal", success=False, error="connectionId, service, path, and fileId are required") try: from modules.connectors.connectorResolver import ConnectorResolver resolver = ConnectorResolver( @@ -1233,37 +1313,15 @@ def _registerCoreTools(registry: ToolRegistry, services): chatService = services.chat fileContent = chatService.getFileContent(fileId) if not fileContent: - return ToolResult(toolCallId="", toolName="externalUpload", success=False, error="File not found") + return ToolResult(toolCallId="", toolName="uploadToExternal", success=False, error="File not found") fileData = fileContent.get("data", b"") if isinstance(fileContent, dict) else b"" if isinstance(fileData, str): fileData = fileData.encode("utf-8") fileName = fileContent.get("fileName", "file") if isinstance(fileContent, dict) else "file" result = await adapter.upload(path, fileData, fileName) - return ToolResult(toolCallId="", toolName="externalUpload", success=True, data=str(result)) + return ToolResult(toolCallId="", toolName="uploadToExternal", success=True, data=str(result)) except Exception as e: - return ToolResult(toolCallId="", toolName="externalUpload", success=False, error=str(e)) - - async def _externalSearch(args: Dict[str, Any], context: Dict[str, Any]): - connectionId = args.get("connectionId", "") - service = args.get("service", "") - query = args.get("query", "") - if not connectionId or not service or not query: - return ToolResult(toolCallId="", toolName="externalSearch", success=False, error="connectionId, service, and query are required") - try: - from modules.connectors.connectorResolver import ConnectorResolver - resolver = ConnectorResolver( - services.getService("security"), - _buildResolverDb(), - ) - adapter = await resolver.resolveService(connectionId, service) - entries = await adapter.search(query, path=args.get("path")) - resultLines = "\n".join( - f"- {e.name} ({e.path})" - for e in entries - ) if entries else "No results found." - return ToolResult(toolCallId="", toolName="externalSearch", success=True, data=resultLines) - except Exception as e: - return ToolResult(toolCallId="", toolName="externalSearch", success=False, error=str(e)) + return ToolResult(toolCallId="", toolName="uploadToExternal", success=False, error=str(e)) async def _sendMail(args: Dict[str, Any], context: Dict[str, Any]): connectionId = args.get("connectionId", "") @@ -1293,48 +1351,22 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "listConnections", _listConnections, - description="List available external connections and their services.", + description="List the user's external connections (SharePoint, OneDrive, Outlook, etc.) and their IDs. Use with browseDataSource/uploadToExternal.", parameters={"type": "object", "properties": {}}, readOnly=True, ) registry.register( - "externalBrowse", _externalBrowse, - description="Browse files in an external source by connectionId+service. For ATTACHED data sources, prefer browseDataSource instead.", + "uploadToExternal", _uploadToExternal, + description=( + "Upload a local file to an external storage via connectionId+service. " + "Use listConnections to find available connections." + ), parameters={ "type": "object", "properties": { **_connToolParams, - "path": {"type": "string", "description": "Path to browse"}, - "filter": {"type": "string", "description": "Filter pattern (e.g. '*.pdf')"}, - }, - "required": ["connectionId", "service"], - }, - readOnly=True, - ) - - registry.register( - "externalDownload", _externalDownload, - description="Download a file from an external source into local storage + auto-index.", - parameters={ - "type": "object", - "properties": { - **_connToolParams, - "path": {"type": "string", "description": "File path to download"}, - }, - "required": ["connectionId", "service", "path"], - }, - readOnly=False, - ) - - registry.register( - "externalUpload", _externalUpload, - description="Upload a local file to an external data source.", - parameters={ - "type": "object", - "properties": { - **_connToolParams, - "path": {"type": "string", "description": "Destination path"}, + "path": {"type": "string", "description": "Destination path on the external service"}, "fileId": {"type": "string", "description": "Local file ID to upload"}, }, "required": ["connectionId", "service", "path", "fileId"], @@ -1342,24 +1374,9 @@ def _registerCoreTools(registry: ToolRegistry, services): readOnly=False, ) - registry.register( - "externalSearch", _externalSearch, - description="Search files in an external source by connectionId+service. For ATTACHED data sources, prefer searchDataSource instead.", - parameters={ - "type": "object", - "properties": { - **_connToolParams, - "query": {"type": "string", "description": "Search query"}, - "path": {"type": "string", "description": "Scope to a specific path"}, - }, - "required": ["connectionId", "service", "query"], - }, - readOnly=True, - ) - registry.register( "sendMail", _sendMail, - description="Send an email via a connected mail service (Outlook, Gmail).", + description="Send an email via a connected mail service (Outlook, Gmail). Use listConnections to find the connectionId.", parameters={ "type": "object", "properties": { @@ -1405,10 +1422,16 @@ def _registerCoreTools(registry: ToolRegistry, services): async def _browseDataSource(args: Dict[str, Any], context: Dict[str, Any]): dsId = args.get("dataSourceId", "") subPath = args.get("subPath", "") - if not dsId: - return ToolResult(toolCallId="", toolName="browseDataSource", success=False, error="dataSourceId is required") + directConnId = args.get("connectionId", "") + directService = args.get("service", "") + if not dsId and not (directConnId and directService): + return ToolResult(toolCallId="", toolName="browseDataSource", success=False, + error="Provide either dataSourceId OR connectionId+service") try: - connectionId, service, basePath = await _resolveDataSource(dsId) + if dsId: + connectionId, service, basePath = await _resolveDataSource(dsId) + else: + connectionId, service, basePath = directConnId, directService, args.get("path", "/") if subPath: if subPath.startswith("/"): browsePath = subPath @@ -1439,11 +1462,19 @@ def _registerCoreTools(registry: ToolRegistry, services): async def _searchDataSource(args: Dict[str, Any], context: Dict[str, Any]): dsId = args.get("dataSourceId", "") + directConnId = args.get("connectionId", "") + directService = args.get("service", "") query = args.get("query", "") - if not dsId or not query: - return ToolResult(toolCallId="", toolName="searchDataSource", success=False, error="dataSourceId and query are required") + if not query: + return ToolResult(toolCallId="", toolName="searchDataSource", success=False, error="query is required") + if not dsId and not (directConnId and directService): + return ToolResult(toolCallId="", toolName="searchDataSource", success=False, + error="Provide either dataSourceId OR connectionId+service") try: - connectionId, service, basePath = await _resolveDataSource(dsId) + if dsId: + connectionId, service, basePath = await _resolveDataSource(dsId) + else: + connectionId, service, basePath = directConnId, directService, args.get("path", "/") from modules.connectors.connectorResolver import ConnectorResolver resolver = ConnectorResolver( services.getService("security"), @@ -1463,14 +1494,22 @@ def _registerCoreTools(registry: ToolRegistry, services): async def _downloadFromDataSource(args: Dict[str, Any], context: Dict[str, Any]): dsId = args.get("dataSourceId", "") + directConnId = args.get("connectionId", "") + directService = args.get("service", "") filePath = args.get("filePath", "") fileName = args.get("fileName", "") - if not dsId or not filePath: - return ToolResult(toolCallId="", toolName="downloadFromDataSource", success=False, error="dataSourceId and filePath are required") + if not filePath: + return ToolResult(toolCallId="", toolName="downloadFromDataSource", success=False, error="filePath is required") + if not dsId and not (directConnId and directService): + return ToolResult(toolCallId="", toolName="downloadFromDataSource", success=False, + error="Provide either dataSourceId OR connectionId+service") try: from modules.connectors.connectorResolver import ConnectorResolver from modules.connectors.connectorProviderBase import DownloadResult as _DR - connectionId, service, basePath = await _resolveDataSource(dsId) + if dsId: + connectionId, service, basePath = await _resolveDataSource(dsId) + else: + connectionId, service, basePath = directConnId, directService, "/" fullPath = filePath if filePath.startswith("/") else f"{basePath.rstrip('/')}/{filePath}" resolver = ConnectorResolver( services.getService("security"), @@ -1525,42 +1564,59 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "browseDataSource", _browseDataSource, - description="Browse files AND folders in an ATTACHED data source by its dataSourceId. This is the PRIMARY tool for listing data source contents.", + description=( + "Browse files and folders in a data source. Accepts either:\n" + "- dataSourceId (for attached data sources shown in the prompt), OR\n" + "- connectionId + service (for direct connection access via listConnections)." + ), parameters={ "type": "object", "properties": { - "dataSourceId": {"type": "string", "description": "DataSource ID (from the attached data sources in the prompt)"}, - "subPath": {"type": "string", "description": "Optional sub-path within the data source to browse"}, - "filter": {"type": "string", "description": "Optional filter pattern (e.g. '*.pdf')"}, + "dataSourceId": {"type": "string", "description": "DataSource ID (from attached data sources)"}, + "connectionId": {"type": "string", "description": "UserConnection ID (alternative to dataSourceId)"}, + "service": {"type": "string", "description": "Service name (alternative to dataSourceId, e.g. sharepoint, onedrive)"}, + "path": {"type": "string", "description": "Root path (used with connectionId+service)"}, + "subPath": {"type": "string", "description": "Sub-path within the data source to browse"}, + "filter": {"type": "string", "description": "Filter pattern (e.g. '*.pdf')"}, }, - "required": ["dataSourceId"], }, readOnly=True, ) registry.register( "searchDataSource", _searchDataSource, - description="Search for files within an attached data source by query.", + description=( + "Search for files within a data source. Accepts either dataSourceId OR connectionId+service." + ), parameters={ "type": "object", "properties": { "dataSourceId": {"type": "string", "description": "DataSource ID"}, + "connectionId": {"type": "string", "description": "UserConnection ID (alternative to dataSourceId)"}, + "service": {"type": "string", "description": "Service name (alternative to dataSourceId)"}, + "path": {"type": "string", "description": "Scope path (used with connectionId+service)"}, "query": {"type": "string", "description": "Search query"}, }, - "required": ["dataSourceId", "query"], + "required": ["query"], }, readOnly=True, ) registry.register( "downloadFromDataSource", _downloadFromDataSource, - description="Download a file or email message from an attached data source into local storage. Returns the local file ID which can then be read with readFile. For email sources (Outlook, Gmail), this downloads the full email content -- browse/search only return subjects. Always provide the fileName if known.", + description=( + "Download a file or email from a data source into local storage. Returns a local file ID " + "to read with readFile. Accepts either dataSourceId OR connectionId+service. " + "For email sources (Outlook, Gmail), browse/search only return subjects -- use this to get full content." + ), parameters={ "type": "object", "properties": { "dataSourceId": {"type": "string", "description": "DataSource ID"}, - "filePath": {"type": "string", "description": "Path of the file to download (as returned by browseDataSource)"}, - "fileName": {"type": "string", "description": "Human-readable file name with extension (e.g. 'report.pdf'). Get this from browseDataSource results."}, + "connectionId": {"type": "string", "description": "UserConnection ID (alternative to dataSourceId)"}, + "service": {"type": "string", "description": "Service name (alternative to dataSourceId)"}, + "filePath": {"type": "string", "description": "Path of the file to download (from browseDataSource results)"}, + "fileName": {"type": "string", "description": "File name with extension (e.g. 'report.pdf')"}, }, "required": ["dataSourceId", "filePath"], }, @@ -1697,7 +1753,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "browseContainer", _browseContainer, - description="Browse the structural index of a file/container (pages, sections, sheets, slides).", + description="Browse the structural index of a document (pages, sections, sheets, slides). Use before readContentObjects for targeted reading.", parameters={ "type": "object", "properties": {"fileId": {"type": "string", "description": "The file ID to browse"}}, @@ -1708,7 +1764,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "readContentObjects", _readContentObjects, - description="Read content objects from a file with optional filters (page, section, type).", + description="Read extracted content objects from a file, optionally filtered by page, section, or type. Use browseContainer first to see the structure.", parameters={ "type": "object", "properties": { @@ -1724,7 +1780,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "extractContainerItem", _extractContainerItem, - description="On-demand extraction of a specific item within a container (ZIP, nested file).", + description="Extract a specific item from a container file (ZIP, nested file). Use browseContainer to see available items.", parameters={ "type": "object", "properties": { @@ -1738,7 +1794,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "summarizeContent", _summarizeContent, - description="AI-powered summary of content objects from a file, optionally filtered.", + description="Generate an AI-powered summary of a file's content. Optionally filter by section, page, or content type.", parameters={ "type": "object", "properties": { @@ -1891,7 +1947,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "describeImage", _describeImage, - description="Analyse an image using AI vision. Works with image files and images extracted from PDFs/DOCX/PPTX.", + description="Analyze an image using AI vision. Works with image files and images extracted from PDFs/DOCX/PPTX. Use for OCR, data extraction, and visual analysis.", parameters={ "type": "object", "properties": { @@ -2450,6 +2506,186 @@ def _registerCoreTools(registry: ToolRegistry, services): readOnly=False, ) + # ── createChart tool ───────────────────────────────────────────────── + + async def _createChart(args: Dict[str, Any], context: Dict[str, Any]): + """Create a data chart as PNG image using matplotlib.""" + import re as _re + + chartType = (args.get("chartType") or "bar").strip().lower() + title = (args.get("title") or "Chart").strip() + labels = args.get("labels") or [] + datasets = args.get("datasets") or [] + xLabel = (args.get("xLabel") or "").strip() + yLabel = (args.get("yLabel") or "").strip() + width = min(max(args.get("width") or 10, 4), 20) + height = min(max(args.get("height") or 6, 3), 14) + colors = args.get("colors") or None + + if not datasets: + return ToolResult(toolCallId="", toolName="createChart", success=False, error="datasets is required (list of {label, values})") + + try: + import matplotlib + matplotlib.use("Agg") + import logging as _mpllog + _mpllog.getLogger("matplotlib").setLevel(_mpllog.WARNING) + import matplotlib.pyplot as plt + import io + + _DEFAULT_COLORS = [ + "#4285F4", "#EA4335", "#FBBC04", "#34A853", "#FF6D01", + "#46BDC6", "#7B61FF", "#F538A0", "#00ACC1", "#AB47BC", + ] + usedColors = colors if colors and len(colors) >= len(datasets) else _DEFAULT_COLORS + + fig, ax = plt.subplots(figsize=(width, height)) + fig.patch.set_facecolor("#FFFFFF") + ax.set_facecolor("#FAFAFA") + + if chartType in ("pie", "donut"): + values = datasets[0].get("values", []) if datasets else [] + explode = [0.02] * len(values) + wedges, texts, autotexts = ax.pie( + values, labels=labels, autopct="%1.1f%%", + colors=usedColors[:len(values)], explode=explode, + textprops={"fontsize": 9}, + ) + if chartType == "donut": + ax.add_artist(plt.Circle((0, 0), 0.55, fc="white")) + ax.set_title(title, fontsize=14, fontweight="bold", pad=16) + + else: + import numpy as _np + x = _np.arange(len(labels)) if labels else _np.arange(max(len(d.get("values", [])) for d in datasets)) + barWidth = 0.8 / max(len(datasets), 1) + + for i, ds in enumerate(datasets): + dsLabel = ds.get("label", f"Series {i+1}") + values = ds.get("values", []) + color = usedColors[i % len(usedColors)] + + if chartType == "bar": + offset = (i - len(datasets) / 2 + 0.5) * barWidth + ax.bar(x + offset, values, barWidth, label=dsLabel, color=color, edgecolor="white", linewidth=0.5) + elif chartType == "horizontalbar": + offset = (i - len(datasets) / 2 + 0.5) * barWidth + ax.barh(x + offset, values, barWidth, label=dsLabel, color=color, edgecolor="white", linewidth=0.5) + elif chartType == "line": + ax.plot(x[:len(values)], values, marker="o", markersize=5, label=dsLabel, color=color, linewidth=2) + elif chartType == "area": + ax.fill_between(x[:len(values)], values, alpha=0.3, color=color) + ax.plot(x[:len(values)], values, label=dsLabel, color=color, linewidth=2) + elif chartType == "scatter": + ax.scatter(x[:len(values)], values, label=dsLabel, color=color, s=50, edgecolors="white", linewidth=0.5) + else: + ax.bar(x, values, label=dsLabel, color=color) + + if labels: + if chartType == "horizontalbar": + ax.set_yticks(x) + ax.set_yticklabels(labels, fontsize=9) + else: + ax.set_xticks(x) + ax.set_xticklabels(labels, fontsize=9, rotation=45 if len(labels) > 6 else 0, ha="right" if len(labels) > 6 else "center") + + ax.set_title(title, fontsize=14, fontweight="bold", pad=12) + if xLabel: + ax.set_xlabel(xLabel, fontsize=10) + if yLabel: + ax.set_ylabel(yLabel, fontsize=10) + if len(datasets) > 1: + ax.legend(fontsize=9, framealpha=0.9) + ax.grid(axis="y", alpha=0.3, linestyle="--") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + plt.tight_layout() + buf = io.BytesIO() + fig.savefig(buf, format="png", dpi=150, bbox_inches="tight") + plt.close(fig) + pngData = buf.getvalue() + + chatService = services.chat + sanitizedTitle = _re.sub(r'[^\w._-]', '_', title, flags=_re.UNICODE).strip('_') or "chart" + fileName = f"{sanitizedTitle}.png" + + if hasattr(chatService.interfaceDbComponent, "saveGeneratedFile"): + fileItem = chatService.interfaceDbComponent.saveGeneratedFile(pngData, fileName, "image/png") + else: + fileItem, _ = chatService.interfaceDbComponent.saveUploadedFile(pngData, fileName) + + fid = fileItem.id if hasattr(fileItem, "id") else fileItem.get("id", "?") if isinstance(fileItem, dict) else "?" + fiId = context.get("featureInstanceId") or (services.featureInstanceId if services else "") + if fiId and fid != "?": + chatService.interfaceDbComponent.updateFile(fid, {"featureInstanceId": fiId}) + tempFolderId = _getOrCreateTempFolder(chatService) + if tempFolderId and fid != "?": + chatService.interfaceDbComponent.updateFile(fid, {"folderId": tempFolderId}) + + sideEvents = [{"type": "fileCreated", "data": { + "fileId": fid, "fileName": fileName, + "mimeType": "image/png", "fileSize": len(pngData), + }}] + return ToolResult( + toolCallId="", toolName="createChart", success=True, + data=f"Chart saved as '{fileName}' (id: {fid}, {len(pngData)} bytes). " + f"Embed in documents with: ![{title}](file:{fid})", + sideEvents=sideEvents, + ) + + except Exception as e: + logger.error(f"createChart failed: {e}", exc_info=True) + return ToolResult(toolCallId="", toolName="createChart", success=False, error=str(e)) + + registry.register( + "createChart", _createChart, + description=( + "Create a data chart/graph as a PNG image using matplotlib. " + "Supported types: bar, horizontalBar, line, area, scatter, pie, donut. " + "The chart is saved as a file in the workspace. " + "Use the returned fileId to embed in documents via renderDocument: ![title](file:fileId). " + "Provide structured data with labels and datasets." + ), + parameters={ + "type": "object", + "properties": { + "chartType": { + "type": "string", + "enum": ["bar", "horizontalBar", "line", "area", "scatter", "pie", "donut"], + "description": "Chart type (default: bar)", + }, + "title": {"type": "string", "description": "Chart title"}, + "labels": { + "type": "array", "items": {"type": "string"}, + "description": "X-axis labels / category names", + }, + "datasets": { + "type": "array", + "items": { + "type": "object", + "properties": { + "label": {"type": "string", "description": "Series name (legend)"}, + "values": {"type": "array", "items": {"type": "number"}, "description": "Data values"}, + }, + "required": ["values"], + }, + "description": "Data series to plot", + }, + "xLabel": {"type": "string", "description": "X-axis label"}, + "yLabel": {"type": "string", "description": "Y-axis label"}, + "colors": { + "type": "array", "items": {"type": "string"}, + "description": "Custom hex colors for series (e.g. ['#4285F4', '#EA4335'])", + }, + "width": {"type": "number", "description": "Figure width in inches (4-20, default 10)"}, + "height": {"type": "number", "description": "Figure height in inches (3-14, default 6)"}, + }, + "required": ["datasets"], + }, + readOnly=False, + ) + # ── Phase 3: speechToText, detectLanguage, neutralizeData, executeCode ── async def _speechToText(args: Dict[str, Any], context: Dict[str, Any]): @@ -2534,7 +2770,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "speechToText", _speechToText, - description="Transcribe an audio file to text. Provide the fileId of an audio file from the workspace.", + description="Transcribe an audio file to text using speech recognition. Returns the transcript with confidence score.", parameters={ "type": "object", "properties": { @@ -2548,7 +2784,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "detectLanguage", _detectLanguage, - description="Detect the language of a text.", + description="Detect the language of a text snippet. Returns ISO 639-1 code (e.g. 'de', 'en').", parameters={ "type": "object", "properties": { @@ -2561,7 +2797,7 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "neutralizeData", _neutralizeData, - description="Anonymize/neutralize text or file content. Replaces personal data (names, addresses, etc.) with placeholders. Does not modify the original.", + description="Anonymize text or file content by replacing personal data (names, addresses, etc.) with placeholders. Non-destructive -- returns the anonymized copy.", parameters={ "type": "object", "properties": { @@ -2590,3 +2826,129 @@ def _registerCoreTools(registry: ToolRegistry, services): }, readOnly=True ) + + # ---- Feature Data Sub-Agent tool ---- + + async def _queryFeatureInstance(args: Dict[str, Any], context: Dict[str, Any]): + """Delegate a question to the Feature Data Sub-Agent.""" + featureInstanceId = args.get("featureInstanceId", "") + question = args.get("question", "") + if not featureInstanceId or not question: + return ToolResult( + toolCallId="", toolName="queryFeatureInstance", + success=False, error="featureInstanceId and question are required", + ) + try: + from modules.serviceCenter.services.serviceAgent.featureDataAgent import runFeatureDataAgent + from modules.datamodels.datamodelFeatureDataSource import FeatureDataSource + from modules.interfaces.interfaceDbApp import getRootInterface + + rootIf = getRootInterface() + instance = rootIf.getFeatureInstance(featureInstanceId) + if not instance: + return ToolResult( + toolCallId="", toolName="queryFeatureInstance", + success=False, error=f"Feature instance {featureInstanceId} not found", + ) + + featureCode = instance.featureCode + mandateId = instance.mandateId or "" + instanceLabel = instance.label or "" + userId = context.get("userId", "") + workspaceInstanceId = context.get("featureInstanceId", "") + + rootDbConn = rootIf.db if hasattr(rootIf, "db") else None + if rootDbConn is None: + return ToolResult( + toolCallId="", toolName="queryFeatureInstance", + success=False, error="No database connector available", + ) + + featureDataSources = rootDbConn.getRecordset( + FeatureDataSource, + recordFilter={"featureInstanceId": featureInstanceId, "workspaceInstanceId": workspaceInstanceId}, + ) + + from modules.security.rbacCatalog import getCatalogService + catalog = getCatalogService() + if not featureDataSources: + selectedTables = catalog.getDataObjects(featureCode) + else: + allObjs = {o["meta"]["table"]: o for o in catalog.getDataObjects(featureCode) if "meta" in o and "table" in o.get("meta", {})} + selectedTables = [allObjs[ds["tableName"]] for ds in featureDataSources if ds.get("tableName") in allObjs] + + if not selectedTables: + return ToolResult( + toolCallId="", toolName="queryFeatureInstance", + success=False, error=f"No data tables available for feature '{featureCode}'", + ) + + from modules.connectors.connectorDbPostgre import DatabaseConnector + from modules.shared.configuration import APP_CONFIG + featureDbName = f"poweron_{featureCode.lower()}" + featureDbConn = DatabaseConnector( + dbHost=APP_CONFIG.get("DB_HOST", "localhost"), + dbDatabase=featureDbName, + dbUser=APP_CONFIG.get("DB_USER"), + dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), + dbPort=int(APP_CONFIG.get("DB_PORT", 5432)), + userId=userId or "agent", + ) + + aiService = services.ai if hasattr(services, "ai") else None + if aiService is None: + return ToolResult( + toolCallId="", toolName="queryFeatureInstance", + success=False, error="AI service not available for sub-agent", + ) + + async def _subAgentAiCall(req): + return await aiService.callAi(req) + + try: + answer = await runFeatureDataAgent( + question=question, + featureInstanceId=featureInstanceId, + featureCode=featureCode, + selectedTables=selectedTables, + mandateId=mandateId, + userId=userId, + aiCallFn=_subAgentAiCall, + dbConnector=featureDbConn, + instanceLabel=instanceLabel, + ) + finally: + try: + featureDbConn.close() + except Exception: + pass + + return ToolResult( + toolCallId="", toolName="queryFeatureInstance", + success=True, data=answer, + ) + except Exception as e: + logger.error(f"queryFeatureInstance failed: {e}", exc_info=True) + return ToolResult( + toolCallId="", toolName="queryFeatureInstance", + success=False, error=str(e), + ) + + registry.register( + "queryFeatureInstance", _queryFeatureInstance, + description=( + "Query data from a feature instance (e.g. Trustee, CommCoach). " + "Delegates to a specialized sub-agent that knows the feature's data schema " + "and can browse/query its tables. Use this when the user has attached " + "feature data sources or asks about feature-specific data." + ), + parameters={ + "type": "object", + "properties": { + "featureInstanceId": {"type": "string", "description": "ID of the feature instance to query"}, + "question": {"type": "string", "description": "What data to find or analyze from this feature instance"}, + }, + "required": ["featureInstanceId", "question"] + }, + readOnly=True + )