399 lines
18 KiB
Python
399 lines
18 KiB
Python
# Copyright (c) 2026 PowerOn AG
|
||
# All rights reserved.
|
||
"""
|
||
Stripe webhook and subscription business logic for billing.
|
||
Handles checkout credit, subscription lifecycle transitions, and invoice events.
|
||
"""
|
||
|
||
import logging
|
||
from datetime import datetime, timezone
|
||
from typing import Any, Dict, Optional
|
||
|
||
from fastapi import HTTPException
|
||
|
||
from modules.datamodels.datamodelBilling import (
|
||
BillingTransaction,
|
||
TransactionTypeEnum,
|
||
ReferenceTypeEnum,
|
||
)
|
||
from modules.shared.i18nRegistry import apiRouteContext
|
||
|
||
routeApiMsg = apiRouteContext("routeBilling")
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def creditStripeSessionIfNeeded(
|
||
billingInterface,
|
||
session: Dict[str, Any],
|
||
eventId: Optional[str] = None,
|
||
CheckoutConfirmResponse=None,
|
||
):
|
||
"""Credit balance from Stripe Checkout session if not already credited.
|
||
Uses Checkout session ID for idempotency across webhook + manual confirmation flows.
|
||
"""
|
||
from modules.serviceCenter.services.serviceBilling.stripeCheckout import ALLOWED_AMOUNTS_CHF
|
||
|
||
session_id = session.get("id")
|
||
metadata = session.get("metadata") or {}
|
||
mandate_id = metadata.get("mandateId")
|
||
user_id = metadata.get("userId") or None
|
||
amount_chf_str = metadata.get("amountChf", "0")
|
||
|
||
if not session_id:
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Stripe session id missing"))
|
||
if not mandate_id:
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Invalid session metadata: mandateId missing"))
|
||
|
||
existing_payment_tx = billingInterface.getPaymentTransactionByReferenceId(session_id)
|
||
if existing_payment_tx:
|
||
if eventId and not billingInterface.getStripeWebhookEventByEventId(eventId):
|
||
billingInterface.createStripeWebhookEvent(eventId)
|
||
return CheckoutConfirmResponse(
|
||
credited=False,
|
||
alreadyCredited=True,
|
||
sessionId=session_id,
|
||
mandateId=mandate_id,
|
||
amountChf=float(existing_payment_tx.get("amount", 0.0)),
|
||
)
|
||
|
||
try:
|
||
amount_chf = float(amount_chf_str)
|
||
except (TypeError, ValueError):
|
||
amount_chf = None
|
||
|
||
if amount_chf is None or amount_chf not in ALLOWED_AMOUNTS_CHF:
|
||
amount_total = session.get("amount_total")
|
||
if amount_total is not None:
|
||
amount_chf = amount_total / 100.0
|
||
else:
|
||
raise HTTPException(status_code=400, detail=routeApiMsg("Invalid amount in Stripe session"))
|
||
|
||
settings = billingInterface.getSettings(mandate_id)
|
||
if not settings:
|
||
raise HTTPException(status_code=404, detail=routeApiMsg("Billing settings not found"))
|
||
|
||
account = billingInterface.getOrCreateMandateAccount(mandate_id, initialBalance=0.0)
|
||
|
||
transaction = BillingTransaction(
|
||
accountId=account["id"],
|
||
transactionType=TransactionTypeEnum.CREDIT,
|
||
amount=amount_chf,
|
||
description="Stripe-Zahlung",
|
||
referenceType=ReferenceTypeEnum.PAYMENT,
|
||
referenceId=session_id,
|
||
createdByUserId=user_id,
|
||
)
|
||
billingInterface.createTransaction(transaction)
|
||
|
||
if eventId and not billingInterface.getStripeWebhookEventByEventId(eventId):
|
||
billingInterface.createStripeWebhookEvent(eventId)
|
||
|
||
logger.info(f"Stripe credit applied: {amount_chf} CHF for session {session_id} on mandate {mandate_id}")
|
||
return CheckoutConfirmResponse(
|
||
credited=True,
|
||
alreadyCredited=False,
|
||
sessionId=session_id,
|
||
mandateId=mandate_id,
|
||
amountChf=amount_chf,
|
||
)
|
||
|
||
|
||
def handleSubscriptionCheckoutCompleted(session, eventId: str, getRootInterface) -> None:
|
||
"""Handle checkout.session.completed for mode=subscription.
|
||
Resolves the local PENDING record by ID from webhook metadata and transitions it."""
|
||
from modules.interfaces.interfaceDbSubscription import getRootInterface as getSubRootInterface
|
||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, getPlan
|
||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||
getService as getSubscriptionService,
|
||
_notifySubscriptionChange,
|
||
)
|
||
from modules.security.rootAccess import getRootUser
|
||
|
||
if not isinstance(session, dict):
|
||
from modules.shared.stripeClient import stripeToDict
|
||
session = stripeToDict(session)
|
||
|
||
metadata = session.get("metadata") or {}
|
||
subscriptionRecordId = metadata.get("subscriptionRecordId")
|
||
mandateId = metadata.get("mandateId")
|
||
planKey = metadata.get("planKey", "")
|
||
platformUrl = metadata.get("platformUrl", "")
|
||
|
||
if not subscriptionRecordId:
|
||
stripeSub = session.get("subscription")
|
||
if stripeSub:
|
||
try:
|
||
from modules.shared.stripeClient import getStripeClient
|
||
stripe = getStripeClient()
|
||
from modules.shared.stripeClient import stripeToDict
|
||
subObj = stripeToDict(stripe.Subscription.retrieve(stripeSub))
|
||
metadata = subObj.get("metadata") or {}
|
||
subscriptionRecordId = metadata.get("subscriptionRecordId")
|
||
mandateId = metadata.get("mandateId")
|
||
planKey = metadata.get("planKey", "")
|
||
platformUrl = platformUrl or metadata.get("platformUrl", "")
|
||
except Exception as e:
|
||
logger.error(
|
||
"Stripe Subscription.retrieve(%s) failed during checkout "
|
||
"metadata recovery: %s", stripeSub, e,
|
||
)
|
||
raise
|
||
|
||
stripeSubId = session.get("subscription")
|
||
|
||
if not mandateId or not subscriptionRecordId:
|
||
logger.warning("Subscription checkout missing metadata: %s", metadata)
|
||
return
|
||
|
||
subInterface = getSubRootInterface()
|
||
rootUser = getRootUser()
|
||
|
||
sub = subInterface.getById(subscriptionRecordId)
|
||
if not sub:
|
||
logger.error("Subscription record %s not found for checkout webhook", subscriptionRecordId)
|
||
return
|
||
if sub.get("status") != SubscriptionStatusEnum.PENDING.value:
|
||
logger.warning("Subscription %s is %s, expected PENDING — skipping", subscriptionRecordId, sub.get("status"))
|
||
return
|
||
|
||
stripeData: Dict[str, Any] = {}
|
||
if stripeSubId:
|
||
stripeData["stripeSubscriptionId"] = stripeSubId
|
||
try:
|
||
from modules.shared.stripeClient import getStripeClient
|
||
stripe = getStripeClient()
|
||
from modules.shared.stripeClient import stripeToDict
|
||
stripeSub = stripeToDict(stripe.Subscription.retrieve(stripeSubId, expand=["items"]))
|
||
|
||
if stripeSub.get("current_period_start"):
|
||
stripeData["currentPeriodStart"] = float(stripeSub["current_period_start"])
|
||
if stripeSub.get("current_period_end"):
|
||
stripeData["currentPeriodEnd"] = float(stripeSub["current_period_end"])
|
||
|
||
from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import getStripePricesForPlan
|
||
priceMapping = getStripePricesForPlan(planKey)
|
||
items = stripeSub.get("items") or {}
|
||
if not isinstance(items, dict):
|
||
items = dict(items)
|
||
for item in items.get("data", []):
|
||
priceId = (item.get("price") or {}).get("id", "")
|
||
if priceMapping and priceId == priceMapping.stripePriceIdUsers:
|
||
stripeData["stripeItemIdUsers"] = item["id"]
|
||
elif priceMapping and priceId == priceMapping.stripePriceIdInstances:
|
||
stripeData["stripeItemIdInstances"] = item["id"]
|
||
except Exception as e:
|
||
logger.error(
|
||
"Error retrieving Stripe subscription %s during checkout "
|
||
"completion (will be retried by Stripe): %s",
|
||
stripeSubId, e,
|
||
)
|
||
raise
|
||
|
||
if stripeData:
|
||
subInterface.updateFields(subscriptionRecordId, stripeData)
|
||
|
||
operative = subInterface.getOperativeForMandate(mandateId)
|
||
hasActivePredecessor = operative is not None and operative["id"] != subscriptionRecordId
|
||
predecessorIsTrial = (
|
||
hasActivePredecessor
|
||
and operative.get("status") == SubscriptionStatusEnum.TRIALING.value
|
||
)
|
||
|
||
if hasActivePredecessor and predecessorIsTrial:
|
||
try:
|
||
subInterface.forceExpire(operative["id"])
|
||
logger.info(
|
||
"Trial subscription %s expired immediately for mandate %s due to paid upgrade %s",
|
||
operative["id"], mandateId, subscriptionRecordId,
|
||
)
|
||
except Exception as e:
|
||
logger.error("Failed to expire trial predecessor %s: %s", operative["id"], e)
|
||
toStatus = SubscriptionStatusEnum.ACTIVE
|
||
elif hasActivePredecessor:
|
||
toStatus = SubscriptionStatusEnum.SCHEDULED
|
||
if operative.get("recurring", True):
|
||
operativeStripeId = operative.get("stripeSubscriptionId")
|
||
if operativeStripeId:
|
||
try:
|
||
from modules.shared.stripeClient import getStripeClient
|
||
stripe = getStripeClient()
|
||
stripe.Subscription.modify(operativeStripeId, cancel_at_period_end=True)
|
||
except Exception as e:
|
||
logger.error("Failed to set cancel_at_period_end on predecessor %s: %s", operativeStripeId, e)
|
||
subInterface.updateFields(operative["id"], {"recurring": False})
|
||
effectiveFrom = operative.get("currentPeriodEnd")
|
||
if effectiveFrom:
|
||
subInterface.updateFields(subscriptionRecordId, {"effectiveFrom": effectiveFrom})
|
||
else:
|
||
toStatus = SubscriptionStatusEnum.ACTIVE
|
||
|
||
try:
|
||
subInterface.transitionStatus(
|
||
subscriptionRecordId, SubscriptionStatusEnum.PENDING, toStatus,
|
||
{"recurring": True},
|
||
)
|
||
except Exception as e:
|
||
logger.error("Failed to transition subscription %s: %s", subscriptionRecordId, e)
|
||
return
|
||
|
||
subService = getSubscriptionService(rootUser, mandateId)
|
||
subService.invalidateCache(mandateId)
|
||
|
||
if toStatus == SubscriptionStatusEnum.ACTIVE:
|
||
plan = getPlan(planKey)
|
||
updatedSub = subInterface.getById(subscriptionRecordId)
|
||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=updatedSub, platformUrl=platformUrl)
|
||
|
||
try:
|
||
billingIf = getRootInterface()
|
||
billingIf.creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung")
|
||
except Exception as ex:
|
||
logger.error("creditSubscriptionBudget on activation failed: %s", ex)
|
||
|
||
logger.info(
|
||
"Checkout completed: sub=%s -> %s, mandate=%s, plan=%s",
|
||
subscriptionRecordId, toStatus.value, mandateId, planKey,
|
||
)
|
||
|
||
|
||
def handleSubscriptionWebhook(event, getRootInterface) -> None:
|
||
"""Process Stripe subscription webhook events.
|
||
All record resolution is by stripeSubscriptionId."""
|
||
from modules.interfaces.interfaceDbSubscription import getRootInterface as getSubRootInterface
|
||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, getPlan
|
||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||
getService as getSubscriptionService,
|
||
_notifySubscriptionChange,
|
||
)
|
||
from modules.security.rootAccess import getRootUser
|
||
|
||
obj = event.data.object
|
||
rawSub = obj.get("id") if event.type.startswith("customer.subscription") else obj.get("subscription")
|
||
stripeSubId = rawSub.get("id") if isinstance(rawSub, dict) else rawSub
|
||
if not stripeSubId:
|
||
logger.warning("Subscription webhook %s has no subscription ID", event.type)
|
||
return
|
||
|
||
subInterface = getSubRootInterface()
|
||
sub = subInterface.getByStripeSubscriptionId(stripeSubId)
|
||
if not sub:
|
||
logger.warning("No local record for Stripe subscription %s (event: %s)", stripeSubId, event.type)
|
||
return
|
||
|
||
subId = sub["id"]
|
||
mandateId = sub["mandateId"]
|
||
currentStatus = SubscriptionStatusEnum(sub["status"])
|
||
rootUser = getRootUser()
|
||
subService = getSubscriptionService(rootUser, mandateId)
|
||
|
||
subMetadata = obj.get("metadata") or {}
|
||
webhookPlatformUrl = subMetadata.get("platformUrl", "")
|
||
|
||
if event.type == "customer.subscription.updated":
|
||
stripeStatus = obj.get("status", "")
|
||
|
||
periodData: Dict[str, Any] = {}
|
||
if obj.get("current_period_start"):
|
||
periodData["currentPeriodStart"] = float(obj["current_period_start"])
|
||
if obj.get("current_period_end"):
|
||
periodData["currentPeriodEnd"] = float(obj["current_period_end"])
|
||
if periodData:
|
||
subInterface.updateFields(subId, periodData)
|
||
|
||
if stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.SCHEDULED:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE)
|
||
subService.invalidateCache(mandateId)
|
||
planKey = sub.get("planKey", "")
|
||
plan = getPlan(planKey)
|
||
refreshedSub = subInterface.getById(subId)
|
||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedSub, platformUrl=webhookPlatformUrl)
|
||
try:
|
||
getRootInterface().creditSubscriptionBudget(mandateId, planKey, periodLabel="Erstaktivierung")
|
||
except Exception as ex:
|
||
logger.error("creditSubscriptionBudget SCHEDULED->ACTIVE failed: %s", ex)
|
||
logger.info("SCHEDULED -> ACTIVE for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.PAST_DUE:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.ACTIVE)
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("PAST_DUE -> ACTIVE for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif stripeStatus == "past_due" and currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE)
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("ACTIVE -> PAST_DUE for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("Period renewed for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif event.type == "customer.subscription.deleted":
|
||
if currentStatus not in (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE,
|
||
SubscriptionStatusEnum.SCHEDULED):
|
||
logger.info("Ignoring deletion for sub %s in status %s", subId, currentStatus.value)
|
||
return
|
||
|
||
subInterface.transitionStatus(subId, currentStatus, SubscriptionStatusEnum.EXPIRED)
|
||
subService.invalidateCache(mandateId)
|
||
logger.info("Sub %s -> EXPIRED (Stripe deleted, mandate %s)", subId, mandateId)
|
||
|
||
scheduled = subInterface.getScheduledForMandate(mandateId)
|
||
if scheduled:
|
||
try:
|
||
subInterface.transitionStatus(
|
||
scheduled["id"], SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE,
|
||
)
|
||
subService.invalidateCache(mandateId)
|
||
plan = getPlan(scheduled.get("planKey", ""))
|
||
refreshedScheduled = subInterface.getById(scheduled["id"])
|
||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedScheduled, platformUrl=webhookPlatformUrl)
|
||
logger.info("Promoted SCHEDULED sub %s -> ACTIVE (mandate %s)", scheduled["id"], mandateId)
|
||
except Exception as e:
|
||
logger.error("Failed to promote SCHEDULED sub %s: %s", scheduled["id"], e)
|
||
|
||
elif event.type == "invoice.payment_failed":
|
||
if currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE)
|
||
subService.invalidateCache(mandateId)
|
||
plan = getPlan(sub.get("planKey", ""))
|
||
_notifySubscriptionChange(mandateId, "payment_failed", plan, subscriptionRecord=sub, platformUrl=webhookPlatformUrl)
|
||
logger.info("Payment failed for sub %s (mandate %s)", subId, mandateId)
|
||
|
||
elif event.type == "customer.subscription.trial_will_end":
|
||
logger.info("Trial ending soon for sub %s (mandate %s)", subId, mandateId)
|
||
try:
|
||
from modules.system.notifyMandateAdmins import notifyMandateAdmins
|
||
notifyMandateAdmins(
|
||
mandateId,
|
||
"[PowerOn] Testphase endet bald",
|
||
"Testphase endet bald",
|
||
[
|
||
"Die kostenlose Testphase für Ihren Mandanten endet in Kürze.",
|
||
"Bitte wählen Sie einen Plan unter Billing-Verwaltung › Abonnement.",
|
||
],
|
||
)
|
||
except Exception as e:
|
||
logger.error("Failed to notify about trial ending: %s", e)
|
||
|
||
elif event.type == "invoice.paid":
|
||
period_ts = obj.get("period_start")
|
||
periodLabel = ""
|
||
if period_ts:
|
||
period_start_at = datetime.fromtimestamp(int(period_ts), tz=timezone.utc)
|
||
periodLabel = period_start_at.strftime("%Y-%m-%d")
|
||
try:
|
||
billing_if = getRootInterface()
|
||
billing_if.resetStorageBillingPeriod(mandateId, period_start_at)
|
||
billing_if.reconcileMandateStorageBilling(mandateId)
|
||
except Exception as ex:
|
||
logger.error("Storage billing on invoice.paid failed: %s", ex)
|
||
|
||
planKey = sub.get("planKey", "")
|
||
try:
|
||
billing_if = getRootInterface()
|
||
billing_if.creditSubscriptionBudget(mandateId, planKey, periodLabel=periodLabel or "Periodenverlängerung")
|
||
except Exception as ex:
|
||
logger.error("creditSubscriptionBudget on invoice.paid failed: %s", ex)
|
||
|
||
logger.info("Invoice paid for sub %s (mandate %s)", subId, mandateId)
|
||
return None
|