531 lines
22 KiB
Python
531 lines
22 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Subscription routes — ID-based, state-machine-driven.
|
|
|
|
Endpoints:
|
|
- GET /api/subscription/plans — list selectable plans
|
|
- GET /api/subscription/status — operative + scheduled subscription for current mandate
|
|
- POST /api/subscription/activate — start checkout for a plan
|
|
- POST /api/subscription/cancel — cancel a specific subscription (by ID)
|
|
- POST /api/subscription/reactivate — reactivate a cancelled subscription (by ID)
|
|
- POST /api/subscription/force-cancel — sysadmin immediate cancel (by ID)
|
|
"""
|
|
|
|
from fastapi import APIRouter, HTTPException, Depends, Request, Query, Path
|
|
from fastapi import status
|
|
from typing import Dict, Any, List, Optional
|
|
import logging
|
|
import json
|
|
import math
|
|
from pydantic import BaseModel, Field
|
|
|
|
from modules.auth import limiter, getRequestContext, RequestContext
|
|
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
|
from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _resolveMandateId(context: RequestContext) -> str:
|
|
if context.mandateId:
|
|
return str(context.mandateId)
|
|
return ""
|
|
|
|
|
|
def _assertMandateAdmin(context: RequestContext, mandateId: str) -> None:
|
|
if context.hasSysAdminRole:
|
|
return
|
|
try:
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
rootInterface = getRootInterface()
|
|
userMandates = rootInterface.getUserMandates(str(context.user.id))
|
|
for um in userMandates:
|
|
if str(getattr(um, "mandateId", None)) != str(mandateId):
|
|
continue
|
|
if not getattr(um, "enabled", True):
|
|
continue
|
|
umId = str(getattr(um, "id", ""))
|
|
roleIds = rootInterface.getRoleIdsForUserMandate(umId)
|
|
for roleId in roleIds:
|
|
role = rootInterface.getRole(roleId)
|
|
if role and role.roleLabel == "admin" and not role.featureInstanceId:
|
|
return
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate admin role required")
|
|
|
|
|
|
# =============================================================================
|
|
# Request / Response models
|
|
# =============================================================================
|
|
|
|
class ActivatePlanRequest(BaseModel):
|
|
planKey: str = Field(..., description="Key of the plan to activate")
|
|
returnUrl: str = Field(..., description="Frontend URL to redirect back to after Stripe Checkout")
|
|
|
|
class CancelRequest(BaseModel):
|
|
subscriptionId: str = Field(..., description="ID of the subscription to cancel")
|
|
|
|
class ReactivateRequest(BaseModel):
|
|
subscriptionId: str = Field(..., description="ID of the subscription to reactivate")
|
|
|
|
class ForceCancelRequest(BaseModel):
|
|
subscriptionId: str = Field(..., description="ID of the subscription to force-cancel")
|
|
|
|
class VerifyCheckoutRequest(BaseModel):
|
|
sessionId: str = Field(..., description="Stripe Checkout Session ID to verify")
|
|
|
|
class SubscriptionStatusResponse(BaseModel):
|
|
active: bool
|
|
subscription: Optional[Dict[str, Any]] = None
|
|
plan: Optional[Dict[str, Any]] = None
|
|
scheduled: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
# =============================================================================
|
|
# Router
|
|
# =============================================================================
|
|
|
|
router = APIRouter(
|
|
prefix="/api/subscription",
|
|
tags=["Subscription"],
|
|
responses={404: {"description": "Not found"}},
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Endpoints
|
|
# =============================================================================
|
|
|
|
@router.get("/plans", response_model=List[Dict[str, Any]])
|
|
@limiter.limit("30/minute")
|
|
def getPlans(request: Request, context: RequestContext = Depends(getRequestContext)):
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
|
getService as getSubscriptionService,
|
|
)
|
|
try:
|
|
mandateId = _resolveMandateId(context)
|
|
subService = getSubscriptionService(context.user, mandateId)
|
|
plans = subService.getSelectablePlans()
|
|
return [p.model_dump() for p in plans]
|
|
except Exception as e:
|
|
logger.error("Error fetching plans: %s", e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/status", response_model=SubscriptionStatusResponse)
|
|
@limiter.limit("60/minute")
|
|
def getStatus(request: Request, context: RequestContext = Depends(getRequestContext)):
|
|
"""Return the operative subscription and any scheduled successor for the current mandate."""
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
|
getService as getSubscriptionService,
|
|
)
|
|
mandateId = _resolveMandateId(context)
|
|
if not mandateId:
|
|
return SubscriptionStatusResponse(active=False)
|
|
_assertMandateAdmin(context, mandateId)
|
|
|
|
try:
|
|
subService = getSubscriptionService(context.user, mandateId)
|
|
operative = subService.getOperativeSubscription(mandateId)
|
|
scheduled = subService.getScheduledSubscription(mandateId)
|
|
|
|
if not operative:
|
|
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum
|
|
pending = subService.listSubscriptions(mandateId, [SubscriptionStatusEnum.PENDING])
|
|
if pending:
|
|
sub = pending[0]
|
|
plan = subService.getPlan(sub.get("planKey", ""))
|
|
return SubscriptionStatusResponse(
|
|
active=False,
|
|
subscription=sub,
|
|
plan=plan.model_dump() if plan else None,
|
|
scheduled=scheduled,
|
|
)
|
|
return SubscriptionStatusResponse(active=False, scheduled=scheduled)
|
|
|
|
plan = subService.getPlan(operative.get("planKey", ""))
|
|
return SubscriptionStatusResponse(
|
|
active=True,
|
|
subscription=operative,
|
|
plan=plan.model_dump() if plan else None,
|
|
scheduled=scheduled,
|
|
)
|
|
except Exception as e:
|
|
logger.error("Error fetching status: %s", e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/activate", response_model=Dict[str, Any])
|
|
@limiter.limit("10/minute")
|
|
def activatePlan(
|
|
request: Request,
|
|
data: ActivatePlanRequest,
|
|
context: RequestContext = Depends(getRequestContext),
|
|
):
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
|
getService as getSubscriptionService,
|
|
)
|
|
mandateId = _resolveMandateId(context)
|
|
if not mandateId:
|
|
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
|
|
_assertMandateAdmin(context, mandateId)
|
|
|
|
try:
|
|
subService = getSubscriptionService(context.user, mandateId)
|
|
return subService.activatePlan(mandateId, data.planKey, returnUrl=data.returnUrl)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
logger.error("Error activating plan %s: %s", data.planKey, e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/cancel", response_model=Dict[str, Any])
|
|
@limiter.limit("30/minute")
|
|
def cancelSubscription(
|
|
request: Request,
|
|
data: CancelRequest,
|
|
context: RequestContext = Depends(getRequestContext),
|
|
):
|
|
"""Cancel a specific subscription by its ID."""
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
|
getService as getSubscriptionService,
|
|
)
|
|
mandateId = _resolveMandateId(context)
|
|
if not mandateId:
|
|
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
|
|
_assertMandateAdmin(context, mandateId)
|
|
|
|
try:
|
|
subService = getSubscriptionService(context.user, mandateId)
|
|
return subService.cancelSubscription(data.subscriptionId)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
logger.error("Error cancelling subscription %s: %s", data.subscriptionId, e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/reactivate", response_model=Dict[str, Any])
|
|
@limiter.limit("30/minute")
|
|
def reactivateSubscription(
|
|
request: Request,
|
|
data: ReactivateRequest,
|
|
context: RequestContext = Depends(getRequestContext),
|
|
):
|
|
"""Reactivate a cancelled (non-recurring) subscription before its period ends."""
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
|
getService as getSubscriptionService,
|
|
)
|
|
mandateId = _resolveMandateId(context)
|
|
if not mandateId:
|
|
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
|
|
_assertMandateAdmin(context, mandateId)
|
|
|
|
try:
|
|
subService = getSubscriptionService(context.user, mandateId)
|
|
return subService.reactivateSubscription(data.subscriptionId)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
logger.error("Error reactivating subscription %s: %s", data.subscriptionId, e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/force-cancel", response_model=Dict[str, Any])
|
|
@limiter.limit("30/minute")
|
|
def forceCancel(
|
|
request: Request,
|
|
data: ForceCancelRequest,
|
|
context: RequestContext = Depends(getRequestContext),
|
|
):
|
|
"""Sysadmin: immediately expire any non-terminal subscription."""
|
|
if not context.hasSysAdminRole:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
|
|
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
|
getService as getSubscriptionService,
|
|
)
|
|
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
|
|
sub = getSubRootInterface().getById(data.subscriptionId)
|
|
if not sub:
|
|
raise HTTPException(status_code=404, detail="Subscription not found")
|
|
mandateId = sub["mandateId"]
|
|
|
|
try:
|
|
subService = getSubscriptionService(context.user, mandateId)
|
|
return subService.forceCancel(data.subscriptionId)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
logger.error("Error force-cancelling subscription %s: %s", data.subscriptionId, e)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.post("/checkout/verify", response_model=Dict[str, Any])
|
|
@limiter.limit("20/minute")
|
|
def verifyCheckout(
|
|
request: Request,
|
|
data: VerifyCheckoutRequest,
|
|
context: RequestContext = Depends(getRequestContext),
|
|
):
|
|
"""Verify a Stripe Checkout Session and activate the subscription if paid.
|
|
|
|
Idempotent: if the webhook already processed the session, returns success.
|
|
Called by the frontend immediately after returning from Stripe.
|
|
"""
|
|
mandateId = _resolveMandateId(context)
|
|
if not mandateId:
|
|
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
|
|
_assertMandateAdmin(context, mandateId)
|
|
|
|
try:
|
|
from modules.shared.stripeClient import getStripeClient, stripeToDict
|
|
stripe = getStripeClient()
|
|
rawSession = stripe.checkout.Session.retrieve(data.sessionId)
|
|
session = stripeToDict(rawSession)
|
|
except Exception as e:
|
|
logger.error("Failed to retrieve checkout session %s: %s", data.sessionId, e)
|
|
raise HTTPException(status_code=400, detail="Invalid session ID")
|
|
|
|
payStatus = session.get("payment_status")
|
|
if session.get("status") != "complete":
|
|
return {"status": "pending", "message": "Checkout not yet completed"}
|
|
if payStatus not in ("paid", "no_payment_required"):
|
|
return {"status": "pending", "message": "Checkout not yet completed"}
|
|
|
|
if session.get("mode") != "subscription":
|
|
raise HTTPException(status_code=400, detail="Not a subscription checkout session")
|
|
|
|
from modules.routes.routeBilling import _handleSubscriptionCheckoutCompleted
|
|
|
|
try:
|
|
_handleSubscriptionCheckoutCompleted(session, f"verify-{data.sessionId}")
|
|
except Exception as e:
|
|
logger.warning(
|
|
"verifyCheckout: handler raised for session %s mandate %s: %s",
|
|
data.sessionId,
|
|
mandateId,
|
|
e,
|
|
)
|
|
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
|
getService as getSubscriptionService,
|
|
)
|
|
from modules.datamodels.datamodelSubscription import OPERATIVE_STATUSES
|
|
|
|
subService = getSubscriptionService(context.user, mandateId)
|
|
operative = subService.getOperativeSubscription(mandateId)
|
|
if operative and operative.get("status") in [s.value for s in OPERATIVE_STATUSES]:
|
|
planKey = operative.get("planKey", "")
|
|
if planKey:
|
|
try:
|
|
from modules.interfaces.interfaceDbBilling import _getRootInterface as _getBillingRoot
|
|
_getBillingRoot().ensureActivationBudget(mandateId, planKey)
|
|
except Exception as ex:
|
|
logger.warning("verifyCheckout: ensureActivationBudget failed: %s", ex)
|
|
return {"status": "activated", "message": "Subscription activated"}
|
|
|
|
return {"status": "pending", "message": "Subscription activation pending — webhook may still be processing."}
|
|
|
|
|
|
# =============================================================================
|
|
# SysAdmin: global subscription overview
|
|
# =============================================================================
|
|
|
|
def _buildEnrichedSubscriptions() -> List[Dict[str, Any]]:
|
|
"""Build the full enriched subscription list (shared by list + filter-values endpoints)."""
|
|
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
|
|
from modules.datamodels.datamodelSubscription import BUILTIN_PLANS, OPERATIVE_STATUSES
|
|
|
|
subInterface = getSubRootInterface()
|
|
allSubs = subInterface.listAll()
|
|
|
|
mandateNames: Dict[str, str] = {}
|
|
try:
|
|
from modules.datamodels.datamodelUam import Mandate
|
|
from modules.security.rootAccess import getRootDbAppConnector
|
|
appDb = getRootDbAppConnector()
|
|
for row in appDb.getRecordset(Mandate):
|
|
r = dict(row)
|
|
mid = r.get("id", "")
|
|
mandateNames[mid] = r.get("label") or r.get("name") or mid[:8]
|
|
except Exception as e:
|
|
logger.warning("Could not bulk-resolve mandate names: %s", e)
|
|
|
|
operativeValues = {s.value for s in OPERATIVE_STATUSES}
|
|
|
|
enriched = []
|
|
for sub in allSubs:
|
|
mid = sub.get("mandateId", "")
|
|
planKey = sub.get("planKey", "")
|
|
plan = BUILTIN_PLANS.get(planKey)
|
|
|
|
sub["mandateName"] = mandateNames.get(mid, mid[:8])
|
|
sub["planTitle"] = (plan.title.get("de") or plan.title.get("en") or planKey) if plan else planKey
|
|
|
|
if sub.get("status") in operativeValues:
|
|
userPrice = sub.get("snapshotPricePerUserCHF", 0) or 0
|
|
instPrice = sub.get("snapshotPricePerInstanceCHF", 0) or 0
|
|
try:
|
|
userCount = subInterface.countActiveUsers(mid)
|
|
instanceCount = subInterface.countActiveFeatureInstances(mid)
|
|
except Exception:
|
|
userCount = 0
|
|
instanceCount = 0
|
|
sub["monthlyRevenueCHF"] = round(userPrice * userCount + instPrice * instanceCount, 2)
|
|
sub["activeUsers"] = userCount
|
|
sub["activeInstances"] = instanceCount
|
|
else:
|
|
sub["monthlyRevenueCHF"] = 0
|
|
sub["activeUsers"] = 0
|
|
sub["activeInstances"] = 0
|
|
|
|
enriched.append(sub)
|
|
|
|
return enriched
|
|
|
|
|
|
@router.get("/admin/all")
|
|
@limiter.limit("30/minute")
|
|
def getAllSubscriptions(
|
|
request: Request,
|
|
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
|
context: RequestContext = Depends(getRequestContext),
|
|
):
|
|
"""SysAdmin: list ALL subscriptions across all mandates with enriched metadata."""
|
|
if not context.hasSysAdminRole:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
|
|
|
|
paginationParams: Optional[PaginationParams] = None
|
|
if pagination:
|
|
try:
|
|
paginationDict = json.loads(pagination)
|
|
if paginationDict:
|
|
paginationDict = normalize_pagination_dict(paginationDict)
|
|
paginationParams = PaginationParams(**paginationDict)
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
|
|
|
enriched = _buildEnrichedSubscriptions()
|
|
filtered = _applyFiltersAndSort(enriched, paginationParams)
|
|
|
|
if paginationParams:
|
|
totalItems = len(filtered)
|
|
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
|
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
|
endIdx = startIdx + paginationParams.pageSize
|
|
pageItems = filtered[startIdx:endIdx]
|
|
return {
|
|
"items": pageItems,
|
|
"pagination": PaginationMetadata(
|
|
currentPage=paginationParams.page,
|
|
pageSize=paginationParams.pageSize,
|
|
totalItems=totalItems,
|
|
totalPages=totalPages,
|
|
sort=paginationParams.sort,
|
|
filters=paginationParams.filters,
|
|
).model_dump(),
|
|
}
|
|
|
|
return {"items": enriched, "pagination": None}
|
|
|
|
|
|
@router.get("/admin/all/filter-values")
|
|
@limiter.limit("60/minute")
|
|
def getFilterValues(
|
|
request: Request,
|
|
column: str = Query(..., description="Column key to extract distinct values for"),
|
|
pagination: Optional[str] = Query(None, description="JSON-encoded current filters (applied except for the requested column)"),
|
|
context: RequestContext = Depends(getRequestContext),
|
|
):
|
|
"""Return distinct values for a column, respecting all active filters except the requested one."""
|
|
if not context.hasSysAdminRole:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
|
|
|
|
crossFilterParams: Optional[PaginationParams] = None
|
|
if pagination:
|
|
try:
|
|
paginationDict = json.loads(pagination)
|
|
if paginationDict:
|
|
paginationDict = normalize_pagination_dict(paginationDict)
|
|
filters = paginationDict.get("filters", {})
|
|
filters.pop(column, None)
|
|
paginationDict["filters"] = filters
|
|
paginationDict.pop("sort", None)
|
|
crossFilterParams = PaginationParams(**paginationDict)
|
|
except (json.JSONDecodeError, ValueError) as e:
|
|
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
|
|
|
enriched = _buildEnrichedSubscriptions()
|
|
crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
|
|
|
|
return _extractDistinctValues(crossFiltered, column)
|
|
|
|
|
|
# ============================================================
|
|
# Data Volume Usage per Mandate
|
|
# ============================================================
|
|
|
|
@router.get("/data-volume/{targetMandateId}")
|
|
@limiter.limit("60/minute")
|
|
def _getDataVolumeUsage(
|
|
request: Request,
|
|
targetMandateId: str = Path(..., description="Mandate ID to check volume for"),
|
|
context: RequestContext = Depends(getRequestContext),
|
|
):
|
|
"""Calculate current data volume usage for a mandate vs. plan limit."""
|
|
from modules.interfaces.interfaceDbApp import getRootInterface
|
|
from modules.datamodels.datamodelFiles import FileItem
|
|
from modules.datamodels.datamodelFeatures import FeatureInstance
|
|
from modules.interfaces.interfaceDbKnowledge import aggregateMandateRagTotalBytes
|
|
from modules.interfaces.interfaceDbManagement import getInterface as getMgmtInterface
|
|
from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRootIf
|
|
|
|
rootIf = getRootInterface()
|
|
mandateId = targetMandateId
|
|
|
|
instances = rootIf.db.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId})
|
|
instIds = [str(inst.get("id") or "") for inst in instances if inst.get("id")]
|
|
|
|
mgmtDb = getMgmtInterface().db
|
|
totalFileBytes = 0
|
|
for instId in instIds:
|
|
files = mgmtDb.getRecordset(FileItem, recordFilter={"featureInstanceId": instId})
|
|
for f in files:
|
|
size = f.get("fileSize") if isinstance(f, dict) else getattr(f, "fileSize", 0)
|
|
totalFileBytes += (size or 0)
|
|
mandateFiles = mgmtDb.getRecordset(FileItem, recordFilter={"mandateId": mandateId})
|
|
for f in mandateFiles:
|
|
size = f.get("fileSize") if isinstance(f, dict) else getattr(f, "fileSize", 0)
|
|
totalFileBytes += (size or 0)
|
|
filesMB = round(totalFileBytes / (1024 * 1024), 2)
|
|
|
|
ragBytes = aggregateMandateRagTotalBytes(mandateId)
|
|
ragMB = round(ragBytes / (1024 * 1024), 2)
|
|
|
|
maxMB = None
|
|
subIf = _getSubRootIf()
|
|
operative = subIf.getOperativeForMandate(mandateId)
|
|
if operative:
|
|
plan = subIf.getPlan(operative.get("planKey") or "")
|
|
if plan and plan.maxDataVolumeMB is not None:
|
|
maxMB = int(plan.maxDataVolumeMB)
|
|
|
|
usedMB = ragMB
|
|
percentUsed = round((usedMB / maxMB) * 100, 1) if maxMB else None
|
|
logger.info(
|
|
"data-volume mandate=%s: files=%.2f MB, rag=%.2f MB, max=%s MB",
|
|
mandateId, filesMB, ragMB, maxMB,
|
|
)
|
|
return {
|
|
"mandateId": mandateId,
|
|
"usedMB": usedMB,
|
|
"filesMB": filesMB,
|
|
"ragIndexMB": ragMB,
|
|
"maxDataVolumeMB": maxMB,
|
|
"percentUsed": percentUsed,
|
|
"warning": (percentUsed or 0) >= 80,
|
|
}
|