# Copyright (c) 2025 Patrick Motsch # All rights reserved. """ Interface for Subscription operations — ID-based, deterministic. Every write operation takes an explicit subscriptionId. No status-scan guessing. See wiki/concepts/Subscription-State-Machine.md. """ import logging from typing import Dict, Any, List, Optional from datetime import datetime, timezone from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG from modules.datamodels.datamodelUam import User from modules.datamodels.datamodelMembership import UserMandate from modules.datamodels.datamodelSubscription import ( SubscriptionPlan, MandateSubscription, SubscriptionStatusEnum, BillingPeriodEnum, ALLOWED_TRANSITIONS, TERMINAL_STATUSES, OPERATIVE_STATUSES, BUILTIN_PLANS, _getPlan, _getSelectablePlans, ) logger = logging.getLogger(__name__) SUBSCRIPTION_DATABASE = "poweron_billing" _subscriptionInterfaces: Dict[str, "SubscriptionObjects"] = {} class InvalidTransitionError(Exception): """Raised when a state transition is not allowed by the state machine.""" def __init__(self, subscriptionId: str, fromStatus: str, toStatus: str): self.subscriptionId = subscriptionId self.fromStatus = fromStatus self.toStatus = toStatus super().__init__(f"Invalid transition {fromStatus} -> {toStatus} for subscription {subscriptionId}") def getInterface(currentUser: User, mandateId: str = None) -> "SubscriptionObjects": cacheKey = f"{currentUser.id}_{mandateId}" if cacheKey not in _subscriptionInterfaces: _subscriptionInterfaces[cacheKey] = SubscriptionObjects(currentUser, mandateId) else: _subscriptionInterfaces[cacheKey].setUserContext(currentUser, mandateId) return _subscriptionInterfaces[cacheKey] def _getRootInterface() -> "SubscriptionObjects": from modules.security.rootAccess import getRootUser return SubscriptionObjects(getRootUser(), mandateId=None) def _getAppDatabaseConnector() -> DatabaseConnector: return DatabaseConnector( dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"), dbHost=APP_CONFIG.get("DB_HOST", "localhost"), dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), dbUser=APP_CONFIG.get("DB_USER"), dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), ) class SubscriptionObjects: """Interface for subscription operations: CRUD, gate checks, Stripe sync. All writes are ID-based. All status changes go through transitionStatus().""" def __init__(self, currentUser: Optional[User] = None, mandateId: str = None): self.currentUser = currentUser self.userId = currentUser.id if currentUser else None self.mandateId = mandateId self.db = DatabaseConnector( dbDatabase=SUBSCRIPTION_DATABASE, dbHost=APP_CONFIG.get("DB_HOST", "localhost"), dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), dbUser=APP_CONFIG.get("DB_USER"), dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), ) def setUserContext(self, currentUser: User, mandateId: str = None): self.currentUser = currentUser self.userId = currentUser.id if currentUser else None self.mandateId = mandateId # ========================================================================= # Plan catalog (in-memory) # ========================================================================= def getPlan(self, planKey: str) -> Optional[SubscriptionPlan]: return _getPlan(planKey) def getSelectablePlans(self) -> List[SubscriptionPlan]: return _getSelectablePlans() # ========================================================================= # Read: by ID (primary access pattern) # ========================================================================= def getById(self, subscriptionId: str) -> Optional[Dict[str, Any]]: """Load a single subscription by its primary key.""" try: results = self.db.getRecordset(MandateSubscription, recordFilter={"id": subscriptionId}) return dict(results[0]) if results else None except Exception as e: logger.error("getById(%s) failed: %s", subscriptionId, e) return None def getByStripeSubscriptionId(self, stripeSubId: str) -> Optional[Dict[str, Any]]: """Load subscription by Stripe subscription ID — the webhook resolution path.""" try: results = self.db.getRecordset(MandateSubscription, recordFilter={"stripeSubscriptionId": stripeSubId}) return dict(results[0]) if results else None except Exception as e: logger.error("getByStripeSubscriptionId(%s) failed: %s", stripeSubId, e) return None # ========================================================================= # Read: by mandate (list queries) # ========================================================================= def listForMandate(self, mandateId: str, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]: """Return all subscriptions for a mandate, optionally filtered by status. Sorted newest-first by startedAt.""" try: results = self.db.getRecordset(MandateSubscription, recordFilter={"mandateId": mandateId}) rows = [dict(r) for r in results] if statusFilter: filterValues = {s.value for s in statusFilter} rows = [r for r in rows if r.get("status") in filterValues] rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True) return rows except Exception as e: logger.error("listForMandate(%s) failed: %s", mandateId, e) return [] def getOperativeForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]: """Return the single operative subscription (ACTIVE, TRIALING, or PAST_DUE). This is a read-only query for the billing gate. Returns None if no operative sub exists.""" for row in self.listForMandate(mandateId): if row.get("status") in {s.value for s in OPERATIVE_STATUSES}: return row return None def getScheduledForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]: """Return a SCHEDULED subscription if one exists (next sub waiting to start).""" for row in self.listForMandate(mandateId, [SubscriptionStatusEnum.SCHEDULED]): return row return None # ========================================================================= # Read: global (SysAdmin) # ========================================================================= def listAll(self, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]: """Return ALL subscriptions across all mandates, newest-first. SysAdmin use only.""" try: results = self.db.getRecordset(MandateSubscription) rows = [dict(r) for r in results] if statusFilter: filterValues = {s.value for s in statusFilter} rows = [r for r in rows if r.get("status") in filterValues] rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True) return rows except Exception as e: logger.error("listAll() failed: %s", e) return [] # ========================================================================= # Write: create # ========================================================================= def createSubscription(self, sub: MandateSubscription) -> Dict[str, Any]: """Persist a new MandateSubscription record.""" return self.db.recordCreate(MandateSubscription, sub.model_dump()) # ========================================================================= # Write: update fields (no status change) # ========================================================================= def updateFields(self, subscriptionId: str, data: Dict[str, Any]) -> Dict[str, Any]: """Update non-status fields on a subscription (e.g. recurring, Stripe IDs, periods). Must NOT be used for status changes — use transitionStatus() for that.""" if "status" in data: raise ValueError("updateFields must not change status — use transitionStatus()") return self.db.recordModify(MandateSubscription, subscriptionId, data) # ========================================================================= # Write: status transition (guarded) # ========================================================================= def transitionStatus( self, subscriptionId: str, expectedFromStatus: SubscriptionStatusEnum, toStatus: SubscriptionStatusEnum, additionalData: Dict[str, Any] = None, ) -> Dict[str, Any]: """Execute a guarded status transition. 1. Load the record by ID 2. Verify current status matches expectedFromStatus 3. Verify the transition is allowed by the state machine 4. Apply the update """ sub = self.getById(subscriptionId) if not sub: raise ValueError(f"Subscription {subscriptionId} not found") currentStatus = sub.get("status", "") if currentStatus != expectedFromStatus.value: raise InvalidTransitionError(subscriptionId, currentStatus, toStatus.value) if (expectedFromStatus, toStatus) not in ALLOWED_TRANSITIONS: raise InvalidTransitionError(subscriptionId, expectedFromStatus.value, toStatus.value) updateData = {"status": toStatus.value} if toStatus in TERMINAL_STATUSES and not (additionalData or {}).get("endedAt"): updateData["endedAt"] = datetime.now(timezone.utc).isoformat() if additionalData: updateData.update(additionalData) result = self.db.recordModify(MandateSubscription, subscriptionId, updateData) logger.info("Transition %s -> %s for subscription %s", expectedFromStatus.value, toStatus.value, subscriptionId) return result def forceExpire(self, subscriptionId: str) -> Dict[str, Any]: """Sysadmin force-expire: ANY non-terminal -> EXPIRED. Bypasses normal transition guards.""" sub = self.getById(subscriptionId) if not sub: raise ValueError(f"Subscription {subscriptionId} not found") currentStatus = sub.get("status", "") if currentStatus == SubscriptionStatusEnum.EXPIRED.value: raise ValueError(f"Subscription {subscriptionId} is already EXPIRED") result = self.db.recordModify(MandateSubscription, subscriptionId, { "status": SubscriptionStatusEnum.EXPIRED.value, "endedAt": datetime.now(timezone.utc).isoformat(), }) logger.info("Force-expired subscription %s (was %s)", subscriptionId, currentStatus) return result # ========================================================================= # Gate: assertActive (read-only, for billing gate) # ========================================================================= def assertActive(self, mandateId: str) -> SubscriptionStatusEnum: """Return effective status for billing decisions. Returns the operative subscription's status, or EXPIRED if none exists. This is the ONLY read-by-mandate operation used in the hot path.""" sub = self.getOperativeForMandate(mandateId) if sub: return SubscriptionStatusEnum(sub["status"]) return SubscriptionStatusEnum.EXPIRED # ========================================================================= # Gate: assertCapacity # ========================================================================= def assertCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> bool: sub = self.getOperativeForMandate(mandateId) if not sub: from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException raise SubscriptionCapacityException( resourceType=resourceType, currentCount=0, maxAllowed=0, message="No active subscription for this mandate.", ) plan = self.getPlan(sub.get("planKey", "")) if not plan: return True if resourceType == "users": cap = plan.maxUsers if cap is None: return True current = self.countActiveUsers(mandateId) if current + delta > cap: from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap) elif resourceType == "featureInstances": cap = plan.maxFeatureInstances if cap is None: return True current = self.countActiveFeatureInstances(mandateId) if current + delta > cap: from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap) elif resourceType == "dataVolumeMB": cap = plan.maxDataVolumeMB if cap is None: return True currentMB = self._getMandateDataVolumeMB(mandateId) if currentMB + delta > cap: from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException raise SubscriptionCapacityException(resourceType=resourceType, currentCount=int(currentMB), maxAllowed=cap) return True def _getMandateDataVolumeMB(self, mandateId: str) -> float: """Sum RAG index size (FileContentIndex.totalSize) across all feature instances of the mandate.""" try: from modules.datamodels.datamodelKnowledge import FileContentIndex knowledgeDb = _getAppDatabaseConnector() indexes = knowledgeDb.getRecordset(FileContentIndex, recordFilter={"mandateId": mandateId}) totalBytes = sum(int(idx.get("totalSize") or 0) for idx in indexes) return totalBytes / (1024 * 1024) except Exception: return 0.0 def getDataVolumeWarning(self, mandateId: str) -> Optional[Dict[str, Any]]: """Return a warning dict if mandate uses >=80% of maxDataVolumeMB, else None.""" sub = self.getOperativeForMandate(mandateId) if not sub: return None plan = self.getPlan(sub.get("planKey", "")) if not plan or not plan.maxDataVolumeMB: return None usedMB = self._getMandateDataVolumeMB(mandateId) limitMB = plan.maxDataVolumeMB percent = (usedMB / limitMB * 100) if limitMB > 0 else 0 if percent >= 80: return {"usedMB": round(usedMB, 2), "limitMB": limitMB, "percent": round(percent, 1), "warning": True} return {"usedMB": round(usedMB, 2), "limitMB": limitMB, "percent": round(percent, 1), "warning": False} # ========================================================================= # Counting (cross-DB queries against poweron_app) # ========================================================================= def countActiveUsers(self, mandateId: str) -> int: try: appDb = _getAppDatabaseConnector() return len(appDb.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})) except Exception as e: logger.error("countActiveUsers(%s) failed: %s", mandateId, e) return 0 def countActiveFeatureInstances(self, mandateId: str) -> int: try: from modules.datamodels.datamodelFeatures import FeatureInstance appDb = _getAppDatabaseConnector() return len(appDb.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId, "enabled": True})) except Exception as e: logger.error("countActiveFeatureInstances(%s) failed: %s", mandateId, e) return 0 # ========================================================================= # Stripe quantity sync # ========================================================================= def syncQuantityToStripe(self, subscriptionId: str) -> None: """Update Stripe subscription item quantities to match actual active counts. Takes subscriptionId, not mandateId.""" sub = self.getById(subscriptionId) if not sub or not sub.get("stripeSubscriptionId"): return mandateId = sub["mandateId"] itemIdUsers = sub.get("stripeItemIdUsers") itemIdInstances = sub.get("stripeItemIdInstances") try: from modules.shared.stripeClient import getStripeClient stripe = getStripeClient() activeUsers = self.countActiveUsers(mandateId) activeInstances = self.countActiveFeatureInstances(mandateId) if itemIdUsers: stripe.SubscriptionItem.modify( itemIdUsers, quantity=max(activeUsers, 0), proration_behavior="create_prorations", ) if itemIdInstances: stripe.SubscriptionItem.modify( itemIdInstances, quantity=max(activeInstances, 0), proration_behavior="create_prorations", ) logger.info("Stripe quantity synced for sub %s: users=%d, instances=%d", subscriptionId, activeUsers, activeInstances) except Exception as e: logger.error("syncQuantityToStripe(%s) failed: %s", subscriptionId, e)