gateway/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py
2026-04-26 18:11:42 +02:00

423 lines
16 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Auto-provision Stripe Products and Prices from the built-in plan catalog.
Creates separate Stripe Products for user licenses and modules
so that invoice line items show clear, descriptive names:
- "Benutzer-Lizenzen"
- "Module"
Idempotent — safe to call on every startup.
Source of truth for unit amounts is BUILTIN_PLANS (CHF). On each run, persisted
Stripe Price IDs are reconciled: if Stripe's unit_amount differs from the
catalog, a new Price is created, the old one is archived, and poweron_billing
StripePlanPrice is updated. Other stale active Prices on the same Product
(same recurring interval) are archived so only the catalog-matching Price stays active.
"""
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Optional
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
from modules.datamodels.datamodelSubscription import (
BUILTIN_PLANS,
SubscriptionPlan,
BillingPeriodEnum,
StripePlanPrice,
)
logger = logging.getLogger(__name__)
_BILLING_DATABASE = "poweron_billing"
_METADATA_KEY = "poweron_plan_key"
_METADATA_LINE_TYPE = "poweron_line_type"
_PERIOD_TO_STRIPE = {
BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1},
BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1},
}
def _getBillingDb() -> DatabaseConnector:
return DatabaseConnector(
dbDatabase=_BILLING_DATABASE,
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
dbUser=APP_CONFIG.get("DB_USER"),
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
)
def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]:
try:
rows = db.getRecordset(StripePlanPrice)
result = {}
for row in rows:
pk = row.get("planKey")
if pk:
result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")})
return result
except Exception as e:
logger.warning("Could not load StripePlanPrice records: %s", e)
return {}
def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]:
"""Search Stripe for a product tagged with plan key + line type."""
try:
products = stripe.Product.search(
query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"',
limit=1,
)
if products.data:
return products.data[0].id
except Exception:
try:
products = stripe.Product.search(
query=f'metadata["{_METADATA_KEY}"]:"{planKey}"',
limit=10,
)
for p in products.data:
meta = p.get("metadata") or {}
if meta.get(_METADATA_LINE_TYPE) == lineType:
return p.id
except Exception:
pass
return None
def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str:
product = stripe.Product.create(
name=name,
description=description,
metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType},
)
logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType)
return product.id
def _recurringMatches(recurring: Dict, interval: str, intervalCount: int) -> bool:
if not recurring:
return False
if recurring.get("interval") != interval:
return False
ic = recurring.get("interval_count")
if ic is None:
ic = 1
return int(ic) == int(intervalCount)
def _findExistingStripePrice(
stripe, productId: str, unitAmount: int, interval: str, intervalCount: int = 1,
) -> Optional[str]:
try:
prices = stripe.Price.list(product=productId, active=True, limit=50)
for p in prices.data:
recurring = p.get("recurring") or {}
if p.get("unit_amount") == unitAmount and _recurringMatches(recurring, interval, intervalCount):
return p.id
except Exception:
pass
return None
def _reconcilePrice(
stripe,
productId: str,
oldPriceId: str,
expectedCHF: float,
interval: str,
nickname: str,
intervalCount: int = 1,
) -> str:
"""If the stored Stripe Price has a different amount, is no longer active,
or has a different recurring interval, create/find a new active one and
deactivate the old."""
from modules.shared.stripeClient import stripeToDict
expectedCents = int(round(expectedCHF * 100))
actualCents: Optional[int] = None
matchesRecurring = False
isActive = False
retrieveFailed = False
try:
raw = stripe.Price.retrieve(oldPriceId)
pd = stripeToDict(raw)
actualCents = pd.get("unit_amount")
matchesRecurring = _recurringMatches(pd.get("recurring") or {}, interval, intervalCount)
# Stripe.Price.retrieve returns archived prices too, so we MUST check
# `active` explicitly. Subscription.create rejects inactive prices with
# "The price specified is inactive. This field only accepts active prices."
isActive = bool(pd.get("active"))
except Exception as ex:
retrieveFailed = True
logger.warning("Could not retrieve Stripe Price %s: %s", oldPriceId, ex)
if not retrieveFailed and isActive and actualCents == expectedCents and matchesRecurring:
return oldPriceId
logger.warning(
"Rotating Stripe Price %s on product %s: active=%s amount=%s (expected %s) recurringMatches=%s retrieveFailed=%s.",
oldPriceId, productId, isActive, actualCents, expectedCents, matchesRecurring, retrieveFailed,
)
existingMatch = _findExistingStripePrice(stripe, productId, expectedCents, interval, intervalCount)
if existingMatch:
newPriceId = existingMatch
else:
newPriceId = _createStripePrice(
stripe, productId, expectedCHF, interval, nickname, intervalCount,
)
try:
stripe.Price.modify(oldPriceId, active=False)
logger.info("Deactivated old Stripe Price %s", oldPriceId)
except Exception as e:
logger.warning("Could not deactivate old price %s: %s", oldPriceId, e)
return newPriceId
def _createStripePrice(
stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str, intervalCount: int = 1,
) -> str:
price = stripe.Price.create(
product=productId,
unit_amount=int(round(unitAmountCHF * 100)),
currency="chf",
recurring={"interval": interval, "interval_count": intervalCount},
nickname=nickname,
)
logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval)
return price.id
def _archiveOtherRecurringPrices(
stripe, productId: Optional[str], keepPriceId: Optional[str], interval: str, intervalCount: int = 1,
) -> None:
"""Archive every other active recurring price on the product (same interval pattern)."""
if not productId or not keepPriceId:
return
try:
prices = stripe.Price.list(product=productId, active=True, limit=100)
for p in prices.data:
if p.id == keepPriceId:
continue
recurring = p.get("recurring") or {}
if not recurring:
continue
if not _recurringMatches(recurring, interval, intervalCount):
continue
try:
stripe.Price.modify(p.id, active=False)
logger.info("Archived stale Stripe Price %s on product %s", p.id, productId)
except Exception as ex:
logger.warning("Could not archive price %s: %s", p.id, ex)
except Exception as e:
logger.warning("Stale price archive pass failed for product %s: %s", productId, e)
def _validateStripeIdsExist(stripe, mapping: StripePlanPrice) -> bool:
"""Quick check whether the stored Stripe product IDs still exist.
Returns False when running against a different Stripe account or after a
DB copy from another environment. Price-level validation (active flag,
drift) is handled by ``_reconcilePrice``; we don't fail here on archived
prices, otherwise we'd needlessly re-provision products on every rotation.
"""
try:
if mapping.stripeProductIdUsers:
stripe.Product.retrieve(mapping.stripeProductIdUsers)
if mapping.stripeProductIdInstances:
stripe.Product.retrieve(mapping.stripeProductIdInstances)
return True
except Exception as e:
code = getattr(e, "code", None)
if code == "resource_missing":
return False
logger.debug("Stripe validation check failed (non-critical): %s", e)
return False
def _processOnePlan(
stripe,
planKey: str,
plan: SubscriptionPlan,
existingMapping: Optional[StripePlanPrice],
) -> None:
"""Reconcile or provision Stripe Products/Prices for a single plan.
Each call uses its own DB connection so it is safe to run in a thread pool.
"""
stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod)
if not stripePeriod:
return
interval = stripePeriod["interval"]
intervalCount = int(stripePeriod.get("interval_count") or 1)
db = _getBillingDb()
if existingMapping:
mapping = existingMapping
hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances
hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances
if hasAllPrices and hasAllProducts:
if _validateStripeIdsExist(stripe, mapping):
changed = False
reconciledUsers = _reconcilePrice(
stripe, mapping.stripeProductIdUsers, mapping.stripePriceIdUsers,
plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
intervalCount,
)
if reconciledUsers != mapping.stripePriceIdUsers:
changed = True
reconciledInstances = _reconcilePrice(
stripe, mapping.stripeProductIdInstances, mapping.stripePriceIdInstances,
plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Modul",
intervalCount,
)
if reconciledInstances != mapping.stripePriceIdInstances:
changed = True
_archiveOtherRecurringPrices(
stripe, mapping.stripeProductIdUsers, reconciledUsers, interval, intervalCount,
)
_archiveOtherRecurringPrices(
stripe, mapping.stripeProductIdInstances, reconciledInstances, interval, intervalCount,
)
if changed:
db.recordModify(StripePlanPrice, mapping.id, {
"stripePriceIdUsers": reconciledUsers,
"stripePriceIdInstances": reconciledInstances,
})
logger.info(
"Reconciled Stripe prices for plan %s to catalog (CHF): users=%s, instances=%s",
planKey, reconciledUsers, reconciledInstances,
)
else:
logger.debug("Stripe prices up-to-date for plan %s", planKey)
return
else:
logger.warning(
"Stored Stripe IDs for plan %s reference unknown objects "
"(likely wrong Stripe account or copied DB) — re-provisioning.",
planKey,
)
productIdUsers = None
productIdInstances = None
priceIdUsers = None
priceIdInstances = None
if plan.pricePerUserCHF > 0:
productIdUsers = _findStripeProduct(stripe, planKey, "users")
if not productIdUsers:
productIdUsers = _createStripeProduct(
stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title or planKey}",
planKey, "users",
)
userCents = int(round(plan.pricePerUserCHF * 100))
priceIdUsers = _findExistingStripePrice(
stripe, productIdUsers, userCents, interval, intervalCount,
)
if not priceIdUsers:
priceIdUsers = _createStripePrice(
stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
intervalCount,
)
_archiveOtherRecurringPrices(stripe, productIdUsers, priceIdUsers, interval, intervalCount)
if plan.pricePerFeatureInstanceCHF > 0:
productIdInstances = _findStripeProduct(stripe, planKey, "instances")
if not productIdInstances:
productIdInstances = _createStripeProduct(
stripe, "Module", f"Module für {plan.title or planKey}",
planKey, "instances",
)
instCents = int(round(plan.pricePerFeatureInstanceCHF * 100))
priceIdInstances = _findExistingStripePrice(
stripe, productIdInstances, instCents, interval, intervalCount,
)
if not priceIdInstances:
priceIdInstances = _createStripePrice(
stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval,
f"{planKey} — Modul",
intervalCount,
)
_archiveOtherRecurringPrices(
stripe, productIdInstances, priceIdInstances, interval, intervalCount,
)
persistData = {
"stripeProductId": "",
"stripeProductIdUsers": productIdUsers,
"stripeProductIdInstances": productIdInstances,
"stripePriceIdUsers": priceIdUsers,
"stripePriceIdInstances": priceIdInstances,
}
if existingMapping:
db.recordModify(StripePlanPrice, existingMapping.id, persistData)
else:
db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump())
logger.info(
"Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s",
planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances,
)
def bootstrapStripePrices() -> None:
"""Ensure all paid plans have separate Stripe Products for users and instances.
Plans are processed in parallel (one thread per plan) to reduce boot time.
Each thread uses its own DB connection; Stripe SDK is thread-safe.
"""
try:
from modules.shared.stripeClient import getStripeClient
stripe = getStripeClient()
except ValueError as e:
logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e)
return
existing = _loadExistingMappings(_getBillingDb())
plans = [
(planKey, plan)
for planKey, plan in BUILTIN_PLANS.items()
if plan.billingPeriod != BillingPeriodEnum.NONE
and (plan.pricePerUserCHF > 0 or plan.pricePerFeatureInstanceCHF > 0)
]
if not plans:
return
with ThreadPoolExecutor(max_workers=len(plans)) as executor:
futures = {
executor.submit(_processOnePlan, stripe, planKey, plan, existing.get(planKey)): planKey
for planKey, plan in plans
}
for future in as_completed(futures):
planKey = futures[future]
try:
future.result()
except Exception as e:
logger.error("Stripe bootstrap failed for plan %s: %s", planKey, e)
def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]:
"""Load the persisted Stripe IDs for a plan."""
try:
db = _getBillingDb()
rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey})
if rows:
return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")})
except Exception as e:
logger.error("Error loading Stripe prices for plan %s: %s", planKey, e)
return None