# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Subscription routes — ID-based, state-machine-driven. Endpoints: - GET /api/subscription/plans — list selectable plans - GET /api/subscription/status — operative + scheduled subscription for current mandate - POST /api/subscription/activate — start checkout for a plan - POST /api/subscription/cancel — cancel a specific subscription (by ID) - POST /api/subscription/reactivate — reactivate a cancelled subscription (by ID) - POST /api/subscription/force-cancel — sysadmin immediate cancel (by ID) """ from fastapi import APIRouter, HTTPException, Depends, Request, Query from fastapi import status from typing import Dict, Any, List, Optional import logging import json import math from pydantic import BaseModel, Field from modules.auth import limiter, getRequestContext, RequestContext from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues logger = logging.getLogger(__name__) def _resolveMandateId(context: RequestContext) -> str: if context.mandateId: return str(context.mandateId) return "" def _assertMandateAdmin(context: RequestContext, mandateId: str) -> None: if context.hasSysAdminRole: return try: from modules.interfaces.interfaceDbApp import getRootInterface rootInterface = getRootInterface() userMandates = rootInterface.getUserMandates(str(context.user.id)) for um in userMandates: if str(getattr(um, "mandateId", None)) != str(mandateId): continue if not getattr(um, "enabled", True): continue umId = str(getattr(um, "id", "")) roleIds = rootInterface.getRoleIdsForUserMandate(umId) for roleId in roleIds: role = rootInterface.getRole(roleId) if role and role.roleLabel == "admin" and not role.featureInstanceId: return except Exception: pass raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate admin role required") # ============================================================================= # Request / Response models # ============================================================================= class ActivatePlanRequest(BaseModel): planKey: str = Field(..., description="Key of the plan to activate") returnUrl: str = Field(..., description="Frontend URL to redirect back to after Stripe Checkout") class CancelRequest(BaseModel): subscriptionId: str = Field(..., description="ID of the subscription to cancel") class ReactivateRequest(BaseModel): subscriptionId: str = Field(..., description="ID of the subscription to reactivate") class ForceCancelRequest(BaseModel): subscriptionId: str = Field(..., description="ID of the subscription to force-cancel") class VerifyCheckoutRequest(BaseModel): sessionId: str = Field(..., description="Stripe Checkout Session ID to verify") class SubscriptionStatusResponse(BaseModel): active: bool subscription: Optional[Dict[str, Any]] = None plan: Optional[Dict[str, Any]] = None scheduled: Optional[Dict[str, Any]] = None # ============================================================================= # Router # ============================================================================= router = APIRouter( prefix="/api/subscription", tags=["Subscription"], responses={404: {"description": "Not found"}}, ) # ============================================================================= # Endpoints # ============================================================================= @router.get("/plans", response_model=List[Dict[str, Any]]) @limiter.limit("30/minute") def getPlans(request: Request, context: RequestContext = Depends(getRequestContext)): from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) try: mandateId = _resolveMandateId(context) subService = getSubscriptionService(context.user, mandateId) plans = subService.getSelectablePlans() return [p.model_dump() for p in plans] except Exception as e: logger.error("Error fetching plans: %s", e) raise HTTPException(status_code=500, detail=str(e)) @router.get("/status", response_model=SubscriptionStatusResponse) @limiter.limit("60/minute") def getStatus(request: Request, context: RequestContext = Depends(getRequestContext)): """Return the operative subscription and any scheduled successor for the current mandate.""" from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) mandateId = _resolveMandateId(context) if not mandateId: return SubscriptionStatusResponse(active=False) _assertMandateAdmin(context, mandateId) try: subService = getSubscriptionService(context.user, mandateId) operative = subService.getOperativeSubscription(mandateId) scheduled = subService.getScheduledSubscription(mandateId) if not operative: from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum pending = subService.listSubscriptions(mandateId, [SubscriptionStatusEnum.PENDING]) if pending: sub = pending[0] plan = subService.getPlan(sub.get("planKey", "")) return SubscriptionStatusResponse( active=False, subscription=sub, plan=plan.model_dump() if plan else None, scheduled=scheduled, ) return SubscriptionStatusResponse(active=False, scheduled=scheduled) plan = subService.getPlan(operative.get("planKey", "")) return SubscriptionStatusResponse( active=True, subscription=operative, plan=plan.model_dump() if plan else None, scheduled=scheduled, ) except Exception as e: logger.error("Error fetching status: %s", e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/activate", response_model=Dict[str, Any]) @limiter.limit("10/minute") def activatePlan( request: Request, data: ActivatePlanRequest, context: RequestContext = Depends(getRequestContext), ): from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) mandateId = _resolveMandateId(context) if not mandateId: raise HTTPException(status_code=400, detail="X-Mandate-Id header required") _assertMandateAdmin(context, mandateId) try: subService = getSubscriptionService(context.user, mandateId) return subService.activatePlan(mandateId, data.planKey, returnUrl=data.returnUrl) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error("Error activating plan %s: %s", data.planKey, e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/cancel", response_model=Dict[str, Any]) @limiter.limit("5/minute") def cancelSubscription( request: Request, data: CancelRequest, context: RequestContext = Depends(getRequestContext), ): """Cancel a specific subscription by its ID.""" from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) mandateId = _resolveMandateId(context) if not mandateId: raise HTTPException(status_code=400, detail="X-Mandate-Id header required") _assertMandateAdmin(context, mandateId) try: subService = getSubscriptionService(context.user, mandateId) return subService.cancelSubscription(data.subscriptionId) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error("Error cancelling subscription %s: %s", data.subscriptionId, e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/reactivate", response_model=Dict[str, Any]) @limiter.limit("5/minute") def reactivateSubscription( request: Request, data: ReactivateRequest, context: RequestContext = Depends(getRequestContext), ): """Reactivate a cancelled (non-recurring) subscription before its period ends.""" from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) mandateId = _resolveMandateId(context) if not mandateId: raise HTTPException(status_code=400, detail="X-Mandate-Id header required") _assertMandateAdmin(context, mandateId) try: subService = getSubscriptionService(context.user, mandateId) return subService.reactivateSubscription(data.subscriptionId) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error("Error reactivating subscription %s: %s", data.subscriptionId, e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/force-cancel", response_model=Dict[str, Any]) @limiter.limit("5/minute") def forceCancel( request: Request, data: ForceCancelRequest, context: RequestContext = Depends(getRequestContext), ): """Sysadmin: immediately expire any non-terminal subscription.""" if not context.hasSysAdminRole: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required") from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( getService as getSubscriptionService, ) from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface sub = getSubRootInterface().getById(data.subscriptionId) if not sub: raise HTTPException(status_code=404, detail="Subscription not found") mandateId = sub["mandateId"] try: subService = getSubscriptionService(context.user, mandateId) return subService.forceCancel(data.subscriptionId) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error("Error force-cancelling subscription %s: %s", data.subscriptionId, e) raise HTTPException(status_code=500, detail=str(e)) @router.post("/checkout/verify", response_model=Dict[str, Any]) @limiter.limit("20/minute") def verifyCheckout( request: Request, data: VerifyCheckoutRequest, context: RequestContext = Depends(getRequestContext), ): """Verify a Stripe Checkout Session and activate the subscription if paid. This is the synchronous counterpart to the checkout.session.completed webhook. It's called by the frontend immediately after returning from Stripe to handle environments where webhooks may be delayed or unavailable (e.g. localhost dev). The logic is idempotent — if the webhook already processed the session, this is a no-op. """ mandateId = _resolveMandateId(context) if not mandateId: raise HTTPException(status_code=400, detail="X-Mandate-Id header required") _assertMandateAdmin(context, mandateId) try: from modules.shared.stripeClient import getStripeClient stripe = getStripeClient() session = stripe.checkout.Session.retrieve(data.sessionId) except Exception as e: logger.error("Failed to retrieve checkout session %s: %s", data.sessionId, e) raise HTTPException(status_code=400, detail="Invalid session ID") if session.get("status") != "complete" or session.get("payment_status") != "paid": return {"status": "pending", "message": "Checkout not yet completed"} if session.get("mode") != "subscription": raise HTTPException(status_code=400, detail="Not a subscription checkout session") from modules.routes.routeBilling import _handleSubscriptionCheckoutCompleted _handleSubscriptionCheckoutCompleted(session, f"verify-{data.sessionId}") return {"status": "activated", "message": "Subscription activated"} # ============================================================================= # SysAdmin: global subscription overview # ============================================================================= def _buildEnrichedSubscriptions() -> List[Dict[str, Any]]: """Build the full enriched subscription list (shared by list + filter-values endpoints).""" from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface from modules.datamodels.datamodelSubscription import BUILTIN_PLANS, OPERATIVE_STATUSES subInterface = getSubRootInterface() allSubs = subInterface.listAll() mandateNames: Dict[str, str] = {} try: from modules.datamodels.datamodelUam import Mandate from modules.security.rootAccess import getRootDbAppConnector appDb = getRootDbAppConnector() for row in appDb.getRecordset(Mandate): r = dict(row) mid = r.get("id", "") mandateNames[mid] = r.get("label") or r.get("name") or mid[:8] except Exception as e: logger.warning("Could not bulk-resolve mandate names: %s", e) operativeValues = {s.value for s in OPERATIVE_STATUSES} enriched = [] for sub in allSubs: mid = sub.get("mandateId", "") planKey = sub.get("planKey", "") plan = BUILTIN_PLANS.get(planKey) sub["mandateName"] = mandateNames.get(mid, mid[:8]) sub["planTitle"] = (plan.title.get("de") or plan.title.get("en") or planKey) if plan else planKey if sub.get("status") in operativeValues: userPrice = sub.get("snapshotPricePerUserCHF", 0) or 0 instPrice = sub.get("snapshotPricePerInstanceCHF", 0) or 0 try: userCount = subInterface.countActiveUsers(mid) instanceCount = subInterface.countActiveFeatureInstances(mid) except Exception: userCount = 0 instanceCount = 0 sub["monthlyRevenueCHF"] = round(userPrice * userCount + instPrice * instanceCount, 2) sub["activeUsers"] = userCount sub["activeInstances"] = instanceCount else: sub["monthlyRevenueCHF"] = 0 sub["activeUsers"] = 0 sub["activeInstances"] = 0 enriched.append(sub) return enriched @router.get("/admin/all") @limiter.limit("30/minute") def getAllSubscriptions( request: Request, pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), context: RequestContext = Depends(getRequestContext), ): """SysAdmin: list ALL subscriptions across all mandates with enriched metadata.""" if not context.hasSysAdminRole: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required") paginationParams: Optional[PaginationParams] = None if pagination: try: paginationDict = json.loads(pagination) if paginationDict: paginationDict = normalize_pagination_dict(paginationDict) paginationParams = PaginationParams(**paginationDict) except (json.JSONDecodeError, ValueError) as e: raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}") enriched = _buildEnrichedSubscriptions() filtered = _applyFiltersAndSort(enriched, paginationParams) if paginationParams: totalItems = len(filtered) totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0 startIdx = (paginationParams.page - 1) * paginationParams.pageSize endIdx = startIdx + paginationParams.pageSize pageItems = filtered[startIdx:endIdx] return { "items": pageItems, "pagination": PaginationMetadata( currentPage=paginationParams.page, pageSize=paginationParams.pageSize, totalItems=totalItems, totalPages=totalPages, sort=paginationParams.sort, filters=paginationParams.filters, ).model_dump(), } return {"items": enriched, "pagination": None} @router.get("/admin/all/filter-values") @limiter.limit("60/minute") def getFilterValues( request: Request, column: str = Query(..., description="Column key to extract distinct values for"), pagination: Optional[str] = Query(None, description="JSON-encoded current filters (applied except for the requested column)"), context: RequestContext = Depends(getRequestContext), ): """Return distinct values for a column, respecting all active filters except the requested one.""" if not context.hasSysAdminRole: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required") crossFilterParams: Optional[PaginationParams] = None if pagination: try: paginationDict = json.loads(pagination) if paginationDict: paginationDict = normalize_pagination_dict(paginationDict) filters = paginationDict.get("filters", {}) filters.pop(column, None) paginationDict["filters"] = filters paginationDict.pop("sort", None) crossFilterParams = PaginationParams(**paginationDict) except (json.JSONDecodeError, ValueError) as e: raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}") enriched = _buildEnrichedSubscriptions() crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams) return _extractDistinctValues(crossFiltered, column)