1895 lines
76 KiB
Python
1895 lines
76 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
||
# All rights reserved.
|
||
"""
|
||
Billing routes for the backend API.
|
||
Implements the endpoints for billing management and usage tracking.
|
||
|
||
Features:
|
||
- User endpoints: View balance, transactions, statistics
|
||
- Admin endpoints: Manage settings, add credits, view all accounts
|
||
"""
|
||
|
||
from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query, Header
|
||
from typing import List, Dict, Any, Optional
|
||
from fastapi import status
|
||
import logging
|
||
from datetime import date, datetime, timezone
|
||
from pydantic import BaseModel, Field
|
||
|
||
# Import auth module
|
||
from modules.auth import limiter, requirePlatformAdmin, getRequestContext, RequestContext
|
||
|
||
# Import billing components
|
||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface, getRootInterface
|
||
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
||
import json
|
||
import math
|
||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||
from modules.datamodels.datamodelBilling import (
|
||
BillingAccount,
|
||
BillingTransaction,
|
||
BillingSettings,
|
||
TransactionTypeEnum,
|
||
ReferenceTypeEnum,
|
||
PeriodTypeEnum,
|
||
BillingBalanceResponse,
|
||
BillingStatisticsResponse,
|
||
BillingStatisticsChartData,
|
||
BillingCheckResult,
|
||
)
|
||
from modules.shared.i18nRegistry import apiRouteContext
|
||
|
||
routeApiMsg = apiRouteContext("routeBilling")
|
||
|
||
# Configure logger
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# =============================================================================
|
||
# Billing RBAC Data Scope
|
||
# =============================================================================
|
||
#
|
||
# RBAC rules for billing data visibility:
|
||
#
|
||
# SysAdmin → ALL transactions and statistics across all mandates
|
||
# Mandate-Admin → ALL user data within their administrated mandates
|
||
# Feature-Instance-Admin→ Data for their administrated feature instances
|
||
# Regular User → ONLY their own data within their mandates
|
||
#
|
||
|
||
class BillingDataScope:
|
||
"""
|
||
Determines what billing data a user can see based on RBAC roles.
|
||
|
||
Evaluated once per request and used to filter transactions/statistics.
|
||
"""
|
||
__slots__ = ('isGlobalAdmin', 'adminMandateIds', 'adminFeatureInstanceIds',
|
||
'memberMandateIds', 'userId')
|
||
|
||
def __init__(self, userId: str):
|
||
self.isGlobalAdmin: bool = False
|
||
self.adminMandateIds: list = []
|
||
self.adminFeatureInstanceIds: list = []
|
||
self.memberMandateIds: list = []
|
||
self.userId: str = userId
|
||
|
||
|
||
def _getBillingDataScope(user) -> BillingDataScope:
|
||
"""
|
||
Determine what billing data a user can see based on RBAC.
|
||
|
||
Uses rootInterface (privileged) to check roles across all mandates
|
||
and feature instances without RBAC restrictions on the lookup itself.
|
||
|
||
Returns:
|
||
BillingDataScope with the user's visibility boundaries.
|
||
"""
|
||
scope = BillingDataScope(userId=user.id)
|
||
|
||
if bool(getattr(user, "isPlatformAdmin", False)):
|
||
scope.isGlobalAdmin = True
|
||
return scope
|
||
|
||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||
rootInterface = getRootInterface()
|
||
|
||
# --- Mandate roles ---
|
||
userMandates = rootInterface.getUserMandates(user.id)
|
||
for um in userMandates:
|
||
mandateId = getattr(um, 'mandateId', None)
|
||
umId = getattr(um, 'id', None)
|
||
if not mandateId or not umId:
|
||
continue
|
||
|
||
roleIds = rootInterface.getRoleIdsForUserMandate(umId)
|
||
isAdmin = False
|
||
for roleId in roleIds:
|
||
role = rootInterface.getRole(roleId)
|
||
if role and role.roleLabel == "admin" and not role.featureInstanceId:
|
||
isAdmin = True
|
||
break
|
||
|
||
if isAdmin:
|
||
scope.adminMandateIds.append(mandateId)
|
||
else:
|
||
scope.memberMandateIds.append(mandateId)
|
||
|
||
# --- Feature instance roles ---
|
||
featureAccesses = rootInterface.getFeatureAccessesForUser(user.id)
|
||
for fa in featureAccesses:
|
||
fiId = getattr(fa, 'featureInstanceId', None)
|
||
faId = getattr(fa, 'id', None)
|
||
if not fiId or not faId:
|
||
continue
|
||
|
||
roleIds = rootInterface.getRoleIdsForFeatureAccess(faId)
|
||
for roleId in roleIds:
|
||
role = rootInterface.getRole(roleId)
|
||
if role and role.roleLabel == "admin":
|
||
scope.adminFeatureInstanceIds.append(fiId)
|
||
break
|
||
|
||
logger.debug(
|
||
f"BillingDataScope for user {user.id}: "
|
||
f"globalAdmin={scope.isGlobalAdmin}, "
|
||
f"adminMandates={scope.adminMandateIds}, "
|
||
f"adminInstances={scope.adminFeatureInstanceIds}, "
|
||
f"memberMandates={scope.memberMandateIds}"
|
||
)
|
||
return scope
|
||
|
||
|
||
def _isAdminOfMandate(ctx: RequestContext, targetMandateId: str) -> bool:
|
||
"""Check if user is PlatformAdmin or admin of the specified mandate.
|
||
|
||
Fail-loud: any DB/lookup error is logged at ERROR and re-raised. We never
|
||
silently return False — that would mask infrastructure outages as "no
|
||
permission" and produce confusing 403s instead of actionable 500s.
|
||
"""
|
||
if ctx.isPlatformAdmin:
|
||
return True
|
||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||
rootInterface = getRootInterface()
|
||
userMandates = rootInterface.getUserMandates(str(ctx.user.id))
|
||
for um in userMandates:
|
||
if str(getattr(um, 'mandateId', None)) != str(targetMandateId):
|
||
continue
|
||
if not getattr(um, 'enabled', True):
|
||
continue
|
||
umId = str(getattr(um, 'id', ''))
|
||
roleIds = rootInterface.getRoleIdsForUserMandate(umId)
|
||
for roleId in roleIds:
|
||
role = rootInterface.getRole(roleId)
|
||
if role and role.roleLabel == "admin" and not role.featureInstanceId:
|
||
return True
|
||
return False
|
||
|
||
|
||
def _isMemberOfMandate(ctx: RequestContext, targetMandateId: str) -> bool:
|
||
"""Check if user has any enabled membership in the specified mandate.
|
||
|
||
Fail-loud: see _isAdminOfMandate above for the same rationale.
|
||
"""
|
||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||
rootInterface = getRootInterface()
|
||
userMandates = rootInterface.getUserMandates(str(ctx.user.id))
|
||
for um in userMandates:
|
||
if str(getattr(um, 'mandateId', None)) != str(targetMandateId):
|
||
continue
|
||
if not getattr(um, 'enabled', True):
|
||
continue
|
||
return True
|
||
return False
|
||
|
||
|
||
|
||
# =============================================================================
|
||
# Request/Response Models
|
||
# =============================================================================
|
||
|
||
class CreditAddRequest(BaseModel):
|
||
"""Request model for adding or deducting credit from an account."""
|
||
userId: Optional[str] = Field(None, description="Target user ID for audit trail only (optional)")
|
||
amount: float = Field(..., description="Amount in CHF. Positive = credit, negative = deduction. Must not be zero.")
|
||
description: str = Field(default="Manual credit", description="Transaction description")
|
||
|
||
|
||
class CheckoutCreateRequest(BaseModel):
|
||
"""Request model for creating Stripe Checkout Session."""
|
||
userId: Optional[str] = Field(None, description="Target user ID for audit trail only (optional)")
|
||
amount: float = Field(..., gt=0, description="Amount to pay in CHF (must be in allowed presets)")
|
||
returnUrl: str = Field(..., min_length=1, description="Absolute frontend URL used for Stripe success/cancel redirects")
|
||
|
||
|
||
class CheckoutCreateResponse(BaseModel):
|
||
"""Response model for Checkout Session creation."""
|
||
redirectUrl: str = Field(..., description="Stripe Checkout URL for redirect")
|
||
|
||
|
||
class CheckoutConfirmRequest(BaseModel):
|
||
"""Request model for confirming Stripe Checkout after redirect."""
|
||
sessionId: str = Field(..., min_length=1, description="Stripe Checkout Session ID (cs_xxx)")
|
||
|
||
|
||
class CheckoutConfirmResponse(BaseModel):
|
||
"""Response model for Stripe Checkout confirmation."""
|
||
credited: bool = Field(..., description="True if a new billing credit was created")
|
||
alreadyCredited: bool = Field(..., description="True if session was already credited before")
|
||
sessionId: str = Field(..., description="Stripe Checkout Session ID")
|
||
mandateId: str = Field(..., description="Mandate ID from Stripe metadata")
|
||
amountChf: float = Field(..., description="Credited amount in CHF")
|
||
|
||
|
||
class BillingSettingsUpdate(BaseModel):
|
||
"""Request model for updating billing settings."""
|
||
warningThresholdPercent: Optional[float] = Field(None, ge=0, le=100)
|
||
notifyOnWarning: Optional[bool] = None
|
||
notifyEmails: Optional[List[str]] = None
|
||
autoRechargeEnabled: Optional[bool] = None
|
||
rechargeAmountCHF: Optional[float] = Field(None, gt=0)
|
||
rechargeMaxPerMonth: Optional[int] = Field(None, ge=0)
|
||
|
||
|
||
class TransactionResponse(BaseModel):
|
||
"""Response model for a billing transaction."""
|
||
id: str
|
||
accountId: str
|
||
transactionType: TransactionTypeEnum
|
||
amount: float
|
||
description: str
|
||
referenceType: Optional[ReferenceTypeEnum]
|
||
workflowId: Optional[str]
|
||
featureCode: Optional[str]
|
||
featureInstanceId: Optional[str] = None
|
||
aicoreProvider: Optional[str]
|
||
aicoreModel: Optional[str] = None
|
||
createdByUserId: Optional[str] = None
|
||
sysCreatedAt: Optional[datetime] = None
|
||
mandateId: Optional[str] = None
|
||
mandateName: Optional[str] = None
|
||
|
||
|
||
class AccountSummary(BaseModel):
|
||
"""Summary of a billing account."""
|
||
id: str
|
||
mandateId: str
|
||
userId: Optional[str]
|
||
balance: float
|
||
warningThreshold: float
|
||
enabled: bool
|
||
|
||
|
||
class UsageReportResponse(BaseModel):
|
||
"""Usage report for an explicit date range."""
|
||
dateFrom: str
|
||
dateTo: str
|
||
bucketSize: str
|
||
totalCost: float
|
||
transactionCount: int
|
||
costByProvider: Dict[str, float]
|
||
costByModel: Dict[str, float] = {}
|
||
costByFeature: Dict[str, float]
|
||
|
||
|
||
# =============================================================================
|
||
# Response Models for Mandate/User Views
|
||
# =============================================================================
|
||
|
||
class MandateBalanceResponse(BaseModel):
|
||
"""Mandate-level balance summary."""
|
||
mandateId: str
|
||
mandateName: str
|
||
totalBalance: float
|
||
userCount: int
|
||
warningThresholdPercent: float
|
||
|
||
|
||
class UserBalanceResponse(BaseModel):
|
||
"""User-level balance summary."""
|
||
accountId: str
|
||
mandateId: str
|
||
mandateName: str
|
||
userId: str
|
||
userName: str
|
||
balance: float
|
||
warningThreshold: float
|
||
isWarning: bool
|
||
enabled: bool
|
||
|
||
|
||
class UserTransactionResponse(BaseModel):
|
||
"""User-level transaction with user context."""
|
||
id: str
|
||
accountId: str
|
||
transactionType: TransactionTypeEnum
|
||
amount: float
|
||
description: str
|
||
referenceType: Optional[ReferenceTypeEnum]
|
||
workflowId: Optional[str]
|
||
featureCode: Optional[str]
|
||
featureInstanceId: Optional[str] = None
|
||
aicoreProvider: Optional[str]
|
||
aicoreModel: Optional[str] = None
|
||
createdByUserId: Optional[str] = None
|
||
sysCreatedAt: Optional[datetime] = None
|
||
mandateId: Optional[str] = None
|
||
mandateName: Optional[str] = None
|
||
userId: Optional[str] = None
|
||
userName: Optional[str] = None
|
||
|
||
|
||
def _getStripeClient():
|
||
"""Initialize and return configured Stripe SDK module."""
|
||
from modules.shared.stripeClient import getStripeClient
|
||
return getStripeClient()
|
||
|
||
|
||
def _creditStripeSessionIfNeeded(
|
||
billingInterface,
|
||
session: Dict[str, Any],
|
||
eventId: Optional[str] = None,
|
||
) -> CheckoutConfirmResponse:
|
||
"""
|
||
Credit balance from Stripe Checkout session if not already credited.
|
||
Uses Checkout session ID for idempotency across webhook + manual confirmation flows.
|
||
"""
|
||
from modules.serviceCenter.services.serviceBilling.stripeCheckout import ALLOWED_AMOUNTS_CHF
|
||
|
||
session_id = session.get("id")
|
||
metadata = session.get("metadata") or {}
|
||
mandate_id = metadata.get("mandateId")
|
||
user_id = metadata.get("userId") or None
|
||
amount_chf_str = metadata.get("amountChf", "0")
|
||
|
||
if not session_id:
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Stripe session id missing"))
|
||
if not mandate_id:
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Invalid session metadata: mandateId missing"))
|
||
|
||
existing_payment_tx = billingInterface.getPaymentTransactionByReferenceId(session_id)
|
||
if existing_payment_tx:
|
||
if eventId and not billingInterface.getStripeWebhookEventByEventId(eventId):
|
||
billingInterface.createStripeWebhookEvent(eventId)
|
||
return CheckoutConfirmResponse(
|
||
credited=False,
|
||
alreadyCredited=True,
|
||
sessionId=session_id,
|
||
mandateId=mandate_id,
|
||
amountChf=float(existing_payment_tx.get("amount", 0.0)),
|
||
)
|
||
|
||
try:
|
||
amount_chf = float(amount_chf_str)
|
||
except (TypeError, ValueError):
|
||
amount_chf = None
|
||
|
||
if amount_chf is None or amount_chf not in ALLOWED_AMOUNTS_CHF:
|
||
amount_total = session.get("amount_total")
|
||
if amount_total is not None:
|
||
amount_chf = amount_total / 100.0
|
||
else:
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Invalid amount in Stripe session"))
|
||
|
||
settings = billingInterface.getSettings(mandate_id)
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail=routeApiMsg("Billing settings not found"))
|
||
|
||
account = billingInterface.getOrCreateMandateAccount(mandate_id, initialBalance=0.0)
|
||
|
||
transaction = BillingTransaction(
|
||
accountId=account["id"],
|
||
transactionType=TransactionTypeEnum.CREDIT,
|
||
amount=amount_chf,
|
||
description="Stripe-Zahlung",
|
||
referenceType=ReferenceTypeEnum.PAYMENT,
|
||
referenceId=session_id,
|
||
createdByUserId=user_id,
|
||
)
|
||
billingInterface.createTransaction(transaction)
|
||
|
||
if eventId and not billingInterface.getStripeWebhookEventByEventId(eventId):
|
||
billingInterface.createStripeWebhookEvent(eventId)
|
||
|
||
logger.info(f"Stripe credit applied: {amount_chf} CHF for session {session_id} on mandate {mandate_id}")
|
||
return CheckoutConfirmResponse(
|
||
credited=True,
|
||
alreadyCredited=False,
|
||
sessionId=session_id,
|
||
mandateId=mandate_id,
|
||
amountChf=amount_chf,
|
||
)
|
||
|
||
|
||
# =============================================================================
|
||
# Router Setup
|
||
# =============================================================================
|
||
|
||
router = APIRouter(
|
||
prefix="/api/billing",
|
||
tags=["Billing"],
|
||
responses={404: {"description": "Not found"}}
|
||
)
|
||
|
||
# =============================================================================
|
||
# User Endpoints
|
||
# =============================================================================
|
||
|
||
@router.get("/balance", response_model=List[BillingBalanceResponse])
|
||
@limiter.limit("60/minute")
|
||
def getBalance(
|
||
request: Request,
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get billing balances for all mandates the current user belongs to.
|
||
Returns balance information for each mandate.
|
||
"""
|
||
try:
|
||
billingService = getBillingService(
|
||
ctx.user,
|
||
ctx.mandateId,
|
||
featureCode="billing"
|
||
)
|
||
|
||
balances = billingService.getBalancesForUser()
|
||
return balances
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing balance: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/balance/{targetMandateId}", response_model=BillingBalanceResponse)
|
||
@limiter.limit("60/minute")
|
||
def getBalanceForMandate(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get billing balance for a specific mandate.
|
||
"""
|
||
try:
|
||
billingService = getBillingService(
|
||
ctx.user,
|
||
targetMandateId,
|
||
featureCode="billing"
|
||
)
|
||
|
||
# Check balance
|
||
checkResult = billingService.checkBalance(0.0)
|
||
|
||
# Get mandate name from app interface
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
appInterface = getAppInterface(ctx.user, mandateId=targetMandateId)
|
||
mandate = appInterface.getMandate(targetMandateId)
|
||
mandateName = (mandate.get("label") or mandate.get("name", "")) if mandate else ""
|
||
|
||
return BillingBalanceResponse(
|
||
mandateId=targetMandateId,
|
||
mandateName=mandateName,
|
||
balance=checkResult.currentBalance or 0.0,
|
||
warningThreshold=0.0, # TODO: Get from account
|
||
isWarning=False,
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing balance for mandate {targetMandateId}: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/transactions", response_model=List[TransactionResponse])
|
||
@limiter.limit("30/minute")
|
||
def getTransactions(
|
||
request: Request,
|
||
limit: int = Query(default=50, ge=1, le=500),
|
||
offset: int = Query(default=0, ge=0),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get transaction history across all mandates the user belongs to.
|
||
"""
|
||
try:
|
||
billingService = getBillingService(
|
||
ctx.user,
|
||
ctx.mandateId,
|
||
featureCode="billing"
|
||
)
|
||
|
||
# Fetch enough transactions for pagination
|
||
transactions = billingService.getTransactionHistory(limit=offset + limit)
|
||
|
||
# Convert to response model
|
||
result = []
|
||
for t in transactions[offset:offset + limit]:
|
||
result.append(TransactionResponse(
|
||
id=t.get("id"),
|
||
accountId=t.get("accountId"),
|
||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||
amount=t.get("amount", 0.0),
|
||
description=t.get("description", ""),
|
||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||
workflowId=t.get("workflowId"),
|
||
featureCode=t.get("featureCode"),
|
||
featureInstanceId=t.get("featureInstanceId"),
|
||
aicoreProvider=t.get("aicoreProvider"),
|
||
aicoreModel=t.get("aicoreModel"),
|
||
createdByUserId=t.get("createdByUserId"),
|
||
sysCreatedAt=t.get("sysCreatedAt"),
|
||
mandateId=t.get("mandateId"),
|
||
mandateName=t.get("mandateName")
|
||
))
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing transactions: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/statistics", response_model=UsageReportResponse)
|
||
@limiter.limit("30/minute")
|
||
def getStatistics(
|
||
request: Request,
|
||
dateFrom: str = Query(..., description="ISO YYYY-MM-DD (inclusive)"),
|
||
dateTo: str = Query(..., description="ISO YYYY-MM-DD (inclusive)"),
|
||
bucketSize: str = Query(..., pattern="^(day|month|year)$",
|
||
description="Time-bucket granularity: day, month, or year"),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get usage statistics for an explicit date range.
|
||
|
||
`dateFrom`/`dateTo` are inclusive local-day boundaries.
|
||
`bucketSize` controls the time-series aggregation granularity and is
|
||
independent of the chosen range.
|
||
"""
|
||
from modules.shared.dateRange import parseIsoDateRange
|
||
|
||
try:
|
||
startDate, toDateInclusive = parseIsoDateRange(dateFrom, dateTo)
|
||
# `calculateStatisticsFromTransactions` expects a half-open
|
||
# [startDate, endDate) interval, so widen the upper bound by one day.
|
||
from datetime import timedelta as _td
|
||
endDate = toDateInclusive + _td(days=1)
|
||
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
settings = billingInterface.getSettings(ctx.mandateId)
|
||
|
||
emptyResponse = UsageReportResponse(
|
||
dateFrom=dateFrom,
|
||
dateTo=dateTo,
|
||
bucketSize=bucketSize,
|
||
totalCost=0.0,
|
||
transactionCount=0,
|
||
costByProvider={},
|
||
costByFeature={},
|
||
)
|
||
if not settings:
|
||
return emptyResponse
|
||
|
||
# Transactions are always on user accounts (audit trail)
|
||
account = billingInterface.getUserAccount(ctx.mandateId, ctx.user.id)
|
||
if not account:
|
||
return emptyResponse
|
||
|
||
stats = billingInterface.calculateStatisticsFromTransactions(
|
||
account["id"],
|
||
startDate,
|
||
endDate,
|
||
)
|
||
|
||
return UsageReportResponse(
|
||
dateFrom=dateFrom,
|
||
dateTo=dateTo,
|
||
bucketSize=bucketSize,
|
||
totalCost=stats.get("totalCostCHF", 0.0),
|
||
transactionCount=stats.get("transactionCount", 0),
|
||
costByProvider=stats.get("costByProvider", {}),
|
||
costByModel=stats.get("costByModel", {}),
|
||
costByFeature=stats.get("costByFeature", {}),
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing statistics: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/providers", response_model=List[str])
|
||
@limiter.limit("60/minute")
|
||
def getAllowedProviders(
|
||
request: Request,
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get list of AICore providers the current user is allowed to use.
|
||
"""
|
||
try:
|
||
billingService = getBillingService(
|
||
ctx.user,
|
||
ctx.mandateId,
|
||
featureCode="billing"
|
||
)
|
||
|
||
return billingService.getallowedProviders()
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting allowed providers: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# =============================================================================
|
||
# Admin Endpoints
|
||
# =============================================================================
|
||
|
||
@router.get("/admin/settings/{targetMandateId}", response_model=Dict[str, Any])
|
||
@limiter.limit("30/minute")
|
||
def getSettingsAdmin(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Get billing settings for a mandate.
|
||
Access: SysAdmin (any mandate) or MandateAdmin (own mandate).
|
||
"""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=routeApiMsg("Admin role required for this mandate"))
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
settings = billingInterface.getSettings(targetMandateId)
|
||
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail=routeApiMsg("Billing settings not found"))
|
||
|
||
return settings
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing settings: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/admin/settings/{targetMandateId}", response_model=Dict[str, Any])
|
||
@limiter.limit("10/minute")
|
||
def createOrUpdateSettings(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
settingsUpdate: BillingSettingsUpdate = Body(...),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Create or update billing settings for a mandate.
|
||
Access: SysAdmin (any mandate) or MandateAdmin (own mandate).
|
||
"""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=routeApiMsg("Admin role required for this mandate"))
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
existingSettings = billingInterface.getSettings(targetMandateId)
|
||
|
||
if existingSettings:
|
||
updates = settingsUpdate.model_dump(exclude_none=True)
|
||
if updates:
|
||
result = billingInterface.updateSettings(existingSettings["id"], updates)
|
||
return result or existingSettings
|
||
return existingSettings
|
||
else:
|
||
from modules.datamodels.datamodelBilling import BillingSettings
|
||
|
||
newSettings = BillingSettings(
|
||
mandateId=targetMandateId,
|
||
warningThresholdPercent=(
|
||
settingsUpdate.warningThresholdPercent
|
||
if settingsUpdate.warningThresholdPercent is not None
|
||
else 10.0
|
||
),
|
||
notifyOnWarning=(
|
||
settingsUpdate.notifyOnWarning
|
||
if settingsUpdate.notifyOnWarning is not None
|
||
else True
|
||
),
|
||
notifyEmails=settingsUpdate.notifyEmails or [],
|
||
autoRechargeEnabled=(
|
||
settingsUpdate.autoRechargeEnabled
|
||
if settingsUpdate.autoRechargeEnabled is not None
|
||
else False
|
||
),
|
||
rechargeAmountCHF=(
|
||
settingsUpdate.rechargeAmountCHF
|
||
if settingsUpdate.rechargeAmountCHF is not None
|
||
else 10.0
|
||
),
|
||
rechargeMaxPerMonth=(
|
||
settingsUpdate.rechargeMaxPerMonth
|
||
if settingsUpdate.rechargeMaxPerMonth is not None
|
||
else 3
|
||
),
|
||
)
|
||
|
||
return billingInterface.createSettings(newSettings)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error updating billing settings: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/admin/credit/{targetMandateId}", response_model=Dict[str, Any])
|
||
@limiter.limit("10/minute")
|
||
def addCredit(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
creditRequest: CreditAddRequest = Body(...),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
_admin = Depends(requirePlatformAdmin)
|
||
):
|
||
"""
|
||
Add credit to a billing account (SysAdmin only).
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
settings = billingInterface.getSettings(targetMandateId)
|
||
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail=routeApiMsg("Billing settings not found for this mandate"))
|
||
|
||
account = billingInterface.getOrCreateMandateAccount(targetMandateId, initialBalance=0.0)
|
||
|
||
if creditRequest.amount == 0:
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Amount must not be zero"))
|
||
|
||
from modules.datamodels.datamodelBilling import BillingTransaction
|
||
|
||
isDeduction = creditRequest.amount < 0
|
||
txType = TransactionTypeEnum.DEBIT if isDeduction else TransactionTypeEnum.CREDIT
|
||
absAmount = abs(creditRequest.amount)
|
||
|
||
transaction = BillingTransaction(
|
||
accountId=account["id"],
|
||
transactionType=txType,
|
||
amount=absAmount,
|
||
description=creditRequest.description,
|
||
referenceType=ReferenceTypeEnum.ADMIN
|
||
)
|
||
|
||
result = billingInterface.createTransaction(transaction)
|
||
|
||
action = "Deducted" if isDeduction else "Added"
|
||
logger.info(f"{action} {absAmount} CHF to account {account['id']} in mandate {targetMandateId}")
|
||
|
||
return result
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error adding credit: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/checkout/amounts", response_model=List[float])
|
||
@limiter.limit("60/minute")
|
||
def getCheckoutAmounts(
|
||
request: Request,
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Return the server-side allow-list of CHF top-up amounts for Stripe Checkout.
|
||
The frontend must populate its dropdown from this list — values not in
|
||
the list are rejected by `create_checkout_session` (server-side validation).
|
||
"""
|
||
from modules.serviceCenter.services.serviceBilling.stripeCheckout import ALLOWED_AMOUNTS_CHF
|
||
return [float(a) for a in ALLOWED_AMOUNTS_CHF]
|
||
|
||
|
||
@router.post("/checkout/create/{targetMandateId}", response_model=CheckoutCreateResponse)
|
||
@limiter.limit("10/minute")
|
||
def createCheckoutSession(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
checkoutRequest: CheckoutCreateRequest = Body(...),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Create Stripe Checkout Session for credit top-up. Returns redirect URL.
|
||
Requires mandate admin role.
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
settings = billingInterface.getSettings(targetMandateId)
|
||
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail=routeApiMsg("Billing settings not found for this mandate"))
|
||
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=403, detail=routeApiMsg("Mandate admin role required to load mandate credit"))
|
||
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
appInterface = getAppInterface(ctx.user, mandateId=targetMandateId)
|
||
mandateRecord = appInterface.getMandate(targetMandateId)
|
||
if mandateRecord is not None:
|
||
mandateLabel = getattr(mandateRecord, "label", None) or getattr(mandateRecord, "name", None) or targetMandateId
|
||
invoiceAddress = {
|
||
"companyName": getattr(mandateRecord, "invoiceCompanyName", None),
|
||
"contactName": getattr(mandateRecord, "invoiceContactName", None),
|
||
"email": getattr(mandateRecord, "invoiceEmail", None),
|
||
"line1": getattr(mandateRecord, "invoiceLine1", None),
|
||
"line2": getattr(mandateRecord, "invoiceLine2", None),
|
||
"postalCode": getattr(mandateRecord, "invoicePostalCode", None),
|
||
"city": getattr(mandateRecord, "invoiceCity", None),
|
||
"state": getattr(mandateRecord, "invoiceState", None),
|
||
"country": getattr(mandateRecord, "invoiceCountry", None) or "CH",
|
||
"vatNumber": getattr(mandateRecord, "invoiceVatNumber", None),
|
||
}
|
||
else:
|
||
mandateLabel = targetMandateId
|
||
invoiceAddress = None
|
||
|
||
from modules.serviceCenter.services.serviceBilling.stripeCheckout import create_checkout_session
|
||
redirect_url = create_checkout_session(
|
||
mandate_id=targetMandateId,
|
||
user_id=checkoutRequest.userId,
|
||
amount_chf=checkoutRequest.amount,
|
||
return_url=checkoutRequest.returnUrl,
|
||
mandate_label=mandateLabel,
|
||
invoice_address=invoiceAddress,
|
||
settings=settings,
|
||
billing_interface=billingInterface,
|
||
)
|
||
return CheckoutCreateResponse(redirectUrl=redirect_url)
|
||
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error creating checkout session: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/checkout/confirm", response_model=CheckoutConfirmResponse)
|
||
@limiter.limit("20/minute")
|
||
def confirmCheckoutSession(
|
||
request: Request,
|
||
confirmRequest: CheckoutConfirmRequest = Body(...),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Confirm Stripe Checkout success by session ID and apply credit idempotently.
|
||
This is a fallback/reconciliation path in addition to webhook processing.
|
||
"""
|
||
try:
|
||
stripe = _getStripeClient()
|
||
session = stripe.checkout.Session.retrieve(confirmRequest.sessionId)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail=routeApiMsg("Stripe Checkout Session not found"))
|
||
|
||
from modules.shared.stripeClient import stripeToDict
|
||
session_dict = stripeToDict(session)
|
||
metadata = session_dict.get("metadata") or {}
|
||
mandate_id = metadata.get("mandateId")
|
||
user_id = metadata.get("userId") or None
|
||
|
||
if not mandate_id:
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Invalid session metadata: mandateId missing"))
|
||
|
||
payment_status = session_dict.get("payment_status")
|
||
if payment_status != "paid":
|
||
raise HTTPException(status_code=409, detail=f"Payment not completed yet (payment_status={payment_status})")
|
||
|
||
billingInterface = getBillingInterface(ctx.user, mandate_id)
|
||
settings = billingInterface.getSettings(mandate_id)
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail=routeApiMsg("Billing settings not found"))
|
||
|
||
if not _isAdminOfMandate(ctx, mandate_id):
|
||
raise HTTPException(status_code=403, detail=routeApiMsg("Mandate admin role required"))
|
||
|
||
root_billing_interface = getRootInterface()
|
||
return _creditStripeSessionIfNeeded(root_billing_interface, session_dict, eventId=None)
|
||
except HTTPException:
|
||
raise
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except Exception as e:
|
||
logger.error(f"Error confirming checkout session {confirmRequest.sessionId}: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/webhook/stripe")
|
||
async def stripeWebhook(
|
||
request: Request,
|
||
stripe_signature: Optional[str] = Header(None, alias="Stripe-Signature")
|
||
):
|
||
"""
|
||
Stripe webhook endpoint. Verifies signature and processes checkout.session.completed.
|
||
No JWT auth - Stripe authenticates via Stripe-Signature header.
|
||
"""
|
||
from modules.shared.configuration import APP_CONFIG
|
||
|
||
webhook_secret = APP_CONFIG.get("STRIPE_WEBHOOK_SECRET")
|
||
if not webhook_secret:
|
||
logger.error("STRIPE_WEBHOOK_SECRET not configured")
|
||
raise HTTPException(status_code=500, detail=routeApiMsg("Webhook not configured"))
|
||
|
||
if not stripe_signature:
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Missing Stripe-Signature header"))
|
||
|
||
payload = await request.body()
|
||
|
||
try:
|
||
import stripe
|
||
event = stripe.Webhook.construct_event(
|
||
payload, stripe_signature, webhook_secret
|
||
)
|
||
except ValueError as e:
|
||
logger.warning(f"Stripe webhook invalid payload: {e}")
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Invalid payload"))
|
||
except Exception as e:
|
||
logger.warning(f"Stripe webhook signature verification failed: {e}")
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Invalid signature"))
|
||
|
||
logger.info(f"Stripe webhook received: event={event.id}, type={event.type}")
|
||
|
||
# Subscription-related events
|
||
subscriptionEventTypes = {
|
||
"customer.subscription.updated",
|
||
"customer.subscription.deleted",
|
||
"invoice.paid",
|
||
"invoice.payment_failed",
|
||
"customer.subscription.trial_will_end",
|
||
}
|
||
|
||
# Checkout events (existing)
|
||
checkoutEventTypes = {"checkout.session.completed", "checkout.session.async_payment_succeeded"}
|
||
|
||
if event.type in subscriptionEventTypes:
|
||
_handleSubscriptionWebhook(event)
|
||
return {"received": True}
|
||
|
||
if event.type not in checkoutEventTypes:
|
||
return {"received": True}
|
||
|
||
session = event.data.object
|
||
event_id = event.id
|
||
|
||
sessionMode = session.get("mode") if hasattr(session, "get") else getattr(session, "mode", None)
|
||
if sessionMode == "subscription":
|
||
handleSubscriptionCheckoutCompleted(session, event_id)
|
||
return {"received": True}
|
||
|
||
billingInterface = getRootInterface()
|
||
if billingInterface.getStripeWebhookEventByEventId(event_id):
|
||
logger.info(f"Stripe event {event_id} already processed, skipping")
|
||
return {"received": True}
|
||
|
||
session_dict = session.to_dict_recursive() if hasattr(session, "to_dict_recursive") else dict(session)
|
||
try:
|
||
result = _creditStripeSessionIfNeeded(billingInterface, session_dict, eventId=event_id)
|
||
logger.info(
|
||
f"Stripe webhook processed session {result.sessionId}: "
|
||
f"credited={result.credited}, alreadyCredited={result.alreadyCredited}"
|
||
)
|
||
except HTTPException as he:
|
||
logger.error(
|
||
"Stripe webhook %s for session %s failed: status=%s detail=%s metadata=%s amount_total=%s",
|
||
event_id,
|
||
session_dict.get("id"),
|
||
he.status_code,
|
||
he.detail,
|
||
session_dict.get("metadata"),
|
||
session_dict.get("amount_total"),
|
||
)
|
||
if 400 <= he.status_code < 500 and event_id:
|
||
if not billingInterface.getStripeWebhookEventByEventId(event_id):
|
||
try:
|
||
billingInterface.createStripeWebhookEvent(event_id)
|
||
logger.warning(
|
||
"Marked Stripe event %s as processed (permanent 4xx) to stop retries",
|
||
event_id,
|
||
)
|
||
except Exception as markEx:
|
||
logger.error("Failed to mark Stripe event %s as processed: %s", event_id, markEx)
|
||
return {"received": True}
|
||
raise
|
||
return {"received": True}
|
||
|
||
|
||
def handleSubscriptionCheckoutCompleted(session, eventId: str) -> None:
|
||
"""Handle checkout.session.completed for mode=subscription.
|
||
Resolves the local PENDING record by ID from webhook metadata and transitions it."""
|
||
from modules.interfaces.interfaceDbSubscription import getRootInterface as getSubRootInterface
|
||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, getPlan
|
||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||
getService as getSubscriptionService,
|
||
_notifySubscriptionChange,
|
||
)
|
||
from modules.security.rootAccess import getRootUser
|
||
from datetime import datetime, timezone
|
||
|
||
if not isinstance(session, dict):
|
||
from modules.shared.stripeClient import stripeToDict
|
||
session = stripeToDict(session)
|
||
|
||
metadata = session.get("metadata") or {}
|
||
subscriptionRecordId = metadata.get("subscriptionRecordId")
|
||
mandateId = metadata.get("mandateId")
|
||
planKey = metadata.get("planKey", "")
|
||
|
||
platformUrl = metadata.get("platformUrl", "")
|
||
|
||
if not subscriptionRecordId:
|
||
stripeSub = session.get("subscription")
|
||
if stripeSub:
|
||
try:
|
||
from modules.shared.stripeClient import getStripeClient
|
||
stripe = getStripeClient()
|
||
from modules.shared.stripeClient import stripeToDict
|
||
subObj = stripeToDict(stripe.Subscription.retrieve(stripeSub))
|
||
metadata = subObj.get("metadata") or {}
|
||
subscriptionRecordId = metadata.get("subscriptionRecordId")
|
||
mandateId = metadata.get("mandateId")
|
||
planKey = metadata.get("planKey", "")
|
||
platformUrl = platformUrl or metadata.get("platformUrl", "")
|
||
except Exception as e:
|
||
# Stripe lookup is the only way to recover the metadata at this
|
||
# point — if it fails we MUST surface it, otherwise the webhook
|
||
# later short-circuits with "missing metadata" and the user
|
||
# silently gets stuck in PENDING.
|
||
logger.error(
|
||
"Stripe Subscription.retrieve(%s) failed during checkout "
|
||
"metadata recovery: %s", stripeSub, e,
|
||
)
|
||
raise
|
||
|
||
stripeSubId = session.get("subscription")
|
||
|
||
if not mandateId or not subscriptionRecordId:
|
||
logger.warning("Subscription checkout missing metadata: %s", metadata)
|
||
return
|
||
|
||
subInterface = getSubRootInterface()
|
||
rootUser = getRootUser()
|
||
|
||
sub = subInterface.getById(subscriptionRecordId)
|
||
if not sub:
|
||
logger.error("Subscription record %s not found for checkout webhook", subscriptionRecordId)
|
||
return
|
||
if sub.get("status") != SubscriptionStatusEnum.PENDING.value:
|
||
logger.warning("Subscription %s is %s, expected PENDING — skipping", subscriptionRecordId, sub.get("status"))
|
||
return
|
||
|
||
stripeData: Dict[str, Any] = {}
|
||
if stripeSubId:
|
||
stripeData["stripeSubscriptionId"] = stripeSubId
|
||
try:
|
||
from modules.shared.stripeClient import getStripeClient
|
||
stripe = getStripeClient()
|
||
from modules.shared.stripeClient import stripeToDict
|
||
stripeSub = stripeToDict(stripe.Subscription.retrieve(stripeSubId, expand=["items"]))
|
||
|
||
if stripeSub.get("current_period_start"):
|
||
stripeData["currentPeriodStart"] = float(stripeSub["current_period_start"])
|
||
if stripeSub.get("current_period_end"):
|
||
stripeData["currentPeriodEnd"] = float(stripeSub["current_period_end"])
|
||
|
||
from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import getStripePricesForPlan
|
||
priceMapping = getStripePricesForPlan(planKey)
|
||
items = stripeSub.get("items") or {}
|
||
if not isinstance(items, dict):
|
||
items = dict(items)
|
||
for item in items.get("data", []):
|
||
priceId = (item.get("price") or {}).get("id", "")
|
||
if priceMapping and priceId == priceMapping.stripePriceIdUsers:
|
||
stripeData["stripeItemIdUsers"] = item["id"]
|
||
elif priceMapping and priceId == priceMapping.stripePriceIdInstances:
|
||
stripeData["stripeItemIdInstances"] = item["id"]
|
||
except Exception as e:
|
||
# Without these enrichment fields the activation completes anyway
|
||
# (status flips to ACTIVE/SCHEDULED below), but periods + Stripe
|
||
# item-IDs are missing on the local record, which breaks later
|
||
# add-on billing and renewal accounting. Re-raise so the webhook
|
||
# is retried by Stripe instead of silently shipping a broken row.
|
||
logger.error(
|
||
"Error retrieving Stripe subscription %s during checkout "
|
||
"completion (will be retried by Stripe): %s",
|
||
stripeSubId, e,
|
||
)
|
||
raise
|
||
|
||
if stripeData:
|
||
subInterface.updateFields(subscriptionRecordId, stripeData)
|
||
|
||
operative = subInterface.getOperativeForMandate(mandateId)
|
||
hasActivePredecessor = operative is not None and operative["id"] != subscriptionRecordId
|
||
predecessorIsTrial = (
|
||
hasActivePredecessor
|
||
and operative.get("status") == SubscriptionStatusEnum.TRIALING.value
|
||
)
|
||
|
||
if hasActivePredecessor and predecessorIsTrial:
|
||
try:
|
||
subInterface.forceExpire(operative["id"])
|
||
logger.info(
|
||
"Trial subscription %s expired immediately for mandate %s due to paid upgrade %s",
|
||
operative["id"], mandateId, subscriptionRecordId,
|
||
)
|
||
except Exception as e:
|
||
logger.error("Failed to expire trial predecessor %s: %s", operative["id"], e)
|
||
toStatus = SubscriptionStatusEnum.ACTIVE
|
||
elif hasActivePredecessor:
|
||
toStatus = SubscriptionStatusEnum.SCHEDULED
|
||
if operative.get("recurring", True):
|
||
operativeStripeId = operative.get("stripeSubscriptionId")
|
||
if operativeStripeId:
|
||
try:
|
||
from modules.shared.stripeClient import getStripeClient
|
||
stripe = getStripeClient()
|
||
stripe.Subscription.modify(operativeStripeId, cancel_at_period_end=True)
|
||
except Exception as e:
|
||
logger.error("Failed to set cancel_at_period_end on predecessor %s: %s", operativeStripeId, e)
|
||
subInterface.updateFields(operative["id"], {"recurring": False})
|
||
effectiveFrom = operative.get("currentPeriodEnd")
|
||
if effectiveFrom:
|
||
subInterface.updateFields(subscriptionRecordId, {"effectiveFrom": effectiveFrom})
|
||
else:
|
||
toStatus = SubscriptionStatusEnum.ACTIVE
|
||
|
||
try:
|
||
subInterface.transitionStatus(
|
||
subscriptionRecordId, SubscriptionStatusEnum.PENDING, toStatus,
|
||
{"recurring": True},
|
||
)
|
||
except Exception as e:
|
||
logger.error("Failed to transition subscription %s: %s", subscriptionRecordId, e)
|
||
return
|
||
|
||
subService = getSubscriptionService(rootUser, mandateId)
|
||
subService.invalidateCache(mandateId)
|
||
|
||
if toStatus == SubscriptionStatusEnum.ACTIVE:
|
||
plan = getPlan(planKey)
|
||
updatedSub = subInterface.getById(subscriptionRecordId)
|
||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=updatedSub, platformUrl=platformUrl)
|
||
|
||
try:
|
||
billingIf = getRootInterface()
|
||
billingIf.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung")
|
||
except Exception as ex:
|
||
logger.error("creditSubscriptionBudget on activation failed: %s", ex)
|
||
|
||
logger.info(
|
||
"Checkout completed: sub=%s -> %s, mandate=%s, plan=%s",
|
||
subscriptionRecordId, toStatus.value, mandateId, planKey,
|
||
)
|
||
|
||
|
||
def _handleSubscriptionWebhook(event) -> None:
|
||
"""Process Stripe subscription webhook events.
|
||
All record resolution is by stripeSubscriptionId — no mandate-based guessing."""
|
||
from modules.interfaces.interfaceDbSubscription import getRootInterface as getSubRootInterface
|
||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, getPlan
|
||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||
getService as getSubscriptionService,
|
||
_notifySubscriptionChange,
|
||
)
|
||
from modules.security.rootAccess import getRootUser
|
||
from datetime import datetime, timezone
|
||
|
||
obj = event.data.object
|
||
rawSub = obj.get("id") if event.type.startswith("customer.subscription") else obj.get("subscription")
|
||
stripeSubId = rawSub.get("id") if isinstance(rawSub, dict) else rawSub
|
||
if not stripeSubId:
|
||
logger.warning("Subscription webhook %s has no subscription ID", event.type)
|
||
return
|
||
|
||
subInterface = getSubRootInterface()
|
||
sub = subInterface.getByStripeSubscriptionId(stripeSubId)
|
||
if not sub:
|
||
logger.warning("No local record for Stripe subscription %s (event: %s)", stripeSubId, event.type)
|
||
return
|
||
|
||
subId = sub["id"]
|
||
mandateId = sub["mandateId"]
|
||
currentStatus = SubscriptionStatusEnum(sub["status"])
|
||
rootUser = getRootUser()
|
||
subService = getSubscriptionService(rootUser, mandateId)
|
||
|
||
subMetadata = obj.get("metadata") or {}
|
||
webhookPlatformUrl = subMetadata.get("platformUrl", "")
|
||
|
||
if event.type == "customer.subscription.updated":
|
||
stripeStatus = obj.get("status", "")
|
||
|
||
periodData: Dict[str, Any] = {}
|
||
if obj.get("current_period_start"):
|
||
periodData["currentPeriodStart"] = float(obj["current_period_start"])
|
||
if obj.get("current_period_end"):
|
||
periodData["currentPeriodEnd"] = float(obj["current_period_end"])
|
||
if periodData:
|
||
subInterface.updateFields(subId, periodData)
|
||
|
||
if stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.SCHEDULED:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE)
|
||
subService.invalidateCache(mandateId)
|
||
planKey = sub.get("planKey", "")
|
||
plan = getPlan(planKey)
|
||
refreshedSub = subInterface.getById(subId)
|
||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedSub, platformUrl=webhookPlatformUrl)
|
||
try:
|
||
getRootInterface().creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung")
|
||
except Exception as ex:
|
||
logger.error("creditSubscriptionBudget SCHEDULED->ACTIVE failed: %s", ex)
|
||
logger.info("SCHEDULED -> ACTIVE for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.PAST_DUE:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.ACTIVE)
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("PAST_DUE -> ACTIVE for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif stripeStatus == "past_due" and currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE)
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("ACTIVE -> PAST_DUE for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("Period renewed for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif event.type == "customer.subscription.deleted":
|
||
if currentStatus not in (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE,
|
||
SubscriptionStatusEnum.SCHEDULED):
|
||
logger.info("Ignoring deletion for sub %s in status %s", subId, currentStatus.value)
|
||
return
|
||
|
||
subInterface.transitionStatus(subId, currentStatus, SubscriptionStatusEnum.EXPIRED)
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("Sub %s -> EXPIRED (Stripe deleted, mandate %s)", subId, mandateId)
|
||
|
||
scheduled = subInterface.getScheduledForMandate(mandateId)
|
||
if scheduled:
|
||
try:
|
||
subInterface.transitionStatus(
|
||
scheduled["id"], SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE,
|
||
)
|
||
subService.invalidateCache(mandateId)
|
||
plan = getPlan(scheduled.get("planKey", ""))
|
||
refreshedScheduled = subInterface.getById(scheduled["id"])
|
||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedScheduled, platformUrl=webhookPlatformUrl)
|
||
logger.info("Promoted SCHEDULED sub %s -> ACTIVE (mandate %s)", scheduled["id"], mandateId)
|
||
except Exception as e:
|
||
logger.error("Failed to promote SCHEDULED sub %s: %s", scheduled["id"], e)
|
||
|
||
elif event.type == "invoice.payment_failed":
|
||
if currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE)
|
||
subService.invalidateCache(mandateId)
|
||
plan = getPlan(sub.get("planKey", ""))
|
||
_notifySubscriptionChange(mandateId, "payment_failed", plan, subscriptionRecord=sub, platformUrl=webhookPlatformUrl)
|
||
logger.info("Payment failed for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif event.type == "customer.subscription.trial_will_end":
|
||
logger.info("Trial ending soon for sub %s (mandate %s)", subId, mandateId)
|
||
try:
|
||
from modules.shared.notifyMandateAdmins import notifyMandateAdmins
|
||
notifyMandateAdmins(
|
||
mandateId,
|
||
"[PowerOn] Testphase endet bald",
|
||
"Testphase endet bald",
|
||
[
|
||
"Die kostenlose Testphase für Ihren Mandanten endet in Kürze.",
|
||
"Bitte wählen Sie einen Plan unter Billing-Verwaltung › Abonnement.",
|
||
],
|
||
)
|
||
except Exception as e:
|
||
logger.error("Failed to notify about trial ending: %s", e)
|
||
|
||
elif event.type == "invoice.paid":
|
||
period_ts = obj.get("period_start")
|
||
periodLabel = ""
|
||
if period_ts:
|
||
period_start_at = datetime.fromtimestamp(int(period_ts), tz=timezone.utc)
|
||
periodLabel = period_start_at.strftime("%Y-%m-%d")
|
||
try:
|
||
billing_if = getRootInterface()
|
||
billing_if.resetStorageBillingPeriod(mandateId, period_start_at)
|
||
billing_if.reconcileMandateStorageBilling(mandateId)
|
||
except Exception as ex:
|
||
logger.error("Storage billing on invoice.paid failed: %s", ex)
|
||
|
||
planKey = sub.get("planKey", "")
|
||
try:
|
||
billing_if = getRootInterface()
|
||
billing_if.creditSubscriptionBudget(mandateId, planKey, periodLabel=periodLabel or "Periodenverlängerung")
|
||
except Exception as ex:
|
||
logger.error("creditSubscriptionBudget on invoice.paid failed: %s", ex)
|
||
|
||
logger.info("Invoice paid for sub %s (mandate %s)", subId, mandateId)
|
||
return None
|
||
|
||
|
||
@router.get("/admin/accounts/{targetMandateId}", response_model=List[AccountSummary])
|
||
@limiter.limit("30/minute")
|
||
def getAccounts(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Get all billing accounts for a mandate.
|
||
Access: SysAdmin (any mandate) or MandateAdmin (own mandate).
|
||
"""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=routeApiMsg("Admin role required for this mandate"))
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
|
||
# Get all accounts for this mandate via interface
|
||
accounts = billingInterface.getAccountsByMandate(targetMandateId)
|
||
|
||
result = []
|
||
for acc in accounts:
|
||
result.append(AccountSummary(
|
||
id=acc.get("id"),
|
||
mandateId=acc.get("mandateId"),
|
||
userId=acc.get("userId"),
|
||
balance=acc.get("balance", 0.0),
|
||
warningThreshold=acc.get("warningThreshold", 0.0),
|
||
enabled=acc.get("enabled", True)
|
||
))
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing accounts: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
class MandateUserSummary(BaseModel):
|
||
"""Summary of a user for billing admin purposes."""
|
||
id: str
|
||
username: Optional[str] = None
|
||
email: Optional[str] = None
|
||
firstName: Optional[str] = None
|
||
lastName: Optional[str] = None
|
||
displayName: Optional[str] = None
|
||
|
||
|
||
@router.get("/admin/users/{targetMandateId}", response_model=List[MandateUserSummary])
|
||
@limiter.limit("30/minute")
|
||
def getUsersForMandate(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""
|
||
Get all users belonging to a mandate.
|
||
Access: SysAdmin (any mandate) or MandateAdmin (own mandate).
|
||
Used by billing admin to select users for credit assignment.
|
||
"""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=routeApiMsg("Admin role required for this mandate"))
|
||
try:
|
||
from modules.interfaces.interfaceDbApp import getInterface as getAppInterface
|
||
|
||
appInterface = getAppInterface(ctx.user, mandateId=targetMandateId)
|
||
userMandates = appInterface.getUserMandatesByMandate(targetMandateId)
|
||
|
||
result = []
|
||
for um in userMandates:
|
||
userId = um.get("userId") if isinstance(um, dict) else getattr(um, "userId", None)
|
||
if not userId:
|
||
continue
|
||
|
||
user = appInterface.getUser(userId)
|
||
if not user:
|
||
continue
|
||
|
||
# Handle both Pydantic models and dicts
|
||
if isinstance(user, dict):
|
||
username = user.get("username", "")
|
||
firstName = user.get("firstName", "")
|
||
lastName = user.get("lastName", "")
|
||
email = user.get("email", "")
|
||
else:
|
||
username = getattr(user, "username", "") or ""
|
||
firstName = getattr(user, "firstName", "") or ""
|
||
lastName = getattr(user, "lastName", "") or ""
|
||
email = getattr(user, "email", "") or ""
|
||
|
||
displayName = f"{firstName} {lastName}".strip() or username or userId
|
||
|
||
result.append(MandateUserSummary(
|
||
id=userId,
|
||
username=username,
|
||
email=email,
|
||
firstName=firstName,
|
||
lastName=lastName,
|
||
displayName=displayName
|
||
))
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting users for mandate {targetMandateId}: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
def _attachCreatedByUserNamesToTransactionRows(rows: List[Dict[str, Any]]) -> None:
|
||
"""Resolve createdByUserId to userName using central FK resolvers.
|
||
|
||
Returns None (not a truncated UUID) for unresolvable IDs so the frontend
|
||
renders an explicit NA() indicator instead of a misleading 8-char snippet.
|
||
"""
|
||
from modules.routes.routeHelpers import resolveUserLabels
|
||
|
||
userIds = list({r.get("createdByUserId") for r in rows if r.get("createdByUserId")})
|
||
userMap: Dict[str, Optional[str]] = {}
|
||
if userIds:
|
||
userMap = resolveUserLabels(userIds)
|
||
|
||
for row in rows:
|
||
uid = row.get("createdByUserId")
|
||
row["userName"] = userMap.get(uid) if uid else None
|
||
|
||
|
||
def _enrichTransactionRows(transactions) -> List[Dict[str, Any]]:
|
||
"""Convert raw transaction dicts to enriched TransactionResponse rows with resolved usernames."""
|
||
result = []
|
||
for t in transactions:
|
||
row = TransactionResponse(
|
||
id=t.get("id"),
|
||
accountId=t.get("accountId"),
|
||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||
amount=t.get("amount", 0.0),
|
||
description=t.get("description", ""),
|
||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||
workflowId=t.get("workflowId"),
|
||
featureCode=t.get("featureCode"),
|
||
featureInstanceId=t.get("featureInstanceId"),
|
||
aicoreProvider=t.get("aicoreProvider"),
|
||
aicoreModel=t.get("aicoreModel"),
|
||
createdByUserId=t.get("createdByUserId"),
|
||
sysCreatedAt=t.get("sysCreatedAt")
|
||
)
|
||
result.append(row.model_dump())
|
||
|
||
_attachCreatedByUserNamesToTransactionRows(result)
|
||
return result
|
||
|
||
|
||
def _buildTransactionsList(ctx: RequestContext, targetMandateId: str, paginationParams: Optional[PaginationParams] = None) -> tuple:
|
||
"""Build enriched transactions for a mandate. Returns (items, paginatedResult|None)."""
|
||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||
|
||
if paginationParams:
|
||
paginatedResult = billingInterface.getTransactionsByMandate(targetMandateId, pagination=paginationParams)
|
||
transactions = paginatedResult.items if hasattr(paginatedResult, 'items') else paginatedResult.get("items", [])
|
||
else:
|
||
defaultPagination = PaginationParams(page=1, pageSize=200, sort=[{"field": "sysCreatedAt", "direction": "desc"}])
|
||
paginatedResult = billingInterface.getTransactionsByMandate(targetMandateId, pagination=defaultPagination)
|
||
transactions = paginatedResult.items if hasattr(paginatedResult, 'items') else paginatedResult.get("items", [])
|
||
|
||
result = _enrichTransactionRows(transactions)
|
||
return result, paginatedResult
|
||
|
||
|
||
@router.get("/admin/transactions/{targetMandateId}")
|
||
@limiter.limit("30/minute")
|
||
def getTransactionsAdmin(
|
||
request: Request,
|
||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
):
|
||
"""Get all transactions for a mandate with pagination support."""
|
||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=routeApiMsg("Admin role required for this mandate"))
|
||
try:
|
||
paginationParams: Optional[PaginationParams] = None
|
||
if pagination:
|
||
try:
|
||
paginationDict = json.loads(pagination)
|
||
if paginationDict:
|
||
paginationDict = normalize_pagination_dict(paginationDict)
|
||
paginationParams = PaginationParams(**paginationDict)
|
||
except (json.JSONDecodeError, ValueError) as e:
|
||
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
||
|
||
enriched, paginatedResult = _buildTransactionsList(ctx, targetMandateId, paginationParams)
|
||
totalItems = getattr(paginatedResult, 'totalItems', len(enriched)) if paginatedResult else len(enriched)
|
||
totalPages = getattr(paginatedResult, 'totalPages', 0) if paginatedResult else 0
|
||
|
||
paginationMeta = None
|
||
if paginationParams:
|
||
paginationMeta = PaginationMetadata(
|
||
currentPage=paginationParams.page,
|
||
pageSize=paginationParams.pageSize,
|
||
totalItems=totalItems,
|
||
totalPages=totalPages,
|
||
sort=paginationParams.sort,
|
||
filters=paginationParams.filters,
|
||
).model_dump()
|
||
|
||
return {"items": enriched, "pagination": paginationMeta}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Error getting billing transactions for mandate {targetMandateId}: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# =============================================================================
|
||
# Mandate View Endpoints (for Admins)
|
||
# =============================================================================
|
||
|
||
@router.get("/view/mandates/balances", response_model=List[MandateBalanceResponse])
|
||
@limiter.limit("30/minute")
|
||
def getMandateViewBalances(
|
||
request: Request,
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
_admin = Depends(requirePlatformAdmin)
|
||
):
|
||
"""
|
||
Get mandate-level balances (SysAdmin only).
|
||
Shows aggregated balances per mandate.
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
balances = billingInterface.getMandateBalances()
|
||
|
||
return [MandateBalanceResponse(**b) for b in balances]
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting mandate view balances: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/view/mandates/transactions", response_model=List[TransactionResponse])
|
||
@limiter.limit("30/minute")
|
||
def getMandateViewTransactions(
|
||
request: Request,
|
||
limit: int = Query(default=100, ge=1, le=1000),
|
||
ctx: RequestContext = Depends(getRequestContext),
|
||
_admin = Depends(requirePlatformAdmin)
|
||
):
|
||
"""
|
||
Get all transactions across mandates (SysAdmin only).
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
transactions = billingInterface.getMandateTransactions(limit=limit)
|
||
|
||
result = []
|
||
for t in transactions:
|
||
result.append(TransactionResponse(
|
||
id=t.get("id"),
|
||
accountId=t.get("accountId"),
|
||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||
amount=t.get("amount", 0.0),
|
||
description=t.get("description", ""),
|
||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||
workflowId=t.get("workflowId"),
|
||
featureCode=t.get("featureCode"),
|
||
featureInstanceId=t.get("featureInstanceId"),
|
||
aicoreProvider=t.get("aicoreProvider"),
|
||
aicoreModel=t.get("aicoreModel"),
|
||
createdByUserId=t.get("createdByUserId"),
|
||
sysCreatedAt=t.get("sysCreatedAt"),
|
||
mandateId=t.get("mandateId"),
|
||
mandateName=t.get("mandateName")
|
||
))
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting mandate view transactions: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# =============================================================================
|
||
# User View Endpoints (RBAC-based)
|
||
# =============================================================================
|
||
|
||
@router.get("/view/users/balances", response_model=List[UserBalanceResponse])
|
||
@limiter.limit("30/minute")
|
||
def getUserViewBalances(
|
||
request: Request,
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
):
|
||
"""
|
||
Get user-level balances.
|
||
|
||
RBAC filtering:
|
||
- SysAdmin: sees all user balances across all mandates
|
||
- Mandate-Admin: sees user balances for mandates they administrate
|
||
- Regular user: sees only their own balances
|
||
"""
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
|
||
# Evaluate RBAC scope
|
||
scope = _getBillingDataScope(ctx.user)
|
||
|
||
# Determine mandate IDs for data loading
|
||
if scope.isGlobalAdmin:
|
||
mandateIds = None
|
||
else:
|
||
mandateIds = scope.adminMandateIds + scope.memberMandateIds
|
||
if not mandateIds:
|
||
return []
|
||
|
||
allBalances = billingInterface.getUserBalancesForMandates(mandateIds)
|
||
|
||
# RBAC filter: mandate admins see all in their mandates, regular users only own
|
||
if not scope.isGlobalAdmin:
|
||
adminMandateSet = set(scope.adminMandateIds)
|
||
allBalances = [
|
||
b for b in allBalances
|
||
if b.get("mandateId") in adminMandateSet or b.get("userId") == scope.userId
|
||
]
|
||
|
||
return [UserBalanceResponse(**b) for b in allBalances]
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting user view balances: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
class ViewStatisticsResponse(BaseModel):
|
||
"""Aggregated statistics across all user's mandates."""
|
||
totalCost: float = 0.0
|
||
transactionCount: int = 0
|
||
costByProvider: Dict[str, float] = {}
|
||
costByModel: Dict[str, float] = {}
|
||
costByFeature: Dict[str, float] = {}
|
||
costByMandate: Dict[str, float] = {}
|
||
timeSeries: List[Dict[str, Any]] = []
|
||
|
||
|
||
@router.get("/view/statistics")
|
||
@limiter.limit("30/minute")
|
||
def getUserViewStatistics(
|
||
request: Request,
|
||
dateFrom: str = Query(..., description="ISO YYYY-MM-DD (inclusive)"),
|
||
dateTo: str = Query(..., description="ISO YYYY-MM-DD (inclusive)"),
|
||
bucketSize: str = Query(..., pattern="^(day|month|year)$",
|
||
description="Time-bucket granularity: day, month, or year"),
|
||
scope: str = Query(default="all", description="Scope: 'personal' (own costs only), 'mandate' (filter by mandateId), 'all' (RBAC-filtered)"),
|
||
mandateId: Optional[str] = Query(None, description="Mandate ID filter (used with scope='mandate')"),
|
||
onlyMine: Optional[bool] = Query(None, description="Additional filter: restrict to current user's transactions within the selected scope"),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
) -> ViewStatisticsResponse:
|
||
"""
|
||
Get aggregated usage statistics across all user's mandates.
|
||
|
||
Scope:
|
||
- personal: only the current user's own transactions (ignores admin role)
|
||
- mandate: transactions for a specific mandate (requires mandateId parameter)
|
||
- all: RBAC-filtered (SysAdmin sees everything, admin sees mandate, user sees own)
|
||
|
||
onlyMine: additional filter that restricts results to the current user's
|
||
transactions while keeping the scope-based mandate selection.
|
||
|
||
`dateFrom`/`dateTo` are inclusive local-day boundaries. `bucketSize`
|
||
controls the time-series aggregation granularity and is independent of
|
||
the chosen range.
|
||
"""
|
||
from modules.shared.dateRange import isoDateRangeToLocalEpoch
|
||
|
||
try:
|
||
startTs, endTs = isoDateRangeToLocalEpoch(dateFrom, dateTo)
|
||
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
|
||
rbacScope = _getBillingDataScope(ctx.user)
|
||
|
||
if rbacScope.isGlobalAdmin:
|
||
loadMandateIds = None
|
||
else:
|
||
loadMandateIds = rbacScope.adminMandateIds + rbacScope.memberMandateIds
|
||
if not loadMandateIds:
|
||
logger.warning("No mandate IDs found for user")
|
||
return ViewStatisticsResponse()
|
||
|
||
if scope == "mandate" and mandateId:
|
||
loadMandateIds = [mandateId]
|
||
|
||
personalUserId = str(ctx.user.id) if (scope == "personal" or onlyMine) else None
|
||
|
||
agg = billingInterface.getTransactionStatisticsAggregated(
|
||
mandateIds=loadMandateIds,
|
||
scope=scope,
|
||
userId=personalUserId,
|
||
startTs=startTs,
|
||
endTs=endTs,
|
||
bucketSize=bucketSize,
|
||
)
|
||
|
||
logger.info(
|
||
f"View statistics (SQL-aggregated): totalCost={agg['totalCost']}, "
|
||
f"count={agg['transactionCount']}, dateFrom={dateFrom}, dateTo={dateTo}, "
|
||
f"bucketSize={bucketSize}"
|
||
)
|
||
|
||
allAccounts = agg.get("_allAccounts", [])
|
||
accountToMandate: Dict[str, str] = {}
|
||
for acc in allAccounts:
|
||
accountToMandate[acc.get("id", "")] = acc.get("mandateId", "")
|
||
|
||
from modules.routes.routeHelpers import resolveMandateLabels
|
||
mandateIdsForLookup = list({v for v in accountToMandate.values() if v})
|
||
mandateMap: Dict[str, Optional[str]] = resolveMandateLabels(mandateIdsForLookup) if mandateIdsForLookup else {}
|
||
|
||
def _mandateName(accountId: str) -> str:
|
||
mid = accountToMandate.get(accountId, "")
|
||
return mandateMap.get(mid) or f"NA({mid})" if mid else "unknown"
|
||
|
||
costByMandate: Dict[str, float] = {}
|
||
for accId, total in agg.get("costByAccountId", {}).items():
|
||
name = _mandateName(accId)
|
||
costByMandate[name] = costByMandate.get(name, 0) + total
|
||
|
||
costByFeature: Dict[str, float] = {}
|
||
for entry in agg.get("costByAccountFeature", []):
|
||
name = _mandateName(entry["accountId"])
|
||
key = f"{name} / {entry['featureCode']}"
|
||
costByFeature[key] = costByFeature.get(key, 0) + entry["total"]
|
||
|
||
return ViewStatisticsResponse(
|
||
totalCost=agg["totalCost"],
|
||
transactionCount=agg["transactionCount"],
|
||
costByProvider=agg.get("costByProvider", {}),
|
||
costByModel=agg.get("costByModel", {}),
|
||
costByFeature=costByFeature,
|
||
costByMandate=costByMandate,
|
||
timeSeries=agg.get("timeSeries", []),
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting view statistics: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/view/users/transactions", response_model=PaginatedResponse[UserTransactionResponse])
|
||
@limiter.limit("30/minute")
|
||
def getUserViewTransactions(
|
||
request: Request,
|
||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||
scope: str = Query(default="all", description="Scope: 'personal' (own costs only), 'mandate' (filter by mandateId), 'all' (RBAC-filtered)"),
|
||
mandateId: Optional[str] = Query(None, description="Mandate ID filter (used with scope='mandate')"),
|
||
onlyMine: Optional[bool] = Query(None, description="Additional filter: restrict to current user's transactions within the selected scope"),
|
||
mode: Optional[str] = Query(None, description="'filterValues' for distinct column values, 'ids' for all filtered IDs"),
|
||
column: Optional[str] = Query(None, description="Column key (required when mode=filterValues)"),
|
||
ctx: RequestContext = Depends(getRequestContext)
|
||
) -> PaginatedResponse[UserTransactionResponse]:
|
||
"""
|
||
Get user-level transactions with pagination support.
|
||
|
||
Scope (same contract as /view/statistics):
|
||
- personal: only the current user's own transactions (ignores admin role)
|
||
- mandate: transactions for a specific mandate (requires mandateId parameter)
|
||
- all: RBAC-filtered (SysAdmin sees everything, admin sees mandate, user sees own)
|
||
|
||
onlyMine: additional filter that restricts results to the current user's
|
||
transactions while keeping the scope-based mandate selection.
|
||
|
||
Query Parameters:
|
||
- pagination: JSON-encoded PaginationParams object, or None for no pagination
|
||
- scope: 'personal', 'mandate', or 'all'
|
||
- mandateId: required when scope='mandate'
|
||
- onlyMine: true to restrict to current user's data within the scope
|
||
"""
|
||
from modules.routes.routeHelpers import parseCrossFilterPagination
|
||
|
||
try:
|
||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||
rbacScope = _getBillingDataScope(ctx.user)
|
||
|
||
if rbacScope.isGlobalAdmin:
|
||
loadMandateIds = None
|
||
else:
|
||
loadMandateIds = rbacScope.adminMandateIds + rbacScope.memberMandateIds
|
||
if not loadMandateIds:
|
||
if mode:
|
||
return []
|
||
return PaginatedResponse(items=[], pagination=None)
|
||
|
||
if scope == "mandate" and mandateId:
|
||
loadMandateIds = [mandateId]
|
||
|
||
personalUserId = str(ctx.user.id) if (scope == "personal" or onlyMine) else None
|
||
|
||
if mode == "filterValues":
|
||
if not column:
|
||
raise HTTPException(status_code=400, detail="column parameter required for mode=filterValues")
|
||
from fastapi.responses import JSONResponse
|
||
crossFilterParams = parseCrossFilterPagination(column, pagination)
|
||
values = billingInterface.getTransactionDistinctValues(
|
||
mandateIds=loadMandateIds,
|
||
column=column,
|
||
pagination=crossFilterParams,
|
||
scope=scope,
|
||
userId=personalUserId,
|
||
)
|
||
return JSONResponse(content=values)
|
||
|
||
if mode == "ids":
|
||
from fastapi.responses import JSONResponse
|
||
paginationParams = None
|
||
if pagination:
|
||
import json as _json
|
||
paginationDict = _json.loads(pagination)
|
||
paginationDict = normalize_pagination_dict(paginationDict)
|
||
paginationParams = PaginationParams(**paginationDict)
|
||
ids = billingInterface.getTransactionIds(
|
||
mandateIds=loadMandateIds,
|
||
pagination=paginationParams,
|
||
scope=scope,
|
||
userId=personalUserId,
|
||
) if hasattr(billingInterface, 'getTransactionIds') else []
|
||
return JSONResponse(content=ids)
|
||
|
||
paginationParams = None
|
||
if pagination:
|
||
import json as _json
|
||
paginationDict = _json.loads(pagination)
|
||
paginationDict = normalize_pagination_dict(paginationDict)
|
||
paginationParams = PaginationParams(**paginationDict)
|
||
|
||
effectiveScope = scope
|
||
|
||
if not paginationParams:
|
||
paginationParams = PaginationParams(page=1, pageSize=50)
|
||
|
||
result = billingInterface.getTransactionsForMandatesPaginated(
|
||
mandateIds=loadMandateIds,
|
||
pagination=paginationParams,
|
||
scope=effectiveScope,
|
||
userId=personalUserId,
|
||
)
|
||
|
||
logger.debug(f"SQL-paginated {result.totalItems} transactions for user {ctx.user.id} "
|
||
f"(scope={scope}, mandateId={mandateId}, page={paginationParams.page})")
|
||
|
||
def _toResponse(d):
|
||
return UserTransactionResponse(
|
||
id=d.get("id"),
|
||
accountId=d.get("accountId"),
|
||
transactionType=TransactionTypeEnum(d.get("transactionType", "DEBIT")),
|
||
amount=d.get("amount", 0.0),
|
||
description=d.get("description", ""),
|
||
referenceType=ReferenceTypeEnum(d["referenceType"]) if d.get("referenceType") else None,
|
||
workflowId=d.get("workflowId"),
|
||
featureCode=d.get("featureCode"),
|
||
featureInstanceId=d.get("featureInstanceId"),
|
||
aicoreProvider=d.get("aicoreProvider"),
|
||
aicoreModel=d.get("aicoreModel"),
|
||
createdByUserId=d.get("createdByUserId"),
|
||
sysCreatedAt=d.get("sysCreatedAt"),
|
||
mandateId=d.get("mandateId"),
|
||
mandateName=d.get("mandateName"),
|
||
userId=d.get("userId"),
|
||
userName=d.get("userName")
|
||
)
|
||
|
||
return PaginatedResponse(
|
||
items=[_toResponse(d) for d in result.items],
|
||
pagination=PaginationMetadata(
|
||
currentPage=paginationParams.page,
|
||
pageSize=paginationParams.pageSize,
|
||
totalItems=result.totalItems,
|
||
totalPages=result.totalPages,
|
||
sort=paginationParams.sort,
|
||
filters=paginationParams.filters,
|
||
)
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error getting user view transactions: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|