2048 lines
81 KiB
Python
2048 lines
81 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
||
# All rights reserved.
|
||
"""
|
||
Billing routes for the backend API.
|
||
Implements the endpoints for billing management and usage tracking.
|
||
|
||
Features:
|
||
- User endpoints: View balance, transactions, statistics
|
||
- Admin endpoints: Manage settings, add credits, view all accounts
|
||
"""
|
||
|
||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query, Header
|
||
from typing import List, Dict, Any, Optional
|
||
from fastapi import status
|
||
import logging
|
||
from datetime import date, datetime, timezone
|
||
from pydantic import BaseModel, Field
|
||
|
||
# Import auth module
|
||
from modules.auth import limiter, requireSysAdminRole, getRequestContext, RequestContext
|
||
|
||
# Import billing components
|
||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface, _getRootInterface
|
||
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
||
import json
|
||
import math
|
||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||
from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues, _handleFilterValuesRequest
|
||
from modules.datamodels.datamodelBilling import (
|
||
BillingAccount,
|
||
BillingTransaction,
|
||
BillingSettings,
|
||
TransactionTypeEnum,
|
||
ReferenceTypeEnum,
|
||
PeriodTypeEnum,
|
||
BillingBalanceResponse,
|
||
BillingStatisticsResponse,
|
||
BillingStatisticsChartData,
|
||
BillingCheckResult,
|
||
)
|
||
|
||
# Configure logger
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# =============================================================================
|
||
# Billing RBAC Data Scope
|
||
# =============================================================================
|
||
#
|
||
# RBAC rules for billing data visibility:
|
||
#
|
||
# SysAdmin → ALL transactions and statistics across all mandates
|
||
# Mandate-Admin → ALL user data within their administrated mandates
|
||
# Feature-Instance-Admin→ Data for their administrated feature instances
|
||
# Regular User → ONLY their own data within their mandates
|
||
#
|
||
|
||
class BillingDataScope:
|
||
"""
|
||
Determines what billing data a user can see based on RBAC roles.
|
||
|
||
Evaluated once per request and used to filter transactions/statistics.
|
||
"""
|
||
__slots__ = ('isGlobalAdmin', 'adminMandateIds', 'adminFeatureInstanceIds',
|
||
'memberMandateIds', 'userId')
|
||
|
||
def __init__(self, userId: str):
|
||
self.isGlobalAdmin: bool = False
|
||
self.adminMandateIds: list = []
|
||
self.adminFeatureInstanceIds: list = []
|
||
self.memberMandateIds: list = []
|
||
self.userId: str = userId
|
||
|
||
|
||
def _getBillingDataScope(user) -> BillingDataScope:
|
||
"""
|
||
Determine what billing data a user can see based on RBAC.
|
||
|
||
Uses rootInterface (privileged) to check roles across all mandates
|
||
and feature instances without RBAC restrictions on the lookup itself.
|
||
|
||
Returns:
|
||
BillingDataScope with the user's visibility boundaries.
|
||
"""
|
||
scope = BillingDataScope(userId=user.id)
|
||
|
||
from modules.auth.authentication import _hasSysAdminRole
|
||
if _hasSysAdminRole(str(user.id)):
|
||
scope.isGlobalAdmin = True
|
||
return scope
|
||
|
||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||
rootInterface = getRootInterface()
|
||
|
||
# --- Mandate roles ---
|
||
userMandates = rootInterface.getUserMandates(user.id)
|
||
for um in userMandates:
|
||
mandateId = getattr(um, 'mandateId', None)
|
||
umId = getattr(um, 'id', None)
|
||
if not mandateId or not umId:
|
||
continue
|
||
|
||
roleIds = rootInterface.getRoleIdsForUserMandate(umId)
|
||
isAdmin = False
|
||
for roleId in roleIds:
|
||
role = rootInterface.getRole(roleId)
|
||
if role and role.roleLabel == "admin" and not role.featureInstanceId:
|
||
isAdmin = True
|
||
break
|
||
|
||
if isAdmin:
|
||
scope.adminMandateIds.append(mandateId)
|
||
else:
|
||
scope.memberMandateIds.append(mandateId)
|
||
|
||
# --- Feature instance roles ---
|
||
featureAccesses = rootInterface.getFeatureAccessesForUser(user.id)
|
||
for fa in featureAccesses:
|
||
fiId = getattr(fa, 'featureInstanceId', None)
|
||
faId = getattr(fa, 'id', None)
|
||
if not fiId or not faId:
|
||
continue
|
||
|
||
roleIds = rootInterface.getRoleIdsForFeatureAccess(faId)
|
||
for roleId in roleIds:
|
||
role = rootInterface.getRole(roleId)
|
||
if role and role.roleLabel == "admin":
|
||
scope.adminFeatureInstanceIds.append(fiId)
|
||
break
|
||
|
||
logger.debug(
|
||
f"BillingDataScope for user {user.id}: "
|
||
f"globalAdmin={scope.isGlobalAdmin}, "
|
||
f"adminMandates={scope.adminMandateIds}, "
|
||
f"adminInstances={scope.adminFeatureInstanceIds}, "
|
||
f"memberMandates={scope.memberMandateIds}"
|
||
)
|
||
return scope
|
||
|
||
|
||
def _isAdminOfMandate(ctx: RequestContext, targetMandateId: str) -> bool:
|
||
"""Check if user is SysAdmin or admin of the specified mandate."""
|
||
if ctx.hasSysAdminRole:
|
||
return True
|
||
try:
|
||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||
rootInterface = getRootInterface()
|
||
userMandates = rootInterface.getUserMandates(str(ctx.user.id))
|
||
for um in userMandates:
|
||
if str(getattr(um, 'mandateId', None)) != str(targetMandateId):
|
||
continue
|
||
if not getattr(um, 'enabled', True):
|
||
continue
|
||
umId = str(getattr(um, 'id', ''))
|
||
roleIds = rootInterface.getRoleIdsForUserMandate(umId)
|
||
for roleId in roleIds:
|
||
role = rootInterface.getRole(roleId)
|
||
if role and role.roleLabel == "admin" and not role.featureInstanceId:
|
||
return True
|
||
return False
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
def _isMemberOfMandate(ctx: RequestContext, targetMandateId: str) -> bool:
|
||
"""Check if user has any enabled membership in the specified mandate."""
|
||
try:
|
||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||
rootInterface = getRootInterface()
|
||
userMandates = rootInterface.getUserMandates(str(ctx.user.id))
|
||
for um in userMandates:
|
||
if str(getattr(um, 'mandateId', None)) != str(targetMandateId):
|
||
continue
|
||
if not getattr(um, 'enabled', True):
|
||
continue
|
||
return True
|
||
return False
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
def _filterTransactionsByScope(transactions: list, scope: BillingDataScope) -> list:
|
||
"""
|
||
Filter a list of transaction dicts based on the user's BillingDataScope.
|
||
|
||
Rules:
|
||
- SysAdmin: no filter
|
||
- Mandate-Admin: all transactions in their admin mandates
|
||
- Feature-Instance-Admin: transactions for their admin feature instances
|
||
- Regular user: only transactions where createdByUserId/userId matches
|
||
"""
|
||
if scope.isGlobalAdmin:
|
||
return transactions
|
||
|
||
adminMandateSet = set(scope.adminMandateIds)
|
||
adminFiSet = set(scope.adminFeatureInstanceIds)
|
||
memberMandateSet = set(scope.memberMandateIds)
|
||
|
||
result = []
|
||
for t in transactions:
|
||
mandateId = t.get("mandateId")
|
||
fiId = t.get("featureInstanceId")
|
||
txUserId = t.get("createdByUserId") or t.get("userId")
|
||
|
||
# Mandate admin → sees all transactions in their mandate
|
||
if mandateId and mandateId in adminMandateSet:
|
||
result.append(t)
|
||
continue
|
||
|
||
# Feature instance admin → sees all transactions for their instances
|
||
if fiId and fiId in adminFiSet:
|
||
result.append(t)
|
||
continue
|
||
|
||
# Regular member → only own transactions
|
||
if mandateId and mandateId in memberMandateSet:
|
||
if txUserId and txUserId == scope.userId:
|
||
result.append(t)
|
||
continue
|
||
|
||
return result
|
||
|
||
|
||
# =============================================================================
|
||
# Request/Response Models
|
||
# =============================================================================
|
||
|
||
class CreditAddRequest(BaseModel):
|
||
"""Request model for adding or deducting credit from an account."""
|
||
userId: Optional[str] = Field(None, description="Target user ID for audit trail only (optional)")
|
||
amount: float = Field(..., description="Amount in CHF. Positive = credit, negative = deduction. Must not be zero.")
|
||
description: str = Field(default="Manual credit", description="Transaction description")
|
||
|
||
|
||
class CheckoutCreateRequest(BaseModel):
|
||
"""Request model for creating Stripe Checkout Session."""
|
||
userId: Optional[str] = Field(None, description="Target user ID for audit trail only (optional)")
|
||
amount: float = Field(..., gt=0, description="Amount to pay in CHF (must be in allowed presets)")
|
||
returnUrl: str = Field(..., min_length=1, description="Absolute frontend URL used for Stripe success/cancel redirects")
|
||
|
||
|
||
class CheckoutCreateResponse(BaseModel):
|
||
"""Response model for Checkout Session creation."""
|
||
redirectUrl: str = Field(..., description="Stripe Checkout URL for redirect")
|
||
|
||
|
||
class CheckoutConfirmRequest(BaseModel):
|
||
"""Request model for confirming Stripe Checkout after redirect."""
|
||
sessionId: str = Field(..., min_length=1, description="Stripe Checkout Session ID (cs_xxx)")
|
||
|
||
|
||
class CheckoutConfirmResponse(BaseModel):
|
||
"""Response model for Stripe Checkout confirmation."""
|
||
credited: bool = Field(..., description="True if a new billing credit was created")
|
||
alreadyCredited: bool = Field(..., description="True if session was already credited before")
|
||
sessionId: str = Field(..., description="Stripe Checkout Session ID")
|
||
mandateId: str = Field(..., description="Mandate ID from Stripe metadata")
|
||
amountChf: float = Field(..., description="Credited amount in CHF")
|
||
|
||
|
||
class BillingSettingsUpdate(BaseModel):
|
||
"""Request model for updating billing settings."""
|
||
warningThresholdPercent: Optional[float] = Field(None, ge=0, le=100)
|
||
notifyOnWarning: Optional[bool] = None
|
||
notifyEmails: Optional[List[str]] = None
|
||
autoRechargeEnabled: Optional[bool] = None
|
||
rechargeAmountCHF: Optional[float] = Field(None, gt=0)
|
||
rechargeMaxPerMonth: Optional[int] = Field(None, ge=0)
|
||
|
||
|
||
class TransactionResponse(BaseModel):
|
||
"""Response model for a billing transaction."""
|
||
id: str
|
||
accountId: str
|
||
transactionType: TransactionTypeEnum
|
||
amount: float
|
||
description: str
|
||
referenceType: Optional[ReferenceTypeEnum]
|
||
workflowId: Optional[str]
|
||
featureCode: Optional[str]
|
||
featureInstanceId: Optional[str] = None
|
||
aicoreProvider: Optional[str]
|
||
aicoreModel: Optional[str] = None
|
||
createdByUserId: Optional[str] = None
|
||
createdAt: Optional[datetime]
|
||
mandateId: Optional[str] = None
|
||
mandateName: Optional[str] = None
|
||
|
||
|
||
class AccountSummary(BaseModel):
|
||
"""Summary of a billing account."""
|
||
id: str
|
||
mandateId: str
|
||
userId: Optional[str]
|
||
balance: float
|
||
warningThreshold: float
|
||
enabled: bool
|
||
|
||
|
||
class UsageReportResponse(BaseModel):
|
||
"""Usage report for a period."""
|
||
period: str
|
||
totalCost: float
|
||
transactionCount: int
|
||
costByProvider: Dict[str, float]
|
||
costByModel: Dict[str, float] = {}
|
||
costByFeature: Dict[str, float]
|
||
|
||
|
||
# =============================================================================
|
||
# Response Models for Mandate/User Views
|
||
# =============================================================================
|
||
|
||
class MandateBalanceResponse(BaseModel):
|
||
"""Mandate-level balance summary."""
|
||
mandateId: str
|
||
mandateName: str
|
||
totalBalance: float
|
||
userCount: int
|
||
warningThresholdPercent: float
|
||
|
||
|
||
class UserBalanceResponse(BaseModel):
|
||
"""User-level balance summary."""
|
||
accountId: str
|
||
mandateId: str
|
||
mandateName: str
|
||
userId: str
|
||
userName: str
|
||
balance: float
|
||
warningThreshold: float
|
||
isWarning: bool
|
||
enabled: bool
|
||
|
||
|
||
class UserTransactionResponse(BaseModel):
|
||
"""User-level transaction with user context."""
|
||
id: str
|
||
accountId: str
|
||
transactionType: TransactionTypeEnum
|
||
amount: float
|
||
description: str
|
||
referenceType: Optional[ReferenceTypeEnum]
|
||
workflowId: Optional[str]
|
||
featureCode: Optional[str]
|
||
featureInstanceId: Optional[str] = None
|
||
aicoreProvider: Optional[str]
|
||
aicoreModel: Optional[str] = None
|
||
createdByUserId: Optional[str] = None
|
||
createdAt: Optional[datetime]
|
||
mandateId: Optional[str] = None
|
||
mandateName: Optional[str] = None
|
||
userId: Optional[str] = None
|
||
userName: Optional[str] = None
|
||
|
||
|
||
def _getStripeClient():
|
||
"""Initialize and return configured Stripe SDK module."""
|
||
from modules.shared.stripeClient import getStripeClient
|
||
return getStripeClient()
|
||
|
||
|
||
def _creditStripeSessionIfNeeded(
|
||
billingInterface,
|
||
session: Dict[str, Any],
|
||
eventId: Optional[str] = None,
|
||
) -> CheckoutConfirmResponse:
|
||
"""
|
||
Credit balance from Stripe Checkout session if not already credited.
|
||
Uses Checkout session ID for idempotency across webhook + manual confirmation flows.
|
||
"""
|
||
from modules.serviceCenter.services.serviceBilling.stripeCheckout import ALLOWED_AMOUNTS_CHF
|
||
|
||
session_id = session.get("id")
|
||
metadata = session.get("metadata") or {}
|
||
mandate_id = metadata.get("mandateId")
|
||
user_id = metadata.get("userId") or None
|
||
amount_chf_str = metadata.get("amountChf", "0")
|
||
|
||
if not session_id:
|
||
raise HTTPException(status_code=400, detail="Stripe session id missing")
|
||
if not mandate_id:
|
||
raise HTTPException(status_code=400, detail="Invalid session metadata: mandateId missing")
|
||
|
||
existing_payment_tx = billingInterface.getPaymentTransactionByReferenceId(session_id)
|
||
if existing_payment_tx:
|
||
if eventId and not billingInterface.getStripeWebhookEventByEventId(eventId):
|
||
billingInterface.createStripeWebhookEvent(eventId)
|
||
return CheckoutConfirmResponse(
|
||
credited=False,
|
||
alreadyCredited=True,
|
||
sessionId=session_id,
|
||
mandateId=mandate_id,
|
||
amountChf=float(existing_payment_tx.get("amount", 0.0)),
|
||
)
|
||
|
||
try:
|
||
amount_chf = float(amount_chf_str)
|
||
except (TypeError, ValueError):
|
||
amount_chf = None
|
||
|
||
if amount_chf is None or amount_chf not in ALLOWED_AMOUNTS_CHF:
|
||
amount_total = session.get("amount_total")
|
||
if amount_total is not None:
|
||
amount_chf = amount_total / 100.0
|
||
else:
|
||
raise HTTPException(status_code=400, detail="Invalid amount in Stripe session")
|
||
|
||
settings = billingInterface.getSettings(mandate_id)
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail="Billing settings not found")
|
||
|
||
account = billingInterface.getOrCreateMandateAccount(mandate_id, initialBalance=0.0)
|
||
|
||
transaction = BillingTransaction(
|
||
accountId=account["id"],
|
||
transactionType=TransactionTypeEnum.CREDIT,
|
||
amount=amount_chf,
|
||
description="Stripe-Zahlung",
|
||
referenceType=ReferenceTypeEnum.PAYMENT,
|
||
referenceId=session_id,
|
||
createdByUserId=user_id,
|
||
)
|
||
billingInterface.createTransaction(transaction)
|
||
|
||
if eventId and not billingInterface.getStripeWebhookEventByEventId(eventId):
|
||
billingInterface.createStripeWebhookEvent(eventId)
|
||
|
||
logger.info(f"Stripe credit applied: {amount_chf} CHF for session {session_id} on mandate {mandate_id}")
|
||
return CheckoutConfirmResponse(
|
||
credited=True,
|
||
alreadyCredited=False,
|
||
sessionId=session_id,
|
||
mandateId=mandate_id,
|
||
amountChf=amount_chf,
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# Router Setup
|
||
# =============================================================================
|
||
|
||
router = APIRouter(
|
||
prefix="/api/billing",
|
||
tags=["Billing"],
|
||
responses={404: {"description": "Not found"}}
|
||
)
|
||
|
||
# =============================================================================
|
||
# User Endpoints
|
||
# =============================================================================
|
||
|
||
@router.get("/balance", response_model=List[BillingBalanceResponse])
|
||
@limiter.limit("60/minute")
|
||
def getBalance(
|
||
request: Request,
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get billing balances for all mandates the current user belongs to.
|
||
Returns balance information for each mandate.
|
||
"""
|
||
try:
|
||
billingService = getBillingService(
|
||
ctx.user,
|
||
ctx.mandateId,
|
||
featureCode="billing"
|
||
)
|
||
|
||
balances = billingService.getBalancesForUser()
|
||
return balances
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing balance: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/balance/{targetMandateId}", response_model=BillingBalanceResponse)
|
||
@limiter.limit("60/minute")
|
||
def getBalanceForMandate(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get billing balance for a specific mandate.
|
||
"""
|
||
try:
|
||
billingService = getBillingService(
|
||
ctx.user,
|
||
targetMandateId,
|
||
featureCode="billing"
|
||
)
|
||
|
||
# Check balance
|
||
checkResult = billingService.checkBalance(0.0)
|
||
|
||
# Get mandate name from app interface
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
appInterface = getAppInterface(ctx.user, mandateId=targetMandateId)
|
||
mandate = appInterface.getMandate(targetMandateId)
|
||
mandateName = (mandate.get("label") or mandate.get("name", "")) if mandate else ""
|
||
|
||
return BillingBalanceResponse(
|
||
mandateId=targetMandateId,
|
||
mandateName=mandateName,
|
||
balance=checkResult.currentBalance or 0.0,
|
||
warningThreshold=0.0, # TODO: Get from account
|
||
isWarning=False,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing balance for mandate {targetMandateId}: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/transactions", response_model=List[TransactionResponse])
|
||
@limiter.limit("30/minute")
|
||
def getTransactions(
|
||
request: Request,
|
||
limit: int = Query(default=50, ge=1, le=500),
|
||
offset: int = Query(default=0, ge=0),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get transaction history across all mandates the user belongs to.
|
||
"""
|
||
try:
|
||
billingService = getBillingService(
|
||
ctx.user,
|
||
ctx.mandateId,
|
||
featureCode="billing"
|
||
)
|
||
|
||
# Fetch enough transactions for pagination
|
||
transactions = billingService.getTransactionHistory(limit=offset + limit)
|
||
|
||
# Convert to response model
|
||
result = []
|
||
for t in transactions[offset:offset + limit]:
|
||
result.append(TransactionResponse(
|
||
id=t.get("id"),
|
||
accountId=t.get("accountId"),
|
||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||
amount=t.get("amount", 0.0),
|
||
description=t.get("description", ""),
|
||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||
workflowId=t.get("workflowId"),
|
||
featureCode=t.get("featureCode"),
|
||
featureInstanceId=t.get("featureInstanceId"),
|
||
aicoreProvider=t.get("aicoreProvider"),
|
||
aicoreModel=t.get("aicoreModel"),
|
||
createdByUserId=t.get("createdByUserId"),
|
||
createdAt=t.get("sysCreatedAt"),
|
||
mandateId=t.get("mandateId"),
|
||
mandateName=t.get("mandateName")
|
||
))
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing transactions: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/statistics/{period}", response_model=UsageReportResponse)
|
||
@limiter.limit("30/minute")
|
||
def getStatistics(
|
||
request: Request,
|
||
period: str = Path(..., description="Period: 'day', 'month', or 'year'"),
|
||
year: int = Query(..., description="Year"),
|
||
month: Optional[int] = Query(None, description="Month (1-12, required for 'day' period)"),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get usage statistics for a period.
|
||
"""
|
||
try:
|
||
# Validate period
|
||
if period not in ["day", "month", "year"]:
|
||
raise HTTPException(status_code=400, detail="Invalid period. Use 'day', 'month', or 'year'")
|
||
|
||
if period == "day" and not month:
|
||
raise HTTPException(status_code=400, detail="Month is required for 'day' period")
|
||
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
settings = billingInterface.getSettings(ctx.mandateId)
|
||
|
||
if not settings:
|
||
return UsageReportResponse(
|
||
period=period,
|
||
totalCost=0.0,
|
||
transactionCount=0,
|
||
costByProvider={},
|
||
costByFeature={}
|
||
)
|
||
|
||
# Transactions are always on user accounts (audit trail)
|
||
account = billingInterface.getUserAccount(ctx.mandateId, ctx.user.id)
|
||
|
||
if not account:
|
||
return UsageReportResponse(
|
||
period=period,
|
||
totalCost=0.0,
|
||
transactionCount=0,
|
||
costByProvider={},
|
||
costByFeature={}
|
||
)
|
||
|
||
# Calculate date range
|
||
if period == "day":
|
||
startDate = date(year, month, 1)
|
||
if month == 12:
|
||
endDate = date(year + 1, 1, 1)
|
||
else:
|
||
endDate = date(year, month + 1, 1)
|
||
elif period == "month":
|
||
startDate = date(year, 1, 1)
|
||
endDate = date(year + 1, 1, 1)
|
||
else: # year
|
||
startDate = date(year, 1, 1)
|
||
endDate = date(year + 1, 1, 1)
|
||
|
||
# Get statistics from transactions
|
||
stats = billingInterface.calculateStatisticsFromTransactions(
|
||
account["id"],
|
||
startDate,
|
||
endDate
|
||
)
|
||
|
||
return UsageReportResponse(
|
||
period=period,
|
||
totalCost=stats.get("totalCostCHF", 0.0),
|
||
transactionCount=stats.get("transactionCount", 0),
|
||
costByProvider=stats.get("costByProvider", {}),
|
||
costByModel=stats.get("costByModel", {}),
|
||
costByFeature=stats.get("costByFeature", {})
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing statistics: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/providers", response_model=List[str])
|
||
@limiter.limit("60/minute")
|
||
def getAllowedProviders(
|
||
request: Request,
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get list of AICore providers the current user is allowed to use.
|
||
"""
|
||
try:
|
||
billingService = getBillingService(
|
||
ctx.user,
|
||
ctx.mandateId,
|
||
featureCode="billing"
|
||
)
|
||
|
||
return billingService.getallowedProviders()
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting allowed providers: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# =============================================================================
|
||
# Admin Endpoints
|
||
# =============================================================================
|
||
|
||
@router.get("/admin/settings/{targetMandateId}", response_model=Dict[str, Any])
|
||
@limiter.limit("30/minute")
|
||
def getSettingsAdmin(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Get billing settings for a mandate.
|
||
Access: SysAdmin (any mandate) or MandateAdmin (own mandate).
|
||
"""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate")
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
settings = billingInterface.getSettings(targetMandateId)
|
||
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail="Billing settings not found")
|
||
|
||
return settings
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing settings: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/admin/settings/{targetMandateId}", response_model=Dict[str, Any])
|
||
@limiter.limit("10/minute")
|
||
def createOrUpdateSettings(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
settingsUpdate: BillingSettingsUpdate = Body(...),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Create or update billing settings for a mandate.
|
||
Access: SysAdmin (any mandate) or MandateAdmin (own mandate).
|
||
"""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate")
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
existingSettings = billingInterface.getSettings(targetMandateId)
|
||
|
||
if existingSettings:
|
||
updates = settingsUpdate.model_dump(exclude_none=True)
|
||
if updates:
|
||
result = billingInterface.updateSettings(existingSettings["id"], updates)
|
||
return result or existingSettings
|
||
return existingSettings
|
||
else:
|
||
from modules.datamodels.datamodelBilling import BillingSettings
|
||
|
||
newSettings = BillingSettings(
|
||
mandateId=targetMandateId,
|
||
warningThresholdPercent=(
|
||
settingsUpdate.warningThresholdPercent
|
||
if settingsUpdate.warningThresholdPercent is not None
|
||
else 10.0
|
||
),
|
||
notifyOnWarning=(
|
||
settingsUpdate.notifyOnWarning
|
||
if settingsUpdate.notifyOnWarning is not None
|
||
else True
|
||
),
|
||
notifyEmails=settingsUpdate.notifyEmails or [],
|
||
autoRechargeEnabled=(
|
||
settingsUpdate.autoRechargeEnabled
|
||
if settingsUpdate.autoRechargeEnabled is not None
|
||
else False
|
||
),
|
||
rechargeAmountCHF=(
|
||
settingsUpdate.rechargeAmountCHF
|
||
if settingsUpdate.rechargeAmountCHF is not None
|
||
else 10.0
|
||
),
|
||
rechargeMaxPerMonth=(
|
||
settingsUpdate.rechargeMaxPerMonth
|
||
if settingsUpdate.rechargeMaxPerMonth is not None
|
||
else 3
|
||
),
|
||
)
|
||
|
||
return billingInterface.createSettings(newSettings)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error updating billing settings: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/admin/credit/{targetMandateId}", response_model=Dict[str, Any])
|
||
@limiter.limit("10/minute")
|
||
def addCredit(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
creditRequest: CreditAddRequest = Body(...),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
_admin = Depends(requireSysAdminRole)
|
||
):
|
||
"""
|
||
Add credit to a billing account (SysAdmin only).
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
settings = billingInterface.getSettings(targetMandateId)
|
||
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail="Billing settings not found for this mandate")
|
||
|
||
account = billingInterface.getOrCreateMandateAccount(targetMandateId, initialBalance=0.0)
|
||
|
||
if creditRequest.amount == 0:
|
||
raise HTTPException(status_code=400, detail="Amount must not be zero")
|
||
|
||
from modules.datamodels.datamodelBilling import BillingTransaction
|
||
|
||
isDeduction = creditRequest.amount < 0
|
||
txType = TransactionTypeEnum.DEBIT if isDeduction else TransactionTypeEnum.CREDIT
|
||
absAmount = abs(creditRequest.amount)
|
||
|
||
transaction = BillingTransaction(
|
||
accountId=account["id"],
|
||
transactionType=txType,
|
||
amount=absAmount,
|
||
description=creditRequest.description,
|
||
referenceType=ReferenceTypeEnum.ADMIN
|
||
)
|
||
|
||
result = billingInterface.createTransaction(transaction)
|
||
|
||
action = "Deducted" if isDeduction else "Added"
|
||
logger.info(f"{action} {absAmount} CHF to account {account['id']} in mandate {targetMandateId}")
|
||
|
||
return result
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error adding credit: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/checkout/create/{targetMandateId}", response_model=CheckoutCreateResponse)
|
||
@limiter.limit("10/minute")
|
||
def createCheckoutSession(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
checkoutRequest: CheckoutCreateRequest = Body(...),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Create Stripe Checkout Session for credit top-up. Returns redirect URL.
|
||
Requires mandate admin role.
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
settings = billingInterface.getSettings(targetMandateId)
|
||
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail="Billing settings not found for this mandate")
|
||
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=403, detail="Mandate admin role required to load mandate credit")
|
||
|
||
from modules.serviceCenter.services.serviceBilling.stripeCheckout import create_checkout_session
|
||
redirect_url = create_checkout_session(
|
||
mandate_id=targetMandateId,
|
||
user_id=checkoutRequest.userId,
|
||
amount_chf=checkoutRequest.amount,
|
||
return_url=checkoutRequest.returnUrl
|
||
)
|
||
return CheckoutCreateResponse(redirectUrl=redirect_url)
|
||
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error creating checkout session: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/checkout/confirm", response_model=CheckoutConfirmResponse)
|
||
@limiter.limit("20/minute")
|
||
def confirmCheckoutSession(
|
||
request: Request,
|
||
confirmRequest: CheckoutConfirmRequest = Body(...),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Confirm Stripe Checkout success by session ID and apply credit idempotently.
|
||
This is a fallback/reconciliation path in addition to webhook processing.
|
||
"""
|
||
try:
|
||
stripe = _getStripeClient()
|
||
session = stripe.checkout.Session.retrieve(confirmRequest.sessionId)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="Stripe Checkout Session not found")
|
||
|
||
from modules.shared.stripeClient import stripeToDict
|
||
session_dict = stripeToDict(session)
|
||
metadata = session_dict.get("metadata") or {}
|
||
mandate_id = metadata.get("mandateId")
|
||
user_id = metadata.get("userId") or None
|
||
|
||
if not mandate_id:
|
||
raise HTTPException(status_code=400, detail="Invalid session metadata: mandateId missing")
|
||
|
||
payment_status = session_dict.get("payment_status")
|
||
if payment_status != "paid":
|
||
raise HTTPException(status_code=409, detail=f"Payment not completed yet (payment_status={payment_status})")
|
||
|
||
billingInterface = getBillingInterface(ctx.user, mandate_id)
|
||
settings = billingInterface.getSettings(mandate_id)
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail="Billing settings not found")
|
||
|
||
if not _isAdminOfMandate(ctx, mandate_id):
|
||
raise HTTPException(status_code=403, detail="Mandate admin role required")
|
||
|
||
root_billing_interface = _getRootInterface()
|
||
return _creditStripeSessionIfNeeded(root_billing_interface, session_dict, eventId=None)
|
||
except HTTPException:
|
||
raise
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except Exception as e:
|
||
logger.error(f"Error confirming checkout session {confirmRequest.sessionId}: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/webhook/stripe")
|
||
async def stripeWebhook(
|
||
request: Request,
|
||
stripe_signature: Optional[str] = Header(None, alias="Stripe-Signature")
|
||
):
|
||
"""
|
||
Stripe webhook endpoint. Verifies signature and processes checkout.session.completed.
|
||
No JWT auth - Stripe authenticates via Stripe-Signature header.
|
||
"""
|
||
from modules.shared.configuration import APP_CONFIG
|
||
|
||
webhook_secret = APP_CONFIG.get("STRIPE_WEBHOOK_SECRET")
|
||
if not webhook_secret:
|
||
logger.error("STRIPE_WEBHOOK_SECRET not configured")
|
||
raise HTTPException(status_code=500, detail="Webhook not configured")
|
||
|
||
if not stripe_signature:
|
||
raise HTTPException(status_code=400, detail="Missing Stripe-Signature header")
|
||
|
||
payload = await request.body()
|
||
|
||
try:
|
||
import stripe
|
||
event = stripe.Webhook.construct_event(
|
||
payload, stripe_signature, webhook_secret
|
||
)
|
||
except ValueError as e:
|
||
logger.warning(f"Stripe webhook invalid payload: {e}")
|
||
raise HTTPException(status_code=400, detail="Invalid payload")
|
||
except Exception as e:
|
||
logger.warning(f"Stripe webhook signature verification failed: {e}")
|
||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||
|
||
logger.info(f"Stripe webhook received: event={event.id}, type={event.type}")
|
||
|
||
# Subscription-related events
|
||
subscriptionEventTypes = {
|
||
"customer.subscription.updated",
|
||
"customer.subscription.deleted",
|
||
"invoice.paid",
|
||
"invoice.payment_failed",
|
||
"customer.subscription.trial_will_end",
|
||
}
|
||
|
||
# Checkout events (existing)
|
||
checkoutEventTypes = {"checkout.session.completed", "checkout.session.async_payment_succeeded"}
|
||
|
||
if event.type in subscriptionEventTypes:
|
||
_handleSubscriptionWebhook(event)
|
||
return {"received": True}
|
||
|
||
if event.type not in checkoutEventTypes:
|
||
return {"received": True}
|
||
|
||
session = event.data.object
|
||
event_id = event.id
|
||
|
||
sessionMode = session.get("mode") if hasattr(session, "get") else getattr(session, "mode", None)
|
||
if sessionMode == "subscription":
|
||
_handleSubscriptionCheckoutCompleted(session, event_id)
|
||
return {"received": True}
|
||
|
||
billingInterface = _getRootInterface()
|
||
if billingInterface.getStripeWebhookEventByEventId(event_id):
|
||
logger.info(f"Stripe event {event_id} already processed, skipping")
|
||
return {"received": True}
|
||
|
||
session_dict = session.to_dict_recursive() if hasattr(session, "to_dict_recursive") else dict(session)
|
||
result = _creditStripeSessionIfNeeded(billingInterface, session_dict, eventId=event_id)
|
||
logger.info(
|
||
f"Stripe webhook processed session {result.sessionId}: "
|
||
f"credited={result.credited}, alreadyCredited={result.alreadyCredited}"
|
||
)
|
||
return {"received": True}
|
||
|
||
|
||
def _handleSubscriptionCheckoutCompleted(session, eventId: str) -> None:
|
||
"""Handle checkout.session.completed for mode=subscription.
|
||
Resolves the local PENDING record by ID from webhook metadata and transitions it."""
|
||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
|
||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, _getPlan
|
||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||
getService as getSubscriptionService,
|
||
_notifySubscriptionChange,
|
||
)
|
||
from modules.security.rootAccess import getRootUser
|
||
from datetime import datetime, timezone
|
||
|
||
if not isinstance(session, dict):
|
||
from modules.shared.stripeClient import stripeToDict
|
||
session = stripeToDict(session)
|
||
|
||
metadata = session.get("metadata") or {}
|
||
subscriptionRecordId = metadata.get("subscriptionRecordId")
|
||
mandateId = metadata.get("mandateId")
|
||
planKey = metadata.get("planKey", "")
|
||
|
||
platformUrl = metadata.get("platformUrl", "")
|
||
|
||
if not subscriptionRecordId:
|
||
stripeSub = session.get("subscription")
|
||
if stripeSub:
|
||
try:
|
||
from modules.shared.stripeClient import getStripeClient
|
||
stripe = getStripeClient()
|
||
from modules.shared.stripeClient import stripeToDict
|
||
subObj = stripeToDict(stripe.Subscription.retrieve(stripeSub))
|
||
metadata = subObj.get("metadata") or {}
|
||
subscriptionRecordId = metadata.get("subscriptionRecordId")
|
||
mandateId = metadata.get("mandateId")
|
||
planKey = metadata.get("planKey", "")
|
||
platformUrl = platformUrl or metadata.get("platformUrl", "")
|
||
except Exception:
|
||
pass
|
||
|
||
stripeSubId = session.get("subscription")
|
||
|
||
if not mandateId or not subscriptionRecordId:
|
||
logger.warning("Subscription checkout missing metadata: %s", metadata)
|
||
return
|
||
|
||
subInterface = getSubRootInterface()
|
||
rootUser = getRootUser()
|
||
|
||
sub = subInterface.getById(subscriptionRecordId)
|
||
if not sub:
|
||
logger.error("Subscription record %s not found for checkout webhook", subscriptionRecordId)
|
||
return
|
||
if sub.get("status") != SubscriptionStatusEnum.PENDING.value:
|
||
logger.warning("Subscription %s is %s, expected PENDING — skipping", subscriptionRecordId, sub.get("status"))
|
||
return
|
||
|
||
stripeData: Dict[str, Any] = {}
|
||
if stripeSubId:
|
||
stripeData["stripeSubscriptionId"] = stripeSubId
|
||
try:
|
||
from modules.shared.stripeClient import getStripeClient
|
||
stripe = getStripeClient()
|
||
from modules.shared.stripeClient import stripeToDict
|
||
stripeSub = stripeToDict(stripe.Subscription.retrieve(stripeSubId, expand=["items"]))
|
||
|
||
if stripeSub.get("current_period_start"):
|
||
stripeData["currentPeriodStart"] = datetime.fromtimestamp(
|
||
stripeSub["current_period_start"], tz=timezone.utc
|
||
).isoformat()
|
||
if stripeSub.get("current_period_end"):
|
||
stripeData["currentPeriodEnd"] = datetime.fromtimestamp(
|
||
stripeSub["current_period_end"], tz=timezone.utc
|
||
).isoformat()
|
||
|
||
from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import getStripePricesForPlan
|
||
priceMapping = getStripePricesForPlan(planKey)
|
||
items = stripeSub.get("items") or {}
|
||
if not isinstance(items, dict):
|
||
items = dict(items)
|
||
for item in items.get("data", []):
|
||
priceId = (item.get("price") or {}).get("id", "")
|
||
if priceMapping and priceId == priceMapping.stripePriceIdUsers:
|
||
stripeData["stripeItemIdUsers"] = item["id"]
|
||
elif priceMapping and priceId == priceMapping.stripePriceIdInstances:
|
||
stripeData["stripeItemIdInstances"] = item["id"]
|
||
except Exception as e:
|
||
logger.error("Error retrieving Stripe subscription %s: %s", stripeSubId, e)
|
||
|
||
if stripeData:
|
||
subInterface.updateFields(subscriptionRecordId, stripeData)
|
||
|
||
operative = subInterface.getOperativeForMandate(mandateId)
|
||
hasActivePredecessor = operative is not None and operative["id"] != subscriptionRecordId
|
||
|
||
if hasActivePredecessor:
|
||
toStatus = SubscriptionStatusEnum.SCHEDULED
|
||
if operative.get("recurring", True):
|
||
operativeStripeId = operative.get("stripeSubscriptionId")
|
||
if operativeStripeId:
|
||
try:
|
||
from modules.shared.stripeClient import getStripeClient
|
||
stripe = getStripeClient()
|
||
stripe.Subscription.modify(operativeStripeId, cancel_at_period_end=True)
|
||
except Exception as e:
|
||
logger.error("Failed to set cancel_at_period_end on predecessor %s: %s", operativeStripeId, e)
|
||
subInterface.updateFields(operative["id"], {"recurring": False})
|
||
effectiveFrom = operative.get("currentPeriodEnd")
|
||
if effectiveFrom:
|
||
subInterface.updateFields(subscriptionRecordId, {"effectiveFrom": effectiveFrom})
|
||
else:
|
||
toStatus = SubscriptionStatusEnum.ACTIVE
|
||
|
||
try:
|
||
subInterface.transitionStatus(
|
||
subscriptionRecordId, SubscriptionStatusEnum.PENDING, toStatus,
|
||
{"recurring": True},
|
||
)
|
||
except Exception as e:
|
||
logger.error("Failed to transition subscription %s: %s", subscriptionRecordId, e)
|
||
return
|
||
|
||
subService = getSubscriptionService(rootUser, mandateId)
|
||
subService.invalidateCache(mandateId)
|
||
|
||
if toStatus == SubscriptionStatusEnum.ACTIVE:
|
||
plan = _getPlan(planKey)
|
||
updatedSub = subInterface.getById(subscriptionRecordId)
|
||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=updatedSub, platformUrl=platformUrl)
|
||
|
||
try:
|
||
billingIf = _getRootInterface()
|
||
billingIf.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung")
|
||
except Exception as ex:
|
||
logger.error("creditSubscriptionBudget on activation failed: %s", ex)
|
||
|
||
logger.info(
|
||
"Checkout completed: sub=%s -> %s, mandate=%s, plan=%s",
|
||
subscriptionRecordId, toStatus.value, mandateId, planKey,
|
||
)
|
||
|
||
|
||
def _handleSubscriptionWebhook(event) -> None:
|
||
"""Process Stripe subscription webhook events.
|
||
All record resolution is by stripeSubscriptionId — no mandate-based guessing."""
|
||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
|
||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, _getPlan
|
||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||
getService as getSubscriptionService,
|
||
_notifySubscriptionChange,
|
||
)
|
||
from modules.security.rootAccess import getRootUser
|
||
from datetime import datetime, timezone
|
||
|
||
obj = event.data.object
|
||
rawSub = obj.get("id") if event.type.startswith("customer.subscription") else obj.get("subscription")
|
||
stripeSubId = rawSub.get("id") if isinstance(rawSub, dict) else rawSub
|
||
if not stripeSubId:
|
||
logger.warning("Subscription webhook %s has no subscription ID", event.type)
|
||
return
|
||
|
||
subInterface = getSubRootInterface()
|
||
sub = subInterface.getByStripeSubscriptionId(stripeSubId)
|
||
if not sub:
|
||
logger.warning("No local record for Stripe subscription %s (event: %s)", stripeSubId, event.type)
|
||
return
|
||
|
||
subId = sub["id"]
|
||
mandateId = sub["mandateId"]
|
||
currentStatus = SubscriptionStatusEnum(sub["status"])
|
||
rootUser = getRootUser()
|
||
subService = getSubscriptionService(rootUser, mandateId)
|
||
|
||
subMetadata = obj.get("metadata") or {}
|
||
webhookPlatformUrl = subMetadata.get("platformUrl", "")
|
||
|
||
if event.type == "customer.subscription.updated":
|
||
stripeStatus = obj.get("status", "")
|
||
|
||
periodData: Dict[str, Any] = {}
|
||
if obj.get("current_period_start"):
|
||
periodData["currentPeriodStart"] = datetime.fromtimestamp(
|
||
obj["current_period_start"], tz=timezone.utc
|
||
).isoformat()
|
||
if obj.get("current_period_end"):
|
||
periodData["currentPeriodEnd"] = datetime.fromtimestamp(
|
||
obj["current_period_end"], tz=timezone.utc
|
||
).isoformat()
|
||
if periodData:
|
||
subInterface.updateFields(subId, periodData)
|
||
|
||
if stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.SCHEDULED:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE)
|
||
subService.invalidateCache(mandateId)
|
||
planKey = sub.get("planKey", "")
|
||
plan = _getPlan(planKey)
|
||
refreshedSub = subInterface.getById(subId)
|
||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedSub, platformUrl=webhookPlatformUrl)
|
||
try:
|
||
_getRootInterface().creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung")
|
||
except Exception as ex:
|
||
logger.error("creditSubscriptionBudget SCHEDULED->ACTIVE failed: %s", ex)
|
||
logger.info("SCHEDULED -> ACTIVE for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.PAST_DUE:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.ACTIVE)
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("PAST_DUE -> ACTIVE for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif stripeStatus == "past_due" and currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE)
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("ACTIVE -> PAST_DUE for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("Period renewed for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif event.type == "customer.subscription.deleted":
|
||
if currentStatus not in (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE,
|
||
SubscriptionStatusEnum.SCHEDULED):
|
||
logger.info("Ignoring deletion for sub %s in status %s", subId, currentStatus.value)
|
||
return
|
||
|
||
subInterface.transitionStatus(subId, currentStatus, SubscriptionStatusEnum.EXPIRED)
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("Sub %s -> EXPIRED (Stripe deleted, mandate %s)", subId, mandateId)
|
||
|
||
scheduled = subInterface.getScheduledForMandate(mandateId)
|
||
if scheduled:
|
||
try:
|
||
subInterface.transitionStatus(
|
||
scheduled["id"], SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE,
|
||
)
|
||
subService.invalidateCache(mandateId)
|
||
plan = _getPlan(scheduled.get("planKey", ""))
|
||
refreshedScheduled = subInterface.getById(scheduled["id"])
|
||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedScheduled, platformUrl=webhookPlatformUrl)
|
||
logger.info("Promoted SCHEDULED sub %s -> ACTIVE (mandate %s)", scheduled["id"], mandateId)
|
||
except Exception as e:
|
||
logger.error("Failed to promote SCHEDULED sub %s: %s", scheduled["id"], e)
|
||
|
||
elif event.type == "invoice.payment_failed":
|
||
if currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE)
|
||
subService.invalidateCache(mandateId)
|
||
plan = _getPlan(sub.get("planKey", ""))
|
||
_notifySubscriptionChange(mandateId, "payment_failed", plan, subscriptionRecord=sub, platformUrl=webhookPlatformUrl)
|
||
logger.info("Payment failed for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif event.type == "customer.subscription.trial_will_end":
|
||
logger.info("Trial ending soon for sub %s (mandate %s)", subId, mandateId)
|
||
try:
|
||
from modules.shared.notifyMandateAdmins import notifyMandateAdmins
|
||
notifyMandateAdmins(
|
||
mandateId,
|
||
"[PowerOn] Testphase endet bald",
|
||
"Testphase endet bald",
|
||
[
|
||
"Die kostenlose Testphase für Ihren Mandanten endet in Kürze.",
|
||
"Bitte wählen Sie einen Plan unter Billing-Verwaltung › Abonnement.",
|
||
],
|
||
)
|
||
except Exception as e:
|
||
logger.error("Failed to notify about trial ending: %s", e)
|
||
|
||
elif event.type == "invoice.paid":
|
||
period_ts = obj.get("period_start")
|
||
periodLabel = ""
|
||
if period_ts:
|
||
period_start_at = datetime.fromtimestamp(int(period_ts), tz=timezone.utc)
|
||
periodLabel = period_start_at.strftime("%Y-%m-%d")
|
||
try:
|
||
billing_if = _getRootInterface()
|
||
billing_if.resetStorageBillingPeriod(mandateId, period_start_at)
|
||
billing_if.reconcileMandateStorageBilling(mandateId)
|
||
except Exception as ex:
|
||
logger.error("Storage billing on invoice.paid failed: %s", ex)
|
||
|
||
planKey = sub.get("planKey", "")
|
||
try:
|
||
billing_if = _getRootInterface()
|
||
billing_if.creditSubscriptionBudget(mandateId, planKey, periodLabel=periodLabel or "Periodenverlängerung")
|
||
except Exception as ex:
|
||
logger.error("creditSubscriptionBudget on invoice.paid failed: %s", ex)
|
||
|
||
logger.info("Invoice paid for sub %s (mandate %s)", subId, mandateId)
|
||
return None
|
||
|
||
|
||
@router.get("/admin/accounts/{targetMandateId}", response_model=List[AccountSummary])
|
||
@limiter.limit("30/minute")
|
||
def getAccounts(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Get all billing accounts for a mandate.
|
||
Access: SysAdmin (any mandate) or MandateAdmin (own mandate).
|
||
"""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate")
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
|
||
# Get all accounts for this mandate via interface
|
||
accounts = billingInterface.getAccountsByMandate(targetMandateId)
|
||
|
||
result = []
|
||
for acc in accounts:
|
||
result.append(AccountSummary(
|
||
id=acc.get("id"),
|
||
mandateId=acc.get("mandateId"),
|
||
userId=acc.get("userId"),
|
||
balance=acc.get("balance", 0.0),
|
||
warningThreshold=acc.get("warningThreshold", 0.0),
|
||
enabled=acc.get("enabled", True)
|
||
))
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing accounts: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
class MandateUserSummary(BaseModel):
|
||
"""Summary of a user for billing admin purposes."""
|
||
id: str
|
||
username: Optional[str] = None
|
||
email: Optional[str] = None
|
||
firstName: Optional[str] = None
|
||
lastName: Optional[str] = None
|
||
displayName: Optional[str] = None
|
||
|
||
|
||
@router.get("/admin/users/{targetMandateId}", response_model=List[MandateUserSummary])
|
||
@limiter.limit("30/minute")
|
||
def getUsersForMandate(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Get all users belonging to a mandate.
|
||
Access: SysAdmin (any mandate) or MandateAdmin (own mandate).
|
||
Used by billing admin to select users for credit assignment.
|
||
"""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate")
|
||
try:
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
|
||
appInterface = getAppInterface(ctx.user, mandateId=targetMandateId)
|
||
userMandates = appInterface.getUserMandatesByMandate(targetMandateId)
|
||
|
||
result = []
|
||
for um in userMandates:
|
||
userId = um.get("userId") if isinstance(um, dict) else getattr(um, "userId", None)
|
||
if not userId:
|
||
continue
|
||
|
||
user = appInterface.getUser(userId)
|
||
if not user:
|
||
continue
|
||
|
||
# Handle both Pydantic models and dicts
|
||
if isinstance(user, dict):
|
||
username = user.get("username", "")
|
||
firstName = user.get("firstName", "")
|
||
lastName = user.get("lastName", "")
|
||
email = user.get("email", "")
|
||
else:
|
||
username = getattr(user, "username", "") or ""
|
||
firstName = getattr(user, "firstName", "") or ""
|
||
lastName = getattr(user, "lastName", "") or ""
|
||
email = getattr(user, "email", "") or ""
|
||
|
||
displayName = f"{firstName} {lastName}".strip() or username or userId
|
||
|
||
result.append(MandateUserSummary(
|
||
id=userId,
|
||
username=username,
|
||
email=email,
|
||
firstName=firstName,
|
||
lastName=lastName,
|
||
displayName=displayName
|
||
))
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting users for mandate {targetMandateId}: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
def _attachCreatedByUserNamesToTransactionRows(rows: List[Dict[str, Any]]) -> None:
|
||
"""Resolve createdByUserId to userName using root app interface (sysadmin transaction views)."""
|
||
try:
|
||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||
|
||
appRoot = getRootInterface()
|
||
userNames: Dict[str, str] = {}
|
||
for row in rows:
|
||
uid = row.get("createdByUserId")
|
||
if not uid:
|
||
row["userName"] = ""
|
||
continue
|
||
if uid not in userNames:
|
||
try:
|
||
u = appRoot.getUser(uid)
|
||
userNames[uid] = u.username if u else uid[:8]
|
||
except Exception:
|
||
userNames[uid] = uid[:8]
|
||
row["userName"] = userNames.get(uid, "")
|
||
except Exception:
|
||
for row in rows:
|
||
uid = row.get("createdByUserId")
|
||
row["userName"] = uid[:8] if uid else ""
|
||
|
||
|
||
def _enrichTransactionRows(transactions) -> List[Dict[str, Any]]:
|
||
"""Convert raw transaction dicts to enriched TransactionResponse rows with resolved usernames."""
|
||
result = []
|
||
for t in transactions:
|
||
row = TransactionResponse(
|
||
id=t.get("id"),
|
||
accountId=t.get("accountId"),
|
||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||
amount=t.get("amount", 0.0),
|
||
description=t.get("description", ""),
|
||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||
workflowId=t.get("workflowId"),
|
||
featureCode=t.get("featureCode"),
|
||
featureInstanceId=t.get("featureInstanceId"),
|
||
aicoreProvider=t.get("aicoreProvider"),
|
||
aicoreModel=t.get("aicoreModel"),
|
||
createdByUserId=t.get("createdByUserId"),
|
||
createdAt=t.get("sysCreatedAt")
|
||
)
|
||
result.append(row.model_dump())
|
||
|
||
_attachCreatedByUserNamesToTransactionRows(result)
|
||
return result
|
||
|
||
|
||
def _buildTransactionsList(ctx: RequestContext, targetMandateId: str) -> List[Dict[str, Any]]:
|
||
"""Build the full enriched transactions list for a mandate."""
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=5000)
|
||
|
||
result = []
|
||
for t in transactions:
|
||
row = TransactionResponse(
|
||
id=t.get("id"),
|
||
accountId=t.get("accountId"),
|
||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||
amount=t.get("amount", 0.0),
|
||
description=t.get("description", ""),
|
||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||
workflowId=t.get("workflowId"),
|
||
featureCode=t.get("featureCode"),
|
||
featureInstanceId=t.get("featureInstanceId"),
|
||
aicoreProvider=t.get("aicoreProvider"),
|
||
aicoreModel=t.get("aicoreModel"),
|
||
createdByUserId=t.get("createdByUserId"),
|
||
createdAt=t.get("sysCreatedAt")
|
||
)
|
||
result.append(row.model_dump())
|
||
|
||
_attachCreatedByUserNamesToTransactionRows(result)
|
||
return result
|
||
|
||
|
||
@router.get("/admin/transactions/{targetMandateId}")
|
||
@limiter.limit("30/minute")
|
||
def getTransactionsAdmin(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
|
||
limit: int = Query(default=100, ge=1, le=1000),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""Get all transactions for a mandate with pagination support."""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate")
|
||
try:
|
||
paginationParams: Optional[PaginationParams] = None
|
||
if pagination:
|
||
try:
|
||
paginationDict = json.loads(pagination)
|
||
if paginationDict:
|
||
paginationDict = normalize_pagination_dict(paginationDict)
|
||
paginationParams = PaginationParams(**paginationDict)
|
||
except (json.JSONDecodeError, ValueError) as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
||
|
||
if paginationParams:
|
||
# DB-level pagination — enrich only the returned page
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
result = billingInterface.getTransactionsByMandate(targetMandateId, pagination=paginationParams)
|
||
transactions = result.items if hasattr(result, 'items') else result
|
||
enrichedItems = _enrichTransactionRows(transactions)
|
||
return {
|
||
"items": enrichedItems,
|
||
"pagination": PaginationMetadata(
|
||
currentPage=paginationParams.page,
|
||
pageSize=paginationParams.pageSize,
|
||
totalItems=result.totalItems if hasattr(result, 'totalItems') else len(enrichedItems),
|
||
totalPages=result.totalPages if hasattr(result, 'totalPages') else 0,
|
||
sort=paginationParams.sort,
|
||
filters=paginationParams.filters,
|
||
).model_dump(),
|
||
}
|
||
|
||
enriched = _buildTransactionsList(ctx, targetMandateId)
|
||
return {"items": enriched, "pagination": None}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing transactions for mandate {targetMandateId}: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/admin/transactions/{targetMandateId}/filter-values")
|
||
@limiter.limit("60/minute")
|
||
def getTransactionFilterValues(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
column: str = Query(..., description="Column key"),
|
||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""Return distinct filter values for a column in mandate transactions."""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate")
|
||
try:
|
||
crossFilterParams: Optional[PaginationParams] = None
|
||
if pagination:
|
||
try:
|
||
paginationDict = json.loads(pagination)
|
||
if paginationDict:
|
||
paginationDict = normalize_pagination_dict(paginationDict)
|
||
filters = paginationDict.get("filters", {})
|
||
filters.pop(column, None)
|
||
paginationDict["filters"] = filters
|
||
paginationDict.pop("sort", None)
|
||
crossFilterParams = PaginationParams(**paginationDict)
|
||
except (json.JSONDecodeError, ValueError):
|
||
pass
|
||
|
||
# Try SQL DISTINCT for native DB columns; fallback to in-memory for enriched columns (e.g. userName)
|
||
try:
|
||
rootBillingInterface = _getRootInterface()
|
||
recordFilter = {"mandateId": targetMandateId}
|
||
values = rootBillingInterface.db.getDistinctColumnValues(
|
||
BillingTransaction, column, crossFilterParams, recordFilter
|
||
)
|
||
return sorted(values, key=lambda v: str(v).lower())
|
||
except Exception:
|
||
enriched = _buildTransactionsList(ctx, targetMandateId)
|
||
crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
|
||
return _extractDistinctValues(crossFiltered, column)
|
||
except Exception as e:
|
||
logger.error(f"Error getting filter values for transactions: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# =============================================================================
|
||
# Mandate View Endpoints (for Admins)
|
||
# =============================================================================
|
||
|
||
@router.get("/view/mandates/balances", response_model=List[MandateBalanceResponse])
|
||
@limiter.limit("30/minute")
|
||
def getMandateViewBalances(
|
||
request: Request,
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
_admin = Depends(requireSysAdminRole)
|
||
):
|
||
"""
|
||
Get mandate-level balances (SysAdmin only).
|
||
Shows aggregated balances per mandate.
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
balances = billingInterface.getMandateBalances()
|
||
|
||
return [MandateBalanceResponse(**b) for b in balances]
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting mandate view balances: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/view/mandates/transactions", response_model=List[TransactionResponse])
|
||
@limiter.limit("30/minute")
|
||
def getMandateViewTransactions(
|
||
request: Request,
|
||
limit: int = Query(default=100, ge=1, le=1000),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
_admin = Depends(requireSysAdminRole)
|
||
):
|
||
"""
|
||
Get all transactions across mandates (SysAdmin only).
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
transactions = billingInterface.getMandateTransactions(limit=limit)
|
||
|
||
result = []
|
||
for t in transactions:
|
||
result.append(TransactionResponse(
|
||
id=t.get("id"),
|
||
accountId=t.get("accountId"),
|
||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||
amount=t.get("amount", 0.0),
|
||
description=t.get("description", ""),
|
||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||
workflowId=t.get("workflowId"),
|
||
featureCode=t.get("featureCode"),
|
||
featureInstanceId=t.get("featureInstanceId"),
|
||
aicoreProvider=t.get("aicoreProvider"),
|
||
aicoreModel=t.get("aicoreModel"),
|
||
createdByUserId=t.get("createdByUserId"),
|
||
createdAt=t.get("sysCreatedAt"),
|
||
mandateId=t.get("mandateId"),
|
||
mandateName=t.get("mandateName")
|
||
))
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting mandate view transactions: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# =============================================================================
|
||
# User View Endpoints (RBAC-based)
|
||
# =============================================================================
|
||
|
||
@router.get("/view/users/balances", response_model=List[UserBalanceResponse])
|
||
@limiter.limit("30/minute")
|
||
def getUserViewBalances(
|
||
request: Request,
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get user-level balances.
|
||
|
||
RBAC filtering:
|
||
- SysAdmin: sees all user balances across all mandates
|
||
- Mandate-Admin: sees user balances for mandates they administrate
|
||
- Regular user: sees only their own balances
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
|
||
# Evaluate RBAC scope
|
||
scope = _getBillingDataScope(ctx.user)
|
||
|
||
# Determine mandate IDs for data loading
|
||
if scope.isGlobalAdmin:
|
||
mandateIds = None
|
||
else:
|
||
mandateIds = scope.adminMandateIds + scope.memberMandateIds
|
||
if not mandateIds:
|
||
return []
|
||
|
||
allBalances = billingInterface.getUserBalancesForMandates(mandateIds)
|
||
|
||
# RBAC filter: mandate admins see all in their mandates, regular users only own
|
||
if not scope.isGlobalAdmin:
|
||
adminMandateSet = set(scope.adminMandateIds)
|
||
allBalances = [
|
||
b for b in allBalances
|
||
if b.get("mandateId") in adminMandateSet or b.get("userId") == scope.userId
|
||
]
|
||
|
||
return [UserBalanceResponse(**b) for b in allBalances]
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting user view balances: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
class ViewStatisticsResponse(BaseModel):
|
||
"""Aggregated statistics across all user's mandates."""
|
||
totalCost: float = 0.0
|
||
transactionCount: int = 0
|
||
costByProvider: Dict[str, float] = {}
|
||
costByModel: Dict[str, float] = {}
|
||
costByFeature: Dict[str, float] = {}
|
||
costByMandate: Dict[str, float] = {}
|
||
timeSeries: List[Dict[str, Any]] = []
|
||
|
||
|
||
@router.get("/view/statistics")
|
||
@limiter.limit("30/minute")
|
||
def getUserViewStatistics(
|
||
request: Request,
|
||
period: str = Query(default="month", description="Period: 'day' or 'month'"),
|
||
year: int = Query(default=None, description="Year"),
|
||
month: Optional[int] = Query(None, description="Month (1-12, required for period='day')"),
|
||
scope: str = Query(default="all", description="Scope: 'personal' (own costs only), 'mandate' (filter by mandateId), 'all' (RBAC-filtered)"),
|
||
mandateId: Optional[str] = Query(None, description="Mandate ID filter (used with scope='mandate')"),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
) -> ViewStatisticsResponse:
|
||
"""
|
||
Get aggregated usage statistics across all user's mandates.
|
||
|
||
Scope:
|
||
- personal: only the current user's own transactions (ignores admin role)
|
||
- mandate: transactions for a specific mandate (requires mandateId parameter)
|
||
- all: RBAC-filtered (SysAdmin sees everything, admin sees mandate, user sees own)
|
||
|
||
- period='month': returns monthly time series for the given year
|
||
- period='day': returns daily time series for the given month/year
|
||
"""
|
||
try:
|
||
from datetime import timedelta
|
||
|
||
if year is None:
|
||
year = datetime.now().year
|
||
|
||
if period == "day" and not month:
|
||
month = datetime.now().month
|
||
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
|
||
# Evaluate RBAC scope
|
||
rbacScope = _getBillingDataScope(ctx.user)
|
||
|
||
# Determine mandate IDs for data loading
|
||
if rbacScope.isGlobalAdmin:
|
||
loadMandateIds = None
|
||
else:
|
||
loadMandateIds = rbacScope.adminMandateIds + rbacScope.memberMandateIds
|
||
if not loadMandateIds:
|
||
logger.warning("No mandate IDs found for user")
|
||
return ViewStatisticsResponse()
|
||
|
||
# Scope=mandate: restrict to specific mandate
|
||
if scope == "mandate" and mandateId:
|
||
loadMandateIds = [mandateId]
|
||
|
||
# Get all transactions
|
||
allTransactions = billingInterface.getUserTransactionsForMandates(loadMandateIds, limit=10000)
|
||
|
||
# Apply RBAC filter (respects admin/user roles)
|
||
allTransactions = _filterTransactionsByScope(allTransactions, rbacScope)
|
||
|
||
# Scope=personal: further filter to only own transactions
|
||
if scope == "personal":
|
||
userId = str(ctx.user.id)
|
||
allTransactions = [
|
||
t for t in allTransactions
|
||
if (t.get("createdByUserId") or t.get("userId")) == userId
|
||
]
|
||
|
||
logger.info(f"View statistics: {len(allTransactions)} RBAC-filtered transactions for period={period}, year={year}, month={month}")
|
||
|
||
# Calculate date range
|
||
if period == "day":
|
||
startDate = date(year, month, 1)
|
||
if month == 12:
|
||
endDate = date(year + 1, 1, 1)
|
||
else:
|
||
endDate = date(year, month + 1, 1)
|
||
else:
|
||
startDate = date(year, 1, 1)
|
||
endDate = date(year + 1, 1, 1)
|
||
|
||
# Filter by date range and only DEBIT transactions
|
||
debits = []
|
||
skippedNoDate = 0
|
||
skippedDateRange = 0
|
||
skippedNotDebit = 0
|
||
|
||
for t in allTransactions:
|
||
createdAt = t.get("sysCreatedAt")
|
||
if not createdAt:
|
||
skippedNoDate += 1
|
||
continue
|
||
|
||
# Parse date from various formats (DB stores as DOUBLE PRECISION / Unix timestamp)
|
||
txDate = None
|
||
if isinstance(createdAt, (int, float)):
|
||
txDate = datetime.fromtimestamp(createdAt).date()
|
||
elif isinstance(createdAt, datetime):
|
||
txDate = createdAt.date()
|
||
elif isinstance(createdAt, date) and not isinstance(createdAt, datetime):
|
||
txDate = createdAt
|
||
elif isinstance(createdAt, str):
|
||
try:
|
||
# Try as float string first (Unix timestamp)
|
||
txDate = datetime.fromtimestamp(float(createdAt)).date()
|
||
except (ValueError, TypeError):
|
||
try:
|
||
txDate = datetime.fromisoformat(createdAt.replace("Z", "+00:00")).date()
|
||
except (ValueError, TypeError):
|
||
skippedNoDate += 1
|
||
continue
|
||
else:
|
||
skippedNoDate += 1
|
||
continue
|
||
|
||
if txDate < startDate or txDate >= endDate:
|
||
skippedDateRange += 1
|
||
continue
|
||
|
||
# Compare transactionType - handle both string and enum
|
||
txType = t.get("transactionType")
|
||
txTypeStr = str(txType) if txType is not None else ""
|
||
if txTypeStr != "DEBIT" and txTypeStr != "TransactionTypeEnum.DEBIT":
|
||
# Also check .value for enum objects
|
||
txTypeValue = getattr(txType, 'value', txTypeStr)
|
||
if txTypeValue != "DEBIT":
|
||
skippedNotDebit += 1
|
||
continue
|
||
|
||
t["_txDate"] = txDate
|
||
debits.append(t)
|
||
|
||
logger.info(f"View statistics: {len(debits)} DEBIT transactions after filter. "
|
||
f"Skipped: noDate={skippedNoDate}, dateRange={skippedDateRange}, notDebit={skippedNotDebit}")
|
||
|
||
# Aggregate totals
|
||
totalCost = sum(t.get("amount", 0) for t in debits)
|
||
|
||
costByProvider: Dict[str, float] = {}
|
||
costByModel: Dict[str, float] = {}
|
||
costByFeature: Dict[str, float] = {}
|
||
costByMandate: Dict[str, float] = {}
|
||
|
||
for t in debits:
|
||
provider = t.get("aicoreProvider") or "unknown"
|
||
costByProvider[provider] = costByProvider.get(provider, 0) + t.get("amount", 0)
|
||
|
||
model = t.get("aicoreModel") or "unknown"
|
||
costByModel[model] = costByModel.get(model, 0) + t.get("amount", 0)
|
||
|
||
mandate = t.get("mandateName") or t.get("mandateId") or "unknown"
|
||
featureCode = t.get("featureCode") or "unknown"
|
||
featureKey = f"{mandate} / {featureCode}"
|
||
costByFeature[featureKey] = costByFeature.get(featureKey, 0) + t.get("amount", 0)
|
||
|
||
mandate = t.get("mandateName") or t.get("mandateId") or "unknown"
|
||
costByMandate[mandate] = costByMandate.get(mandate, 0) + t.get("amount", 0)
|
||
|
||
# Build time series (raw data only, no display logic)
|
||
timeSeries = []
|
||
if period == "day":
|
||
numDays = (endDate - startDate).days
|
||
for day in range(numDays):
|
||
d = startDate + timedelta(days=day)
|
||
dayCost = sum(t.get("amount", 0) for t in debits if t["_txDate"] == d)
|
||
dayCount = sum(1 for t in debits if t["_txDate"] == d)
|
||
if dayCost > 0 or dayCount > 0:
|
||
timeSeries.append({
|
||
"date": d.isoformat(),
|
||
"cost": round(dayCost, 4),
|
||
"count": dayCount
|
||
})
|
||
else:
|
||
for m in range(1, 13):
|
||
mStart = date(year, m, 1)
|
||
mEnd = date(year, m + 1, 1) if m < 12 else date(year + 1, 1, 1)
|
||
monthCost = sum(t.get("amount", 0) for t in debits if mStart <= t["_txDate"] < mEnd)
|
||
monthCount = sum(1 for t in debits if mStart <= t["_txDate"] < mEnd)
|
||
timeSeries.append({
|
||
"date": f"{year}-{m:02d}",
|
||
"cost": round(monthCost, 4),
|
||
"count": monthCount
|
||
})
|
||
|
||
return ViewStatisticsResponse(
|
||
totalCost=round(totalCost, 4),
|
||
transactionCount=len(debits),
|
||
costByProvider=costByProvider,
|
||
costByModel=costByModel,
|
||
costByFeature=costByFeature,
|
||
costByMandate=costByMandate,
|
||
timeSeries=timeSeries
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting view statistics: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/view/users/transactions", response_model=PaginatedResponse[UserTransactionResponse])
|
||
@limiter.limit("30/minute")
|
||
def getUserViewTransactions(
|
||
request: Request,
|
||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
) -> PaginatedResponse[UserTransactionResponse]:
|
||
"""
|
||
Get user-level transactions with pagination support.
|
||
|
||
RBAC filtering:
|
||
- SysAdmin: sees all user transactions across all mandates
|
||
- Mandate-Admin: sees all user transactions for mandates they administrate
|
||
- Feature-Instance-Admin: sees transactions for their feature instances
|
||
- Regular user: sees only their own transactions
|
||
|
||
Query Parameters:
|
||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
|
||
# Parse pagination params
|
||
paginationParams = None
|
||
if pagination:
|
||
import json
|
||
paginationDict = json.loads(pagination)
|
||
paginationDict = normalize_pagination_dict(paginationDict)
|
||
paginationParams = PaginationParams(**paginationDict)
|
||
|
||
# Evaluate RBAC scope
|
||
scope = _getBillingDataScope(ctx.user)
|
||
|
||
# Determine mandate IDs for data loading
|
||
if scope.isGlobalAdmin:
|
||
mandateIds = None # Load all
|
||
else:
|
||
# Load data for all mandates the user belongs to (admin + member)
|
||
mandateIds = scope.adminMandateIds + scope.memberMandateIds
|
||
if not mandateIds:
|
||
return PaginatedResponse(items=[], pagination=None)
|
||
|
||
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
|
||
|
||
# Apply RBAC filter
|
||
allTransactions = _filterTransactionsByScope(allTransactions, scope)
|
||
|
||
logger.debug(f"RBAC-filtered {len(allTransactions)} transactions for user {ctx.user.id}")
|
||
|
||
# Convert to response objects as dicts for filtering/sorting
|
||
transactionDicts = []
|
||
for t in allTransactions:
|
||
transactionDicts.append({
|
||
"id": t.get("id"),
|
||
"accountId": t.get("accountId"),
|
||
"transactionType": t.get("transactionType", "DEBIT"),
|
||
"amount": t.get("amount", 0.0),
|
||
"description": t.get("description", ""),
|
||
"referenceType": t.get("referenceType"),
|
||
"workflowId": t.get("workflowId"),
|
||
"featureCode": t.get("featureCode"),
|
||
"featureInstanceId": t.get("featureInstanceId"),
|
||
"aicoreProvider": t.get("aicoreProvider"),
|
||
"aicoreModel": t.get("aicoreModel"),
|
||
"createdByUserId": t.get("createdByUserId"),
|
||
"createdAt": t.get("sysCreatedAt"),
|
||
"mandateId": t.get("mandateId"),
|
||
"mandateName": t.get("mandateName"),
|
||
"userId": t.get("userId"),
|
||
"userName": t.get("userName"),
|
||
})
|
||
|
||
# Apply filters and sorting
|
||
filteredDicts = _applyFiltersAndSort(transactionDicts, paginationParams)
|
||
|
||
# Convert to response models
|
||
def _toResponse(d):
|
||
return UserTransactionResponse(
|
||
id=d.get("id"),
|
||
accountId=d.get("accountId"),
|
||
transactionType=TransactionTypeEnum(d.get("transactionType", "DEBIT")),
|
||
amount=d.get("amount", 0.0),
|
||
description=d.get("description", ""),
|
||
referenceType=ReferenceTypeEnum(d["referenceType"]) if d.get("referenceType") else None,
|
||
workflowId=d.get("workflowId"),
|
||
featureCode=d.get("featureCode"),
|
||
featureInstanceId=d.get("featureInstanceId"),
|
||
aicoreProvider=d.get("aicoreProvider"),
|
||
aicoreModel=d.get("aicoreModel"),
|
||
createdByUserId=d.get("createdByUserId"),
|
||
createdAt=d.get("createdAt"),
|
||
mandateId=d.get("mandateId"),
|
||
mandateName=d.get("mandateName"),
|
||
userId=d.get("userId"),
|
||
userName=d.get("userName")
|
||
)
|
||
|
||
if paginationParams:
|
||
import math
|
||
totalItems = len(filteredDicts)
|
||
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
||
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
||
endIdx = startIdx + paginationParams.pageSize
|
||
paginatedDicts = filteredDicts[startIdx:endIdx]
|
||
|
||
return PaginatedResponse(
|
||
items=[_toResponse(d) for d in paginatedDicts],
|
||
pagination=PaginationMetadata(
|
||
currentPage=paginationParams.page,
|
||
pageSize=paginationParams.pageSize,
|
||
totalItems=totalItems,
|
||
totalPages=totalPages,
|
||
sort=paginationParams.sort,
|
||
filters=paginationParams.filters
|
||
)
|
||
)
|
||
else:
|
||
return PaginatedResponse(
|
||
items=[_toResponse(d) for d in filteredDicts],
|
||
pagination=None
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting user view transactions: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/view/users/transactions/filter-values")
|
||
@limiter.limit("60/minute")
|
||
def getUserViewTransactionsFilterValues(
|
||
request: Request,
|
||
column: str = Query(..., description="Column key"),
|
||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""Return distinct filter values for a column in user transactions."""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
scope = _getBillingDataScope(ctx.user)
|
||
if scope.isGlobalAdmin:
|
||
mandateIds = None
|
||
else:
|
||
mandateIds = scope.adminMandateIds + scope.memberMandateIds
|
||
if not mandateIds:
|
||
return []
|
||
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
|
||
allTransactions = _filterTransactionsByScope(allTransactions, scope)
|
||
transactionDicts = []
|
||
for t in allTransactions:
|
||
transactionDicts.append({
|
||
"id": t.get("id"),
|
||
"accountId": t.get("accountId"),
|
||
"transactionType": t.get("transactionType", "DEBIT"),
|
||
"amount": t.get("amount", 0.0),
|
||
"description": t.get("description", ""),
|
||
"referenceType": t.get("referenceType"),
|
||
"workflowId": t.get("workflowId"),
|
||
"featureCode": t.get("featureCode"),
|
||
"featureInstanceId": t.get("featureInstanceId"),
|
||
"aicoreProvider": t.get("aicoreProvider"),
|
||
"aicoreModel": t.get("aicoreModel"),
|
||
"createdByUserId": t.get("createdByUserId"),
|
||
"createdAt": t.get("sysCreatedAt"),
|
||
"mandateId": t.get("mandateId"),
|
||
"mandateName": t.get("mandateName"),
|
||
"userId": t.get("userId"),
|
||
"userName": t.get("userName"),
|
||
})
|
||
return _handleFilterValuesRequest(transactionDicts, column, pagination)
|
||
except Exception as e:
|
||
logger.error(f"Error getting filter values for user transactions: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|