# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Auto-provision Stripe Products and Prices from the built-in plan catalog. Creates separate Stripe Products for user licenses and modules so that invoice line items show clear, descriptive names: - "Benutzer-Lizenzen" - "Module" Idempotent — safe to call on every startup. Source of truth for unit amounts is BUILTIN_PLANS (CHF). On each run, persisted Stripe Price IDs are reconciled: if Stripe's unit_amount differs from the catalog, a new Price is created, the old one is archived, and poweron_billing StripePlanPrice is updated. Other stale active Prices on the same Product (same recurring interval) are archived so only the catalog-matching Price stays active. """ import logging from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, Optional from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG from modules.datamodels.datamodelSubscription import ( BUILTIN_PLANS, SubscriptionPlan, BillingPeriodEnum, StripePlanPrice, ) logger = logging.getLogger(__name__) _BILLING_DATABASE = "poweron_billing" _METADATA_KEY = "poweron_plan_key" _METADATA_LINE_TYPE = "poweron_line_type" _PERIOD_TO_STRIPE = { BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1}, BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1}, } def _getBillingDb() -> DatabaseConnector: return DatabaseConnector( dbDatabase=_BILLING_DATABASE, dbHost=APP_CONFIG.get("DB_HOST", "localhost"), dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), dbUser=APP_CONFIG.get("DB_USER"), dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), ) def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]: try: rows = db.getRecordset(StripePlanPrice) result = {} for row in rows: pk = row.get("planKey") if pk: result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")}) return result except Exception as e: logger.warning("Could not load StripePlanPrice records: %s", e) return {} def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]: """Search Stripe for a product tagged with plan key + line type.""" try: products = stripe.Product.search( query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"', limit=1, ) if products.data: return products.data[0].id except Exception: try: products = stripe.Product.search( query=f'metadata["{_METADATA_KEY}"]:"{planKey}"', limit=10, ) for p in products.data: meta = p.get("metadata") or {} if meta.get(_METADATA_LINE_TYPE) == lineType: return p.id except Exception: pass return None def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str: product = stripe.Product.create( name=name, description=description, metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType}, ) logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType) return product.id def _recurringMatches(recurring: Dict, interval: str, intervalCount: int) -> bool: if not recurring: return False if recurring.get("interval") != interval: return False ic = recurring.get("interval_count") if ic is None: ic = 1 return int(ic) == int(intervalCount) def _findExistingStripePrice( stripe, productId: str, unitAmount: int, interval: str, intervalCount: int = 1, ) -> Optional[str]: try: prices = stripe.Price.list(product=productId, active=True, limit=50) for p in prices.data: recurring = p.get("recurring") or {} if p.get("unit_amount") == unitAmount and _recurringMatches(recurring, interval, intervalCount): return p.id except Exception: pass return None def _reconcilePrice( stripe, productId: str, oldPriceId: str, expectedCHF: float, interval: str, nickname: str, intervalCount: int = 1, ) -> str: """If the stored Stripe Price has a different amount, is no longer active, or has a different recurring interval, create/find a new active one and deactivate the old.""" from modules.shared.stripeClient import stripeToDict expectedCents = int(round(expectedCHF * 100)) actualCents: Optional[int] = None matchesRecurring = False isActive = False retrieveFailed = False try: raw = stripe.Price.retrieve(oldPriceId) pd = stripeToDict(raw) actualCents = pd.get("unit_amount") matchesRecurring = _recurringMatches(pd.get("recurring") or {}, interval, intervalCount) # Stripe.Price.retrieve returns archived prices too, so we MUST check # `active` explicitly. Subscription.create rejects inactive prices with # "The price specified is inactive. This field only accepts active prices." isActive = bool(pd.get("active")) except Exception as ex: retrieveFailed = True logger.warning("Could not retrieve Stripe Price %s: %s", oldPriceId, ex) if not retrieveFailed and isActive and actualCents == expectedCents and matchesRecurring: return oldPriceId logger.warning( "Rotating Stripe Price %s on product %s: active=%s amount=%s (expected %s) recurringMatches=%s retrieveFailed=%s.", oldPriceId, productId, isActive, actualCents, expectedCents, matchesRecurring, retrieveFailed, ) existingMatch = _findExistingStripePrice(stripe, productId, expectedCents, interval, intervalCount) if existingMatch: newPriceId = existingMatch else: newPriceId = _createStripePrice( stripe, productId, expectedCHF, interval, nickname, intervalCount, ) try: stripe.Price.modify(oldPriceId, active=False) logger.info("Deactivated old Stripe Price %s", oldPriceId) except Exception as e: logger.warning("Could not deactivate old price %s: %s", oldPriceId, e) return newPriceId def _createStripePrice( stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str, intervalCount: int = 1, ) -> str: price = stripe.Price.create( product=productId, unit_amount=int(round(unitAmountCHF * 100)), currency="chf", recurring={"interval": interval, "interval_count": intervalCount}, nickname=nickname, ) logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval) return price.id def _archiveOtherRecurringPrices( stripe, productId: Optional[str], keepPriceId: Optional[str], interval: str, intervalCount: int = 1, ) -> None: """Archive every other active recurring price on the product (same interval pattern).""" if not productId or not keepPriceId: return try: prices = stripe.Price.list(product=productId, active=True, limit=100) for p in prices.data: if p.id == keepPriceId: continue recurring = p.get("recurring") or {} if not recurring: continue if not _recurringMatches(recurring, interval, intervalCount): continue try: stripe.Price.modify(p.id, active=False) logger.info("Archived stale Stripe Price %s on product %s", p.id, productId) except Exception as ex: logger.warning("Could not archive price %s: %s", p.id, ex) except Exception as e: logger.warning("Stale price archive pass failed for product %s: %s", productId, e) def _validateStripeIdsExist(stripe, mapping: StripePlanPrice) -> bool: """Quick check whether the stored Stripe product IDs still exist. Returns False when running against a different Stripe account or after a DB copy from another environment. Price-level validation (active flag, drift) is handled by ``_reconcilePrice``; we don't fail here on archived prices, otherwise we'd needlessly re-provision products on every rotation. """ try: if mapping.stripeProductIdUsers: stripe.Product.retrieve(mapping.stripeProductIdUsers) if mapping.stripeProductIdInstances: stripe.Product.retrieve(mapping.stripeProductIdInstances) return True except Exception as e: code = getattr(e, "code", None) if code == "resource_missing": return False logger.debug("Stripe validation check failed (non-critical): %s", e) return False def _processOnePlan( stripe, planKey: str, plan: SubscriptionPlan, existingMapping: Optional[StripePlanPrice], ) -> None: """Reconcile or provision Stripe Products/Prices for a single plan. Each call uses its own DB connection so it is safe to run in a thread pool. """ stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod) if not stripePeriod: return interval = stripePeriod["interval"] intervalCount = int(stripePeriod.get("interval_count") or 1) db = _getBillingDb() if existingMapping: mapping = existingMapping hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances if hasAllPrices and hasAllProducts: if _validateStripeIdsExist(stripe, mapping): changed = False reconciledUsers = _reconcilePrice( stripe, mapping.stripeProductIdUsers, mapping.stripePriceIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz", intervalCount, ) if reconciledUsers != mapping.stripePriceIdUsers: changed = True reconciledInstances = _reconcilePrice( stripe, mapping.stripeProductIdInstances, mapping.stripePriceIdInstances, plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Modul", intervalCount, ) if reconciledInstances != mapping.stripePriceIdInstances: changed = True _archiveOtherRecurringPrices( stripe, mapping.stripeProductIdUsers, reconciledUsers, interval, intervalCount, ) _archiveOtherRecurringPrices( stripe, mapping.stripeProductIdInstances, reconciledInstances, interval, intervalCount, ) if changed: db.recordModify(StripePlanPrice, mapping.id, { "stripePriceIdUsers": reconciledUsers, "stripePriceIdInstances": reconciledInstances, }) logger.info( "Reconciled Stripe prices for plan %s to catalog (CHF): users=%s, instances=%s", planKey, reconciledUsers, reconciledInstances, ) else: logger.debug("Stripe prices up-to-date for plan %s", planKey) return else: logger.warning( "Stored Stripe IDs for plan %s reference unknown objects " "(likely wrong Stripe account or copied DB) — re-provisioning.", planKey, ) productIdUsers = None productIdInstances = None priceIdUsers = None priceIdInstances = None if plan.pricePerUserCHF > 0: productIdUsers = _findStripeProduct(stripe, planKey, "users") if not productIdUsers: productIdUsers = _createStripeProduct( stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title or planKey}", planKey, "users", ) userCents = int(round(plan.pricePerUserCHF * 100)) priceIdUsers = _findExistingStripePrice( stripe, productIdUsers, userCents, interval, intervalCount, ) if not priceIdUsers: priceIdUsers = _createStripePrice( stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz", intervalCount, ) _archiveOtherRecurringPrices(stripe, productIdUsers, priceIdUsers, interval, intervalCount) if plan.pricePerFeatureInstanceCHF > 0: productIdInstances = _findStripeProduct(stripe, planKey, "instances") if not productIdInstances: productIdInstances = _createStripeProduct( stripe, "Module", f"Module für {plan.title or planKey}", planKey, "instances", ) instCents = int(round(plan.pricePerFeatureInstanceCHF * 100)) priceIdInstances = _findExistingStripePrice( stripe, productIdInstances, instCents, interval, intervalCount, ) if not priceIdInstances: priceIdInstances = _createStripePrice( stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Modul", intervalCount, ) _archiveOtherRecurringPrices( stripe, productIdInstances, priceIdInstances, interval, intervalCount, ) persistData = { "stripeProductId": "", "stripeProductIdUsers": productIdUsers, "stripeProductIdInstances": productIdInstances, "stripePriceIdUsers": priceIdUsers, "stripePriceIdInstances": priceIdInstances, } if existingMapping: db.recordModify(StripePlanPrice, existingMapping.id, persistData) else: db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump()) logger.info( "Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s", planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances, ) def bootstrapStripePrices() -> None: """Ensure all paid plans have separate Stripe Products for users and instances. Plans are processed in parallel (one thread per plan) to reduce boot time. Each thread uses its own DB connection; Stripe SDK is thread-safe. """ try: from modules.shared.stripeClient import getStripeClient stripe = getStripeClient() except ValueError as e: logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e) return existing = _loadExistingMappings(_getBillingDb()) plans = [ (planKey, plan) for planKey, plan in BUILTIN_PLANS.items() if plan.billingPeriod != BillingPeriodEnum.NONE and (plan.pricePerUserCHF > 0 or plan.pricePerFeatureInstanceCHF > 0) ] if not plans: return with ThreadPoolExecutor(max_workers=len(plans)) as executor: futures = { executor.submit(_processOnePlan, stripe, planKey, plan, existing.get(planKey)): planKey for planKey, plan in plans } for future in as_completed(futures): planKey = futures[future] try: future.result() except Exception as e: logger.error("Stripe bootstrap failed for plan %s: %s", planKey, e) def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]: """Load the persisted Stripe IDs for a plan.""" try: db = _getBillingDb() rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey}) if rows: return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")}) except Exception as e: logger.error("Error loading Stripe prices for plan %s: %s", planKey, e) return None