gateway/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py
ValueOn AG 268c4b8e1e prices
2026-04-02 13:09:04 +02:00

386 lines
15 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Auto-provision Stripe Products and Prices from the built-in plan catalog.
Creates separate Stripe Products for user licenses and feature instances
so that invoice line items show clear, descriptive names:
- "Benutzer-Lizenzen"
- "Feature-Instanzen"
Idempotent — safe to call on every startup.
Source of truth for unit amounts is BUILTIN_PLANS (CHF). On each run, persisted
Stripe Price IDs are reconciled: if Stripe's unit_amount differs from the
catalog, a new Price is created, the old one is archived, and poweron_billing
StripePlanPrice is updated. Other stale active Prices on the same Product
(same recurring interval) are archived so only the catalog-matching Price stays active.
"""
import logging
from typing import Dict, Optional
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
from modules.datamodels.datamodelSubscription import (
BUILTIN_PLANS,
SubscriptionPlan,
BillingPeriodEnum,
StripePlanPrice,
)
logger = logging.getLogger(__name__)
_BILLING_DATABASE = "poweron_billing"
_METADATA_KEY = "poweron_plan_key"
_METADATA_LINE_TYPE = "poweron_line_type"
_PERIOD_TO_STRIPE = {
BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1},
BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1},
}
def _getBillingDb() -> DatabaseConnector:
return DatabaseConnector(
dbDatabase=_BILLING_DATABASE,
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
dbUser=APP_CONFIG.get("DB_USER"),
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
)
def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]:
try:
rows = db.getRecordset(StripePlanPrice)
result = {}
for row in rows:
pk = row.get("planKey")
if pk:
result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")})
return result
except Exception as e:
logger.warning("Could not load StripePlanPrice records: %s", e)
return {}
def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]:
"""Search Stripe for a product tagged with plan key + line type."""
try:
products = stripe.Product.search(
query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"',
limit=1,
)
if products.data:
return products.data[0].id
except Exception:
try:
products = stripe.Product.search(
query=f'metadata["{_METADATA_KEY}"]:"{planKey}"',
limit=10,
)
for p in products.data:
meta = p.get("metadata") or {}
if meta.get(_METADATA_LINE_TYPE) == lineType:
return p.id
except Exception:
pass
return None
def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str:
product = stripe.Product.create(
name=name,
description=description,
metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType},
)
logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType)
return product.id
def _recurringMatches(recurring: Dict, interval: str, intervalCount: int) -> bool:
if not recurring:
return False
if recurring.get("interval") != interval:
return False
ic = recurring.get("interval_count")
if ic is None:
ic = 1
return int(ic) == int(intervalCount)
def _findExistingStripePrice(
stripe, productId: str, unitAmount: int, interval: str, intervalCount: int = 1,
) -> Optional[str]:
try:
prices = stripe.Price.list(product=productId, active=True, limit=50)
for p in prices.data:
recurring = p.get("recurring") or {}
if p.get("unit_amount") == unitAmount and _recurringMatches(recurring, interval, intervalCount):
return p.id
except Exception:
pass
return None
def _getStripePriceAmount(stripe, priceId: str) -> Optional[int]:
"""Retrieve the unit_amount (in Rappen) of an existing Stripe Price."""
try:
from modules.shared.stripeClient import stripeToDict
price = stripeToDict(stripe.Price.retrieve(priceId))
return price.get("unit_amount") if price else None
except Exception:
return None
def _reconcilePrice(
stripe,
productId: str,
oldPriceId: str,
expectedCHF: float,
interval: str,
nickname: str,
intervalCount: int = 1,
) -> str:
"""If the stored Stripe Price has a different amount, create a new one and deactivate the old."""
from modules.shared.stripeClient import stripeToDict
expectedCents = int(round(expectedCHF * 100))
actualCents = _getStripePriceAmount(stripe, oldPriceId)
matchesRecurring = False
try:
raw = stripe.Price.retrieve(oldPriceId)
pd = stripeToDict(raw)
matchesRecurring = _recurringMatches(pd.get("recurring") or {}, interval, intervalCount)
except Exception:
pass
if actualCents == expectedCents and matchesRecurring:
return oldPriceId
logger.warning(
"Price drift or recurring mismatch for %s: Stripe amount=%s Rappen (expected %s). Rotating price.",
oldPriceId, actualCents, expectedCents,
)
existingMatch = _findExistingStripePrice(stripe, productId, expectedCents, interval, intervalCount)
if existingMatch:
newPriceId = existingMatch
else:
newPriceId = _createStripePrice(
stripe, productId, expectedCHF, interval, nickname, intervalCount,
)
try:
stripe.Price.modify(oldPriceId, active=False)
logger.info("Deactivated old Stripe Price %s", oldPriceId)
except Exception as e:
logger.warning("Could not deactivate old price %s: %s", oldPriceId, e)
return newPriceId
def _createStripePrice(
stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str, intervalCount: int = 1,
) -> str:
price = stripe.Price.create(
product=productId,
unit_amount=int(round(unitAmountCHF * 100)),
currency="chf",
recurring={"interval": interval, "interval_count": intervalCount},
nickname=nickname,
)
logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval)
return price.id
def _archiveOtherRecurringPrices(
stripe, productId: Optional[str], keepPriceId: Optional[str], interval: str, intervalCount: int = 1,
) -> None:
"""Archive every other active recurring price on the product (same interval pattern)."""
if not productId or not keepPriceId:
return
try:
prices = stripe.Price.list(product=productId, active=True, limit=100)
for p in prices.data:
if p.id == keepPriceId:
continue
recurring = p.get("recurring") or {}
if not recurring:
continue
if not _recurringMatches(recurring, interval, intervalCount):
continue
try:
stripe.Price.modify(p.id, active=False)
logger.info("Archived stale Stripe Price %s on product %s", p.id, productId)
except Exception as ex:
logger.warning("Could not archive price %s: %s", p.id, ex)
except Exception as e:
logger.warning("Stale price archive pass failed for product %s: %s", productId, e)
def _validateStripeIdsExist(stripe, mapping: StripePlanPrice) -> bool:
"""Quick check whether at least the stored product IDs still exist in Stripe.
Returns False when running against a different Stripe account or after DB copy."""
try:
if mapping.stripeProductIdUsers:
stripe.Product.retrieve(mapping.stripeProductIdUsers)
if mapping.stripeProductIdInstances:
stripe.Product.retrieve(mapping.stripeProductIdInstances)
return True
except Exception as e:
code = getattr(e, "code", None)
if code == "resource_missing":
return False
logger.debug("Stripe validation check failed (non-critical): %s", e)
return False
def bootstrapStripePrices() -> None:
"""Ensure all paid plans have separate Stripe Products for users and instances."""
try:
from modules.shared.stripeClient import getStripeClient
stripe = getStripeClient()
except ValueError as e:
logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e)
return
db = _getBillingDb()
existing = _loadExistingMappings(db)
for planKey, plan in BUILTIN_PLANS.items():
if plan.billingPeriod == BillingPeriodEnum.NONE:
continue
if plan.pricePerUserCHF == 0 and plan.pricePerFeatureInstanceCHF == 0:
continue
stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod)
if not stripePeriod:
continue
interval = stripePeriod["interval"]
intervalCount = int(stripePeriod.get("interval_count") or 1)
if planKey in existing:
mapping = existing[planKey]
hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances
hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances
if hasAllPrices and hasAllProducts:
if _validateStripeIdsExist(stripe, mapping):
changed = False
reconciledUsers = _reconcilePrice(
stripe, mapping.stripeProductIdUsers, mapping.stripePriceIdUsers,
plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
intervalCount,
)
if reconciledUsers != mapping.stripePriceIdUsers:
changed = True
reconciledInstances = _reconcilePrice(
stripe, mapping.stripeProductIdInstances, mapping.stripePriceIdInstances,
plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Feature-Instanz",
intervalCount,
)
if reconciledInstances != mapping.stripePriceIdInstances:
changed = True
_archiveOtherRecurringPrices(
stripe, mapping.stripeProductIdUsers, reconciledUsers, interval, intervalCount,
)
_archiveOtherRecurringPrices(
stripe, mapping.stripeProductIdInstances, reconciledInstances, interval, intervalCount,
)
if changed:
db.recordModify(StripePlanPrice, mapping.id, {
"stripePriceIdUsers": reconciledUsers,
"stripePriceIdInstances": reconciledInstances,
})
logger.info(
"Reconciled Stripe prices for plan %s to catalog (CHF): users=%s, instances=%s",
planKey, reconciledUsers, reconciledInstances,
)
else:
logger.debug("Stripe prices up-to-date for plan %s", planKey)
continue
else:
logger.warning(
"Stored Stripe IDs for plan %s reference unknown objects "
"(likely wrong Stripe account or copied DB) — re-provisioning.",
planKey,
)
productIdUsers = None
productIdInstances = None
priceIdUsers = None
priceIdInstances = None
if plan.pricePerUserCHF > 0:
productIdUsers = _findStripeProduct(stripe, planKey, "users")
if not productIdUsers:
productIdUsers = _createStripeProduct(
stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title.get('de', planKey)}",
planKey, "users",
)
userCents = int(round(plan.pricePerUserCHF * 100))
priceIdUsers = _findExistingStripePrice(
stripe, productIdUsers, userCents, interval, intervalCount,
)
if not priceIdUsers:
priceIdUsers = _createStripePrice(
stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
intervalCount,
)
_archiveOtherRecurringPrices(stripe, productIdUsers, priceIdUsers, interval, intervalCount)
if plan.pricePerFeatureInstanceCHF > 0:
productIdInstances = _findStripeProduct(stripe, planKey, "instances")
if not productIdInstances:
productIdInstances = _createStripeProduct(
stripe, "Feature-Instanzen", f"Feature-Instanzen für {plan.title.get('de', planKey)}",
planKey, "instances",
)
instCents = int(round(plan.pricePerFeatureInstanceCHF * 100))
priceIdInstances = _findExistingStripePrice(
stripe, productIdInstances, instCents, interval, intervalCount,
)
if not priceIdInstances:
priceIdInstances = _createStripePrice(
stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval,
f"{planKey} — Feature-Instanz",
intervalCount,
)
_archiveOtherRecurringPrices(
stripe, productIdInstances, priceIdInstances, interval, intervalCount,
)
persistData = {
"stripeProductId": "",
"stripeProductIdUsers": productIdUsers,
"stripeProductIdInstances": productIdInstances,
"stripePriceIdUsers": priceIdUsers,
"stripePriceIdInstances": priceIdInstances,
}
if planKey in existing:
db.recordModify(StripePlanPrice, existing[planKey].id, persistData)
else:
db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump())
logger.info(
"Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s",
planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances,
)
def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]:
"""Load the persisted Stripe IDs for a plan."""
try:
db = _getBillingDb()
rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey})
if rows:
return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")})
except Exception as e:
logger.error("Error loading Stripe prices for plan %s: %s", planKey, e)
return None