# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Interface for Billing operations. Manages billing accounts, transactions, and usage statistics. All billing data is stored in the poweron_billing database. """ import logging from typing import Dict, Any, List, Optional, Union from datetime import date, datetime, timedelta, timezone import uuid from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG from modules.shared.timeUtils import getUtcTimestamp from modules.datamodels.datamodelUam import User, Mandate from modules.datamodels.datamodelMembership import UserMandate from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult from modules.datamodels.datamodelBilling import ( BillingAccount, BillingTransaction, BillingSettings, StripeWebhookEvent, UsageStatistics, TransactionTypeEnum, ReferenceTypeEnum, PeriodTypeEnum, BillingBalanceResponse, BillingCheckResult, STORAGE_PRICE_PER_GB_CHF, ) logger = logging.getLogger(__name__) def _logBillingTransactionsMissingSysCreatedAt(rows: List[Dict[str, Any]], context: str) -> None: """Log ERROR when sysCreatedAt is missing; does not raise.""" missingIds = [r.get("id") for r in rows if r.get("sysCreatedAt") is None] if not missingIds: return cap = 40 sample = missingIds[:cap] suffix = f"; ... (+{len(missingIds) - cap} more)" if len(missingIds) > cap else "" logger.error( "BillingTransaction missing sysCreatedAt (%s): count=%s; transactionIds=%s%s", context, len(missingIds), sample, suffix, ) def _numericSysCreatedAtForSort(row: Dict[str, Any]) -> float: v = row["sysCreatedAt"] if isinstance(v, datetime): return v.timestamp() return float(v) def _sortBillingTransactionsBySysCreatedAtDesc(rows: List[Dict[str, Any]], context: str) -> None: _logBillingTransactionsMissingSysCreatedAt(rows, context) valid = [r for r in rows if r.get("sysCreatedAt") is not None] invalid = [r for r in rows if r.get("sysCreatedAt") is None] valid.sort(key=_numericSysCreatedAtForSort, reverse=True) rows[:] = valid + invalid def _getAppDatabaseConnector() -> DatabaseConnector: """App DB connector (same config as UserMandate reads in this module).""" return DatabaseConnector( dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"), dbHost=APP_CONFIG.get("DB_HOST", "localhost"), dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), dbUser=APP_CONFIG.get("DB_USER"), dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), ) def _getRootMandateIdFromAppDb(appDb: DatabaseConnector) -> Optional[str]: """Resolve root mandate id (name='root', isSystem=True) from app database.""" try: rows = appDb.getRecordset(Mandate, recordFilter={"name": "root", "isSystem": True}) if rows: rid = rows[0].get("id") return str(rid) if rid is not None else None except Exception as e: logger.warning("Could not resolve root mandate id from app DB: %s", e) return None _cachedRootMandateId: Optional[str] = None _rootMandateIdCacheResolved: bool = False def _getCachedRootMandateId() -> Optional[str]: """Lazy-cached root mandate id (name=root, isSystem=True) for hot paths.""" global _cachedRootMandateId, _rootMandateIdCacheResolved if not _rootMandateIdCacheResolved: appDb = _getAppDatabaseConnector() _cachedRootMandateId = _getRootMandateIdFromAppDb(appDb) _rootMandateIdCacheResolved = True return _cachedRootMandateId # Singleton factory for BillingObjects instances _billingInterfaces: Dict[str, "BillingObjects"] = {} # Database name for billing BILLING_DATABASE = "poweron_billing" def getInterface(currentUser: User, mandateId: str = None) -> "BillingObjects": """ Factory function to get or create a BillingObjects instance. Args: currentUser: Current user object mandateId: Mandate ID for context Returns: BillingObjects instance """ cacheKey = f"{currentUser.id}_{mandateId}" if cacheKey not in _billingInterfaces: _billingInterfaces[cacheKey] = BillingObjects(currentUser, mandateId) else: _billingInterfaces[cacheKey].setUserContext(currentUser, mandateId) return _billingInterfaces[cacheKey] def _getRootInterface() -> "BillingObjects": """Get interface with system access for bootstrap operations.""" from modules.security.rootAccess import getRootUser rootUser = getRootUser() return BillingObjects(rootUser, mandateId=None) class BillingObjects: """ Interface for billing operations. Manages accounts, transactions, settings, and statistics. """ def __init__(self, currentUser: Optional[User] = None, mandateId: str = None): """ Initialize the billing interface. Args: currentUser: Current user object mandateId: Mandate ID for context """ self.currentUser = currentUser self.userId = currentUser.id if currentUser else None self.mandateId = mandateId # Initialize database connection self._initializeDatabase() def _initializeDatabase(self): """Initialize database connection.""" self.db = DatabaseConnector( dbDatabase=BILLING_DATABASE, dbHost=APP_CONFIG.get('DB_HOST', 'localhost'), dbPort=int(APP_CONFIG.get('DB_PORT', '5432')), dbUser=APP_CONFIG.get('DB_USER'), dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET') ) def setUserContext(self, currentUser: User, mandateId: str = None): """ Update user context. Args: currentUser: Current user object mandateId: Mandate ID for context """ self.currentUser = currentUser self.userId = currentUser.id if currentUser else None self.mandateId = mandateId # ========================================================================= # BillingSettings Operations # ========================================================================= def getSettings(self, mandateId: str) -> Optional[Dict[str, Any]]: """ Get billing settings for a mandate. Args: mandateId: Mandate ID Returns: BillingSettings dict or None if not found """ try: results = self.db.getRecordset( BillingSettings, recordFilter={"mandateId": mandateId} ) if not results: return None return dict(results[0]) except Exception as e: logger.error(f"Error getting billing settings: {e}") return None def createSettings(self, settings: BillingSettings) -> Dict[str, Any]: """ Create billing settings for a mandate. Args: settings: BillingSettings object Returns: Created settings dict """ settingsDict = settings.model_dump(exclude_none=True) return self.db.recordCreate(BillingSettings, settingsDict) def updateSettings(self, settingsId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]: """ Update billing settings. Args: settingsId: Settings ID updates: Fields to update Returns: Updated settings dict or None """ return self.db.recordModify(BillingSettings, settingsId, updates) def getOrCreateSettings(self, mandateId: str) -> Dict[str, Any]: """ Get or create billing settings for a mandate. Args: mandateId: Mandate ID Returns: BillingSettings dict """ existing = self.getSettings(mandateId) if existing: return existing settings = BillingSettings( mandateId=mandateId, warningThresholdPercent=10.0, notifyOnWarning=True, ) return self.createSettings(settings) # ========================================================================= # BillingAccount Operations # ========================================================================= def getAccount(self, accountId: str) -> Optional[Dict[str, Any]]: """Get a billing account by ID.""" try: results = self.db.getRecordset( BillingAccount, recordFilter={"id": accountId} ) return results[0] if results else None except Exception as e: logger.error(f"Error getting billing account: {e}") return None def getMandateAccount(self, mandateId: str) -> Optional[Dict[str, Any]]: """ Get the mandate-level billing account. Args: mandateId: Mandate ID Returns: BillingAccount dict or None """ try: results = self.db.getRecordset( BillingAccount, recordFilter={ "mandateId": mandateId, "userId": None } ) return results[0] if results else None except Exception as e: logger.error(f"Error getting mandate account: {e}") return None def getUserAccount(self, mandateId: str, userId: str) -> Optional[Dict[str, Any]]: """ Get a user-level billing account within a mandate. Args: mandateId: Mandate ID userId: User ID Returns: BillingAccount dict or None """ try: results = self.db.getRecordset( BillingAccount, recordFilter={ "mandateId": mandateId, "userId": userId } ) return results[0] if results else None except Exception as e: logger.error(f"Error getting user account: {e}") return None def getAccountsByMandate(self, mandateId: str) -> List[Dict[str, Any]]: """ Get all billing accounts for a mandate. Args: mandateId: Mandate ID Returns: List of BillingAccount dicts """ try: return self.db.getRecordset( BillingAccount, recordFilter={"mandateId": mandateId} ) except Exception as e: logger.error(f"Error getting accounts for mandate: {e}") return [] def createAccount(self, account: BillingAccount) -> Dict[str, Any]: """ Create a new billing account. Args: account: BillingAccount object Returns: Created account dict """ accountDict = account.model_dump(exclude_none=True) return self.db.recordCreate(BillingAccount, accountDict) def updateAccountBalance(self, accountId: str, newBalance: float) -> Optional[Dict[str, Any]]: """ Update account balance atomically. Args: accountId: Account ID newBalance: New balance value Returns: Updated account dict or None """ return self.db.recordModify(BillingAccount, accountId, {"balance": newBalance}) def getOrCreateMandateAccount(self, mandateId: str, initialBalance: float = 0.0) -> Dict[str, Any]: """ Get or create a mandate-level billing account. Args: mandateId: Mandate ID initialBalance: Initial balance if creating Returns: BillingAccount dict """ existing = self.getMandateAccount(mandateId) if existing: return existing account = BillingAccount( mandateId=mandateId, balance=initialBalance, enabled=True ) return self.createAccount(account) def getOrCreateUserAccount(self, mandateId: str, userId: str, initialBalance: float = 0.0) -> Dict[str, Any]: """ Get or create a user-level billing account. Args: mandateId: Mandate ID userId: User ID initialBalance: Initial balance if creating Returns: BillingAccount dict """ existing = self.getUserAccount(mandateId, userId) if existing: return existing account = BillingAccount( mandateId=mandateId, userId=userId, balance=initialBalance, enabled=True ) created = self.createAccount(account) # If initial balance > 0, create a SYSTEM credit transaction if initialBalance > 0: self.createTransaction(BillingTransaction( accountId=created["id"], transactionType=TransactionTypeEnum.CREDIT, amount=initialBalance, description="Initial credit for new user", referenceType=ReferenceTypeEnum.SYSTEM )) return created def ensureAllMandateSettingsExist(self) -> int: """ Efficiently ensure all mandates have billing settings. Creates default settings (0 CHF) for mandates without settings. Uses bulk queries to minimize database connections. Returns: Number of settings created """ try: settingsCreated = 0 # Step 1: Get all existing billing settings in one query (from billing DB) allSettings = self.db.getRecordset(BillingSettings) existingMandateIds = set(s.get("mandateId") for s in allSettings if s.get("mandateId")) # Step 2: Get all mandates from APP database (separate connection) appDb = DatabaseConnector( dbDatabase=APP_CONFIG.get('DB_DATABASE', 'poweron_app'), dbHost=APP_CONFIG.get('DB_HOST', 'localhost'), dbPort=int(APP_CONFIG.get('DB_PORT', '5432')), dbUser=APP_CONFIG.get('DB_USER'), dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET') ) allMandates = appDb.getRecordset(Mandate, recordFilter={"enabled": True}) # Step 3: Create settings for mandates that don't have them for mandate in allMandates: mandateId = mandate.get("id") if not mandateId or mandateId in existingMandateIds: continue settings = BillingSettings( mandateId=mandateId, warningThresholdPercent=10.0, notifyOnWarning=True, ) self.createSettings(settings) existingMandateIds.add(mandateId) settingsCreated += 1 if settingsCreated > 0: logger.info(f"Created {settingsCreated} missing billing settings for mandates") return settingsCreated except Exception as e: logger.error(f"Error ensuring mandate settings exist: {e}") return 0 def ensureAllUserAccountsExist(self) -> int: """ Ensure all users across all mandates have billing accounts. User accounts are always created for audit trail with initial balance 0.0. Uses bulk queries to minimize database connections. Returns: Number of accounts created """ try: accountsCreated = 0 appDb = _getAppDatabaseConnector() allSettings = self.db.getRecordset(BillingSettings) billingMandateIds = set( s.get("mandateId") for s in allSettings if s.get("mandateId") ) if not billingMandateIds: logger.debug("No billable mandates found, skipping account check") return 0 allAccounts = self.db.getRecordset(BillingAccount) existingAccountKeys = set() for acc in allAccounts: if not acc.get("userId"): continue key = (acc.get("mandateId"), acc.get("userId")) existingAccountKeys.add(key) allUserMandates = appDb.getRecordset( UserMandate, recordFilter={"enabled": True} ) for um in allUserMandates: mandateId = um.get("mandateId") userId = um.get("userId") if not mandateId or not userId: continue if mandateId not in billingMandateIds: continue key = (mandateId, userId) if key in existingAccountKeys: continue account = BillingAccount( mandateId=mandateId, userId=userId, balance=0.0, enabled=True ) self.createAccount(account) existingAccountKeys.add(key) accountsCreated += 1 if accountsCreated > 0: logger.info(f"Created {accountsCreated} missing billing accounts") return accountsCreated except Exception as e: logger.error(f"Error ensuring user accounts exist: {e}") return 0 # ========================================================================= # BillingTransaction Operations # ========================================================================= def createTransaction(self, transaction: BillingTransaction, balanceAccountId: str = None) -> Dict[str, Any]: """ Create a new billing transaction and update account balance. The transaction is always recorded against transaction.accountId (audit trail). The balance is updated on balanceAccountId if provided, otherwise on transaction.accountId. This allows recording a transaction on a user account (audit) while deducting from a mandate pool account (shared budget). Args: transaction: BillingTransaction object balanceAccountId: Optional account ID for balance update (defaults to transaction.accountId) Returns: Created transaction dict """ # Validate that the transaction's account exists txAccount = self.getAccount(transaction.accountId) if not txAccount: raise ValueError(f"Transaction account {transaction.accountId} not found") # Determine which account to update balance on targetBalanceAccountId = balanceAccountId or transaction.accountId if targetBalanceAccountId == transaction.accountId: balanceAccount = txAccount else: balanceAccount = self.getAccount(targetBalanceAccountId) if not balanceAccount: raise ValueError(f"Balance account {targetBalanceAccountId} not found") currentBalance = balanceAccount.get("balance", 0.0) # Calculate new balance if transaction.transactionType == TransactionTypeEnum.CREDIT: newBalance = currentBalance + transaction.amount elif transaction.transactionType == TransactionTypeEnum.DEBIT: newBalance = currentBalance - transaction.amount else: # ADJUSTMENT newBalance = currentBalance + transaction.amount # Create transaction record (always on transaction.accountId for audit) transactionDict = transaction.model_dump(exclude_none=True) created = self.db.recordCreate(BillingTransaction, transactionDict) # Update balance on the target account self.updateAccountBalance(targetBalanceAccountId, newBalance) logger.info(f"Billing transaction created: {transaction.transactionType.value} {transaction.amount} CHF, " f"audit={transaction.accountId}, balance on {targetBalanceAccountId}: {currentBalance} -> {newBalance}") return created def getTransactions( self, accountId: str, limit: int = 100, offset: int = 0, startDate: date = None, endDate: date = None, pagination: PaginationParams = None ) -> Union[List[Dict[str, Any]], PaginatedResult]: """ Get transactions for an account. When pagination is provided, uses database-level pagination and returns PaginatedResult. Otherwise falls back to in-memory filtering/sorting/slicing. Args: accountId: Account ID limit: Maximum number of results (legacy path) offset: Offset for pagination (legacy path) startDate: Filter by start date (legacy path) endDate: Filter by end date (legacy path) pagination: PaginationParams for DB-level pagination Returns: PaginatedResult when pagination is provided, List of dicts otherwise """ try: if pagination: recordFilter = {"accountId": accountId} result = self.db.getRecordsetPaginated( BillingTransaction, pagination=pagination, recordFilter=recordFilter ) _logBillingTransactionsMissingSysCreatedAt( result["items"], "getTransactions(accountId) paginated", ) return PaginatedResult( items=result["items"], totalItems=result["totalItems"], totalPages=result["totalPages"] ) filterDict = {"accountId": accountId} results = self.db.getRecordset(BillingTransaction, recordFilter=filterDict) if startDate or endDate: filtered = [] for t in results: createdAt = t.get("sysCreatedAt") if createdAt: tDate = createdAt.date() if isinstance(createdAt, datetime) else createdAt if startDate and tDate < startDate: continue if endDate and tDate > endDate: continue filtered.append(t) results = filtered _sortBillingTransactionsBySysCreatedAtDesc(results, "getTransactions(accountId)") return results[offset:offset + limit] except Exception as e: logger.error(f"Error getting transactions: {e}") if pagination: return PaginatedResult(items=[], totalItems=0, totalPages=0) return [] def getTransactionsByMandate( self, mandateId: str, limit: int = 100, pagination: PaginationParams = None ) -> Union[List[Dict[str, Any]], PaginatedResult]: """ Get all transactions for a mandate (across all accounts). When pagination is provided, collects all accountIds for the mandate and issues a single DB query with SQL-level filtering, sorting, and pagination. Otherwise falls back to per-account querying and in-memory merging. Args: mandateId: Mandate ID limit: Maximum number of results (legacy path) pagination: PaginationParams for DB-level pagination Returns: PaginatedResult when pagination is provided, List of dicts otherwise """ accounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId}) accountIds = [acc["id"] for acc in accounts if acc.get("id")] if not accountIds: if pagination: return PaginatedResult(items=[], totalItems=0, totalPages=0) return [] if pagination: result = self.db.getRecordsetPaginated( BillingTransaction, pagination=pagination, recordFilter={"accountId": accountIds} ) return PaginatedResult( items=result["items"], totalItems=result["totalItems"], totalPages=result["totalPages"] ) allTransactions = [] for account in accounts: transactions = self.getTransactions(account["id"], limit=limit) allTransactions.extend(transactions) _sortBillingTransactionsBySysCreatedAtDesc( allTransactions, "getTransactionsByMandate", ) return allTransactions[:limit] # ========================================================================= # StripeWebhookEvent Operations (idempotency) # ========================================================================= def getStripeWebhookEventByEventId(self, event_id: str) -> Optional[Dict[str, Any]]: """ Check if a Stripe event has already been processed (idempotency). Args: event_id: Stripe event ID (evt_xxx) Returns: Event record if exists, else None """ try: results = self.db.getRecordset( StripeWebhookEvent, recordFilter={"event_id": event_id} ) return results[0] if results else None except Exception as e: logger.error(f"Error checking Stripe webhook event: {e}") return None def createStripeWebhookEvent(self, event_id: str) -> Dict[str, Any]: """ Record that a Stripe event has been processed. Args: event_id: Stripe event ID (evt_xxx) Returns: Created event record """ record = StripeWebhookEvent(event_id=event_id) return self.db.recordCreate(StripeWebhookEvent, record.model_dump()) def getPaymentTransactionByReferenceId(self, referenceId: str) -> Optional[Dict[str, Any]]: """ Find an existing Stripe payment credit transaction by Checkout Session ID. Args: referenceId: Stripe Checkout Session ID (cs_xxx) Returns: Transaction record if found, else None """ try: results = self.db.getRecordset( BillingTransaction, recordFilter={ "referenceType": ReferenceTypeEnum.PAYMENT.value, "referenceId": referenceId, } ) return results[0] if results else None except Exception as e: logger.error(f"Error checking Stripe payment transaction by referenceId: {e}") return None # ========================================================================= # Balance Check Operations # ========================================================================= def checkBalance(self, mandateId: str, userId: str, estimatedCost: float) -> BillingCheckResult: """ Check if there's sufficient balance for an operation. Checks mandate pool balance against estimatedCost. User accounts are ensured to exist for audit trail. Missing settings: treated as PREPAY_MANDATE with empty pool. """ self.getOrCreateUserAccount(mandateId, userId, initialBalance=0.0) poolAccount = self.getOrCreateMandateAccount(mandateId) currentBalance = poolAccount.get("balance", 0.0) if currentBalance < estimatedCost: return BillingCheckResult( allowed=False, reason="INSUFFICIENT_BALANCE", currentBalance=currentBalance, requiredAmount=estimatedCost, ) return BillingCheckResult(allowed=True, currentBalance=currentBalance) def recordUsage( self, mandateId: str, userId: str, priceCHF: float, workflowId: str = None, featureInstanceId: str = None, featureCode: str = None, aicoreProvider: str = None, aicoreModel: str = None, description: str = "AI Usage", processingTime: float = None, bytesSent: int = None, bytesReceived: int = None, errorCount: int = None ) -> Optional[Dict[str, Any]]: """ Record usage cost as a billing transaction. Transaction is recorded on the user's account (audit trail). Balance is always deducted from the mandate pool account (PREPAY_MANDATE). """ if priceCHF <= 0: return None settings = self.getSettings(mandateId) if not settings: logger.debug(f"No billing settings for mandate {mandateId}, skipping usage recording") return None userAccount = self.getOrCreateUserAccount(mandateId, userId) transaction = BillingTransaction( accountId=userAccount["id"], transactionType=TransactionTypeEnum.DEBIT, amount=priceCHF, description=description, referenceType=ReferenceTypeEnum.WORKFLOW, workflowId=workflowId, featureInstanceId=featureInstanceId, featureCode=featureCode, aicoreProvider=aicoreProvider, aicoreModel=aicoreModel, createdByUserId=userId, processingTime=processingTime, bytesSent=bytesSent, bytesReceived=bytesReceived, errorCount=errorCount ) poolAccount = self.getOrCreateMandateAccount(mandateId) return self.createTransaction(transaction, balanceAccountId=poolAccount["id"]) def _parseSettingsDateTime(self, value: Any) -> Optional[datetime]: """Parse datetime from billing settings row (ISO string or datetime).""" if value is None: return None if isinstance(value, datetime): if value.tzinfo: return value.astimezone(timezone.utc) return value.replace(tzinfo=timezone.utc) if isinstance(value, str): s = value.replace("Z", "+00:00") try: dt = datetime.fromisoformat(s) except ValueError: return None if dt.tzinfo: return dt.astimezone(timezone.utc) return dt.replace(tzinfo=timezone.utc) return None def resetStorageBillingPeriod(self, mandateId: str, periodStartAt: datetime) -> None: """Reset storage watermark state for a new subscription billing period (e.g. Stripe invoice.paid).""" if periodStartAt.tzinfo is None: periodStartAt = periodStartAt.replace(tzinfo=timezone.utc) else: periodStartAt = periodStartAt.astimezone(timezone.utc) settings = self.getOrCreateSettings(mandateId) prev = self._parseSettingsDateTime(settings.get("storagePeriodStartAt")) if prev is not None and abs((prev - periodStartAt).total_seconds()) < 2: return from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRoot usedMB = float(_getSubRoot().getMandateDataVolumeMB(mandateId)) self.updateSettings( settings["id"], { "storageHighWatermarkMB": usedMB, "storageBilledUpToMB": 0.0, "storagePeriodStartAt": periodStartAt, }, ) logger.info( "Storage billing period reset for mandate %s at %s (usedMB=%.2f)", mandateId, periodStartAt.isoformat(), usedMB, ) def reconcileMandateStorageBilling(self, mandateId: str) -> Optional[Dict[str, Any]]: """Debit prepay pool for new storage overage using period high-watermark (no credit on delete).""" settings = self.getSettings(mandateId) if not settings: return None from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRoot from modules.datamodels.datamodelSubscription import _getPlan subIface = _getSubRoot() usedMB = float(subIface.getMandateDataVolumeMB(mandateId)) sub = subIface.getOperativeForMandate(mandateId) plan = _getPlan(sub.get("planKey", "")) if sub else None includedMB = plan.maxDataVolumeMB if plan and plan.maxDataVolumeMB is not None else None if includedMB is None: return None prevHigh = float(settings.get("storageHighWatermarkMB") or 0.0) high = max(prevHigh, usedMB) overageMB = max(0.0, high - float(includedMB)) billed = float(settings.get("storageBilledUpToMB") or 0.0) deltaOverage = overageMB - billed settingsUpdates: Dict[str, Any] = {} if high != prevHigh: settingsUpdates["storageHighWatermarkMB"] = high if deltaOverage <= 1e-9: if settingsUpdates: self.updateSettings(settings["id"], settingsUpdates) return None costCHF = round((deltaOverage / 1024.0) * float(STORAGE_PRICE_PER_GB_CHF), 4) if costCHF <= 0: if settingsUpdates: self.updateSettings(settings["id"], settingsUpdates) return None poolAccount = self.getOrCreateMandateAccount(mandateId) transaction = BillingTransaction( accountId=poolAccount["id"], transactionType=TransactionTypeEnum.DEBIT, amount=costCHF, description=f"Speicher-Überhang ({deltaOverage:.2f} MB über Plan)", referenceType=ReferenceTypeEnum.STORAGE, referenceId=mandateId, ) created = self.createTransaction(transaction) settingsUpdates["storageBilledUpToMB"] = overageMB self.updateSettings(settings["id"], settingsUpdates) logger.info( "Storage overage billed mandate=%s deltaOverageMB=%.4f costCHF=%s", mandateId, deltaOverage, costCHF, ) return created # ========================================================================= # Subscription AI-Budget Credit # ========================================================================= def creditSubscriptionBudget(self, mandateId: str, planKey: str, periodLabel: str = "") -> Optional[Dict[str, Any]]: """Credit AI budget to the mandate pool account. Amount = budgetAiPerUserCHF * activeUsers (dynamic, not the static plan.budgetAiCHF). Should be called once per billing period (initial activation + each invoice.paid). Returns the created CREDIT transaction or None if budget is 0.""" from modules.datamodels.datamodelSubscription import _getPlan plan = _getPlan(planKey) if not plan or not plan.budgetAiPerUserCHF or plan.budgetAiPerUserCHF <= 0: return None from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRoot subRoot = _getSubRoot() activeUsers = max(subRoot.countActiveUsers(mandateId), 1) amount = plan.budgetAiPerUserCHF * activeUsers poolAccount = self.getOrCreateMandateAccount(mandateId) description = f"AI-Budget ({planKey}, {activeUsers} User)" if periodLabel: description += f" – {periodLabel}" transaction = BillingTransaction( accountId=poolAccount["id"], transactionType=TransactionTypeEnum.CREDIT, amount=amount, description=description, referenceType=ReferenceTypeEnum.SUBSCRIPTION, referenceId=mandateId, ) created = self.createTransaction(transaction) logger.info( "AI-Budget credited mandate=%s plan=%s users=%d amount=%.2f CHF", mandateId, planKey, activeUsers, amount, ) return created def ensureActivationBudget(self, mandateId: str, planKey: str) -> Optional[Dict[str, Any]]: """Idempotent: credit the activation budget only if no SUBSCRIPTION credit exists yet.""" poolAccount = self.getMandateAccount(mandateId) if not poolAccount: return self.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung") existing = self.db.getRecordset( BillingTransaction, recordFilter={ "accountId": poolAccount["id"], "transactionType": TransactionTypeEnum.CREDIT.value, "referenceType": ReferenceTypeEnum.SUBSCRIPTION.value, }, ) if existing: return None return self.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung") def adjustAiBudgetForUserChange(self, mandateId: str, planKey: str, delta: int) -> Optional[Dict[str, Any]]: """Pro-rata AI budget adjustment when users are added/removed mid-cycle. delta > 0: user added -> CREDIT pro-rata portion delta < 0: user removed -> DEBIT pro-rata portion """ from modules.datamodels.datamodelSubscription import _getPlan plan = _getPlan(planKey) if not plan or not plan.budgetAiPerUserCHF or plan.budgetAiPerUserCHF <= 0: return None from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRoot subRoot = _getSubRoot() operative = subRoot.getOperativeForMandate(mandateId) if not operative: return None periodStart = operative.get("currentPeriodStart") periodEnd = operative.get("currentPeriodEnd") if not periodStart or not periodEnd: return None if isinstance(periodStart, str): periodStart = datetime.fromisoformat(periodStart) if isinstance(periodEnd, str): periodEnd = datetime.fromisoformat(periodEnd) if periodStart.tzinfo is None: periodStart = periodStart.replace(tzinfo=timezone.utc) if periodEnd.tzinfo is None: periodEnd = periodEnd.replace(tzinfo=timezone.utc) now = datetime.now(timezone.utc) totalSeconds = (periodEnd - periodStart).total_seconds() remainingSeconds = max((periodEnd - now).total_seconds(), 0) proRataFraction = remainingSeconds / totalSeconds if totalSeconds > 0 else 0 amount = round(abs(delta) * plan.budgetAiPerUserCHF * proRataFraction, 2) if amount <= 0: return None poolAccount = self.getOrCreateMandateAccount(mandateId) if delta > 0: txType = TransactionTypeEnum.CREDIT label = f"AI-Budget pro-rata +{abs(delta)} User ({planKey})" else: txType = TransactionTypeEnum.DEBIT label = f"AI-Budget pro-rata -{abs(delta)} User ({planKey})" transaction = BillingTransaction( accountId=poolAccount["id"], transactionType=txType, amount=amount, description=label, referenceType=ReferenceTypeEnum.SUBSCRIPTION, referenceId=mandateId, ) created = self.createTransaction(transaction) logger.info( "AI-Budget pro-rata %s mandate=%s delta=%+d amount=%.2f CHF (fraction=%.4f)", txType.value, mandateId, delta, amount, proRataFraction, ) return created # ========================================================================= # Workflow Cost Query # ========================================================================= def getWorkflowCost(self, workflowId: str) -> float: """Sum of all transaction amounts for a workflow.""" if not workflowId: return 0.0 transactions = self.db.getRecordset( BillingTransaction, recordFilter={"workflowId": workflowId} ) return sum(t.get("amount", 0.0) for t in transactions) # ========================================================================= # Statistics Operations # ========================================================================= def getUsageStatistics( self, accountId: str, periodType: PeriodTypeEnum, year: int, month: int = None ) -> List[Dict[str, Any]]: """ Get usage statistics for an account. Args: accountId: Account ID periodType: Period type (DAY, MONTH, YEAR) year: Year month: Month (for DAY period type) Returns: List of statistics dicts """ filterDict = { "accountId": accountId, "periodType": periodType.value } results = self.db.getRecordset(UsageStatistics, recordFilter=filterDict) # Filter by year filtered = [s for s in results if s.get("periodStart") and s["periodStart"].year == year] # Filter by month if specified if month and periodType == PeriodTypeEnum.DAY: filtered = [s for s in filtered if s["periodStart"].month == month] return sorted(filtered, key=lambda x: x.get("periodStart", date.min)) def calculateStatisticsFromTransactions( self, accountId: str, startDate: date, endDate: date ) -> Dict[str, Any]: """ Calculate statistics from transactions for a period. Args: accountId: Account ID startDate: Start date endDate: End date Returns: Statistics dict """ transactions = self.getTransactions(accountId, limit=10000, startDate=startDate, endDate=endDate) # Filter only DEBIT transactions (usage) debits = [t for t in transactions if t.get("transactionType") == TransactionTypeEnum.DEBIT.value] totalCost = sum(t.get("amount", 0) for t in debits) # Calculate by provider costByProvider = {} costByModel = {} for t in debits: provider = t.get("aicoreProvider", "unknown") costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0) model = t.get("aicoreModel", "unknown") costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0) # Calculate by feature costByFeature = {} for t in debits: feature = t.get("featureCode", "unknown") costByFeature[feature] = costByFeature.get(feature, 0) + t.get("amount", 0) return { "totalCostCHF": totalCost, "transactionCount": len(debits), "costByProvider": costByProvider, "costByModel": costByModel, "costByFeature": costByFeature } # ========================================================================= # Utility Methods # ========================================================================= def getBalancesForUser(self, userId: str) -> List[BillingBalanceResponse]: """ Get all billing balances for a user across mandates. Shows the mandate pool balance (shared budget visible to user). Args: userId: User ID Returns: List of BillingBalanceResponse """ from modules.interfaces.interfaceDbApp import getRootInterface balances = [] try: # Use rootInterface (privileged, SysAdmin context) to bypass RBAC # for mandate/user lookups. User access is verified via UserMandate membership. rootInterface = getRootInterface() userMandates = rootInterface.getUserMandates(userId) for um in userMandates: mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) if not mandateId: continue mandate = rootInterface.getMandate(mandateId) if not mandate or not getattr(mandate, "enabled", True): continue mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "") settings = self.getSettings(mandateId) if not settings: continue poolAccount = self.getOrCreateMandateAccount(mandateId) if not poolAccount: continue balance = poolAccount.get("balance", 0.0) warningThreshold = poolAccount.get("warningThreshold", 0.0) balances.append(BillingBalanceResponse( mandateId=mandateId, mandateName=mandateName, balance=balance, warningThreshold=warningThreshold, isWarning=balance <= warningThreshold, )) except Exception as e: logger.error(f"Error getting balances for user: {e}") return balances def getTransactionsForUser(self, userId: str, limit: int = 100) -> List[Dict[str, Any]]: """ Get all transactions for a user across all mandates they belong to. Since transactions are always recorded on user accounts, we query directly by user account - clean and simple. Args: userId: User ID limit: Maximum number of results Returns: List of transaction dicts """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface allTransactions = [] try: appInterface = getAppInterface(self.currentUser) userMandates = appInterface.getUserMandates(userId) for um in userMandates: mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None) if not mandateId: continue settings = self.getSettings(mandateId) if not settings: continue # Get user's account in this mandate userAccount = self.getUserAccount(mandateId, userId) if not userAccount: continue # Get transactions for user's account (all transactions are on user accounts now) transactions = self.getTransactions(userAccount["id"], limit=limit) mandate = appInterface.getMandate(mandateId) mandateName = "" if mandate: mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "") for t in transactions: t["mandateId"] = mandateId t["mandateName"] = mandateName allTransactions.append(t) except Exception as e: logger.error(f"Error getting transactions for user: {e}") _sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getTransactionsForUser") return allTransactions[:limit] # ========================================================================= # Mandate View Operations (Admin-Level) # ========================================================================= def getMandateBalances(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]: """ Get mandate-level balances. Args: mandateIds: Optional list of mandate IDs to filter. If None, returns all. Returns: List of mandate balance dicts """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface balances = [] try: appInterface = getAppInterface(self.currentUser) # Get settings for filtering if mandateIds: allSettings = [self.getSettings(mId) for mId in mandateIds] allSettings = [s for s in allSettings if s] else: allSettings = self.db.getRecordset(BillingSettings) for settings in allSettings: mandateId = settings.get("mandateId") if not mandateId: continue mandate = appInterface.getMandate(mandateId) mandateName = "" if mandate: mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "") allMandateAccounts = self.db.getRecordset( BillingAccount, recordFilter={"mandateId": mandateId} ) userCount = sum(1 for acc in allMandateAccounts if acc.get("userId")) poolAccount = self.getMandateAccount(mandateId) totalBalance = poolAccount.get("balance", 0.0) if poolAccount else 0.0 balances.append({ "mandateId": mandateId, "mandateName": mandateName, "totalBalance": totalBalance, "userCount": userCount, "warningThresholdPercent": settings.get("warningThresholdPercent", 10.0), }) except Exception as e: logger.error(f"Error getting mandate balances: {e}") return balances def getMandateTransactions(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]: """ Get all transactions for specified mandates. Args: mandateIds: Optional list of mandate IDs to filter. If None, returns all. limit: Maximum number of results Returns: List of transaction dicts with mandate context """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface allTransactions = [] try: appInterface = getAppInterface(self.currentUser) # Determine which mandates to query if mandateIds: targetMandateIds = mandateIds else: allSettings = self.db.getRecordset(BillingSettings) targetMandateIds = [s.get("mandateId") for s in allSettings if s.get("mandateId")] for mandateId in targetMandateIds: transactions = self.getTransactionsByMandate(mandateId, limit=limit) # Get mandate name mandate = appInterface.getMandate(mandateId) mandateName = "" if mandate: mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "") for t in transactions: t["mandateId"] = mandateId t["mandateName"] = mandateName allTransactions.append(t) except Exception as e: logger.error(f"Error getting mandate transactions: {e}") # Sort by creation date descending and limit _sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getMandateTransactions") return allTransactions[:limit] # ========================================================================= # User View Operations (User-Level with RBAC) # ========================================================================= def getUserBalancesForMandates(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]: """ Get all user-level balances for specified mandates. Args: mandateIds: Optional list of mandate IDs to filter. If None, returns all. Returns: List of user balance dicts with mandate and user context """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface balances = [] try: appInterface = getAppInterface(self.currentUser) allAccounts = self.db.getRecordset(BillingAccount) allAccounts = [acc for acc in allAccounts if acc.get("userId")] # Filter by mandate if specified if mandateIds: allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds] # Get all relevant settings in one query settingsMap = {} allSettings = self.db.getRecordset(BillingSettings) for s in allSettings: settingsMap[s.get("mandateId")] = s # Get user info efficiently userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId"))) userMap = {} for userId in userIds: user = appInterface.getUser(userId) if user: displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None) userMap[userId] = displayName or username or userId # Get mandate info efficiently mandateMap = {} mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId"))) for mandateId in mandateIdList: mandate = appInterface.getMandate(mandateId) if mandate: mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "") mandateMap[mandateId] = mandateName for account in allAccounts: mandateId = account.get("mandateId") userId = account.get("userId") if not mandateId or not userId: continue settings = settingsMap.get(mandateId) if not settings: continue balance = account.get("balance", 0.0) warningThreshold = account.get("warningThreshold", 0.0) balances.append({ "accountId": account.get("id"), "mandateId": mandateId, "mandateName": mandateMap.get(mandateId, ""), "userId": userId, "userName": userMap.get(userId, userId), "balance": balance, "warningThreshold": warningThreshold, "isWarning": balance <= warningThreshold, "enabled": account.get("enabled", True) }) except Exception as e: logger.error(f"Error getting user balances for mandates: {e}") return balances @staticmethod def _mapPaginationColumns(pagination: PaginationParams) -> PaginationParams: """Remap frontend column names to DB column names in filters and sort.""" _COL_MAP = {"createdAt": "sysCreatedAt"} _ENRICHED_COLS = {"mandateName", "userName", "mandateId", "userId"} import copy p = copy.deepcopy(pagination) if p.filters: mapped = {} for k, v in p.filters.items(): if k in _ENRICHED_COLS: continue mapped[_COL_MAP.get(k, k)] = v p.filters = mapped if p.sort: mapped = [] for s in p.sort: field = s.get("field", "") if isinstance(s, dict) else getattr(s, "field", "") if field in _ENRICHED_COLS: continue newField = _COL_MAP.get(field, field) if isinstance(s, dict): mapped.append({**s, "field": newField}) else: mapped.append({"field": newField, "direction": getattr(s, "direction", "asc")}) p.sort = mapped if mapped else [{"field": "sysCreatedAt", "direction": "desc"}] return p def getTransactionsForMandatesPaginated( self, mandateIds: Optional[List[str]], pagination: PaginationParams, scope: str = "all", userId: Optional[str] = None, ) -> PaginatedResult: """ SQL-level paginated transactions across multiple mandates. Single SQL query with WHERE accountId = ANY(...), ORDER BY, LIMIT/OFFSET. Enrichment (userName, mandateName) only for the returned page. """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface try: mappedPagination = self._mapPaginationColumns(pagination) allAccounts = self.db.getRecordset(BillingAccount) if mandateIds: allAccounts = [a for a in allAccounts if a.get("mandateId") in set(mandateIds)] accountIds = [a.get("id") for a in allAccounts if a.get("id")] if not accountIds: return PaginatedResult(items=[], totalItems=0, totalPages=0) recordFilter: Dict[str, Any] = {"accountId": accountIds} if scope == "personal" and userId: recordFilter["createdByUserId"] = userId result = self.db.getRecordsetPaginated( BillingTransaction, pagination=mappedPagination, recordFilter=recordFilter, ) pageItems = result.get("items", []) if isinstance(result, dict) else result.items accountMap = {a.get("id"): a for a in allAccounts} pageUserIds = set() pageMandateIds = set() for t in pageItems: accId = t.get("accountId") acc = accountMap.get(accId, {}) mid = acc.get("mandateId") uid = t.get("createdByUserId") or acc.get("userId") if uid: pageUserIds.add(uid) if mid: pageMandateIds.add(mid) appInterface = getAppInterface(self.currentUser) userMap: Dict[str, str] = {} if pageUserIds: users = appInterface.getUsersByIds(list(pageUserIds)) for uid, u in users.items(): dn = getattr(u, "displayName", None) or getattr(u, "username", None) or uid userMap[uid] = dn mandateMap: Dict[str, str] = {} if pageMandateIds: mandates = appInterface.getMandatesByIds(list(pageMandateIds)) for mid, m in mandates.items(): mandateMap[mid] = getattr(m, "label", None) or getattr(m, "name", None) or mid enriched = [] for t in pageItems: row = dict(t) accId = row.get("accountId") acc = accountMap.get(accId, {}) mid = acc.get("mandateId") txUserId = row.get("createdByUserId") or acc.get("userId") row["mandateId"] = mid row["mandateName"] = mandateMap.get(mid, "") row["userId"] = txUserId row["userName"] = userMap.get(txUserId, txUserId) if txUserId else None enriched.append(row) totalItems = result.get("totalItems", 0) if isinstance(result, dict) else result.totalItems totalPages = result.get("totalPages", 0) if isinstance(result, dict) else result.totalPages return PaginatedResult(items=enriched, totalItems=totalItems, totalPages=totalPages) except Exception as e: logger.error(f"Error in getTransactionsForMandatesPaginated: {e}") return PaginatedResult(items=[], totalItems=0, totalPages=0) def _buildScopeFilter( self, mandateIds: Optional[List[str]], scope: str = "all", userId: Optional[str] = None, startTs: Optional[float] = None, endTs: Optional[float] = None, ) -> tuple: """Build WHERE clause parts for scoped transaction queries. Returns (conditions, values, accountIds).""" allAccounts = self.db.getRecordset(BillingAccount) if mandateIds: mandateSet = set(mandateIds) allAccounts = [a for a in allAccounts if a.get("mandateId") in mandateSet] accountIds = [a.get("id") for a in allAccounts if a.get("id")] if not accountIds: return [], [], [], allAccounts conditions = ['"accountId" = ANY(%s)', '"transactionType" = %s'] values: list = [accountIds, "DEBIT"] if scope == "personal" and userId: conditions.append('"createdByUserId" = %s') values.append(userId) if startTs is not None: conditions.append('"sysCreatedAt" >= %s') values.append(startTs) if endTs is not None: conditions.append('"sysCreatedAt" < %s') values.append(endTs) return conditions, values, accountIds, allAccounts def getTransactionStatisticsAggregated( self, mandateIds: Optional[List[str]], scope: str = "all", userId: Optional[str] = None, startTs: Optional[float] = None, endTs: Optional[float] = None, period: str = "month", ) -> Dict[str, Any]: """ Pure SQL aggregation for statistics. No row-level loading. Returns: totalCost, transactionCount, costByProvider, costByModel, costByFeature, costByAccountId, timeSeries """ table = BillingTransaction.__name__ try: if not self.db._ensureTableExists(BillingTransaction): return self._emptyStats() conditions, values, accountIds, allAccounts = self._buildScopeFilter( mandateIds, scope, userId, startTs, endTs ) if not accountIds: return self._emptyStats() whereClause = " WHERE " + " AND ".join(conditions) self.db._ensure_connection() result: Dict[str, Any] = {} with self.db.connection.cursor() as cur: # 1) Totals cur.execute( f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}', values, ) row = cur.fetchone() result["totalCost"] = round(float(row["total"]), 4) result["transactionCount"] = int(row["cnt"]) # 2) GROUP BY aicoreProvider cur.execute( f'SELECT COALESCE("aicoreProvider", \'unknown\') AS grp, SUM("amount") AS total ' f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC', values, ) result["costByProvider"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()} # 3) GROUP BY aicoreModel cur.execute( f'SELECT COALESCE("aicoreModel", \'unknown\') AS grp, SUM("amount") AS total ' f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC', values, ) result["costByModel"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()} # 4) GROUP BY accountId (will be enriched to mandateName by caller) cur.execute( f'SELECT "accountId" AS grp, SUM("amount") AS total ' f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC', values, ) result["costByAccountId"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()} # 5) GROUP BY accountId + featureCode (for costByFeature) cur.execute( f'SELECT "accountId", COALESCE("featureCode", \'unknown\') AS fc, SUM("amount") AS total ' f'FROM "{table}"{whereClause} GROUP BY "accountId", fc ORDER BY total DESC', values, ) result["costByAccountFeature"] = [ {"accountId": r["accountId"], "featureCode": r["fc"], "total": round(float(r["total"]), 4)} for r in cur.fetchall() ] # 6) Time series via DATE_TRUNC on epoch timestamp if period == "day": truncExpr = "DATE_TRUNC('day', TO_TIMESTAMP(\"sysCreatedAt\"))" else: truncExpr = "DATE_TRUNC('month', TO_TIMESTAMP(\"sysCreatedAt\"))" cur.execute( f'SELECT {truncExpr} AS bucket, SUM("amount") AS total, COUNT(*) AS cnt ' f'FROM "{table}"{whereClause} AND "sysCreatedAt" IS NOT NULL ' f'GROUP BY bucket ORDER BY bucket', values, ) timeSeries = [] for r in cur.fetchall(): bucket = r["bucket"] if period == "day": label = bucket.strftime("%Y-%m-%d") if bucket else "unknown" else: label = bucket.strftime("%Y-%m") if bucket else "unknown" timeSeries.append({ "date": label, "cost": round(float(r["total"]), 4), "count": int(r["cnt"]), }) result["timeSeries"] = timeSeries self.db.connection.commit() result["_allAccounts"] = allAccounts return result except Exception as e: logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True) try: self.db.connection.rollback() except Exception: pass return self._emptyStats() @staticmethod def _emptyStats() -> Dict[str, Any]: return { "totalCost": 0.0, "transactionCount": 0, "costByProvider": {}, "costByModel": {}, "costByAccountId": {}, "costByAccountFeature": [], "timeSeries": [], "_allAccounts": [], } def getTransactionDistinctValues( self, mandateIds: Optional[List[str]], column: str, pagination: Optional[PaginationParams] = None, scope: str = "all", userId: Optional[str] = None, ) -> List[str]: """SQL DISTINCT for filter-values on BillingTransaction, scoped by mandates.""" _COLUMN_MAP = { "createdAt": "sysCreatedAt", "mandateId": "accountId", "mandateName": "accountId", } dbColumn = _COLUMN_MAP.get(column, column) mappedPagination = self._mapPaginationColumns(pagination) if pagination else None try: allAccounts = self.db.getRecordset(BillingAccount) if mandateIds: allAccounts = [a for a in allAccounts if a.get("mandateId") in set(mandateIds)] accountIds = [a.get("id") for a in allAccounts if a.get("id")] if not accountIds: return [] recordFilter: Dict[str, Any] = {"accountId": accountIds} if scope == "personal" and userId: recordFilter["createdByUserId"] = userId if column in ("mandateName", "userName"): return self._getEnrichedDistinctValues(column, allAccounts, recordFilter, mappedPagination) return self.db.getDistinctColumnValues( BillingTransaction, dbColumn, mappedPagination, recordFilter ) except Exception as e: logger.error(f"Error in getTransactionDistinctValues({column}): {e}") return [] def _getEnrichedDistinctValues( self, column: str, allAccounts: List[Dict], recordFilter: Dict[str, Any], pagination: Optional[PaginationParams], ) -> List[str]: """Resolve enriched columns (mandateName, userName) via batch lookup.""" from modules.interfaces.interfaceDbApp import getInterface as getAppInterface if column == "mandateName": mandateIds = list({a.get("mandateId") for a in allAccounts if a.get("mandateId")}) appInterface = getAppInterface(self.currentUser) mandates = appInterface.getMandatesByIds(mandateIds) return sorted( {getattr(m, "label", None) or getattr(m, "name", "") or mid for mid, m in mandates.items()}, key=lambda v: v.lower(), ) if column == "userName": dbCol = "createdByUserId" values = self.db.getDistinctColumnValues(BillingTransaction, dbCol, pagination, recordFilter) if not values: return [] appInterface = getAppInterface(self.currentUser) users = appInterface.getUsersByIds(values) return sorted( {getattr(u, "displayName", None) or getattr(u, "username", None) or uid for uid, u in users.items()}, key=lambda v: v.lower(), ) return [] def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]: """ Get all transactions for specified mandates. All usage transactions are on user accounts (audit trail). Args: mandateIds: Optional list of mandate IDs to filter. If None, returns all. limit: Maximum number of results Returns: List of transaction dicts with mandate and user context """ from modules.interfaces.interfaceDbApp import getInterface as getAppInterface allTransactions = [] try: appInterface = getAppInterface(self.currentUser) # Get ALL accounts (both USER and MANDATE types) to cover all billing models allAccounts = self.db.getRecordset(BillingAccount) # Filter by mandate if specified if mandateIds: allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds] # Build account to user/mandate mapping accountMap = {} for acc in allAccounts: accountMap[acc.get("id")] = { "mandateId": acc.get("mandateId"), "userId": acc.get("userId") } # Get user info efficiently userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId"))) userMap = {} for userId in userIds: user = appInterface.getUser(userId) if user: displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None) userMap[userId] = displayName or username or userId # Get mandate info efficiently mandateMap = {} mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId"))) for mandateId in mandateIdList: mandate = appInterface.getMandate(mandateId) if mandate: mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "") mandateMap[mandateId] = mandateName # Get transactions for all accounts and collect createdByUserIds rawTransactions = [] for account in allAccounts: accountId = account.get("id") if not accountId: continue transactions = self.getTransactions(accountId, limit=limit) accountInfo = accountMap.get(accountId, {}) mandateId = accountInfo.get("mandateId") accountUserId = accountInfo.get("userId") for t in transactions: t["_accountUserId"] = accountUserId t["_accountMandateId"] = mandateId rawTransactions.append(t) # Resolve createdByUserIds that are not yet in userMap extraUserIds = set() for t in rawTransactions: cbUserId = t.get("createdByUserId") if cbUserId and cbUserId not in userMap: extraUserIds.add(cbUserId) for uid in extraUserIds: user = appInterface.getUser(uid) if user: displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None) username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None) userMap[uid] = displayName or username or uid # Enrich transactions for t in rawTransactions: mandateId = t.pop("_accountMandateId", None) accountUserId = t.pop("_accountUserId", None) t["mandateId"] = mandateId t["mandateName"] = mandateMap.get(mandateId, "") # Prefer createdByUserId (per-transaction) over account-derived userId txUserId = t.get("createdByUserId") or accountUserId t["userId"] = txUserId t["userName"] = userMap.get(txUserId, txUserId) if txUserId else None allTransactions.append(t) except Exception as e: logger.error(f"Error getting user transactions for mandates: {e}") # Sort by creation date descending and limit _sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getUserTransactionsForMandates") return allTransactions[:limit]