gateway/modules/interfaces/interfaceDbBilling.py
2026-04-26 18:11:42 +02:00

2128 lines
83 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Interface for Billing operations.
Manages billing accounts, transactions, and usage statistics.
All billing data is stored in the poweron_billing database.
"""
import logging
from typing import Dict, Any, List, Optional, Union
from datetime import date, datetime, timedelta, timezone
import uuid
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
from modules.shared.dbRegistry import registerDatabase
from modules.shared.timeUtils import getUtcTimestamp
from modules.datamodels.datamodelUam import User, Mandate
from modules.datamodels.datamodelMembership import UserMandate
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
from modules.datamodels.datamodelBilling import (
BillingAccount,
BillingTransaction,
BillingSettings,
StripeWebhookEvent,
UsageStatistics,
TransactionTypeEnum,
ReferenceTypeEnum,
PeriodTypeEnum,
BillingBalanceResponse,
BillingCheckResult,
STORAGE_PRICE_PER_GB_CHF,
)
logger = logging.getLogger(__name__)
def _logBillingTransactionsMissingSysCreatedAt(rows: List[Dict[str, Any]], context: str) -> None:
"""Log ERROR when sysCreatedAt is missing; does not raise."""
missingIds = [r.get("id") for r in rows if r.get("sysCreatedAt") is None]
if not missingIds:
return
cap = 40
sample = missingIds[:cap]
suffix = f"; ... (+{len(missingIds) - cap} more)" if len(missingIds) > cap else ""
logger.error(
"BillingTransaction missing sysCreatedAt (%s): count=%s; transactionIds=%s%s",
context,
len(missingIds),
sample,
suffix,
)
def _numericSysCreatedAtForSort(row: Dict[str, Any]) -> float:
v = row["sysCreatedAt"]
if isinstance(v, datetime):
return v.timestamp()
return float(v)
def _sortBillingTransactionsBySysCreatedAtDesc(rows: List[Dict[str, Any]], context: str) -> None:
_logBillingTransactionsMissingSysCreatedAt(rows, context)
valid = [r for r in rows if r.get("sysCreatedAt") is not None]
invalid = [r for r in rows if r.get("sysCreatedAt") is None]
valid.sort(key=_numericSysCreatedAtForSort, reverse=True)
rows[:] = valid + invalid
def _getAppDatabaseConnector() -> DatabaseConnector:
"""App DB connector (same config as UserMandate reads in this module)."""
return DatabaseConnector(
dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"),
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
dbUser=APP_CONFIG.get("DB_USER"),
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
)
def _getRootMandateIdFromAppDb(appDb: DatabaseConnector) -> Optional[str]:
"""Resolve root mandate id (name='root', isSystem=True) from app database."""
try:
rows = appDb.getRecordset(Mandate, recordFilter={"name": "root", "isSystem": True})
if rows:
rid = rows[0].get("id")
return str(rid) if rid is not None else None
except Exception as e:
logger.warning("Could not resolve root mandate id from app DB: %s", e)
return None
_cachedRootMandateId: Optional[str] = None
_rootMandateIdCacheResolved: bool = False
def _getCachedRootMandateId() -> Optional[str]:
"""Lazy-cached root mandate id (name=root, isSystem=True) for hot paths."""
global _cachedRootMandateId, _rootMandateIdCacheResolved
if not _rootMandateIdCacheResolved:
appDb = _getAppDatabaseConnector()
_cachedRootMandateId = _getRootMandateIdFromAppDb(appDb)
_rootMandateIdCacheResolved = True
return _cachedRootMandateId
# Singleton factory for BillingObjects instances
_billingInterfaces: Dict[str, "BillingObjects"] = {}
# Database name for billing
BILLING_DATABASE = "poweron_billing"
registerDatabase(BILLING_DATABASE)
def getInterface(currentUser: User, mandateId: str = None) -> "BillingObjects":
"""
Factory function to get or create a BillingObjects instance.
Args:
currentUser: Current user object
mandateId: Mandate ID for context
Returns:
BillingObjects instance
"""
cacheKey = f"{currentUser.id}_{mandateId}"
if cacheKey not in _billingInterfaces:
_billingInterfaces[cacheKey] = BillingObjects(currentUser, mandateId)
else:
_billingInterfaces[cacheKey].setUserContext(currentUser, mandateId)
return _billingInterfaces[cacheKey]
def getRootInterface() -> "BillingObjects":
"""Get interface with system access for bootstrap operations."""
from modules.security.rootAccess import getRootUser
rootUser = getRootUser()
return BillingObjects(rootUser, mandateId=None)
class BillingObjects:
"""
Interface for billing operations.
Manages accounts, transactions, settings, and statistics.
"""
def __init__(self, currentUser: Optional[User] = None, mandateId: str = None):
"""
Initialize the billing interface.
Args:
currentUser: Current user object
mandateId: Mandate ID for context
"""
self.currentUser = currentUser
self.userId = currentUser.id if currentUser else None
self.mandateId = mandateId
# Initialize database connection
self._initializeDatabase()
def _initializeDatabase(self):
"""Initialize database connection."""
self.db = DatabaseConnector(
dbDatabase=BILLING_DATABASE,
dbHost=APP_CONFIG.get('DB_HOST', 'localhost'),
dbPort=int(APP_CONFIG.get('DB_PORT', '5432')),
dbUser=APP_CONFIG.get('DB_USER'),
dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET')
)
def setUserContext(self, currentUser: User, mandateId: str = None):
"""
Update user context.
Args:
currentUser: Current user object
mandateId: Mandate ID for context
"""
self.currentUser = currentUser
self.userId = currentUser.id if currentUser else None
self.mandateId = mandateId
# =========================================================================
# BillingSettings Operations
# =========================================================================
def getSettings(self, mandateId: str) -> Optional[Dict[str, Any]]:
"""
Get billing settings for a mandate.
Args:
mandateId: Mandate ID
Returns:
BillingSettings dict or None if not found
"""
try:
results = self.db.getRecordset(
BillingSettings,
recordFilter={"mandateId": mandateId}
)
if not results:
return None
return dict(results[0])
except Exception as e:
logger.error(f"Error getting billing settings: {e}")
return None
def createSettings(self, settings: BillingSettings) -> Dict[str, Any]:
"""
Create billing settings for a mandate.
Args:
settings: BillingSettings object
Returns:
Created settings dict
"""
settingsDict = settings.model_dump(exclude_none=True)
return self.db.recordCreate(BillingSettings, settingsDict)
def updateSettings(self, settingsId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Update billing settings.
Args:
settingsId: Settings ID
updates: Fields to update
Returns:
Updated settings dict or None
"""
return self.db.recordModify(BillingSettings, settingsId, updates)
def getOrCreateSettings(self, mandateId: str) -> Dict[str, Any]:
"""
Get or create billing settings for a mandate.
Args:
mandateId: Mandate ID
Returns:
BillingSettings dict
"""
existing = self.getSettings(mandateId)
if existing:
return existing
settings = BillingSettings(
mandateId=mandateId,
warningThresholdPercent=10.0,
notifyOnWarning=True,
)
return self.createSettings(settings)
# =========================================================================
# BillingAccount Operations
# =========================================================================
def getAccount(self, accountId: str) -> Optional[Dict[str, Any]]:
"""Get a billing account by ID."""
try:
results = self.db.getRecordset(
BillingAccount,
recordFilter={"id": accountId}
)
return results[0] if results else None
except Exception as e:
logger.error(f"Error getting billing account: {e}")
return None
def getMandateAccount(self, mandateId: str) -> Optional[Dict[str, Any]]:
"""
Get the mandate-level billing account.
Args:
mandateId: Mandate ID
Returns:
BillingAccount dict or None
"""
try:
results = self.db.getRecordset(
BillingAccount,
recordFilter={
"mandateId": mandateId,
"userId": None
}
)
return results[0] if results else None
except Exception as e:
logger.error(f"Error getting mandate account: {e}")
return None
def getUserAccount(self, mandateId: str, userId: str) -> Optional[Dict[str, Any]]:
"""
Get a user-level billing account within a mandate.
Args:
mandateId: Mandate ID
userId: User ID
Returns:
BillingAccount dict or None
"""
try:
results = self.db.getRecordset(
BillingAccount,
recordFilter={
"mandateId": mandateId,
"userId": userId
}
)
return results[0] if results else None
except Exception as e:
logger.error(f"Error getting user account: {e}")
return None
def getAccountsByMandate(self, mandateId: str) -> List[Dict[str, Any]]:
"""
Get all billing accounts for a mandate.
Args:
mandateId: Mandate ID
Returns:
List of BillingAccount dicts
"""
try:
return self.db.getRecordset(
BillingAccount,
recordFilter={"mandateId": mandateId}
)
except Exception as e:
logger.error(f"Error getting accounts for mandate: {e}")
return []
def createAccount(self, account: BillingAccount) -> Dict[str, Any]:
"""
Create a new billing account.
Args:
account: BillingAccount object
Returns:
Created account dict
"""
accountDict = account.model_dump(exclude_none=True)
return self.db.recordCreate(BillingAccount, accountDict)
def updateAccountBalance(self, accountId: str, newBalance: float) -> Optional[Dict[str, Any]]:
"""
Update account balance atomically.
Args:
accountId: Account ID
newBalance: New balance value
Returns:
Updated account dict or None
"""
return self.db.recordModify(BillingAccount, accountId, {"balance": newBalance})
def getOrCreateMandateAccount(self, mandateId: str, initialBalance: float = 0.0) -> Dict[str, Any]:
"""
Get or create a mandate-level billing account.
Args:
mandateId: Mandate ID
initialBalance: Initial balance if creating
Returns:
BillingAccount dict
"""
existing = self.getMandateAccount(mandateId)
if existing:
return existing
account = BillingAccount(
mandateId=mandateId,
balance=initialBalance,
enabled=True
)
return self.createAccount(account)
def getOrCreateUserAccount(self, mandateId: str, userId: str, initialBalance: float = 0.0) -> Dict[str, Any]:
"""
Get or create a user-level billing account.
Args:
mandateId: Mandate ID
userId: User ID
initialBalance: Initial balance if creating
Returns:
BillingAccount dict
"""
existing = self.getUserAccount(mandateId, userId)
if existing:
return existing
account = BillingAccount(
mandateId=mandateId,
userId=userId,
balance=initialBalance,
enabled=True
)
created = self.createAccount(account)
# If initial balance > 0, create a SYSTEM credit transaction
if initialBalance > 0:
self.createTransaction(BillingTransaction(
accountId=created["id"],
transactionType=TransactionTypeEnum.CREDIT,
amount=initialBalance,
description="Initial credit for new user",
referenceType=ReferenceTypeEnum.SYSTEM
))
return created
def ensureAllMandateSettingsExist(self) -> int:
"""
Efficiently ensure all mandates have billing settings.
Creates default settings (0 CHF) for mandates without settings.
Uses bulk queries to minimize database connections.
Returns:
Number of settings created
"""
try:
settingsCreated = 0
# Step 1: Get all existing billing settings in one query (from billing DB)
allSettings = self.db.getRecordset(BillingSettings)
existingMandateIds = set(s.get("mandateId") for s in allSettings if s.get("mandateId"))
# Step 2: Get all mandates from APP database (separate connection)
appDb = DatabaseConnector(
dbDatabase=APP_CONFIG.get('DB_DATABASE', 'poweron_app'),
dbHost=APP_CONFIG.get('DB_HOST', 'localhost'),
dbPort=int(APP_CONFIG.get('DB_PORT', '5432')),
dbUser=APP_CONFIG.get('DB_USER'),
dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET')
)
allMandates = appDb.getRecordset(Mandate, recordFilter={"enabled": True})
# Step 3: Create settings for mandates that don't have them
for mandate in allMandates:
mandateId = mandate.get("id")
if not mandateId or mandateId in existingMandateIds:
continue
settings = BillingSettings(
mandateId=mandateId,
warningThresholdPercent=10.0,
notifyOnWarning=True,
)
self.createSettings(settings)
existingMandateIds.add(mandateId)
settingsCreated += 1
if settingsCreated > 0:
logger.info(f"Created {settingsCreated} missing billing settings for mandates")
return settingsCreated
except Exception as e:
logger.error(f"Error ensuring mandate settings exist: {e}")
return 0
def ensureAllUserAccountsExist(self) -> int:
"""
Ensure all users across all mandates have billing accounts.
User accounts are always created for audit trail with initial balance 0.0.
Uses bulk queries to minimize database connections.
Returns:
Number of accounts created
"""
try:
accountsCreated = 0
appDb = _getAppDatabaseConnector()
allSettings = self.db.getRecordset(BillingSettings)
billingMandateIds = set(
s.get("mandateId") for s in allSettings if s.get("mandateId")
)
if not billingMandateIds:
logger.debug("No billable mandates found, skipping account check")
return 0
allAccounts = self.db.getRecordset(BillingAccount)
existingAccountKeys = set()
for acc in allAccounts:
if not acc.get("userId"):
continue
key = (acc.get("mandateId"), acc.get("userId"))
existingAccountKeys.add(key)
allUserMandates = appDb.getRecordset(
UserMandate,
recordFilter={"enabled": True}
)
for um in allUserMandates:
mandateId = um.get("mandateId")
userId = um.get("userId")
if not mandateId or not userId:
continue
if mandateId not in billingMandateIds:
continue
key = (mandateId, userId)
if key in existingAccountKeys:
continue
account = BillingAccount(
mandateId=mandateId,
userId=userId,
balance=0.0,
enabled=True
)
self.createAccount(account)
existingAccountKeys.add(key)
accountsCreated += 1
if accountsCreated > 0:
logger.info(f"Created {accountsCreated} missing billing accounts")
return accountsCreated
except Exception as e:
logger.error(f"Error ensuring user accounts exist: {e}")
return 0
# =========================================================================
# BillingTransaction Operations
# =========================================================================
def createTransaction(self, transaction: BillingTransaction, balanceAccountId: str = None) -> Dict[str, Any]:
"""
Create a new billing transaction and update account balance.
The transaction is always recorded against transaction.accountId (audit trail).
The balance is updated on balanceAccountId if provided, otherwise on transaction.accountId.
This allows recording a transaction on a user account (audit) while deducting
from a mandate pool account (shared budget).
Args:
transaction: BillingTransaction object
balanceAccountId: Optional account ID for balance update (defaults to transaction.accountId)
Returns:
Created transaction dict
"""
# Validate that the transaction's account exists
txAccount = self.getAccount(transaction.accountId)
if not txAccount:
raise ValueError(f"Transaction account {transaction.accountId} not found")
# Determine which account to update balance on
targetBalanceAccountId = balanceAccountId or transaction.accountId
if targetBalanceAccountId == transaction.accountId:
balanceAccount = txAccount
else:
balanceAccount = self.getAccount(targetBalanceAccountId)
if not balanceAccount:
raise ValueError(f"Balance account {targetBalanceAccountId} not found")
currentBalance = balanceAccount.get("balance", 0.0)
# Calculate new balance
if transaction.transactionType == TransactionTypeEnum.CREDIT:
newBalance = currentBalance + transaction.amount
elif transaction.transactionType == TransactionTypeEnum.DEBIT:
newBalance = currentBalance - transaction.amount
else: # ADJUSTMENT
newBalance = currentBalance + transaction.amount
# Create transaction record (always on transaction.accountId for audit)
transactionDict = transaction.model_dump(exclude_none=True)
created = self.db.recordCreate(BillingTransaction, transactionDict)
# Update balance on the target account
self.updateAccountBalance(targetBalanceAccountId, newBalance)
logger.info(f"Billing transaction created: {transaction.transactionType.value} {transaction.amount} CHF, "
f"audit={transaction.accountId}, balance on {targetBalanceAccountId}: {currentBalance} -> {newBalance}")
return created
def getTransactions(
self,
accountId: str,
limit: int = 100,
offset: int = 0,
startDate: date = None,
endDate: date = None,
pagination: PaginationParams = None
) -> Union[List[Dict[str, Any]], PaginatedResult]:
"""
Get transactions for an account.
When pagination is provided, uses database-level pagination and returns
PaginatedResult. Otherwise falls back to in-memory filtering/sorting/slicing.
Args:
accountId: Account ID
limit: Maximum number of results (legacy path)
offset: Offset for pagination (legacy path)
startDate: Filter by start date (legacy path)
endDate: Filter by end date (legacy path)
pagination: PaginationParams for DB-level pagination
Returns:
PaginatedResult when pagination is provided, List of dicts otherwise
"""
try:
if pagination:
recordFilter = {"accountId": accountId}
result = self.db.getRecordsetPaginated(
BillingTransaction,
pagination=pagination,
recordFilter=recordFilter
)
_logBillingTransactionsMissingSysCreatedAt(
result["items"],
"getTransactions(accountId) paginated",
)
return PaginatedResult(
items=result["items"],
totalItems=result["totalItems"],
totalPages=result["totalPages"]
)
filterDict = {"accountId": accountId}
results = self.db.getRecordset(BillingTransaction, recordFilter=filterDict)
if startDate or endDate:
filtered = []
for t in results:
createdAt = t.get("sysCreatedAt")
if createdAt:
tDate = createdAt.date() if isinstance(createdAt, datetime) else createdAt
if startDate and tDate < startDate:
continue
if endDate and tDate > endDate:
continue
filtered.append(t)
results = filtered
_sortBillingTransactionsBySysCreatedAtDesc(results, "getTransactions(accountId)")
return results[offset:offset + limit]
except Exception as e:
logger.error(f"Error getting transactions: {e}")
if pagination:
return PaginatedResult(items=[], totalItems=0, totalPages=0)
return []
def getTransactionsByMandate(
self,
mandateId: str,
limit: int = 100,
pagination: PaginationParams = None
) -> Union[List[Dict[str, Any]], PaginatedResult]:
"""
Get all transactions for a mandate (across all accounts).
When pagination is provided, collects all accountIds for the mandate and
issues a single DB query with SQL-level filtering, sorting, and pagination.
Otherwise falls back to per-account querying and in-memory merging.
Args:
mandateId: Mandate ID
limit: Maximum number of results (legacy path)
pagination: PaginationParams for DB-level pagination
Returns:
PaginatedResult when pagination is provided, List of dicts otherwise
"""
accounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId})
accountIds = [acc["id"] for acc in accounts if acc.get("id")]
if not accountIds:
if pagination:
return PaginatedResult(items=[], totalItems=0, totalPages=0)
return []
if pagination:
result = self.db.getRecordsetPaginated(
BillingTransaction,
pagination=pagination,
recordFilter={"accountId": accountIds}
)
return PaginatedResult(
items=result["items"],
totalItems=result["totalItems"],
totalPages=result["totalPages"]
)
allTransactions = []
for account in accounts:
transactions = self.getTransactions(account["id"], limit=limit)
allTransactions.extend(transactions)
_sortBillingTransactionsBySysCreatedAtDesc(
allTransactions,
"getTransactionsByMandate",
)
return allTransactions[:limit]
# =========================================================================
# StripeWebhookEvent Operations (idempotency)
# =========================================================================
def getStripeWebhookEventByEventId(self, event_id: str) -> Optional[Dict[str, Any]]:
"""
Check if a Stripe event has already been processed (idempotency).
Args:
event_id: Stripe event ID (evt_xxx)
Returns:
Event record if exists, else None
"""
try:
results = self.db.getRecordset(
StripeWebhookEvent,
recordFilter={"event_id": event_id}
)
return results[0] if results else None
except Exception as e:
logger.error(f"Error checking Stripe webhook event: {e}")
return None
def createStripeWebhookEvent(self, event_id: str) -> Dict[str, Any]:
"""
Record that a Stripe event has been processed.
Args:
event_id: Stripe event ID (evt_xxx)
Returns:
Created event record
"""
record = StripeWebhookEvent(event_id=event_id)
return self.db.recordCreate(StripeWebhookEvent, record.model_dump())
def getPaymentTransactionByReferenceId(self, referenceId: str) -> Optional[Dict[str, Any]]:
"""
Find an existing Stripe payment credit transaction by Checkout Session ID.
Args:
referenceId: Stripe Checkout Session ID (cs_xxx)
Returns:
Transaction record if found, else None
"""
try:
results = self.db.getRecordset(
BillingTransaction,
recordFilter={
"referenceType": ReferenceTypeEnum.PAYMENT.value,
"referenceId": referenceId,
}
)
return results[0] if results else None
except Exception as e:
logger.error(f"Error checking Stripe payment transaction by referenceId: {e}")
return None
# =========================================================================
# Balance Check Operations
# =========================================================================
def checkBalance(self, mandateId: str, userId: str, estimatedCost: float) -> BillingCheckResult:
"""
Check if there's sufficient balance for an operation.
Checks mandate pool balance against estimatedCost.
User accounts are ensured to exist for audit trail.
Missing settings: treated as PREPAY_MANDATE with empty pool.
"""
self.getOrCreateUserAccount(mandateId, userId, initialBalance=0.0)
poolAccount = self.getOrCreateMandateAccount(mandateId)
currentBalance = poolAccount.get("balance", 0.0)
if currentBalance < estimatedCost:
return BillingCheckResult(
allowed=False,
reason="INSUFFICIENT_BALANCE",
currentBalance=currentBalance,
requiredAmount=estimatedCost,
)
return BillingCheckResult(allowed=True, currentBalance=currentBalance)
def recordUsage(
self,
mandateId: str,
userId: str,
priceCHF: float,
workflowId: str = None,
featureInstanceId: str = None,
featureCode: str = None,
aicoreProvider: str = None,
aicoreModel: str = None,
description: str = "AI Usage",
processingTime: float = None,
bytesSent: int = None,
bytesReceived: int = None,
errorCount: int = None
) -> Optional[Dict[str, Any]]:
"""
Record usage cost as a billing transaction.
Transaction is recorded on the user's account (audit trail).
Balance is always deducted from the mandate pool account (PREPAY_MANDATE).
"""
if priceCHF <= 0:
return None
settings = self.getSettings(mandateId)
if not settings:
logger.debug(f"No billing settings for mandate {mandateId}, skipping usage recording")
return None
userAccount = self.getOrCreateUserAccount(mandateId, userId)
transaction = BillingTransaction(
accountId=userAccount["id"],
transactionType=TransactionTypeEnum.DEBIT,
amount=priceCHF,
description=description,
referenceType=ReferenceTypeEnum.WORKFLOW,
workflowId=workflowId,
featureInstanceId=featureInstanceId,
featureCode=featureCode,
aicoreProvider=aicoreProvider,
aicoreModel=aicoreModel,
createdByUserId=userId,
processingTime=processingTime,
bytesSent=bytesSent,
bytesReceived=bytesReceived,
errorCount=errorCount
)
poolAccount = self.getOrCreateMandateAccount(mandateId)
return self.createTransaction(transaction, balanceAccountId=poolAccount["id"])
def _parseSettingsDateTime(self, value: Any) -> Optional[datetime]:
"""Parse datetime from billing settings row (ISO string or datetime)."""
if value is None:
return None
if isinstance(value, datetime):
if value.tzinfo:
return value.astimezone(timezone.utc)
return value.replace(tzinfo=timezone.utc)
if isinstance(value, str):
s = value.replace("Z", "+00:00")
try:
dt = datetime.fromisoformat(s)
except ValueError:
return None
if dt.tzinfo:
return dt.astimezone(timezone.utc)
return dt.replace(tzinfo=timezone.utc)
return None
def resetStorageBillingPeriod(self, mandateId: str, periodStartAt: datetime) -> None:
"""Reset storage watermark state for a new subscription billing period (e.g. Stripe invoice.paid)."""
if periodStartAt.tzinfo is None:
periodStartAt = periodStartAt.replace(tzinfo=timezone.utc)
else:
periodStartAt = periodStartAt.astimezone(timezone.utc)
periodStartTs = periodStartAt.timestamp()
settings = self.getOrCreateSettings(mandateId)
prev = settings.get("storagePeriodStartAt")
if prev is not None and abs(prev - periodStartTs) < 2:
return
from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot
usedMB = float(_getSubRoot().getMandateDataVolumeMB(mandateId))
self.updateSettings(
settings["id"],
{
"storageHighWatermarkMB": usedMB,
"storageBilledUpToMB": 0.0,
"storagePeriodStartAt": periodStartTs,
},
)
logger.info(
"Storage billing period reset for mandate %s at %s (usedMB=%.2f)",
mandateId,
periodStartAt.isoformat(),
usedMB,
)
def reconcileMandateStorageBilling(self, mandateId: str) -> Optional[Dict[str, Any]]:
"""Debit prepay pool for new storage overage using period high-watermark (no credit on delete)."""
settings = self.getSettings(mandateId)
if not settings:
return None
from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot
from modules.datamodels.datamodelSubscription import getPlan
subIface = _getSubRoot()
usedMB = float(subIface.getMandateDataVolumeMB(mandateId))
sub = subIface.getOperativeForMandate(mandateId)
plan = getPlan(sub.get("planKey", "")) if sub else None
includedMB = plan.maxDataVolumeMB if plan and plan.maxDataVolumeMB is not None else None
if includedMB is None:
return None
prevHigh = float(settings.get("storageHighWatermarkMB") or 0.0)
high = max(prevHigh, usedMB)
overageMB = max(0.0, high - float(includedMB))
billed = float(settings.get("storageBilledUpToMB") or 0.0)
deltaOverage = overageMB - billed
settingsUpdates: Dict[str, Any] = {}
if high != prevHigh:
settingsUpdates["storageHighWatermarkMB"] = high
if deltaOverage <= 1e-9:
if settingsUpdates:
self.updateSettings(settings["id"], settingsUpdates)
return None
costCHF = round((deltaOverage / 1024.0) * float(STORAGE_PRICE_PER_GB_CHF), 4)
if costCHF <= 0:
if settingsUpdates:
self.updateSettings(settings["id"], settingsUpdates)
return None
poolAccount = self.getOrCreateMandateAccount(mandateId)
transaction = BillingTransaction(
accountId=poolAccount["id"],
transactionType=TransactionTypeEnum.DEBIT,
amount=costCHF,
description=f"Speicher-Überhang ({deltaOverage:.2f} MB über Plan)",
referenceType=ReferenceTypeEnum.STORAGE,
referenceId=mandateId,
)
created = self.createTransaction(transaction)
settingsUpdates["storageBilledUpToMB"] = overageMB
self.updateSettings(settings["id"], settingsUpdates)
logger.info(
"Storage overage billed mandate=%s deltaOverageMB=%.4f costCHF=%s",
mandateId,
deltaOverage,
costCHF,
)
return created
# =========================================================================
# Subscription AI-Budget Credit
# =========================================================================
def creditSubscriptionBudget(self, mandateId: str, planKey: str, periodLabel: str = "") -> Optional[Dict[str, Any]]:
"""Credit AI budget to the mandate pool account.
Amount = budgetAiPerUserCHF * activeUsers (dynamic, not the static plan.budgetAiCHF).
Should be called once per billing period (initial activation + each invoice.paid).
Returns the created CREDIT transaction or None if budget is 0."""
from modules.datamodels.datamodelSubscription import getPlan
plan = getPlan(planKey)
if not plan or not plan.budgetAiPerUserCHF or plan.budgetAiPerUserCHF <= 0:
return None
from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot
subRoot = _getSubRoot()
activeUsers = max(subRoot.countActiveUsers(mandateId), 1)
amount = plan.budgetAiPerUserCHF * activeUsers
poolAccount = self.getOrCreateMandateAccount(mandateId)
description = f"AI-Budget ({planKey}, {activeUsers} User)"
if periodLabel:
description += f" {periodLabel}"
transaction = BillingTransaction(
accountId=poolAccount["id"],
transactionType=TransactionTypeEnum.CREDIT,
amount=amount,
description=description,
referenceType=ReferenceTypeEnum.SUBSCRIPTION,
referenceId=mandateId,
)
created = self.createTransaction(transaction)
logger.info(
"AI-Budget credited mandate=%s plan=%s users=%d amount=%.2f CHF",
mandateId, planKey, activeUsers, amount,
)
return created
def ensureActivationBudget(self, mandateId: str, planKey: str) -> Optional[Dict[str, Any]]:
"""Idempotent: credit the activation budget only if no SUBSCRIPTION credit exists yet."""
poolAccount = self.getMandateAccount(mandateId)
if not poolAccount:
return self.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung")
existing = self.db.getRecordset(
BillingTransaction,
recordFilter={
"accountId": poolAccount["id"],
"transactionType": TransactionTypeEnum.CREDIT.value,
"referenceType": ReferenceTypeEnum.SUBSCRIPTION.value,
},
)
if existing:
return None
return self.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung")
def adjustAiBudgetForUserChange(self, mandateId: str, planKey: str, delta: int) -> Optional[Dict[str, Any]]:
"""Pro-rata AI budget adjustment when users are added/removed mid-cycle.
delta > 0: user added -> CREDIT pro-rata portion
delta < 0: user removed -> DEBIT pro-rata portion
"""
from modules.datamodels.datamodelSubscription import getPlan
plan = getPlan(planKey)
if not plan or not plan.budgetAiPerUserCHF or plan.budgetAiPerUserCHF <= 0:
return None
from modules.interfaces.interfaceDbSubscription import getRootInterface as _getSubRoot
subRoot = _getSubRoot()
operative = subRoot.getOperativeForMandate(mandateId)
if not operative:
return None
periodStart = operative.get("currentPeriodStart")
periodEnd = operative.get("currentPeriodEnd")
if not periodStart or not periodEnd:
return None
nowTs = datetime.now(timezone.utc).timestamp()
totalSeconds = periodEnd - periodStart
remainingSeconds = max(periodEnd - nowTs, 0)
proRataFraction = remainingSeconds / totalSeconds if totalSeconds > 0 else 0
amount = round(abs(delta) * plan.budgetAiPerUserCHF * proRataFraction, 2)
if amount <= 0:
return None
poolAccount = self.getOrCreateMandateAccount(mandateId)
if delta > 0:
txType = TransactionTypeEnum.CREDIT
label = f"AI-Budget pro-rata +{abs(delta)} User ({planKey})"
else:
txType = TransactionTypeEnum.DEBIT
label = f"AI-Budget pro-rata -{abs(delta)} User ({planKey})"
transaction = BillingTransaction(
accountId=poolAccount["id"],
transactionType=txType,
amount=amount,
description=label,
referenceType=ReferenceTypeEnum.SUBSCRIPTION,
referenceId=mandateId,
)
created = self.createTransaction(transaction)
logger.info(
"AI-Budget pro-rata %s mandate=%s delta=%+d amount=%.2f CHF (fraction=%.4f)",
txType.value, mandateId, delta, amount, proRataFraction,
)
return created
# =========================================================================
# Workflow Cost Query
# =========================================================================
def getWorkflowCost(self, workflowId: str) -> float:
"""Sum of all transaction amounts for a workflow."""
if not workflowId:
return 0.0
transactions = self.db.getRecordset(
BillingTransaction,
recordFilter={"workflowId": workflowId}
)
return sum(t.get("amount", 0.0) for t in transactions)
# =========================================================================
# Statistics Operations
# =========================================================================
def getUsageStatistics(
self,
accountId: str,
periodType: PeriodTypeEnum,
year: int,
month: int = None
) -> List[Dict[str, Any]]:
"""
Get usage statistics for an account.
Args:
accountId: Account ID
periodType: Period type (DAY, MONTH, YEAR)
year: Year
month: Month (for DAY period type)
Returns:
List of statistics dicts
"""
filterDict = {
"accountId": accountId,
"periodType": periodType.value
}
results = self.db.getRecordset(UsageStatistics, recordFilter=filterDict)
# Filter by year
filtered = [s for s in results if s.get("periodStart") and s["periodStart"].year == year]
# Filter by month if specified
if month and periodType == PeriodTypeEnum.DAY:
filtered = [s for s in filtered if s["periodStart"].month == month]
return sorted(filtered, key=lambda x: x.get("periodStart", date.min))
def calculateStatisticsFromTransactions(
self,
accountId: str,
startDate: date,
endDate: date
) -> Dict[str, Any]:
"""
Calculate statistics from transactions for a period.
Args:
accountId: Account ID
startDate: Start date
endDate: End date
Returns:
Statistics dict
"""
transactions = self.getTransactions(accountId, limit=10000, startDate=startDate, endDate=endDate)
# Filter only DEBIT transactions (usage)
debits = [t for t in transactions if t.get("transactionType") == TransactionTypeEnum.DEBIT.value]
totalCost = sum(t.get("amount", 0) for t in debits)
# Calculate by provider
costByProvider = {}
costByModel = {}
for t in debits:
provider = t.get("aicoreProvider", "unknown")
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
model = t.get("aicoreModel", "unknown")
costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0)
# Calculate by feature
costByFeature = {}
for t in debits:
feature = t.get("featureCode", "unknown")
costByFeature[feature] = costByFeature.get(feature, 0) + t.get("amount", 0)
return {
"totalCostCHF": totalCost,
"transactionCount": len(debits),
"costByProvider": costByProvider,
"costByModel": costByModel,
"costByFeature": costByFeature
}
# =========================================================================
# Utility Methods
# =========================================================================
def getBalancesForUser(self, userId: str) -> List[BillingBalanceResponse]:
"""
Get all billing balances for a user across mandates.
Shows the mandate pool balance (shared budget visible to user).
Args:
userId: User ID
Returns:
List of BillingBalanceResponse
"""
from modules.interfaces.interfaceDbApp import getRootInterface
balances = []
try:
# Use rootInterface (privileged, SysAdmin context) to bypass RBAC
# for mandate/user lookups. User access is verified via UserMandate membership.
rootInterface = getRootInterface()
userMandates = rootInterface.getUserMandates(userId)
for um in userMandates:
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
if not mandateId:
continue
mandate = rootInterface.getMandate(mandateId)
if not mandate or not getattr(mandate, "enabled", True):
continue
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None) or f"NA({mandateId})"
settings = self.getSettings(mandateId)
if not settings:
continue
poolAccount = self.getOrCreateMandateAccount(mandateId)
if not poolAccount:
continue
balance = poolAccount.get("balance", 0.0)
warningThreshold = poolAccount.get("warningThreshold", 0.0)
balances.append(BillingBalanceResponse(
mandateId=mandateId,
mandateName=mandateName,
balance=balance,
warningThreshold=warningThreshold,
isWarning=balance <= warningThreshold,
))
except Exception as e:
logger.error(f"Error getting balances for user: {e}")
return balances
def getTransactionsForUser(self, userId: str, limit: int = 100) -> List[Dict[str, Any]]:
"""
Get all transactions for a user across all mandates they belong to.
Since transactions are always recorded on user accounts, we query
directly by user account - clean and simple.
Args:
userId: User ID
limit: Maximum number of results
Returns:
List of transaction dicts
"""
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
allTransactions = []
try:
appInterface = getAppInterface(self.currentUser)
userMandates = appInterface.getUserMandates(userId)
for um in userMandates:
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
if not mandateId:
continue
settings = self.getSettings(mandateId)
if not settings:
continue
# Get user's account in this mandate
userAccount = self.getUserAccount(mandateId, userId)
if not userAccount:
continue
transactions = self.getTransactions(userAccount["id"], limit=limit)
mandate = appInterface.getMandate(mandateId)
mandateName = f"NA({mandateId})"
if mandate:
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None) or f"NA({mandateId})"
for t in transactions:
t["mandateId"] = mandateId
t["mandateName"] = mandateName
allTransactions.append(t)
except Exception as e:
logger.error(f"Error getting transactions for user: {e}")
_sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getTransactionsForUser")
return allTransactions[:limit]
# =========================================================================
# Mandate View Operations (Admin-Level)
# =========================================================================
def getMandateBalances(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]:
"""
Get mandate-level balances.
Args:
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
Returns:
List of mandate balance dicts
"""
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
balances = []
try:
appInterface = getAppInterface(self.currentUser)
# Get settings for filtering
if mandateIds:
allSettings = [self.getSettings(mId) for mId in mandateIds]
allSettings = [s for s in allSettings if s]
else:
allSettings = self.db.getRecordset(BillingSettings)
for settings in allSettings:
mandateId = settings.get("mandateId")
if not mandateId:
continue
mandate = appInterface.getMandate(mandateId)
mandateName = f"NA({mandateId})"
if mandate:
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None) or f"NA({mandateId})"
allMandateAccounts = self.db.getRecordset(
BillingAccount,
recordFilter={"mandateId": mandateId}
)
userCount = sum(1 for acc in allMandateAccounts if acc.get("userId"))
poolAccount = self.getMandateAccount(mandateId)
totalBalance = poolAccount.get("balance", 0.0) if poolAccount else 0.0
balances.append({
"mandateId": mandateId,
"mandateName": mandateName,
"totalBalance": totalBalance,
"userCount": userCount,
"warningThresholdPercent": settings.get("warningThresholdPercent", 10.0),
})
except Exception as e:
logger.error(f"Error getting mandate balances: {e}")
return balances
def getMandateTransactions(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
"""
Get all transactions for specified mandates.
Args:
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
limit: Maximum number of results
Returns:
List of transaction dicts with mandate context
"""
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
allTransactions = []
try:
appInterface = getAppInterface(self.currentUser)
# Determine which mandates to query
if mandateIds:
targetMandateIds = mandateIds
else:
allSettings = self.db.getRecordset(BillingSettings)
targetMandateIds = [s.get("mandateId") for s in allSettings if s.get("mandateId")]
for mandateId in targetMandateIds:
transactions = self.getTransactionsByMandate(mandateId, limit=limit)
mandate = appInterface.getMandate(mandateId)
mandateName = f"NA({mandateId})"
if mandate:
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None) or f"NA({mandateId})"
for t in transactions:
t["mandateId"] = mandateId
t["mandateName"] = mandateName
allTransactions.append(t)
except Exception as e:
logger.error(f"Error getting mandate transactions: {e}")
# Sort by creation date descending and limit
_sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getMandateTransactions")
return allTransactions[:limit]
# =========================================================================
# User View Operations (User-Level with RBAC)
# =========================================================================
def getUserBalancesForMandates(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]:
"""
Get all user-level balances for specified mandates.
Args:
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
Returns:
List of user balance dicts with mandate and user context
"""
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
balances = []
try:
appInterface = getAppInterface(self.currentUser)
allAccounts = self.db.getRecordset(BillingAccount)
allAccounts = [acc for acc in allAccounts if acc.get("userId")]
# Filter by mandate if specified
if mandateIds:
allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds]
# Get all relevant settings in one query
settingsMap = {}
allSettings = self.db.getRecordset(BillingSettings)
for s in allSettings:
settingsMap[s.get("mandateId")] = s
userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId")))
userMap = {}
for userId in userIds:
user = appInterface.getUser(userId)
if user:
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
userMap[userId] = displayName or username or f"NA({userId})"
mandateMap = {}
mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId")))
for mandateId in mandateIdList:
mandate = appInterface.getMandate(mandateId)
if mandate:
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None)
mandateMap[mandateId] = mandateName or f"NA({mandateId})"
for account in allAccounts:
mandateId = account.get("mandateId")
userId = account.get("userId")
if not mandateId or not userId:
continue
settings = settingsMap.get(mandateId)
if not settings:
continue
balance = account.get("balance", 0.0)
warningThreshold = account.get("warningThreshold", 0.0)
balances.append({
"accountId": account.get("id"),
"mandateId": mandateId,
"mandateName": mandateMap.get(mandateId) or (f"NA({mandateId})" if mandateId else None),
"userId": userId,
"userName": userMap.get(userId) or (f"NA({userId})" if userId else None),
"balance": balance,
"warningThreshold": warningThreshold,
"isWarning": balance <= warningThreshold,
"enabled": account.get("enabled", True)
})
except Exception as e:
logger.error(f"Error getting user balances for mandates: {e}")
return balances
@staticmethod
def _mapPaginationColumns(pagination: PaginationParams) -> PaginationParams:
"""Remap frontend column names to DB column names in filters and sort."""
_COL_MAP: dict = {}
_ENRICHED_COLS = {"mandateName", "userName", "mandateId", "userId"}
import copy
p = copy.deepcopy(pagination)
if p.filters:
mapped = {}
for k, v in p.filters.items():
if k in _ENRICHED_COLS:
continue
mapped[_COL_MAP.get(k, k)] = v
p.filters = mapped
if p.sort:
mapped = []
for s in p.sort:
field = s.get("field", "") if isinstance(s, dict) else getattr(s, "field", "")
if field in _ENRICHED_COLS:
continue
newField = _COL_MAP.get(field, field)
if isinstance(s, dict):
mapped.append({**s, "field": newField})
else:
mapped.append({"field": newField, "direction": getattr(s, "direction", "asc")})
p.sort = mapped if mapped else [{"field": "sysCreatedAt", "direction": "desc"}]
return p
def getTransactionsForMandatesPaginated(
self,
mandateIds: Optional[List[str]],
pagination: PaginationParams,
scope: str = "all",
userId: Optional[str] = None,
) -> PaginatedResult:
"""
SQL-level paginated transactions across multiple mandates.
Single SQL query with WHERE accountId = ANY(...), ORDER BY, LIMIT/OFFSET.
Enrichment (userName, mandateName) only for the returned page.
"""
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
try:
mappedPagination = self._mapPaginationColumns(pagination)
allAccounts = self.db.getRecordset(BillingAccount)
if mandateIds:
allAccounts = [a for a in allAccounts if a.get("mandateId") in set(mandateIds)]
accountIds = [a.get("id") for a in allAccounts if a.get("id")]
if not accountIds:
return PaginatedResult(items=[], totalItems=0, totalPages=0)
# Extract free-text search term and run a custom query that covers
# enriched columns (mandateName, userName) and the numeric amount
# column. The generic SQL search only covers TEXT columns of the
# BillingTransaction table, which excludes these fields.
searchTerm: Optional[str] = None
if mappedPagination and mappedPagination.filters:
raw = mappedPagination.filters.get("search")
if isinstance(raw, str) and raw.strip():
searchTerm = raw.strip()
if searchTerm:
searchResult = self._searchTransactionsPaginated(
allAccounts=allAccounts,
accountIds=accountIds,
userId=userId,
searchTerm=searchTerm,
pagination=mappedPagination,
)
pageItems = searchResult["items"]
totalItems = searchResult["totalItems"]
totalPages = searchResult["totalPages"]
else:
recordFilter: Dict[str, Any] = {"accountId": accountIds}
if userId:
recordFilter["createdByUserId"] = userId
result = self.db.getRecordsetPaginated(
BillingTransaction,
pagination=mappedPagination,
recordFilter=recordFilter,
)
pageItems = result.get("items", []) if isinstance(result, dict) else result.items
totalItems = result.get("totalItems", 0) if isinstance(result, dict) else result.totalItems
totalPages = result.get("totalPages", 0) if isinstance(result, dict) else result.totalPages
accountMap = {a.get("id"): a for a in allAccounts}
pageUserIds = set()
pageMandateIds = set()
for t in pageItems:
accId = t.get("accountId")
acc = accountMap.get(accId, {})
mid = acc.get("mandateId")
uid = t.get("createdByUserId") or acc.get("userId")
if uid:
pageUserIds.add(uid)
if mid:
pageMandateIds.add(mid)
appInterface = getAppInterface(self.currentUser)
userMap: Dict[str, str] = {}
if pageUserIds:
users = appInterface.getUsersByIds(list(pageUserIds))
for uid, u in users.items():
dn = getattr(u, "displayName", None) or getattr(u, "username", None) or f"NA({uid})"
userMap[uid] = dn
mandateMap: Dict[str, str] = {}
if pageMandateIds:
mandates = appInterface.getMandatesByIds(list(pageMandateIds))
for mid, m in mandates.items():
mandateMap[mid] = getattr(m, "label", None) or getattr(m, "name", None) or f"NA({mid})"
enriched = []
for t in pageItems:
row = dict(t)
accId = row.get("accountId")
acc = accountMap.get(accId, {})
mid = acc.get("mandateId")
txUserId = row.get("createdByUserId") or acc.get("userId")
row["mandateId"] = mid
row["mandateName"] = mandateMap.get(mid) or (f"NA({mid})" if mid else None)
row["userId"] = txUserId
row["userName"] = userMap.get(txUserId) or (f"NA({txUserId})" if txUserId else None)
enriched.append(row)
return PaginatedResult(items=enriched, totalItems=totalItems, totalPages=totalPages)
except Exception as e:
logger.error(f"Error in getTransactionsForMandatesPaginated: {e}")
return PaginatedResult(items=[], totalItems=0, totalPages=0)
def _searchTransactionsPaginated(
self,
allAccounts: List[Dict[str, Any]],
accountIds: List[str],
userId: Optional[str],
searchTerm: str,
pagination: PaginationParams,
) -> Dict[str, Any]:
"""
Custom paginated search for BillingTransaction that also covers the
enriched columns `mandateName` and `userName` as well as the numeric
`amount` column. Resolves matching mandate/user IDs via the app DB
first, then builds a single SQL query with OR-combined conditions.
"""
import math
from modules.connectors.connectorDbPostgre import getModelFields, parseRecordFields
from modules.datamodels.datamodelUam import UserInDB
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
table = BillingTransaction.__name__
fields = getModelFields(BillingTransaction)
pattern = f"%{searchTerm}%"
# Resolve matching user / mandate IDs via the app DB (which is separate
# from the billing DB and hosts UserInDB / Mandate tables).
matchingUserIds: List[str] = []
matchingMandateIds: List[str] = []
try:
appInterface = getAppInterface(self.currentUser)
appInterface.db._ensure_connection()
with appInterface.db.connection.cursor() as cur:
if appInterface.db._ensureTableExists(UserInDB):
cur.execute(
'SELECT "id" FROM "UserInDB" WHERE '
'COALESCE("username", \'\') ILIKE %s OR '
'COALESCE("fullName", \'\') ILIKE %s OR '
'COALESCE("email", \'\') ILIKE %s',
(pattern, pattern, pattern),
)
matchingUserIds = [r["id"] for r in cur.fetchall() if r.get("id")]
if appInterface.db._ensureTableExists(Mandate):
cur.execute(
'SELECT "id" FROM "Mandate" WHERE '
'COALESCE("label", \'\') ILIKE %s OR '
'COALESCE("name", \'\') ILIKE %s',
(pattern, pattern),
)
matchingMandateIds = [r["id"] for r in cur.fetchall() if r.get("id")]
except Exception as e:
logger.warning(f"_searchTransactionsPaginated: user/mandate resolution failed: {e}")
matchingAccountIds = [
a.get("id") for a in allAccounts
if a.get("id") and a.get("mandateId") in set(matchingMandateIds)
]
# Try to interpret the search term as a number for amount matching.
amountVal: Optional[float] = None
try:
amountVal = float(searchTerm.replace(",", "."))
except Exception:
amountVal = None
whereParts: List[str] = ['"accountId" = ANY(%s)']
whereValues: List[Any] = [accountIds]
if userId:
whereParts.append('"createdByUserId" = %s')
whereValues.append(userId)
# Apply non-search filters from pagination (reuse existing builder for
# everything except the `search` key which we handle explicitly).
import copy
paginationWithoutSearch = copy.deepcopy(pagination) if pagination else None
if paginationWithoutSearch and paginationWithoutSearch.filters:
paginationWithoutSearch.filters = {
k: v for k, v in paginationWithoutSearch.filters.items() if k != "search"
}
orParts: List[str] = []
orValues: List[Any] = []
textCols = [c for c, t in fields.items() if t == "TEXT"]
for col in textCols:
orParts.append(f'COALESCE("{col}"::TEXT, \'\') ILIKE %s')
orValues.append(pattern)
if matchingUserIds:
orParts.append('"createdByUserId" = ANY(%s)')
orValues.append(matchingUserIds)
if matchingAccountIds:
orParts.append('"accountId" = ANY(%s)')
orValues.append(matchingAccountIds)
orParts.append('"amount"::TEXT ILIKE %s')
orValues.append(pattern)
if amountVal is not None:
orParts.append('"amount" = %s')
orValues.append(amountVal)
whereParts.append(f"({' OR '.join(orParts)})")
whereValues.extend(orValues)
# Apply remaining structured filters via the generic helper by feeding
# it a dummy pagination that does NOT include LIMIT/OFFSET. We only
# need the WHERE contribution for the non-search filters here.
extraWhere = ""
extraValues: List[Any] = []
if paginationWithoutSearch and paginationWithoutSearch.filters:
try:
fromPagination = copy.deepcopy(paginationWithoutSearch)
fromPagination.sort = []
fromPagination.page = 1
fromPagination.pageSize = 1
ew, _, _, values, _ = self.db._buildPaginationClauses(
BillingTransaction, fromPagination, recordFilter=None
)
if ew:
extraWhere = ew.replace(" WHERE ", " AND ", 1)
extraValues = list(values)
except Exception as e:
logger.warning(f"_searchTransactionsPaginated: extra-filter build failed: {e}")
whereClause = " WHERE " + " AND ".join(whereParts) + extraWhere
whereValues.extend(extraValues)
# Build ORDER BY from pagination.sort
validColumns = set(fields.keys())
orderParts: List[str] = []
if pagination and pagination.sort:
for sf in pagination.sort:
sfField = sf.get("field") if isinstance(sf, dict) else getattr(sf, "field", None)
sfDir = sf.get("direction", "asc") if isinstance(sf, dict) else getattr(sf, "direction", "asc")
if sfField and sfField in validColumns:
direction = "DESC" if str(sfDir).lower() == "desc" else "ASC"
colType = fields.get(sfField, "TEXT")
if colType == "BOOLEAN":
orderParts.append(f'COALESCE("{sfField}", FALSE) {direction}')
else:
orderParts.append(f'"{sfField}" {direction} NULLS LAST')
if not orderParts:
orderParts.append('"id"')
orderClause = " ORDER BY " + ", ".join(orderParts)
pageSize = pagination.pageSize if pagination else 50
page = pagination.page if pagination else 1
offset = (page - 1) * pageSize
limitClause = f" LIMIT {pageSize} OFFSET {offset}"
try:
self.db._ensure_connection()
with self.db.connection.cursor() as cur:
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
cur.execute(countSql, whereValues)
totalItems = cur.fetchone()["count"]
dataSql = f'SELECT * FROM "{table}"{whereClause}{orderClause}{limitClause}'
cur.execute(dataSql, whereValues)
records = [dict(row) for row in cur.fetchall()]
for rec in records:
parseRecordFields(rec, fields, f"search table {table}")
totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
except Exception as e:
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
try:
self.db.connection.rollback()
except Exception:
pass
return {"items": [], "totalItems": 0, "totalPages": 0}
def _buildScopeFilter(
self,
mandateIds: Optional[List[str]],
scope: str = "all",
userId: Optional[str] = None,
startTs: Optional[float] = None,
endTs: Optional[float] = None,
) -> tuple:
"""Build WHERE clause parts for scoped transaction queries. Returns (conditions, values, accountIds)."""
allAccounts = self.db.getRecordset(BillingAccount)
if mandateIds:
mandateSet = set(mandateIds)
allAccounts = [a for a in allAccounts if a.get("mandateId") in mandateSet]
accountIds = [a.get("id") for a in allAccounts if a.get("id")]
if not accountIds:
return [], [], [], allAccounts
conditions = ['"accountId" = ANY(%s)', '"transactionType" = %s']
values: list = [accountIds, "DEBIT"]
if userId:
conditions.append('"createdByUserId" = %s')
values.append(userId)
if startTs is not None:
conditions.append('"sysCreatedAt" >= %s')
values.append(startTs)
if endTs is not None:
conditions.append('"sysCreatedAt" < %s')
values.append(endTs)
return conditions, values, accountIds, allAccounts
def getTransactionStatisticsAggregated(
self,
mandateIds: Optional[List[str]],
scope: str = "all",
userId: Optional[str] = None,
startTs: Optional[float] = None,
endTs: Optional[float] = None,
bucketSize: str = "month",
) -> Dict[str, Any]:
"""
Pure SQL aggregation for statistics. No row-level loading.
`bucketSize` controls only the time-series aggregation granularity
(`'day' | 'month' | 'year'`); the date range is set via `startTs`/`endTs`.
Returns: totalCost, transactionCount, costByProvider, costByModel,
costByFeature, costByAccountId, timeSeries
"""
table = BillingTransaction.__name__
try:
if not self.db._ensureTableExists(BillingTransaction):
return self._emptyStats()
conditions, values, accountIds, allAccounts = self._buildScopeFilter(
mandateIds, scope, userId, startTs, endTs
)
if not accountIds:
return self._emptyStats()
whereClause = " WHERE " + " AND ".join(conditions)
self.db._ensure_connection()
result: Dict[str, Any] = {}
with self.db.connection.cursor() as cur:
# 1) Totals
cur.execute(
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
values,
)
row = cur.fetchone()
result["totalCost"] = round(float(row["total"]), 4)
result["transactionCount"] = int(row["cnt"])
# 2) GROUP BY aicoreProvider
cur.execute(
f'SELECT COALESCE("aicoreProvider", \'unknown\') AS grp, SUM("amount") AS total '
f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC',
values,
)
result["costByProvider"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()}
# 3) GROUP BY aicoreModel
cur.execute(
f'SELECT COALESCE("aicoreModel", \'unknown\') AS grp, SUM("amount") AS total '
f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC',
values,
)
result["costByModel"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()}
# 4) GROUP BY accountId (will be enriched to mandateName by caller)
cur.execute(
f'SELECT "accountId" AS grp, SUM("amount") AS total '
f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC',
values,
)
result["costByAccountId"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()}
# 5) GROUP BY accountId + featureCode (for costByFeature)
cur.execute(
f'SELECT "accountId", COALESCE("featureCode", \'unknown\') AS fc, SUM("amount") AS total '
f'FROM "{table}"{whereClause} GROUP BY "accountId", fc ORDER BY total DESC',
values,
)
result["costByAccountFeature"] = [
{"accountId": r["accountId"], "featureCode": r["fc"], "total": round(float(r["total"]), 4)}
for r in cur.fetchall()
]
# 6) Time series via DATE_TRUNC on epoch timestamp
_bucketSpec = {
"day": ("day", "%Y-%m-%d"),
"month": ("month", "%Y-%m"),
"year": ("year", "%Y"),
}.get(bucketSize)
if _bucketSpec is None:
raise ValueError(
f"Invalid bucketSize: {bucketSize!r} (expected day|month|year)"
)
_truncUnit, _labelFormat = _bucketSpec
truncExpr = f"DATE_TRUNC('{_truncUnit}', TO_TIMESTAMP(\"sysCreatedAt\"))"
cur.execute(
f'SELECT {truncExpr} AS bucket, SUM("amount") AS total, COUNT(*) AS cnt '
f'FROM "{table}"{whereClause} AND "sysCreatedAt" IS NOT NULL '
f'GROUP BY bucket ORDER BY bucket',
values,
)
timeSeries = []
for r in cur.fetchall():
bucket = r["bucket"]
label = bucket.strftime(_labelFormat) if bucket else "unknown"
timeSeries.append({
"date": label,
"cost": round(float(r["total"]), 4),
"count": int(r["cnt"]),
})
result["timeSeries"] = timeSeries
self.db.connection.commit()
result["_allAccounts"] = allAccounts
return result
except Exception as e:
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
try:
self.db.connection.rollback()
except Exception:
pass
return self._emptyStats()
@staticmethod
def _emptyStats() -> Dict[str, Any]:
return {
"totalCost": 0.0,
"transactionCount": 0,
"costByProvider": {},
"costByModel": {},
"costByAccountId": {},
"costByAccountFeature": [],
"timeSeries": [],
"_allAccounts": [],
}
def getTransactionDistinctValues(
self,
mandateIds: Optional[List[str]],
column: str,
pagination: Optional[PaginationParams] = None,
scope: str = "all",
userId: Optional[str] = None,
) -> List[str]:
"""SQL DISTINCT for filter-values on BillingTransaction, scoped by mandates."""
_COLUMN_MAP = {
"mandateId": "accountId",
"mandateName": "accountId",
}
dbColumn = _COLUMN_MAP.get(column, column)
mappedPagination = self._mapPaginationColumns(pagination) if pagination else None
try:
allAccounts = self.db.getRecordset(BillingAccount)
if mandateIds:
allAccounts = [a for a in allAccounts if a.get("mandateId") in set(mandateIds)]
accountIds = [a.get("id") for a in allAccounts if a.get("id")]
if not accountIds:
return []
recordFilter: Dict[str, Any] = {"accountId": accountIds}
if userId:
recordFilter["createdByUserId"] = userId
if column in ("mandateName", "userName"):
return self._getEnrichedDistinctValues(column, allAccounts, recordFilter, mappedPagination)
return self.db.getDistinctColumnValues(
BillingTransaction, dbColumn, mappedPagination, recordFilter
)
except Exception as e:
logger.error(f"Error in getTransactionDistinctValues({column}): {e}")
return []
def _getEnrichedDistinctValues(
self,
column: str,
allAccounts: List[Dict],
recordFilter: Dict[str, Any],
pagination: Optional[PaginationParams],
) -> List[str]:
"""Resolve enriched columns (mandateName, userName) via batch lookup."""
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
if column == "mandateName":
mandateIds = list({a.get("mandateId") for a in allAccounts if a.get("mandateId")})
appInterface = getAppInterface(self.currentUser)
mandates = appInterface.getMandatesByIds(mandateIds)
return sorted(
{getattr(m, "label", None) or getattr(m, "name", None) or f"NA({mid})" for mid, m in mandates.items()},
key=lambda v: v.lower(),
)
if column == "userName":
dbCol = "createdByUserId"
values = self.db.getDistinctColumnValues(BillingTransaction, dbCol, pagination, recordFilter)
if not values:
return []
appInterface = getAppInterface(self.currentUser)
users = appInterface.getUsersByIds(values)
return sorted(
{getattr(u, "displayName", None) or getattr(u, "username", None) or f"NA({uid})" for uid, u in users.items()},
key=lambda v: v.lower(),
)
return []
def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
"""
Get all transactions for specified mandates.
All usage transactions are on user accounts (audit trail).
Args:
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
limit: Maximum number of results
Returns:
List of transaction dicts with mandate and user context
"""
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
allTransactions = []
try:
appInterface = getAppInterface(self.currentUser)
# Get ALL accounts (both USER and MANDATE types) to cover all billing models
allAccounts = self.db.getRecordset(BillingAccount)
# Filter by mandate if specified
if mandateIds:
allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds]
# Build account to user/mandate mapping
accountMap = {}
for acc in allAccounts:
accountMap[acc.get("id")] = {
"mandateId": acc.get("mandateId"),
"userId": acc.get("userId")
}
userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId")))
userMap = {}
for userId in userIds:
user = appInterface.getUser(userId)
if user:
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
userMap[userId] = displayName or username or f"NA({userId})"
mandateMap = {}
mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId")))
for mandateId in mandateIdList:
mandate = appInterface.getMandate(mandateId)
if mandate:
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", None) if isinstance(mandate, dict) else None)
mandateMap[mandateId] = mandateName or f"NA({mandateId})"
# Get transactions for all accounts and collect createdByUserIds
rawTransactions = []
for account in allAccounts:
accountId = account.get("id")
if not accountId:
continue
transactions = self.getTransactions(accountId, limit=limit)
accountInfo = accountMap.get(accountId, {})
mandateId = accountInfo.get("mandateId")
accountUserId = accountInfo.get("userId")
for t in transactions:
t["_accountUserId"] = accountUserId
t["_accountMandateId"] = mandateId
rawTransactions.append(t)
# Resolve createdByUserIds that are not yet in userMap
extraUserIds = set()
for t in rawTransactions:
cbUserId = t.get("createdByUserId")
if cbUserId and cbUserId not in userMap:
extraUserIds.add(cbUserId)
for uid in extraUserIds:
user = appInterface.getUser(uid)
if user:
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
userMap[uid] = displayName or username or f"NA({uid})"
for t in rawTransactions:
mandateId = t.pop("_accountMandateId", None)
accountUserId = t.pop("_accountUserId", None)
t["mandateId"] = mandateId
t["mandateName"] = mandateMap.get(mandateId) or (f"NA({mandateId})" if mandateId else None)
txUserId = t.get("createdByUserId") or accountUserId
t["userId"] = txUserId
t["userName"] = userMap.get(txUserId) or (f"NA({txUserId})" if txUserId else None)
allTransactions.append(t)
except Exception as e:
logger.error(f"Error getting user transactions for mandates: {e}")
# Sort by creation date descending and limit
_sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getUserTransactionsForMandates")
return allTransactions[:limit]