# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Auto-provision Stripe Products and Prices from the built-in plan catalog. Creates separate Stripe Products for user licenses and modules so that invoice line items show clear, descriptive names: - "Benutzer-Lizenzen" - "Module" Idempotent — safe to call on every startup. Source of truth for unit amounts is BUILTIN_PLANS (CHF). On each run, persisted Stripe Price IDs are reconciled: if Stripe's unit_amount differs from the catalog, a new Price is created, the old one is archived, and poweron_billing StripePlanPrice is updated. Other stale active Prices on the same Product (same recurring interval) are archived so only the catalog-matching Price stays active. """ import logging from typing import Dict, Optional from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG from modules.datamodels.datamodelSubscription import ( BUILTIN_PLANS, SubscriptionPlan, BillingPeriodEnum, StripePlanPrice, ) logger = logging.getLogger(__name__) _BILLING_DATABASE = "poweron_billing" _METADATA_KEY = "poweron_plan_key" _METADATA_LINE_TYPE = "poweron_line_type" _PERIOD_TO_STRIPE = { BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1}, BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1}, } def _getBillingDb() -> DatabaseConnector: return DatabaseConnector( dbDatabase=_BILLING_DATABASE, dbHost=APP_CONFIG.get("DB_HOST", "localhost"), dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), dbUser=APP_CONFIG.get("DB_USER"), dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), ) def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]: try: rows = db.getRecordset(StripePlanPrice) result = {} for row in rows: pk = row.get("planKey") if pk: result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")}) return result except Exception as e: logger.warning("Could not load StripePlanPrice records: %s", e) return {} def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]: """Search Stripe for a product tagged with plan key + line type.""" try: products = stripe.Product.search( query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"', limit=1, ) if products.data: return products.data[0].id except Exception: try: products = stripe.Product.search( query=f'metadata["{_METADATA_KEY}"]:"{planKey}"', limit=10, ) for p in products.data: meta = p.get("metadata") or {} if meta.get(_METADATA_LINE_TYPE) == lineType: return p.id except Exception: pass return None def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str: product = stripe.Product.create( name=name, description=description, metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType}, ) logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType) return product.id def _recurringMatches(recurring: Dict, interval: str, intervalCount: int) -> bool: if not recurring: return False if recurring.get("interval") != interval: return False ic = recurring.get("interval_count") if ic is None: ic = 1 return int(ic) == int(intervalCount) def _findExistingStripePrice( stripe, productId: str, unitAmount: int, interval: str, intervalCount: int = 1, ) -> Optional[str]: try: prices = stripe.Price.list(product=productId, active=True, limit=50) for p in prices.data: recurring = p.get("recurring") or {} if p.get("unit_amount") == unitAmount and _recurringMatches(recurring, interval, intervalCount): return p.id except Exception: pass return None def _reconcilePrice( stripe, productId: str, oldPriceId: str, expectedCHF: float, interval: str, nickname: str, intervalCount: int = 1, ) -> str: """If the stored Stripe Price has a different amount, is no longer active, or has a different recurring interval, create/find a new active one and deactivate the old.""" from modules.shared.stripeClient import stripeToDict expectedCents = int(round(expectedCHF * 100)) actualCents: Optional[int] = None matchesRecurring = False isActive = False retrieveFailed = False try: raw = stripe.Price.retrieve(oldPriceId) pd = stripeToDict(raw) actualCents = pd.get("unit_amount") matchesRecurring = _recurringMatches(pd.get("recurring") or {}, interval, intervalCount) # Stripe.Price.retrieve returns archived prices too, so we MUST check # `active` explicitly. Subscription.create rejects inactive prices with # "The price specified is inactive. This field only accepts active prices." isActive = bool(pd.get("active")) except Exception as ex: retrieveFailed = True logger.warning("Could not retrieve Stripe Price %s: %s", oldPriceId, ex) if not retrieveFailed and isActive and actualCents == expectedCents and matchesRecurring: return oldPriceId logger.warning( "Rotating Stripe Price %s on product %s: active=%s amount=%s (expected %s) recurringMatches=%s retrieveFailed=%s.", oldPriceId, productId, isActive, actualCents, expectedCents, matchesRecurring, retrieveFailed, ) existingMatch = _findExistingStripePrice(stripe, productId, expectedCents, interval, intervalCount) if existingMatch: newPriceId = existingMatch else: newPriceId = _createStripePrice( stripe, productId, expectedCHF, interval, nickname, intervalCount, ) try: stripe.Price.modify(oldPriceId, active=False) logger.info("Deactivated old Stripe Price %s", oldPriceId) except Exception as e: logger.warning("Could not deactivate old price %s: %s", oldPriceId, e) return newPriceId def _createStripePrice( stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str, intervalCount: int = 1, ) -> str: price = stripe.Price.create( product=productId, unit_amount=int(round(unitAmountCHF * 100)), currency="chf", recurring={"interval": interval, "interval_count": intervalCount}, nickname=nickname, ) logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval) return price.id def _archiveOtherRecurringPrices( stripe, productId: Optional[str], keepPriceId: Optional[str], interval: str, intervalCount: int = 1, ) -> None: """Archive every other active recurring price on the product (same interval pattern).""" if not productId or not keepPriceId: return try: prices = stripe.Price.list(product=productId, active=True, limit=100) for p in prices.data: if p.id == keepPriceId: continue recurring = p.get("recurring") or {} if not recurring: continue if not _recurringMatches(recurring, interval, intervalCount): continue try: stripe.Price.modify(p.id, active=False) logger.info("Archived stale Stripe Price %s on product %s", p.id, productId) except Exception as ex: logger.warning("Could not archive price %s: %s", p.id, ex) except Exception as e: logger.warning("Stale price archive pass failed for product %s: %s", productId, e) def _validateStripeIdsExist(stripe, mapping: StripePlanPrice) -> bool: """Quick check whether the stored Stripe product IDs still exist. Returns False when running against a different Stripe account or after a DB copy from another environment. Price-level validation (active flag, drift) is handled by ``_reconcilePrice``; we don't fail here on archived prices, otherwise we'd needlessly re-provision products on every rotation. """ try: if mapping.stripeProductIdUsers: stripe.Product.retrieve(mapping.stripeProductIdUsers) if mapping.stripeProductIdInstances: stripe.Product.retrieve(mapping.stripeProductIdInstances) return True except Exception as e: code = getattr(e, "code", None) if code == "resource_missing": return False logger.debug("Stripe validation check failed (non-critical): %s", e) return False def bootstrapStripePrices() -> None: """Ensure all paid plans have separate Stripe Products for users and instances.""" try: from modules.shared.stripeClient import getStripeClient stripe = getStripeClient() except ValueError as e: logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e) return db = _getBillingDb() existing = _loadExistingMappings(db) for planKey, plan in BUILTIN_PLANS.items(): if plan.billingPeriod == BillingPeriodEnum.NONE: continue if plan.pricePerUserCHF == 0 and plan.pricePerFeatureInstanceCHF == 0: continue stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod) if not stripePeriod: continue interval = stripePeriod["interval"] intervalCount = int(stripePeriod.get("interval_count") or 1) if planKey in existing: mapping = existing[planKey] hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances if hasAllPrices and hasAllProducts: if _validateStripeIdsExist(stripe, mapping): changed = False reconciledUsers = _reconcilePrice( stripe, mapping.stripeProductIdUsers, mapping.stripePriceIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz", intervalCount, ) if reconciledUsers != mapping.stripePriceIdUsers: changed = True reconciledInstances = _reconcilePrice( stripe, mapping.stripeProductIdInstances, mapping.stripePriceIdInstances, plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Modul", intervalCount, ) if reconciledInstances != mapping.stripePriceIdInstances: changed = True _archiveOtherRecurringPrices( stripe, mapping.stripeProductIdUsers, reconciledUsers, interval, intervalCount, ) _archiveOtherRecurringPrices( stripe, mapping.stripeProductIdInstances, reconciledInstances, interval, intervalCount, ) if changed: db.recordModify(StripePlanPrice, mapping.id, { "stripePriceIdUsers": reconciledUsers, "stripePriceIdInstances": reconciledInstances, }) logger.info( "Reconciled Stripe prices for plan %s to catalog (CHF): users=%s, instances=%s", planKey, reconciledUsers, reconciledInstances, ) else: logger.debug("Stripe prices up-to-date for plan %s", planKey) continue else: logger.warning( "Stored Stripe IDs for plan %s reference unknown objects " "(likely wrong Stripe account or copied DB) — re-provisioning.", planKey, ) productIdUsers = None productIdInstances = None priceIdUsers = None priceIdInstances = None if plan.pricePerUserCHF > 0: productIdUsers = _findStripeProduct(stripe, planKey, "users") if not productIdUsers: productIdUsers = _createStripeProduct( stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title or planKey}", planKey, "users", ) userCents = int(round(plan.pricePerUserCHF * 100)) priceIdUsers = _findExistingStripePrice( stripe, productIdUsers, userCents, interval, intervalCount, ) if not priceIdUsers: priceIdUsers = _createStripePrice( stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz", intervalCount, ) _archiveOtherRecurringPrices(stripe, productIdUsers, priceIdUsers, interval, intervalCount) if plan.pricePerFeatureInstanceCHF > 0: productIdInstances = _findStripeProduct(stripe, planKey, "instances") if not productIdInstances: productIdInstances = _createStripeProduct( stripe, "Module", f"Module für {plan.title or planKey}", planKey, "instances", ) instCents = int(round(plan.pricePerFeatureInstanceCHF * 100)) priceIdInstances = _findExistingStripePrice( stripe, productIdInstances, instCents, interval, intervalCount, ) if not priceIdInstances: priceIdInstances = _createStripePrice( stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Modul", intervalCount, ) _archiveOtherRecurringPrices( stripe, productIdInstances, priceIdInstances, interval, intervalCount, ) persistData = { "stripeProductId": "", "stripeProductIdUsers": productIdUsers, "stripeProductIdInstances": productIdInstances, "stripePriceIdUsers": priceIdUsers, "stripePriceIdInstances": priceIdInstances, } if planKey in existing: db.recordModify(StripePlanPrice, existing[planKey].id, persistData) else: db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump()) logger.info( "Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s", planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances, ) def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]: """Load the persisted Stripe IDs for a plan.""" try: db = _getBillingDb() rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey}) if rows: return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")}) except Exception as e: logger.error("Error loading Stripe prices for plan %s: %s", planKey, e) return None