# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Auto-provision Stripe Products and Prices from the built-in plan catalog. Creates separate Stripe Products for user licenses and feature instances so that invoice line items show clear, descriptive names: - "Benutzer-Lizenzen" - "Feature-Instanzen" Idempotent — safe to call on every startup. """ import logging from typing import Dict, Optional from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG from modules.datamodels.datamodelSubscription import ( BUILTIN_PLANS, SubscriptionPlan, BillingPeriodEnum, StripePlanPrice, ) logger = logging.getLogger(__name__) _BILLING_DATABASE = "poweron_billing" _METADATA_KEY = "poweron_plan_key" _METADATA_LINE_TYPE = "poweron_line_type" _PERIOD_TO_STRIPE = { BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1}, BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1}, } def _getBillingDb() -> DatabaseConnector: return DatabaseConnector( dbDatabase=_BILLING_DATABASE, dbHost=APP_CONFIG.get("DB_HOST", "localhost"), dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), dbUser=APP_CONFIG.get("DB_USER"), dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), ) def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]: try: rows = db.getRecordset(StripePlanPrice) result = {} for row in rows: pk = row.get("planKey") if pk: result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")}) return result except Exception as e: logger.warning("Could not load StripePlanPrice records: %s", e) return {} def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]: """Search Stripe for a product tagged with plan key + line type.""" try: products = stripe.Product.search( query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"', limit=1, ) if products.data: return products.data[0].id except Exception: try: products = stripe.Product.search( query=f'metadata["{_METADATA_KEY}"]:"{planKey}"', limit=10, ) for p in products.data: meta = p.get("metadata") or {} if meta.get(_METADATA_LINE_TYPE) == lineType: return p.id except Exception: pass return None def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str: product = stripe.Product.create( name=name, description=description, metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType}, ) logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType) return product.id def _findExistingStripePrice(stripe, productId: str, unitAmount: int, interval: str) -> Optional[str]: try: prices = stripe.Price.list(product=productId, active=True, limit=50) for p in prices.data: recurring = p.get("recurring") or {} if p.get("unit_amount") == unitAmount and recurring.get("interval") == interval: return p.id except Exception: pass return None def _getStripePriceAmount(stripe, priceId: str) -> Optional[int]: """Retrieve the unit_amount (in Rappen) of an existing Stripe Price.""" try: price = stripe.Price.retrieve(priceId) return price.get("unit_amount") if price else None except Exception: return None def _reconcilePrice(stripe, productId: str, oldPriceId: str, expectedCHF: float, interval: str, nickname: str) -> str: """If the stored Stripe Price has a different amount, create a new one and deactivate the old.""" expectedCents = int(expectedCHF * 100) actualCents = _getStripePriceAmount(stripe, oldPriceId) if actualCents == expectedCents: return oldPriceId logger.warning( "Price drift detected for %s: Stripe has %s Rappen, catalog expects %s Rappen. Rotating price.", oldPriceId, actualCents, expectedCents, ) existingMatch = _findExistingStripePrice(stripe, productId, expectedCents, interval) if existingMatch: newPriceId = existingMatch else: newPriceId = _createStripePrice(stripe, productId, expectedCHF, interval, nickname) try: stripe.Price.modify(oldPriceId, active=False) logger.info("Deactivated old Stripe Price %s", oldPriceId) except Exception as e: logger.warning("Could not deactivate old price %s: %s", oldPriceId, e) return newPriceId def _createStripePrice(stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str) -> str: price = stripe.Price.create( product=productId, unit_amount=int(unitAmountCHF * 100), currency="chf", recurring={"interval": interval}, nickname=nickname, ) logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval) return price.id def bootstrapStripePrices() -> None: """Ensure all paid plans have separate Stripe Products for users and instances.""" try: from modules.shared.stripeClient import getStripeClient stripe = getStripeClient() except ValueError as e: logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e) return db = _getBillingDb() existing = _loadExistingMappings(db) for planKey, plan in BUILTIN_PLANS.items(): if plan.billingPeriod == BillingPeriodEnum.NONE: continue if plan.pricePerUserCHF == 0 and plan.pricePerFeatureInstanceCHF == 0: continue stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod) if not stripePeriod: continue interval = stripePeriod["interval"] if planKey in existing: mapping = existing[planKey] hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances if hasAllPrices and hasAllProducts: changed = False reconciledUsers = _reconcilePrice( stripe, mapping.stripeProductIdUsers, mapping.stripePriceIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz", ) if reconciledUsers != mapping.stripePriceIdUsers: changed = True reconciledInstances = _reconcilePrice( stripe, mapping.stripeProductIdInstances, mapping.stripePriceIdInstances, plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Feature-Instanz", ) if reconciledInstances != mapping.stripePriceIdInstances: changed = True if changed: db.recordModify(StripePlanPrice, mapping.id, { "stripePriceIdUsers": reconciledUsers, "stripePriceIdInstances": reconciledInstances, }) logger.info("Reconciled Stripe prices for plan %s: users=%s, instances=%s", planKey, reconciledUsers, reconciledInstances) else: logger.debug("Stripe prices up-to-date for plan %s", planKey) continue productIdUsers = None productIdInstances = None priceIdUsers = None priceIdInstances = None if plan.pricePerUserCHF > 0: productIdUsers = _findStripeProduct(stripe, planKey, "users") if not productIdUsers: productIdUsers = _createStripeProduct( stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title.get('de', planKey)}", planKey, "users", ) priceIdUsers = _findExistingStripePrice(stripe, productIdUsers, int(plan.pricePerUserCHF * 100), interval) if not priceIdUsers: priceIdUsers = _createStripePrice( stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz", ) if plan.pricePerFeatureInstanceCHF > 0: productIdInstances = _findStripeProduct(stripe, planKey, "instances") if not productIdInstances: productIdInstances = _createStripeProduct( stripe, "Feature-Instanzen", f"Feature-Instanzen für {plan.title.get('de', planKey)}", planKey, "instances", ) priceIdInstances = _findExistingStripePrice( stripe, productIdInstances, int(plan.pricePerFeatureInstanceCHF * 100), interval, ) if not priceIdInstances: priceIdInstances = _createStripePrice( stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Feature-Instanz", ) persistData = { "stripeProductId": "", "stripeProductIdUsers": productIdUsers, "stripeProductIdInstances": productIdInstances, "stripePriceIdUsers": priceIdUsers, "stripePriceIdInstances": priceIdInstances, } if planKey in existing: db.recordModify(StripePlanPrice, existing[planKey].id, persistData) else: db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump()) logger.info( "Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s", planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances, ) def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]: """Load the persisted Stripe IDs for a plan.""" try: db = _getBillingDb() rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey}) if rows: return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")}) except Exception as e: logger.error("Error loading Stripe prices for plan %s: %s", planKey, e) return None