# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Interface for Billing operations. Manages billing accounts, transactions, and usage statistics. All billing data is stored in the poweron_billing database. """ import logging from typing import Dict, Any, List, Optional from datetime import date, datetime, timedelta import uuid from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG from modules.shared.timeUtils import getUtcTimestamp from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelBilling import ( BillingAccount, BillingTransaction, BillingSettings, UsageStatistics, BillingAddress, BillingModelEnum, AccountTypeEnum, TransactionTypeEnum, ReferenceTypeEnum, PeriodTypeEnum, BillingBalanceResponse, BillingCheckResult, ) logger = logging.getLogger(__name__) # Singleton factory for BillingObjects instances _billingInterfaces: Dict[str, "BillingObjects"] = {} # Database name for billing BILLING_DATABASE = "poweron_billing" def getInterface(currentUser: User, mandateId: str = None) -> "BillingObjects": """ Factory function to get or create a BillingObjects instance. Args: currentUser: Current user object mandateId: Mandate ID for context Returns: BillingObjects instance """ cacheKey = f"{currentUser.id}_{mandateId}" if cacheKey not in _billingInterfaces: _billingInterfaces[cacheKey] = BillingObjects(currentUser, mandateId) else: _billingInterfaces[cacheKey].setUserContext(currentUser, mandateId) return _billingInterfaces[cacheKey] def _getRootInterface() -> "BillingObjects": """Get interface with system access for bootstrap operations.""" from modules.security.rootAccess import getRootUser rootUser = getRootUser() return BillingObjects(rootUser, mandateId=None) class BillingObjects: """ Interface for billing operations. Manages accounts, transactions, settings, and statistics. """ def __init__(self, currentUser: Optional[User] = None, mandateId: str = None): """ Initialize the billing interface. Args: currentUser: Current user object mandateId: Mandate ID for context """ self.currentUser = currentUser self.userId = currentUser.id if currentUser else None self.mandateId = mandateId # Initialize database connection self._initializeDatabase() def _initializeDatabase(self): """Initialize database connection.""" self.db = DatabaseConnector( dbDatabase=BILLING_DATABASE, dbHost=APP_CONFIG.get('DB_HOST', 'localhost'), dbPort=int(APP_CONFIG.get('DB_PORT', '5432')), dbUser=APP_CONFIG.get('DB_USER'), dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET') ) def setUserContext(self, currentUser: User, mandateId: str = None): """ Update user context. Args: currentUser: Current user object mandateId: Mandate ID for context """ self.currentUser = currentUser self.userId = currentUser.id if currentUser else None self.mandateId = mandateId # ========================================================================= # BillingSettings Operations # ========================================================================= def getSettings(self, mandateId: str) -> Optional[Dict[str, Any]]: """ Get billing settings for a mandate. Args: mandateId: Mandate ID Returns: BillingSettings dict or None if not found """ try: results = self.db.getRecordset( BillingSettings, recordFilter={"mandateId": mandateId} ) return results[0] if results else None except Exception as e: logger.error(f"Error getting billing settings: {e}") return None def createSettings(self, settings: BillingSettings) -> Dict[str, Any]: """ Create billing settings for a mandate. Args: settings: BillingSettings object Returns: Created settings dict """ settingsDict = settings.model_dump(exclude_none=True) # Handle nested BillingAddress if settings.billingAddress: settingsDict["billingAddress"] = settings.billingAddress.model_dump() return self.db.recordCreate(BillingSettings, settingsDict) def updateSettings(self, settingsId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ Update billing settings. Args: settingsId: Settings ID updates: Fields to update Returns: Updated settings dict or None """ return self.db.recordModify(BillingSettings, settingsId, updates) def getOrCreateSettings(self, mandateId: str, defaultModel: BillingModelEnum = BillingModelEnum.UNLIMITED) -> Dict[str, Any]: """ Get or create billing settings for a mandate. Args: mandateId: Mandate ID defaultModel: Default billing model if creating Returns: BillingSettings dict """ existing = self.getSettings(mandateId) if existing: return existing settings = BillingSettings( mandateId=mandateId, billingModel=defaultModel, defaultUserCredit=10.0, warningThresholdPercent=10.0, blockOnZeroBalance=True, notifyOnWarning=True ) return self.createSettings(settings) # ========================================================================= # BillingAccount Operations # ========================================================================= def getAccount(self, accountId: str) -> Optional[Dict[str, Any]]: """Get a billing account by ID.""" try: results = self.db.getRecordset( BillingAccount, recordFilter={"id": accountId} ) return results[0] if results else None except Exception as e: logger.error(f"Error getting billing account: {e}") return None def getMandateAccount(self, mandateId: str) -> Optional[Dict[str, Any]]: """ Get the mandate-level billing account. Args: mandateId: Mandate ID Returns: BillingAccount dict or None """ try: results = self.db.getRecordset( BillingAccount, recordFilter={ "mandateId": mandateId, "accountType": AccountTypeEnum.MANDATE.value } ) return results[0] if results else None except Exception as e: logger.error(f"Error getting mandate account: {e}") return None def getUserAccount(self, mandateId: str, userId: str) -> Optional[Dict[str, Any]]: """ Get a user-level billing account within a mandate. Args: mandateId: Mandate ID userId: User ID Returns: BillingAccount dict or None """ try: results = self.db.getRecordset( BillingAccount, recordFilter={ "mandateId": mandateId, "userId": userId, "accountType": AccountTypeEnum.USER.value } ) return results[0] if results else None except Exception as e: logger.error(f"Error getting user account: {e}") return None def getAccountsByMandate(self, mandateId: str) -> List[Dict[str, Any]]: """ Get all billing accounts for a mandate. Args: mandateId: Mandate ID Returns: List of BillingAccount dicts """ try: return self.db.getRecordset( BillingAccount, recordFilter={"mandateId": mandateId} ) except Exception as e: logger.error(f"Error getting accounts for mandate: {e}") return [] def createAccount(self, account: BillingAccount) -> Dict[str, Any]: """ Create a new billing account. Args: account: BillingAccount object Returns: Created account dict """ accountDict = account.model_dump(exclude_none=True) return self.db.recordCreate(BillingAccount, accountDict) def updateAccountBalance(self, accountId: str, newBalance: float) -> Optional[Dict[str, Any]]: """ Update account balance atomically. Args: accountId: Account ID newBalance: New balance value Returns: Updated account dict or None """ return self.db.recordModify(BillingAccount, accountId, {"balance": newBalance}) def getOrCreateMandateAccount(self, mandateId: str, initialBalance: float = 0.0) -> Dict[str, Any]: """ Get or create a mandate-level billing account. Args: mandateId: Mandate ID initialBalance: Initial balance if creating Returns: BillingAccount dict """ existing = self.getMandateAccount(mandateId) if existing: return existing account = BillingAccount( mandateId=mandateId, accountType=AccountTypeEnum.MANDATE, balance=initialBalance, enabled=True ) return self.createAccount(account) def getOrCreateUserAccount(self, mandateId: str, userId: str, initialBalance: float = 0.0) -> Dict[str, Any]: """ Get or create a user-level billing account. Args: mandateId: Mandate ID userId: User ID initialBalance: Initial balance if creating Returns: BillingAccount dict """ existing = self.getUserAccount(mandateId, userId) if existing: return existing account = BillingAccount( mandateId=mandateId, userId=userId, accountType=AccountTypeEnum.USER, balance=initialBalance, enabled=True ) created = self.createAccount(account) # If initial balance > 0, create a SYSTEM credit transaction if initialBalance > 0: self.createTransaction(BillingTransaction( accountId=created["id"], transactionType=TransactionTypeEnum.CREDIT, amount=initialBalance, description="Initial credit for new user", referenceType=ReferenceTypeEnum.SYSTEM )) return created # ========================================================================= # BillingTransaction Operations # ========================================================================= def createTransaction(self, transaction: BillingTransaction) -> Dict[str, Any]: """ Create a new billing transaction and update account balance. Args: transaction: BillingTransaction object Returns: Created transaction dict """ # Get current account account = self.getAccount(transaction.accountId) if not account: raise ValueError(f"Account {transaction.accountId} not found") currentBalance = account.get("balance", 0.0) # Calculate new balance if transaction.transactionType == TransactionTypeEnum.CREDIT: newBalance = currentBalance + transaction.amount elif transaction.transactionType == TransactionTypeEnum.DEBIT: newBalance = currentBalance - transaction.amount else: # ADJUSTMENT newBalance = currentBalance + transaction.amount # Can be positive or negative # Create transaction transactionDict = transaction.model_dump(exclude_none=True) created = self.db.recordCreate(BillingTransaction, transactionDict) # Update account balance self.updateAccountBalance(transaction.accountId, newBalance) logger.info(f"Billing transaction created: {transaction.transactionType.value} {transaction.amount} CHF, " f"balance: {currentBalance} -> {newBalance}") return created def getTransactions( self, accountId: str, limit: int = 100, offset: int = 0, startDate: date = None, endDate: date = None ) -> List[Dict[str, Any]]: """ Get transactions for an account. Args: accountId: Account ID limit: Maximum number of results offset: Offset for pagination startDate: Filter by start date endDate: Filter by end date Returns: List of transaction dicts """ try: filterDict = {"accountId": accountId} results = self.db.getRecordset(BillingTransaction, recordFilter=filterDict) # Apply date filters if provided if startDate or endDate: filtered = [] for t in results: createdAt = t.get("_createdAt") if createdAt: tDate = createdAt.date() if isinstance(createdAt, datetime) else createdAt if startDate and tDate < startDate: continue if endDate and tDate > endDate: continue filtered.append(t) results = filtered # Sort by creation date descending results.sort(key=lambda x: x.get("_createdAt", ""), reverse=True) # Apply pagination return results[offset:offset + limit] except Exception as e: logger.error(f"Error getting transactions: {e}") return [] def getTransactionsByMandate(self, mandateId: str, limit: int = 100) -> List[Dict[str, Any]]: """ Get all transactions for a mandate (across all accounts). Args: mandateId: Mandate ID limit: Maximum number of results Returns: List of transaction dicts """ # Get all accounts for mandate accounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId}) allTransactions = [] for account in accounts: transactions = self.getTransactions(account["id"], limit=limit) allTransactions.extend(transactions) # Sort by creation date descending and limit allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True) return allTransactions[:limit] # ========================================================================= # Balance Check Operations # ========================================================================= def checkBalance(self, mandateId: str, userId: str, estimatedCost: float) -> BillingCheckResult: """ Check if there's sufficient balance for an operation. Args: mandateId: Mandate ID userId: User ID estimatedCost: Estimated cost of the operation Returns: BillingCheckResult """ settings = self.getSettings(mandateId) if not settings: # No settings = no billing = allowed return BillingCheckResult(allowed=True, billingModel=BillingModelEnum.UNLIMITED) billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) # UNLIMITED = always allowed if billingModel == BillingModelEnum.UNLIMITED: return BillingCheckResult(allowed=True, billingModel=billingModel) # Get the relevant account if billingModel == BillingModelEnum.PREPAY_USER: account = self.getUserAccount(mandateId, userId) else: account = self.getMandateAccount(mandateId) if not account: # No account = no balance = potentially blocked if settings.get("blockOnZeroBalance", True): return BillingCheckResult( allowed=False, reason="NO_ACCOUNT", currentBalance=0.0, requiredAmount=estimatedCost, billingModel=billingModel ) return BillingCheckResult(allowed=True, currentBalance=0.0, billingModel=billingModel) currentBalance = account.get("balance", 0.0) # CREDIT_POSTPAY with credit limit check if billingModel == BillingModelEnum.CREDIT_POSTPAY: creditLimit = account.get("creditLimit") if creditLimit and abs(currentBalance) + estimatedCost > creditLimit: return BillingCheckResult( allowed=False, reason="CREDIT_LIMIT_EXCEEDED", currentBalance=currentBalance, requiredAmount=estimatedCost, billingModel=billingModel ) return BillingCheckResult(allowed=True, currentBalance=currentBalance, billingModel=billingModel) # PREPAY models - check balance if currentBalance < estimatedCost: if settings.get("blockOnZeroBalance", True): return BillingCheckResult( allowed=False, reason="INSUFFICIENT_BALANCE", currentBalance=currentBalance, requiredAmount=estimatedCost, billingModel=billingModel ) return BillingCheckResult(allowed=True, currentBalance=currentBalance, billingModel=billingModel) def recordUsage( self, mandateId: str, userId: str, priceCHF: float, workflowId: str = None, featureInstanceId: str = None, featureCode: str = None, aicoreProvider: str = None, description: str = "AI Usage" ) -> Optional[Dict[str, Any]]: """ Record usage cost as a billing transaction. Args: mandateId: Mandate ID userId: User ID priceCHF: Cost in CHF workflowId: Optional workflow ID featureInstanceId: Optional feature instance ID featureCode: Optional feature code aicoreProvider: Optional AICore provider name description: Transaction description Returns: Created transaction dict or None """ if priceCHF <= 0: return None settings = self.getSettings(mandateId) if not settings: logger.debug(f"No billing settings for mandate {mandateId}, skipping usage recording") return None billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) # UNLIMITED = no transaction recording if billingModel == BillingModelEnum.UNLIMITED: return None # Get or create the relevant account if billingModel == BillingModelEnum.PREPAY_USER: account = self.getOrCreateUserAccount(mandateId, userId) else: account = self.getOrCreateMandateAccount(mandateId) # Create debit transaction transaction = BillingTransaction( accountId=account["id"], transactionType=TransactionTypeEnum.DEBIT, amount=priceCHF, description=description, referenceType=ReferenceTypeEnum.WORKFLOW, workflowId=workflowId, featureInstanceId=featureInstanceId, featureCode=featureCode, aicoreProvider=aicoreProvider ) return self.createTransaction(transaction) # ========================================================================= # Statistics Operations # ========================================================================= def getUsageStatistics( self, accountId: str, periodType: PeriodTypeEnum, year: int, month: int = None ) -> List[Dict[str, Any]]: """ Get usage statistics for an account. Args: accountId: Account ID periodType: Period type (DAY, MONTH, YEAR) year: Year month: Month (for DAY period type) Returns: List of statistics dicts """ filterDict = { "accountId": accountId, "periodType": periodType.value } results = self.db.getRecordset(UsageStatistics, recordFilter=filterDict) # Filter by year filtered = [s for s in results if s.get("periodStart") and s["periodStart"].year == year] # Filter by month if specified if month and periodType == PeriodTypeEnum.DAY: filtered = [s for s in filtered if s["periodStart"].month == month] return sorted(filtered, key=lambda x: x.get("periodStart", date.min)) def calculateStatisticsFromTransactions( self, accountId: str, startDate: date, endDate: date ) -> Dict[str, Any]: """ Calculate statistics from transactions for a period. Args: accountId: Account ID startDate: Start date endDate: End date Returns: Statistics dict """ transactions = self.getTransactions(accountId, limit=10000, startDate=startDate, endDate=endDate) # Filter only DEBIT transactions (usage) debits = [t for t in transactions if t.get("transactionType") == TransactionTypeEnum.DEBIT.value] totalCost = sum(t.get("amount", 0) for t in debits) # Calculate by provider costByProvider = {} for t in debits: provider = t.get("aicoreProvider", "unknown") costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0) # Calculate by feature costByFeature = {} for t in debits: feature = t.get("featureCode", "unknown") costByFeature[feature] = costByFeature.get(feature, 0) + t.get("amount", 0) return { "totalCostCHF": totalCost, "transactionCount": len(debits), "costByProvider": costByProvider, "costByFeature": costByFeature } # ========================================================================= # Utility Methods # ========================================================================= def getBalancesForUser(self, userId: str) -> List[BillingBalanceResponse]: """ Get all billing balances for a user across mandates. Args: userId: User ID Returns: List of BillingBalanceResponse """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface balances = [] # Get all mandates the user belongs to try: appInterface = getAppInterface(self.currentUser) userMandates = appInterface.getUserMandates(userId) for um in userMandates: mandateId = um.get("mandateId") mandate = appInterface.getMandate(mandateId) if not mandate: continue settings = self.getSettings(mandateId) if not settings: continue billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) # Get the relevant account if billingModel == BillingModelEnum.PREPAY_USER: account = self.getUserAccount(mandateId, userId) elif billingModel in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]: account = self.getMandateAccount(mandateId) else: continue if not account: continue balance = account.get("balance", 0.0) warningThreshold = account.get("warningThreshold", 0.0) balances.append(BillingBalanceResponse( mandateId=mandateId, mandateName=mandate.get("name", ""), billingModel=billingModel, balance=balance, warningThreshold=warningThreshold, isWarning=balance <= warningThreshold, creditLimit=account.get("creditLimit") )) except Exception as e: logger.error(f"Error getting balances for user: {e}") return balances