# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Billing routes for the backend API. Implements the endpoints for billing management and usage tracking. Features: - User endpoints: View balance, transactions, statistics - Admin endpoints: Manage settings, add credits, view all accounts """ from fastapi import APIRouter, HTTPException, Depends, Body, Path, Request, Response, Query from typing import List, Dict, Any, Optional from fastapi import status import logging from datetime import date, datetime from pydantic import BaseModel, Field # Import auth module from modules.auth import limiter, requireSysAdmin, getRequestContext, RequestContext # Import billing components from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface from modules.services.serviceBilling.mainServiceBilling import getService as getBillingService from modules.datamodels.datamodelBilling import ( BillingAccount, BillingTransaction, BillingSettings, BillingAddress, BillingModelEnum, TransactionTypeEnum, ReferenceTypeEnum, PeriodTypeEnum, BillingBalanceResponse, BillingStatisticsResponse, BillingStatisticsChartData, BillingCheckResult, ) # Configure logger logger = logging.getLogger(__name__) # ============================================================================= # Request/Response Models # ============================================================================= class CreditAddRequest(BaseModel): """Request model for adding credit to an account.""" userId: Optional[str] = Field(None, description="Target user ID (for PREPAY_USER model)") amount: float = Field(..., gt=0, description="Amount to credit in CHF") description: str = Field(default="Manual credit", description="Transaction description") class BillingSettingsUpdate(BaseModel): """Request model for updating billing settings.""" billingModel: Optional[BillingModelEnum] = None defaultUserCredit: Optional[float] = Field(None, ge=0) warningThresholdPercent: Optional[float] = Field(None, ge=0, le=100) blockOnZeroBalance: Optional[bool] = None notifyOnWarning: Optional[bool] = None notifyEmails: Optional[List[str]] = None billingAddress: Optional[BillingAddress] = None class TransactionResponse(BaseModel): """Response model for a billing transaction.""" id: str accountId: str transactionType: TransactionTypeEnum amount: float description: str referenceType: Optional[ReferenceTypeEnum] workflowId: Optional[str] featureCode: Optional[str] aicoreProvider: Optional[str] createdAt: Optional[datetime] class AccountSummary(BaseModel): """Summary of a billing account.""" id: str mandateId: str userId: Optional[str] accountType: str balance: float creditLimit: Optional[float] warningThreshold: float enabled: bool class UsageReportResponse(BaseModel): """Usage report for a period.""" period: str totalCost: float transactionCount: int costByProvider: Dict[str, float] costByFeature: Dict[str, float] # ============================================================================= # Router Setup # ============================================================================= router = APIRouter( prefix="/api/billing", tags=["Billing"], responses={404: {"description": "Not found"}} ) # ============================================================================= # User Endpoints # ============================================================================= @router.get("/balance", response_model=List[BillingBalanceResponse]) @limiter.limit("60/minute") async def getBalance( request: Request, ctx: RequestContext = Depends(getRequestContext) ): """ Get billing balances for all mandates the current user belongs to. Returns balance information for each mandate. """ try: billingService = getBillingService( ctx.currentUser, ctx.mandateId, featureCode="billing" ) balances = billingService.getBalancesForUser() return balances except Exception as e: logger.error(f"Error getting billing balance: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/balance/{targetMandateId}", response_model=BillingBalanceResponse) @limiter.limit("60/minute") async def getBalanceForMandate( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), ctx: RequestContext = Depends(getRequestContext) ): """ Get billing balance for a specific mandate. """ try: billingService = getBillingService( ctx.currentUser, targetMandateId, featureCode="billing" ) # Check balance checkResult = billingService.checkBalance(0.0) # Get mandate name from app interface from modules.interfaces.interfaceDbApp import getInterface as getAppInterface appInterface = getAppInterface(ctx.currentUser, mandateId=targetMandateId) mandate = appInterface.getMandate(targetMandateId) mandateName = mandate.get("name", "") if mandate else "" return BillingBalanceResponse( mandateId=targetMandateId, mandateName=mandateName, billingModel=checkResult.billingModel or BillingModelEnum.UNLIMITED, balance=checkResult.currentBalance or 0.0, warningThreshold=0.0, # TODO: Get from account isWarning=False, creditLimit=None ) except Exception as e: logger.error(f"Error getting billing balance for mandate {targetMandateId}: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/transactions", response_model=List[TransactionResponse]) @limiter.limit("30/minute") async def getTransactions( request: Request, limit: int = Query(default=50, ge=1, le=500), offset: int = Query(default=0, ge=0), ctx: RequestContext = Depends(getRequestContext) ): """ Get transaction history for the current mandate. """ try: billingService = getBillingService( ctx.currentUser, ctx.mandateId, featureCode="billing" ) transactions = billingService.getTransactionHistory(limit=limit) # Convert to response model result = [] for t in transactions[offset:offset + limit]: result.append(TransactionResponse( id=t.get("id"), accountId=t.get("accountId"), transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), amount=t.get("amount", 0.0), description=t.get("description", ""), referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, workflowId=t.get("workflowId"), featureCode=t.get("featureCode"), aicoreProvider=t.get("aicoreProvider"), createdAt=t.get("_createdAt") )) return result except Exception as e: logger.error(f"Error getting billing transactions: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/statistics/{period}", response_model=UsageReportResponse) @limiter.limit("30/minute") async def getStatistics( request: Request, period: str = Path(..., description="Period: 'day', 'month', or 'year'"), year: int = Query(..., description="Year"), month: Optional[int] = Query(None, description="Month (1-12, required for 'day' period)"), ctx: RequestContext = Depends(getRequestContext) ): """ Get usage statistics for a period. """ try: # Validate period if period not in ["day", "month", "year"]: raise HTTPException(status_code=400, detail="Invalid period. Use 'day', 'month', or 'year'") if period == "day" and not month: raise HTTPException(status_code=400, detail="Month is required for 'day' period") billingInterface = getBillingInterface(ctx.currentUser, ctx.mandateId) settings = billingInterface.getSettings(ctx.mandateId) if not settings: return UsageReportResponse( period=period, totalCost=0.0, transactionCount=0, costByProvider={}, costByFeature={} ) billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) # Get the relevant account if billingModel == BillingModelEnum.PREPAY_USER: account = billingInterface.getUserAccount(ctx.mandateId, ctx.currentUser.id) else: account = billingInterface.getMandateAccount(ctx.mandateId) if not account: return UsageReportResponse( period=period, totalCost=0.0, transactionCount=0, costByProvider={}, costByFeature={} ) # Calculate date range if period == "day": startDate = date(year, month, 1) if month == 12: endDate = date(year + 1, 1, 1) else: endDate = date(year, month + 1, 1) elif period == "month": startDate = date(year, 1, 1) endDate = date(year + 1, 1, 1) else: # year startDate = date(year, 1, 1) endDate = date(year + 1, 1, 1) # Get statistics from transactions stats = billingInterface.calculateStatisticsFromTransactions( account["id"], startDate, endDate ) return UsageReportResponse( period=period, totalCost=stats.get("totalCostCHF", 0.0), transactionCount=stats.get("transactionCount", 0), costByProvider=stats.get("costByProvider", {}), costByFeature=stats.get("costByFeature", {}) ) except HTTPException: raise except Exception as e: logger.error(f"Error getting billing statistics: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/providers", response_model=List[str]) @limiter.limit("60/minute") async def getAllowedProviders( request: Request, ctx: RequestContext = Depends(getRequestContext) ): """ Get list of AICore providers the current user is allowed to use. """ try: billingService = getBillingService( ctx.currentUser, ctx.mandateId, featureCode="billing" ) return billingService.getallowedProviders() except Exception as e: logger.error(f"Error getting allowed providers: {e}") raise HTTPException(status_code=500, detail=str(e)) # ============================================================================= # Admin Endpoints # ============================================================================= @router.get("/admin/settings/{targetMandateId}", response_model=Dict[str, Any]) @limiter.limit("30/minute") async def getSettingsAdmin( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), ctx: RequestContext = Depends(getRequestContext), _admin = Depends(requireSysAdmin) ): """ Get billing settings for a mandate (SysAdmin only). """ try: billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) settings = billingInterface.getSettings(targetMandateId) if not settings: raise HTTPException(status_code=404, detail="Billing settings not found") return settings except HTTPException: raise except Exception as e: logger.error(f"Error getting billing settings: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/admin/settings/{targetMandateId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") async def createOrUpdateSettings( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), settingsUpdate: BillingSettingsUpdate = Body(...), ctx: RequestContext = Depends(getRequestContext), _admin = Depends(requireSysAdmin) ): """ Create or update billing settings for a mandate (SysAdmin only). """ try: billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) existingSettings = billingInterface.getSettings(targetMandateId) if existingSettings: # Update existing settings updates = settingsUpdate.model_dump(exclude_none=True) if updates: result = billingInterface.updateSettings(existingSettings["id"], updates) return result or existingSettings return existingSettings else: # Create new settings from modules.datamodels.datamodelBilling import BillingSettings newSettings = BillingSettings( mandateId=targetMandateId, billingModel=settingsUpdate.billingModel or BillingModelEnum.UNLIMITED, defaultUserCredit=settingsUpdate.defaultUserCredit or 10.0, warningThresholdPercent=settingsUpdate.warningThresholdPercent or 10.0, blockOnZeroBalance=settingsUpdate.blockOnZeroBalance if settingsUpdate.blockOnZeroBalance is not None else True, notifyOnWarning=settingsUpdate.notifyOnWarning if settingsUpdate.notifyOnWarning is not None else True, notifyEmails=settingsUpdate.notifyEmails or [], billingAddress=settingsUpdate.billingAddress ) return billingInterface.createSettings(newSettings) except HTTPException: raise except Exception as e: logger.error(f"Error updating billing settings: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.post("/admin/credit/{targetMandateId}", response_model=Dict[str, Any]) @limiter.limit("10/minute") async def addCredit( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), creditRequest: CreditAddRequest = Body(...), ctx: RequestContext = Depends(getRequestContext), _admin = Depends(requireSysAdmin) ): """ Add credit to a billing account (SysAdmin only). For PREPAY_USER model, specify userId. For PREPAY_MANDATE, leave userId empty. """ try: # Get settings to determine billing model billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) settings = billingInterface.getSettings(targetMandateId) if not settings: raise HTTPException(status_code=404, detail="Billing settings not found for this mandate") billingModel = BillingModelEnum(settings.get("billingModel", BillingModelEnum.UNLIMITED.value)) # Validate request based on billing model if billingModel == BillingModelEnum.PREPAY_USER: if not creditRequest.userId: raise HTTPException(status_code=400, detail="userId is required for PREPAY_USER model") # Create user-level account if needed and add credit account = billingInterface.getOrCreateUserAccount( targetMandateId, creditRequest.userId, initialBalance=0.0 ) elif billingModel in [BillingModelEnum.PREPAY_MANDATE, BillingModelEnum.CREDIT_POSTPAY]: # Create mandate-level account if needed and add credit account = billingInterface.getOrCreateMandateAccount(targetMandateId, initialBalance=0.0) else: raise HTTPException(status_code=400, detail=f"Cannot add credit to {billingModel.value} billing model") # Create credit transaction from modules.datamodels.datamodelBilling import BillingTransaction transaction = BillingTransaction( accountId=account["id"], transactionType=TransactionTypeEnum.CREDIT, amount=creditRequest.amount, description=creditRequest.description, referenceType=ReferenceTypeEnum.ADMIN ) result = billingInterface.createTransaction(transaction) logger.info(f"Added {creditRequest.amount} CHF credit to account {account['id']} in mandate {targetMandateId}") return result except HTTPException: raise except Exception as e: logger.error(f"Error adding credit: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/admin/accounts/{targetMandateId}", response_model=List[AccountSummary]) @limiter.limit("30/minute") async def getAccounts( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), ctx: RequestContext = Depends(getRequestContext), _admin = Depends(requireSysAdmin) ): """ Get all billing accounts for a mandate (SysAdmin only). """ try: billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) # Get all accounts for this mandate via interface accounts = billingInterface.getAccountsByMandate(targetMandateId) result = [] for acc in accounts: result.append(AccountSummary( id=acc.get("id"), mandateId=acc.get("mandateId"), userId=acc.get("userId"), accountType=acc.get("accountType"), balance=acc.get("balance", 0.0), creditLimit=acc.get("creditLimit"), warningThreshold=acc.get("warningThreshold", 0.0), enabled=acc.get("enabled", True) )) return result except Exception as e: logger.error(f"Error getting billing accounts: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/admin/transactions/{targetMandateId}", response_model=List[TransactionResponse]) @limiter.limit("30/minute") async def getTransactionsAdmin( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), limit: int = Query(default=100, ge=1, le=1000), ctx: RequestContext = Depends(getRequestContext), _admin = Depends(requireSysAdmin) ): """ Get all transactions for a mandate (SysAdmin only). """ try: billingInterface = getBillingInterface(ctx.currentUser, targetMandateId) transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=limit) result = [] for t in transactions: result.append(TransactionResponse( id=t.get("id"), accountId=t.get("accountId"), transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), amount=t.get("amount", 0.0), description=t.get("description", ""), referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, workflowId=t.get("workflowId"), featureCode=t.get("featureCode"), aicoreProvider=t.get("aicoreProvider"), createdAt=t.get("_createdAt") )) return result except Exception as e: logger.error(f"Error getting billing transactions for mandate {targetMandateId}: {e}") raise HTTPException(status_code=500, detail=str(e))