386 lines
15 KiB
Python
386 lines
15 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Auto-provision Stripe Products and Prices from the built-in plan catalog.
|
|
|
|
Creates separate Stripe Products for user licenses and feature instances
|
|
so that invoice line items show clear, descriptive names:
|
|
- "Benutzer-Lizenzen"
|
|
- "Feature-Instanzen"
|
|
|
|
Idempotent — safe to call on every startup.
|
|
|
|
Source of truth for unit amounts is BUILTIN_PLANS (CHF). On each run, persisted
|
|
Stripe Price IDs are reconciled: if Stripe's unit_amount differs from the
|
|
catalog, a new Price is created, the old one is archived, and poweron_billing
|
|
StripePlanPrice is updated. Other stale active Prices on the same Product
|
|
(same recurring interval) are archived so only the catalog-matching Price stays active.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Optional
|
|
|
|
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.datamodels.datamodelSubscription import (
|
|
BUILTIN_PLANS,
|
|
SubscriptionPlan,
|
|
BillingPeriodEnum,
|
|
StripePlanPrice,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_BILLING_DATABASE = "poweron_billing"
|
|
_METADATA_KEY = "poweron_plan_key"
|
|
_METADATA_LINE_TYPE = "poweron_line_type"
|
|
|
|
_PERIOD_TO_STRIPE = {
|
|
BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1},
|
|
BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1},
|
|
}
|
|
|
|
|
|
def _getBillingDb() -> DatabaseConnector:
|
|
return DatabaseConnector(
|
|
dbDatabase=_BILLING_DATABASE,
|
|
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
|
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
|
|
dbUser=APP_CONFIG.get("DB_USER"),
|
|
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
|
|
)
|
|
|
|
|
|
def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]:
|
|
try:
|
|
rows = db.getRecordset(StripePlanPrice)
|
|
result = {}
|
|
for row in rows:
|
|
pk = row.get("planKey")
|
|
if pk:
|
|
result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")})
|
|
return result
|
|
except Exception as e:
|
|
logger.warning("Could not load StripePlanPrice records: %s", e)
|
|
return {}
|
|
|
|
|
|
def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]:
|
|
"""Search Stripe for a product tagged with plan key + line type."""
|
|
try:
|
|
products = stripe.Product.search(
|
|
query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"',
|
|
limit=1,
|
|
)
|
|
if products.data:
|
|
return products.data[0].id
|
|
except Exception:
|
|
try:
|
|
products = stripe.Product.search(
|
|
query=f'metadata["{_METADATA_KEY}"]:"{planKey}"',
|
|
limit=10,
|
|
)
|
|
for p in products.data:
|
|
meta = p.get("metadata") or {}
|
|
if meta.get(_METADATA_LINE_TYPE) == lineType:
|
|
return p.id
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str:
|
|
product = stripe.Product.create(
|
|
name=name,
|
|
description=description,
|
|
metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType},
|
|
)
|
|
logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType)
|
|
return product.id
|
|
|
|
|
|
def _recurringMatches(recurring: Dict, interval: str, intervalCount: int) -> bool:
|
|
if not recurring:
|
|
return False
|
|
if recurring.get("interval") != interval:
|
|
return False
|
|
ic = recurring.get("interval_count")
|
|
if ic is None:
|
|
ic = 1
|
|
return int(ic) == int(intervalCount)
|
|
|
|
|
|
def _findExistingStripePrice(
|
|
stripe, productId: str, unitAmount: int, interval: str, intervalCount: int = 1,
|
|
) -> Optional[str]:
|
|
try:
|
|
prices = stripe.Price.list(product=productId, active=True, limit=50)
|
|
for p in prices.data:
|
|
recurring = p.get("recurring") or {}
|
|
if p.get("unit_amount") == unitAmount and _recurringMatches(recurring, interval, intervalCount):
|
|
return p.id
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def _getStripePriceAmount(stripe, priceId: str) -> Optional[int]:
|
|
"""Retrieve the unit_amount (in Rappen) of an existing Stripe Price."""
|
|
try:
|
|
from modules.shared.stripeClient import stripeToDict
|
|
price = stripeToDict(stripe.Price.retrieve(priceId))
|
|
return price.get("unit_amount") if price else None
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _reconcilePrice(
|
|
stripe,
|
|
productId: str,
|
|
oldPriceId: str,
|
|
expectedCHF: float,
|
|
interval: str,
|
|
nickname: str,
|
|
intervalCount: int = 1,
|
|
) -> str:
|
|
"""If the stored Stripe Price has a different amount, create a new one and deactivate the old."""
|
|
from modules.shared.stripeClient import stripeToDict
|
|
|
|
expectedCents = int(round(expectedCHF * 100))
|
|
actualCents = _getStripePriceAmount(stripe, oldPriceId)
|
|
matchesRecurring = False
|
|
try:
|
|
raw = stripe.Price.retrieve(oldPriceId)
|
|
pd = stripeToDict(raw)
|
|
matchesRecurring = _recurringMatches(pd.get("recurring") or {}, interval, intervalCount)
|
|
except Exception:
|
|
pass
|
|
|
|
if actualCents == expectedCents and matchesRecurring:
|
|
return oldPriceId
|
|
|
|
logger.warning(
|
|
"Price drift or recurring mismatch for %s: Stripe amount=%s Rappen (expected %s). Rotating price.",
|
|
oldPriceId, actualCents, expectedCents,
|
|
)
|
|
|
|
existingMatch = _findExistingStripePrice(stripe, productId, expectedCents, interval, intervalCount)
|
|
if existingMatch:
|
|
newPriceId = existingMatch
|
|
else:
|
|
newPriceId = _createStripePrice(
|
|
stripe, productId, expectedCHF, interval, nickname, intervalCount,
|
|
)
|
|
|
|
try:
|
|
stripe.Price.modify(oldPriceId, active=False)
|
|
logger.info("Deactivated old Stripe Price %s", oldPriceId)
|
|
except Exception as e:
|
|
logger.warning("Could not deactivate old price %s: %s", oldPriceId, e)
|
|
|
|
return newPriceId
|
|
|
|
|
|
def _createStripePrice(
|
|
stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str, intervalCount: int = 1,
|
|
) -> str:
|
|
price = stripe.Price.create(
|
|
product=productId,
|
|
unit_amount=int(round(unitAmountCHF * 100)),
|
|
currency="chf",
|
|
recurring={"interval": interval, "interval_count": intervalCount},
|
|
nickname=nickname,
|
|
)
|
|
logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval)
|
|
return price.id
|
|
|
|
|
|
def _archiveOtherRecurringPrices(
|
|
stripe, productId: Optional[str], keepPriceId: Optional[str], interval: str, intervalCount: int = 1,
|
|
) -> None:
|
|
"""Archive every other active recurring price on the product (same interval pattern)."""
|
|
if not productId or not keepPriceId:
|
|
return
|
|
try:
|
|
prices = stripe.Price.list(product=productId, active=True, limit=100)
|
|
for p in prices.data:
|
|
if p.id == keepPriceId:
|
|
continue
|
|
recurring = p.get("recurring") or {}
|
|
if not recurring:
|
|
continue
|
|
if not _recurringMatches(recurring, interval, intervalCount):
|
|
continue
|
|
try:
|
|
stripe.Price.modify(p.id, active=False)
|
|
logger.info("Archived stale Stripe Price %s on product %s", p.id, productId)
|
|
except Exception as ex:
|
|
logger.warning("Could not archive price %s: %s", p.id, ex)
|
|
except Exception as e:
|
|
logger.warning("Stale price archive pass failed for product %s: %s", productId, e)
|
|
|
|
|
|
def _validateStripeIdsExist(stripe, mapping: StripePlanPrice) -> bool:
|
|
"""Quick check whether at least the stored product IDs still exist in Stripe.
|
|
Returns False when running against a different Stripe account or after DB copy."""
|
|
try:
|
|
if mapping.stripeProductIdUsers:
|
|
stripe.Product.retrieve(mapping.stripeProductIdUsers)
|
|
if mapping.stripeProductIdInstances:
|
|
stripe.Product.retrieve(mapping.stripeProductIdInstances)
|
|
return True
|
|
except Exception as e:
|
|
code = getattr(e, "code", None)
|
|
if code == "resource_missing":
|
|
return False
|
|
logger.debug("Stripe validation check failed (non-critical): %s", e)
|
|
return False
|
|
|
|
|
|
def bootstrapStripePrices() -> None:
|
|
"""Ensure all paid plans have separate Stripe Products for users and instances."""
|
|
try:
|
|
from modules.shared.stripeClient import getStripeClient
|
|
stripe = getStripeClient()
|
|
except ValueError as e:
|
|
logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e)
|
|
return
|
|
|
|
db = _getBillingDb()
|
|
existing = _loadExistingMappings(db)
|
|
|
|
for planKey, plan in BUILTIN_PLANS.items():
|
|
if plan.billingPeriod == BillingPeriodEnum.NONE:
|
|
continue
|
|
if plan.pricePerUserCHF == 0 and plan.pricePerFeatureInstanceCHF == 0:
|
|
continue
|
|
|
|
stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod)
|
|
if not stripePeriod:
|
|
continue
|
|
|
|
interval = stripePeriod["interval"]
|
|
intervalCount = int(stripePeriod.get("interval_count") or 1)
|
|
|
|
if planKey in existing:
|
|
mapping = existing[planKey]
|
|
hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances
|
|
hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances
|
|
if hasAllPrices and hasAllProducts:
|
|
if _validateStripeIdsExist(stripe, mapping):
|
|
changed = False
|
|
reconciledUsers = _reconcilePrice(
|
|
stripe, mapping.stripeProductIdUsers, mapping.stripePriceIdUsers,
|
|
plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
|
|
intervalCount,
|
|
)
|
|
if reconciledUsers != mapping.stripePriceIdUsers:
|
|
changed = True
|
|
|
|
reconciledInstances = _reconcilePrice(
|
|
stripe, mapping.stripeProductIdInstances, mapping.stripePriceIdInstances,
|
|
plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Feature-Instanz",
|
|
intervalCount,
|
|
)
|
|
if reconciledInstances != mapping.stripePriceIdInstances:
|
|
changed = True
|
|
|
|
_archiveOtherRecurringPrices(
|
|
stripe, mapping.stripeProductIdUsers, reconciledUsers, interval, intervalCount,
|
|
)
|
|
_archiveOtherRecurringPrices(
|
|
stripe, mapping.stripeProductIdInstances, reconciledInstances, interval, intervalCount,
|
|
)
|
|
|
|
if changed:
|
|
db.recordModify(StripePlanPrice, mapping.id, {
|
|
"stripePriceIdUsers": reconciledUsers,
|
|
"stripePriceIdInstances": reconciledInstances,
|
|
})
|
|
logger.info(
|
|
"Reconciled Stripe prices for plan %s to catalog (CHF): users=%s, instances=%s",
|
|
planKey, reconciledUsers, reconciledInstances,
|
|
)
|
|
else:
|
|
logger.debug("Stripe prices up-to-date for plan %s", planKey)
|
|
continue
|
|
else:
|
|
logger.warning(
|
|
"Stored Stripe IDs for plan %s reference unknown objects "
|
|
"(likely wrong Stripe account or copied DB) — re-provisioning.",
|
|
planKey,
|
|
)
|
|
|
|
productIdUsers = None
|
|
productIdInstances = None
|
|
priceIdUsers = None
|
|
priceIdInstances = None
|
|
|
|
if plan.pricePerUserCHF > 0:
|
|
productIdUsers = _findStripeProduct(stripe, planKey, "users")
|
|
if not productIdUsers:
|
|
productIdUsers = _createStripeProduct(
|
|
stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title.get('de', planKey)}",
|
|
planKey, "users",
|
|
)
|
|
userCents = int(round(plan.pricePerUserCHF * 100))
|
|
priceIdUsers = _findExistingStripePrice(
|
|
stripe, productIdUsers, userCents, interval, intervalCount,
|
|
)
|
|
if not priceIdUsers:
|
|
priceIdUsers = _createStripePrice(
|
|
stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
|
|
intervalCount,
|
|
)
|
|
_archiveOtherRecurringPrices(stripe, productIdUsers, priceIdUsers, interval, intervalCount)
|
|
|
|
if plan.pricePerFeatureInstanceCHF > 0:
|
|
productIdInstances = _findStripeProduct(stripe, planKey, "instances")
|
|
if not productIdInstances:
|
|
productIdInstances = _createStripeProduct(
|
|
stripe, "Feature-Instanzen", f"Feature-Instanzen für {plan.title.get('de', planKey)}",
|
|
planKey, "instances",
|
|
)
|
|
instCents = int(round(plan.pricePerFeatureInstanceCHF * 100))
|
|
priceIdInstances = _findExistingStripePrice(
|
|
stripe, productIdInstances, instCents, interval, intervalCount,
|
|
)
|
|
if not priceIdInstances:
|
|
priceIdInstances = _createStripePrice(
|
|
stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval,
|
|
f"{planKey} — Feature-Instanz",
|
|
intervalCount,
|
|
)
|
|
_archiveOtherRecurringPrices(
|
|
stripe, productIdInstances, priceIdInstances, interval, intervalCount,
|
|
)
|
|
|
|
persistData = {
|
|
"stripeProductId": "",
|
|
"stripeProductIdUsers": productIdUsers,
|
|
"stripeProductIdInstances": productIdInstances,
|
|
"stripePriceIdUsers": priceIdUsers,
|
|
"stripePriceIdInstances": priceIdInstances,
|
|
}
|
|
|
|
if planKey in existing:
|
|
db.recordModify(StripePlanPrice, existing[planKey].id, persistData)
|
|
else:
|
|
db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump())
|
|
|
|
logger.info(
|
|
"Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s",
|
|
planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances,
|
|
)
|
|
|
|
|
|
def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]:
|
|
"""Load the persisted Stripe IDs for a plan."""
|
|
try:
|
|
db = _getBillingDb()
|
|
rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey})
|
|
if rows:
|
|
return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")})
|
|
except Exception as e:
|
|
logger.error("Error loading Stripe prices for plan %s: %s", planKey, e)
|
|
return None
|