1284 lines
49 KiB
Python
1284 lines
49 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Interface for Billing operations.
|
|
Manages billing accounts, transactions, and usage statistics.
|
|
|
|
All billing data is stored in the poweron_billing database.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, List, Optional
|
|
from datetime import date, datetime, timedelta
|
|
import uuid
|
|
|
|
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.shared.timeUtils import getUtcTimestamp
|
|
from modules.datamodels.datamodelUam import User, Mandate
|
|
from modules.datamodels.datamodelMembership import UserMandate
|
|
from modules.datamodels.datamodelBilling import (
|
|
BillingAccount,
|
|
BillingTransaction,
|
|
BillingSettings,
|
|
UsageStatistics,
|
|
BillingAddress,
|
|
BillingModelEnum,
|
|
AccountTypeEnum,
|
|
TransactionTypeEnum,
|
|
ReferenceTypeEnum,
|
|
PeriodTypeEnum,
|
|
BillingBalanceResponse,
|
|
BillingCheckResult,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Singleton factory for BillingObjects instances
|
|
_billingInterfaces: Dict[str, "BillingObjects"] = {}
|
|
|
|
# Database name for billing
|
|
BILLING_DATABASE = "poweron_billing"
|
|
|
|
|
|
def getInterface(currentUser: User, mandateId: str = None) -> "BillingObjects":
|
|
"""
|
|
Factory function to get or create a BillingObjects instance.
|
|
|
|
Args:
|
|
currentUser: Current user object
|
|
mandateId: Mandate ID for context
|
|
|
|
Returns:
|
|
BillingObjects instance
|
|
"""
|
|
cacheKey = f"{currentUser.id}_{mandateId}"
|
|
|
|
if cacheKey not in _billingInterfaces:
|
|
_billingInterfaces[cacheKey] = BillingObjects(currentUser, mandateId)
|
|
else:
|
|
_billingInterfaces[cacheKey].setUserContext(currentUser, mandateId)
|
|
|
|
return _billingInterfaces[cacheKey]
|
|
|
|
|
|
def _getRootInterface() -> "BillingObjects":
|
|
"""Get interface with system access for bootstrap operations."""
|
|
from modules.security.rootAccess import getRootUser
|
|
rootUser = getRootUser()
|
|
return BillingObjects(rootUser, mandateId=None)
|
|
|
|
|
|
class BillingObjects:
|
|
"""
|
|
Interface for billing operations.
|
|
Manages accounts, transactions, settings, and statistics.
|
|
"""
|
|
|
|
def __init__(self, currentUser: Optional[User] = None, mandateId: str = None):
|
|
"""
|
|
Initialize the billing interface.
|
|
|
|
Args:
|
|
currentUser: Current user object
|
|
mandateId: Mandate ID for context
|
|
"""
|
|
self.currentUser = currentUser
|
|
self.userId = currentUser.id if currentUser else None
|
|
self.mandateId = mandateId
|
|
|
|
# Initialize database connection
|
|
self._initializeDatabase()
|
|
|
|
def _initializeDatabase(self):
|
|
"""Initialize database connection."""
|
|
self.db = DatabaseConnector(
|
|
dbDatabase=BILLING_DATABASE,
|
|
dbHost=APP_CONFIG.get('DB_HOST', 'localhost'),
|
|
dbPort=int(APP_CONFIG.get('DB_PORT', '5432')),
|
|
dbUser=APP_CONFIG.get('DB_USER'),
|
|
dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET')
|
|
)
|
|
|
|
def setUserContext(self, currentUser: User, mandateId: str = None):
|
|
"""
|
|
Update user context.
|
|
|
|
Args:
|
|
currentUser: Current user object
|
|
mandateId: Mandate ID for context
|
|
"""
|
|
self.currentUser = currentUser
|
|
self.userId = currentUser.id if currentUser else None
|
|
self.mandateId = mandateId
|
|
|
|
# =========================================================================
|
|
# BillingSettings Operations
|
|
# =========================================================================
|
|
|
|
def getSettings(self, mandateId: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get billing settings for a mandate.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
|
|
Returns:
|
|
BillingSettings dict or None if not found
|
|
"""
|
|
try:
|
|
results = self.db.getRecordset(
|
|
BillingSettings,
|
|
recordFilter={"mandateId": mandateId}
|
|
)
|
|
return results[0] if results else None
|
|
except Exception as e:
|
|
logger.error(f"Error getting billing settings: {e}")
|
|
return None
|
|
|
|
def createSettings(self, settings: BillingSettings) -> Dict[str, Any]:
|
|
"""
|
|
Create billing settings for a mandate.
|
|
|
|
Args:
|
|
settings: BillingSettings object
|
|
|
|
Returns:
|
|
Created settings dict
|
|
"""
|
|
settingsDict = settings.model_dump(exclude_none=True)
|
|
|
|
# Handle nested BillingAddress
|
|
if settings.billingAddress:
|
|
settingsDict["billingAddress"] = settings.billingAddress.model_dump()
|
|
|
|
return self.db.recordCreate(BillingSettings, settingsDict)
|
|
|
|
def updateSettings(self, settingsId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Update billing settings.
|
|
|
|
Args:
|
|
settingsId: Settings ID
|
|
updates: Fields to update
|
|
|
|
Returns:
|
|
Updated settings dict or None
|
|
"""
|
|
return self.db.recordModify(BillingSettings, settingsId, updates)
|
|
|
|
def getOrCreateSettings(self, mandateId: str, defaultModel: BillingModelEnum = BillingModelEnum.UNLIMITED) -> Dict[str, Any]:
|
|
"""
|
|
Get or create billing settings for a mandate.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
defaultModel: Default billing model if creating
|
|
|
|
Returns:
|
|
BillingSettings dict
|
|
"""
|
|
existing = self.getSettings(mandateId)
|
|
if existing:
|
|
return existing
|
|
|
|
settings = BillingSettings(
|
|
mandateId=mandateId,
|
|
billingModel=defaultModel,
|
|
defaultUserCredit=10.0,
|
|
warningThresholdPercent=10.0,
|
|
blockOnZeroBalance=True,
|
|
notifyOnWarning=True
|
|
)
|
|
return self.createSettings(settings)
|
|
|
|
# =========================================================================
|
|
# BillingAccount Operations
|
|
# =========================================================================
|
|
|
|
def getAccount(self, accountId: str) -> Optional[Dict[str, Any]]:
|
|
"""Get a billing account by ID."""
|
|
try:
|
|
results = self.db.getRecordset(
|
|
BillingAccount,
|
|
recordFilter={"id": accountId}
|
|
)
|
|
return results[0] if results else None
|
|
except Exception as e:
|
|
logger.error(f"Error getting billing account: {e}")
|
|
return None
|
|
|
|
def getMandateAccount(self, mandateId: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get the mandate-level billing account.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
|
|
Returns:
|
|
BillingAccount dict or None
|
|
"""
|
|
try:
|
|
results = self.db.getRecordset(
|
|
BillingAccount,
|
|
recordFilter={
|
|
"mandateId": mandateId,
|
|
"accountType": AccountTypeEnum.MANDATE.value
|
|
}
|
|
)
|
|
return results[0] if results else None
|
|
except Exception as e:
|
|
logger.error(f"Error getting mandate account: {e}")
|
|
return None
|
|
|
|
def getUserAccount(self, mandateId: str, userId: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get a user-level billing account within a mandate.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
userId: User ID
|
|
|
|
Returns:
|
|
BillingAccount dict or None
|
|
"""
|
|
try:
|
|
results = self.db.getRecordset(
|
|
BillingAccount,
|
|
recordFilter={
|
|
"mandateId": mandateId,
|
|
"userId": userId,
|
|
"accountType": AccountTypeEnum.USER.value
|
|
}
|
|
)
|
|
return results[0] if results else None
|
|
except Exception as e:
|
|
logger.error(f"Error getting user account: {e}")
|
|
return None
|
|
|
|
def getAccountsByMandate(self, mandateId: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all billing accounts for a mandate.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
|
|
Returns:
|
|
List of BillingAccount dicts
|
|
"""
|
|
try:
|
|
return self.db.getRecordset(
|
|
BillingAccount,
|
|
recordFilter={"mandateId": mandateId}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error getting accounts for mandate: {e}")
|
|
return []
|
|
|
|
def createAccount(self, account: BillingAccount) -> Dict[str, Any]:
|
|
"""
|
|
Create a new billing account.
|
|
|
|
Args:
|
|
account: BillingAccount object
|
|
|
|
Returns:
|
|
Created account dict
|
|
"""
|
|
accountDict = account.model_dump(exclude_none=True)
|
|
return self.db.recordCreate(BillingAccount, accountDict)
|
|
|
|
def updateAccountBalance(self, accountId: str, newBalance: float) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Update account balance atomically.
|
|
|
|
Args:
|
|
accountId: Account ID
|
|
newBalance: New balance value
|
|
|
|
Returns:
|
|
Updated account dict or None
|
|
"""
|
|
return self.db.recordModify(BillingAccount, accountId, {"balance": newBalance})
|
|
|
|
def getOrCreateMandateAccount(self, mandateId: str, initialBalance: float = 0.0) -> Dict[str, Any]:
|
|
"""
|
|
Get or create a mandate-level billing account.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
initialBalance: Initial balance if creating
|
|
|
|
Returns:
|
|
BillingAccount dict
|
|
"""
|
|
existing = self.getMandateAccount(mandateId)
|
|
if existing:
|
|
return existing
|
|
|
|
account = BillingAccount(
|
|
mandateId=mandateId,
|
|
accountType=AccountTypeEnum.MANDATE,
|
|
balance=initialBalance,
|
|
enabled=True
|
|
)
|
|
return self.createAccount(account)
|
|
|
|
def getOrCreateUserAccount(self, mandateId: str, userId: str, initialBalance: float = 0.0) -> Dict[str, Any]:
|
|
"""
|
|
Get or create a user-level billing account.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
userId: User ID
|
|
initialBalance: Initial balance if creating
|
|
|
|
Returns:
|
|
BillingAccount dict
|
|
"""
|
|
existing = self.getUserAccount(mandateId, userId)
|
|
if existing:
|
|
return existing
|
|
|
|
account = BillingAccount(
|
|
mandateId=mandateId,
|
|
userId=userId,
|
|
accountType=AccountTypeEnum.USER,
|
|
balance=initialBalance,
|
|
enabled=True
|
|
)
|
|
created = self.createAccount(account)
|
|
|
|
# If initial balance > 0, create a SYSTEM credit transaction
|
|
if initialBalance > 0:
|
|
self.createTransaction(BillingTransaction(
|
|
accountId=created["id"],
|
|
transactionType=TransactionTypeEnum.CREDIT,
|
|
amount=initialBalance,
|
|
description="Initial credit for new user",
|
|
referenceType=ReferenceTypeEnum.SYSTEM
|
|
))
|
|
|
|
return created
|
|
|
|
def ensureAllMandateSettingsExist(self) -> int:
|
|
"""
|
|
Efficiently ensure all mandates have billing settings.
|
|
Creates default settings (PREPAY_USER) for mandates without settings.
|
|
Uses bulk queries to minimize database connections.
|
|
|
|
Returns:
|
|
Number of settings created
|
|
"""
|
|
try:
|
|
settingsCreated = 0
|
|
|
|
# Step 1: Get all existing billing settings in one query (from billing DB)
|
|
allSettings = self.db.getRecordset(BillingSettings)
|
|
existingMandateIds = set(s.get("mandateId") for s in allSettings if s.get("mandateId"))
|
|
|
|
# Step 2: Get all mandates from APP database (separate connection)
|
|
appDb = DatabaseConnector(
|
|
dbDatabase=APP_CONFIG.get('DB_DATABASE', 'poweron_app'),
|
|
dbHost=APP_CONFIG.get('DB_HOST', 'localhost'),
|
|
dbPort=int(APP_CONFIG.get('DB_PORT', '5432')),
|
|
dbUser=APP_CONFIG.get('DB_USER'),
|
|
dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET')
|
|
)
|
|
allMandates = appDb.getRecordset(Mandate, recordFilter={"enabled": True})
|
|
|
|
# Step 3: Create settings for mandates that don't have them
|
|
for mandate in allMandates:
|
|
mandateId = mandate.get("id")
|
|
if not mandateId or mandateId in existingMandateIds:
|
|
continue
|
|
|
|
# Create default billing settings
|
|
settings = BillingSettings(
|
|
mandateId=mandateId,
|
|
billingModel=BillingModelEnum.PREPAY_USER,
|
|
defaultUserCredit=10.0,
|
|
warningThresholdPercent=10.0,
|
|
blockOnZeroBalance=True,
|
|
notifyOnWarning=True
|
|
)
|
|
self.createSettings(settings)
|
|
existingMandateIds.add(mandateId) # Track newly created
|
|
settingsCreated += 1
|
|
|
|
if settingsCreated > 0:
|
|
logger.info(f"Created {settingsCreated} missing billing settings for mandates")
|
|
|
|
return settingsCreated
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error ensuring mandate settings exist: {e}")
|
|
return 0
|
|
|
|
def ensureAllUserAccountsExist(self) -> int:
|
|
"""
|
|
Efficiently ensure all users across all mandates have billing accounts.
|
|
Uses bulk queries to minimize database connections.
|
|
|
|
Returns:
|
|
Number of accounts created
|
|
"""
|
|
try:
|
|
accountsCreated = 0
|
|
|
|
# Step 1: Get all billing settings in one query (only PREPAY_USER mandates need user accounts)
|
|
allSettings = self.db.getRecordset(BillingSettings)
|
|
prepayUserMandates = {}
|
|
for s in allSettings:
|
|
if s.get("billingModel") == BillingModelEnum.PREPAY_USER.value:
|
|
prepayUserMandates[s.get("mandateId")] = s.get("defaultUserCredit", 10.0)
|
|
|
|
if not prepayUserMandates:
|
|
logger.debug("No PREPAY_USER mandates found, skipping account check")
|
|
return 0
|
|
|
|
# Step 2: Get all existing USER accounts in one query (from billing DB)
|
|
allAccounts = self.db.getRecordset(
|
|
BillingAccount,
|
|
recordFilter={"accountType": AccountTypeEnum.USER.value}
|
|
)
|
|
# Build set of existing (mandateId, userId) pairs
|
|
existingAccountKeys = set()
|
|
for acc in allAccounts:
|
|
key = (acc.get("mandateId"), acc.get("userId"))
|
|
existingAccountKeys.add(key)
|
|
|
|
# Step 3: Get all user-mandate combinations from APP database (separate connection)
|
|
appDb = DatabaseConnector(
|
|
dbDatabase=APP_CONFIG.get('DB_DATABASE', 'poweron_app'),
|
|
dbHost=APP_CONFIG.get('DB_HOST', 'localhost'),
|
|
dbPort=int(APP_CONFIG.get('DB_PORT', '5432')),
|
|
dbUser=APP_CONFIG.get('DB_USER'),
|
|
dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET')
|
|
)
|
|
allUserMandates = appDb.getRecordset(
|
|
UserMandate,
|
|
recordFilter={"enabled": True}
|
|
)
|
|
|
|
# Step 4: Find missing accounts and create them
|
|
for um in allUserMandates:
|
|
mandateId = um.get("mandateId")
|
|
userId = um.get("userId")
|
|
|
|
if not mandateId or not userId:
|
|
continue
|
|
|
|
# Only process mandates with PREPAY_USER billing
|
|
if mandateId not in prepayUserMandates:
|
|
continue
|
|
|
|
# Check if account already exists (in memory, no DB call)
|
|
key = (mandateId, userId)
|
|
if key in existingAccountKeys:
|
|
continue
|
|
|
|
# Create missing account
|
|
defaultCredit = prepayUserMandates[mandateId]
|
|
account = BillingAccount(
|
|
mandateId=mandateId,
|
|
userId=userId,
|
|
accountType=AccountTypeEnum.USER,
|
|
balance=defaultCredit,
|
|
enabled=True
|
|
)
|
|
created = self.createAccount(account)
|
|
|
|
# Create initial credit transaction
|
|
if defaultCredit > 0:
|
|
self.createTransaction(BillingTransaction(
|
|
accountId=created["id"],
|
|
transactionType=TransactionTypeEnum.CREDIT,
|
|
amount=defaultCredit,
|
|
description="Initial credit for new user",
|
|
referenceType=ReferenceTypeEnum.SYSTEM
|
|
))
|
|
|
|
existingAccountKeys.add(key) # Track newly created
|
|
accountsCreated += 1
|
|
|
|
if accountsCreated > 0:
|
|
logger.info(f"Created {accountsCreated} missing billing accounts")
|
|
|
|
return accountsCreated
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error ensuring user accounts exist: {e}")
|
|
return 0
|
|
|
|
# =========================================================================
|
|
# BillingTransaction Operations
|
|
# =========================================================================
|
|
|
|
def createTransaction(self, transaction: BillingTransaction) -> Dict[str, Any]:
|
|
"""
|
|
Create a new billing transaction and update account balance.
|
|
|
|
Args:
|
|
transaction: BillingTransaction object
|
|
|
|
Returns:
|
|
Created transaction dict
|
|
"""
|
|
# Get current account
|
|
account = self.getAccount(transaction.accountId)
|
|
if not account:
|
|
raise ValueError(f"Account {transaction.accountId} not found")
|
|
|
|
currentBalance = account.get("balance", 0.0)
|
|
|
|
# Calculate new balance
|
|
if transaction.transactionType == TransactionTypeEnum.CREDIT:
|
|
newBalance = currentBalance + transaction.amount
|
|
elif transaction.transactionType == TransactionTypeEnum.DEBIT:
|
|
newBalance = currentBalance - transaction.amount
|
|
else: # ADJUSTMENT
|
|
newBalance = currentBalance + transaction.amount # Can be positive or negative
|
|
|
|
# Create transaction
|
|
transactionDict = transaction.model_dump(exclude_none=True)
|
|
created = self.db.recordCreate(BillingTransaction, transactionDict)
|
|
|
|
# Update account balance
|
|
self.updateAccountBalance(transaction.accountId, newBalance)
|
|
|
|
logger.info(f"Billing transaction created: {transaction.transactionType.value} {transaction.amount} CHF, "
|
|
f"balance: {currentBalance} -> {newBalance}")
|
|
|
|
return created
|
|
|
|
def getTransactions(
|
|
self,
|
|
accountId: str,
|
|
limit: int = 100,
|
|
offset: int = 0,
|
|
startDate: date = None,
|
|
endDate: date = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get transactions for an account.
|
|
|
|
Args:
|
|
accountId: Account ID
|
|
limit: Maximum number of results
|
|
offset: Offset for pagination
|
|
startDate: Filter by start date
|
|
endDate: Filter by end date
|
|
|
|
Returns:
|
|
List of transaction dicts
|
|
"""
|
|
try:
|
|
filterDict = {"accountId": accountId}
|
|
results = self.db.getRecordset(BillingTransaction, recordFilter=filterDict)
|
|
|
|
# Apply date filters if provided
|
|
if startDate or endDate:
|
|
filtered = []
|
|
for t in results:
|
|
createdAt = t.get("_createdAt")
|
|
if createdAt:
|
|
tDate = createdAt.date() if isinstance(createdAt, datetime) else createdAt
|
|
if startDate and tDate < startDate:
|
|
continue
|
|
if endDate and tDate > endDate:
|
|
continue
|
|
filtered.append(t)
|
|
results = filtered
|
|
|
|
# Sort by creation date descending
|
|
results.sort(key=lambda x: x.get("_createdAt", ""), reverse=True)
|
|
|
|
# Apply pagination
|
|
return results[offset:offset + limit]
|
|
except Exception as e:
|
|
logger.error(f"Error getting transactions: {e}")
|
|
return []
|
|
|
|
def getTransactionsByMandate(self, mandateId: str, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all transactions for a mandate (across all accounts).
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
limit: Maximum number of results
|
|
|
|
Returns:
|
|
List of transaction dicts
|
|
"""
|
|
# Get all accounts for mandate
|
|
accounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId})
|
|
|
|
allTransactions = []
|
|
for account in accounts:
|
|
transactions = self.getTransactions(account["id"], limit=limit)
|
|
allTransactions.extend(transactions)
|
|
|
|
# Sort by creation date descending and limit
|
|
allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True)
|
|
return allTransactions[:limit]
|
|
|
|
# =========================================================================
|
|
# Balance Check Operations
|
|
# =========================================================================
|
|
|
|
def checkBalance(self, mandateId: str, userId: str, estimatedCost: float) -> BillingCheckResult:
|
|
"""
|
|
Check if there's sufficient balance for an operation.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
userId: User ID
|
|
estimatedCost: Estimated cost of the operation
|
|
|
|
Returns:
|
|
BillingCheckResult
|
|
"""
|
|
settings = self.getSettings(mandateId)
|
|
if not settings:
|
|
# No settings = no billing = allowed
|
|
return BillingCheckResult(allowed=True, billingModel=BillingModelEnum.UNLIMITED)
|
|
|
|
billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
|
|
|
|
# UNLIMITED = always allowed
|
|
if billingModel == BillingModelEnum.UNLIMITED:
|
|
return BillingCheckResult(allowed=True, billingModel=billingModel)
|
|
|
|
# Get the relevant account
|
|
if billingModel == BillingModelEnum.PREPAY_USER:
|
|
account = self.getUserAccount(mandateId, userId)
|
|
# Auto-create user account if not exists (with default credit from settings)
|
|
if not account:
|
|
defaultCredit = settings.get("defaultUserCredit", 10.0)
|
|
logger.info(f"Auto-creating billing account for user {userId} in mandate {mandateId} with {defaultCredit} CHF initial credit")
|
|
account = self.getOrCreateUserAccount(mandateId, userId, initialBalance=defaultCredit)
|
|
else:
|
|
account = self.getMandateAccount(mandateId)
|
|
|
|
if not account:
|
|
# No account (only happens for mandate-level accounts) = potentially blocked
|
|
if settings.get("blockOnZeroBalance", True):
|
|
return BillingCheckResult(
|
|
allowed=False,
|
|
reason="NO_ACCOUNT",
|
|
currentBalance=0.0,
|
|
requiredAmount=estimatedCost,
|
|
billingModel=billingModel
|
|
)
|
|
return BillingCheckResult(allowed=True, currentBalance=0.0, billingModel=billingModel)
|
|
|
|
currentBalance = account.get("balance", 0.0)
|
|
|
|
# CREDIT_POSTPAY with credit limit check
|
|
if billingModel == BillingModelEnum.CREDIT_POSTPAY:
|
|
creditLimit = account.get("creditLimit")
|
|
if creditLimit and abs(currentBalance) + estimatedCost > creditLimit:
|
|
return BillingCheckResult(
|
|
allowed=False,
|
|
reason="CREDIT_LIMIT_EXCEEDED",
|
|
currentBalance=currentBalance,
|
|
requiredAmount=estimatedCost,
|
|
billingModel=billingModel
|
|
)
|
|
return BillingCheckResult(allowed=True, currentBalance=currentBalance, billingModel=billingModel)
|
|
|
|
# PREPAY models - check balance
|
|
if currentBalance < estimatedCost:
|
|
if settings.get("blockOnZeroBalance", True):
|
|
return BillingCheckResult(
|
|
allowed=False,
|
|
reason="INSUFFICIENT_BALANCE",
|
|
currentBalance=currentBalance,
|
|
requiredAmount=estimatedCost,
|
|
billingModel=billingModel
|
|
)
|
|
|
|
return BillingCheckResult(allowed=True, currentBalance=currentBalance, billingModel=billingModel)
|
|
|
|
def recordUsage(
|
|
self,
|
|
mandateId: str,
|
|
userId: str,
|
|
priceCHF: float,
|
|
workflowId: str = None,
|
|
featureInstanceId: str = None,
|
|
featureCode: str = None,
|
|
aicoreProvider: str = None,
|
|
aicoreModel: str = None,
|
|
description: str = "AI Usage"
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Record usage cost as a billing transaction.
|
|
|
|
Args:
|
|
mandateId: Mandate ID
|
|
userId: User ID
|
|
priceCHF: Cost in CHF
|
|
workflowId: Optional workflow ID
|
|
featureInstanceId: Optional feature instance ID
|
|
featureCode: Optional feature code
|
|
aicoreProvider: AICore provider name (e.g., 'anthropic', 'openai')
|
|
aicoreModel: AICore model name (e.g., 'claude-4-sonnet', 'gpt-4o')
|
|
description: Transaction description
|
|
|
|
Returns:
|
|
Created transaction dict or None
|
|
"""
|
|
if priceCHF <= 0:
|
|
return None
|
|
|
|
settings = self.getSettings(mandateId)
|
|
if not settings:
|
|
logger.debug(f"No billing settings for mandate {mandateId}, skipping usage recording")
|
|
return None
|
|
|
|
billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
|
|
|
|
# UNLIMITED = no transaction recording
|
|
if billingModel == BillingModelEnum.UNLIMITED:
|
|
return None
|
|
|
|
# Get or create the relevant account
|
|
if billingModel == BillingModelEnum.PREPAY_USER:
|
|
account = self.getOrCreateUserAccount(mandateId, userId)
|
|
else:
|
|
account = self.getOrCreateMandateAccount(mandateId)
|
|
|
|
# Create debit transaction
|
|
transaction = BillingTransaction(
|
|
accountId=account["id"],
|
|
transactionType=TransactionTypeEnum.DEBIT,
|
|
amount=priceCHF,
|
|
description=description,
|
|
referenceType=ReferenceTypeEnum.WORKFLOW,
|
|
workflowId=workflowId,
|
|
featureInstanceId=featureInstanceId,
|
|
featureCode=featureCode,
|
|
aicoreProvider=aicoreProvider,
|
|
aicoreModel=aicoreModel,
|
|
createdByUserId=userId
|
|
)
|
|
|
|
return self.createTransaction(transaction)
|
|
|
|
# =========================================================================
|
|
# Statistics Operations
|
|
# =========================================================================
|
|
|
|
def getUsageStatistics(
|
|
self,
|
|
accountId: str,
|
|
periodType: PeriodTypeEnum,
|
|
year: int,
|
|
month: int = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get usage statistics for an account.
|
|
|
|
Args:
|
|
accountId: Account ID
|
|
periodType: Period type (DAY, MONTH, YEAR)
|
|
year: Year
|
|
month: Month (for DAY period type)
|
|
|
|
Returns:
|
|
List of statistics dicts
|
|
"""
|
|
filterDict = {
|
|
"accountId": accountId,
|
|
"periodType": periodType.value
|
|
}
|
|
|
|
results = self.db.getRecordset(UsageStatistics, recordFilter=filterDict)
|
|
|
|
# Filter by year
|
|
filtered = [s for s in results if s.get("periodStart") and s["periodStart"].year == year]
|
|
|
|
# Filter by month if specified
|
|
if month and periodType == PeriodTypeEnum.DAY:
|
|
filtered = [s for s in filtered if s["periodStart"].month == month]
|
|
|
|
return sorted(filtered, key=lambda x: x.get("periodStart", date.min))
|
|
|
|
def calculateStatisticsFromTransactions(
|
|
self,
|
|
accountId: str,
|
|
startDate: date,
|
|
endDate: date
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Calculate statistics from transactions for a period.
|
|
|
|
Args:
|
|
accountId: Account ID
|
|
startDate: Start date
|
|
endDate: End date
|
|
|
|
Returns:
|
|
Statistics dict
|
|
"""
|
|
transactions = self.getTransactions(accountId, limit=10000, startDate=startDate, endDate=endDate)
|
|
|
|
# Filter only DEBIT transactions (usage)
|
|
debits = [t for t in transactions if t.get("transactionType") == TransactionTypeEnum.DEBIT.value]
|
|
|
|
totalCost = sum(t.get("amount", 0) for t in debits)
|
|
|
|
# Calculate by provider
|
|
costByProvider = {}
|
|
costByModel = {}
|
|
for t in debits:
|
|
provider = t.get("aicoreProvider", "unknown")
|
|
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
|
|
|
|
model = t.get("aicoreModel", "unknown")
|
|
costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0)
|
|
|
|
# Calculate by feature
|
|
costByFeature = {}
|
|
for t in debits:
|
|
feature = t.get("featureCode", "unknown")
|
|
costByFeature[feature] = costByFeature.get(feature, 0) + t.get("amount", 0)
|
|
|
|
return {
|
|
"totalCostCHF": totalCost,
|
|
"transactionCount": len(debits),
|
|
"costByProvider": costByProvider,
|
|
"costByModel": costByModel,
|
|
"costByFeature": costByFeature
|
|
}
|
|
|
|
# =========================================================================
|
|
# Utility Methods
|
|
# =========================================================================
|
|
|
|
def getBalancesForUser(self, userId: str) -> List[BillingBalanceResponse]:
|
|
"""
|
|
Get all billing balances for a user across mandates.
|
|
|
|
Args:
|
|
userId: User ID
|
|
|
|
Returns:
|
|
List of BillingBalanceResponse
|
|
"""
|
|
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
|
|
|
balances = []
|
|
|
|
# Get all mandates the user belongs to
|
|
try:
|
|
appInterface = getAppInterface(self.currentUser)
|
|
userMandates = appInterface.getUserMandates(userId)
|
|
|
|
for um in userMandates:
|
|
# Handle both Pydantic models and dicts
|
|
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
|
|
if not mandateId:
|
|
continue
|
|
|
|
mandate = appInterface.getMandate(mandateId)
|
|
if not mandate:
|
|
continue
|
|
|
|
# Get mandate name (handle both Pydantic and dict)
|
|
mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "")
|
|
|
|
settings = self.getSettings(mandateId)
|
|
if not settings:
|
|
continue
|
|
|
|
billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
|
|
|
|
# Get the relevant account
|
|
if billingModel == BillingModelEnum.PREPAY_USER:
|
|
account = self.getUserAccount(mandateId, userId)
|
|
elif billingModel in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]:
|
|
account = self.getMandateAccount(mandateId)
|
|
else:
|
|
continue
|
|
|
|
if not account:
|
|
continue
|
|
|
|
balance = account.get("balance", 0.0)
|
|
warningThreshold = account.get("warningThreshold", 0.0)
|
|
|
|
balances.append(BillingBalanceResponse(
|
|
mandateId=mandateId,
|
|
mandateName=mandateName,
|
|
billingModel=billingModel,
|
|
balance=balance,
|
|
warningThreshold=warningThreshold,
|
|
isWarning=balance <= warningThreshold,
|
|
creditLimit=account.get("creditLimit")
|
|
))
|
|
except Exception as e:
|
|
logger.error(f"Error getting balances for user: {e}")
|
|
|
|
return balances
|
|
|
|
def getTransactionsForUser(self, userId: str, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all transactions for a user across all mandates they belong to.
|
|
|
|
Args:
|
|
userId: User ID
|
|
limit: Maximum number of results
|
|
|
|
Returns:
|
|
List of transaction dicts
|
|
"""
|
|
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
|
|
|
allTransactions = []
|
|
|
|
try:
|
|
appInterface = getAppInterface(self.currentUser)
|
|
userMandates = appInterface.getUserMandates(userId)
|
|
|
|
for um in userMandates:
|
|
# Handle both Pydantic models and dicts
|
|
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
|
|
if not mandateId:
|
|
continue
|
|
|
|
# Only include mandates with billing settings
|
|
settings = self.getSettings(mandateId)
|
|
if not settings:
|
|
continue
|
|
|
|
# Get transactions for this mandate
|
|
transactions = self.getTransactionsByMandate(mandateId, limit=limit)
|
|
|
|
# Add mandate context to each transaction
|
|
mandate = appInterface.getMandate(mandateId)
|
|
mandateName = ""
|
|
if mandate:
|
|
mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "")
|
|
|
|
for t in transactions:
|
|
t["mandateId"] = mandateId
|
|
t["mandateName"] = mandateName
|
|
allTransactions.append(t)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting transactions for user: {e}")
|
|
|
|
# Sort by creation date descending and limit
|
|
allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True)
|
|
return allTransactions[:limit]
|
|
|
|
# =========================================================================
|
|
# Mandate View Operations (Admin-Level)
|
|
# =========================================================================
|
|
|
|
def getMandateBalances(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get mandate-level balances.
|
|
|
|
Args:
|
|
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
|
|
|
Returns:
|
|
List of mandate balance dicts
|
|
"""
|
|
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
|
|
|
balances = []
|
|
|
|
try:
|
|
appInterface = getAppInterface(self.currentUser)
|
|
|
|
# Get settings for filtering
|
|
if mandateIds:
|
|
allSettings = [self.getSettings(mId) for mId in mandateIds]
|
|
allSettings = [s for s in allSettings if s]
|
|
else:
|
|
allSettings = self.db.getRecordset(BillingSettings)
|
|
|
|
for settings in allSettings:
|
|
mandateId = settings.get("mandateId")
|
|
if not mandateId:
|
|
continue
|
|
|
|
billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value))
|
|
|
|
# Get mandate info
|
|
mandate = appInterface.getMandate(mandateId)
|
|
mandateName = ""
|
|
if mandate:
|
|
mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "")
|
|
|
|
# For PREPAY_MANDATE, get the mandate account balance
|
|
# For PREPAY_USER, aggregate all user balances
|
|
if billingModel == BillingModelEnum.PREPAY_MANDATE:
|
|
account = self.getMandateAccount(mandateId)
|
|
totalBalance = account.get("balance", 0.0) if account else 0.0
|
|
userCount = 0
|
|
elif billingModel == BillingModelEnum.PREPAY_USER:
|
|
# Get all user accounts for this mandate
|
|
userAccounts = self.db.getRecordset(
|
|
BillingAccount,
|
|
recordFilter={"mandateId": mandateId, "accountType": AccountTypeEnum.USER.value}
|
|
)
|
|
totalBalance = sum(acc.get("balance", 0.0) for acc in userAccounts)
|
|
userCount = len(userAccounts)
|
|
else:
|
|
totalBalance = 0.0
|
|
userCount = 0
|
|
|
|
balances.append({
|
|
"mandateId": mandateId,
|
|
"mandateName": mandateName,
|
|
"billingModel": billingModel.value,
|
|
"totalBalance": totalBalance,
|
|
"userCount": userCount,
|
|
"defaultUserCredit": settings.get("defaultUserCredit", 0.0),
|
|
"warningThresholdPercent": settings.get("warningThresholdPercent", 10.0),
|
|
"blockOnZeroBalance": settings.get("blockOnZeroBalance", True)
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting mandate balances: {e}")
|
|
|
|
return balances
|
|
|
|
def getMandateTransactions(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all transactions for specified mandates.
|
|
|
|
Args:
|
|
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
|
limit: Maximum number of results
|
|
|
|
Returns:
|
|
List of transaction dicts with mandate context
|
|
"""
|
|
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
|
|
|
allTransactions = []
|
|
|
|
try:
|
|
appInterface = getAppInterface(self.currentUser)
|
|
|
|
# Determine which mandates to query
|
|
if mandateIds:
|
|
targetMandateIds = mandateIds
|
|
else:
|
|
allSettings = self.db.getRecordset(BillingSettings)
|
|
targetMandateIds = [s.get("mandateId") for s in allSettings if s.get("mandateId")]
|
|
|
|
for mandateId in targetMandateIds:
|
|
transactions = self.getTransactionsByMandate(mandateId, limit=limit)
|
|
|
|
# Get mandate name
|
|
mandate = appInterface.getMandate(mandateId)
|
|
mandateName = ""
|
|
if mandate:
|
|
mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "")
|
|
|
|
for t in transactions:
|
|
t["mandateId"] = mandateId
|
|
t["mandateName"] = mandateName
|
|
allTransactions.append(t)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting mandate transactions: {e}")
|
|
|
|
# Sort by creation date descending and limit
|
|
allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True)
|
|
return allTransactions[:limit]
|
|
|
|
# =========================================================================
|
|
# User View Operations (User-Level with RBAC)
|
|
# =========================================================================
|
|
|
|
def getUserBalancesForMandates(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all user-level balances for specified mandates.
|
|
|
|
Args:
|
|
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
|
|
|
Returns:
|
|
List of user balance dicts with mandate and user context
|
|
"""
|
|
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
|
|
|
balances = []
|
|
|
|
try:
|
|
appInterface = getAppInterface(self.currentUser)
|
|
|
|
# Get all user accounts
|
|
accountFilter = {"accountType": AccountTypeEnum.USER.value}
|
|
allAccounts = self.db.getRecordset(BillingAccount, recordFilter=accountFilter)
|
|
|
|
# Filter by mandate if specified
|
|
if mandateIds:
|
|
allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds]
|
|
|
|
# Get all relevant settings in one query
|
|
settingsMap = {}
|
|
allSettings = self.db.getRecordset(BillingSettings)
|
|
for s in allSettings:
|
|
settingsMap[s.get("mandateId")] = s
|
|
|
|
# Get user info efficiently
|
|
userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId")))
|
|
userMap = {}
|
|
for userId in userIds:
|
|
user = appInterface.getUser(userId)
|
|
if user:
|
|
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
|
|
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
|
|
userMap[userId] = displayName or username or userId
|
|
|
|
# Get mandate info efficiently
|
|
mandateMap = {}
|
|
mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId")))
|
|
for mandateId in mandateIdList:
|
|
mandate = appInterface.getMandate(mandateId)
|
|
if mandate:
|
|
mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "")
|
|
mandateMap[mandateId] = mandateName
|
|
|
|
for account in allAccounts:
|
|
mandateId = account.get("mandateId")
|
|
userId = account.get("userId")
|
|
|
|
if not mandateId or not userId:
|
|
continue
|
|
|
|
settings = settingsMap.get(mandateId)
|
|
if not settings:
|
|
continue
|
|
|
|
balance = account.get("balance", 0.0)
|
|
warningThreshold = account.get("warningThreshold", 0.0)
|
|
|
|
balances.append({
|
|
"accountId": account.get("id"),
|
|
"mandateId": mandateId,
|
|
"mandateName": mandateMap.get(mandateId, ""),
|
|
"userId": userId,
|
|
"userName": userMap.get(userId, userId),
|
|
"balance": balance,
|
|
"warningThreshold": warningThreshold,
|
|
"isWarning": balance <= warningThreshold,
|
|
"enabled": account.get("enabled", True)
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting user balances for mandates: {e}")
|
|
|
|
return balances
|
|
|
|
def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all transactions for specified mandates (both USER and MANDATE accounts).
|
|
|
|
Args:
|
|
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
|
limit: Maximum number of results
|
|
|
|
Returns:
|
|
List of transaction dicts with mandate and user context
|
|
"""
|
|
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
|
|
|
allTransactions = []
|
|
|
|
try:
|
|
appInterface = getAppInterface(self.currentUser)
|
|
|
|
# Get ALL accounts (both USER and MANDATE types) to cover all billing models
|
|
allAccounts = self.db.getRecordset(BillingAccount)
|
|
|
|
# Filter by mandate if specified
|
|
if mandateIds:
|
|
allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds]
|
|
|
|
# Build account to user/mandate mapping
|
|
accountMap = {}
|
|
for acc in allAccounts:
|
|
accountMap[acc.get("id")] = {
|
|
"mandateId": acc.get("mandateId"),
|
|
"userId": acc.get("userId")
|
|
}
|
|
|
|
# Get user info efficiently
|
|
userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId")))
|
|
userMap = {}
|
|
for userId in userIds:
|
|
user = appInterface.getUser(userId)
|
|
if user:
|
|
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
|
|
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
|
|
userMap[userId] = displayName or username or userId
|
|
|
|
# Get mandate info efficiently
|
|
mandateMap = {}
|
|
mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId")))
|
|
for mandateId in mandateIdList:
|
|
mandate = appInterface.getMandate(mandateId)
|
|
if mandate:
|
|
mandateName = getattr(mandate, 'name', None) or (mandate.get("name", "") if isinstance(mandate, dict) else "")
|
|
mandateMap[mandateId] = mandateName
|
|
|
|
# Get transactions for all accounts and collect createdByUserIds
|
|
rawTransactions = []
|
|
for account in allAccounts:
|
|
accountId = account.get("id")
|
|
if not accountId:
|
|
continue
|
|
|
|
transactions = self.getTransactions(accountId, limit=limit)
|
|
accountInfo = accountMap.get(accountId, {})
|
|
mandateId = accountInfo.get("mandateId")
|
|
accountUserId = accountInfo.get("userId")
|
|
|
|
for t in transactions:
|
|
t["_accountUserId"] = accountUserId
|
|
t["_accountMandateId"] = mandateId
|
|
rawTransactions.append(t)
|
|
|
|
# Resolve createdByUserIds that are not yet in userMap
|
|
extraUserIds = set()
|
|
for t in rawTransactions:
|
|
cbUserId = t.get("createdByUserId")
|
|
if cbUserId and cbUserId not in userMap:
|
|
extraUserIds.add(cbUserId)
|
|
|
|
for uid in extraUserIds:
|
|
user = appInterface.getUser(uid)
|
|
if user:
|
|
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
|
|
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
|
|
userMap[uid] = displayName or username or uid
|
|
|
|
# Enrich transactions
|
|
for t in rawTransactions:
|
|
mandateId = t.pop("_accountMandateId", None)
|
|
accountUserId = t.pop("_accountUserId", None)
|
|
t["mandateId"] = mandateId
|
|
t["mandateName"] = mandateMap.get(mandateId, "")
|
|
# Prefer createdByUserId (per-transaction) over account-derived userId
|
|
txUserId = t.get("createdByUserId") or accountUserId
|
|
t["userId"] = txUserId
|
|
t["userName"] = userMap.get(txUserId, txUserId) if txUserId else None
|
|
allTransactions.append(t)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting user transactions for mandates: {e}")
|
|
|
|
# Sort by creation date descending and limit
|
|
allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True)
|
|
return allTransactions[:limit]
|