From 887867acd0477d09b72b5a47ac5a59ab9e2778c5 Mon Sep 17 00:00:00 2001
From: patrick-motsch
Date: Sun, 8 Feb 2026 16:14:01 +0100
Subject: [PATCH] billing rbac
---
modules/aicore/aicorePluginPrivateLlm.py | 9 +-
modules/datamodels/datamodelBilling.py | 4 +
.../automation/routeFeatureAutomation.py | 4 +
modules/features/chatbot/service.py | 8 +-
.../routeFeatureChatplayground.py | 3 +-
modules/interfaces/interfaceAiObjects.py | 25 +-
modules/interfaces/interfaceDbBilling.py | 60 +++-
modules/routes/routeBilling.py | 270 ++++++++++++++----
modules/services/serviceAi/mainServiceAi.py | 207 ++++++++++----
.../serviceBilling/mainServiceBilling.py | 18 +-
.../services/serviceChat/mainServiceChat.py | 66 +----
modules/workflows/automation/mainWorkflow.py | 29 +-
12 files changed, 506 insertions(+), 197 deletions(-)
diff --git a/modules/aicore/aicorePluginPrivateLlm.py b/modules/aicore/aicorePluginPrivateLlm.py
index dfcb7eaf..718c5905 100644
--- a/modules/aicore/aicorePluginPrivateLlm.py
+++ b/modules/aicore/aicorePluginPrivateLlm.py
@@ -241,8 +241,10 @@ class AiPrivateLlm(BaseConnectorAi):
priority=PriorityEnum.COST,
processingMode=ProcessingModeEnum.BASIC,
operationTypes=createOperationTypeRatings(
- (OperationTypeEnum.DATA_EXTRACT, 9),
- (OperationTypeEnum.DATA_ANALYSE, 9),
+ (OperationTypeEnum.PLAN, 7),
+ (OperationTypeEnum.DATA_ANALYSE, 8),
+ (OperationTypeEnum.DATA_GENERATE, 8),
+ (OperationTypeEnum.DATA_EXTRACT, 8),
),
version="qwen2.5:7b",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_TEXT_PER_CALL
@@ -268,7 +270,6 @@ class AiPrivateLlm(BaseConnectorAi):
processingMode=ProcessingModeEnum.ADVANCED,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.IMAGE_ANALYSE, 9),
- (OperationTypeEnum.DATA_EXTRACT, 8),
),
version="qwen2.5vl:7b",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_VISION_PER_CALL
@@ -294,8 +295,6 @@ class AiPrivateLlm(BaseConnectorAi):
processingMode=ProcessingModeEnum.DETAILED,
operationTypes=createOperationTypeRatings(
(OperationTypeEnum.IMAGE_ANALYSE, 9),
- (OperationTypeEnum.DATA_EXTRACT, 9),
- (OperationTypeEnum.DATA_ANALYSE, 8),
),
version="granite3.2-vision",
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_VISION_PER_CALL
diff --git a/modules/datamodels/datamodelBilling.py b/modules/datamodels/datamodelBilling.py
index e7e59eb4..ea5a4c00 100644
--- a/modules/datamodels/datamodelBilling.py
+++ b/modules/datamodels/datamodelBilling.py
@@ -121,6 +121,8 @@ class BillingTransaction(BaseModel):
featureInstanceId: Optional[str] = Field(None, description="Feature instance ID")
featureCode: Optional[str] = Field(None, description="Feature code (e.g., chatplayground, automation)")
aicoreProvider: Optional[str] = Field(None, description="AICore provider (anthropic, openai, etc.)")
+ aicoreModel: Optional[str] = Field(None, description="AICore model name (e.g., claude-4-sonnet, gpt-4o)")
+ createdByUserId: Optional[str] = Field(None, description="User who created/caused this transaction")
registerModelLabels(
@@ -138,6 +140,8 @@ registerModelLabels(
"featureInstanceId": {"en": "Feature Instance ID", "de": "Feature-Instanz-ID"},
"featureCode": {"en": "Feature Code", "de": "Feature-Code"},
"aicoreProvider": {"en": "AI Provider", "de": "AI-Anbieter"},
+ "aicoreModel": {"en": "AI Model", "de": "AI-Modell"},
+ "createdByUserId": {"en": "Created By User", "de": "Erstellt von Benutzer"},
},
)
diff --git a/modules/features/automation/routeFeatureAutomation.py b/modules/features/automation/routeFeatureAutomation.py
index d39a3358..8a5fee1d 100644
--- a/modules/features/automation/routeFeatureAutomation.py
+++ b/modules/features/automation/routeFeatureAutomation.py
@@ -367,6 +367,10 @@ async def execute_automation_route(
try:
from modules.services import getInterface as getServices
services = getServices(context.user, context.mandateId)
+ # Propagate feature context for billing
+ if context.featureInstanceId:
+ services.featureInstanceId = str(context.featureInstanceId)
+ services.featureCode = 'automation'
workflow = await executeAutomation(automationId, services)
return workflow
except HTTPException:
diff --git a/modules/features/chatbot/service.py b/modules/features/chatbot/service.py
index fd7f3344..3afd1632 100644
--- a/modules/features/chatbot/service.py
+++ b/modules/features/chatbot/service.py
@@ -93,6 +93,12 @@ async def chatProcess(
try:
# Get services with mandate context
services = getServices(currentUser, mandateId)
+
+ # Set feature context for billing
+ if featureInstanceId:
+ services.featureInstanceId = featureInstanceId
+ services.featureCode = 'chatbot'
+
interfaceDbChat = services.interfaceDbChat
# Get event manager and create queue if needed
@@ -698,7 +704,7 @@ async def _convert_file_ids_to_document_references(
# Search database if not found in messages
if not document_id:
try:
- from modules.shared.databaseUtils import getRecordsetWithRBAC
+ from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
documents = getRecordsetWithRBAC(
services.interfaceDbChat.db,
ChatDocument,
diff --git a/modules/features/chatplayground/routeFeatureChatplayground.py b/modules/features/chatplayground/routeFeatureChatplayground.py
index ce1611ea..e3787904 100644
--- a/modules/features/chatplayground/routeFeatureChatplayground.py
+++ b/modules/features/chatplayground/routeFeatureChatplayground.py
@@ -102,7 +102,8 @@ async def start_workflow(
workflowMode,
workflowId,
mandateId=mandateId,
- featureInstanceId=instanceId
+ featureInstanceId=instanceId,
+ featureCode='chatplayground'
)
return workflow
diff --git a/modules/interfaces/interfaceAiObjects.py b/modules/interfaces/interfaceAiObjects.py
index 3976350d..0214d231 100644
--- a/modules/interfaces/interfaceAiObjects.py
+++ b/modules/interfaces/interfaceAiObjects.py
@@ -4,8 +4,8 @@ import logging
import asyncio
import uuid
import base64
-from typing import Dict, Any, List, Union, Tuple, Optional
-from dataclasses import dataclass
+from typing import Dict, Any, List, Union, Tuple, Optional, Callable
+from dataclasses import dataclass, field
import time
logger = logging.getLogger(__name__)
@@ -29,7 +29,13 @@ from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy
@dataclass(slots=True)
class AiObjects:
- """Centralized AI interface: dynamically discovers and uses AI models. Includes web functionality."""
+ """Centralized AI interface: dynamically discovers and uses AI models.
+
+ billingCallback: Set by serviceAi before AI calls. Called after EVERY individual
+ model call with the AiCallResponse. This ensures per-model-call billing with
+ exact provider + model name. The callback handles billing recording.
+ """
+ billingCallback: Optional[Callable] = field(default=None, repr=False)
def __post_init__(self) -> None:
# Auto-discover and register all available connectors
@@ -226,7 +232,7 @@ class AiObjects:
# Calculate price using model's own price calculation method
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
- return AiCallResponse(
+ response = AiCallResponse(
content=content,
modelName=model.name,
provider=model.connectorType,
@@ -236,6 +242,17 @@ class AiObjects:
bytesReceived=outputBytes,
errorCount=0
)
+
+ # BILLING: Record billing for THIS specific model call
+ # billingCallback is set by serviceAi and records one billing transaction
+ # per model call with exact provider + model name
+ if self.billingCallback:
+ try:
+ self.billingCallback(response)
+ except Exception as e:
+ logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
+
+ return response
# Utility methods
diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py
index ae8b13ec..a8cbd61b 100644
--- a/modules/interfaces/interfaceDbBilling.py
+++ b/modules/interfaces/interfaceDbBilling.py
@@ -710,6 +710,7 @@ class BillingObjects:
featureInstanceId: str = None,
featureCode: str = None,
aicoreProvider: str = None,
+ aicoreModel: str = None,
description: str = "AI Usage"
) -> Optional[Dict[str, Any]]:
"""
@@ -722,7 +723,8 @@ class BillingObjects:
workflowId: Optional workflow ID
featureInstanceId: Optional feature instance ID
featureCode: Optional feature code
- aicoreProvider: Optional AICore provider name
+ aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai')
+ aicoreModel: AICore model name (e.g., 'claude-4-sonnet', 'gpt-4o')
description: Transaction description
Returns:
@@ -758,7 +760,9 @@ class BillingObjects:
workflowId=workflowId,
featureInstanceId=featureInstanceId,
featureCode=featureCode,
- aicoreProvider=aicoreProvider
+ aicoreProvider=aicoreProvider,
+ aicoreModel=aicoreModel,
+ createdByUserId=userId
)
return self.createTransaction(transaction)
@@ -828,9 +832,13 @@ class BillingObjects:
# Calculate by provider
costByProvider = {}
+ costByModel = {}
for t in debits:
provider = t.get("aicoreProvider", "unknown")
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
+
+ model = t.get("aicoreModel", "unknown")
+ costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0)
# Calculate by feature
costByFeature = {}
@@ -842,6 +850,7 @@ class BillingObjects:
"totalCostCHF": totalCost,
"transactionCount": len(debits),
"costByProvider": costByProvider,
+ "costByModel": costByModel,
"costByFeature": costByFeature
}
@@ -1129,8 +1138,8 @@ class BillingObjects:
user = appInterface.getUser(userId)
if user:
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
- email = getattr(user, 'email', None) or (user.get("email") if isinstance(user, dict) else None)
- userMap[userId] = displayName or email or userId
+ username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
+ userMap[userId] = displayName or username or userId
# Get mandate info efficiently
mandateMap = {}
@@ -1212,8 +1221,8 @@ class BillingObjects:
user = appInterface.getUser(userId)
if user:
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
- email = getattr(user, 'email', None) or (user.get("email") if isinstance(user, dict) else None)
- userMap[userId] = displayName or email or userId
+ username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
+ userMap[userId] = displayName or username or userId
# Get mandate info efficiently
mandateMap = {}
@@ -1224,7 +1233,8 @@ class BillingObjects:
mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "")
mandateMap[mandateId] = mandateName
- # Get transactions for all accounts
+ # Get transactions for all accounts and collect createdByUserIds
+ rawTransactions = []
for account in allAccounts:
accountId = account.get("id")
if not accountId:
@@ -1233,14 +1243,38 @@ class BillingObjects:
transactions = self.getTransactions(accountId, limit=limit)
accountInfo = accountMap.get(accountId, {})
mandateId = accountInfo.get("mandateId")
- userId = accountInfo.get("userId")
+ accountUserId = accountInfo.get("userId")
for t in transactions:
- t["mandateId"] = mandateId
- t["mandateName"] = mandateMap.get(mandateId, "")
- t["userId"] = userId
- t["userName"] = userMap.get(userId, userId)
- allTransactions.append(t)
+ t["_accountUserId"] = accountUserId
+ t["_accountMandateId"] = mandateId
+ rawTransactions.append(t)
+
+ # Resolve createdByUserIds that are not yet in userMap
+ extraUserIds = set()
+ for t in rawTransactions:
+ cbUserId = t.get("createdByUserId")
+ if cbUserId and cbUserId not in userMap:
+ extraUserIds.add(cbUserId)
+
+ for uid in extraUserIds:
+ user = appInterface.getUser(uid)
+ if user:
+ displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
+ username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
+ userMap[uid] = displayName or username or uid
+
+ # Enrich transactions
+ for t in rawTransactions:
+ mandateId = t.pop("_accountMandateId", None)
+ accountUserId = t.pop("_accountUserId", None)
+ t["mandateId"] = mandateId
+ t["mandateName"] = mandateMap.get(mandateId, "")
+ # Prefer createdByUserId (per-transaction) over account-derived userId
+ txUserId = t.get("createdByUserId") or accountUserId
+ t["userId"] = txUserId
+ t["userName"] = userMap.get(txUserId, txUserId) if txUserId else None
+ allTransactions.append(t)
except Exception as e:
logger.error(f"Error getting user transactions for mandates: {e}")
diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py
index 26133704..3e87cd83 100644
--- a/modules/routes/routeBilling.py
+++ b/modules/routes/routeBilling.py
@@ -42,6 +42,143 @@ from modules.datamodels.datamodelBilling import (
# Configure logger
logger = logging.getLogger(__name__)
+
+# =============================================================================
+# Billing RBAC Data Scope
+# =============================================================================
+#
+# RBAC rules for billing data visibility:
+#
+# SysAdmin → ALL transactions and statistics across all mandates
+# Mandate-Admin → ALL user data within their administrated mandates
+# Feature-Instance-Admin→ Data for their administrated feature instances
+# Regular User → ONLY their own data within their mandates
+#
+
+class BillingDataScope:
+ """
+ Determines what billing data a user can see based on RBAC roles.
+
+ Evaluated once per request and used to filter transactions/statistics.
+ """
+ __slots__ = ('isGlobalAdmin', 'adminMandateIds', 'adminFeatureInstanceIds',
+ 'memberMandateIds', 'userId')
+
+ def __init__(self, userId: str):
+ self.isGlobalAdmin: bool = False
+ self.adminMandateIds: list = []
+ self.adminFeatureInstanceIds: list = []
+ self.memberMandateIds: list = []
+ self.userId: str = userId
+
+
+def _getBillingDataScope(user) -> BillingDataScope:
+ """
+ Determine what billing data a user can see based on RBAC.
+
+ Uses rootInterface (privileged) to check roles across all mandates
+ and feature instances without RBAC restrictions on the lookup itself.
+
+ Returns:
+ BillingDataScope with the user's visibility boundaries.
+ """
+ scope = BillingDataScope(userId=user.id)
+
+ if user.isSysAdmin:
+ scope.isGlobalAdmin = True
+ return scope
+
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ rootInterface = getRootInterface()
+
+ # --- Mandate roles ---
+ userMandates = rootInterface.getUserMandates(user.id)
+ for um in userMandates:
+ mandateId = getattr(um, 'mandateId', None)
+ umId = getattr(um, 'id', None)
+ if not mandateId or not umId:
+ continue
+
+ roleIds = rootInterface.getRoleIdsForUserMandate(umId)
+ isAdmin = False
+ for roleId in roleIds:
+ role = rootInterface.getRole(roleId)
+ if role and role.roleLabel == "admin" and not role.featureInstanceId:
+ isAdmin = True
+ break
+
+ if isAdmin:
+ scope.adminMandateIds.append(mandateId)
+ else:
+ scope.memberMandateIds.append(mandateId)
+
+ # --- Feature instance roles ---
+ featureAccesses = rootInterface.getFeatureAccessesForUser(user.id)
+ for fa in featureAccesses:
+ fiId = getattr(fa, 'featureInstanceId', None)
+ faId = getattr(fa, 'id', None)
+ if not fiId or not faId:
+ continue
+
+ roleIds = rootInterface.getRoleIdsForFeatureAccess(faId)
+ for roleId in roleIds:
+ role = rootInterface.getRole(roleId)
+ if role and role.roleLabel == "admin":
+ scope.adminFeatureInstanceIds.append(fiId)
+ break
+
+ logger.debug(
+ f"BillingDataScope for user {user.id}: "
+ f"globalAdmin={scope.isGlobalAdmin}, "
+ f"adminMandates={scope.adminMandateIds}, "
+ f"adminInstances={scope.adminFeatureInstanceIds}, "
+ f"memberMandates={scope.memberMandateIds}"
+ )
+ return scope
+
+
+def _filterTransactionsByScope(transactions: list, scope: BillingDataScope) -> list:
+ """
+ Filter a list of transaction dicts based on the user's BillingDataScope.
+
+ Rules:
+ - SysAdmin: no filter
+ - Mandate-Admin: all transactions in their admin mandates
+ - Feature-Instance-Admin: transactions for their admin feature instances
+ - Regular user: only transactions where createdByUserId/userId matches
+ """
+ if scope.isGlobalAdmin:
+ return transactions
+
+ adminMandateSet = set(scope.adminMandateIds)
+ adminFiSet = set(scope.adminFeatureInstanceIds)
+ memberMandateSet = set(scope.memberMandateIds)
+
+ result = []
+ for t in transactions:
+ mandateId = t.get("mandateId")
+ fiId = t.get("featureInstanceId")
+ txUserId = t.get("createdByUserId") or t.get("userId")
+
+ # Mandate admin → sees all transactions in their mandate
+ if mandateId and mandateId in adminMandateSet:
+ result.append(t)
+ continue
+
+ # Feature instance admin → sees all transactions for their instances
+ if fiId and fiId in adminFiSet:
+ result.append(t)
+ continue
+
+ # Regular member → only own transactions
+ if mandateId and mandateId in memberMandateSet:
+ if txUserId and txUserId == scope.userId:
+ result.append(t)
+ continue
+
+ return result
+
+
# =============================================================================
# Request/Response Models
# =============================================================================
@@ -74,7 +211,10 @@ class TransactionResponse(BaseModel):
referenceType: Optional[ReferenceTypeEnum]
workflowId: Optional[str]
featureCode: Optional[str]
+ featureInstanceId: Optional[str] = None
aicoreProvider: Optional[str]
+ aicoreModel: Optional[str] = None
+ createdByUserId: Optional[str] = None
createdAt: Optional[datetime]
mandateId: Optional[str] = None
mandateName: Optional[str] = None
@@ -98,6 +238,7 @@ class UsageReportResponse(BaseModel):
totalCost: float
transactionCount: int
costByProvider: Dict[str, float]
+ costByModel: Dict[str, float] = {}
costByFeature: Dict[str, float]
@@ -140,7 +281,10 @@ class UserTransactionResponse(BaseModel):
referenceType: Optional[ReferenceTypeEnum]
workflowId: Optional[str]
featureCode: Optional[str]
+ featureInstanceId: Optional[str] = None
aicoreProvider: Optional[str]
+ aicoreModel: Optional[str] = None
+ createdByUserId: Optional[str] = None
createdAt: Optional[datetime]
mandateId: Optional[str] = None
mandateName: Optional[str] = None
@@ -261,7 +405,10 @@ def getTransactions(
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
workflowId=t.get("workflowId"),
featureCode=t.get("featureCode"),
+ featureInstanceId=t.get("featureInstanceId"),
aicoreProvider=t.get("aicoreProvider"),
+ aicoreModel=t.get("aicoreModel"),
+ createdByUserId=t.get("createdByUserId"),
createdAt=t.get("_createdAt"),
mandateId=t.get("mandateId"),
mandateName=t.get("mandateName")
@@ -349,6 +496,7 @@ def getStatistics(
totalCost=stats.get("totalCostCHF", 0.0),
transactionCount=stats.get("transactionCount", 0),
costByProvider=stats.get("costByProvider", {}),
+ costByModel=stats.get("costByModel", {}),
costByFeature=stats.get("costByFeature", {})
)
@@ -564,6 +712,7 @@ def getAccounts(
class MandateUserSummary(BaseModel):
"""Summary of a user for billing admin purposes."""
id: str
+ username: Optional[str] = None
email: Optional[str] = None
firstName: Optional[str] = None
lastName: Optional[str] = None
@@ -600,18 +749,21 @@ def getUsersForMandate(
# Handle both Pydantic models and dicts
if isinstance(user, dict):
+ username = user.get("username", "")
firstName = user.get("firstName", "")
lastName = user.get("lastName", "")
email = user.get("email", "")
else:
+ username = getattr(user, "username", "") or ""
firstName = getattr(user, "firstName", "") or ""
lastName = getattr(user, "lastName", "") or ""
email = getattr(user, "email", "") or ""
- displayName = f"{firstName} {lastName}".strip() or email or userId
+ displayName = f"{firstName} {lastName}".strip() or username or userId
result.append(MandateUserSummary(
id=userId,
+ username=username,
email=email,
firstName=firstName,
lastName=lastName,
@@ -652,7 +804,10 @@ def getTransactionsAdmin(
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
workflowId=t.get("workflowId"),
featureCode=t.get("featureCode"),
+ featureInstanceId=t.get("featureInstanceId"),
aicoreProvider=t.get("aicoreProvider"),
+ aicoreModel=t.get("aicoreModel"),
+ createdByUserId=t.get("createdByUserId"),
createdAt=t.get("_createdAt")
))
@@ -715,7 +870,10 @@ def getMandateViewTransactions(
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
workflowId=t.get("workflowId"),
featureCode=t.get("featureCode"),
+ featureInstanceId=t.get("featureInstanceId"),
aicoreProvider=t.get("aicoreProvider"),
+ aicoreModel=t.get("aicoreModel"),
+ createdByUserId=t.get("createdByUserId"),
createdAt=t.get("_createdAt"),
mandateId=t.get("mandateId"),
mandateName=t.get("mandateName")
@@ -740,39 +898,35 @@ def getUserViewBalances(
):
"""
Get user-level balances.
+
+ RBAC filtering:
- SysAdmin: sees all user balances across all mandates
- - MandateAdmin: sees user balances for mandates they manage
+ - Mandate-Admin: sees user balances for mandates they administrate
- Regular user: sees only their own balances
"""
try:
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
- # Determine which mandates the user has access to
- if ctx.user.isSysAdmin:
- # SysAdmin sees all
+ # Evaluate RBAC scope
+ scope = _getBillingDataScope(ctx.user)
+
+ # Determine mandate IDs for data loading
+ if scope.isGlobalAdmin:
mandateIds = None
else:
- # Get mandates where user is admin or has billing access
- from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
- appInterface = getAppInterface(ctx.user)
- userMandates = appInterface.getUserMandates(ctx.user.id)
-
- # Filter to only mandates where user has admin role
- # For simplicity, we'll check if user is admin in any mandate
- mandateIds = []
- for um in userMandates:
- mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
- if mandateId:
- mandateIds.append(mandateId)
-
+ mandateIds = scope.adminMandateIds + scope.memberMandateIds
if not mandateIds:
return []
allBalances = billingInterface.getUserBalancesForMandates(mandateIds)
- # Non-admin users only see their own balances
- if not ctx.user.isSysAdmin:
- allBalances = [b for b in allBalances if b.get("userId") == ctx.user.id]
+ # RBAC filter: mandate admins see all in their mandates, regular users only own
+ if not scope.isGlobalAdmin:
+ adminMandateSet = set(scope.adminMandateIds)
+ allBalances = [
+ b for b in allBalances
+ if b.get("mandateId") in adminMandateSet or b.get("userId") == scope.userId
+ ]
return [UserBalanceResponse(**b) for b in allBalances]
@@ -786,6 +940,7 @@ class ViewStatisticsResponse(BaseModel):
totalCost: float = 0.0
transactionCount: int = 0
costByProvider: Dict[str, float] = {}
+ costByModel: Dict[str, float] = {}
costByFeature: Dict[str, float] = {}
costByMandate: Dict[str, float] = {}
timeSeries: List[Dict[str, Any]] = []
@@ -802,6 +957,13 @@ def getUserViewStatistics(
) -> ViewStatisticsResponse:
"""
Get aggregated usage statistics across all user's mandates.
+
+ RBAC filtering:
+ - SysAdmin: statistics across all mandates
+ - Mandate-Admin: statistics for mandates they administrate
+ - Feature-Instance-Admin: statistics for their feature instances
+ - Regular user: only their own usage statistics
+
- period='month': returns monthly time series for the given year
- period='day': returns daily time series for the given month/year
"""
@@ -816,25 +978,25 @@ def getUserViewStatistics(
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
- # Get all mandates the user has access to
- if ctx.user.isSysAdmin:
+ # Evaluate RBAC scope
+ scope = _getBillingDataScope(ctx.user)
+
+ # Determine mandate IDs for data loading
+ if scope.isGlobalAdmin:
mandateIds = None
else:
- from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
- appInterface = getAppInterface(ctx.user)
- userMandates = appInterface.getUserMandates(ctx.user.id)
- mandateIds = []
- for um in userMandates:
- mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
- if mandateId:
- mandateIds.append(mandateId)
+ mandateIds = scope.adminMandateIds + scope.memberMandateIds
if not mandateIds:
logger.warning("No mandate IDs found for user")
return ViewStatisticsResponse()
# Get all transactions
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
- logger.info(f"View statistics: {len(allTransactions)} total transactions fetched for period={period}, year={year}, month={month}")
+
+ # Apply RBAC filter
+ allTransactions = _filterTransactionsByScope(allTransactions, scope)
+
+ logger.info(f"View statistics: {len(allTransactions)} RBAC-filtered transactions for period={period}, year={year}, month={month}")
# Calculate date range
if period == "day":
@@ -905,6 +1067,7 @@ def getUserViewStatistics(
totalCost = sum(t.get("amount", 0) for t in debits)
costByProvider: Dict[str, float] = {}
+ costByModel: Dict[str, float] = {}
costByFeature: Dict[str, float] = {}
costByMandate: Dict[str, float] = {}
@@ -912,6 +1075,9 @@ def getUserViewStatistics(
provider = t.get("aicoreProvider") or "unknown"
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
+ model = t.get("aicoreModel") or "unknown"
+ costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0)
+
mandate = t.get("mandateName") or t.get("mandateId") or "unknown"
featureCode = t.get("featureCode") or "unknown"
featureKey = f"{mandate} / {featureCode}"
@@ -950,6 +1116,7 @@ def getUserViewStatistics(
totalCost=round(totalCost, 4),
transactionCount=len(debits),
costByProvider=costByProvider,
+ costByModel=costByModel,
costByFeature=costByFeature,
costByMandate=costByMandate,
timeSeries=timeSeries
@@ -969,8 +1136,11 @@ def getUserViewTransactions(
) -> PaginatedResponse[UserTransactionResponse]:
"""
Get user-level transactions with pagination support.
+
+ RBAC filtering:
- SysAdmin: sees all user transactions across all mandates
- - MandateAdmin: sees user transactions for mandates they manage
+ - Mandate-Admin: sees all user transactions for mandates they administrate
+ - Feature-Instance-Admin: sees transactions for their feature instances
- Regular user: sees only their own transactions
Query Parameters:
@@ -987,28 +1157,24 @@ def getUserViewTransactions(
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
- # Determine which mandates the user has access to
- if ctx.user.isSysAdmin:
- # SysAdmin sees all
- mandateIds = None
+ # Evaluate RBAC scope
+ scope = _getBillingDataScope(ctx.user)
+
+ # Determine mandate IDs for data loading
+ if scope.isGlobalAdmin:
+ mandateIds = None # Load all
else:
- # Get mandates where user has access
- from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
- appInterface = getAppInterface(ctx.user)
- userMandates = appInterface.getUserMandates(ctx.user.id)
-
- mandateIds = []
- for um in userMandates:
- mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
- if mandateId:
- mandateIds.append(mandateId)
-
+ # Load data for all mandates the user belongs to (admin + member)
+ mandateIds = scope.adminMandateIds + scope.memberMandateIds
if not mandateIds:
return PaginatedResponse(items=[], pagination=None)
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
- logger.debug(f"Found {len(allTransactions)} transactions for mandates {mandateIds}")
+ # Apply RBAC filter
+ allTransactions = _filterTransactionsByScope(allTransactions, scope)
+
+ logger.debug(f"RBAC-filtered {len(allTransactions)} transactions for user {ctx.user.id}")
# Convert to response objects as dicts for filtering/sorting
transactionDicts = []
@@ -1022,7 +1188,10 @@ def getUserViewTransactions(
"referenceType": t.get("referenceType"),
"workflowId": t.get("workflowId"),
"featureCode": t.get("featureCode"),
+ "featureInstanceId": t.get("featureInstanceId"),
"aicoreProvider": t.get("aicoreProvider"),
+ "aicoreModel": t.get("aicoreModel"),
+ "createdByUserId": t.get("createdByUserId"),
"createdAt": t.get("_createdAt"),
"mandateId": t.get("mandateId"),
"mandateName": t.get("mandateName"),
@@ -1044,7 +1213,10 @@ def getUserViewTransactions(
referenceType=ReferenceTypeEnum(d["referenceType"]) if d.get("referenceType") else None,
workflowId=d.get("workflowId"),
featureCode=d.get("featureCode"),
+ featureInstanceId=d.get("featureInstanceId"),
aicoreProvider=d.get("aicoreProvider"),
+ aicoreModel=d.get("aicoreModel"),
+ createdByUserId=d.get("createdByUserId"),
createdAt=d.get("createdAt"),
mandateId=d.get("mandateId"),
mandateName=d.get("mandateName"),
diff --git a/modules/services/serviceAi/mainServiceAi.py b/modules/services/serviceAi/mainServiceAi.py
index 5fdf32a5..bd3ef6b4 100644
--- a/modules/services/serviceAi/mainServiceAi.py
+++ b/modules/services/serviceAi/mainServiceAi.py
@@ -21,7 +21,8 @@ from modules.datamodels.datamodelAi import JsonAccumulationState
from modules.services.serviceBilling.mainServiceBilling import (
getService as getBillingService,
InsufficientBalanceException,
- ProviderNotAllowedException
+ ProviderNotAllowedException,
+ BillingContextError
)
logger = logging.getLogger(__name__)
@@ -88,41 +89,104 @@ class AiService:
async def callAi(self, request: AiCallRequest, progressCallback=None):
"""Router: handles content parts via extractionService, text context via interface.
- Replaces direct calls to self.aiObjects.call() to route content parts processing
- through serviceExtraction layer.
-
- Includes billing checks:
- - Balance check before AI call
- - Provider permission check (via RBAC)
-
- Also stores workflow stats after each successful AI call.
+ FAIL-SAFE BILLING at the source:
+ 1. Pre-flight check: validates billing context is complete (RAISES if not)
+ 2. Balance & provider check before AI call
+ 3. billingCallback on aiObjects: records one billing transaction per model call
+ with exact provider + model name (set before AI call, invoked by _callWithModel)
"""
- # Billing check before AI call (validates RBAC permissions)
+ # FAIL-SAFE: Pre-flight billing validation (like 0 CHF credit card check)
+ self._preflightBillingCheck()
+
+ # Balance & provider permission checks
await self._checkBillingBeforeAiCall()
# Calculate effective allowedProviders: RBAC ∩ Workflow
- # RBAC is master - only RBAC-permitted providers can ever be used
effectiveProviders = self._calculateEffectiveProviders()
if effectiveProviders and request.options:
request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders})
logger.debug(f"Effective allowedProviders for AI request: {effectiveProviders}")
- if hasattr(request, 'contentParts') and request.contentParts:
- response = await self.extractionService.processContentPartsWithAi(
- request, self.aiObjects, progressCallback
- )
- else:
- response = await self.aiObjects.callWithTextContext(request)
+ # Set billing callback on aiObjects BEFORE the AI call
+ # This callback is invoked by _callWithModel() after EVERY individual model call
+ # For parallel content parts (e.g., 200 MB doc), each model call creates its own transaction
+ self.aiObjects.billingCallback = self._createBillingCallback()
- # Store workflow stats after each AI call
+ try:
+ if hasattr(request, 'contentParts') and request.contentParts:
+ response = await self.extractionService.processContentPartsWithAi(
+ request, self.aiObjects, progressCallback
+ )
+ else:
+ response = await self.aiObjects.callWithTextContext(request)
+ finally:
+ # Clear callback after call completes
+ self.aiObjects.billingCallback = None
+
+ # Store workflow stats for analytics
self._storeAiCallStats(response, request)
return response
+ def _preflightBillingCheck(self) -> None:
+ """
+ Pre-flight billing validation - like a 0 CHF credit card authorization check.
+
+ Validates that ALL required billing context is present and that a billing
+ transaction CAN be recorded. This dry-run check catches missing context
+ BEFORE an expensive AI call starts.
+
+ FAIL-SAFE: This method RAISES if billing context is incomplete.
+ An AI call without billing context MUST NOT proceed.
+
+ Raises:
+ BillingContextError: If billing context is incomplete or invalid
+ """
+ if not self.services:
+ raise BillingContextError("No service context available - cannot bill AI call")
+
+ user = getattr(self.services, 'user', None)
+ if not user:
+ raise BillingContextError("No user context - cannot bill AI call")
+
+ mandateId = getattr(self.services, 'mandateId', None)
+ if not mandateId:
+ raise BillingContextError(
+ f"No mandateId in service context for user {user.id} - cannot bill AI call. "
+ "Every AI call MUST have a mandate context for billing."
+ )
+
+ # Validate billing service can be created
+ featureInstanceId = getattr(self.services, 'featureInstanceId', None)
+ featureCode = getattr(self.services, 'featureCode', None)
+
+ try:
+ billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
+ except Exception as e:
+ raise BillingContextError(
+ f"Cannot create billing service for user {user.id}, mandate {mandateId}: {e}"
+ )
+
+ # Dry-run: verify billing service can check balance (DB accessible)
+ try:
+ billingService.checkBalance(0.0)
+ except Exception as e:
+ raise BillingContextError(
+ f"Billing system not accessible for mandate {mandateId}: {e}"
+ )
+
+ logger.debug(
+ f"Pre-flight billing check PASSED: user={user.id}, mandate={mandateId}, "
+ f"feature={featureCode or 'none'}, instance={featureInstanceId or 'none'}"
+ )
+
async def _checkBillingBeforeAiCall(self) -> None:
"""
Check billing status before making an AI call.
+ FAIL-SAFE: Context validation is done in _preflightBillingCheck() which is
+ called first. This method handles balance and provider permission checks.
+
Verifies:
1. User has sufficient balance (for prepay models)
2. Provider is allowed for the user (via RBAC)
@@ -130,34 +194,19 @@ class AiService:
Raises:
InsufficientBalanceException: If balance is insufficient
ProviderNotAllowedException: If provider is not allowed
+ BillingContextError: If billing check fails unexpectedly
"""
+ # Context is already validated by _preflightBillingCheck()
+ user = self.services.user
+ mandateId = self.services.mandateId
+ featureInstanceId = getattr(self.services, 'featureInstanceId', None)
+ featureCode = getattr(self.services, 'featureCode', None)
+
try:
- # Get context from services
- if not self.services:
- logger.debug("No service center - skipping billing check")
- return
-
- user = getattr(self.services, 'user', None)
- mandateId = getattr(self.services, 'mandateId', None)
-
- if not user or not mandateId:
- logger.debug("No user or mandate context - skipping billing check")
- return
-
- # Get feature context
- featureInstanceId = getattr(self.services, 'featureInstanceId', None)
- featureCode = getattr(self.services, 'featureCode', None)
-
# Get billing service
- billingService = getBillingService(
- user,
- mandateId,
- featureInstanceId,
- featureCode
- )
+ billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
# Check balance (estimate typical AI call cost)
- # We use a small estimate here; actual cost is recorded after the call
estimatedCost = 0.01 # ~1 cent CHF minimum
balanceCheck = billingService.checkBalance(estimatedCost)
@@ -170,7 +219,7 @@ class AiService:
raise InsufficientBalanceException(
currentBalance=balanceCheck.currentBalance or 0.0,
requiredAmount=estimatedCost,
- message=f"Ungenügendes Guthaben. Aktuell: CHF {balanceCheck.currentBalance:.2f}"
+ message=f"Ungenugendes Guthaben. Aktuell: CHF {balanceCheck.currentBalance:.2f}"
)
logger.debug(f"Billing check passed: Balance {balanceCheck.currentBalance:.2f} CHF")
@@ -181,13 +230,12 @@ class AiService:
logger.warning(f"No AI providers allowed for user {user.id} in mandate {mandateId}")
raise ProviderNotAllowedException(
provider="any",
- message="Keine AI-Provider für Ihre Rolle freigegeben. Kontaktieren Sie Ihren Administrator."
+ message="Keine AI-Provider fuer Ihre Rolle freigegeben. Kontaktieren Sie Ihren Administrator."
)
# Check automation-level allowedProviders restriction
automationAllowedProviders = getattr(self.services, 'allowedProviders', None)
if automationAllowedProviders:
- # Filter by both RBAC and automation-level restrictions
effectiveProviders = [p for p in automationAllowedProviders if p in rbacAllowedProviders]
if not effectiveProviders:
logger.warning(f"No providers available after automation restriction. "
@@ -195,7 +243,7 @@ class AiService:
f"RBAC allows: {rbacAllowedProviders}")
raise ProviderNotAllowedException(
provider="any",
- message="Die konfigurierten AI-Provider dieser Automation sind für Ihre Rolle nicht freigegeben."
+ message="Die konfigurierten AI-Provider dieser Automation sind fuer Ihre Rolle nicht freigegeben."
)
logger.debug(f"Automation provider check passed: {effectiveProviders}")
@@ -207,19 +255,78 @@ class AiService:
logger.warning(f"Preferred provider {provider} not allowed for user {user.id}")
raise ProviderNotAllowedException(
provider=provider,
- message=f"Der gewählte Provider '{provider}' ist für Ihre Rolle nicht freigegeben."
+ message=f"Der gewaehlte Provider '{provider}' ist fuer Ihre Rolle nicht freigegeben."
)
logger.debug(f"All preferred providers are allowed: {preferredProviders}")
logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed")
except InsufficientBalanceException:
- raise # Re-raise billing exceptions
+ raise
except ProviderNotAllowedException:
- raise # Re-raise provider exceptions
+ raise
+ except BillingContextError:
+ raise
except Exception as e:
- # Log but don't block on billing check errors
- logger.warning(f"Billing check failed with error (non-blocking): {e}")
+ # FAIL-SAFE: Don't silently swallow errors - log at ERROR level
+ logger.error(f"BILLING FAIL-SAFE: Billing check failed with unexpected error: {e}")
+ raise BillingContextError(f"Billing check failed: {e}")
+
+ def _createBillingCallback(self):
+ """
+ Create a billing callback for interfaceAiObjects._callWithModel().
+
+ Returns a function that records one billing transaction per individual model call.
+ Each transaction contains the exact provider name AND model name.
+
+ For a 200 MB document processed with N parallel AI calls (possibly different models),
+ this creates N separate billing transactions - one per model call.
+ """
+ user = self.services.user
+ mandateId = self.services.mandateId
+ featureInstanceId = getattr(self.services, 'featureInstanceId', None)
+ featureCode = getattr(self.services, 'featureCode', None)
+
+ # Get workflow ID if available
+ workflowId = None
+ workflow = getattr(self.services, 'workflow', None)
+ if workflow and hasattr(workflow, 'id'):
+ workflowId = workflow.id
+
+ billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
+
+ def _billingCallback(response) -> None:
+ """Record billing for a single AI model call."""
+ if not response or getattr(response, 'errorCount', 0) > 0:
+ return
+
+ priceCHF = getattr(response, 'priceCHF', 0.0)
+ if not priceCHF or priceCHF <= 0:
+ return
+
+ provider = getattr(response, 'provider', None) or 'unknown'
+ modelName = getattr(response, 'modelName', None) or 'unknown'
+
+ try:
+ billingService.recordUsage(
+ priceCHF=priceCHF,
+ workflowId=workflowId,
+ aicoreProvider=provider,
+ aicoreModel=modelName,
+ description=f"AI: {modelName}"
+ )
+ logger.debug(
+ f"Billed model call: {priceCHF:.4f} CHF, "
+ f"provider={provider}, model={modelName}, mandate={mandateId}"
+ )
+ except Exception as e:
+ logger.error(
+ f"BILLING: Failed to record transaction! "
+ f"Cost={priceCHF:.4f} CHF, user={user.id}, mandate={mandateId}, "
+ f"provider={provider}, model={modelName}, error={e}"
+ )
+
+ return _billingCallback
def _calculateEffectiveProviders(self) -> Optional[List[str]]:
"""
diff --git a/modules/services/serviceBilling/mainServiceBilling.py b/modules/services/serviceBilling/mainServiceBilling.py
index 472e0b58..8bf1c2c4 100644
--- a/modules/services/serviceBilling/mainServiceBilling.py
+++ b/modules/services/serviceBilling/mainServiceBilling.py
@@ -192,6 +192,7 @@ class BillingService:
priceCHF: float,
workflowId: str = None,
aicoreProvider: str = None,
+ aicoreModel: str = None,
description: str = None
) -> Optional[Dict[str, Any]]:
"""
@@ -206,6 +207,7 @@ class BillingService:
priceCHF: Base price from AI model (before markup)
workflowId: Optional workflow ID
aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai')
+ aicoreModel: AICore model name (e.g., 'claude-4-sonnet', 'gpt-4o')
description: Optional description
Returns:
@@ -222,7 +224,7 @@ class BillingService:
# Build description
if not description:
- description = f"AI Usage: {aicoreProvider or 'unknown'}"
+ description = f"AI Usage: {aicoreModel or aicoreProvider or 'unknown'}"
return self._billingInterface.recordUsage(
mandateId=self.mandateId,
@@ -232,6 +234,7 @@ class BillingService:
featureInstanceId=self.featureInstanceId,
featureCode=self.featureCode,
aicoreProvider=aicoreProvider,
+ aicoreModel=aicoreModel,
description=description
)
@@ -399,3 +402,16 @@ class ProviderNotAllowedException(Exception):
self.provider = provider
self.message = message or f"Provider '{provider}' is not allowed for your role"
super().__init__(self.message)
+
+
+class BillingContextError(Exception):
+ """Raised when billing context is incomplete (missing mandateId, user, etc.).
+
+ This is a FAIL-SAFE error: AI calls MUST NOT proceed without valid billing context.
+ Acts like a 0 CHF credit card pre-authorization check - validates that billing
+ CAN be recorded before any expensive AI operation starts.
+ """
+
+ def __init__(self, message: str = None):
+ self.message = message or "Billing context incomplete - AI call blocked"
+ super().__init__(self.message)
diff --git a/modules/services/serviceChat/mainServiceChat.py b/modules/services/serviceChat/mainServiceChat.py
index 055e34cd..b1e4a7ae 100644
--- a/modules/services/serviceChat/mainServiceChat.py
+++ b/modules/services/serviceChat/mainServiceChat.py
@@ -675,9 +675,11 @@ class ChatService:
def storeWorkflowStat(self, workflow: Any, aiResponse: Any, process: str) -> ChatStat:
"""Persist workflow-level ChatStat from AiCallResponse and append to workflow stats list.
- Also records the usage cost to the billing system if configured."""
+
+ Billing is handled at the AI call source (interfaceAiObjects._callWithModel)
+ via billingCallback - not here. This method only handles workflow stats.
+ """
try:
- # Create ChatStat from AiCallResponse data
statData = {
"workflowId": workflow.id,
"process": process,
@@ -689,76 +691,16 @@ class ChatService:
"errorCount": aiResponse.errorCount
}
- # Create the stat record in the database
stat = self.interfaceDbChat.createStat(statData)
- # Append to workflow stats list in memory
if not hasattr(workflow, 'stats') or workflow.stats is None:
workflow.stats = []
workflow.stats.append(stat)
- # Record billing transaction if mandateId is available
- self._recordBillingUsage(workflow, aiResponse, process)
-
return stat
except Exception as e:
logger.error(f"Failed to store workflow stat: {e}")
raise
-
- def _recordBillingUsage(self, workflow: Any, aiResponse: Any, process: str) -> None:
- """Record AI usage to the billing system.
-
- This method:
- 1. Gets the mandate context from services
- 2. Records the usage cost with 50% markup via BillingService
-
- Args:
- workflow: ChatWorkflow object
- aiResponse: AI call response with cost information
- process: Process identifier for the AI call
- """
- try:
- # Check if we have mandate context
- mandateId = getattr(self.services, 'mandateId', None)
- if not mandateId:
- logger.debug("No mandate context, skipping billing recording")
- return
-
- # Check if there's a cost to record
- priceCHF = getattr(aiResponse, 'priceCHF', 0.0)
- if not priceCHF or priceCHF <= 0:
- return
-
- # Get provider from AiCallResponse (set from model.connectorType)
- aicoreProvider = getattr(aiResponse, 'provider', None) or 'unknown'
-
- # Get feature context if available
- featureInstanceId = getattr(self.services, 'featureInstanceId', None)
- featureCode = getattr(self.services, 'featureCode', None)
-
- # Import and use BillingService
- from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService
-
- billingService = getBillingService(
- self.user,
- mandateId,
- featureInstanceId=featureInstanceId,
- featureCode=featureCode
- )
-
- # Record the usage (includes 50% markup)
- billingService.recordUsage(
- priceCHF=priceCHF,
- workflowId=workflow.id,
- aicoreProvider=aicoreProvider,
- description=f"AI Usage: {process}"
- )
-
- logger.debug(f"Recorded billing usage: {priceCHF} CHF for {process} (provider: {aicoreProvider})")
-
- except Exception as e:
- # Don't fail the main operation if billing recording fails
- logger.warning(f"Failed to record billing usage (non-critical): {e}")
def updateMessage(self, messageId: str, messageData: Dict[str, Any]):
"""Update message by delegating to the chat interface"""
diff --git a/modules/workflows/automation/mainWorkflow.py b/modules/workflows/automation/mainWorkflow.py
index 172fc977..06d36dae 100644
--- a/modules/workflows/automation/mainWorkflow.py
+++ b/modules/workflows/automation/mainWorkflow.py
@@ -24,7 +24,7 @@ from .subAutomationUtils import parseScheduleToCron, planToPrompt, replacePlaceh
logger = logging.getLogger(__name__)
-async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow:
+async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None, featureCode: Optional[str] = None) -> ChatWorkflow:
"""
Starts a new chat or continues an existing one, then launches processing asynchronously.
@@ -32,12 +32,10 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
currentUser: Current user
userInput: User input request
workflowId: Optional workflow ID to continue existing workflow
- workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution
- mandateId: Mandate ID from request context (required for proper data isolation)
- featureInstanceId: Feature instance ID for context
-
- Example usage for Dynamic mode:
- workflow = await chatStart(currentUser, userInput, workflowMode=WorkflowModeEnum.WORKFLOW_DYNAMIC, mandateId=mandateId)
+ workflowMode: Workflow mode (Dynamic, Automation, etc.)
+ mandateId: Mandate ID (required for billing)
+ featureInstanceId: Feature instance ID (required for billing)
+ featureCode: Feature code (e.g., 'chatplayground', 'automation')
"""
try:
services = getServices(currentUser, mandateId=mandateId)
@@ -47,10 +45,11 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
services.allowedProviders = userInput.allowedProviders
logger.info(f"AI provider filter active: {userInput.allowedProviders}")
- # Store feature instance ID in services context
+ # Store feature context in services (for billing and RBAC)
if featureInstanceId:
services.featureInstanceId = featureInstanceId
- services.featureCode = 'chatplayground'
+ if featureCode:
+ services.featureCode = featureCode
workflowManager = WorkflowManager(services)
workflow = await workflowManager.workflowStart(userInput, workflowMode, workflowId)
@@ -105,8 +104,11 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
services.allowedProviders = automation.allowedProviders
logger.debug(f"Automation {automationId} restricted to providers: {automation.allowedProviders}")
- # Store feature context for billing
+ # Context comes EXCLUSIVELY from the automation definition
+ services.mandateId = str(automation.mandateId)
+ services.featureInstanceId = str(automation.featureInstanceId)
services.featureCode = 'automation'
+ featureInstanceId = services.featureInstanceId
# 2. Replace placeholders in template to generate plan
template = automation.template or ""
@@ -159,11 +161,16 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
executionLog["messages"].append("Starting workflow execution")
# 5. Start workflow using chatStart
+ # Pass mandateId, featureInstanceId, and featureCode from original services context
+ # so billing is recorded correctly with full feature context
workflow = await chatStart(
currentUser=creatorUser,
userInput=userInput,
workflowMode=WorkflowModeEnum.WORKFLOW_AUTOMATION,
- workflowId=None
+ workflowId=None,
+ mandateId=services.mandateId,
+ featureInstanceId=featureInstanceId,
+ featureCode='automation'
)
executionLog["workflowId"] = workflow.id