billing rbac
This commit is contained in:
parent
8d28f6d77b
commit
887867acd0
12 changed files with 506 additions and 197 deletions
|
|
@ -241,8 +241,10 @@ class AiPrivateLlm(BaseConnectorAi):
|
|||
priority=PriorityEnum.COST,
|
||||
processingMode=ProcessingModeEnum.BASIC,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.DATA_EXTRACT, 9),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 9),
|
||||
(OperationTypeEnum.PLAN, 7),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
(OperationTypeEnum.DATA_GENERATE, 8),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 8),
|
||||
),
|
||||
version="qwen2.5:7b",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_TEXT_PER_CALL
|
||||
|
|
@ -268,7 +270,6 @@ class AiPrivateLlm(BaseConnectorAi):
|
|||
processingMode=ProcessingModeEnum.ADVANCED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.IMAGE_ANALYSE, 9),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 8),
|
||||
),
|
||||
version="qwen2.5vl:7b",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_VISION_PER_CALL
|
||||
|
|
@ -294,8 +295,6 @@ class AiPrivateLlm(BaseConnectorAi):
|
|||
processingMode=ProcessingModeEnum.DETAILED,
|
||||
operationTypes=createOperationTypeRatings(
|
||||
(OperationTypeEnum.IMAGE_ANALYSE, 9),
|
||||
(OperationTypeEnum.DATA_EXTRACT, 9),
|
||||
(OperationTypeEnum.DATA_ANALYSE, 8),
|
||||
),
|
||||
version="granite3.2-vision",
|
||||
calculatepriceCHF=lambda processingTime, bytesSent, bytesReceived: PRICE_VISION_PER_CALL
|
||||
|
|
|
|||
|
|
@ -121,6 +121,8 @@ class BillingTransaction(BaseModel):
|
|||
featureInstanceId: Optional[str] = Field(None, description="Feature instance ID")
|
||||
featureCode: Optional[str] = Field(None, description="Feature code (e.g., chatplayground, automation)")
|
||||
aicoreProvider: Optional[str] = Field(None, description="AICore provider (anthropic, openai, etc.)")
|
||||
aicoreModel: Optional[str] = Field(None, description="AICore model name (e.g., claude-4-sonnet, gpt-4o)")
|
||||
createdByUserId: Optional[str] = Field(None, description="User who created/caused this transaction")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
|
|
@ -138,6 +140,8 @@ registerModelLabels(
|
|||
"featureInstanceId": {"en": "Feature Instance ID", "de": "Feature-Instanz-ID"},
|
||||
"featureCode": {"en": "Feature Code", "de": "Feature-Code"},
|
||||
"aicoreProvider": {"en": "AI Provider", "de": "AI-Anbieter"},
|
||||
"aicoreModel": {"en": "AI Model", "de": "AI-Modell"},
|
||||
"createdByUserId": {"en": "Created By User", "de": "Erstellt von Benutzer"},
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -367,6 +367,10 @@ async def execute_automation_route(
|
|||
try:
|
||||
from modules.services import getInterface as getServices
|
||||
services = getServices(context.user, context.mandateId)
|
||||
# Propagate feature context for billing
|
||||
if context.featureInstanceId:
|
||||
services.featureInstanceId = str(context.featureInstanceId)
|
||||
services.featureCode = 'automation'
|
||||
workflow = await executeAutomation(automationId, services)
|
||||
return workflow
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -93,6 +93,12 @@ async def chatProcess(
|
|||
try:
|
||||
# Get services with mandate context
|
||||
services = getServices(currentUser, mandateId)
|
||||
|
||||
# Set feature context for billing
|
||||
if featureInstanceId:
|
||||
services.featureInstanceId = featureInstanceId
|
||||
services.featureCode = 'chatbot'
|
||||
|
||||
interfaceDbChat = services.interfaceDbChat
|
||||
|
||||
# Get event manager and create queue if needed
|
||||
|
|
@ -698,7 +704,7 @@ async def _convert_file_ids_to_document_references(
|
|||
# Search database if not found in messages
|
||||
if not document_id:
|
||||
try:
|
||||
from modules.shared.databaseUtils import getRecordsetWithRBAC
|
||||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
||||
documents = getRecordsetWithRBAC(
|
||||
services.interfaceDbChat.db,
|
||||
ChatDocument,
|
||||
|
|
|
|||
|
|
@ -102,7 +102,8 @@ async def start_workflow(
|
|||
workflowMode,
|
||||
workflowId,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId
|
||||
featureInstanceId=instanceId,
|
||||
featureCode='chatplayground'
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import logging
|
|||
import asyncio
|
||||
import uuid
|
||||
import base64
|
||||
from typing import Dict, Any, List, Union, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any, List, Union, Tuple, Optional, Callable
|
||||
from dataclasses import dataclass, field
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -29,7 +29,13 @@ from modules.datamodels.datamodelExtraction import ContentPart, MergeStrategy
|
|||
|
||||
@dataclass(slots=True)
|
||||
class AiObjects:
|
||||
"""Centralized AI interface: dynamically discovers and uses AI models. Includes web functionality."""
|
||||
"""Centralized AI interface: dynamically discovers and uses AI models.
|
||||
|
||||
billingCallback: Set by serviceAi before AI calls. Called after EVERY individual
|
||||
model call with the AiCallResponse. This ensures per-model-call billing with
|
||||
exact provider + model name. The callback handles billing recording.
|
||||
"""
|
||||
billingCallback: Optional[Callable] = field(default=None, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Auto-discover and register all available connectors
|
||||
|
|
@ -226,7 +232,7 @@ class AiObjects:
|
|||
# Calculate price using model's own price calculation method
|
||||
priceCHF = model.calculatepriceCHF(processingTime, inputBytes, outputBytes)
|
||||
|
||||
return AiCallResponse(
|
||||
response = AiCallResponse(
|
||||
content=content,
|
||||
modelName=model.name,
|
||||
provider=model.connectorType,
|
||||
|
|
@ -236,6 +242,17 @@ class AiObjects:
|
|||
bytesReceived=outputBytes,
|
||||
errorCount=0
|
||||
)
|
||||
|
||||
# BILLING: Record billing for THIS specific model call
|
||||
# billingCallback is set by serviceAi and records one billing transaction
|
||||
# per model call with exact provider + model name
|
||||
if self.billingCallback:
|
||||
try:
|
||||
self.billingCallback(response)
|
||||
except Exception as e:
|
||||
logger.error(f"BILLING: Failed to record billing for model {model.name}: {e}")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Utility methods
|
||||
|
|
|
|||
|
|
@ -710,6 +710,7 @@ class BillingObjects:
|
|||
featureInstanceId: str = None,
|
||||
featureCode: str = None,
|
||||
aicoreProvider: str = None,
|
||||
aicoreModel: str = None,
|
||||
description: str = "AI Usage"
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
|
|
@ -722,7 +723,8 @@ class BillingObjects:
|
|||
workflowId: Optional workflow ID
|
||||
featureInstanceId: Optional feature instance ID
|
||||
featureCode: Optional feature code
|
||||
aicoreProvider: Optional AICore provider name
|
||||
aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai')
|
||||
aicoreModel: AICore model name (e.g., 'claude-4-sonnet', 'gpt-4o')
|
||||
description: Transaction description
|
||||
|
||||
Returns:
|
||||
|
|
@ -758,7 +760,9 @@ class BillingObjects:
|
|||
workflowId=workflowId,
|
||||
featureInstanceId=featureInstanceId,
|
||||
featureCode=featureCode,
|
||||
aicoreProvider=aicoreProvider
|
||||
aicoreProvider=aicoreProvider,
|
||||
aicoreModel=aicoreModel,
|
||||
createdByUserId=userId
|
||||
)
|
||||
|
||||
return self.createTransaction(transaction)
|
||||
|
|
@ -828,9 +832,13 @@ class BillingObjects:
|
|||
|
||||
# Calculate by provider
|
||||
costByProvider = {}
|
||||
costByModel = {}
|
||||
for t in debits:
|
||||
provider = t.get("aicoreProvider", "unknown")
|
||||
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
|
||||
|
||||
model = t.get("aicoreModel", "unknown")
|
||||
costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0)
|
||||
|
||||
# Calculate by feature
|
||||
costByFeature = {}
|
||||
|
|
@ -842,6 +850,7 @@ class BillingObjects:
|
|||
"totalCostCHF": totalCost,
|
||||
"transactionCount": len(debits),
|
||||
"costByProvider": costByProvider,
|
||||
"costByModel": costByModel,
|
||||
"costByFeature": costByFeature
|
||||
}
|
||||
|
||||
|
|
@ -1129,8 +1138,8 @@ class BillingObjects:
|
|||
user = appInterface.getUser(userId)
|
||||
if user:
|
||||
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
|
||||
email = getattr(user, 'email', None) or (user.get("email") if isinstance(user, dict) else None)
|
||||
userMap[userId] = displayName or email or userId
|
||||
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
|
||||
userMap[userId] = displayName or username or userId
|
||||
|
||||
# Get mandate info efficiently
|
||||
mandateMap = {}
|
||||
|
|
@ -1212,8 +1221,8 @@ class BillingObjects:
|
|||
user = appInterface.getUser(userId)
|
||||
if user:
|
||||
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
|
||||
email = getattr(user, 'email', None) or (user.get("email") if isinstance(user, dict) else None)
|
||||
userMap[userId] = displayName or email or userId
|
||||
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
|
||||
userMap[userId] = displayName or username or userId
|
||||
|
||||
# Get mandate info efficiently
|
||||
mandateMap = {}
|
||||
|
|
@ -1224,7 +1233,8 @@ class BillingObjects:
|
|||
mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "")
|
||||
mandateMap[mandateId] = mandateName
|
||||
|
||||
# Get transactions for all accounts
|
||||
# Get transactions for all accounts and collect createdByUserIds
|
||||
rawTransactions = []
|
||||
for account in allAccounts:
|
||||
accountId = account.get("id")
|
||||
if not accountId:
|
||||
|
|
@ -1233,14 +1243,38 @@ class BillingObjects:
|
|||
transactions = self.getTransactions(accountId, limit=limit)
|
||||
accountInfo = accountMap.get(accountId, {})
|
||||
mandateId = accountInfo.get("mandateId")
|
||||
userId = accountInfo.get("userId")
|
||||
accountUserId = accountInfo.get("userId")
|
||||
|
||||
for t in transactions:
|
||||
t["mandateId"] = mandateId
|
||||
t["mandateName"] = mandateMap.get(mandateId, "")
|
||||
t["userId"] = userId
|
||||
t["userName"] = userMap.get(userId, userId)
|
||||
allTransactions.append(t)
|
||||
t["_accountUserId"] = accountUserId
|
||||
t["_accountMandateId"] = mandateId
|
||||
rawTransactions.append(t)
|
||||
|
||||
# Resolve createdByUserIds that are not yet in userMap
|
||||
extraUserIds = set()
|
||||
for t in rawTransactions:
|
||||
cbUserId = t.get("createdByUserId")
|
||||
if cbUserId and cbUserId not in userMap:
|
||||
extraUserIds.add(cbUserId)
|
||||
|
||||
for uid in extraUserIds:
|
||||
user = appInterface.getUser(uid)
|
||||
if user:
|
||||
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
|
||||
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
|
||||
userMap[uid] = displayName or username or uid
|
||||
|
||||
# Enrich transactions
|
||||
for t in rawTransactions:
|
||||
mandateId = t.pop("_accountMandateId", None)
|
||||
accountUserId = t.pop("_accountUserId", None)
|
||||
t["mandateId"] = mandateId
|
||||
t["mandateName"] = mandateMap.get(mandateId, "")
|
||||
# Prefer createdByUserId (per-transaction) over account-derived userId
|
||||
txUserId = t.get("createdByUserId") or accountUserId
|
||||
t["userId"] = txUserId
|
||||
t["userName"] = userMap.get(txUserId, txUserId) if txUserId else None
|
||||
allTransactions.append(t)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user transactions for mandates: {e}")
|
||||
|
|
|
|||
|
|
@ -42,6 +42,143 @@ from modules.datamodels.datamodelBilling import (
|
|||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Billing RBAC Data Scope
|
||||
# =============================================================================
|
||||
#
|
||||
# RBAC rules for billing data visibility:
|
||||
#
|
||||
# SysAdmin → ALL transactions and statistics across all mandates
|
||||
# Mandate-Admin → ALL user data within their administrated mandates
|
||||
# Feature-Instance-Admin→ Data for their administrated feature instances
|
||||
# Regular User → ONLY their own data within their mandates
|
||||
#
|
||||
|
||||
class BillingDataScope:
|
||||
"""
|
||||
Determines what billing data a user can see based on RBAC roles.
|
||||
|
||||
Evaluated once per request and used to filter transactions/statistics.
|
||||
"""
|
||||
__slots__ = ('isGlobalAdmin', 'adminMandateIds', 'adminFeatureInstanceIds',
|
||||
'memberMandateIds', 'userId')
|
||||
|
||||
def __init__(self, userId: str):
|
||||
self.isGlobalAdmin: bool = False
|
||||
self.adminMandateIds: list = []
|
||||
self.adminFeatureInstanceIds: list = []
|
||||
self.memberMandateIds: list = []
|
||||
self.userId: str = userId
|
||||
|
||||
|
||||
def _getBillingDataScope(user) -> BillingDataScope:
|
||||
"""
|
||||
Determine what billing data a user can see based on RBAC.
|
||||
|
||||
Uses rootInterface (privileged) to check roles across all mandates
|
||||
and feature instances without RBAC restrictions on the lookup itself.
|
||||
|
||||
Returns:
|
||||
BillingDataScope with the user's visibility boundaries.
|
||||
"""
|
||||
scope = BillingDataScope(userId=user.id)
|
||||
|
||||
if user.isSysAdmin:
|
||||
scope.isGlobalAdmin = True
|
||||
return scope
|
||||
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
rootInterface = getRootInterface()
|
||||
|
||||
# --- Mandate roles ---
|
||||
userMandates = rootInterface.getUserMandates(user.id)
|
||||
for um in userMandates:
|
||||
mandateId = getattr(um, 'mandateId', None)
|
||||
umId = getattr(um, 'id', None)
|
||||
if not mandateId or not umId:
|
||||
continue
|
||||
|
||||
roleIds = rootInterface.getRoleIdsForUserMandate(umId)
|
||||
isAdmin = False
|
||||
for roleId in roleIds:
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role and role.roleLabel == "admin" and not role.featureInstanceId:
|
||||
isAdmin = True
|
||||
break
|
||||
|
||||
if isAdmin:
|
||||
scope.adminMandateIds.append(mandateId)
|
||||
else:
|
||||
scope.memberMandateIds.append(mandateId)
|
||||
|
||||
# --- Feature instance roles ---
|
||||
featureAccesses = rootInterface.getFeatureAccessesForUser(user.id)
|
||||
for fa in featureAccesses:
|
||||
fiId = getattr(fa, 'featureInstanceId', None)
|
||||
faId = getattr(fa, 'id', None)
|
||||
if not fiId or not faId:
|
||||
continue
|
||||
|
||||
roleIds = rootInterface.getRoleIdsForFeatureAccess(faId)
|
||||
for roleId in roleIds:
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role and role.roleLabel == "admin":
|
||||
scope.adminFeatureInstanceIds.append(fiId)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
f"BillingDataScope for user {user.id}: "
|
||||
f"globalAdmin={scope.isGlobalAdmin}, "
|
||||
f"adminMandates={scope.adminMandateIds}, "
|
||||
f"adminInstances={scope.adminFeatureInstanceIds}, "
|
||||
f"memberMandates={scope.memberMandateIds}"
|
||||
)
|
||||
return scope
|
||||
|
||||
|
||||
def _filterTransactionsByScope(transactions: list, scope: BillingDataScope) -> list:
|
||||
"""
|
||||
Filter a list of transaction dicts based on the user's BillingDataScope.
|
||||
|
||||
Rules:
|
||||
- SysAdmin: no filter
|
||||
- Mandate-Admin: all transactions in their admin mandates
|
||||
- Feature-Instance-Admin: transactions for their admin feature instances
|
||||
- Regular user: only transactions where createdByUserId/userId matches
|
||||
"""
|
||||
if scope.isGlobalAdmin:
|
||||
return transactions
|
||||
|
||||
adminMandateSet = set(scope.adminMandateIds)
|
||||
adminFiSet = set(scope.adminFeatureInstanceIds)
|
||||
memberMandateSet = set(scope.memberMandateIds)
|
||||
|
||||
result = []
|
||||
for t in transactions:
|
||||
mandateId = t.get("mandateId")
|
||||
fiId = t.get("featureInstanceId")
|
||||
txUserId = t.get("createdByUserId") or t.get("userId")
|
||||
|
||||
# Mandate admin → sees all transactions in their mandate
|
||||
if mandateId and mandateId in adminMandateSet:
|
||||
result.append(t)
|
||||
continue
|
||||
|
||||
# Feature instance admin → sees all transactions for their instances
|
||||
if fiId and fiId in adminFiSet:
|
||||
result.append(t)
|
||||
continue
|
||||
|
||||
# Regular member → only own transactions
|
||||
if mandateId and mandateId in memberMandateSet:
|
||||
if txUserId and txUserId == scope.userId:
|
||||
result.append(t)
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
|
@ -74,7 +211,10 @@ class TransactionResponse(BaseModel):
|
|||
referenceType: Optional[ReferenceTypeEnum]
|
||||
workflowId: Optional[str]
|
||||
featureCode: Optional[str]
|
||||
featureInstanceId: Optional[str] = None
|
||||
aicoreProvider: Optional[str]
|
||||
aicoreModel: Optional[str] = None
|
||||
createdByUserId: Optional[str] = None
|
||||
createdAt: Optional[datetime]
|
||||
mandateId: Optional[str] = None
|
||||
mandateName: Optional[str] = None
|
||||
|
|
@ -98,6 +238,7 @@ class UsageReportResponse(BaseModel):
|
|||
totalCost: float
|
||||
transactionCount: int
|
||||
costByProvider: Dict[str, float]
|
||||
costByModel: Dict[str, float] = {}
|
||||
costByFeature: Dict[str, float]
|
||||
|
||||
|
||||
|
|
@ -140,7 +281,10 @@ class UserTransactionResponse(BaseModel):
|
|||
referenceType: Optional[ReferenceTypeEnum]
|
||||
workflowId: Optional[str]
|
||||
featureCode: Optional[str]
|
||||
featureInstanceId: Optional[str] = None
|
||||
aicoreProvider: Optional[str]
|
||||
aicoreModel: Optional[str] = None
|
||||
createdByUserId: Optional[str] = None
|
||||
createdAt: Optional[datetime]
|
||||
mandateId: Optional[str] = None
|
||||
mandateName: Optional[str] = None
|
||||
|
|
@ -261,7 +405,10 @@ def getTransactions(
|
|||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||||
workflowId=t.get("workflowId"),
|
||||
featureCode=t.get("featureCode"),
|
||||
featureInstanceId=t.get("featureInstanceId"),
|
||||
aicoreProvider=t.get("aicoreProvider"),
|
||||
aicoreModel=t.get("aicoreModel"),
|
||||
createdByUserId=t.get("createdByUserId"),
|
||||
createdAt=t.get("_createdAt"),
|
||||
mandateId=t.get("mandateId"),
|
||||
mandateName=t.get("mandateName")
|
||||
|
|
@ -349,6 +496,7 @@ def getStatistics(
|
|||
totalCost=stats.get("totalCostCHF", 0.0),
|
||||
transactionCount=stats.get("transactionCount", 0),
|
||||
costByProvider=stats.get("costByProvider", {}),
|
||||
costByModel=stats.get("costByModel", {}),
|
||||
costByFeature=stats.get("costByFeature", {})
|
||||
)
|
||||
|
||||
|
|
@ -564,6 +712,7 @@ def getAccounts(
|
|||
class MandateUserSummary(BaseModel):
|
||||
"""Summary of a user for billing admin purposes."""
|
||||
id: str
|
||||
username: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
firstName: Optional[str] = None
|
||||
lastName: Optional[str] = None
|
||||
|
|
@ -600,18 +749,21 @@ def getUsersForMandate(
|
|||
|
||||
# Handle both Pydantic models and dicts
|
||||
if isinstance(user, dict):
|
||||
username = user.get("username", "")
|
||||
firstName = user.get("firstName", "")
|
||||
lastName = user.get("lastName", "")
|
||||
email = user.get("email", "")
|
||||
else:
|
||||
username = getattr(user, "username", "") or ""
|
||||
firstName = getattr(user, "firstName", "") or ""
|
||||
lastName = getattr(user, "lastName", "") or ""
|
||||
email = getattr(user, "email", "") or ""
|
||||
|
||||
displayName = f"{firstName} {lastName}".strip() or email or userId
|
||||
displayName = f"{firstName} {lastName}".strip() or username or userId
|
||||
|
||||
result.append(MandateUserSummary(
|
||||
id=userId,
|
||||
username=username,
|
||||
email=email,
|
||||
firstName=firstName,
|
||||
lastName=lastName,
|
||||
|
|
@ -652,7 +804,10 @@ def getTransactionsAdmin(
|
|||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||||
workflowId=t.get("workflowId"),
|
||||
featureCode=t.get("featureCode"),
|
||||
featureInstanceId=t.get("featureInstanceId"),
|
||||
aicoreProvider=t.get("aicoreProvider"),
|
||||
aicoreModel=t.get("aicoreModel"),
|
||||
createdByUserId=t.get("createdByUserId"),
|
||||
createdAt=t.get("_createdAt")
|
||||
))
|
||||
|
||||
|
|
@ -715,7 +870,10 @@ def getMandateViewTransactions(
|
|||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||||
workflowId=t.get("workflowId"),
|
||||
featureCode=t.get("featureCode"),
|
||||
featureInstanceId=t.get("featureInstanceId"),
|
||||
aicoreProvider=t.get("aicoreProvider"),
|
||||
aicoreModel=t.get("aicoreModel"),
|
||||
createdByUserId=t.get("createdByUserId"),
|
||||
createdAt=t.get("_createdAt"),
|
||||
mandateId=t.get("mandateId"),
|
||||
mandateName=t.get("mandateName")
|
||||
|
|
@ -740,39 +898,35 @@ def getUserViewBalances(
|
|||
):
|
||||
"""
|
||||
Get user-level balances.
|
||||
|
||||
RBAC filtering:
|
||||
- SysAdmin: sees all user balances across all mandates
|
||||
- MandateAdmin: sees user balances for mandates they manage
|
||||
- Mandate-Admin: sees user balances for mandates they administrate
|
||||
- Regular user: sees only their own balances
|
||||
"""
|
||||
try:
|
||||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||||
|
||||
# Determine which mandates the user has access to
|
||||
if ctx.user.isSysAdmin:
|
||||
# SysAdmin sees all
|
||||
# Evaluate RBAC scope
|
||||
scope = _getBillingDataScope(ctx.user)
|
||||
|
||||
# Determine mandate IDs for data loading
|
||||
if scope.isGlobalAdmin:
|
||||
mandateIds = None
|
||||
else:
|
||||
# Get mandates where user is admin or has billing access
|
||||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||||
appInterface = getAppInterface(ctx.user)
|
||||
userMandates = appInterface.getUserMandates(ctx.user.id)
|
||||
|
||||
# Filter to only mandates where user has admin role
|
||||
# For simplicity, we'll check if user is admin in any mandate
|
||||
mandateIds = []
|
||||
for um in userMandates:
|
||||
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
|
||||
if mandateId:
|
||||
mandateIds.append(mandateId)
|
||||
|
||||
mandateIds = scope.adminMandateIds + scope.memberMandateIds
|
||||
if not mandateIds:
|
||||
return []
|
||||
|
||||
allBalances = billingInterface.getUserBalancesForMandates(mandateIds)
|
||||
|
||||
# Non-admin users only see their own balances
|
||||
if not ctx.user.isSysAdmin:
|
||||
allBalances = [b for b in allBalances if b.get("userId") == ctx.user.id]
|
||||
# RBAC filter: mandate admins see all in their mandates, regular users only own
|
||||
if not scope.isGlobalAdmin:
|
||||
adminMandateSet = set(scope.adminMandateIds)
|
||||
allBalances = [
|
||||
b for b in allBalances
|
||||
if b.get("mandateId") in adminMandateSet or b.get("userId") == scope.userId
|
||||
]
|
||||
|
||||
return [UserBalanceResponse(**b) for b in allBalances]
|
||||
|
||||
|
|
@ -786,6 +940,7 @@ class ViewStatisticsResponse(BaseModel):
|
|||
totalCost: float = 0.0
|
||||
transactionCount: int = 0
|
||||
costByProvider: Dict[str, float] = {}
|
||||
costByModel: Dict[str, float] = {}
|
||||
costByFeature: Dict[str, float] = {}
|
||||
costByMandate: Dict[str, float] = {}
|
||||
timeSeries: List[Dict[str, Any]] = []
|
||||
|
|
@ -802,6 +957,13 @@ def getUserViewStatistics(
|
|||
) -> ViewStatisticsResponse:
|
||||
"""
|
||||
Get aggregated usage statistics across all user's mandates.
|
||||
|
||||
RBAC filtering:
|
||||
- SysAdmin: statistics across all mandates
|
||||
- Mandate-Admin: statistics for mandates they administrate
|
||||
- Feature-Instance-Admin: statistics for their feature instances
|
||||
- Regular user: only their own usage statistics
|
||||
|
||||
- period='month': returns monthly time series for the given year
|
||||
- period='day': returns daily time series for the given month/year
|
||||
"""
|
||||
|
|
@ -816,25 +978,25 @@ def getUserViewStatistics(
|
|||
|
||||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||||
|
||||
# Get all mandates the user has access to
|
||||
if ctx.user.isSysAdmin:
|
||||
# Evaluate RBAC scope
|
||||
scope = _getBillingDataScope(ctx.user)
|
||||
|
||||
# Determine mandate IDs for data loading
|
||||
if scope.isGlobalAdmin:
|
||||
mandateIds = None
|
||||
else:
|
||||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||||
appInterface = getAppInterface(ctx.user)
|
||||
userMandates = appInterface.getUserMandates(ctx.user.id)
|
||||
mandateIds = []
|
||||
for um in userMandates:
|
||||
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
|
||||
if mandateId:
|
||||
mandateIds.append(mandateId)
|
||||
mandateIds = scope.adminMandateIds + scope.memberMandateIds
|
||||
if not mandateIds:
|
||||
logger.warning("No mandate IDs found for user")
|
||||
return ViewStatisticsResponse()
|
||||
|
||||
# Get all transactions
|
||||
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
|
||||
logger.info(f"View statistics: {len(allTransactions)} total transactions fetched for period={period}, year={year}, month={month}")
|
||||
|
||||
# Apply RBAC filter
|
||||
allTransactions = _filterTransactionsByScope(allTransactions, scope)
|
||||
|
||||
logger.info(f"View statistics: {len(allTransactions)} RBAC-filtered transactions for period={period}, year={year}, month={month}")
|
||||
|
||||
# Calculate date range
|
||||
if period == "day":
|
||||
|
|
@ -905,6 +1067,7 @@ def getUserViewStatistics(
|
|||
totalCost = sum(t.get("amount", 0) for t in debits)
|
||||
|
||||
costByProvider: Dict[str, float] = {}
|
||||
costByModel: Dict[str, float] = {}
|
||||
costByFeature: Dict[str, float] = {}
|
||||
costByMandate: Dict[str, float] = {}
|
||||
|
||||
|
|
@ -912,6 +1075,9 @@ def getUserViewStatistics(
|
|||
provider = t.get("aicoreProvider") or "unknown"
|
||||
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
|
||||
|
||||
model = t.get("aicoreModel") or "unknown"
|
||||
costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0)
|
||||
|
||||
mandate = t.get("mandateName") or t.get("mandateId") or "unknown"
|
||||
featureCode = t.get("featureCode") or "unknown"
|
||||
featureKey = f"{mandate} / {featureCode}"
|
||||
|
|
@ -950,6 +1116,7 @@ def getUserViewStatistics(
|
|||
totalCost=round(totalCost, 4),
|
||||
transactionCount=len(debits),
|
||||
costByProvider=costByProvider,
|
||||
costByModel=costByModel,
|
||||
costByFeature=costByFeature,
|
||||
costByMandate=costByMandate,
|
||||
timeSeries=timeSeries
|
||||
|
|
@ -969,8 +1136,11 @@ def getUserViewTransactions(
|
|||
) -> PaginatedResponse[UserTransactionResponse]:
|
||||
"""
|
||||
Get user-level transactions with pagination support.
|
||||
|
||||
RBAC filtering:
|
||||
- SysAdmin: sees all user transactions across all mandates
|
||||
- MandateAdmin: sees user transactions for mandates they manage
|
||||
- Mandate-Admin: sees all user transactions for mandates they administrate
|
||||
- Feature-Instance-Admin: sees transactions for their feature instances
|
||||
- Regular user: sees only their own transactions
|
||||
|
||||
Query Parameters:
|
||||
|
|
@ -987,28 +1157,24 @@ def getUserViewTransactions(
|
|||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
|
||||
# Determine which mandates the user has access to
|
||||
if ctx.user.isSysAdmin:
|
||||
# SysAdmin sees all
|
||||
mandateIds = None
|
||||
# Evaluate RBAC scope
|
||||
scope = _getBillingDataScope(ctx.user)
|
||||
|
||||
# Determine mandate IDs for data loading
|
||||
if scope.isGlobalAdmin:
|
||||
mandateIds = None # Load all
|
||||
else:
|
||||
# Get mandates where user has access
|
||||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||||
appInterface = getAppInterface(ctx.user)
|
||||
userMandates = appInterface.getUserMandates(ctx.user.id)
|
||||
|
||||
mandateIds = []
|
||||
for um in userMandates:
|
||||
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
|
||||
if mandateId:
|
||||
mandateIds.append(mandateId)
|
||||
|
||||
# Load data for all mandates the user belongs to (admin + member)
|
||||
mandateIds = scope.adminMandateIds + scope.memberMandateIds
|
||||
if not mandateIds:
|
||||
return PaginatedResponse(items=[], pagination=None)
|
||||
|
||||
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
|
||||
|
||||
logger.debug(f"Found {len(allTransactions)} transactions for mandates {mandateIds}")
|
||||
# Apply RBAC filter
|
||||
allTransactions = _filterTransactionsByScope(allTransactions, scope)
|
||||
|
||||
logger.debug(f"RBAC-filtered {len(allTransactions)} transactions for user {ctx.user.id}")
|
||||
|
||||
# Convert to response objects as dicts for filtering/sorting
|
||||
transactionDicts = []
|
||||
|
|
@ -1022,7 +1188,10 @@ def getUserViewTransactions(
|
|||
"referenceType": t.get("referenceType"),
|
||||
"workflowId": t.get("workflowId"),
|
||||
"featureCode": t.get("featureCode"),
|
||||
"featureInstanceId": t.get("featureInstanceId"),
|
||||
"aicoreProvider": t.get("aicoreProvider"),
|
||||
"aicoreModel": t.get("aicoreModel"),
|
||||
"createdByUserId": t.get("createdByUserId"),
|
||||
"createdAt": t.get("_createdAt"),
|
||||
"mandateId": t.get("mandateId"),
|
||||
"mandateName": t.get("mandateName"),
|
||||
|
|
@ -1044,7 +1213,10 @@ def getUserViewTransactions(
|
|||
referenceType=ReferenceTypeEnum(d["referenceType"]) if d.get("referenceType") else None,
|
||||
workflowId=d.get("workflowId"),
|
||||
featureCode=d.get("featureCode"),
|
||||
featureInstanceId=d.get("featureInstanceId"),
|
||||
aicoreProvider=d.get("aicoreProvider"),
|
||||
aicoreModel=d.get("aicoreModel"),
|
||||
createdByUserId=d.get("createdByUserId"),
|
||||
createdAt=d.get("createdAt"),
|
||||
mandateId=d.get("mandateId"),
|
||||
mandateName=d.get("mandateName"),
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ from modules.datamodels.datamodelAi import JsonAccumulationState
|
|||
from modules.services.serviceBilling.mainServiceBilling import (
|
||||
getService as getBillingService,
|
||||
InsufficientBalanceException,
|
||||
ProviderNotAllowedException
|
||||
ProviderNotAllowedException,
|
||||
BillingContextError
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -88,41 +89,104 @@ class AiService:
|
|||
async def callAi(self, request: AiCallRequest, progressCallback=None):
|
||||
"""Router: handles content parts via extractionService, text context via interface.
|
||||
|
||||
Replaces direct calls to self.aiObjects.call() to route content parts processing
|
||||
through serviceExtraction layer.
|
||||
|
||||
Includes billing checks:
|
||||
- Balance check before AI call
|
||||
- Provider permission check (via RBAC)
|
||||
|
||||
Also stores workflow stats after each successful AI call.
|
||||
FAIL-SAFE BILLING at the source:
|
||||
1. Pre-flight check: validates billing context is complete (RAISES if not)
|
||||
2. Balance & provider check before AI call
|
||||
3. billingCallback on aiObjects: records one billing transaction per model call
|
||||
with exact provider + model name (set before AI call, invoked by _callWithModel)
|
||||
"""
|
||||
# Billing check before AI call (validates RBAC permissions)
|
||||
# FAIL-SAFE: Pre-flight billing validation (like 0 CHF credit card check)
|
||||
self._preflightBillingCheck()
|
||||
|
||||
# Balance & provider permission checks
|
||||
await self._checkBillingBeforeAiCall()
|
||||
|
||||
# Calculate effective allowedProviders: RBAC ∩ Workflow
|
||||
# RBAC is master - only RBAC-permitted providers can ever be used
|
||||
effectiveProviders = self._calculateEffectiveProviders()
|
||||
if effectiveProviders and request.options:
|
||||
request.options = request.options.model_copy(update={'allowedProviders': effectiveProviders})
|
||||
logger.debug(f"Effective allowedProviders for AI request: {effectiveProviders}")
|
||||
|
||||
if hasattr(request, 'contentParts') and request.contentParts:
|
||||
response = await self.extractionService.processContentPartsWithAi(
|
||||
request, self.aiObjects, progressCallback
|
||||
)
|
||||
else:
|
||||
response = await self.aiObjects.callWithTextContext(request)
|
||||
# Set billing callback on aiObjects BEFORE the AI call
|
||||
# This callback is invoked by _callWithModel() after EVERY individual model call
|
||||
# For parallel content parts (e.g., 200 MB doc), each model call creates its own transaction
|
||||
self.aiObjects.billingCallback = self._createBillingCallback()
|
||||
|
||||
# Store workflow stats after each AI call
|
||||
try:
|
||||
if hasattr(request, 'contentParts') and request.contentParts:
|
||||
response = await self.extractionService.processContentPartsWithAi(
|
||||
request, self.aiObjects, progressCallback
|
||||
)
|
||||
else:
|
||||
response = await self.aiObjects.callWithTextContext(request)
|
||||
finally:
|
||||
# Clear callback after call completes
|
||||
self.aiObjects.billingCallback = None
|
||||
|
||||
# Store workflow stats for analytics
|
||||
self._storeAiCallStats(response, request)
|
||||
|
||||
return response
|
||||
|
||||
def _preflightBillingCheck(self) -> None:
|
||||
"""
|
||||
Pre-flight billing validation - like a 0 CHF credit card authorization check.
|
||||
|
||||
Validates that ALL required billing context is present and that a billing
|
||||
transaction CAN be recorded. This dry-run check catches missing context
|
||||
BEFORE an expensive AI call starts.
|
||||
|
||||
FAIL-SAFE: This method RAISES if billing context is incomplete.
|
||||
An AI call without billing context MUST NOT proceed.
|
||||
|
||||
Raises:
|
||||
BillingContextError: If billing context is incomplete or invalid
|
||||
"""
|
||||
if not self.services:
|
||||
raise BillingContextError("No service context available - cannot bill AI call")
|
||||
|
||||
user = getattr(self.services, 'user', None)
|
||||
if not user:
|
||||
raise BillingContextError("No user context - cannot bill AI call")
|
||||
|
||||
mandateId = getattr(self.services, 'mandateId', None)
|
||||
if not mandateId:
|
||||
raise BillingContextError(
|
||||
f"No mandateId in service context for user {user.id} - cannot bill AI call. "
|
||||
"Every AI call MUST have a mandate context for billing."
|
||||
)
|
||||
|
||||
# Validate billing service can be created
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
featureCode = getattr(self.services, 'featureCode', None)
|
||||
|
||||
try:
|
||||
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
|
||||
except Exception as e:
|
||||
raise BillingContextError(
|
||||
f"Cannot create billing service for user {user.id}, mandate {mandateId}: {e}"
|
||||
)
|
||||
|
||||
# Dry-run: verify billing service can check balance (DB accessible)
|
||||
try:
|
||||
billingService.checkBalance(0.0)
|
||||
except Exception as e:
|
||||
raise BillingContextError(
|
||||
f"Billing system not accessible for mandate {mandateId}: {e}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Pre-flight billing check PASSED: user={user.id}, mandate={mandateId}, "
|
||||
f"feature={featureCode or 'none'}, instance={featureInstanceId or 'none'}"
|
||||
)
|
||||
|
||||
async def _checkBillingBeforeAiCall(self) -> None:
|
||||
"""
|
||||
Check billing status before making an AI call.
|
||||
|
||||
FAIL-SAFE: Context validation is done in _preflightBillingCheck() which is
|
||||
called first. This method handles balance and provider permission checks.
|
||||
|
||||
Verifies:
|
||||
1. User has sufficient balance (for prepay models)
|
||||
2. Provider is allowed for the user (via RBAC)
|
||||
|
|
@ -130,34 +194,19 @@ class AiService:
|
|||
Raises:
|
||||
InsufficientBalanceException: If balance is insufficient
|
||||
ProviderNotAllowedException: If provider is not allowed
|
||||
BillingContextError: If billing check fails unexpectedly
|
||||
"""
|
||||
# Context is already validated by _preflightBillingCheck()
|
||||
user = self.services.user
|
||||
mandateId = self.services.mandateId
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
featureCode = getattr(self.services, 'featureCode', None)
|
||||
|
||||
try:
|
||||
# Get context from services
|
||||
if not self.services:
|
||||
logger.debug("No service center - skipping billing check")
|
||||
return
|
||||
|
||||
user = getattr(self.services, 'user', None)
|
||||
mandateId = getattr(self.services, 'mandateId', None)
|
||||
|
||||
if not user or not mandateId:
|
||||
logger.debug("No user or mandate context - skipping billing check")
|
||||
return
|
||||
|
||||
# Get feature context
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
featureCode = getattr(self.services, 'featureCode', None)
|
||||
|
||||
# Get billing service
|
||||
billingService = getBillingService(
|
||||
user,
|
||||
mandateId,
|
||||
featureInstanceId,
|
||||
featureCode
|
||||
)
|
||||
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
|
||||
|
||||
# Check balance (estimate typical AI call cost)
|
||||
# We use a small estimate here; actual cost is recorded after the call
|
||||
estimatedCost = 0.01 # ~1 cent CHF minimum
|
||||
balanceCheck = billingService.checkBalance(estimatedCost)
|
||||
|
||||
|
|
@ -170,7 +219,7 @@ class AiService:
|
|||
raise InsufficientBalanceException(
|
||||
currentBalance=balanceCheck.currentBalance or 0.0,
|
||||
requiredAmount=estimatedCost,
|
||||
message=f"Ungenügendes Guthaben. Aktuell: CHF {balanceCheck.currentBalance:.2f}"
|
||||
message=f"Ungenugendes Guthaben. Aktuell: CHF {balanceCheck.currentBalance:.2f}"
|
||||
)
|
||||
|
||||
logger.debug(f"Billing check passed: Balance {balanceCheck.currentBalance:.2f} CHF")
|
||||
|
|
@ -181,13 +230,12 @@ class AiService:
|
|||
logger.warning(f"No AI providers allowed for user {user.id} in mandate {mandateId}")
|
||||
raise ProviderNotAllowedException(
|
||||
provider="any",
|
||||
message="Keine AI-Provider für Ihre Rolle freigegeben. Kontaktieren Sie Ihren Administrator."
|
||||
message="Keine AI-Provider fuer Ihre Rolle freigegeben. Kontaktieren Sie Ihren Administrator."
|
||||
)
|
||||
|
||||
# Check automation-level allowedProviders restriction
|
||||
automationAllowedProviders = getattr(self.services, 'allowedProviders', None)
|
||||
if automationAllowedProviders:
|
||||
# Filter by both RBAC and automation-level restrictions
|
||||
effectiveProviders = [p for p in automationAllowedProviders if p in rbacAllowedProviders]
|
||||
if not effectiveProviders:
|
||||
logger.warning(f"No providers available after automation restriction. "
|
||||
|
|
@ -195,7 +243,7 @@ class AiService:
|
|||
f"RBAC allows: {rbacAllowedProviders}")
|
||||
raise ProviderNotAllowedException(
|
||||
provider="any",
|
||||
message="Die konfigurierten AI-Provider dieser Automation sind für Ihre Rolle nicht freigegeben."
|
||||
message="Die konfigurierten AI-Provider dieser Automation sind fuer Ihre Rolle nicht freigegeben."
|
||||
)
|
||||
logger.debug(f"Automation provider check passed: {effectiveProviders}")
|
||||
|
||||
|
|
@ -207,19 +255,78 @@ class AiService:
|
|||
logger.warning(f"Preferred provider {provider} not allowed for user {user.id}")
|
||||
raise ProviderNotAllowedException(
|
||||
provider=provider,
|
||||
message=f"Der gewählte Provider '{provider}' ist für Ihre Rolle nicht freigegeben."
|
||||
message=f"Der gewaehlte Provider '{provider}' ist fuer Ihre Rolle nicht freigegeben."
|
||||
)
|
||||
logger.debug(f"All preferred providers are allowed: {preferredProviders}")
|
||||
|
||||
logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed")
|
||||
|
||||
except InsufficientBalanceException:
|
||||
raise # Re-raise billing exceptions
|
||||
raise
|
||||
except ProviderNotAllowedException:
|
||||
raise # Re-raise provider exceptions
|
||||
raise
|
||||
except BillingContextError:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log but don't block on billing check errors
|
||||
logger.warning(f"Billing check failed with error (non-blocking): {e}")
|
||||
# FAIL-SAFE: Don't silently swallow errors - log at ERROR level
|
||||
logger.error(f"BILLING FAIL-SAFE: Billing check failed with unexpected error: {e}")
|
||||
raise BillingContextError(f"Billing check failed: {e}")
|
||||
|
||||
def _createBillingCallback(self):
|
||||
"""
|
||||
Create a billing callback for interfaceAiObjects._callWithModel().
|
||||
|
||||
Returns a function that records one billing transaction per individual model call.
|
||||
Each transaction contains the exact provider name AND model name.
|
||||
|
||||
For a 200 MB document processed with N parallel AI calls (possibly different models),
|
||||
this creates N separate billing transactions - one per model call.
|
||||
"""
|
||||
user = self.services.user
|
||||
mandateId = self.services.mandateId
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
featureCode = getattr(self.services, 'featureCode', None)
|
||||
|
||||
# Get workflow ID if available
|
||||
workflowId = None
|
||||
workflow = getattr(self.services, 'workflow', None)
|
||||
if workflow and hasattr(workflow, 'id'):
|
||||
workflowId = workflow.id
|
||||
|
||||
billingService = getBillingService(user, mandateId, featureInstanceId, featureCode)
|
||||
|
||||
def _billingCallback(response) -> None:
|
||||
"""Record billing for a single AI model call."""
|
||||
if not response or getattr(response, 'errorCount', 0) > 0:
|
||||
return
|
||||
|
||||
priceCHF = getattr(response, 'priceCHF', 0.0)
|
||||
if not priceCHF or priceCHF <= 0:
|
||||
return
|
||||
|
||||
provider = getattr(response, 'provider', None) or 'unknown'
|
||||
modelName = getattr(response, 'modelName', None) or 'unknown'
|
||||
|
||||
try:
|
||||
billingService.recordUsage(
|
||||
priceCHF=priceCHF,
|
||||
workflowId=workflowId,
|
||||
aicoreProvider=provider,
|
||||
aicoreModel=modelName,
|
||||
description=f"AI: {modelName}"
|
||||
)
|
||||
logger.debug(
|
||||
f"Billed model call: {priceCHF:.4f} CHF, "
|
||||
f"provider={provider}, model={modelName}, mandate={mandateId}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"BILLING: Failed to record transaction! "
|
||||
f"Cost={priceCHF:.4f} CHF, user={user.id}, mandate={mandateId}, "
|
||||
f"provider={provider}, model={modelName}, error={e}"
|
||||
)
|
||||
|
||||
return _billingCallback
|
||||
|
||||
def _calculateEffectiveProviders(self) -> Optional[List[str]]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -192,6 +192,7 @@ class BillingService:
|
|||
priceCHF: float,
|
||||
workflowId: str = None,
|
||||
aicoreProvider: str = None,
|
||||
aicoreModel: str = None,
|
||||
description: str = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
|
|
@ -206,6 +207,7 @@ class BillingService:
|
|||
priceCHF: Base price from AI model (before markup)
|
||||
workflowId: Optional workflow ID
|
||||
aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai')
|
||||
aicoreModel: AICore model name (e.g., 'claude-4-sonnet', 'gpt-4o')
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
|
|
@ -222,7 +224,7 @@ class BillingService:
|
|||
|
||||
# Build description
|
||||
if not description:
|
||||
description = f"AI Usage: {aicoreProvider or 'unknown'}"
|
||||
description = f"AI Usage: {aicoreModel or aicoreProvider or 'unknown'}"
|
||||
|
||||
return self._billingInterface.recordUsage(
|
||||
mandateId=self.mandateId,
|
||||
|
|
@ -232,6 +234,7 @@ class BillingService:
|
|||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.featureCode,
|
||||
aicoreProvider=aicoreProvider,
|
||||
aicoreModel=aicoreModel,
|
||||
description=description
|
||||
)
|
||||
|
||||
|
|
@ -399,3 +402,16 @@ class ProviderNotAllowedException(Exception):
|
|||
self.provider = provider
|
||||
self.message = message or f"Provider '{provider}' is not allowed for your role"
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class BillingContextError(Exception):
|
||||
"""Raised when billing context is incomplete (missing mandateId, user, etc.).
|
||||
|
||||
This is a FAIL-SAFE error: AI calls MUST NOT proceed without valid billing context.
|
||||
Acts like a 0 CHF credit card pre-authorization check - validates that billing
|
||||
CAN be recorded before any expensive AI operation starts.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = None):
|
||||
self.message = message or "Billing context incomplete - AI call blocked"
|
||||
super().__init__(self.message)
|
||||
|
|
|
|||
|
|
@ -675,9 +675,11 @@ class ChatService:
|
|||
|
||||
def storeWorkflowStat(self, workflow: Any, aiResponse: Any, process: str) -> ChatStat:
|
||||
"""Persist workflow-level ChatStat from AiCallResponse and append to workflow stats list.
|
||||
Also records the usage cost to the billing system if configured."""
|
||||
|
||||
Billing is handled at the AI call source (interfaceAiObjects._callWithModel)
|
||||
via billingCallback - not here. This method only handles workflow stats.
|
||||
"""
|
||||
try:
|
||||
# Create ChatStat from AiCallResponse data
|
||||
statData = {
|
||||
"workflowId": workflow.id,
|
||||
"process": process,
|
||||
|
|
@ -689,76 +691,16 @@ class ChatService:
|
|||
"errorCount": aiResponse.errorCount
|
||||
}
|
||||
|
||||
# Create the stat record in the database
|
||||
stat = self.interfaceDbChat.createStat(statData)
|
||||
|
||||
# Append to workflow stats list in memory
|
||||
if not hasattr(workflow, 'stats') or workflow.stats is None:
|
||||
workflow.stats = []
|
||||
workflow.stats.append(stat)
|
||||
|
||||
# Record billing transaction if mandateId is available
|
||||
self._recordBillingUsage(workflow, aiResponse, process)
|
||||
|
||||
return stat
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store workflow stat: {e}")
|
||||
raise
|
||||
|
||||
def _recordBillingUsage(self, workflow: Any, aiResponse: Any, process: str) -> None:
|
||||
"""Record AI usage to the billing system.
|
||||
|
||||
This method:
|
||||
1. Gets the mandate context from services
|
||||
2. Records the usage cost with 50% markup via BillingService
|
||||
|
||||
Args:
|
||||
workflow: ChatWorkflow object
|
||||
aiResponse: AI call response with cost information
|
||||
process: Process identifier for the AI call
|
||||
"""
|
||||
try:
|
||||
# Check if we have mandate context
|
||||
mandateId = getattr(self.services, 'mandateId', None)
|
||||
if not mandateId:
|
||||
logger.debug("No mandate context, skipping billing recording")
|
||||
return
|
||||
|
||||
# Check if there's a cost to record
|
||||
priceCHF = getattr(aiResponse, 'priceCHF', 0.0)
|
||||
if not priceCHF or priceCHF <= 0:
|
||||
return
|
||||
|
||||
# Get provider from AiCallResponse (set from model.connectorType)
|
||||
aicoreProvider = getattr(aiResponse, 'provider', None) or 'unknown'
|
||||
|
||||
# Get feature context if available
|
||||
featureInstanceId = getattr(self.services, 'featureInstanceId', None)
|
||||
featureCode = getattr(self.services, 'featureCode', None)
|
||||
|
||||
# Import and use BillingService
|
||||
from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
||||
|
||||
billingService = getBillingService(
|
||||
self.user,
|
||||
mandateId,
|
||||
featureInstanceId=featureInstanceId,
|
||||
featureCode=featureCode
|
||||
)
|
||||
|
||||
# Record the usage (includes 50% markup)
|
||||
billingService.recordUsage(
|
||||
priceCHF=priceCHF,
|
||||
workflowId=workflow.id,
|
||||
aicoreProvider=aicoreProvider,
|
||||
description=f"AI Usage: {process}"
|
||||
)
|
||||
|
||||
logger.debug(f"Recorded billing usage: {priceCHF} CHF for {process} (provider: {aicoreProvider})")
|
||||
|
||||
except Exception as e:
|
||||
# Don't fail the main operation if billing recording fails
|
||||
logger.warning(f"Failed to record billing usage (non-critical): {e}")
|
||||
|
||||
def updateMessage(self, messageId: str, messageData: Dict[str, Any]):
|
||||
"""Update message by delegating to the chat interface"""
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from .subAutomationUtils import parseScheduleToCron, planToPrompt, replacePlaceh
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> ChatWorkflow:
|
||||
async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode: WorkflowModeEnum, workflowId: Optional[str] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None, featureCode: Optional[str] = None) -> ChatWorkflow:
|
||||
"""
|
||||
Starts a new chat or continues an existing one, then launches processing asynchronously.
|
||||
|
||||
|
|
@ -32,12 +32,10 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
|
|||
currentUser: Current user
|
||||
userInput: User input request
|
||||
workflowId: Optional workflow ID to continue existing workflow
|
||||
workflowMode: "Dynamic" for iterative dynamic-style processing, "Automation" for automated workflow execution
|
||||
mandateId: Mandate ID from request context (required for proper data isolation)
|
||||
featureInstanceId: Feature instance ID for context
|
||||
|
||||
Example usage for Dynamic mode:
|
||||
workflow = await chatStart(currentUser, userInput, workflowMode=WorkflowModeEnum.WORKFLOW_DYNAMIC, mandateId=mandateId)
|
||||
workflowMode: Workflow mode (Dynamic, Automation, etc.)
|
||||
mandateId: Mandate ID (required for billing)
|
||||
featureInstanceId: Feature instance ID (required for billing)
|
||||
featureCode: Feature code (e.g., 'chatplayground', 'automation')
|
||||
"""
|
||||
try:
|
||||
services = getServices(currentUser, mandateId=mandateId)
|
||||
|
|
@ -47,10 +45,11 @@ async def chatStart(currentUser: User, userInput: UserInputRequest, workflowMode
|
|||
services.allowedProviders = userInput.allowedProviders
|
||||
logger.info(f"AI provider filter active: {userInput.allowedProviders}")
|
||||
|
||||
# Store feature instance ID in services context
|
||||
# Store feature context in services (for billing and RBAC)
|
||||
if featureInstanceId:
|
||||
services.featureInstanceId = featureInstanceId
|
||||
services.featureCode = 'chatplayground'
|
||||
if featureCode:
|
||||
services.featureCode = featureCode
|
||||
|
||||
workflowManager = WorkflowManager(services)
|
||||
workflow = await workflowManager.workflowStart(userInput, workflowMode, workflowId)
|
||||
|
|
@ -105,8 +104,11 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
services.allowedProviders = automation.allowedProviders
|
||||
logger.debug(f"Automation {automationId} restricted to providers: {automation.allowedProviders}")
|
||||
|
||||
# Store feature context for billing
|
||||
# Context comes EXCLUSIVELY from the automation definition
|
||||
services.mandateId = str(automation.mandateId)
|
||||
services.featureInstanceId = str(automation.featureInstanceId)
|
||||
services.featureCode = 'automation'
|
||||
featureInstanceId = services.featureInstanceId
|
||||
|
||||
# 2. Replace placeholders in template to generate plan
|
||||
template = automation.template or ""
|
||||
|
|
@ -159,11 +161,16 @@ async def executeAutomation(automationId: str, services) -> ChatWorkflow:
|
|||
executionLog["messages"].append("Starting workflow execution")
|
||||
|
||||
# 5. Start workflow using chatStart
|
||||
# Pass mandateId, featureInstanceId, and featureCode from original services context
|
||||
# so billing is recorded correctly with full feature context
|
||||
workflow = await chatStart(
|
||||
currentUser=creatorUser,
|
||||
userInput=userInput,
|
||||
workflowMode=WorkflowModeEnum.WORKFLOW_AUTOMATION,
|
||||
workflowId=None
|
||||
workflowId=None,
|
||||
mandateId=services.mandateId,
|
||||
featureInstanceId=featureInstanceId,
|
||||
featureCode='automation'
|
||||
)
|
||||
|
||||
executionLog["workflowId"] = workflow.id
|
||||
|
|
|
|||
Loading…
Reference in a new issue