# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Subscription routes — ID-based, state-machine-driven. Endpoints: - GET /api/subscription/plans — list selectable plans - GET /api/subscription/status — operative + scheduled subscription for current mandate - POST /api/subscription/activate — start checkout for a plan - POST /api/subscription/cancel — cancel a specific subscription (by ID) - POST /api/subscription/reactivate — reactivate a cancelled subscription (by ID) - POST /api/subscription/force-cancel — sysadmin immediate cancel (by ID) """ from fastapi import APIRouter, HTTPException, Depends, Request, Query, Path from fastapi import status from typing import Dict, Any, List, Optional import logging import json import math from pydantic import BaseModel, Field from modules.auth import limiter, getRequestContext, RequestContext from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues from modules.shared.i18nRegistry import apiRouteContext routeApiMsg = apiRouteContext("routeSubscription") logger = logging.getLogger(__name__) def _resolveMandateId(context: RequestContext) -> str: if context.mandateId: return str(context.mandateId) return "" def _assertMandateAdmin(context: RequestContext, mandateId: str) -> None: if context.hasSysAdminRole: return try: from modules.interfaces.interfaceDbApp import getRootInterface rootInterface = getRootInterface() userMandates = rootInterface.getUserMandates(str(context.user.id)) for um in userMandates: if str(getattr(um, "mandateId", None)) != str(mandateId): continue if not getattr(um, "enabled", True): continue umId = str(getattr(um, "id", "")) roleIds = rootInterface.getRoleIdsForUserMandate(umId) for roleId in roleIds: role = rootInterface.getRole(roleId) if role and role.roleLabel == "admin" and not role.featureInstanceId: return except Exception: pass raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=routeApiMsg("Mandate admin role required")) # ============================================================================= # Request / Response models # ============================================================================= class ActivatePlanRequest(BaseModel): planKey: str = Field(..., description="Key of the plan to activate") returnUrl: str = Field(..., description="Frontend URL to redirect back to after Stripe Checkout") class CancelRequest(BaseModel): subscriptionId: str = Field(..., description="ID of the subscription to cancel") class ReactivateRequest(BaseModel): subscriptionId: str = Field(..., description="ID of the subscription to reactivate") class ForceCancelRequest(BaseModel): subscriptionId: str = Field(..., description="ID of the subscription to force-cancel") class VerifyCheckoutRequest(BaseModel): sessionId: str = Field(..., description="Stripe Checkout Session ID to verify") class SubscriptionStatusResponse(BaseModel): active: bool subscription: Optional[Dict[str, Any]] = None plan: Optional[Dict[str, Any]] = None scheduled: Optional[Dict[str, Any]] = None # ============================================================================= # Router # ============================================================================= router = APIRouter( prefix="/api/subscription", tags=["Subscription"], responses={404: {"description": "Not found"}}, ) # ============================================================================= # Endpoints # ============================================================================= @router.get("/plans", response_model=List[Dict[str, Any]]) @limiter.limit("30/minute") def getPlans(request: Request, context: RequestContext = Depends(getRequestContext)): from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) try: mandateId = _resolveMandateId(context) subService = getSubscriptionService(context.user, mandateId) plans = subService.getSelectablePlans() return [p.model_dump() for p in plans] except Exception as e: logger.error("Error fetching plans: %s", e) raise HTTPException(status_code=500, detail=str(e)) @router.get("/status", response_model=SubscriptionStatusResponse) @limiter.limit("60/minute") def getStatus(request: Request, context: RequestContext = Depends(getRequestContext)): """Return the operative subscription and any scheduled successor for the current mandate.""" from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) mandateId = _resolveMandateId(context) if not mandateId: return SubscriptionStatusResponse(active=False) _assertMandateAdmin(context, mandateId) try: subService = getSubscriptionService(context.user, mandateId) operative = subService.getOperativeSubscription(mandateId) scheduled = subService.getScheduledSubscription(mandateId) if not operative: from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum pending = subService.listSubscriptions(mandateId, [SubscriptionStatusEnum.PENDING]) if pending: sub = pending[0] plan = subService.getPlan(sub.get("planKey", "")) return SubscriptionStatusResponse( active=False, subscription=sub, plan=plan.model_dump() if plan else None, scheduled=scheduled, ) return SubscriptionStatusResponse(active=False, scheduled=scheduled) plan = subService.getPlan(operative.get("planKey", "")) return SubscriptionStatusResponse( active=True, subscription=operative, plan=plan.model_dump() if plan else None, scheduled=scheduled, ) except Exception as e: logger.error("Error fetching status: %s", e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/activate", response_model=Dict[str, Any]) @limiter.limit("10/minute") def activatePlan( request: Request, data: ActivatePlanRequest, context: RequestContext = Depends(getRequestContext), ): from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) mandateId = _resolveMandateId(context) if not mandateId: raise HTTPException(status_code=400, detail=routeApiMsg("X-Mandate-Id header required")) _assertMandateAdmin(context, mandateId) try: subService = getSubscriptionService(context.user, mandateId) return subService.activatePlan(mandateId, data.planKey, returnUrl=data.returnUrl) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error("Error activating plan %s: %s", data.planKey, e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/cancel", response_model=Dict[str, Any]) @limiter.limit("30/minute") def cancelSubscription( request: Request, data: CancelRequest, context: RequestContext = Depends(getRequestContext), ): """Cancel a specific subscription by its ID.""" from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) mandateId = _resolveMandateId(context) if not mandateId: raise HTTPException(status_code=400, detail=routeApiMsg("X-Mandate-Id header required")) _assertMandateAdmin(context, mandateId) try: subService = getSubscriptionService(context.user, mandateId) return subService.cancelSubscription(data.subscriptionId) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error("Error cancelling subscription %s: %s", data.subscriptionId, e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/reactivate", response_model=Dict[str, Any]) @limiter.limit("30/minute") def reactivateSubscription( request: Request, data: ReactivateRequest, context: RequestContext = Depends(getRequestContext), ): """Reactivate a cancelled (non-recurring) subscription before its period ends.""" from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) mandateId = _resolveMandateId(context) if not mandateId: raise HTTPException(status_code=400, detail=routeApiMsg("X-Mandate-Id header required")) _assertMandateAdmin(context, mandateId) try: subService = getSubscriptionService(context.user, mandateId) return subService.reactivateSubscription(data.subscriptionId) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error("Error reactivating subscription %s: %s", data.subscriptionId, e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/force-cancel", response_model=Dict[str, Any]) @limiter.limit("30/minute") def forceCancel( request: Request, data: ForceCancelRequest, context: RequestContext = Depends(getRequestContext), ): """Sysadmin: immediately expire any non-terminal subscription.""" if not context.hasSysAdminRole: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=routeApiMsg("Sysadmin role required")) from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface sub = getSubRootInterface().getById(data.subscriptionId) if not sub: raise HTTPException(status_code=404, detail=routeApiMsg("Subscription not found")) mandateId = sub["mandateId"] try: subService = getSubscriptionService(context.user, mandateId) return subService.forceCancel(data.subscriptionId) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error("Error force-cancelling subscription %s: %s", data.subscriptionId, e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/checkout/verify", response_model=Dict[str, Any]) @limiter.limit("20/minute") def verifyCheckout( request: Request, data: VerifyCheckoutRequest, context: RequestContext = Depends(getRequestContext), ): """Verify a Stripe Checkout Session and activate the subscription if paid. Idempotent: if the webhook already processed the session, returns success. Called by the frontend immediately after returning from Stripe. """ mandateId = _resolveMandateId(context) if not mandateId: raise HTTPException(status_code=400, detail=routeApiMsg("X-Mandate-Id header required")) _assertMandateAdmin(context, mandateId) try: from modules.shared.stripeClient import getStripeClient, stripeToDict stripe = getStripeClient() rawSession = stripe.checkout.Session.retrieve(data.sessionId) session = stripeToDict(rawSession) except Exception as e: logger.error("Failed to retrieve checkout session %s: %s", data.sessionId, e) raise HTTPException(status_code=400, detail=routeApiMsg("Invalid session ID")) payStatus = session.get("payment_status") if session.get("status") != "complete": return {"status": "pending", "message": "Checkout not yet completed"} if payStatus not in ("paid", "no_payment_required"): return {"status": "pending", "message": "Checkout not yet completed"} if session.get("mode") != "subscription": raise HTTPException(status_code=400, detail=routeApiMsg("Not a subscription checkout session")) from modules.routes.routeBilling import _handleSubscriptionCheckoutCompleted try: _handleSubscriptionCheckoutCompleted(session, f"verify-{data.sessionId}") except Exception as e: logger.warning( "verifyCheckout: handler raised for session %s mandate %s: %s", data.sessionId, mandateId, e, ) from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) from modules.datamodels.datamodelSubscription import OPERATIVE_STATUSES subService = getSubscriptionService(context.user, mandateId) operative = subService.getOperativeSubscription(mandateId) if operative and operative.get("status") in [s.value for s in OPERATIVE_STATUSES]: planKey = operative.get("planKey", "") if planKey: try: from modules.interfaces.interfaceDbBilling import _getRootInterface as _getBillingRoot _getBillingRoot().ensureActivationBudget(mandateId, planKey) except Exception as ex: logger.warning("verifyCheckout: ensureActivationBudget failed: %s", ex) return {"status": "activated", "message": "Subscription activated"} return {"status": "pending", "message": "Subscription activation pending — webhook may still be processing."} # ============================================================================= # SysAdmin: global subscription overview # ============================================================================= def _buildEnrichedSubscriptions() -> List[Dict[str, Any]]: """Build the full enriched subscription list (shared by list + filter-values endpoints).""" from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface from modules.datamodels.datamodelSubscription import BUILTIN_PLANS, OPERATIVE_STATUSES subInterface = getSubRootInterface() allSubs = subInterface.listAll() mandateNames: Dict[str, str] = {} try: from modules.datamodels.datamodelUam import Mandate from modules.security.rootAccess import getRootDbAppConnector appDb = getRootDbAppConnector() for row in appDb.getRecordset(Mandate): r = dict(row) mid = r.get("id", "") mandateNames[mid] = r.get("label") or r.get("name") or mid[:8] except Exception as e: logger.warning("Could not bulk-resolve mandate names: %s", e) operativeValues = {s.value for s in OPERATIVE_STATUSES} operativeMandateIds = list({ sub.get("mandateId") for sub in allSubs if sub.get("mandateId") and sub.get("status") in operativeValues }) userCountMap: Dict[str, int] = {} instanceCountMap: Dict[str, int] = {} if operativeMandateIds: try: from modules.datamodels.datamodelMembership import UserMandate from modules.datamodels.datamodelFeatures import FeatureInstance from modules.security.rootAccess import getRootDbAppConnector appDb = getRootDbAppConnector() allUM = appDb.getRecordset(UserMandate, recordFilter={"mandateId": operativeMandateIds}) for um in (allUM or []): mid = um.get("mandateId") if isinstance(um, dict) else getattr(um, "mandateId", None) if mid: userCountMap[mid] = userCountMap.get(mid, 0) + 1 allFI = appDb.getRecordset(FeatureInstance, recordFilter={"mandateId": operativeMandateIds}) for fi in (allFI or []): fid = fi if isinstance(fi, dict) else fi.__dict__ if fid.get("enabled"): mid = fid.get("mandateId") if mid: instanceCountMap[mid] = instanceCountMap.get(mid, 0) + 1 except Exception as e: logger.warning("Batch count for subscriptions failed: %s", e) enriched = [] for sub in allSubs: mid = sub.get("mandateId", "") planKey = sub.get("planKey", "") plan = BUILTIN_PLANS.get(planKey) sub["mandateName"] = mandateNames.get(mid, mid[:8]) sub["planTitle"] = (plan.title or planKey) if plan else planKey if sub.get("status") in operativeValues: userPrice = sub.get("snapshotPricePerUserCHF", 0) or 0 instPrice = sub.get("snapshotPricePerInstanceCHF", 0) or 0 userCount = userCountMap.get(mid, 0) instanceCount = instanceCountMap.get(mid, 0) includedModules = plan.includedModules if plan else 0 billableModules = max(0, instanceCount - includedModules) sub["monthlyRevenueCHF"] = round(userPrice * userCount + instPrice * billableModules, 2) sub["activeUsers"] = userCount sub["activeInstances"] = instanceCount else: sub["monthlyRevenueCHF"] = 0 sub["activeUsers"] = 0 sub["activeInstances"] = 0 enriched.append(sub) return enriched @router.get("/admin/all") @limiter.limit("30/minute") def getAllSubscriptions( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), context: RequestContext = Depends(getRequestContext), ): """SysAdmin: list ALL subscriptions across all mandates with enriched metadata.""" if not context.hasSysAdminRole: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=routeApiMsg("Sysadmin role required")) paginationParams: Optional[PaginationParams] = None if pagination: try: paginationDict = json.loads(pagination) if paginationDict: paginationDict = normalize_pagination_dict(paginationDict) paginationParams = PaginationParams(**paginationDict) except (json.JSONDecodeError, ValueError) as e: raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}") enriched = _buildEnrichedSubscriptions() filtered = _applyFiltersAndSort(enriched, paginationParams) if paginationParams: totalItems = len(filtered) totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0 startIdx = (paginationParams.page - 1) * paginationParams.pageSize endIdx = startIdx + paginationParams.pageSize pageItems = filtered[startIdx:endIdx] return { "items": pageItems, "pagination": PaginationMetadata( currentPage=paginationParams.page, pageSize=paginationParams.pageSize, totalItems=totalItems, totalPages=totalPages, sort=paginationParams.sort, filters=paginationParams.filters, ).model_dump(), } return {"items": enriched, "pagination": None} @router.get("/admin/all/filter-values") @limiter.limit("60/minute") def getFilterValues( request: Request, column: str = Query(..., description="Column key to extract distinct values for"), pagination: Optional[str] = Query(None, description="JSON-encoded current filters (applied except for the requested column)"), context: RequestContext = Depends(getRequestContext), ): """Return distinct values for a column, respecting all active filters except the requested one.""" if not context.hasSysAdminRole: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=routeApiMsg("Sysadmin role required")) crossFilterParams: Optional[PaginationParams] = None if pagination: try: paginationDict = json.loads(pagination) if paginationDict: paginationDict = normalize_pagination_dict(paginationDict) filters = paginationDict.get("filters", {}) filters.pop(column, None) paginationDict["filters"] = filters paginationDict.pop("sort", None) crossFilterParams = PaginationParams(**paginationDict) except (json.JSONDecodeError, ValueError) as e: raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}") enriched = _buildEnrichedSubscriptions() crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams) return _extractDistinctValues(crossFiltered, column) # ============================================================ # Data Volume Usage per Mandate # ============================================================ @router.get("/data-volume/{targetMandateId}") @limiter.limit("60/minute") def _getDataVolumeUsage( request: Request, targetMandateId: str = Path(..., description="Mandate ID to check volume for"), context: RequestContext = Depends(getRequestContext), ): """Calculate current data volume usage for a mandate vs. plan limit.""" from modules.interfaces.interfaceDbApp import getRootInterface from modules.datamodels.datamodelFiles import FileItem from modules.datamodels.datamodelFeatures import FeatureInstance from modules.interfaces.interfaceDbKnowledge import aggregateMandateRagTotalBytes from modules.interfaces.interfaceDbManagement import getInterface as getMgmtInterface from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRootIf rootIf = getRootInterface() mandateId = targetMandateId instances = rootIf.db.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId}) instIds = [str(inst.get("id") or "") for inst in instances if inst.get("id")] mgmtDb = getMgmtInterface().db totalFileBytes = 0 if instIds: files = mgmtDb.getRecordset(FileItem, recordFilter={"featureInstanceId": instIds}) for f in (files or []): size = f.get("fileSize") if isinstance(f, dict) else getattr(f, "fileSize", 0) totalFileBytes += (size or 0) mandateFiles = mgmtDb.getRecordset(FileItem, recordFilter={"mandateId": mandateId}) for f in (mandateFiles or []): size = f.get("fileSize") if isinstance(f, dict) else getattr(f, "fileSize", 0) totalFileBytes += (size or 0) filesMB = round(totalFileBytes / (1024 * 1024), 2) ragBytes = aggregateMandateRagTotalBytes(mandateId) ragMB = round(ragBytes / (1024 * 1024), 2) maxMB = None subIf = _getSubRootIf() operative = subIf.getOperativeForMandate(mandateId) if operative: plan = subIf.getPlan(operative.get("planKey") or "") if plan and plan.maxDataVolumeMB is not None: maxMB = int(plan.maxDataVolumeMB) usedMB = ragMB percentUsed = round((usedMB / maxMB) * 100, 1) if maxMB else None logger.info( "data-volume mandate=%s: files=%.2f MB, rag=%.2f MB, max=%s MB", mandateId, filesMB, ragMB, maxMB, ) return { "mandateId": mandateId, "usedMB": usedMB, "filesMB": filesMB, "ragIndexMB": ragMB, "maxDataVolumeMB": maxMB, "percentUsed": percentUsed, "warning": (percentUsed or 0) >= 80, }