2145 lines
84 KiB
Python
2145 lines
84 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
||
# All rights reserved.
|
||
"""
|
||
Interface for Billing operations.
|
||
Manages billing accounts, transactions, and usage statistics.
|
||
|
||
All billing data is stored in the poweron_billing database.
|
||
"""
|
||
|
||
import logging
|
||
from typing import Dict, Any, List, Optional, Union
|
||
from datetime import date, datetime, timedelta, timezone
|
||
import uuid
|
||
|
||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||
from modules.shared.configuration import APP_CONFIG
|
||
from modules.shared.dbRegistry import registerDatabase
|
||
from modules.shared.timeUtils import getUtcTimestamp
|
||
from modules.datamodels.datamodelUam import User, Mandate
|
||
from modules.datamodels.datamodelMembership import UserMandate
|
||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
|
||
from modules.datamodels.datamodelBilling import (
|
||
BillingAccount,
|
||
BillingTransaction,
|
||
BillingSettings,
|
||
StripeWebhookEvent,
|
||
UsageStatistics,
|
||
TransactionTypeEnum,
|
||
ReferenceTypeEnum,
|
||
PeriodTypeEnum,
|
||
BillingBalanceResponse,
|
||
BillingCheckResult,
|
||
STORAGE_PRICE_PER_GB_CHF,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _logBillingTransactionsMissingSysCreatedAt(rows: List[Dict[str, Any]], context: str) -> None:
|
||
"""Log ERROR when sysCreatedAt is missing; does not raise."""
|
||
missingIds = [r.get("id") for r in rows if r.get("sysCreatedAt") is None]
|
||
if not missingIds:
|
||
return
|
||
cap = 40
|
||
sample = missingIds[:cap]
|
||
suffix = f"; ... (+{len(missingIds) - cap} more)" if len(missingIds) > cap else ""
|
||
logger.error(
|
||
"BillingTransaction missing sysCreatedAt (%s): count=%s; transactionIds=%s%s",
|
||
context,
|
||
len(missingIds),
|
||
sample,
|
||
suffix,
|
||
)
|
||
|
||
|
||
def _numericSysCreatedAtForSort(row: Dict[str, Any]) -> float:
|
||
v = row["sysCreatedAt"]
|
||
if isinstance(v, datetime):
|
||
return v.timestamp()
|
||
return float(v)
|
||
|
||
|
||
def _sortBillingTransactionsBySysCreatedAtDesc(rows: List[Dict[str, Any]], context: str) -> None:
|
||
_logBillingTransactionsMissingSysCreatedAt(rows, context)
|
||
valid = [r for r in rows if r.get("sysCreatedAt") is not None]
|
||
invalid = [r for r in rows if r.get("sysCreatedAt") is None]
|
||
valid.sort(key=_numericSysCreatedAtForSort, reverse=True)
|
||
rows[:] = valid + invalid
|
||
|
||
|
||
def _getAppDatabaseConnector() -> DatabaseConnector:
|
||
"""App DB connector (same config as UserMandate reads in this module)."""
|
||
return DatabaseConnector(
|
||
dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"),
|
||
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
||
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
|
||
dbUser=APP_CONFIG.get("DB_USER"),
|
||
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
|
||
)
|
||
|
||
|
||
def _getRootMandateIdFromAppDb(appDb: DatabaseConnector) -> Optional[str]:
|
||
"""Resolve root mandate id (name='root', isSystem=True) from app database."""
|
||
try:
|
||
rows = appDb.getRecordset(Mandate, recordFilter={"name": "root", "isSystem": True})
|
||
if rows:
|
||
rid = rows[0].get("id")
|
||
return str(rid) if rid is not None else None
|
||
except Exception as e:
|
||
logger.warning("Could not resolve root mandate id from app DB: %s", e)
|
||
return None
|
||
|
||
|
||
_cachedRootMandateId: Optional[str] = None
|
||
_rootMandateIdCacheResolved: bool = False
|
||
|
||
|
||
def _getCachedRootMandateId() -> Optional[str]:
|
||
"""Lazy-cached root mandate id (name=root, isSystem=True) for hot paths."""
|
||
global _cachedRootMandateId, _rootMandateIdCacheResolved
|
||
if not _rootMandateIdCacheResolved:
|
||
appDb = _getAppDatabaseConnector()
|
||
_cachedRootMandateId = _getRootMandateIdFromAppDb(appDb)
|
||
_rootMandateIdCacheResolved = True
|
||
return _cachedRootMandateId
|
||
|
||
|
||
# Singleton factory for BillingObjects instances
|
||
_billingInterfaces: Dict[str, "BillingObjects"] = {}
|
||
|
||
# Database name for billing
|
||
BILLING_DATABASE = "poweron_billing"
|
||
registerDatabase(BILLING_DATABASE)
|
||
|
||
|
||
def getInterface(currentUser: User, mandateId: str = None) -> "BillingObjects":
|
||
"""
|
||
Factory function to get or create a BillingObjects instance.
|
||
|
||
Args:
|
||
currentUser: Current user object
|
||
mandateId: Mandate ID for context
|
||
|
||
Returns:
|
||
BillingObjects instance
|
||
"""
|
||
cacheKey = f"{currentUser.id}_{mandateId}"
|
||
|
||
if cacheKey not in _billingInterfaces:
|
||
_billingInterfaces[cacheKey] = BillingObjects(currentUser, mandateId)
|
||
else:
|
||
_billingInterfaces[cacheKey].setUserContext(currentUser, mandateId)
|
||
|
||
return _billingInterfaces[cacheKey]
|
||
|
||
|
||
def _getRootInterface() -> "BillingObjects":
|
||
"""Get interface with system access for bootstrap operations."""
|
||
from modules.security.rootAccess import getRootUser
|
||
rootUser = getRootUser()
|
||
return BillingObjects(rootUser, mandateId=None)
|
||
|
||
|
||
class BillingObjects:
|
||
"""
|
||
Interface for billing operations.
|
||
Manages accounts, transactions, settings, and statistics.
|
||
"""
|
||
|
||
def __init__(self, currentUser: Optional[User] = None, mandateId: str = None):
|
||
"""
|
||
Initialize the billing interface.
|
||
|
||
Args:
|
||
currentUser: Current user object
|
||
mandateId: Mandate ID for context
|
||
"""
|
||
self.currentUser = currentUser
|
||
self.userId = currentUser.id if currentUser else None
|
||
self.mandateId = mandateId
|
||
|
||
# Initialize database connection
|
||
self._initializeDatabase()
|
||
|
||
def _initializeDatabase(self):
|
||
"""Initialize database connection."""
|
||
self.db = DatabaseConnector(
|
||
dbDatabase=BILLING_DATABASE,
|
||
dbHost=APP_CONFIG.get('DB_HOST', 'localhost'),
|
||
dbPort=int(APP_CONFIG.get('DB_PORT', '5432')),
|
||
dbUser=APP_CONFIG.get('DB_USER'),
|
||
dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET')
|
||
)
|
||
|
||
def setUserContext(self, currentUser: User, mandateId: str = None):
|
||
"""
|
||
Update user context.
|
||
|
||
Args:
|
||
currentUser: Current user object
|
||
mandateId: Mandate ID for context
|
||
"""
|
||
self.currentUser = currentUser
|
||
self.userId = currentUser.id if currentUser else None
|
||
self.mandateId = mandateId
|
||
|
||
# =========================================================================
|
||
# BillingSettings Operations
|
||
# =========================================================================
|
||
|
||
def getSettings(self, mandateId: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Get billing settings for a mandate.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
|
||
Returns:
|
||
BillingSettings dict or None if not found
|
||
"""
|
||
try:
|
||
results = self.db.getRecordset(
|
||
BillingSettings,
|
||
recordFilter={"mandateId": mandateId}
|
||
)
|
||
if not results:
|
||
return None
|
||
return dict(results[0])
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing settings: {e}")
|
||
return None
|
||
|
||
def createSettings(self, settings: BillingSettings) -> Dict[str, Any]:
|
||
"""
|
||
Create billing settings for a mandate.
|
||
|
||
Args:
|
||
settings: BillingSettings object
|
||
|
||
Returns:
|
||
Created settings dict
|
||
"""
|
||
settingsDict = settings.model_dump(exclude_none=True)
|
||
return self.db.recordCreate(BillingSettings, settingsDict)
|
||
|
||
def updateSettings(self, settingsId: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Update billing settings.
|
||
|
||
Args:
|
||
settingsId: Settings ID
|
||
updates: Fields to update
|
||
|
||
Returns:
|
||
Updated settings dict or None
|
||
"""
|
||
return self.db.recordModify(BillingSettings, settingsId, updates)
|
||
|
||
def getOrCreateSettings(self, mandateId: str) -> Dict[str, Any]:
|
||
"""
|
||
Get or create billing settings for a mandate.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
|
||
Returns:
|
||
BillingSettings dict
|
||
"""
|
||
existing = self.getSettings(mandateId)
|
||
if existing:
|
||
return existing
|
||
|
||
settings = BillingSettings(
|
||
mandateId=mandateId,
|
||
warningThresholdPercent=10.0,
|
||
notifyOnWarning=True,
|
||
)
|
||
return self.createSettings(settings)
|
||
|
||
# =========================================================================
|
||
# BillingAccount Operations
|
||
# =========================================================================
|
||
|
||
def getAccount(self, accountId: str) -> Optional[Dict[str, Any]]:
|
||
"""Get a billing account by ID."""
|
||
try:
|
||
results = self.db.getRecordset(
|
||
BillingAccount,
|
||
recordFilter={"id": accountId}
|
||
)
|
||
return results[0] if results else None
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing account: {e}")
|
||
return None
|
||
|
||
def getMandateAccount(self, mandateId: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Get the mandate-level billing account.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
|
||
Returns:
|
||
BillingAccount dict or None
|
||
"""
|
||
try:
|
||
results = self.db.getRecordset(
|
||
BillingAccount,
|
||
recordFilter={
|
||
"mandateId": mandateId,
|
||
"userId": None
|
||
}
|
||
)
|
||
return results[0] if results else None
|
||
except Exception as e:
|
||
logger.error(f"Error getting mandate account: {e}")
|
||
return None
|
||
|
||
def getUserAccount(self, mandateId: str, userId: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Get a user-level billing account within a mandate.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
userId: User ID
|
||
|
||
Returns:
|
||
BillingAccount dict or None
|
||
"""
|
||
try:
|
||
results = self.db.getRecordset(
|
||
BillingAccount,
|
||
recordFilter={
|
||
"mandateId": mandateId,
|
||
"userId": userId
|
||
}
|
||
)
|
||
return results[0] if results else None
|
||
except Exception as e:
|
||
logger.error(f"Error getting user account: {e}")
|
||
return None
|
||
|
||
def getAccountsByMandate(self, mandateId: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
Get all billing accounts for a mandate.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
|
||
Returns:
|
||
List of BillingAccount dicts
|
||
"""
|
||
try:
|
||
return self.db.getRecordset(
|
||
BillingAccount,
|
||
recordFilter={"mandateId": mandateId}
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error getting accounts for mandate: {e}")
|
||
return []
|
||
|
||
def createAccount(self, account: BillingAccount) -> Dict[str, Any]:
|
||
"""
|
||
Create a new billing account.
|
||
|
||
Args:
|
||
account: BillingAccount object
|
||
|
||
Returns:
|
||
Created account dict
|
||
"""
|
||
accountDict = account.model_dump(exclude_none=True)
|
||
return self.db.recordCreate(BillingAccount, accountDict)
|
||
|
||
def updateAccountBalance(self, accountId: str, newBalance: float) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Update account balance atomically.
|
||
|
||
Args:
|
||
accountId: Account ID
|
||
newBalance: New balance value
|
||
|
||
Returns:
|
||
Updated account dict or None
|
||
"""
|
||
return self.db.recordModify(BillingAccount, accountId, {"balance": newBalance})
|
||
|
||
def getOrCreateMandateAccount(self, mandateId: str, initialBalance: float = 0.0) -> Dict[str, Any]:
|
||
"""
|
||
Get or create a mandate-level billing account.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
initialBalance: Initial balance if creating
|
||
|
||
Returns:
|
||
BillingAccount dict
|
||
"""
|
||
existing = self.getMandateAccount(mandateId)
|
||
if existing:
|
||
return existing
|
||
|
||
account = BillingAccount(
|
||
mandateId=mandateId,
|
||
balance=initialBalance,
|
||
enabled=True
|
||
)
|
||
return self.createAccount(account)
|
||
|
||
def getOrCreateUserAccount(self, mandateId: str, userId: str, initialBalance: float = 0.0) -> Dict[str, Any]:
|
||
"""
|
||
Get or create a user-level billing account.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
userId: User ID
|
||
initialBalance: Initial balance if creating
|
||
|
||
Returns:
|
||
BillingAccount dict
|
||
"""
|
||
existing = self.getUserAccount(mandateId, userId)
|
||
if existing:
|
||
return existing
|
||
|
||
account = BillingAccount(
|
||
mandateId=mandateId,
|
||
userId=userId,
|
||
balance=initialBalance,
|
||
enabled=True
|
||
)
|
||
created = self.createAccount(account)
|
||
|
||
# If initial balance > 0, create a SYSTEM credit transaction
|
||
if initialBalance > 0:
|
||
self.createTransaction(BillingTransaction(
|
||
accountId=created["id"],
|
||
transactionType=TransactionTypeEnum.CREDIT,
|
||
amount=initialBalance,
|
||
description="Initial credit for new user",
|
||
referenceType=ReferenceTypeEnum.SYSTEM
|
||
))
|
||
|
||
return created
|
||
|
||
def ensureAllMandateSettingsExist(self) -> int:
|
||
"""
|
||
Efficiently ensure all mandates have billing settings.
|
||
Creates default settings (0 CHF) for mandates without settings.
|
||
Uses bulk queries to minimize database connections.
|
||
|
||
Returns:
|
||
Number of settings created
|
||
"""
|
||
try:
|
||
settingsCreated = 0
|
||
|
||
# Step 1: Get all existing billing settings in one query (from billing DB)
|
||
allSettings = self.db.getRecordset(BillingSettings)
|
||
existingMandateIds = set(s.get("mandateId") for s in allSettings if s.get("mandateId"))
|
||
|
||
# Step 2: Get all mandates from APP database (separate connection)
|
||
appDb = DatabaseConnector(
|
||
dbDatabase=APP_CONFIG.get('DB_DATABASE', 'poweron_app'),
|
||
dbHost=APP_CONFIG.get('DB_HOST', 'localhost'),
|
||
dbPort=int(APP_CONFIG.get('DB_PORT', '5432')),
|
||
dbUser=APP_CONFIG.get('DB_USER'),
|
||
dbPassword=APP_CONFIG.get('DB_PASSWORD_SECRET')
|
||
)
|
||
allMandates = appDb.getRecordset(Mandate, recordFilter={"enabled": True})
|
||
|
||
# Step 3: Create settings for mandates that don't have them
|
||
for mandate in allMandates:
|
||
mandateId = mandate.get("id")
|
||
if not mandateId or mandateId in existingMandateIds:
|
||
continue
|
||
|
||
settings = BillingSettings(
|
||
mandateId=mandateId,
|
||
warningThresholdPercent=10.0,
|
||
notifyOnWarning=True,
|
||
)
|
||
self.createSettings(settings)
|
||
existingMandateIds.add(mandateId)
|
||
settingsCreated += 1
|
||
|
||
if settingsCreated > 0:
|
||
logger.info(f"Created {settingsCreated} missing billing settings for mandates")
|
||
|
||
return settingsCreated
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error ensuring mandate settings exist: {e}")
|
||
return 0
|
||
|
||
def ensureAllUserAccountsExist(self) -> int:
|
||
"""
|
||
Ensure all users across all mandates have billing accounts.
|
||
User accounts are always created for audit trail with initial balance 0.0.
|
||
Uses bulk queries to minimize database connections.
|
||
|
||
Returns:
|
||
Number of accounts created
|
||
"""
|
||
try:
|
||
accountsCreated = 0
|
||
appDb = _getAppDatabaseConnector()
|
||
|
||
allSettings = self.db.getRecordset(BillingSettings)
|
||
billingMandateIds = set(
|
||
s.get("mandateId") for s in allSettings if s.get("mandateId")
|
||
)
|
||
|
||
if not billingMandateIds:
|
||
logger.debug("No billable mandates found, skipping account check")
|
||
return 0
|
||
|
||
allAccounts = self.db.getRecordset(BillingAccount)
|
||
existingAccountKeys = set()
|
||
for acc in allAccounts:
|
||
if not acc.get("userId"):
|
||
continue
|
||
key = (acc.get("mandateId"), acc.get("userId"))
|
||
existingAccountKeys.add(key)
|
||
|
||
allUserMandates = appDb.getRecordset(
|
||
UserMandate,
|
||
recordFilter={"enabled": True}
|
||
)
|
||
|
||
for um in allUserMandates:
|
||
mandateId = um.get("mandateId")
|
||
userId = um.get("userId")
|
||
|
||
if not mandateId or not userId:
|
||
continue
|
||
|
||
if mandateId not in billingMandateIds:
|
||
continue
|
||
|
||
key = (mandateId, userId)
|
||
if key in existingAccountKeys:
|
||
continue
|
||
|
||
account = BillingAccount(
|
||
mandateId=mandateId,
|
||
userId=userId,
|
||
balance=0.0,
|
||
enabled=True
|
||
)
|
||
self.createAccount(account)
|
||
|
||
existingAccountKeys.add(key)
|
||
accountsCreated += 1
|
||
|
||
if accountsCreated > 0:
|
||
logger.info(f"Created {accountsCreated} missing billing accounts")
|
||
|
||
return accountsCreated
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error ensuring user accounts exist: {e}")
|
||
return 0
|
||
|
||
# =========================================================================
|
||
# BillingTransaction Operations
|
||
# =========================================================================
|
||
|
||
def createTransaction(self, transaction: BillingTransaction, balanceAccountId: str = None) -> Dict[str, Any]:
|
||
"""
|
||
Create a new billing transaction and update account balance.
|
||
|
||
The transaction is always recorded against transaction.accountId (audit trail).
|
||
The balance is updated on balanceAccountId if provided, otherwise on transaction.accountId.
|
||
This allows recording a transaction on a user account (audit) while deducting
|
||
from a mandate pool account (shared budget).
|
||
|
||
Args:
|
||
transaction: BillingTransaction object
|
||
balanceAccountId: Optional account ID for balance update (defaults to transaction.accountId)
|
||
|
||
Returns:
|
||
Created transaction dict
|
||
"""
|
||
# Validate that the transaction's account exists
|
||
txAccount = self.getAccount(transaction.accountId)
|
||
if not txAccount:
|
||
raise ValueError(f"Transaction account {transaction.accountId} not found")
|
||
|
||
# Determine which account to update balance on
|
||
targetBalanceAccountId = balanceAccountId or transaction.accountId
|
||
if targetBalanceAccountId == transaction.accountId:
|
||
balanceAccount = txAccount
|
||
else:
|
||
balanceAccount = self.getAccount(targetBalanceAccountId)
|
||
if not balanceAccount:
|
||
raise ValueError(f"Balance account {targetBalanceAccountId} not found")
|
||
|
||
currentBalance = balanceAccount.get("balance", 0.0)
|
||
|
||
# Calculate new balance
|
||
if transaction.transactionType == TransactionTypeEnum.CREDIT:
|
||
newBalance = currentBalance + transaction.amount
|
||
elif transaction.transactionType == TransactionTypeEnum.DEBIT:
|
||
newBalance = currentBalance - transaction.amount
|
||
else: # ADJUSTMENT
|
||
newBalance = currentBalance + transaction.amount
|
||
|
||
# Create transaction record (always on transaction.accountId for audit)
|
||
transactionDict = transaction.model_dump(exclude_none=True)
|
||
created = self.db.recordCreate(BillingTransaction, transactionDict)
|
||
|
||
# Update balance on the target account
|
||
self.updateAccountBalance(targetBalanceAccountId, newBalance)
|
||
|
||
logger.info(f"Billing transaction created: {transaction.transactionType.value} {transaction.amount} CHF, "
|
||
f"audit={transaction.accountId}, balance on {targetBalanceAccountId}: {currentBalance} -> {newBalance}")
|
||
|
||
return created
|
||
|
||
def getTransactions(
|
||
self,
|
||
accountId: str,
|
||
limit: int = 100,
|
||
offset: int = 0,
|
||
startDate: date = None,
|
||
endDate: date = None,
|
||
pagination: PaginationParams = None
|
||
) -> Union[List[Dict[str, Any]], PaginatedResult]:
|
||
"""
|
||
Get transactions for an account.
|
||
|
||
When pagination is provided, uses database-level pagination and returns
|
||
PaginatedResult. Otherwise falls back to in-memory filtering/sorting/slicing.
|
||
|
||
Args:
|
||
accountId: Account ID
|
||
limit: Maximum number of results (legacy path)
|
||
offset: Offset for pagination (legacy path)
|
||
startDate: Filter by start date (legacy path)
|
||
endDate: Filter by end date (legacy path)
|
||
pagination: PaginationParams for DB-level pagination
|
||
|
||
Returns:
|
||
PaginatedResult when pagination is provided, List of dicts otherwise
|
||
"""
|
||
try:
|
||
if pagination:
|
||
recordFilter = {"accountId": accountId}
|
||
result = self.db.getRecordsetPaginated(
|
||
BillingTransaction,
|
||
pagination=pagination,
|
||
recordFilter=recordFilter
|
||
)
|
||
_logBillingTransactionsMissingSysCreatedAt(
|
||
result["items"],
|
||
"getTransactions(accountId) paginated",
|
||
)
|
||
return PaginatedResult(
|
||
items=result["items"],
|
||
totalItems=result["totalItems"],
|
||
totalPages=result["totalPages"]
|
||
)
|
||
|
||
filterDict = {"accountId": accountId}
|
||
results = self.db.getRecordset(BillingTransaction, recordFilter=filterDict)
|
||
|
||
if startDate or endDate:
|
||
filtered = []
|
||
for t in results:
|
||
createdAt = t.get("sysCreatedAt")
|
||
if createdAt:
|
||
tDate = createdAt.date() if isinstance(createdAt, datetime) else createdAt
|
||
if startDate and tDate < startDate:
|
||
continue
|
||
if endDate and tDate > endDate:
|
||
continue
|
||
filtered.append(t)
|
||
results = filtered
|
||
|
||
_sortBillingTransactionsBySysCreatedAtDesc(results, "getTransactions(accountId)")
|
||
|
||
return results[offset:offset + limit]
|
||
except Exception as e:
|
||
logger.error(f"Error getting transactions: {e}")
|
||
if pagination:
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
return []
|
||
|
||
def getTransactionsByMandate(
|
||
self,
|
||
mandateId: str,
|
||
limit: int = 100,
|
||
pagination: PaginationParams = None
|
||
) -> Union[List[Dict[str, Any]], PaginatedResult]:
|
||
"""
|
||
Get all transactions for a mandate (across all accounts).
|
||
|
||
When pagination is provided, collects all accountIds for the mandate and
|
||
issues a single DB query with SQL-level filtering, sorting, and pagination.
|
||
Otherwise falls back to per-account querying and in-memory merging.
|
||
|
||
Args:
|
||
mandateId: Mandate ID
|
||
limit: Maximum number of results (legacy path)
|
||
pagination: PaginationParams for DB-level pagination
|
||
|
||
Returns:
|
||
PaginatedResult when pagination is provided, List of dicts otherwise
|
||
"""
|
||
accounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId})
|
||
accountIds = [acc["id"] for acc in accounts if acc.get("id")]
|
||
|
||
if not accountIds:
|
||
if pagination:
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
return []
|
||
|
||
if pagination:
|
||
result = self.db.getRecordsetPaginated(
|
||
BillingTransaction,
|
||
pagination=pagination,
|
||
recordFilter={"accountId": accountIds}
|
||
)
|
||
return PaginatedResult(
|
||
items=result["items"],
|
||
totalItems=result["totalItems"],
|
||
totalPages=result["totalPages"]
|
||
)
|
||
|
||
allTransactions = []
|
||
for account in accounts:
|
||
transactions = self.getTransactions(account["id"], limit=limit)
|
||
allTransactions.extend(transactions)
|
||
|
||
_sortBillingTransactionsBySysCreatedAtDesc(
|
||
allTransactions,
|
||
"getTransactionsByMandate",
|
||
)
|
||
return allTransactions[:limit]
|
||
|
||
# =========================================================================
|
||
# StripeWebhookEvent Operations (idempotency)
|
||
# =========================================================================
|
||
|
||
def getStripeWebhookEventByEventId(self, event_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Check if a Stripe event has already been processed (idempotency).
|
||
|
||
Args:
|
||
event_id: Stripe event ID (evt_xxx)
|
||
|
||
Returns:
|
||
Event record if exists, else None
|
||
"""
|
||
try:
|
||
results = self.db.getRecordset(
|
||
StripeWebhookEvent,
|
||
recordFilter={"event_id": event_id}
|
||
)
|
||
return results[0] if results else None
|
||
except Exception as e:
|
||
logger.error(f"Error checking Stripe webhook event: {e}")
|
||
return None
|
||
|
||
def createStripeWebhookEvent(self, event_id: str) -> Dict[str, Any]:
|
||
"""
|
||
Record that a Stripe event has been processed.
|
||
|
||
Args:
|
||
event_id: Stripe event ID (evt_xxx)
|
||
|
||
Returns:
|
||
Created event record
|
||
"""
|
||
record = StripeWebhookEvent(event_id=event_id)
|
||
return self.db.recordCreate(StripeWebhookEvent, record.model_dump())
|
||
|
||
def getPaymentTransactionByReferenceId(self, referenceId: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Find an existing Stripe payment credit transaction by Checkout Session ID.
|
||
|
||
Args:
|
||
referenceId: Stripe Checkout Session ID (cs_xxx)
|
||
|
||
Returns:
|
||
Transaction record if found, else None
|
||
"""
|
||
try:
|
||
results = self.db.getRecordset(
|
||
BillingTransaction,
|
||
recordFilter={
|
||
"referenceType": ReferenceTypeEnum.PAYMENT.value,
|
||
"referenceId": referenceId,
|
||
}
|
||
)
|
||
return results[0] if results else None
|
||
except Exception as e:
|
||
logger.error(f"Error checking Stripe payment transaction by referenceId: {e}")
|
||
return None
|
||
|
||
# =========================================================================
|
||
# Balance Check Operations
|
||
# =========================================================================
|
||
|
||
def checkBalance(self, mandateId: str, userId: str, estimatedCost: float) -> BillingCheckResult:
|
||
"""
|
||
Check if there's sufficient balance for an operation.
|
||
|
||
Checks mandate pool balance against estimatedCost.
|
||
User accounts are ensured to exist for audit trail.
|
||
Missing settings: treated as PREPAY_MANDATE with empty pool.
|
||
"""
|
||
self.getOrCreateUserAccount(mandateId, userId, initialBalance=0.0)
|
||
|
||
poolAccount = self.getOrCreateMandateAccount(mandateId)
|
||
currentBalance = poolAccount.get("balance", 0.0)
|
||
|
||
if currentBalance < estimatedCost:
|
||
return BillingCheckResult(
|
||
allowed=False,
|
||
reason="INSUFFICIENT_BALANCE",
|
||
currentBalance=currentBalance,
|
||
requiredAmount=estimatedCost,
|
||
)
|
||
|
||
return BillingCheckResult(allowed=True, currentBalance=currentBalance)
|
||
|
||
def recordUsage(
|
||
self,
|
||
mandateId: str,
|
||
userId: str,
|
||
priceCHF: float,
|
||
workflowId: str = None,
|
||
featureInstanceId: str = None,
|
||
featureCode: str = None,
|
||
aicoreProvider: str = None,
|
||
aicoreModel: str = None,
|
||
description: str = "AI Usage",
|
||
processingTime: float = None,
|
||
bytesSent: int = None,
|
||
bytesReceived: int = None,
|
||
errorCount: int = None
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Record usage cost as a billing transaction.
|
||
|
||
Transaction is recorded on the user's account (audit trail).
|
||
Balance is always deducted from the mandate pool account (PREPAY_MANDATE).
|
||
"""
|
||
if priceCHF <= 0:
|
||
return None
|
||
|
||
settings = self.getSettings(mandateId)
|
||
if not settings:
|
||
logger.debug(f"No billing settings for mandate {mandateId}, skipping usage recording")
|
||
return None
|
||
|
||
userAccount = self.getOrCreateUserAccount(mandateId, userId)
|
||
|
||
transaction = BillingTransaction(
|
||
accountId=userAccount["id"],
|
||
transactionType=TransactionTypeEnum.DEBIT,
|
||
amount=priceCHF,
|
||
description=description,
|
||
referenceType=ReferenceTypeEnum.WORKFLOW,
|
||
workflowId=workflowId,
|
||
featureInstanceId=featureInstanceId,
|
||
featureCode=featureCode,
|
||
aicoreProvider=aicoreProvider,
|
||
aicoreModel=aicoreModel,
|
||
createdByUserId=userId,
|
||
processingTime=processingTime,
|
||
bytesSent=bytesSent,
|
||
bytesReceived=bytesReceived,
|
||
errorCount=errorCount
|
||
)
|
||
|
||
poolAccount = self.getOrCreateMandateAccount(mandateId)
|
||
return self.createTransaction(transaction, balanceAccountId=poolAccount["id"])
|
||
|
||
def _parseSettingsDateTime(self, value: Any) -> Optional[datetime]:
|
||
"""Parse datetime from billing settings row (ISO string or datetime)."""
|
||
if value is None:
|
||
return None
|
||
if isinstance(value, datetime):
|
||
if value.tzinfo:
|
||
return value.astimezone(timezone.utc)
|
||
return value.replace(tzinfo=timezone.utc)
|
||
if isinstance(value, str):
|
||
s = value.replace("Z", "+00:00")
|
||
try:
|
||
dt = datetime.fromisoformat(s)
|
||
except ValueError:
|
||
return None
|
||
if dt.tzinfo:
|
||
return dt.astimezone(timezone.utc)
|
||
return dt.replace(tzinfo=timezone.utc)
|
||
return None
|
||
|
||
def resetStorageBillingPeriod(self, mandateId: str, periodStartAt: datetime) -> None:
|
||
"""Reset storage watermark state for a new subscription billing period (e.g. Stripe invoice.paid)."""
|
||
if periodStartAt.tzinfo is None:
|
||
periodStartAt = periodStartAt.replace(tzinfo=timezone.utc)
|
||
else:
|
||
periodStartAt = periodStartAt.astimezone(timezone.utc)
|
||
settings = self.getOrCreateSettings(mandateId)
|
||
prev = self._parseSettingsDateTime(settings.get("storagePeriodStartAt"))
|
||
if prev is not None and abs((prev - periodStartAt).total_seconds()) < 2:
|
||
return
|
||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRoot
|
||
|
||
usedMB = float(_getSubRoot().getMandateDataVolumeMB(mandateId))
|
||
self.updateSettings(
|
||
settings["id"],
|
||
{
|
||
"storageHighWatermarkMB": usedMB,
|
||
"storageBilledUpToMB": 0.0,
|
||
"storagePeriodStartAt": periodStartAt,
|
||
},
|
||
)
|
||
logger.info(
|
||
"Storage billing period reset for mandate %s at %s (usedMB=%.2f)",
|
||
mandateId,
|
||
periodStartAt.isoformat(),
|
||
usedMB,
|
||
)
|
||
|
||
def reconcileMandateStorageBilling(self, mandateId: str) -> Optional[Dict[str, Any]]:
|
||
"""Debit prepay pool for new storage overage using period high-watermark (no credit on delete)."""
|
||
settings = self.getSettings(mandateId)
|
||
if not settings:
|
||
return None
|
||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRoot
|
||
from modules.datamodels.datamodelSubscription import _getPlan
|
||
|
||
subIface = _getSubRoot()
|
||
usedMB = float(subIface.getMandateDataVolumeMB(mandateId))
|
||
sub = subIface.getOperativeForMandate(mandateId)
|
||
plan = _getPlan(sub.get("planKey", "")) if sub else None
|
||
includedMB = plan.maxDataVolumeMB if plan and plan.maxDataVolumeMB is not None else None
|
||
if includedMB is None:
|
||
return None
|
||
|
||
prevHigh = float(settings.get("storageHighWatermarkMB") or 0.0)
|
||
high = max(prevHigh, usedMB)
|
||
overageMB = max(0.0, high - float(includedMB))
|
||
billed = float(settings.get("storageBilledUpToMB") or 0.0)
|
||
deltaOverage = overageMB - billed
|
||
settingsUpdates: Dict[str, Any] = {}
|
||
if high != prevHigh:
|
||
settingsUpdates["storageHighWatermarkMB"] = high
|
||
if deltaOverage <= 1e-9:
|
||
if settingsUpdates:
|
||
self.updateSettings(settings["id"], settingsUpdates)
|
||
return None
|
||
|
||
costCHF = round((deltaOverage / 1024.0) * float(STORAGE_PRICE_PER_GB_CHF), 4)
|
||
if costCHF <= 0:
|
||
if settingsUpdates:
|
||
self.updateSettings(settings["id"], settingsUpdates)
|
||
return None
|
||
|
||
poolAccount = self.getOrCreateMandateAccount(mandateId)
|
||
transaction = BillingTransaction(
|
||
accountId=poolAccount["id"],
|
||
transactionType=TransactionTypeEnum.DEBIT,
|
||
amount=costCHF,
|
||
description=f"Speicher-Überhang ({deltaOverage:.2f} MB über Plan)",
|
||
referenceType=ReferenceTypeEnum.STORAGE,
|
||
referenceId=mandateId,
|
||
)
|
||
created = self.createTransaction(transaction)
|
||
settingsUpdates["storageBilledUpToMB"] = overageMB
|
||
self.updateSettings(settings["id"], settingsUpdates)
|
||
logger.info(
|
||
"Storage overage billed mandate=%s deltaOverageMB=%.4f costCHF=%s",
|
||
mandateId,
|
||
deltaOverage,
|
||
costCHF,
|
||
)
|
||
return created
|
||
|
||
# =========================================================================
|
||
# Subscription AI-Budget Credit
|
||
# =========================================================================
|
||
|
||
def creditSubscriptionBudget(self, mandateId: str, planKey: str, periodLabel: str = "") -> Optional[Dict[str, Any]]:
|
||
"""Credit AI budget to the mandate pool account.
|
||
|
||
Amount = budgetAiPerUserCHF * activeUsers (dynamic, not the static plan.budgetAiCHF).
|
||
Should be called once per billing period (initial activation + each invoice.paid).
|
||
Returns the created CREDIT transaction or None if budget is 0."""
|
||
from modules.datamodels.datamodelSubscription import _getPlan
|
||
|
||
plan = _getPlan(planKey)
|
||
if not plan or not plan.budgetAiPerUserCHF or plan.budgetAiPerUserCHF <= 0:
|
||
return None
|
||
|
||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRoot
|
||
subRoot = _getSubRoot()
|
||
activeUsers = max(subRoot.countActiveUsers(mandateId), 1)
|
||
amount = plan.budgetAiPerUserCHF * activeUsers
|
||
|
||
poolAccount = self.getOrCreateMandateAccount(mandateId)
|
||
description = f"AI-Budget ({planKey}, {activeUsers} User)"
|
||
if periodLabel:
|
||
description += f" – {periodLabel}"
|
||
|
||
transaction = BillingTransaction(
|
||
accountId=poolAccount["id"],
|
||
transactionType=TransactionTypeEnum.CREDIT,
|
||
amount=amount,
|
||
description=description,
|
||
referenceType=ReferenceTypeEnum.SUBSCRIPTION,
|
||
referenceId=mandateId,
|
||
)
|
||
created = self.createTransaction(transaction)
|
||
logger.info(
|
||
"AI-Budget credited mandate=%s plan=%s users=%d amount=%.2f CHF",
|
||
mandateId, planKey, activeUsers, amount,
|
||
)
|
||
return created
|
||
|
||
def ensureActivationBudget(self, mandateId: str, planKey: str) -> Optional[Dict[str, Any]]:
|
||
"""Idempotent: credit the activation budget only if no SUBSCRIPTION credit exists yet."""
|
||
poolAccount = self.getMandateAccount(mandateId)
|
||
if not poolAccount:
|
||
return self.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung")
|
||
|
||
existing = self.db.getRecordset(
|
||
BillingTransaction,
|
||
recordFilter={
|
||
"accountId": poolAccount["id"],
|
||
"transactionType": TransactionTypeEnum.CREDIT.value,
|
||
"referenceType": ReferenceTypeEnum.SUBSCRIPTION.value,
|
||
},
|
||
)
|
||
if existing:
|
||
return None
|
||
|
||
return self.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung")
|
||
|
||
def adjustAiBudgetForUserChange(self, mandateId: str, planKey: str, delta: int) -> Optional[Dict[str, Any]]:
|
||
"""Pro-rata AI budget adjustment when users are added/removed mid-cycle.
|
||
|
||
delta > 0: user added -> CREDIT pro-rata portion
|
||
delta < 0: user removed -> DEBIT pro-rata portion
|
||
"""
|
||
from modules.datamodels.datamodelSubscription import _getPlan
|
||
|
||
plan = _getPlan(planKey)
|
||
if not plan or not plan.budgetAiPerUserCHF or plan.budgetAiPerUserCHF <= 0:
|
||
return None
|
||
|
||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRoot
|
||
subRoot = _getSubRoot()
|
||
operative = subRoot.getOperativeForMandate(mandateId)
|
||
if not operative:
|
||
return None
|
||
|
||
periodStart = operative.get("currentPeriodStart")
|
||
periodEnd = operative.get("currentPeriodEnd")
|
||
if not periodStart or not periodEnd:
|
||
return None
|
||
|
||
if isinstance(periodStart, str):
|
||
periodStart = datetime.fromisoformat(periodStart)
|
||
if isinstance(periodEnd, str):
|
||
periodEnd = datetime.fromisoformat(periodEnd)
|
||
if periodStart.tzinfo is None:
|
||
periodStart = periodStart.replace(tzinfo=timezone.utc)
|
||
if periodEnd.tzinfo is None:
|
||
periodEnd = periodEnd.replace(tzinfo=timezone.utc)
|
||
|
||
now = datetime.now(timezone.utc)
|
||
totalSeconds = (periodEnd - periodStart).total_seconds()
|
||
remainingSeconds = max((periodEnd - now).total_seconds(), 0)
|
||
proRataFraction = remainingSeconds / totalSeconds if totalSeconds > 0 else 0
|
||
|
||
amount = round(abs(delta) * plan.budgetAiPerUserCHF * proRataFraction, 2)
|
||
if amount <= 0:
|
||
return None
|
||
|
||
poolAccount = self.getOrCreateMandateAccount(mandateId)
|
||
|
||
if delta > 0:
|
||
txType = TransactionTypeEnum.CREDIT
|
||
label = f"AI-Budget pro-rata +{abs(delta)} User ({planKey})"
|
||
else:
|
||
txType = TransactionTypeEnum.DEBIT
|
||
label = f"AI-Budget pro-rata -{abs(delta)} User ({planKey})"
|
||
|
||
transaction = BillingTransaction(
|
||
accountId=poolAccount["id"],
|
||
transactionType=txType,
|
||
amount=amount,
|
||
description=label,
|
||
referenceType=ReferenceTypeEnum.SUBSCRIPTION,
|
||
referenceId=mandateId,
|
||
)
|
||
created = self.createTransaction(transaction)
|
||
logger.info(
|
||
"AI-Budget pro-rata %s mandate=%s delta=%+d amount=%.2f CHF (fraction=%.4f)",
|
||
txType.value, mandateId, delta, amount, proRataFraction,
|
||
)
|
||
return created
|
||
|
||
# =========================================================================
|
||
# Workflow Cost Query
|
||
# =========================================================================
|
||
|
||
def getWorkflowCost(self, workflowId: str) -> float:
|
||
"""Sum of all transaction amounts for a workflow."""
|
||
if not workflowId:
|
||
return 0.0
|
||
transactions = self.db.getRecordset(
|
||
BillingTransaction,
|
||
recordFilter={"workflowId": workflowId}
|
||
)
|
||
return sum(t.get("amount", 0.0) for t in transactions)
|
||
|
||
# =========================================================================
|
||
# Statistics Operations
|
||
# =========================================================================
|
||
|
||
def getUsageStatistics(
|
||
self,
|
||
accountId: str,
|
||
periodType: PeriodTypeEnum,
|
||
year: int,
|
||
month: int = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""
|
||
Get usage statistics for an account.
|
||
|
||
Args:
|
||
accountId: Account ID
|
||
periodType: Period type (DAY, MONTH, YEAR)
|
||
year: Year
|
||
month: Month (for DAY period type)
|
||
|
||
Returns:
|
||
List of statistics dicts
|
||
"""
|
||
filterDict = {
|
||
"accountId": accountId,
|
||
"periodType": periodType.value
|
||
}
|
||
|
||
results = self.db.getRecordset(UsageStatistics, recordFilter=filterDict)
|
||
|
||
# Filter by year
|
||
filtered = [s for s in results if s.get("periodStart") and s["periodStart"].year == year]
|
||
|
||
# Filter by month if specified
|
||
if month and periodType == PeriodTypeEnum.DAY:
|
||
filtered = [s for s in filtered if s["periodStart"].month == month]
|
||
|
||
return sorted(filtered, key=lambda x: x.get("periodStart", date.min))
|
||
|
||
def calculateStatisticsFromTransactions(
|
||
self,
|
||
accountId: str,
|
||
startDate: date,
|
||
endDate: date
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Calculate statistics from transactions for a period.
|
||
|
||
Args:
|
||
accountId: Account ID
|
||
startDate: Start date
|
||
endDate: End date
|
||
|
||
Returns:
|
||
Statistics dict
|
||
"""
|
||
transactions = self.getTransactions(accountId, limit=10000, startDate=startDate, endDate=endDate)
|
||
|
||
# Filter only DEBIT transactions (usage)
|
||
debits = [t for t in transactions if t.get("transactionType") == TransactionTypeEnum.DEBIT.value]
|
||
|
||
totalCost = sum(t.get("amount", 0) for t in debits)
|
||
|
||
# Calculate by provider
|
||
costByProvider = {}
|
||
costByModel = {}
|
||
for t in debits:
|
||
provider = t.get("aicoreProvider", "unknown")
|
||
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
|
||
|
||
model = t.get("aicoreModel", "unknown")
|
||
costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0)
|
||
|
||
# Calculate by feature
|
||
costByFeature = {}
|
||
for t in debits:
|
||
feature = t.get("featureCode", "unknown")
|
||
costByFeature[feature] = costByFeature.get(feature, 0) + t.get("amount", 0)
|
||
|
||
return {
|
||
"totalCostCHF": totalCost,
|
||
"transactionCount": len(debits),
|
||
"costByProvider": costByProvider,
|
||
"costByModel": costByModel,
|
||
"costByFeature": costByFeature
|
||
}
|
||
|
||
# =========================================================================
|
||
# Utility Methods
|
||
# =========================================================================
|
||
|
||
def getBalancesForUser(self, userId: str) -> List[BillingBalanceResponse]:
|
||
"""
|
||
Get all billing balances for a user across mandates.
|
||
Shows the mandate pool balance (shared budget visible to user).
|
||
|
||
Args:
|
||
userId: User ID
|
||
|
||
Returns:
|
||
List of BillingBalanceResponse
|
||
"""
|
||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||
|
||
balances = []
|
||
|
||
try:
|
||
# Use rootInterface (privileged, SysAdmin context) to bypass RBAC
|
||
# for mandate/user lookups. User access is verified via UserMandate membership.
|
||
rootInterface = getRootInterface()
|
||
userMandates = rootInterface.getUserMandates(userId)
|
||
|
||
for um in userMandates:
|
||
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
|
||
if not mandateId:
|
||
continue
|
||
|
||
mandate = rootInterface.getMandate(mandateId)
|
||
if not mandate or not getattr(mandate, "enabled", True):
|
||
continue
|
||
|
||
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "")
|
||
|
||
settings = self.getSettings(mandateId)
|
||
if not settings:
|
||
continue
|
||
|
||
poolAccount = self.getOrCreateMandateAccount(mandateId)
|
||
if not poolAccount:
|
||
continue
|
||
balance = poolAccount.get("balance", 0.0)
|
||
warningThreshold = poolAccount.get("warningThreshold", 0.0)
|
||
|
||
balances.append(BillingBalanceResponse(
|
||
mandateId=mandateId,
|
||
mandateName=mandateName,
|
||
balance=balance,
|
||
warningThreshold=warningThreshold,
|
||
isWarning=balance <= warningThreshold,
|
||
))
|
||
except Exception as e:
|
||
logger.error(f"Error getting balances for user: {e}")
|
||
|
||
return balances
|
||
|
||
def getTransactionsForUser(self, userId: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||
"""
|
||
Get all transactions for a user across all mandates they belong to.
|
||
Since transactions are always recorded on user accounts, we query
|
||
directly by user account - clean and simple.
|
||
|
||
Args:
|
||
userId: User ID
|
||
limit: Maximum number of results
|
||
|
||
Returns:
|
||
List of transaction dicts
|
||
"""
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
|
||
allTransactions = []
|
||
|
||
try:
|
||
appInterface = getAppInterface(self.currentUser)
|
||
userMandates = appInterface.getUserMandates(userId)
|
||
|
||
for um in userMandates:
|
||
mandateId = getattr(um, 'mandateId', None) or (um.get("mandateId") if isinstance(um, dict) else None)
|
||
if not mandateId:
|
||
continue
|
||
|
||
settings = self.getSettings(mandateId)
|
||
if not settings:
|
||
continue
|
||
|
||
# Get user's account in this mandate
|
||
userAccount = self.getUserAccount(mandateId, userId)
|
||
if not userAccount:
|
||
continue
|
||
|
||
# Get transactions for user's account (all transactions are on user accounts now)
|
||
transactions = self.getTransactions(userAccount["id"], limit=limit)
|
||
|
||
mandate = appInterface.getMandate(mandateId)
|
||
mandateName = ""
|
||
if mandate:
|
||
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "")
|
||
|
||
for t in transactions:
|
||
t["mandateId"] = mandateId
|
||
t["mandateName"] = mandateName
|
||
allTransactions.append(t)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting transactions for user: {e}")
|
||
|
||
_sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getTransactionsForUser")
|
||
return allTransactions[:limit]
|
||
|
||
# =========================================================================
|
||
# Mandate View Operations (Admin-Level)
|
||
# =========================================================================
|
||
|
||
def getMandateBalances(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]:
|
||
"""
|
||
Get mandate-level balances.
|
||
|
||
Args:
|
||
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
||
|
||
Returns:
|
||
List of mandate balance dicts
|
||
"""
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
|
||
balances = []
|
||
|
||
try:
|
||
appInterface = getAppInterface(self.currentUser)
|
||
|
||
# Get settings for filtering
|
||
if mandateIds:
|
||
allSettings = [self.getSettings(mId) for mId in mandateIds]
|
||
allSettings = [s for s in allSettings if s]
|
||
else:
|
||
allSettings = self.db.getRecordset(BillingSettings)
|
||
|
||
for settings in allSettings:
|
||
mandateId = settings.get("mandateId")
|
||
if not mandateId:
|
||
continue
|
||
|
||
mandate = appInterface.getMandate(mandateId)
|
||
mandateName = ""
|
||
if mandate:
|
||
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "")
|
||
|
||
allMandateAccounts = self.db.getRecordset(
|
||
BillingAccount,
|
||
recordFilter={"mandateId": mandateId}
|
||
)
|
||
userCount = sum(1 for acc in allMandateAccounts if acc.get("userId"))
|
||
|
||
poolAccount = self.getMandateAccount(mandateId)
|
||
totalBalance = poolAccount.get("balance", 0.0) if poolAccount else 0.0
|
||
|
||
balances.append({
|
||
"mandateId": mandateId,
|
||
"mandateName": mandateName,
|
||
"totalBalance": totalBalance,
|
||
"userCount": userCount,
|
||
"warningThresholdPercent": settings.get("warningThresholdPercent", 10.0),
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting mandate balances: {e}")
|
||
|
||
return balances
|
||
|
||
def getMandateTransactions(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||
"""
|
||
Get all transactions for specified mandates.
|
||
|
||
Args:
|
||
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
||
limit: Maximum number of results
|
||
|
||
Returns:
|
||
List of transaction dicts with mandate context
|
||
"""
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
|
||
allTransactions = []
|
||
|
||
try:
|
||
appInterface = getAppInterface(self.currentUser)
|
||
|
||
# Determine which mandates to query
|
||
if mandateIds:
|
||
targetMandateIds = mandateIds
|
||
else:
|
||
allSettings = self.db.getRecordset(BillingSettings)
|
||
targetMandateIds = [s.get("mandateId") for s in allSettings if s.get("mandateId")]
|
||
|
||
for mandateId in targetMandateIds:
|
||
transactions = self.getTransactionsByMandate(mandateId, limit=limit)
|
||
|
||
# Get mandate name
|
||
mandate = appInterface.getMandate(mandateId)
|
||
mandateName = ""
|
||
if mandate:
|
||
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "")
|
||
|
||
for t in transactions:
|
||
t["mandateId"] = mandateId
|
||
t["mandateName"] = mandateName
|
||
allTransactions.append(t)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting mandate transactions: {e}")
|
||
|
||
# Sort by creation date descending and limit
|
||
_sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getMandateTransactions")
|
||
return allTransactions[:limit]
|
||
|
||
# =========================================================================
|
||
# User View Operations (User-Level with RBAC)
|
||
# =========================================================================
|
||
|
||
def getUserBalancesForMandates(self, mandateIds: List[str] = None) -> List[Dict[str, Any]]:
|
||
"""
|
||
Get all user-level balances for specified mandates.
|
||
|
||
Args:
|
||
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
||
|
||
Returns:
|
||
List of user balance dicts with mandate and user context
|
||
"""
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
|
||
balances = []
|
||
|
||
try:
|
||
appInterface = getAppInterface(self.currentUser)
|
||
|
||
allAccounts = self.db.getRecordset(BillingAccount)
|
||
allAccounts = [acc for acc in allAccounts if acc.get("userId")]
|
||
|
||
# Filter by mandate if specified
|
||
if mandateIds:
|
||
allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds]
|
||
|
||
# Get all relevant settings in one query
|
||
settingsMap = {}
|
||
allSettings = self.db.getRecordset(BillingSettings)
|
||
for s in allSettings:
|
||
settingsMap[s.get("mandateId")] = s
|
||
|
||
# Get user info efficiently
|
||
userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId")))
|
||
userMap = {}
|
||
for userId in userIds:
|
||
user = appInterface.getUser(userId)
|
||
if user:
|
||
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
|
||
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
|
||
userMap[userId] = displayName or username or userId
|
||
|
||
# Get mandate info efficiently
|
||
mandateMap = {}
|
||
mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId")))
|
||
for mandateId in mandateIdList:
|
||
mandate = appInterface.getMandate(mandateId)
|
||
if mandate:
|
||
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "")
|
||
mandateMap[mandateId] = mandateName
|
||
|
||
for account in allAccounts:
|
||
mandateId = account.get("mandateId")
|
||
userId = account.get("userId")
|
||
|
||
if not mandateId or not userId:
|
||
continue
|
||
|
||
settings = settingsMap.get(mandateId)
|
||
if not settings:
|
||
continue
|
||
|
||
balance = account.get("balance", 0.0)
|
||
warningThreshold = account.get("warningThreshold", 0.0)
|
||
|
||
balances.append({
|
||
"accountId": account.get("id"),
|
||
"mandateId": mandateId,
|
||
"mandateName": mandateMap.get(mandateId, ""),
|
||
"userId": userId,
|
||
"userName": userMap.get(userId, userId),
|
||
"balance": balance,
|
||
"warningThreshold": warningThreshold,
|
||
"isWarning": balance <= warningThreshold,
|
||
"enabled": account.get("enabled", True)
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting user balances for mandates: {e}")
|
||
|
||
return balances
|
||
|
||
@staticmethod
|
||
def _mapPaginationColumns(pagination: PaginationParams) -> PaginationParams:
|
||
"""Remap frontend column names to DB column names in filters and sort."""
|
||
_COL_MAP = {"createdAt": "sysCreatedAt"}
|
||
_ENRICHED_COLS = {"mandateName", "userName", "mandateId", "userId"}
|
||
import copy
|
||
p = copy.deepcopy(pagination)
|
||
if p.filters:
|
||
mapped = {}
|
||
for k, v in p.filters.items():
|
||
if k in _ENRICHED_COLS:
|
||
continue
|
||
mapped[_COL_MAP.get(k, k)] = v
|
||
p.filters = mapped
|
||
if p.sort:
|
||
mapped = []
|
||
for s in p.sort:
|
||
field = s.get("field", "") if isinstance(s, dict) else getattr(s, "field", "")
|
||
if field in _ENRICHED_COLS:
|
||
continue
|
||
newField = _COL_MAP.get(field, field)
|
||
if isinstance(s, dict):
|
||
mapped.append({**s, "field": newField})
|
||
else:
|
||
mapped.append({"field": newField, "direction": getattr(s, "direction", "asc")})
|
||
p.sort = mapped if mapped else [{"field": "sysCreatedAt", "direction": "desc"}]
|
||
return p
|
||
|
||
def getTransactionsForMandatesPaginated(
|
||
self,
|
||
mandateIds: Optional[List[str]],
|
||
pagination: PaginationParams,
|
||
scope: str = "all",
|
||
userId: Optional[str] = None,
|
||
) -> PaginatedResult:
|
||
"""
|
||
SQL-level paginated transactions across multiple mandates.
|
||
Single SQL query with WHERE accountId = ANY(...), ORDER BY, LIMIT/OFFSET.
|
||
Enrichment (userName, mandateName) only for the returned page.
|
||
"""
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
|
||
try:
|
||
mappedPagination = self._mapPaginationColumns(pagination)
|
||
|
||
allAccounts = self.db.getRecordset(BillingAccount)
|
||
if mandateIds:
|
||
allAccounts = [a for a in allAccounts if a.get("mandateId") in set(mandateIds)]
|
||
|
||
accountIds = [a.get("id") for a in allAccounts if a.get("id")]
|
||
if not accountIds:
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
|
||
# Extract free-text search term and run a custom query that covers
|
||
# enriched columns (mandateName, userName) and the numeric amount
|
||
# column. The generic SQL search only covers TEXT columns of the
|
||
# BillingTransaction table, which excludes these fields.
|
||
searchTerm: Optional[str] = None
|
||
if mappedPagination and mappedPagination.filters:
|
||
raw = mappedPagination.filters.get("search")
|
||
if isinstance(raw, str) and raw.strip():
|
||
searchTerm = raw.strip()
|
||
|
||
if searchTerm:
|
||
searchResult = self._searchTransactionsPaginated(
|
||
allAccounts=allAccounts,
|
||
accountIds=accountIds,
|
||
userId=userId,
|
||
searchTerm=searchTerm,
|
||
pagination=mappedPagination,
|
||
)
|
||
pageItems = searchResult["items"]
|
||
totalItems = searchResult["totalItems"]
|
||
totalPages = searchResult["totalPages"]
|
||
else:
|
||
recordFilter: Dict[str, Any] = {"accountId": accountIds}
|
||
if userId:
|
||
recordFilter["createdByUserId"] = userId
|
||
|
||
result = self.db.getRecordsetPaginated(
|
||
BillingTransaction,
|
||
pagination=mappedPagination,
|
||
recordFilter=recordFilter,
|
||
)
|
||
pageItems = result.get("items", []) if isinstance(result, dict) else result.items
|
||
totalItems = result.get("totalItems", 0) if isinstance(result, dict) else result.totalItems
|
||
totalPages = result.get("totalPages", 0) if isinstance(result, dict) else result.totalPages
|
||
|
||
accountMap = {a.get("id"): a for a in allAccounts}
|
||
|
||
pageUserIds = set()
|
||
pageMandateIds = set()
|
||
for t in pageItems:
|
||
accId = t.get("accountId")
|
||
acc = accountMap.get(accId, {})
|
||
mid = acc.get("mandateId")
|
||
uid = t.get("createdByUserId") or acc.get("userId")
|
||
if uid:
|
||
pageUserIds.add(uid)
|
||
if mid:
|
||
pageMandateIds.add(mid)
|
||
|
||
appInterface = getAppInterface(self.currentUser)
|
||
userMap: Dict[str, str] = {}
|
||
if pageUserIds:
|
||
users = appInterface.getUsersByIds(list(pageUserIds))
|
||
for uid, u in users.items():
|
||
dn = getattr(u, "displayName", None) or getattr(u, "username", None) or uid
|
||
userMap[uid] = dn
|
||
|
||
mandateMap: Dict[str, str] = {}
|
||
if pageMandateIds:
|
||
mandates = appInterface.getMandatesByIds(list(pageMandateIds))
|
||
for mid, m in mandates.items():
|
||
mandateMap[mid] = getattr(m, "label", None) or getattr(m, "name", None) or mid
|
||
|
||
enriched = []
|
||
for t in pageItems:
|
||
row = dict(t)
|
||
accId = row.get("accountId")
|
||
acc = accountMap.get(accId, {})
|
||
mid = acc.get("mandateId")
|
||
txUserId = row.get("createdByUserId") or acc.get("userId")
|
||
row["mandateId"] = mid
|
||
row["mandateName"] = mandateMap.get(mid, "")
|
||
row["userId"] = txUserId
|
||
row["userName"] = userMap.get(txUserId, txUserId) if txUserId else None
|
||
enriched.append(row)
|
||
|
||
return PaginatedResult(items=enriched, totalItems=totalItems, totalPages=totalPages)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in getTransactionsForMandatesPaginated: {e}")
|
||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||
|
||
def _searchTransactionsPaginated(
|
||
self,
|
||
allAccounts: List[Dict[str, Any]],
|
||
accountIds: List[str],
|
||
userId: Optional[str],
|
||
searchTerm: str,
|
||
pagination: PaginationParams,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Custom paginated search for BillingTransaction that also covers the
|
||
enriched columns `mandateName` and `userName` as well as the numeric
|
||
`amount` column. Resolves matching mandate/user IDs via the app DB
|
||
first, then builds a single SQL query with OR-combined conditions.
|
||
"""
|
||
import math
|
||
from modules.connectors.connectorDbPostgre import _get_model_fields, _parseRecordFields
|
||
from modules.datamodels.datamodelUam import UserInDB
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
|
||
table = BillingTransaction.__name__
|
||
fields = _get_model_fields(BillingTransaction)
|
||
pattern = f"%{searchTerm}%"
|
||
|
||
# Resolve matching user / mandate IDs via the app DB (which is separate
|
||
# from the billing DB and hosts UserInDB / Mandate tables).
|
||
matchingUserIds: List[str] = []
|
||
matchingMandateIds: List[str] = []
|
||
try:
|
||
appInterface = getAppInterface(self.currentUser)
|
||
appInterface.db._ensure_connection()
|
||
with appInterface.db.connection.cursor() as cur:
|
||
if appInterface.db._ensureTableExists(UserInDB):
|
||
cur.execute(
|
||
'SELECT "id" FROM "UserInDB" WHERE '
|
||
'COALESCE("username", \'\') ILIKE %s OR '
|
||
'COALESCE("fullName", \'\') ILIKE %s OR '
|
||
'COALESCE("email", \'\') ILIKE %s',
|
||
(pattern, pattern, pattern),
|
||
)
|
||
matchingUserIds = [r["id"] for r in cur.fetchall() if r.get("id")]
|
||
|
||
if appInterface.db._ensureTableExists(Mandate):
|
||
cur.execute(
|
||
'SELECT "id" FROM "Mandate" WHERE '
|
||
'COALESCE("label", \'\') ILIKE %s OR '
|
||
'COALESCE("name", \'\') ILIKE %s',
|
||
(pattern, pattern),
|
||
)
|
||
matchingMandateIds = [r["id"] for r in cur.fetchall() if r.get("id")]
|
||
except Exception as e:
|
||
logger.warning(f"_searchTransactionsPaginated: user/mandate resolution failed: {e}")
|
||
|
||
matchingAccountIds = [
|
||
a.get("id") for a in allAccounts
|
||
if a.get("id") and a.get("mandateId") in set(matchingMandateIds)
|
||
]
|
||
|
||
# Try to interpret the search term as a number for amount matching.
|
||
amountVal: Optional[float] = None
|
||
try:
|
||
amountVal = float(searchTerm.replace(",", "."))
|
||
except Exception:
|
||
amountVal = None
|
||
|
||
whereParts: List[str] = ['"accountId" = ANY(%s)']
|
||
whereValues: List[Any] = [accountIds]
|
||
if userId:
|
||
whereParts.append('"createdByUserId" = %s')
|
||
whereValues.append(userId)
|
||
|
||
# Apply non-search filters from pagination (reuse existing builder for
|
||
# everything except the `search` key which we handle explicitly).
|
||
import copy
|
||
paginationWithoutSearch = copy.deepcopy(pagination) if pagination else None
|
||
if paginationWithoutSearch and paginationWithoutSearch.filters:
|
||
paginationWithoutSearch.filters = {
|
||
k: v for k, v in paginationWithoutSearch.filters.items() if k != "search"
|
||
}
|
||
|
||
orParts: List[str] = []
|
||
orValues: List[Any] = []
|
||
|
||
textCols = [c for c, t in fields.items() if t == "TEXT"]
|
||
for col in textCols:
|
||
orParts.append(f'COALESCE("{col}"::TEXT, \'\') ILIKE %s')
|
||
orValues.append(pattern)
|
||
|
||
if matchingUserIds:
|
||
orParts.append('"createdByUserId" = ANY(%s)')
|
||
orValues.append(matchingUserIds)
|
||
if matchingAccountIds:
|
||
orParts.append('"accountId" = ANY(%s)')
|
||
orValues.append(matchingAccountIds)
|
||
|
||
orParts.append('"amount"::TEXT ILIKE %s')
|
||
orValues.append(pattern)
|
||
if amountVal is not None:
|
||
orParts.append('"amount" = %s')
|
||
orValues.append(amountVal)
|
||
|
||
whereParts.append(f"({' OR '.join(orParts)})")
|
||
whereValues.extend(orValues)
|
||
|
||
# Apply remaining structured filters via the generic helper by feeding
|
||
# it a dummy pagination that does NOT include LIMIT/OFFSET. We only
|
||
# need the WHERE contribution for the non-search filters here.
|
||
extraWhere = ""
|
||
extraValues: List[Any] = []
|
||
if paginationWithoutSearch and paginationWithoutSearch.filters:
|
||
try:
|
||
fromPagination = copy.deepcopy(paginationWithoutSearch)
|
||
fromPagination.sort = []
|
||
fromPagination.page = 1
|
||
fromPagination.pageSize = 1
|
||
ew, _, _, values, _ = self.db._buildPaginationClauses(
|
||
BillingTransaction, fromPagination, recordFilter=None
|
||
)
|
||
if ew:
|
||
extraWhere = ew.replace(" WHERE ", " AND ", 1)
|
||
extraValues = list(values)
|
||
except Exception as e:
|
||
logger.warning(f"_searchTransactionsPaginated: extra-filter build failed: {e}")
|
||
|
||
whereClause = " WHERE " + " AND ".join(whereParts) + extraWhere
|
||
whereValues.extend(extraValues)
|
||
|
||
# Build ORDER BY from pagination.sort
|
||
validColumns = set(fields.keys())
|
||
orderParts: List[str] = []
|
||
if pagination and pagination.sort:
|
||
for sf in pagination.sort:
|
||
sfField = sf.get("field") if isinstance(sf, dict) else getattr(sf, "field", None)
|
||
sfDir = sf.get("direction", "asc") if isinstance(sf, dict) else getattr(sf, "direction", "asc")
|
||
if sfField and sfField in validColumns:
|
||
direction = "DESC" if str(sfDir).lower() == "desc" else "ASC"
|
||
colType = fields.get(sfField, "TEXT")
|
||
if colType == "BOOLEAN":
|
||
orderParts.append(f'COALESCE("{sfField}", FALSE) {direction}')
|
||
else:
|
||
orderParts.append(f'"{sfField}" {direction} NULLS LAST')
|
||
if not orderParts:
|
||
orderParts.append('"id"')
|
||
orderClause = " ORDER BY " + ", ".join(orderParts)
|
||
|
||
pageSize = pagination.pageSize if pagination else 50
|
||
page = pagination.page if pagination else 1
|
||
offset = (page - 1) * pageSize
|
||
limitClause = f" LIMIT {pageSize} OFFSET {offset}"
|
||
|
||
try:
|
||
self.db._ensure_connection()
|
||
with self.db.connection.cursor() as cur:
|
||
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
|
||
cur.execute(countSql, whereValues)
|
||
totalItems = cur.fetchone()["count"]
|
||
|
||
dataSql = f'SELECT * FROM "{table}"{whereClause}{orderClause}{limitClause}'
|
||
cur.execute(dataSql, whereValues)
|
||
records = [dict(row) for row in cur.fetchall()]
|
||
|
||
for rec in records:
|
||
_parseRecordFields(rec, fields, f"search table {table}")
|
||
|
||
totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0
|
||
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
|
||
|
||
except Exception as e:
|
||
logger.error(f"_searchTransactionsPaginated SQL error: {e}", exc_info=True)
|
||
try:
|
||
self.db.connection.rollback()
|
||
except Exception:
|
||
pass
|
||
return {"items": [], "totalItems": 0, "totalPages": 0}
|
||
|
||
def _buildScopeFilter(
|
||
self,
|
||
mandateIds: Optional[List[str]],
|
||
scope: str = "all",
|
||
userId: Optional[str] = None,
|
||
startTs: Optional[float] = None,
|
||
endTs: Optional[float] = None,
|
||
) -> tuple:
|
||
"""Build WHERE clause parts for scoped transaction queries. Returns (conditions, values, accountIds)."""
|
||
allAccounts = self.db.getRecordset(BillingAccount)
|
||
if mandateIds:
|
||
mandateSet = set(mandateIds)
|
||
allAccounts = [a for a in allAccounts if a.get("mandateId") in mandateSet]
|
||
|
||
accountIds = [a.get("id") for a in allAccounts if a.get("id")]
|
||
if not accountIds:
|
||
return [], [], [], allAccounts
|
||
|
||
conditions = ['"accountId" = ANY(%s)', '"transactionType" = %s']
|
||
values: list = [accountIds, "DEBIT"]
|
||
|
||
if userId:
|
||
conditions.append('"createdByUserId" = %s')
|
||
values.append(userId)
|
||
|
||
if startTs is not None:
|
||
conditions.append('"sysCreatedAt" >= %s')
|
||
values.append(startTs)
|
||
if endTs is not None:
|
||
conditions.append('"sysCreatedAt" < %s')
|
||
values.append(endTs)
|
||
|
||
return conditions, values, accountIds, allAccounts
|
||
|
||
def getTransactionStatisticsAggregated(
|
||
self,
|
||
mandateIds: Optional[List[str]],
|
||
scope: str = "all",
|
||
userId: Optional[str] = None,
|
||
startTs: Optional[float] = None,
|
||
endTs: Optional[float] = None,
|
||
bucketSize: str = "month",
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Pure SQL aggregation for statistics. No row-level loading.
|
||
|
||
`bucketSize` controls only the time-series aggregation granularity
|
||
(`'day' | 'month' | 'year'`); the date range is set via `startTs`/`endTs`.
|
||
|
||
Returns: totalCost, transactionCount, costByProvider, costByModel,
|
||
costByFeature, costByAccountId, timeSeries
|
||
"""
|
||
table = BillingTransaction.__name__
|
||
|
||
try:
|
||
if not self.db._ensureTableExists(BillingTransaction):
|
||
return self._emptyStats()
|
||
|
||
conditions, values, accountIds, allAccounts = self._buildScopeFilter(
|
||
mandateIds, scope, userId, startTs, endTs
|
||
)
|
||
if not accountIds:
|
||
return self._emptyStats()
|
||
|
||
whereClause = " WHERE " + " AND ".join(conditions)
|
||
self.db._ensure_connection()
|
||
|
||
result: Dict[str, Any] = {}
|
||
|
||
with self.db.connection.cursor() as cur:
|
||
# 1) Totals
|
||
cur.execute(
|
||
f'SELECT COALESCE(SUM("amount"), 0) AS total, COUNT(*) AS cnt FROM "{table}"{whereClause}',
|
||
values,
|
||
)
|
||
row = cur.fetchone()
|
||
result["totalCost"] = round(float(row["total"]), 4)
|
||
result["transactionCount"] = int(row["cnt"])
|
||
|
||
# 2) GROUP BY aicoreProvider
|
||
cur.execute(
|
||
f'SELECT COALESCE("aicoreProvider", \'unknown\') AS grp, SUM("amount") AS total '
|
||
f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC',
|
||
values,
|
||
)
|
||
result["costByProvider"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()}
|
||
|
||
# 3) GROUP BY aicoreModel
|
||
cur.execute(
|
||
f'SELECT COALESCE("aicoreModel", \'unknown\') AS grp, SUM("amount") AS total '
|
||
f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC',
|
||
values,
|
||
)
|
||
result["costByModel"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()}
|
||
|
||
# 4) GROUP BY accountId (will be enriched to mandateName by caller)
|
||
cur.execute(
|
||
f'SELECT "accountId" AS grp, SUM("amount") AS total '
|
||
f'FROM "{table}"{whereClause} GROUP BY grp ORDER BY total DESC',
|
||
values,
|
||
)
|
||
result["costByAccountId"] = {r["grp"]: round(float(r["total"]), 4) for r in cur.fetchall()}
|
||
|
||
# 5) GROUP BY accountId + featureCode (for costByFeature)
|
||
cur.execute(
|
||
f'SELECT "accountId", COALESCE("featureCode", \'unknown\') AS fc, SUM("amount") AS total '
|
||
f'FROM "{table}"{whereClause} GROUP BY "accountId", fc ORDER BY total DESC',
|
||
values,
|
||
)
|
||
result["costByAccountFeature"] = [
|
||
{"accountId": r["accountId"], "featureCode": r["fc"], "total": round(float(r["total"]), 4)}
|
||
for r in cur.fetchall()
|
||
]
|
||
|
||
# 6) Time series via DATE_TRUNC on epoch timestamp
|
||
_bucketSpec = {
|
||
"day": ("day", "%Y-%m-%d"),
|
||
"month": ("month", "%Y-%m"),
|
||
"year": ("year", "%Y"),
|
||
}.get(bucketSize)
|
||
if _bucketSpec is None:
|
||
raise ValueError(
|
||
f"Invalid bucketSize: {bucketSize!r} (expected day|month|year)"
|
||
)
|
||
_truncUnit, _labelFormat = _bucketSpec
|
||
truncExpr = f"DATE_TRUNC('{_truncUnit}', TO_TIMESTAMP(\"sysCreatedAt\"))"
|
||
|
||
cur.execute(
|
||
f'SELECT {truncExpr} AS bucket, SUM("amount") AS total, COUNT(*) AS cnt '
|
||
f'FROM "{table}"{whereClause} AND "sysCreatedAt" IS NOT NULL '
|
||
f'GROUP BY bucket ORDER BY bucket',
|
||
values,
|
||
)
|
||
timeSeries = []
|
||
for r in cur.fetchall():
|
||
bucket = r["bucket"]
|
||
label = bucket.strftime(_labelFormat) if bucket else "unknown"
|
||
timeSeries.append({
|
||
"date": label,
|
||
"cost": round(float(r["total"]), 4),
|
||
"count": int(r["cnt"]),
|
||
})
|
||
result["timeSeries"] = timeSeries
|
||
|
||
self.db.connection.commit()
|
||
|
||
result["_allAccounts"] = allAccounts
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in getTransactionStatisticsAggregated: {e}", exc_info=True)
|
||
try:
|
||
self.db.connection.rollback()
|
||
except Exception:
|
||
pass
|
||
return self._emptyStats()
|
||
|
||
@staticmethod
|
||
def _emptyStats() -> Dict[str, Any]:
|
||
return {
|
||
"totalCost": 0.0,
|
||
"transactionCount": 0,
|
||
"costByProvider": {},
|
||
"costByModel": {},
|
||
"costByAccountId": {},
|
||
"costByAccountFeature": [],
|
||
"timeSeries": [],
|
||
"_allAccounts": [],
|
||
}
|
||
|
||
def getTransactionDistinctValues(
|
||
self,
|
||
mandateIds: Optional[List[str]],
|
||
column: str,
|
||
pagination: Optional[PaginationParams] = None,
|
||
scope: str = "all",
|
||
userId: Optional[str] = None,
|
||
) -> List[str]:
|
||
"""SQL DISTINCT for filter-values on BillingTransaction, scoped by mandates."""
|
||
_COLUMN_MAP = {
|
||
"createdAt": "sysCreatedAt",
|
||
"mandateId": "accountId",
|
||
"mandateName": "accountId",
|
||
}
|
||
dbColumn = _COLUMN_MAP.get(column, column)
|
||
|
||
mappedPagination = self._mapPaginationColumns(pagination) if pagination else None
|
||
|
||
try:
|
||
allAccounts = self.db.getRecordset(BillingAccount)
|
||
if mandateIds:
|
||
allAccounts = [a for a in allAccounts if a.get("mandateId") in set(mandateIds)]
|
||
accountIds = [a.get("id") for a in allAccounts if a.get("id")]
|
||
if not accountIds:
|
||
return []
|
||
|
||
recordFilter: Dict[str, Any] = {"accountId": accountIds}
|
||
if userId:
|
||
recordFilter["createdByUserId"] = userId
|
||
|
||
if column in ("mandateName", "userName"):
|
||
return self._getEnrichedDistinctValues(column, allAccounts, recordFilter, mappedPagination)
|
||
|
||
return self.db.getDistinctColumnValues(
|
||
BillingTransaction, dbColumn, mappedPagination, recordFilter
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error in getTransactionDistinctValues({column}): {e}")
|
||
return []
|
||
|
||
def _getEnrichedDistinctValues(
|
||
self,
|
||
column: str,
|
||
allAccounts: List[Dict],
|
||
recordFilter: Dict[str, Any],
|
||
pagination: Optional[PaginationParams],
|
||
) -> List[str]:
|
||
"""Resolve enriched columns (mandateName, userName) via batch lookup."""
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
|
||
if column == "mandateName":
|
||
mandateIds = list({a.get("mandateId") for a in allAccounts if a.get("mandateId")})
|
||
appInterface = getAppInterface(self.currentUser)
|
||
mandates = appInterface.getMandatesByIds(mandateIds)
|
||
return sorted(
|
||
{getattr(m, "label", None) or getattr(m, "name", "") or mid for mid, m in mandates.items()},
|
||
key=lambda v: v.lower(),
|
||
)
|
||
|
||
if column == "userName":
|
||
dbCol = "createdByUserId"
|
||
values = self.db.getDistinctColumnValues(BillingTransaction, dbCol, pagination, recordFilter)
|
||
if not values:
|
||
return []
|
||
appInterface = getAppInterface(self.currentUser)
|
||
users = appInterface.getUsersByIds(values)
|
||
return sorted(
|
||
{getattr(u, "displayName", None) or getattr(u, "username", None) or uid for uid, u in users.items()},
|
||
key=lambda v: v.lower(),
|
||
)
|
||
|
||
return []
|
||
|
||
def getUserTransactionsForMandates(self, mandateIds: List[str] = None, limit: int = 100) -> List[Dict[str, Any]]:
|
||
"""
|
||
Get all transactions for specified mandates.
|
||
All usage transactions are on user accounts (audit trail).
|
||
|
||
Args:
|
||
mandateIds: Optional list of mandate IDs to filter. If None, returns all.
|
||
limit: Maximum number of results
|
||
|
||
Returns:
|
||
List of transaction dicts with mandate and user context
|
||
"""
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
|
||
allTransactions = []
|
||
|
||
try:
|
||
appInterface = getAppInterface(self.currentUser)
|
||
|
||
# Get ALL accounts (both USER and MANDATE types) to cover all billing models
|
||
allAccounts = self.db.getRecordset(BillingAccount)
|
||
|
||
# Filter by mandate if specified
|
||
if mandateIds:
|
||
allAccounts = [acc for acc in allAccounts if acc.get("mandateId") in mandateIds]
|
||
|
||
# Build account to user/mandate mapping
|
||
accountMap = {}
|
||
for acc in allAccounts:
|
||
accountMap[acc.get("id")] = {
|
||
"mandateId": acc.get("mandateId"),
|
||
"userId": acc.get("userId")
|
||
}
|
||
|
||
# Get user info efficiently
|
||
userIds = list(set(acc.get("userId") for acc in allAccounts if acc.get("userId")))
|
||
userMap = {}
|
||
for userId in userIds:
|
||
user = appInterface.getUser(userId)
|
||
if user:
|
||
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
|
||
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
|
||
userMap[userId] = displayName or username or userId
|
||
|
||
# Get mandate info efficiently
|
||
mandateMap = {}
|
||
mandateIdList = list(set(acc.get("mandateId") for acc in allAccounts if acc.get("mandateId")))
|
||
for mandateId in mandateIdList:
|
||
mandate = appInterface.getMandate(mandateId)
|
||
if mandate:
|
||
mandateName = getattr(mandate, 'label', None) or getattr(mandate, 'name', None) or (mandate.get("label") or mandate.get("name", "") if isinstance(mandate, dict) else "")
|
||
mandateMap[mandateId] = mandateName
|
||
|
||
# Get transactions for all accounts and collect createdByUserIds
|
||
rawTransactions = []
|
||
for account in allAccounts:
|
||
accountId = account.get("id")
|
||
if not accountId:
|
||
continue
|
||
|
||
transactions = self.getTransactions(accountId, limit=limit)
|
||
accountInfo = accountMap.get(accountId, {})
|
||
mandateId = accountInfo.get("mandateId")
|
||
accountUserId = accountInfo.get("userId")
|
||
|
||
for t in transactions:
|
||
t["_accountUserId"] = accountUserId
|
||
t["_accountMandateId"] = mandateId
|
||
rawTransactions.append(t)
|
||
|
||
# Resolve createdByUserIds that are not yet in userMap
|
||
extraUserIds = set()
|
||
for t in rawTransactions:
|
||
cbUserId = t.get("createdByUserId")
|
||
if cbUserId and cbUserId not in userMap:
|
||
extraUserIds.add(cbUserId)
|
||
|
||
for uid in extraUserIds:
|
||
user = appInterface.getUser(uid)
|
||
if user:
|
||
displayName = getattr(user, 'displayName', None) or (user.get("displayName") if isinstance(user, dict) else None)
|
||
username = getattr(user, 'username', None) or (user.get("username") if isinstance(user, dict) else None)
|
||
userMap[uid] = displayName or username or uid
|
||
|
||
# Enrich transactions
|
||
for t in rawTransactions:
|
||
mandateId = t.pop("_accountMandateId", None)
|
||
accountUserId = t.pop("_accountUserId", None)
|
||
t["mandateId"] = mandateId
|
||
t["mandateName"] = mandateMap.get(mandateId, "")
|
||
# Prefer createdByUserId (per-transaction) over account-derived userId
|
||
txUserId = t.get("createdByUserId") or accountUserId
|
||
t["userId"] = txUserId
|
||
t["userName"] = userMap.get(txUserId, txUserId) if txUserId else None
|
||
allTransactions.append(t)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting user transactions for mandates: {e}")
|
||
|
||
# Sort by creation date descending and limit
|
||
_sortBillingTransactionsBySysCreatedAtDesc(allTransactions, "getUserTransactionsForMandates")
|
||
return allTransactions[:limit]
|