214 lines
7.8 KiB
Python
214 lines
7.8 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Auto-provision Stripe Products and Prices from the built-in plan catalog.
|
|
|
|
Creates separate Stripe Products for user licenses and feature instances
|
|
so that invoice line items show clear, descriptive names:
|
|
- "Benutzer-Lizenzen"
|
|
- "Feature-Instanzen"
|
|
|
|
Idempotent — safe to call on every startup.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Optional
|
|
|
|
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.datamodels.datamodelSubscription import (
|
|
BUILTIN_PLANS,
|
|
SubscriptionPlan,
|
|
BillingPeriodEnum,
|
|
StripePlanPrice,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_BILLING_DATABASE = "poweron_billing"
|
|
_METADATA_KEY = "poweron_plan_key"
|
|
_METADATA_LINE_TYPE = "poweron_line_type"
|
|
|
|
_PERIOD_TO_STRIPE = {
|
|
BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1},
|
|
BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1},
|
|
}
|
|
|
|
|
|
def _getBillingDb() -> DatabaseConnector:
|
|
return DatabaseConnector(
|
|
dbDatabase=_BILLING_DATABASE,
|
|
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
|
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
|
|
dbUser=APP_CONFIG.get("DB_USER"),
|
|
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
|
|
)
|
|
|
|
|
|
def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]:
|
|
try:
|
|
rows = db.getRecordset(StripePlanPrice)
|
|
result = {}
|
|
for row in rows:
|
|
pk = row.get("planKey")
|
|
if pk:
|
|
result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")})
|
|
return result
|
|
except Exception as e:
|
|
logger.warning("Could not load StripePlanPrice records: %s", e)
|
|
return {}
|
|
|
|
|
|
def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]:
|
|
"""Search Stripe for a product tagged with plan key + line type."""
|
|
try:
|
|
products = stripe.Product.search(
|
|
query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"',
|
|
limit=1,
|
|
)
|
|
if products.data:
|
|
return products.data[0].id
|
|
except Exception:
|
|
try:
|
|
products = stripe.Product.search(
|
|
query=f'metadata["{_METADATA_KEY}"]:"{planKey}"',
|
|
limit=10,
|
|
)
|
|
for p in products.data:
|
|
meta = p.get("metadata") or {}
|
|
if meta.get(_METADATA_LINE_TYPE) == lineType:
|
|
return p.id
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str:
|
|
product = stripe.Product.create(
|
|
name=name,
|
|
description=description,
|
|
metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType},
|
|
)
|
|
logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType)
|
|
return product.id
|
|
|
|
|
|
def _findExistingStripePrice(stripe, productId: str, unitAmount: int, interval: str) -> Optional[str]:
|
|
try:
|
|
prices = stripe.Price.list(product=productId, active=True, limit=50)
|
|
for p in prices.data:
|
|
recurring = p.get("recurring") or {}
|
|
if p.get("unit_amount") == unitAmount and recurring.get("interval") == interval:
|
|
return p.id
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def _createStripePrice(stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str) -> str:
|
|
price = stripe.Price.create(
|
|
product=productId,
|
|
unit_amount=int(unitAmountCHF * 100),
|
|
currency="chf",
|
|
recurring={"interval": interval},
|
|
nickname=nickname,
|
|
)
|
|
logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval)
|
|
return price.id
|
|
|
|
|
|
def bootstrapStripePrices() -> None:
|
|
"""Ensure all paid plans have separate Stripe Products for users and instances."""
|
|
try:
|
|
from modules.shared.stripeClient import getStripeClient
|
|
stripe = getStripeClient()
|
|
except ValueError as e:
|
|
logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e)
|
|
return
|
|
|
|
db = _getBillingDb()
|
|
existing = _loadExistingMappings(db)
|
|
|
|
for planKey, plan in BUILTIN_PLANS.items():
|
|
if plan.billingPeriod == BillingPeriodEnum.NONE:
|
|
continue
|
|
if plan.pricePerUserCHF == 0 and plan.pricePerFeatureInstanceCHF == 0:
|
|
continue
|
|
|
|
stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod)
|
|
if not stripePeriod:
|
|
continue
|
|
|
|
interval = stripePeriod["interval"]
|
|
|
|
if planKey in existing:
|
|
mapping = existing[planKey]
|
|
hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances
|
|
hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances
|
|
if hasAllPrices and hasAllProducts:
|
|
logger.debug("Stripe prices already configured for plan %s", planKey)
|
|
continue
|
|
|
|
productIdUsers = None
|
|
productIdInstances = None
|
|
priceIdUsers = None
|
|
priceIdInstances = None
|
|
|
|
if plan.pricePerUserCHF > 0:
|
|
productIdUsers = _findStripeProduct(stripe, planKey, "users")
|
|
if not productIdUsers:
|
|
productIdUsers = _createStripeProduct(
|
|
stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title.get('de', planKey)}",
|
|
planKey, "users",
|
|
)
|
|
priceIdUsers = _findExistingStripePrice(stripe, productIdUsers, int(plan.pricePerUserCHF * 100), interval)
|
|
if not priceIdUsers:
|
|
priceIdUsers = _createStripePrice(
|
|
stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
|
|
)
|
|
|
|
if plan.pricePerFeatureInstanceCHF > 0:
|
|
productIdInstances = _findStripeProduct(stripe, planKey, "instances")
|
|
if not productIdInstances:
|
|
productIdInstances = _createStripeProduct(
|
|
stripe, "Feature-Instanzen", f"Feature-Instanzen für {plan.title.get('de', planKey)}",
|
|
planKey, "instances",
|
|
)
|
|
priceIdInstances = _findExistingStripePrice(
|
|
stripe, productIdInstances, int(plan.pricePerFeatureInstanceCHF * 100), interval,
|
|
)
|
|
if not priceIdInstances:
|
|
priceIdInstances = _createStripePrice(
|
|
stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval,
|
|
f"{planKey} — Feature-Instanz",
|
|
)
|
|
|
|
persistData = {
|
|
"stripeProductId": "",
|
|
"stripeProductIdUsers": productIdUsers,
|
|
"stripeProductIdInstances": productIdInstances,
|
|
"stripePriceIdUsers": priceIdUsers,
|
|
"stripePriceIdInstances": priceIdInstances,
|
|
}
|
|
|
|
if planKey in existing:
|
|
db.recordModify(StripePlanPrice, existing[planKey].id, persistData)
|
|
else:
|
|
db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump())
|
|
|
|
logger.info(
|
|
"Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s",
|
|
planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances,
|
|
)
|
|
|
|
|
|
def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]:
|
|
"""Load the persisted Stripe IDs for a plan."""
|
|
try:
|
|
db = _getBillingDb()
|
|
rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey})
|
|
if rows:
|
|
return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")})
|
|
except Exception as e:
|
|
logger.error("Error loading Stripe prices for plan %s: %s", planKey, e)
|
|
return None
|