# Copyright (c) 2026 PowerOn AG # All rights reserved. """ Interface for Billing operations. Manages billing accounts, transactions, and usage statistics. All billing data is stored in the poweron_billing database. """ import logging import copy import math from typing import Dict, Any, List, Optional, Union from datetime import date, datetime, timedelta, timezone import uuid from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG from modules.dbHelpers.dbRegistry import registerDatabase from modules.shared.timeUtils import getUtcTimestamp from modules.datamodels.datamodelUam import User, Mandate from modules.datamodels.datamodelMembership import UserMandate from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult from modules.datamodels.datamodelBilling import ( BillingAccount, BillingTransaction, BillingSettings, StripeWebhookEvent, UsageStatistics, TransactionTypeEnum, ReferenceTypeEnum, PeriodTypeEnum, BillingBalanceResponse, BillingCheckResult, STORAGE_PRICE_PER_GB_CHF, ) logger = logging.getLogger(__name__) def _logBillingTransactionsMissingSysCreatedAt(rows: List[Dict[str, Any]], context: str) -> None: """Log ERROR when sysCreatedAt is missing; does not raise.""" missingIds = [r.get("id") for r in rows if r.get("sysCreatedAt") is None] if not missingIds: return cap = 40 sample = missingIds[:cap] suffix = f"; ... (+{len(missingIds) - cap} more)" if len(missingIds) > cap else "" logger.error( "BillingTransaction missing sysCreatedAt (%s): count=%s; transactionIds=%s%s", context, len(missingIds), sample, suffix, ) def _numericSysCreatedAtForSort(row: Dict[str, Any]) -> float: v = row["sysCreatedAt"] if isinstance(v, datetime): return v.timestamp() return float(v) def _sortBillingTransactionsBySysCreatedAtDesc(rows: List[Dict[str, Any]], context: str) -> None: _logBillingTransactionsMissingSysCreatedAt(rows, context) valid = [r for r in rows if r.get("sysCreatedAt") is not None] invalid = [r for r in rows if r.get("sysCreatedAt") is None] valid.sort(key=_numericSysCreatedAtForSort, reverse=True) rows[:] = valid + invalid def _getAppDatabaseConnector() -> DatabaseConnector: """App DB connector (same config as UserMandate reads in this module).""" return DatabaseConnector( dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"), dbHost=APP_CONFIG.get("DB_HOST", "localhost"), dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), dbUser=APP_CONFIG.get("DB_USER"), dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), ) def _getRootMandateIdFromAppDb(appDb: DatabaseConnector) -> Optional[str]: """Resolve root mandate id (name='root', isSystem=True) from app database.""" try: rows = appDb.getRecordset(Mandate, recordFilter={"name": "root", "isSystem": True}) if rows: rid = rows[0].get("id") return str(rid) if rid is not None else None except Exception as e: logger.warning("Could not resolve root mandate id from app DB: %s", e) return None _cachedRootMandateId: Optional[str] = None _rootMandateIdCacheResolved: bool = False def _getCachedRootMandateId() -> Optional[str]: """Lazy-cached root mandate id (name=root, isSystem=True) for hot paths.""" global _cachedRootMandateId, _rootMandateIdCacheResolved if not _rootMandateIdCacheResolved: appDb = _getAppDatabaseConnector() _cachedRootMandateId = _getRootMandateIdFromAppDb(appDb) _rootMandateIdCacheResolved = True return _cachedRootMandateId # Singleton factory for BillingObjects instances _billingInterfaces: Dict[str, "BillingObjects"] = {} # Database name for billing BILLING_DATABASE = "poweron_billing" registerDatabase(BILLING_DATABASE) def getInterface(currentUser: User, mandateId: str = None) -> "BillingObjects": """ Factory function to get or create a BillingObjects instance. Args: currentUser: Current user object mandateId: Mandate ID for context Returns: BillingObjects instance """ cacheKey = f"{currentUser.id}_{mandateId}" if cacheKey not in _billingInterfaces: _billingInterfaces[cacheKey] = BillingObjects(currentUser, mandateId) else: _billingInterfaces[cacheKey].setUserContext(currentUser, mandateId) return _billingInterfaces[cacheKey] def getRootInterface() -> "BillingObjects": """Get interface with system access for bootstrap operations.""" from modules.security.rootAccess import getRootUser rootUser = getRootUser() return BillingObjects(rootUser, mandateId=None) class BillingObjects: """ Interface for billing operations. Manages accounts, transactions, settings, and statistics. """ def __init__(self, currentUser: Optional[User] = None, mandateId: str = None): """ Initialize the billing interface. Args: currentUser: Current user object mandateId: Mandate ID for context """ self.currentUser = currentUser self.userId = currentUser.id if currentUser else None self.mandateId = mandateId # Initialize database connection self._initializeDatabase() def _initializeDatabase(self): """Initialize database connection.""" self.db = DatabaseConnector( dbDatabase=BILLING_DATABASE, dbHost=APP_CONFIG.get('DB_HOST', 'localhost'), dbPort=int(APP_CONFIG.get('DB_PORT', '5432')), dbUser=APP_CONFIG.get('DB_USER'), dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET') ) def setUserContext(self, currentUser: User, mandateId: str = None): """ Update user context. Args: currentUser: Current user object mandateId: Mandate ID for context """ self.currentUser = currentUser self.userId = currentUser.id if currentUser else None self.mandateId = mandateId # ========================================================================= # BillingSettings Operations # ========================================================================= def getSettings(self, mandateId: str) -> Optional[Dict[str, Any]]: """ Get billing settings for a mandate. Args: mandateId: Mandate ID Returns: BillingSettings dict or None if not found """ try: results = self.db.getRecordset( BillingSettings, recordFilter={"mandateId": mandateId} ) if not results: return None return dict(results[0]) except Exception as e: logger.error(f"Error getting billing settings: {e}") return None def createSettings(self, settings: BillingSettings) -> Dict[str, Any]: """ Create billing settings for a mandate. Args: settings: BillingSettings object Returns: Created settings dict """ settingsDict = settings.model_dump(exclude_none=True) return self.db.recordCreate(BillingSettings, settingsDict) def updateSettings(self, settingsId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ Update billing settings. Args: settingsId: Settings ID updates: Fields to update Returns: Updated settings dict or None """ return self.db.recordModify(BillingSettings, settingsId, updates) def getOrCreateSettings(self, mandateId: str) -> Dict[str, Any]: """ Get or create billing settings for a mandate. Args: mandateId: Mandate ID Returns: BillingSettings dict """ existing = self.getSettings(mandateId) if existing: return existing settings = BillingSettings( mandateId=mandateId, warningThresholdPercent=10.0, notifyOnWarning=True, ) return self.createSettings(settings) # ========================================================================= # BillingAccount Operations # ========================================================================= def getAccount(self, accountId: str) -> Optional[Dict[str, Any]]: """Get a billing account by ID.""" try: results = self.db.getRecordset( BillingAccount, recordFilter={"id": accountId} ) return results[0] if results else None except Exception as e: logger.error(f"Error getting billing account: {e}") return None def getMandateAccount(self, mandateId: str) -> Optional[Dict[str, Any]]: """ Get the mandate-level billing account. Args: mandateId: Mandate ID Returns: BillingAccount dict or None """ try: results = self.db.getRecordset( BillingAccount, recordFilter={ "mandateId": mandateId, "userId": None } ) return results[0] if results else None except Exception as e: logger.error(f"Error getting mandate account: {e}") return None def getUserAccount(self, mandateId: str, userId: str) -> Optional[Dict[str, Any]]: """ Get a user-level billing account within a mandate. Args: mandateId: Mandate ID userId: User ID Returns: BillingAccount dict or None """ try: results = self.db.getRecordset( BillingAccount, recordFilter={ "mandateId": mandateId, "userId": userId } ) return results[0] if results else None except Exception as e: logger.error(f"Error getting user account: {e}") return None def getAccountsByMandate(self, mandateId: str) -> List[Dict[str, Any]]: """ Get all billing accounts for a mandate. Args: mandateId: Mandate ID Returns: List of BillingAccount dicts """ try: return self.db.getRecordset( BillingAccount, recordFilter={"mandateId": mandateId} ) except Exception as e: logger.error(f"Error getting accounts for mandate: {e}") return [] def createAccount(self, account: BillingAccount) -> Dict[str, Any]: """ Create a new billing account. Args: account: BillingAccount object Returns: Created account dict """ accountDict = account.model_dump(exclude_none=True) return self.db.recordCreate(BillingAccount, accountDict) def updateAccountBalance(self, accountId: str, newBalance: float) -> Optional[Dict[str, Any]]: """ Update account balance atomically. Args: accountId: Account ID newBalance: New balance value Returns: Updated account dict or None """ return self.db.recordModify(BillingAccount, accountId, {"balance": newBalance}) def getOrCreateMandateAccount(self, mandateId: str, initialBalance: float = 0.0) -> Dict[str, Any]: """ Get or create a mandate-level billing account. Args: mandateId: Mandate ID initialBalance: Initial balance if creating Returns: BillingAccount dict """ existing = self.getMandateAccount(mandateId) if existing: return existing account = BillingAccount( mandateId=mandateId, balance=initialBalance, enabled=True ) return self.createAccount(account) def getOrCreateUserAccount(self, mandateId: str, userId: str, initialBalance: float = 0.0) -> Dict[str, Any]: """ Get or create a user-level billing account. Args: mandateId: Mandate ID userId: User ID initialBalance: Initial balance if creating Returns: BillingAccount dict """ existing = self.getUserAccount(mandateId, userId) if existing: return existing account = BillingAccount( mandateId=mandateId, userId=userId, balance=initialBalance, enabled=True ) created = self.createAccount(account) # If initial balance > 0, create a SYSTEM credit transaction if initialBalance > 0: self.createTransaction(BillingTransaction( accountId=created["id"], transactionType=TransactionTypeEnum.CREDIT, amount=initialBalance, description="Initial credit for new user", referenceType=ReferenceTypeEnum.SYSTEM )) return created def ensureAllMandateSettingsExist(self) -> int: """ Efficiently ensure all mandates have billing settings. Creates default settings (0 CHF) for mandates without settings. Uses bulk queries to minimize database connections. Returns: Number of settings created """ try: settingsCreated = 0 # Step 1: Get all existing billing settings in one query (from billing DB) allSettings = self.db.getRecordset(BillingSettings) existingMandateIds = set(s.get("mandateId") for s in allSettings if s.get("mandateId")) # Step 2: Get all mandates from APP database (separate connection) appDb = DatabaseConnector( dbDatabase=APP_CONFIG.get('DB_DATABASE', 'poweron_app'), dbHost=APP_CONFIG.get('DB_HOST', 'localhost'), dbPort=int(APP_CONFIG.get('DB_PORT', '5432')), dbUser=APP_CONFIG.get('DB_USER'), dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET') ) allMandates = appDb.getRecordset(Mandate, recordFilter={"enabled": True}) # Step 3: Create settings for mandates that don't have them for mandate in allMandates: mandateId = mandate.get("id") if not mandateId or mandateId in existingMandateIds: continue settings = BillingSettings( mandateId=mandateId, warningThresholdPercent=10.0, notifyOnWarning=True, ) self.createSettings(settings) existingMandateIds.add(mandateId) settingsCreated += 1 if settingsCreated > 0: logger.info(f"Created {settingsCreated} missing billing settings for mandates") return settingsCreated except Exception as e: logger.error(f"Error ensuring mandate settings exist: {e}") return 0 def ensureAllUserAccountsExist(self) -> int: """ Ensure all users across all mandates have billing accounts. User accounts are always created for audit trail with initial balance 0.0. Uses bulk queries to minimize database connections. Returns: Number of accounts created """ try: accountsCreated = 0 appDb = _getAppDatabaseConnector() allSettings = self.db.getRecordset(BillingSettings) billingMandateIds = set( s.get("mandateId") for s in allSettings if s.get("mandateId") ) if not billingMandateIds: logger.debug("No billable mandates found, skipping account check") return 0 allAccounts = self.db.getRecordset(BillingAccount) existingAccountKeys = set() for acc in allAccounts: if not acc.get("userId"): continue key = (acc.get("mandateId"), acc.get("userId")) existingAccountKeys.add(key) allUserMandates = appDb.getRecordset( UserMandate, recordFilter={"enabled": True} ) for um in allUserMandates: mandateId = um.get("mandateId") userId = um.get("userId") if not mandateId or not userId: continue if mandateId not in billingMandateIds: continue key = (mandateId, userId) if key in existingAccountKeys: continue account = BillingAccount( mandateId=mandateId, userId=userId, balance=0.0, enabled=True ) self.createAccount(account) existingAccountKeys.add(key) accountsCreated += 1 if accountsCreated > 0: logger.info(f"Created {accountsCreated} missing billing accounts") return accountsCreated except Exception as e: logger.error(f"Error ensuring user accounts exist: {e}") return 0 # ========================================================================= # BillingTransaction Operations # ========================================================================= def createTransaction(self, transaction: BillingTransaction, balanceAccountId: str = None) -> Dict[str, Any]: """ Create a new billing transaction and update account balance. The transaction is always recorded against transaction.accountId (audit trail). The balance is updated on balanceAccountId if provided, otherwise on transaction.accountId. This allows recording a transaction on a user account (audit) while deducting from a mandate pool account (shared budget). Args: transaction: BillingTransaction object balanceAccountId: Optional account ID for balance update (defaults to transaction.accountId) Returns: Created transaction dict """ # Validate that the transaction's account exists txAccount = self.getAccount(transaction.accountId) if not txAccount: raise ValueError(f"Transaction account {transaction.accountId} not found") # Determine which account to update balance on targetBalanceAccountId = balanceAccountId or transaction.accountId if targetBalanceAccountId == transaction.accountId: balanceAccount = txAccount else: balanceAccount = self.getAccount(targetBalanceAccountId) if not balanceAccount: raise ValueError(f"Balance account {targetBalanceAccountId} not found") currentBalance = balanceAccount.get("balance", 0.0) # Calculate new balance if transaction.transactionType == TransactionTypeEnum.CREDIT: newBalance = currentBalance + transaction.amount elif transaction.transactionType == TransactionTypeEnum.DEBIT: newBalance = currentBalance - transaction.amount else: # ADJUSTMENT newBalance = currentBalance + transaction.amount # Create transaction record (always on transaction.accountId for audit) transactionDict = transaction.model_dump(exclude_none=True) created = self.db.recordCreate(BillingTransaction, transactionDict) # Update balance on the target account self.updateAccountBalance(targetBalanceAccountId, newBalance) logger.info(f"Billing transaction created: {transaction.transactionType.value} {transaction.amount} CHF, " f"audit={transaction.accountId}, balance on {targetBalanceAccountId}: {currentBalance} -> {newBalance}") return created def getTransactions( self, accountId: str, limit: int = 100, offset: int = 0, startDate: date = None, endDate: date = None, pagination: PaginationParams = None ) -> Union[List[Dict[str, Any]], PaginatedResult]: """ Get transactions for an account. When pagination is provided, uses database-level pagination and returns PaginatedResult. Otherwise falls back to in-memory filtering/sorting/slicing. Args: accountId: Account ID limit: Maximum number of results (legacy path) offset: Offset for pagination (legacy path) startDate: Filter by start date (legacy path) endDate: Filter by end date (legacy path) pagination: PaginationParams for DB-level pagination Returns: PaginatedResult when pagination is provided, List of dicts otherwise """ try: if pagination: recordFilter = {"accountId": accountId} result = self.db.getRecordsetPaginated( BillingTransaction, pagination=pagination, recordFilter=recordFilter ) from modules.dbHelpers.fkLabelResolver import enrichRowsWithFkLabels enrichRowsWithFkLabels(result.get("items", []), BillingTransaction, db=self.db) _logBillingTransactionsMissingSysCreatedAt( result["items"], "getTransactions(accountId) paginated", ) return PaginatedResult( items=result["items"], totalItems=result["totalItems"], totalPages=result["totalPages"] ) filterDict = {"accountId": accountId} results = self.db.getRecordset(BillingTransaction, recordFilter=filterDict) if startDate or endDate: filtered = [] for t in results: createdAt = t.get("sysCreatedAt") if createdAt: tDate = createdAt.date() if isinstance(createdAt, datetime) else createdAt if startDate and tDate < startDate: continue if endDate and tDate > endDate: continue filtered.append(t) results = filtered _sortBillingTransactionsBySysCreatedAtDesc(results, "getTransactions(accountId)") return results[offset:offset + limit] except Exception as e: logger.error(f"Error getting transactions: {e}") if pagination: return PaginatedResult(items=[], totalItems=0, totalPages=0) return [] def getTransactionsByMandate( self, mandateId: str, limit: int = 100, pagination: PaginationParams = None ) -> Union[List[Dict[str, Any]], PaginatedResult]: """ Get all transactions for a mandate (across all accounts). When pagination is provided, collects all accountIds for the mandate and issues a single DB query with SQL-level filtering, sorting, and pagination. Otherwise falls back to per-account querying and in-memory merging. Args: mandateId: Mandate ID limit: Maximum number of results (legacy path) pagination: PaginationParams for DB-level pagination Returns: PaginatedResult when pagination is provided, List of dicts otherwise """ accounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId}) accountIds = [acc["id"] for acc in accounts if acc.get("id")] if not accountIds: if pagination: return PaginatedResult(items=[], totalItems=0, totalPages=0) return [] if pagination: result = self.db.getRecordsetPaginated( BillingTransaction, pagination=pagination, recordFilter={"accountId": accountIds} ) from modules.dbHelpers.fkLabelResolver import enrichRowsWithFkLabels enrichRowsWithFkLabels(result.get("items", []), BillingTransaction, db=self.db) return PaginatedResult( items=result["items"], totalItems=result["totalItems"], totalPages=result["totalPages"] ) allTransactions = [] for account in accounts: transactions = self.getTransactions(account["id"], limit=limit) allTransactions.extend(transactions) _sortBillingTransactionsBySysCreatedAtDesc( allTransactions, "getTransactionsByMandate", ) return allTransactions[:limit] # ========================================================================= # StripeWebhookEvent Operations (idempotency) # ========================================================================= def getStripeWebhookEventByEventId(self, event_id: str) -> Optional[Dict[str, Any]]: """ Check if a Stripe event has already been processed (idempotency). Args: event_id: Stripe event ID (evt_xxx) Returns: Event record if exists, else None """ try: results = self.db.getRecordset( StripeWebhookEvent, recordFilter={"event_id": event_id} ) return results[0] if results else None except Exception as e: logger.error(f"Error checking Stripe webhook event: {e}") return None def createStripeWebhookEvent(self, event_id: str) -> Dict[str, Any]: """ Record that a Stripe event has been processed. Args: event_id: Stripe event ID (evt_xxx) Returns: Created event record """ record = StripeWebhookEvent(event_id=event_id) return self.db.recordCreate(StripeWebhookEvent, record.model_dump()) def getPaymentTransactionByReferenceId(self, referenceId: str) -> Optional[Dict[str, Any]]: """ Find an existing Stripe payment credit transaction by Checkout Session ID. Args: referenceId: Stripe Checkout Session ID (cs_xxx) Returns: Transaction record if found, else None """ try: results = self.db.getRecordset( BillingTransaction, recordFilter={ "referenceType": ReferenceTypeEnum.PAYMENT.value, "referenceId": referenceId, } ) return results[0] if results else None except Exception as e: logger.error(f"Error checking Stripe payment transaction by referenceId: {e}") return None # ========================================================================= # Balance Check Operations # ========================================================================= def checkBalance(self, mandateId: str, userId: str, estimatedCost: float) -> BillingCheckResult: """ Check if there's sufficient balance for an operation. Checks mandate pool balance against estimatedCost. User accounts are ensured to exist for audit trail. Missing settings: treated as PREPAY_MANDATE with empty pool. """ self.getOrCreateUserAccount(mandateId, userId, initialBalance=0.0) poolAccount = self.getOrCreateMandateAccount(mandateId) currentBalance = poolAccount.get("balance", 0.0) if currentBalance < estimatedCost: return BillingCheckResult( allowed=False, reason="INSUFFICIENT_BALANCE", currentBalance=currentBalance, requiredAmount=estimatedCost, ) return BillingCheckResult(allowed=True, currentBalance=currentBalance) def recordUsage( self, mandateId: str, userId: str, priceCHF: float, workflowId: str = None, featureInstanceId: str = None, featureCode: str = None, aicoreProvider: str = None, aicoreModel: str = None, description: str = "AI Usage", processingTime: float = None, bytesSent: int = None, bytesReceived: int = None, errorCount: int = None ) -> Optional[Dict[str, Any]]: """ Record usage cost as a billing transaction. Transaction is recorded on the user's account (audit trail). Balance is always deducted from the mandate pool account (PREPAY_MANDATE). """ if priceCHF <= 0: return None settings = self.getSettings(mandateId) if not settings: logger.debug(f"No billing settings for mandate {mandateId}, skipping usage recording") return None userAccount = self.getOrCreateUserAccount(mandateId, userId) transaction = BillingTransaction( accountId=userAccount["id"], transactionType=TransactionTypeEnum.DEBIT, amount=priceCHF, description=description, referenceType=ReferenceTypeEnum.WORKFLOW, workflowId=workflowId, featureInstanceId=featureInstanceId, featureCode=featureCode, aicoreProvider=aicoreProvider, aicoreModel=aicoreModel, createdByUserId=userId, processingTime=processingTime, bytesSent=bytesSent, bytesReceived=bytesReceived, errorCount=errorCount ) poolAccount = self.getOrCreateMandateAccount(mandateId) return self.createTransaction(transaction, balanceAccountId=poolAccount["id"]) def _parseSettingsDateTime(self, value: Any) -> Optional[datetime]: """Parse datetime from billing settings row (ISO string or datetime).""" if value is None: return None if isinstance(value, datetime): if value.tzinfo: return value.astimezone(timezone.utc) return value.replace(tzinfo=timezone.utc) if isinstance(value, str): s = value.replace("Z", "+00:00") try: dt = datetime.fromisoformat(s) except ValueError: return None if dt.tzinfo: return dt.astimezone(timezone.utc) return dt.replace(tzinfo=timezone.utc) return None def resetStorageBillingPeriod(self, mandateId: str, periodStartAt: datetime) -> None: """Reset storage watermark state for a new subscription billing period (e.g. Stripe invoice.paid).""" if periodStartAt.tzinfo is None: periodStartAt = periodStartAt.replace(tzinfo=timezone.utc) else: periodStartAt = periodStartAt.astimezone(timezone.utc) periodStartTs = periodStartAt.timestamp() settings = self.getOrCreateSettings(mandateId) prev = settings.get("storagePeriodStartAt") if prev is not None and abs(prev - periodStartTs) < 2: return from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot usedMB = float(_getSubRoot().getMandateDataVolumeMB(mandateId)) self.updateSettings( settings["id"], { "storageHighWatermarkMB": usedMB, "storageBilledUpToMB": 0.0, "storagePeriodStartAt": periodStartTs, }, ) logger.info( "Storage billing period reset for mandate %s at %s (usedMB=%.2f)", mandateId, periodStartAt.isoformat(), usedMB, ) def reconcileMandateStorageBilling(self, mandateId: str) -> Optional[Dict[str, Any]]: """Debit prepay pool for new storage overage using period high-watermark (no credit on delete). Skipped for enterprise subscriptions (hard-block via assertCapacity instead).""" settings = self.getSettings(mandateId) if not settings: return None from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot from modules.datamodels.datamodelSubscription import getPlan, getEffectiveLimits subIface = _getSubRoot() usedMB = float(subIface.getMandateDataVolumeMB(mandateId)) sub = subIface.getOperativeForMandate(mandateId) if sub and sub.get("isEnterprise"): return None plan = getPlan(sub.get("planKey", "")) if sub else None limits = getEffectiveLimits(sub, plan) if sub else {} includedMB = limits.get("maxDataVolumeMB") if includedMB is None: return None prevHigh = float(settings.get("storageHighWatermarkMB") or 0.0) high = max(prevHigh, usedMB) overageMB = max(0.0, high - float(includedMB)) billed = float(settings.get("storageBilledUpToMB") or 0.0) deltaOverage = overageMB - billed settingsUpdates: Dict[str, Any] = {} if high != prevHigh: settingsUpdates["storageHighWatermarkMB"] = high if deltaOverage <= 1e-9: if settingsUpdates: self.updateSettings(settings["id"], settingsUpdates) return None costCHF = round((deltaOverage / 1024.0) * float(STORAGE_PRICE_PER_GB_CHF), 4) if costCHF <= 0: if settingsUpdates: self.updateSettings(settings["id"], settingsUpdates) return None poolAccount = self.getOrCreateMandateAccount(mandateId) transaction = BillingTransaction( accountId=poolAccount["id"], transactionType=TransactionTypeEnum.DEBIT, amount=costCHF, description=f"Speicher-Überhang ({deltaOverage:.2f} MB über Plan)", referenceType=ReferenceTypeEnum.STORAGE, referenceId=mandateId, ) created = self.createTransaction(transaction) settingsUpdates["storageBilledUpToMB"] = overageMB self.updateSettings(settings["id"], settingsUpdates) logger.info( "Storage overage billed mandate=%s deltaOverageMB=%.4f costCHF=%s", mandateId, deltaOverage, costCHF, ) return created # ========================================================================= # Subscription AI-Budget Credit # ========================================================================= def creditSubscriptionBudget( self, mandateId: str, planKey: str, periodLabel: str = "", enterpriseBudgetOverride: Optional[float] = None, ) -> Optional[Dict[str, Any]]: """Credit AI budget to the mandate pool account. For standard plans: amount = budgetAiPerUserCHF * activeUsers. For enterprise: uses the fixed ``enterpriseBudgetOverride`` amount. Should be called once per billing period (initial activation + each invoice.paid). Returns the created CREDIT transaction or None if budget is 0.""" if enterpriseBudgetOverride is not None and enterpriseBudgetOverride > 0: amount = enterpriseBudgetOverride description = f"AI-Budget Enterprise ({planKey})" if periodLabel: description += f" – {periodLabel}" else: from modules.datamodels.datamodelSubscription import getPlan plan = getPlan(planKey) if not plan or not plan.budgetAiPerUserCHF or plan.budgetAiPerUserCHF <= 0: return None from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot subRoot = _getSubRoot() activeUsers = max(subRoot.countActiveUsers(mandateId), 1) amount = plan.budgetAiPerUserCHF * activeUsers description = f"AI-Budget ({planKey}, {activeUsers} User)" if periodLabel: description += f" – {periodLabel}" poolAccount = self.getOrCreateMandateAccount(mandateId) transaction = BillingTransaction( accountId=poolAccount["id"], transactionType=TransactionTypeEnum.CREDIT, amount=amount, description=description, referenceType=ReferenceTypeEnum.SUBSCRIPTION, referenceId=mandateId, ) created = self.createTransaction(transaction) logger.info( "AI-Budget credited mandate=%s plan=%s amount=%.2f CHF", mandateId, planKey, amount, ) return created def ensureActivationBudget(self, mandateId: str, planKey: str) -> Optional[Dict[str, Any]]: """Idempotent: credit the activation budget only if no SUBSCRIPTION credit exists yet.""" poolAccount = self.getMandateAccount(mandateId) if not poolAccount: return self.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung") existing = self.db.getRecordset( BillingTransaction, recordFilter={ "accountId": poolAccount["id"], "transactionType": TransactionTypeEnum.CREDIT.value, "referenceType": ReferenceTypeEnum.SUBSCRIPTION.value, }, ) if existing: return None return self.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung") def adjustAiBudgetForUserChange(self, mandateId: str, planKey: str, delta: int) -> Optional[Dict[str, Any]]: """Pro-rata AI budget adjustment when users are added/removed mid-cycle. delta > 0: user added -> CREDIT pro-rata portion delta < 0: user removed -> DEBIT pro-rata portion Skipped for enterprise subscriptions (fixed budget, no pro-rata).""" from modules.datamodels.datamodelSubscription import getPlan plan = getPlan(planKey) if not plan or not plan.budgetAiPerUserCHF or plan.budgetAiPerUserCHF <= 0: return None from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot subRoot = _getSubRoot() operative = subRoot.getOperativeForMandate(mandateId) if not operative: return None if operative.get("isEnterprise"): return None periodStart = operative.get("currentPeriodStart") periodEnd = operative.get("currentPeriodEnd") if not periodStart or not periodEnd: return None nowTs = datetime.now(timezone.utc).timestamp() totalSeconds = periodEnd - periodStart remainingSeconds = max(periodEnd - nowTs, 0) proRataFraction = remainingSeconds / totalSeconds if totalSeconds > 0 else 0 amount = round(abs(delta) * plan.budgetAiPerUserCHF * proRataFraction, 2) if amount <= 0: return None poolAccount = self.getOrCreateMandateAccount(mandateId) if delta > 0: txType = TransactionTypeEnum.CREDIT label = f"AI-Budget pro-rata +{abs(delta)} User ({planKey})" else: txType = TransactionTypeEnum.DEBIT label = f"AI-Budget pro-rata -{abs(delta)} User ({planKey})" transaction = BillingTransaction( accountId=poolAccount["id"], transactionType=txType, amount=amount, description=label, referenceType=ReferenceTypeEnum.SUBSCRIPTION, referenceId=mandateId, ) created = self.createTransaction(transaction) logger.info( "AI-Budget pro-rata %s mandate=%s delta=%+d amount=%.2f CHF (fraction=%.4f)", txType.value, mandateId, delta, amount, proRataFraction, ) return created # ========================================================================= # Workflow Cost Query # ========================================================================= def getWorkflowCost(self, workflowId: str) -> float: """Sum of all transaction amounts for a workflow.""" if not workflowId: return 0.0 transactions = self.db.getRecordset( BillingTransaction, recordFilter={"workflowId": workflowId} ) return sum(t.get("amount", 0.0) for t in transactions) # ========================================================================= # Statistics Operations # ========================================================================= def getUsageStatistics( self, accountId: str, periodType: PeriodTypeEnum, year: int, month: int = None ) -> List[Dict[str, Any]]: """ Get usage statistics for an account. Args: accountId: Account ID periodType: Period type (DAY, MONTH, YEAR) year: Year month: Month (for DAY period type) Returns: List of statistics dicts """ filterDict = { "accountId": accountId, "periodType": periodType.value } results = self.db.getRecordset(UsageStatistics, recordFilter=filterDict) # Filter by year filtered = [s for s in results if s.get("periodStart") and s["periodStart"].year == year] # Filter by month if specified if month and periodType == PeriodTypeEnum.DAY: filtered = [s for s in filtered if s["periodStart"].month == month] return sorted(filtered, key=lambda x: x.get("periodStart", date.min)) def calculateStatisticsFromTransactions( self, accountId: str, startDate: date, endDate: date ) -> Dict[str, Any]: """ Calculate statistics from transactions for a period. Args: accountId: Account ID startDate: Start date endDate: End date Returns: Statistics dict """ transactions = self.getTransactions(accountId, limit=10000, startDate=startDate, endDate=endDate) # Filter only DEBIT transactions (usage) debits = [t for t in transactions if t.get("transactionType") == TransactionTypeEnum.DEBIT.value] totalCost = sum(t.get("amount", 0) for t in debits) # Calculate by provider costByProvider = {} costByModel = {} for t in debits: provider = t.get("aicoreProvider", "unknown") costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0) model = t.get("aicoreModel", "unknown") costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0) # Calculate by feature costByFeature = {} for t in debits: feature = t.get("featureCode", "unknown") costByFeature[feature] = costByFeature.get(feature, 0) + t.get("amount", 0) return { "totalCostCHF": totalCost, "transactionCount": len(debits), "costByProvider": costByProvider, "costByModel": costByModel, "costByFeature": costByFeature } # ========================================================================= # Utility Methods # ========================================================================= def getBalancesForUser(self, userId: str) -> List[BillingBalanceResponse]: """ Get all billing balances for a user across mandates. Shows the mandate pool balance (shared budget visible to user). Args: userId: User ID Returns: List of BillingBalanceResponse """ from modules.interfaces.interfaceDbApp import getRootInterface balances = [] try: # Use rootInterface (privileged, SysAdmin context) to bypass RBAC # for mandate/user lookups. User access is verified via UserMandate membership. rootInterface = getRootInterface() userMandates = rootInterface.getUserMandates(userId) for um in userMandates: mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) if not mandateId: continue mandate = rootInterface.getMandate(mandateId) if not mandate or not getattr(mandate, "enabled", True): continue mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None) or f"NA({mandateId})" settings = self.getSettings(mandateId) if not settings: continue poolAccount = self.getOrCreateMandateAccount(mandateId) if not poolAccount: continue balance = poolAccount.get("balance", 0.0) warningThreshold = poolAccount.get("warningThreshold", 0.0) balances.append(BillingBalanceResponse( mandateId=mandateId, mandateName=mandateName, balance=balance, warningThreshold=warningThreshold, isWarning=balance <= warningThreshold, )) except Exception as e: logger.error(f"Error getting balances for user: {e}") return balances def getTransactionsForUser(self, userId: str, limit: int = 100) -> List[Dict[str, Any]]: """ Get all transactions for a user across all mandates they belong to. Since transactions are always recorded on user accounts, we query directly by user account - clean and simple. Args: userId: User ID limit: Maximum number of results Returns: List of transaction dicts """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface allTransactions = [] try: appInterface = getAppInterface(self.currentUser) userMandates = appInterface.getUserMandates(userId) for um in userMandates: mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) if not mandateId: continue settings = self.getSettings(mandateId) if not settings: continue # Get user's account in this mandate userAccount = self.getUserAccount(mandateId, userId) if not userAccount: continue transactions = self.getTransactions(userAccount["id"], limit=limit) mandate = appInterface.getMandate(mandateId) mandateName = f"NA({mandateId})" if mandate: mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None) or f"NA({mandateId})" for t in transactions: t["mandateId"] = mandateId t["mandateName"] = mandateName allTransactions.append(t) except Exception as e: logger.error(f"Error getting transactions for user: {e}") _sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getTransactionsForUser") return allTransactions[:limit] # ========================================================================= # Mandate View Operations (Admin-Level) # ========================================================================= def getMandateBalances(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]: """ Get mandate-level balances. Args: mandateIds: Optional list of mandate IDs to filter. If None, returns all. Returns: List of mandate balance dicts """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface balances = [] try: appInterface = getAppInterface(self.currentUser) # Get settings for filtering if mandateIds: allSettings = [self.getSettings(mId) for mId in mandateIds] allSettings = [s for s in allSettings if s] else: allSettings = self.db.getRecordset(BillingSettings) for settings in allSettings: mandateId = settings.get("mandateId") if not mandateId: continue mandate = appInterface.getMandate(mandateId) mandateName = f"NA({mandateId})" if mandate: mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None) or f"NA({mandateId})" allMandateAccounts = self.db.getRecordset( BillingAccount, recordFilter={"mandateId": mandateId} ) userCount = sum(1 for acc in allMandateAccounts if acc.get("userId")) poolAccount = self.getMandateAccount(mandateId) totalBalance = poolAccount.get("balance", 0.0) if poolAccount else 0.0 balances.append({ "mandateId": mandateId, "mandateName": mandateName, "totalBalance": totalBalance, "userCount": userCount, "warningThresholdPercent": settings.get("warningThresholdPercent", 10.0), }) except Exception as e: logger.error(f"Error getting mandate balances: {e}") return balances def getMandateTransactions(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]: """ Get all transactions for specified mandates. Args: mandateIds: Optional list of mandate IDs to filter. If None, returns all. limit: Maximum number of results Returns: List of transaction dicts with mandate context """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface allTransactions = [] try: appInterface = getAppInterface(self.currentUser) # Determine which mandates to query if mandateIds: targetMandateIds = mandateIds else: allSettings = self.db.getRecordset(BillingSettings) targetMandateIds = [s.get("mandateId") for s in allSettings if s.get("mandateId")] for mandateId in targetMandateIds: transactions = self.getTransactionsByMandate(mandateId, limit=limit) mandate = appInterface.getMandate(mandateId) mandateName = f"NA({mandateId})" if mandate: mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None) or f"NA({mandateId})" for t in transactions: t["mandateId"] = mandateId t["mandateName"] = mandateName allTransactions.append(t) except Exception as e: logger.error(f"Error getting mandate transactions: {e}") # Sort by creation date descending and limit _sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getMandateTransactions") return allTransactions[:limit] # ========================================================================= # User View Operations (User-Level with RBAC) # ========================================================================= def getUserBalancesForMandates(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]: """ Get all user-level balances for specified mandates. Args: mandateIds: Optional list of mandate IDs to filter. If None, returns all. Returns: List of user balance dicts with mandate and user context """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface balances = [] try: appInterface = getAppInterface(self.currentUser) allAccounts = self.db.getRecordset(BillingAccount) allAccounts = [acc for acc in allAccounts if acc.get("userId")] # Filter by mandate if specified if mandateIds: allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds] # Get all relevant settings in one query settingsMap = {} allSettings = self.db.getRecordset(BillingSettings) for s in allSettings: settingsMap[s.get("mandateId")] = s userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId"))) userMap = {} for userId in userIds: user = appInterface.getUser(userId) if user: displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None) userMap[userId] = displayName or username or f"NA({userId})" mandateMap = {} mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId"))) for mandateId in mandateIdList: mandate = appInterface.getMandate(mandateId) if mandate: mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None) mandateMap[mandateId] = mandateName or f"NA({mandateId})" for account in allAccounts: mandateId = account.get("mandateId") userId = account.get("userId") if not mandateId or not userId: continue settings = settingsMap.get(mandateId) if not settings: continue balance = account.get("balance", 0.0) warningThreshold = account.get("warningThreshold", 0.0) balances.append({ "accountId": account.get("id"), "mandateId": mandateId, "mandateName": mandateMap.get(mandateId) or (f"NA({mandateId})" if mandateId else None), "userId": userId, "userName": userMap.get(userId) or (f"NA({userId})" if userId else None), "balance": balance, "warningThreshold": warningThreshold, "isWarning": balance <= warningThreshold, "enabled": account.get("enabled", True) }) except Exception as e: logger.error(f"Error getting user balances for mandates: {e}") return balances @staticmethod def _mapPaginationColumns(pagination: PaginationParams) -> PaginationParams: """Remap frontend column names to DB column names in filters and sort.""" _COL_MAP: dict = {} _ENRICHED_COLS = {"mandateName", "userName", "mandateId", "userId"} p = copy.deepcopy(pagination) if p.filters: mapped = {} for k, v in p.filters.items(): if k in _ENRICHED_COLS: continue mapped[_COL_MAP.get(k, k)] = v p.filters = mapped if p.sort: mapped = [] for s in p.sort: field = s.get("field", "") if isinstance(s, dict) else getattr(s, "field", "") if field in _ENRICHED_COLS: continue newField = _COL_MAP.get(field, field) if isinstance(s, dict): mapped.append({**s, "field": newField}) else: mapped.append({"field": newField, "direction": getattr(s, "direction", "asc")}) p.sort = mapped if mapped else [{"field": "sysCreatedAt", "direction": "desc"}] return p def getTransactionsForMandatesPaginated( self, mandateIds: Optional[List[str]], pagination: PaginationParams, scope: str = "all", userId: Optional[str] = None, ) -> PaginatedResult: """ SQL-level paginated transactions across multiple mandates. Single SQL query with WHERE accountId = ANY(...), ORDER BY, LIMIT/OFFSET. Enrichment (userName, mandateName) only for the returned page. """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface try: mappedPagination = self._mapPaginationColumns(pagination) allAccounts = self.db.getRecordset(BillingAccount) if mandateIds: allAccounts = [a for a in allAccounts if a.get("mandateId") in set(mandateIds)] accountIds = [a.get("id") for a in allAccounts if a.get("id")] if not accountIds: return PaginatedResult(items=[], totalItems=0, totalPages=0) # Extract free-text search term and run a custom query that covers # enriched columns (mandateName, userName) and the numeric amount # column. The generic SQL search only covers TEXT columns of the # BillingTransaction table, which excludes these fields. searchTerm: Optional[str] = None if mappedPagination and mappedPagination.filters: raw = mappedPagination.filters.get("search") if isinstance(raw, str) and raw.strip(): searchTerm = raw.strip() if searchTerm: searchResult = self._searchTransactionsPaginated( allAccounts=allAccounts, accountIds=accountIds, userId=userId, searchTerm=searchTerm, pagination=mappedPagination, ) pageItems = searchResult["items"] totalItems = searchResult["totalItems"] totalPages = searchResult["totalPages"] else: recordFilter: Dict[str, Any] = {"accountId": accountIds} if userId: recordFilter["createdByUserId"] = userId result = self.db.getRecordsetPaginated( BillingTransaction, pagination=mappedPagination, recordFilter=recordFilter, ) from modules.dbHelpers.fkLabelResolver import enrichRowsWithFkLabels enrichRowsWithFkLabels( result.get("items", []) if isinstance(result, dict) else result.items, BillingTransaction, db=self.db, ) pageItems = result.get("items", []) if isinstance(result, dict) else result.items totalItems = result.get("totalItems", 0) if isinstance(result, dict) else result.totalItems totalPages = result.get("totalPages", 0) if isinstance(result, dict) else result.totalPages accountMap = {a.get("id"): a for a in allAccounts} pageUserIds = set() pageMandateIds = set() for t in pageItems: accId = t.get("accountId") acc = accountMap.get(accId, {}) mid = acc.get("mandateId") uid = t.get("createdByUserId") or acc.get("userId") if uid: pageUserIds.add(uid) if mid: pageMandateIds.add(mid) appInterface = getAppInterface(self.currentUser) userMap: Dict[str, str] = {} if pageUserIds: users = appInterface.getUsersByIds(list(pageUserIds)) for uid, u in users.items(): dn = getattr(u, "displayName", None) or getattr(u, "username", None) or f"NA({uid})" userMap[uid] = dn mandateMap: Dict[str, str] = {} if pageMandateIds: mandates = appInterface.getMandatesByIds(list(pageMandateIds)) for mid, m in mandates.items(): mandateMap[mid] = getattr(m, "label", None) or getattr(m, "name", None) or f"NA({mid})" enriched = [] for t in pageItems: row = dict(t) accId = row.get("accountId") acc = accountMap.get(accId, {}) mid = acc.get("mandateId") txUserId = row.get("createdByUserId") or acc.get("userId") row["mandateId"] = mid row["mandateName"] = mandateMap.get(mid) or (f"NA({mid})" if mid else None) row["userId"] = txUserId row["userName"] = userMap.get(txUserId) or (f"NA({txUserId})" if txUserId else None) enriched.append(row) return PaginatedResult(items=enriched, totalItems=totalItems, totalPages=totalPages) except Exception as e: logger.error(f"Error in getTransactionsForMandatesPaginated: {e}") return PaginatedResult(items=[], totalItems=0, totalPages=0) def _searchTransactionsPaginated( self, allAccounts: List[Dict[str, Any]], accountIds: List[str], userId: Optional[str], searchTerm: str, pagination: PaginationParams, ) -> Dict[str, Any]: """ Custom paginated search for BillingTransaction that also covers the enriched columns `mandateName` and `userName` as well as the numeric `amount` column. Resolves matching mandate/user IDs via the app DB first, then builds a single SQL query with OR-combined conditions. """ from modules.connectors.connectorDbPostgre import getModelFields, parseRecordFields from modules.datamodels.datamodelUam import UserInDB from modules.interfaces.interfaceDbApp import getInterface as getAppInterface table = BillingTransaction.__name__ fields = getModelFields(BillingTransaction) pattern = f"%{searchTerm}%" # Resolve matching user / mandate IDs via the app DB (which is separate # from the billing DB and hosts UserInDB / Mandate tables). matchingUserIds: List[str] = [] matchingMandateIds: List[str] = [] try: appInterface = getAppInterface(self.currentUser) appInterface.db._ensure_connection() with appInterface.db.borrowCursor() as cur: if appInterface.db._ensureTableExists(UserInDB): cur.execute( 'SELECT "id" FROM "UserInDB" WHERE ' 'COALESCE("username", \'\') ILIKE %s OR ' 'COALESCE("fullName", \'\') ILIKE %s OR ' 'COALESCE("email", \'\') ILIKE %s', (pattern, pattern, pattern), ) matchingUserIds = [r["id"] for r in cur.fetchall() if r.get("id")] if appInterface.db._ensureTableExists(Mandate): cur.execute( 'SELECT "id" FROM "Mandate" WHERE ' 'COALESCE("label", \'\') ILIKE %s OR ' 'COALESCE("name", \'\') ILIKE %s', (pattern, pattern), ) matchingMandateIds = [r["id"] for r in cur.fetchall() if r.get("id")] except Exception as e: logger.warning(f"_searchTransactionsPaginated: user/mandate resolution failed: {e}") matchingAccountIds = [ a.get("id") for a in allAccounts if a.get("id") and a.get("mandateId") in set(matchingMandateIds) ] # Try to interpret the search term as a number for amount matching. amountVal: Optional[float] = None try: amountVal = float(searchTerm.replace(",", ".")) except Exception: amountVal = None whereParts: List[str] = ['"accountId" = ANY(%s)'] whereValues: List[Any] = [accountIds] if userId: whereParts.append('"createdByUserId" = %s') whereValues.append(userId) # Apply non-search filters from pagination (reuse existing builder for # everything except the `search` key which we handle explicitly). paginationWithoutSearch = copy.deepcopy(pagination) if pagination else None if paginationWithoutSearch and paginationWithoutSearch.filters: paginationWithoutSearch.filters = { k: v for k, v in paginationWithoutSearch.filters.items() if k != "search" } orParts: List[str] = [] orValues: List[Any] = [] textCols = [c for c, t in fields.items() if t == "TEXT"] for col in textCols: orParts.append(f'COALESCE("{col}"::TEXT, \'\') ILIKE %s') orValues.append(pattern) if matchingUserIds: orParts.append('"createdByUserId" = ANY(%s)') orValues.append(matchingUserIds) if matchingAccountIds: orParts.append('"accountId" = ANY(%s)') orValues.append(matchingAccountIds) orParts.append('"amount"::TEXT ILIKE %s') orValues.append(pattern) if amountVal is not None: orParts.append('"amount" = %s') orValues.append(amountVal) whereParts.append(f"({' OR '.join(orParts)})") whereValues.extend(orValues) # Apply remaining structured filters via the generic helper by feeding # it a dummy pagination that does NOT include LIMIT/OFFSET. We only # need the WHERE contribution for the non-search filters here. extraWhere = "" extraValues: List[Any] = [] if paginationWithoutSearch and paginationWithoutSearch.filters: try: fromPagination = copy.deepcopy(paginationWithoutSearch) fromPagination.sort = [] fromPagination.page = 1 fromPagination.pageSize = 1 ew, _, _, values, _ = self.db._buildPaginationClauses( BillingTransaction, fromPagination, recordFilter=None ) if ew: extraWhere = ew.replace(" WHERE ", " AND ", 1) extraValues = list(values) except Exception as e: logger.warning(f"_searchTransactionsPaginated: extra-filter build failed: {e}") whereClause = " WHERE " + " AND ".join(whereParts) + extraWhere whereValues.extend(extraValues) # Build ORDER BY from pagination.sort validColumns = set(fields.keys()) orderParts: List[str] = [] if pagination and pagination.sort: for sf in pagination.sort: sfField = sf.get("field") if isinstance(sf, dict) else getattr(sf, "field", None) sfDir = sf.get("direction", "asc") if isinstance(sf, dict) else getattr(sf, "direction", "asc") if sfField and sfField in validColumns: direction = "DESC" if str(sfDir).lower() == "desc" else "ASC" colType = fields.get(sfField, "TEXT") if colType == "BOOLEAN": orderParts.append(f'COALESCE("{sfField}", FALSE) {direction}') else: orderParts.append(f'"{sfField}" {direction} NULLS LAST') if not orderParts: orderParts.append('"id"') orderClause = " ORDER BY " + ", ".join(orderParts) pageSize = pagination.pageSize if pagination else 50 page = pagination.page if pagination else 1 offset = (page - 1) * pageSize limitClause = f" LIMIT {pageSize} OFFSET {offset}" try: self.db._ensure_connection() with self.db.borrowCursor() as cur: countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}' cur.execute(countSql, whereValues) totalItems = cur.fetchone()["count"] dataSql = f'SELECT * FROM "{table}"{whereClause}{orderClause}{limitClause}' cur.execute(dataSql, whereValues) records = [dict(row) for row in cur.fetchall()] for rec in records: parseRecordFields(rec, fields, f"search table {table}") totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0 return {"items": records, "totalItems": totalItems, "totalPages": totalPages} except Exception as e: logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True) # Rollback is handled by `borrowCursor()` context manager on exit. return {"items": [], "totalItems": 0, "totalPages": 0} def _buildScopeFilter( self, mandateIds: Optional[List[str]], scope: str = "all", userId: Optional[str] = None, startTs: Optional[float] = None, endTs: Optional[float] = None, ) -> tuple: """Build WHERE clause parts for scoped transaction queries. Returns (conditions, values, accountIds).""" allAccounts = self.db.getRecordset(BillingAccount) if mandateIds: mandateSet = set(mandateIds) allAccounts = [a for a in allAccounts if a.get("mandateId") in mandateSet] accountIds = [a.get("id") for a in allAccounts if a.get("id")] if not accountIds: return [], [], [], allAccounts conditions = ['"accountId" = ANY(%s)', '"transactionType" = %s'] values: list = [accountIds, "DEBIT"] if userId: conditions.append('"createdByUserId" = %s') values.append(userId) if startTs is not None: conditions.append('"sysCreatedAt" >= %s') values.append(startTs) if endTs is not None: conditions.append('"sysCreatedAt" < %s') values.append(endTs) return conditions, values, accountIds, allAccounts def getTransactionStatisticsAggregated( self, mandateIds: Optional[List[str]], scope: str = "all", userId: Optional[str] = None, startTs: Optional[float] = None, endTs: Optional[float] = None, bucketSize: str = "month", ) -> Dict[str, Any]: """ Pure SQL aggregation for statistics. No row-level loading. `bucketSize` controls only the time-series aggregation granularity (`'day' | 'month' | 'year'`); the date range is set via `startTs`/`endTs`. Returns: totalCost, transactionCount, costByProvider, costByModel, costByFeature, costByAccountId, timeSeries """ table = BillingTransaction.__name__ try: if not self.db._ensureTableExists(BillingTransaction): return self._emptyStats() conditions, values, accountIds, allAccounts = self._buildScopeFilter( mandateIds, scope, userId, startTs, endTs ) if not accountIds: return self._emptyStats() whereClause = " WHERE " + " AND ".join(conditions) self.db._ensure_connection() result: Dict[str, Any] = {} with self.db.borrowCursor() as cur: # 1) Totals cur.execute( f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}', values, ) row = cur.fetchone() result["totalCost"] = round(float(row["total"]), 4) result["transactionCount"] = int(row["cnt"]) # 2) GROUP BY aicoreProvider cur.execute( f'SELECT COALESCE("aicoreProvider", \'unknown\') AS grp, SUM("amount") AS total ' f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC', values, ) result["costByProvider"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()} # 3) GROUP BY aicoreModel cur.execute( f'SELECT COALESCE("aicoreModel", \'unknown\') AS grp, SUM("amount") AS total ' f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC', values, ) result["costByModel"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()} # 4) GROUP BY accountId (will be enriched to mandateName by caller) cur.execute( f'SELECT "accountId" AS grp, SUM("amount") AS total ' f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC', values, ) result["costByAccountId"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()} # 5) GROUP BY accountId + featureCode (for costByFeature) cur.execute( f'SELECT "accountId", COALESCE("featureCode", \'unknown\') AS fc, SUM("amount") AS total ' f'FROM "{table}"{whereClause} GROUP BY "accountId", fc ORDER BY total DESC', values, ) result["costByAccountFeature"] = [ {"accountId": r["accountId"], "featureCode": r["fc"], "total": round(float(r["total"]), 4)} for r in cur.fetchall() ] # 6) Time series via DATE_TRUNC on epoch timestamp _bucketSpec = { "day": ("day", "%Y-%m-%d"), "month": ("month", "%Y-%m"), "year": ("year", "%Y"), }.get(bucketSize) if _bucketSpec is None: raise ValueError( f"Invalid bucketSize: {bucketSize!r} (expected day|month|year)" ) _truncUnit, _labelFormat = _bucketSpec truncExpr = f"DATE_TRUNC('{_truncUnit}', TO_TIMESTAMP(\"sysCreatedAt\"))" cur.execute( f'SELECT {truncExpr} AS bucket, SUM("amount") AS total, COUNT(*) AS cnt ' f'FROM "{table}"{whereClause} AND "sysCreatedAt" IS NOT NULL ' f'GROUP BY bucket ORDER BY bucket', values, ) timeSeries = [] for r in cur.fetchall(): bucket = r["bucket"] label = bucket.strftime(_labelFormat) if bucket else "unknown" timeSeries.append({ "date": label, "cost": round(float(r["total"]), 4), "count": int(r["cnt"]), }) result["timeSeries"] = timeSeries # Commit/rollback are handled by `borrowCursor()` context manager. result["_allAccounts"] = allAccounts return result except Exception as e: logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True) return self._emptyStats() @staticmethod def _emptyStats() -> Dict[str, Any]: return { "totalCost": 0.0, "transactionCount": 0, "costByProvider": {}, "costByModel": {}, "costByAccountId": {}, "costByAccountFeature": [], "timeSeries": [], "_allAccounts": [], } def getTransactionDistinctValues( self, mandateIds: Optional[List[str]], column: str, pagination: Optional[PaginationParams] = None, scope: str = "all", userId: Optional[str] = None, ) -> List[str]: """SQL DISTINCT for filter-values on BillingTransaction, scoped by mandates.""" _COLUMN_MAP = { "mandateId": "accountId", "mandateName": "accountId", } dbColumn = _COLUMN_MAP.get(column, column) mappedPagination = self._mapPaginationColumns(pagination) if pagination else None try: allAccounts = self.db.getRecordset(BillingAccount) if mandateIds: allAccounts = [a for a in allAccounts if a.get("mandateId") in set(mandateIds)] accountIds = [a.get("id") for a in allAccounts if a.get("id")] if not accountIds: return [] recordFilter: Dict[str, Any] = {"accountId": accountIds} if userId: recordFilter["createdByUserId"] = userId if column in ("mandateName", "userName"): return self._getEnrichedDistinctValues(column, allAccounts, recordFilter, mappedPagination) return self.db.getDistinctColumnValues( BillingTransaction, dbColumn, mappedPagination, recordFilter ) except Exception as e: logger.error(f"Error in getTransactionDistinctValues({column}): {e}") return [] def _getEnrichedDistinctValues( self, column: str, allAccounts: List[Dict], recordFilter: Dict[str, Any], pagination: Optional[PaginationParams], ) -> List[str]: """Resolve enriched columns (mandateName, userName) via batch lookup.""" from modules.interfaces.interfaceDbApp import getInterface as getAppInterface if column == "mandateName": mandateIds = list({a.get("mandateId") for a in allAccounts if a.get("mandateId")}) appInterface = getAppInterface(self.currentUser) mandates = appInterface.getMandatesByIds(mandateIds) return sorted( {getattr(m, "label", None) or getattr(m, "name", None) or f"NA({mid})" for mid, m in mandates.items()}, key=lambda v: v.lower(), ) if column == "userName": dbCol = "createdByUserId" values = self.db.getDistinctColumnValues(BillingTransaction, dbCol, pagination, recordFilter) if not values: return [] appInterface = getAppInterface(self.currentUser) users = appInterface.getUsersByIds(values) return sorted( {getattr(u, "displayName", None) or getattr(u, "username", None) or f"NA({uid})" for uid, u in users.items()}, key=lambda v: v.lower(), ) return [] def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]: """ Get all transactions for specified mandates. All usage transactions are on user accounts (audit trail). Args: mandateIds: Optional list of mandate IDs to filter. If None, returns all. limit: Maximum number of results Returns: List of transaction dicts with mandate and user context """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface allTransactions = [] try: appInterface = getAppInterface(self.currentUser) # Get ALL accounts (both USER and MANDATE types) to cover all billing models allAccounts = self.db.getRecordset(BillingAccount) # Filter by mandate if specified if mandateIds: allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds] # Build account to user/mandate mapping accountMap = {} for acc in allAccounts: accountMap[acc.get("id")] = { "mandateId": acc.get("mandateId"), "userId": acc.get("userId") } userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId"))) userMap = {} for userId in userIds: user = appInterface.getUser(userId) if user: displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None) userMap[userId] = displayName or username or f"NA({userId})" mandateMap = {} mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId"))) for mandateId in mandateIdList: mandate = appInterface.getMandate(mandateId) if mandate: mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None) mandateMap[mandateId] = mandateName or f"NA({mandateId})" # Get transactions for all accounts and collect createdByUserIds rawTransactions = [] for account in allAccounts: accountId = account.get("id") if not accountId: continue transactions = self.getTransactions(accountId, limit=limit) accountInfo = accountMap.get(accountId, {}) mandateId = accountInfo.get("mandateId") accountUserId = accountInfo.get("userId") for t in transactions: t["_accountUserId"] = accountUserId t["_accountMandateId"] = mandateId rawTransactions.append(t) # Resolve createdByUserIds that are not yet in userMap extraUserIds = set() for t in rawTransactions: cbUserId = t.get("createdByUserId") if cbUserId and cbUserId not in userMap: extraUserIds.add(cbUserId) for uid in extraUserIds: user = appInterface.getUser(uid) if user: displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None) userMap[uid] = displayName or username or f"NA({uid})" for t in rawTransactions: mandateId = t.pop("_accountMandateId", None) accountUserId = t.pop("_accountUserId", None) t["mandateId"] = mandateId t["mandateName"] = mandateMap.get(mandateId) or (f"NA({mandateId})" if mandateId else None) txUserId = t.get("createdByUserId") or accountUserId t["userId"] = txUserId t["userName"] = userMap.get(txUserId) or (f"NA({txUserId})" if txUserId else None) allTransactions.append(t) except Exception as e: logger.error(f"Error getting user transactions for mandates: {e}") # Sort by creation date descending and limit _sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getUserTransactionsForMandates") return allTransactions[:limit] def deleteMandateData(self, mandateId: str) -> None: """Delete all billing data for a mandate (accounts, transactions, settings). Used as cascade during mandate hard-delete via the onMandateDelete lifecycle hook. """ billingAccounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId}) for acc in billingAccounts: accTxs = self.db.getRecordset(BillingTransaction, recordFilter={"accountId": acc.get("id")}) for tx in accTxs: self.db.recordDelete(BillingTransaction, tx.get("id")) self.db.recordDelete(BillingAccount, acc.get("id")) billingSettings = self.db.getRecordset(BillingSettings, recordFilter={"mandateId": mandateId}) for bs in billingSettings: self.db.recordDelete(BillingSettings, bs.get("id")) if billingAccounts or billingSettings: logger.info("deleteMandateData: deleted billing data for mandate %s", mandateId) def onMandateDelete(mandateId: str, instances: list) -> None: """Lifecycle hook: cascade-delete billing data when a mandate is hard-deleted.""" getRootInterface().deleteMandateData(mandateId) def onUserMandateCreate(userId: str, mandateId: str) -> None: """Lifecycle hook: ensure user has a billing audit account when added to a mandate.""" try: billingInterface = getRootInterface() settings = billingInterface.getSettings(mandateId) if not settings: return billingInterface.getOrCreateUserAccount(mandateId, userId, initialBalance=0.0) logger.info("Ensured billing audit account for user %s in mandate %s", userId, mandateId) except Exception as e: logger.warning("Failed to create billing account for user %s (non-critical): %s", userId, e) def onUserMandateDelete(userId: str, mandateId: str) -> None: """Lifecycle hook: pro-rata AI budget debit when user is removed from a mandate.""" _adjustAiBudgetForUserChange(mandateId, delta=-1) def onUserBudgetAdjust(mandateId: str, delta: int) -> None: """Lifecycle hook: pro-rata AI budget credit/debit for user membership changes.""" _adjustAiBudgetForUserChange(mandateId, delta) def onMandateProvision(mandateId: str, planKey: str) -> None: """Lifecycle hook: create billing settings and activation budget for a new mandate.""" try: billingRoot = getRootInterface() billingRoot.getOrCreateSettings(mandateId) billingRoot.ensureActivationBudget(mandateId, planKey) except Exception as e: logger.error("Initial billing setup failed for mandate %s (plan=%s): %s", mandateId, planKey, e) def onStorageChanged(mandateId: str) -> None: """Lifecycle hook: reconcile storage billing after knowledge content changes.""" try: getRootInterface().reconcileMandateStorageBilling(mandateId) except Exception as e: logger.warning("reconcileMandateStorageBilling failed for mandate %s: %s", mandateId, e) def _adjustAiBudgetForUserChange(mandateId: str, delta: int) -> None: """Pro-rata AI budget credit/debit when a user is added or removed mid-cycle.""" try: from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface from modules.security.rootAccess import getRootUser rootUser = getRootUser() subIf = getSubInterface(rootUser, mandateId) operative = subIf.getOperativeForMandate(mandateId) if not operative: return planKey = operative.get("planKey", "") billingIf = getInterface(rootUser) billingIf.adjustAiBudgetForUserChange(mandateId, planKey, delta) except Exception as e: logger.debug("AI budget adjustment skipped: %s", e)