gateway/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py

297 lines
11 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Auto-provision Stripe Products and Prices from the built-in plan catalog.
Creates separate Stripe Products for user licenses and feature instances
so that invoice line items show clear, descriptive names:
- "Benutzer-Lizenzen"
- "Feature-Instanzen"
Idempotent — safe to call on every startup.
"""
import logging
from typing import Dict, Optional
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
from modules.datamodels.datamodelSubscription import (
BUILTIN_PLANS,
SubscriptionPlan,
BillingPeriodEnum,
StripePlanPrice,
)
logger = logging.getLogger(__name__)
_BILLING_DATABASE = "poweron_billing"
_METADATA_KEY = "poweron_plan_key"
_METADATA_LINE_TYPE = "poweron_line_type"
_PERIOD_TO_STRIPE = {
BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1},
BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1},
}
def _getBillingDb() -> DatabaseConnector:
return DatabaseConnector(
dbDatabase=_BILLING_DATABASE,
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
dbUser=APP_CONFIG.get("DB_USER"),
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
)
def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]:
try:
rows = db.getRecordset(StripePlanPrice)
result = {}
for row in rows:
pk = row.get("planKey")
if pk:
result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")})
return result
except Exception as e:
logger.warning("Could not load StripePlanPrice records: %s", e)
return {}
def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]:
"""Search Stripe for a product tagged with plan key + line type."""
try:
products = stripe.Product.search(
query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"',
limit=1,
)
if products.data:
return products.data[0].id
except Exception:
try:
products = stripe.Product.search(
query=f'metadata["{_METADATA_KEY}"]:"{planKey}"',
limit=10,
)
for p in products.data:
meta = p.get("metadata") or {}
if meta.get(_METADATA_LINE_TYPE) == lineType:
return p.id
except Exception:
pass
return None
def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str:
product = stripe.Product.create(
name=name,
description=description,
metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType},
)
logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType)
return product.id
def _findExistingStripePrice(stripe, productId: str, unitAmount: int, interval: str) -> Optional[str]:
try:
prices = stripe.Price.list(product=productId, active=True, limit=50)
for p in prices.data:
recurring = p.get("recurring") or {}
if p.get("unit_amount") == unitAmount and recurring.get("interval") == interval:
return p.id
except Exception:
pass
return None
def _getStripePriceAmount(stripe, priceId: str) -> Optional[int]:
"""Retrieve the unit_amount (in Rappen) of an existing Stripe Price."""
try:
price = stripe.Price.retrieve(priceId)
return price.get("unit_amount") if price else None
except Exception:
return None
def _reconcilePrice(stripe, productId: str, oldPriceId: str, expectedCHF: float, interval: str, nickname: str) -> str:
"""If the stored Stripe Price has a different amount, create a new one and deactivate the old."""
expectedCents = int(expectedCHF * 100)
actualCents = _getStripePriceAmount(stripe, oldPriceId)
if actualCents == expectedCents:
return oldPriceId
logger.warning(
"Price drift detected for %s: Stripe has %s Rappen, catalog expects %s Rappen. Rotating price.",
oldPriceId, actualCents, expectedCents,
)
existingMatch = _findExistingStripePrice(stripe, productId, expectedCents, interval)
if existingMatch:
newPriceId = existingMatch
else:
newPriceId = _createStripePrice(stripe, productId, expectedCHF, interval, nickname)
try:
stripe.Price.modify(oldPriceId, active=False)
logger.info("Deactivated old Stripe Price %s", oldPriceId)
except Exception as e:
logger.warning("Could not deactivate old price %s: %s", oldPriceId, e)
return newPriceId
def _createStripePrice(stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str) -> str:
price = stripe.Price.create(
product=productId,
unit_amount=int(unitAmountCHF * 100),
currency="chf",
recurring={"interval": interval},
nickname=nickname,
)
logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval)
return price.id
def _validateStripeIdsExist(stripe, mapping: StripePlanPrice) -> bool:
"""Quick check whether at least the stored product IDs still exist in Stripe.
Returns False when running against a different Stripe account or after DB copy."""
try:
if mapping.stripeProductIdUsers:
stripe.Product.retrieve(mapping.stripeProductIdUsers)
if mapping.stripeProductIdInstances:
stripe.Product.retrieve(mapping.stripeProductIdInstances)
return True
except Exception as e:
code = getattr(e, "code", None)
if code == "resource_missing":
return False
logger.debug("Stripe validation check failed (non-critical): %s", e)
return False
def bootstrapStripePrices() -> None:
"""Ensure all paid plans have separate Stripe Products for users and instances."""
try:
from modules.shared.stripeClient import getStripeClient
stripe = getStripeClient()
except ValueError as e:
logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e)
return
db = _getBillingDb()
existing = _loadExistingMappings(db)
for planKey, plan in BUILTIN_PLANS.items():
if plan.billingPeriod == BillingPeriodEnum.NONE:
continue
if plan.pricePerUserCHF == 0 and plan.pricePerFeatureInstanceCHF == 0:
continue
stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod)
if not stripePeriod:
continue
interval = stripePeriod["interval"]
if planKey in existing:
mapping = existing[planKey]
hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances
hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances
if hasAllPrices and hasAllProducts:
if _validateStripeIdsExist(stripe, mapping):
changed = False
reconciledUsers = _reconcilePrice(
stripe, mapping.stripeProductIdUsers, mapping.stripePriceIdUsers,
plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
)
if reconciledUsers != mapping.stripePriceIdUsers:
changed = True
reconciledInstances = _reconcilePrice(
stripe, mapping.stripeProductIdInstances, mapping.stripePriceIdInstances,
plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Feature-Instanz",
)
if reconciledInstances != mapping.stripePriceIdInstances:
changed = True
if changed:
db.recordModify(StripePlanPrice, mapping.id, {
"stripePriceIdUsers": reconciledUsers,
"stripePriceIdInstances": reconciledInstances,
})
logger.info("Reconciled Stripe prices for plan %s: users=%s, instances=%s", planKey, reconciledUsers, reconciledInstances)
else:
logger.debug("Stripe prices up-to-date for plan %s", planKey)
continue
else:
logger.warning(
"Stored Stripe IDs for plan %s reference unknown objects "
"(likely wrong Stripe account or copied DB) — re-provisioning.",
planKey,
)
productIdUsers = None
productIdInstances = None
priceIdUsers = None
priceIdInstances = None
if plan.pricePerUserCHF > 0:
productIdUsers = _findStripeProduct(stripe, planKey, "users")
if not productIdUsers:
productIdUsers = _createStripeProduct(
stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title.get('de', planKey)}",
planKey, "users",
)
priceIdUsers = _findExistingStripePrice(stripe, productIdUsers, int(plan.pricePerUserCHF * 100), interval)
if not priceIdUsers:
priceIdUsers = _createStripePrice(
stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
)
if plan.pricePerFeatureInstanceCHF > 0:
productIdInstances = _findStripeProduct(stripe, planKey, "instances")
if not productIdInstances:
productIdInstances = _createStripeProduct(
stripe, "Feature-Instanzen", f"Feature-Instanzen für {plan.title.get('de', planKey)}",
planKey, "instances",
)
priceIdInstances = _findExistingStripePrice(
stripe, productIdInstances, int(plan.pricePerFeatureInstanceCHF * 100), interval,
)
if not priceIdInstances:
priceIdInstances = _createStripePrice(
stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval,
f"{planKey} — Feature-Instanz",
)
persistData = {
"stripeProductId": "",
"stripeProductIdUsers": productIdUsers,
"stripeProductIdInstances": productIdInstances,
"stripePriceIdUsers": priceIdUsers,
"stripePriceIdInstances": priceIdInstances,
}
if planKey in existing:
db.recordModify(StripePlanPrice, existing[planKey].id, persistData)
else:
db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump())
logger.info(
"Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s",
planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances,
)
def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]:
"""Load the persisted Stripe IDs for a plan."""
try:
db = _getBillingDb()
rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey})
if rows:
return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")})
except Exception as e:
logger.error("Error loading Stripe prices for plan %s: %s", planKey, e)
return None