gateway/modules/routes/routeSubscription.py
ValueOn AG b142c0fa6c NT-Problem: Unhandled exception: get
Root Cause: Die Stripe-Python-Bibliothek auf INT hat keine .get() Methode auf Stripe-Objekten. Wenn session.get("payment_status") aufgerufen wird, sucht Python via __getattr__ nach einem Feld namens "get" → AttributeError("get").

Bestätigung: In routeBilling.py gab es bereits einen hasattr(session, "get")-Check (Zeile 998) — jemand kannte das Problem.

Fix: Alle Stripe-Objekte werden sofort nach dem API-Call in dict() konvertiert
2026-03-31 02:05:16 +02:00

531 lines
22 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Subscription routes — ID-based, state-machine-driven.
Endpoints:
- GET /api/subscription/plans — list selectable plans
- GET /api/subscription/status — operative + scheduled subscription for current mandate
- POST /api/subscription/activate — start checkout for a plan
- POST /api/subscription/cancel — cancel a specific subscription (by ID)
- POST /api/subscription/reactivate — reactivate a cancelled subscription (by ID)
- POST /api/subscription/force-cancel — sysadmin immediate cancel (by ID)
"""
from fastapi import APIRouter, HTTPException, Depends, Request, Query, Path
from fastapi import status
from typing import Dict, Any, List, Optional
import logging
import json
import math
from pydantic import BaseModel, Field
from modules.auth import limiter, getRequestContext, RequestContext
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
logger = logging.getLogger(__name__)
def _resolveMandateId(context: RequestContext) -> str:
if context.mandateId:
return str(context.mandateId)
return ""
def _assertMandateAdmin(context: RequestContext, mandateId: str) -> None:
if context.hasSysAdminRole:
return
try:
from modules.interfaces.interfaceDbApp import getRootInterface
rootInterface = getRootInterface()
userMandates = rootInterface.getUserMandates(str(context.user.id))
for um in userMandates:
if str(getattr(um, "mandateId", None)) != str(mandateId):
continue
if not getattr(um, "enabled", True):
continue
umId = str(getattr(um, "id", ""))
roleIds = rootInterface.getRoleIdsForUserMandate(umId)
for roleId in roleIds:
role = rootInterface.getRole(roleId)
if role and role.roleLabel == "admin" and not role.featureInstanceId:
return
except Exception:
pass
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate admin role required")
# =============================================================================
# Request / Response models
# =============================================================================
class ActivatePlanRequest(BaseModel):
planKey: str = Field(..., description="Key of the plan to activate")
returnUrl: str = Field(..., description="Frontend URL to redirect back to after Stripe Checkout")
class CancelRequest(BaseModel):
subscriptionId: str = Field(..., description="ID of the subscription to cancel")
class ReactivateRequest(BaseModel):
subscriptionId: str = Field(..., description="ID of the subscription to reactivate")
class ForceCancelRequest(BaseModel):
subscriptionId: str = Field(..., description="ID of the subscription to force-cancel")
class VerifyCheckoutRequest(BaseModel):
sessionId: str = Field(..., description="Stripe Checkout Session ID to verify")
class SubscriptionStatusResponse(BaseModel):
active: bool
subscription: Optional[Dict[str, Any]] = None
plan: Optional[Dict[str, Any]] = None
scheduled: Optional[Dict[str, Any]] = None
# =============================================================================
# Router
# =============================================================================
router = APIRouter(
prefix="/api/subscription",
tags=["Subscription"],
responses={404: {"description": "Not found"}},
)
# =============================================================================
# Endpoints
# =============================================================================
@router.get("/plans", response_model=List[Dict[str, Any]])
@limiter.limit("30/minute")
def getPlans(request: Request, context: RequestContext = Depends(getRequestContext)):
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
try:
mandateId = _resolveMandateId(context)
subService = getSubscriptionService(context.user, mandateId)
plans = subService.getSelectablePlans()
return [p.model_dump() for p in plans]
except Exception as e:
logger.error("Error fetching plans: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/status", response_model=SubscriptionStatusResponse)
@limiter.limit("60/minute")
def getStatus(request: Request, context: RequestContext = Depends(getRequestContext)):
"""Return the operative subscription and any scheduled successor for the current mandate."""
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
mandateId = _resolveMandateId(context)
if not mandateId:
return SubscriptionStatusResponse(active=False)
_assertMandateAdmin(context, mandateId)
try:
subService = getSubscriptionService(context.user, mandateId)
operative = subService.getOperativeSubscription(mandateId)
scheduled = subService.getScheduledSubscription(mandateId)
if not operative:
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum
pending = subService.listSubscriptions(mandateId, [SubscriptionStatusEnum.PENDING])
if pending:
sub = pending[0]
plan = subService.getPlan(sub.get("planKey", ""))
return SubscriptionStatusResponse(
active=False,
subscription=sub,
plan=plan.model_dump() if plan else None,
scheduled=scheduled,
)
return SubscriptionStatusResponse(active=False, scheduled=scheduled)
plan = subService.getPlan(operative.get("planKey", ""))
return SubscriptionStatusResponse(
active=True,
subscription=operative,
plan=plan.model_dump() if plan else None,
scheduled=scheduled,
)
except Exception as e:
logger.error("Error fetching status: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/activate", response_model=Dict[str, Any])
@limiter.limit("10/minute")
def activatePlan(
request: Request,
data: ActivatePlanRequest,
context: RequestContext = Depends(getRequestContext),
):
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
mandateId = _resolveMandateId(context)
if not mandateId:
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
_assertMandateAdmin(context, mandateId)
try:
subService = getSubscriptionService(context.user, mandateId)
return subService.activatePlan(mandateId, data.planKey, returnUrl=data.returnUrl)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Error activating plan %s: %s", data.planKey, e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/cancel", response_model=Dict[str, Any])
@limiter.limit("30/minute")
def cancelSubscription(
request: Request,
data: CancelRequest,
context: RequestContext = Depends(getRequestContext),
):
"""Cancel a specific subscription by its ID."""
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
mandateId = _resolveMandateId(context)
if not mandateId:
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
_assertMandateAdmin(context, mandateId)
try:
subService = getSubscriptionService(context.user, mandateId)
return subService.cancelSubscription(data.subscriptionId)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Error cancelling subscription %s: %s", data.subscriptionId, e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/reactivate", response_model=Dict[str, Any])
@limiter.limit("30/minute")
def reactivateSubscription(
request: Request,
data: ReactivateRequest,
context: RequestContext = Depends(getRequestContext),
):
"""Reactivate a cancelled (non-recurring) subscription before its period ends."""
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
mandateId = _resolveMandateId(context)
if not mandateId:
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
_assertMandateAdmin(context, mandateId)
try:
subService = getSubscriptionService(context.user, mandateId)
return subService.reactivateSubscription(data.subscriptionId)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Error reactivating subscription %s: %s", data.subscriptionId, e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/force-cancel", response_model=Dict[str, Any])
@limiter.limit("30/minute")
def forceCancel(
request: Request,
data: ForceCancelRequest,
context: RequestContext = Depends(getRequestContext),
):
"""Sysadmin: immediately expire any non-terminal subscription."""
if not context.hasSysAdminRole:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
sub = getSubRootInterface().getById(data.subscriptionId)
if not sub:
raise HTTPException(status_code=404, detail="Subscription not found")
mandateId = sub["mandateId"]
try:
subService = getSubscriptionService(context.user, mandateId)
return subService.forceCancel(data.subscriptionId)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error("Error force-cancelling subscription %s: %s", data.subscriptionId, e)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/checkout/verify", response_model=Dict[str, Any])
@limiter.limit("20/minute")
def verifyCheckout(
request: Request,
data: VerifyCheckoutRequest,
context: RequestContext = Depends(getRequestContext),
):
"""Verify a Stripe Checkout Session and activate the subscription if paid.
Idempotent: if the webhook already processed the session, returns success.
Called by the frontend immediately after returning from Stripe.
"""
mandateId = _resolveMandateId(context)
if not mandateId:
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
_assertMandateAdmin(context, mandateId)
try:
from modules.shared.stripeClient import getStripeClient
stripe = getStripeClient()
rawSession = stripe.checkout.Session.retrieve(data.sessionId)
session = dict(rawSession)
except Exception as e:
logger.error("Failed to retrieve checkout session %s: %s", data.sessionId, e)
raise HTTPException(status_code=400, detail="Invalid session ID")
payStatus = session.get("payment_status")
if session.get("status") != "complete":
return {"status": "pending", "message": "Checkout not yet completed"}
if payStatus not in ("paid", "no_payment_required"):
return {"status": "pending", "message": "Checkout not yet completed"}
if session.get("mode") != "subscription":
raise HTTPException(status_code=400, detail="Not a subscription checkout session")
from modules.routes.routeBilling import _handleSubscriptionCheckoutCompleted
try:
_handleSubscriptionCheckoutCompleted(session, f"verify-{data.sessionId}")
except Exception as e:
logger.warning(
"verifyCheckout: handler raised for session %s mandate %s: %s",
data.sessionId,
mandateId,
e,
)
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
getService as getSubscriptionService,
)
from modules.datamodels.datamodelSubscription import OPERATIVE_STATUSES
subService = getSubscriptionService(context.user, mandateId)
operative = subService.getOperativeSubscription(mandateId)
if operative and operative.get("status") in [s.value for s in OPERATIVE_STATUSES]:
planKey = operative.get("planKey", "")
if planKey:
try:
from modules.interfaces.interfaceDbBilling import _getRootInterface as _getBillingRoot
_getBillingRoot().ensureActivationBudget(mandateId, planKey)
except Exception as ex:
logger.warning("verifyCheckout: ensureActivationBudget failed: %s", ex)
return {"status": "activated", "message": "Subscription activated"}
return {"status": "pending", "message": "Subscription activation pending — webhook may still be processing."}
# =============================================================================
# SysAdmin: global subscription overview
# =============================================================================
def _buildEnrichedSubscriptions() -> List[Dict[str, Any]]:
"""Build the full enriched subscription list (shared by list + filter-values endpoints)."""
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
from modules.datamodels.datamodelSubscription import BUILTIN_PLANS, OPERATIVE_STATUSES
subInterface = getSubRootInterface()
allSubs = subInterface.listAll()
mandateNames: Dict[str, str] = {}
try:
from modules.datamodels.datamodelUam import Mandate
from modules.security.rootAccess import getRootDbAppConnector
appDb = getRootDbAppConnector()
for row in appDb.getRecordset(Mandate):
r = dict(row)
mid = r.get("id", "")
mandateNames[mid] = r.get("label") or r.get("name") or mid[:8]
except Exception as e:
logger.warning("Could not bulk-resolve mandate names: %s", e)
operativeValues = {s.value for s in OPERATIVE_STATUSES}
enriched = []
for sub in allSubs:
mid = sub.get("mandateId", "")
planKey = sub.get("planKey", "")
plan = BUILTIN_PLANS.get(planKey)
sub["mandateName"] = mandateNames.get(mid, mid[:8])
sub["planTitle"] = (plan.title.get("de") or plan.title.get("en") or planKey) if plan else planKey
if sub.get("status") in operativeValues:
userPrice = sub.get("snapshotPricePerUserCHF", 0) or 0
instPrice = sub.get("snapshotPricePerInstanceCHF", 0) or 0
try:
userCount = subInterface.countActiveUsers(mid)
instanceCount = subInterface.countActiveFeatureInstances(mid)
except Exception:
userCount = 0
instanceCount = 0
sub["monthlyRevenueCHF"] = round(userPrice * userCount + instPrice * instanceCount, 2)
sub["activeUsers"] = userCount
sub["activeInstances"] = instanceCount
else:
sub["monthlyRevenueCHF"] = 0
sub["activeUsers"] = 0
sub["activeInstances"] = 0
enriched.append(sub)
return enriched
@router.get("/admin/all")
@limiter.limit("30/minute")
def getAllSubscriptions(
request: Request,
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
context: RequestContext = Depends(getRequestContext),
):
"""SysAdmin: list ALL subscriptions across all mandates with enriched metadata."""
if not context.hasSysAdminRole:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
paginationParams: Optional[PaginationParams] = None
if pagination:
try:
paginationDict = json.loads(pagination)
if paginationDict:
paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
enriched = _buildEnrichedSubscriptions()
filtered = _applyFiltersAndSort(enriched, paginationParams)
if paginationParams:
totalItems = len(filtered)
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
endIdx = startIdx + paginationParams.pageSize
pageItems = filtered[startIdx:endIdx]
return {
"items": pageItems,
"pagination": PaginationMetadata(
currentPage=paginationParams.page,
pageSize=paginationParams.pageSize,
totalItems=totalItems,
totalPages=totalPages,
sort=paginationParams.sort,
filters=paginationParams.filters,
).model_dump(),
}
return {"items": enriched, "pagination": None}
@router.get("/admin/all/filter-values")
@limiter.limit("60/minute")
def getFilterValues(
request: Request,
column: str = Query(..., description="Column key to extract distinct values for"),
pagination: Optional[str] = Query(None, description="JSON-encoded current filters (applied except for the requested column)"),
context: RequestContext = Depends(getRequestContext),
):
"""Return distinct values for a column, respecting all active filters except the requested one."""
if not context.hasSysAdminRole:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
crossFilterParams: Optional[PaginationParams] = None
if pagination:
try:
paginationDict = json.loads(pagination)
if paginationDict:
paginationDict = normalize_pagination_dict(paginationDict)
filters = paginationDict.get("filters", {})
filters.pop(column, None)
paginationDict["filters"] = filters
paginationDict.pop("sort", None)
crossFilterParams = PaginationParams(**paginationDict)
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
enriched = _buildEnrichedSubscriptions()
crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
return _extractDistinctValues(crossFiltered, column)
# ============================================================
# Data Volume Usage per Mandate
# ============================================================
@router.get("/data-volume/{targetMandateId}")
@limiter.limit("60/minute")
def _getDataVolumeUsage(
request: Request,
targetMandateId: str = Path(..., description="Mandate ID to check volume for"),
context: RequestContext = Depends(getRequestContext),
):
"""Calculate current data volume usage for a mandate vs. plan limit."""
from modules.interfaces.interfaceDbApp import getRootInterface
from modules.datamodels.datamodelFiles import FileItem
from modules.datamodels.datamodelFeatures import FeatureInstance
from modules.interfaces.interfaceDbKnowledge import aggregateMandateRagTotalBytes
from modules.interfaces.interfaceDbManagement import getInterface as getMgmtInterface
from modules.interfaces.interfaceDbSubscription import _getRootInterface as _getSubRootIf
rootIf = getRootInterface()
mandateId = targetMandateId
instances = rootIf.db.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId})
instIds = [str(inst.get("id") or "") for inst in instances if inst.get("id")]
mgmtDb = getMgmtInterface().db
totalFileBytes = 0
for instId in instIds:
files = mgmtDb.getRecordset(FileItem, recordFilter={"featureInstanceId": instId})
for f in files:
size = f.get("fileSize") if isinstance(f, dict) else getattr(f, "fileSize", 0)
totalFileBytes += (size or 0)
mandateFiles = mgmtDb.getRecordset(FileItem, recordFilter={"mandateId": mandateId})
for f in mandateFiles:
size = f.get("fileSize") if isinstance(f, dict) else getattr(f, "fileSize", 0)
totalFileBytes += (size or 0)
filesMB = round(totalFileBytes / (1024 * 1024), 2)
ragBytes = aggregateMandateRagTotalBytes(mandateId)
ragMB = round(ragBytes / (1024 * 1024), 2)
maxMB = None
subIf = _getSubRootIf()
operative = subIf.getOperativeForMandate(mandateId)
if operative:
plan = subIf.getPlan(operative.get("planKey") or "")
if plan and plan.maxDataVolumeMB is not None:
maxMB = int(plan.maxDataVolumeMB)
usedMB = ragMB
percentUsed = round((usedMB / maxMB) * 100, 1) if maxMB else None
logger.info(
"data-volume mandate=%s: files=%.2f MB, rag=%.2f MB, max=%s MB",
mandateId, filesMB, ragMB, maxMB,
)
return {
"mandateId": mandateId,
"usedMB": usedMB,
"filesMB": filesMB,
"ragIndexMB": ragMB,
"maxDataVolumeMB": maxMB,
"percentUsed": percentUsed,
"warning": (percentUsed or 0) >= 80,
}