gateway/modules/routes/routeSubscription.py

437 lines
18 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Subscription routes — ID-based, state-machine-driven.
Endpoints:
- GET /api/subscription/plans — list selectable plans
- GET /api/subscription/status — operative + scheduled subscription for current mandate
- POST /api/subscription/activate — start checkout for a plan
- POST /api/subscription/cancel — cancel a specific subscription (by ID)
- POST /api/subscription/reactivate — reactivate a cancelled subscription (by ID)
- POST /api/subscription/force-cancel — sysadmin immediate cancel (by ID)
"""
from fastapi import APIRouter, HTTPException, Depends, Request, Query
from fastapi import status
from typing import Dict, Any, List, Optional
import logging
import json
import math
from pydantic import BaseModel, Field
from modules.auth import limiter, getRequestContext, RequestContext
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
logger = logging.getLogger(__name__)
def _resolveMandateId(context: RequestContext) -> str:
if context.mandateId:
return str(context.mandateId)
return ""
def _assertMandateAdmin(context: RequestContext, mandateId: str) -> None:
if context.hasSysAdminRole:
return
try:
from modules.interfaces.interfaceDbApp import getRootInterface
rootInterface = getRootInterface()
userMandates = rootInterface.getUserMandates(str(context.user.id))
for um in userMandates:
if str(getattr(um, "mandateId", None)) != str(mandateId):
continue
if not getattr(um, "enabled", True):
continue
umId = str(getattr(um, "id", ""))
roleIds = rootInterface.getRoleIdsForUserMandate(umId)
for roleId in roleIds:
role = rootInterface.getRole(roleId)
if role and role.roleLabel == "admin" and not role.featureInstanceId:
return
except Exception:
pass
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate admin role required")
# =============================================================================
# Request / Response models
# =============================================================================
class ActivatePlanRequest(BaseModel):
planKey: str = Field(..., description="Key of the plan to activate")
returnUrl: str = Field(..., description="Frontend URL to redirect back to after Stripe Checkout")
class CancelRequest(BaseModel):
subscriptionId: str = Field(..., description="ID of the subscription to cancel")
class ReactivateRequest(BaseModel):
subscriptionId: str = Field(..., description="ID of the subscription to reactivate")
class ForceCancelRequest(BaseModel):
subscriptionId: str = Field(..., description="ID of the subscription to force-cancel")
class VerifyCheckoutRequest(BaseModel):
sessionId: str = Field(..., description="Stripe Checkout Session ID to verify")
class SubscriptionStatusResponse(BaseModel):
active: bool
subscription: Optional[Dict[str, Any]] = None
plan: Optional[Dict[str, Any]] = None
scheduled: Optional[Dict[str, Any]] = None
# =============================================================================
# Router
# =============================================================================
router = APIRouter(
prefix="/api/subscription",
tags=["Subscription"],
responses={404: {"description": "Not found"}},
)
# =============================================================================
# Endpoints
# =============================================================================
@router.get("/plans", response_model=List[Dict[str, Any]])
@limiter.limit("30/minute")
def getPlans(request: Request, context: RequestContext = Depends(getRequestContext)):
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
try:
mandateId = _resolveMandateId(context)
subService = getSubscriptionService(context.user, mandateId)
plans = subService.getSelectablePlans()
return [p.model_dump() for p in plans]
except Exception as e:
logger.error("Error fetching plans: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/status", response_model=SubscriptionStatusResponse)
@limiter.limit("60/minute")
def getStatus(request: Request, context: RequestContext = Depends(getRequestContext)):
"""Return the operative subscription and any scheduled successor for the current mandate."""
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
mandateId = _resolveMandateId(context)
if not mandateId:
return SubscriptionStatusResponse(active=False)
_assertMandateAdmin(context, mandateId)
try:
subService = getSubscriptionService(context.user, mandateId)
operative = subService.getOperativeSubscription(mandateId)
scheduled = subService.getScheduledSubscription(mandateId)
if not operative:
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum
pending = subService.listSubscriptions(mandateId, [SubscriptionStatusEnum.PENDING])
if pending:
sub = pending[0]
plan = subService.getPlan(sub.get("planKey", ""))
return SubscriptionStatusResponse(
active=False,
subscription=sub,
plan=plan.model_dump() if plan else None,
scheduled=scheduled,
)
return SubscriptionStatusResponse(active=False, scheduled=scheduled)
plan = subService.getPlan(operative.get("planKey", ""))
return SubscriptionStatusResponse(
active=True,
subscription=operative,
plan=plan.model_dump() if plan else None,
scheduled=scheduled,
)
except Exception as e:
logger.error("Error fetching status: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/activate", response_model=Dict[str, Any])
@limiter.limit("10/minute")
def activatePlan(
request: Request,
data: ActivatePlanRequest,
context: RequestContext = Depends(getRequestContext),
):
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
mandateId = _resolveMandateId(context)
if not mandateId:
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
_assertMandateAdmin(context, mandateId)
try:
subService = getSubscriptionService(context.user, mandateId)
return subService.activatePlan(mandateId, data.planKey, returnUrl=data.returnUrl)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Error activating plan %s: %s", data.planKey, e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/cancel", response_model=Dict[str, Any])
@limiter.limit("5/minute")
def cancelSubscription(
request: Request,
data: CancelRequest,
context: RequestContext = Depends(getRequestContext),
):
"""Cancel a specific subscription by its ID."""
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
mandateId = _resolveMandateId(context)
if not mandateId:
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
_assertMandateAdmin(context, mandateId)
try:
subService = getSubscriptionService(context.user, mandateId)
return subService.cancelSubscription(data.subscriptionId)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Error cancelling subscription %s: %s", data.subscriptionId, e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/reactivate", response_model=Dict[str, Any])
@limiter.limit("5/minute")
def reactivateSubscription(
request: Request,
data: ReactivateRequest,
context: RequestContext = Depends(getRequestContext),
):
"""Reactivate a cancelled (non-recurring) subscription before its period ends."""
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
mandateId = _resolveMandateId(context)
if not mandateId:
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
_assertMandateAdmin(context, mandateId)
try:
subService = getSubscriptionService(context.user, mandateId)
return subService.reactivateSubscription(data.subscriptionId)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Error reactivating subscription %s: %s", data.subscriptionId, e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/force-cancel", response_model=Dict[str, Any])
@limiter.limit("5/minute")
def forceCancel(
request: Request,
data: ForceCancelRequest,
context: RequestContext = Depends(getRequestContext),
):
"""Sysadmin: immediately expire any non-terminal subscription."""
if not context.hasSysAdminRole:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
sub = getSubRootInterface().getById(data.subscriptionId)
if not sub:
raise HTTPException(status_code=404, detail="Subscription not found")
mandateId = sub["mandateId"]
try:
subService = getSubscriptionService(context.user, mandateId)
return subService.forceCancel(data.subscriptionId)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Error force-cancelling subscription %s: %s", data.subscriptionId, e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/checkout/verify", response_model=Dict[str, Any])
@limiter.limit("20/minute")
def verifyCheckout(
request: Request,
data: VerifyCheckoutRequest,
context: RequestContext = Depends(getRequestContext),
):
"""Verify a Stripe Checkout Session and activate the subscription if paid.
This is the synchronous counterpart to the checkout.session.completed webhook.
It's called by the frontend immediately after returning from Stripe to handle
environments where webhooks may be delayed or unavailable (e.g. localhost dev).
The logic is idempotent — if the webhook already processed the session, this is a no-op.
"""
mandateId = _resolveMandateId(context)
if not mandateId:
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
_assertMandateAdmin(context, mandateId)
try:
from modules.shared.stripeClient import getStripeClient
stripe = getStripeClient()
session = stripe.checkout.Session.retrieve(data.sessionId)
except Exception as e:
logger.error("Failed to retrieve checkout session %s: %s", data.sessionId, e)
raise HTTPException(status_code=400, detail="Invalid session ID")
if session.get("status") != "complete" or session.get("payment_status") != "paid":
return {"status": "pending", "message": "Checkout not yet completed"}
if session.get("mode") != "subscription":
raise HTTPException(status_code=400, detail="Not a subscription checkout session")
from modules.routes.routeBilling import _handleSubscriptionCheckoutCompleted
_handleSubscriptionCheckoutCompleted(session, f"verify-{data.sessionId}")
return {"status": "activated", "message": "Subscription activated"}
# =============================================================================
# SysAdmin: global subscription overview
# =============================================================================
def _buildEnrichedSubscriptions() -> List[Dict[str, Any]]:
"""Build the full enriched subscription list (shared by list + filter-values endpoints)."""
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
from modules.datamodels.datamodelSubscription import BUILTIN_PLANS, OPERATIVE_STATUSES
subInterface = getSubRootInterface()
allSubs = subInterface.listAll()
mandateNames: Dict[str, str] = {}
try:
from modules.datamodels.datamodelUam import Mandate
from modules.security.rootAccess import getRootDbAppConnector
appDb = getRootDbAppConnector()
for row in appDb.getRecordset(Mandate):
r = dict(row)
mid = r.get("id", "")
mandateNames[mid] = r.get("label") or r.get("name") or mid[:8]
except Exception as e:
logger.warning("Could not bulk-resolve mandate names: %s", e)
operativeValues = {s.value for s in OPERATIVE_STATUSES}
enriched = []
for sub in allSubs:
mid = sub.get("mandateId", "")
planKey = sub.get("planKey", "")
plan = BUILTIN_PLANS.get(planKey)
sub["mandateName"] = mandateNames.get(mid, mid[:8])
sub["planTitle"] = (plan.title.get("de") or plan.title.get("en") or planKey) if plan else planKey
if sub.get("status") in operativeValues:
userPrice = sub.get("snapshotPricePerUserCHF", 0) or 0
instPrice = sub.get("snapshotPricePerInstanceCHF", 0) or 0
try:
userCount = subInterface.countActiveUsers(mid)
instanceCount = subInterface.countActiveFeatureInstances(mid)
except Exception:
userCount = 0
instanceCount = 0
sub["monthlyRevenueCHF"] = round(userPrice * userCount + instPrice * instanceCount, 2)
sub["activeUsers"] = userCount
sub["activeInstances"] = instanceCount
else:
sub["monthlyRevenueCHF"] = 0
sub["activeUsers"] = 0
sub["activeInstances"] = 0
enriched.append(sub)
return enriched
@router.get("/admin/all")
@limiter.limit("30/minute")
def getAllSubscriptions(
request: Request,
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
context: RequestContext = Depends(getRequestContext),
):
"""SysAdmin: list ALL subscriptions across all mandates with enriched metadata."""
if not context.hasSysAdminRole:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
paginationParams: Optional[PaginationParams] = None
if pagination:
try:
paginationDict = json.loads(pagination)
if paginationDict:
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
enriched = _buildEnrichedSubscriptions()
filtered = _applyFiltersAndSort(enriched, paginationParams)
if paginationParams:
totalItems = len(filtered)
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
endIdx = startIdx + paginationParams.pageSize
pageItems = filtered[startIdx:endIdx]
return {
"items": pageItems,
"pagination": PaginationMetadata(
currentPage=paginationParams.page,
pageSize=paginationParams.pageSize,
totalItems=totalItems,
totalPages=totalPages,
sort=paginationParams.sort,
filters=paginationParams.filters,
).model_dump(),
}
return {"items": enriched, "pagination": None}
@router.get("/admin/all/filter-values")
@limiter.limit("60/minute")
def getFilterValues(
request: Request,
column: str = Query(..., description="Column key to extract distinct values for"),
pagination: Optional[str] = Query(None, description="JSON-encoded current filters (applied except for the requested column)"),
context: RequestContext = Depends(getRequestContext),
):
"""Return distinct values for a column, respecting all active filters except the requested one."""
if not context.hasSysAdminRole:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
crossFilterParams: Optional[PaginationParams] = None
if pagination:
try:
paginationDict = json.loads(pagination)
if paginationDict:
paginationDict = normalize_pagination_dict(paginationDict)
filters = paginationDict.get("filters", {})
filters.pop(column, None)
paginationDict["filters"] = filters
paginationDict.pop("sort", None)
crossFilterParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
enriched = _buildEnrichedSubscriptions()
crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
return _extractDistinctValues(crossFiltered, column)