404 lines
18 KiB
Python
404 lines
18 KiB
Python
# Copyright (c) 2025 Patrick Motsch
|
|
# All rights reserved.
|
|
"""
|
|
Interface for Subscription operations — ID-based, deterministic.
|
|
|
|
Every write operation takes an explicit subscriptionId.
|
|
No status-scan guessing. See wiki/concepts/Subscription-State-Machine.md.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Dict, Any, List, Optional
|
|
from datetime import datetime, timezone
|
|
|
|
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
|
from modules.shared.configuration import APP_CONFIG
|
|
from modules.shared.dbRegistry import registerDatabase
|
|
from modules.datamodels.datamodelUam import User
|
|
from modules.datamodels.datamodelMembership import UserMandate
|
|
from modules.datamodels.datamodelSubscription import (
|
|
SubscriptionPlan,
|
|
MandateSubscription,
|
|
SubscriptionStatusEnum,
|
|
BillingPeriodEnum,
|
|
ALLOWED_TRANSITIONS,
|
|
TERMINAL_STATUSES,
|
|
OPERATIVE_STATUSES,
|
|
BUILTIN_PLANS,
|
|
getPlan as getPlanFromCatalog,
|
|
_getSelectablePlans,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SUBSCRIPTION_DATABASE = "poweron_billing"
|
|
registerDatabase(SUBSCRIPTION_DATABASE)
|
|
|
|
_subscriptionInterfaces: Dict[str, "SubscriptionObjects"] = {}
|
|
|
|
|
|
class InvalidTransitionError(Exception):
|
|
"""Raised when a state transition is not allowed by the state machine."""
|
|
def __init__(self, subscriptionId: str, fromStatus: str, toStatus: str):
|
|
self.subscriptionId = subscriptionId
|
|
self.fromStatus = fromStatus
|
|
self.toStatus = toStatus
|
|
super().__init__(f"Invalid transition {fromStatus} -> {toStatus} for subscription {subscriptionId}")
|
|
|
|
|
|
def getInterface(currentUser: User, mandateId: str = None) -> "SubscriptionObjects":
|
|
cacheKey = f"{currentUser.id}_{mandateId}"
|
|
if cacheKey not in _subscriptionInterfaces:
|
|
_subscriptionInterfaces[cacheKey] = SubscriptionObjects(currentUser, mandateId)
|
|
else:
|
|
_subscriptionInterfaces[cacheKey].setUserContext(currentUser, mandateId)
|
|
return _subscriptionInterfaces[cacheKey]
|
|
|
|
|
|
def getRootInterface() -> "SubscriptionObjects":
|
|
from modules.security.rootAccess import getRootUser
|
|
return SubscriptionObjects(getRootUser(), mandateId=None)
|
|
|
|
|
|
def _getAppDatabaseConnector() -> DatabaseConnector:
|
|
return DatabaseConnector(
|
|
dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"),
|
|
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
|
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
|
|
dbUser=APP_CONFIG.get("DB_USER"),
|
|
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
|
|
)
|
|
|
|
|
|
class SubscriptionObjects:
|
|
"""Interface for subscription operations: CRUD, gate checks, Stripe sync.
|
|
All writes are ID-based. All status changes go through transitionStatus()."""
|
|
|
|
def __init__(self, currentUser: Optional[User] = None, mandateId: str = None):
|
|
self.currentUser = currentUser
|
|
self.userId = currentUser.id if currentUser else None
|
|
self.mandateId = mandateId
|
|
self.db = DatabaseConnector(
|
|
dbDatabase=SUBSCRIPTION_DATABASE,
|
|
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
|
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
|
|
dbUser=APP_CONFIG.get("DB_USER"),
|
|
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
|
|
)
|
|
|
|
def setUserContext(self, currentUser: User, mandateId: str = None):
|
|
self.currentUser = currentUser
|
|
self.userId = currentUser.id if currentUser else None
|
|
self.mandateId = mandateId
|
|
|
|
# =========================================================================
|
|
# Plan catalog (in-memory)
|
|
# =========================================================================
|
|
|
|
def getPlan(self, planKey: str) -> Optional[SubscriptionPlan]:
|
|
return getPlanFromCatalog(planKey)
|
|
|
|
def getSelectablePlans(self) -> List[SubscriptionPlan]:
|
|
return _getSelectablePlans()
|
|
|
|
# =========================================================================
|
|
# Read: by ID (primary access pattern)
|
|
# =========================================================================
|
|
|
|
def getById(self, subscriptionId: str) -> Optional[Dict[str, Any]]:
|
|
"""Load a single subscription by its primary key."""
|
|
try:
|
|
results = self.db.getRecordset(MandateSubscription, recordFilter={"id": subscriptionId})
|
|
return dict(results[0]) if results else None
|
|
except Exception as e:
|
|
logger.error("getById(%s) failed: %s", subscriptionId, e)
|
|
return None
|
|
|
|
def getByStripeSubscriptionId(self, stripeSubId: str) -> Optional[Dict[str, Any]]:
|
|
"""Load subscription by Stripe subscription ID — the webhook resolution path."""
|
|
try:
|
|
results = self.db.getRecordset(MandateSubscription, recordFilter={"stripeSubscriptionId": stripeSubId})
|
|
return dict(results[0]) if results else None
|
|
except Exception as e:
|
|
logger.error("getByStripeSubscriptionId(%s) failed: %s", stripeSubId, e)
|
|
return None
|
|
|
|
# =========================================================================
|
|
# Read: by mandate (list queries)
|
|
# =========================================================================
|
|
|
|
def listForMandate(self, mandateId: str, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]:
|
|
"""Return all subscriptions for a mandate, optionally filtered by status.
|
|
Sorted newest-first by startedAt."""
|
|
try:
|
|
results = self.db.getRecordset(MandateSubscription, recordFilter={"mandateId": mandateId})
|
|
rows = [dict(r) for r in results]
|
|
if statusFilter:
|
|
filterValues = {s.value for s in statusFilter}
|
|
rows = [r for r in rows if r.get("status") in filterValues]
|
|
rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True)
|
|
return rows
|
|
except Exception as e:
|
|
logger.error("listForMandate(%s) failed: %s", mandateId, e)
|
|
return []
|
|
|
|
def getOperativeForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]:
|
|
"""Return the single operative subscription (ACTIVE, TRIALING, or PAST_DUE).
|
|
This is a read-only query for the billing gate. Returns None if no operative sub exists."""
|
|
for row in self.listForMandate(mandateId):
|
|
if row.get("status") in {s.value for s in OPERATIVE_STATUSES}:
|
|
return row
|
|
return None
|
|
|
|
def getScheduledForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]:
|
|
"""Return a SCHEDULED subscription if one exists (next sub waiting to start)."""
|
|
for row in self.listForMandate(mandateId, [SubscriptionStatusEnum.SCHEDULED]):
|
|
return row
|
|
return None
|
|
|
|
# =========================================================================
|
|
# Read: global (SysAdmin)
|
|
# =========================================================================
|
|
|
|
def listAll(self, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]:
|
|
"""Return ALL subscriptions across all mandates, newest-first. SysAdmin use only."""
|
|
try:
|
|
results = self.db.getRecordset(MandateSubscription)
|
|
rows = [dict(r) for r in results]
|
|
if statusFilter:
|
|
filterValues = {s.value for s in statusFilter}
|
|
rows = [r for r in rows if r.get("status") in filterValues]
|
|
rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True)
|
|
return rows
|
|
except Exception as e:
|
|
logger.error("listAll() failed: %s", e)
|
|
return []
|
|
|
|
# =========================================================================
|
|
# Write: create
|
|
# =========================================================================
|
|
|
|
def createSubscription(self, sub: MandateSubscription) -> Dict[str, Any]:
|
|
"""Persist a new MandateSubscription record."""
|
|
return self.db.recordCreate(MandateSubscription, sub.model_dump())
|
|
|
|
# =========================================================================
|
|
# Write: update fields (no status change)
|
|
# =========================================================================
|
|
|
|
def updateFields(self, subscriptionId: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Update non-status fields on a subscription (e.g. recurring, Stripe IDs, periods).
|
|
Must NOT be used for status changes — use transitionStatus() for that."""
|
|
if "status" in data:
|
|
raise ValueError("updateFields must not change status — use transitionStatus()")
|
|
return self.db.recordModify(MandateSubscription, subscriptionId, data)
|
|
|
|
# =========================================================================
|
|
# Write: status transition (guarded)
|
|
# =========================================================================
|
|
|
|
def transitionStatus(
|
|
self,
|
|
subscriptionId: str,
|
|
expectedFromStatus: SubscriptionStatusEnum,
|
|
toStatus: SubscriptionStatusEnum,
|
|
additionalData: Dict[str, Any] = None,
|
|
) -> Dict[str, Any]:
|
|
"""Execute a guarded status transition.
|
|
|
|
1. Load the record by ID
|
|
2. Verify current status matches expectedFromStatus
|
|
3. Verify the transition is allowed by the state machine
|
|
4. Apply the update
|
|
"""
|
|
sub = self.getById(subscriptionId)
|
|
if not sub:
|
|
raise ValueError(f"Subscription {subscriptionId} not found")
|
|
|
|
currentStatus = sub.get("status", "")
|
|
if currentStatus != expectedFromStatus.value:
|
|
raise InvalidTransitionError(subscriptionId, currentStatus, toStatus.value)
|
|
|
|
if (expectedFromStatus, toStatus) not in ALLOWED_TRANSITIONS:
|
|
raise InvalidTransitionError(subscriptionId, expectedFromStatus.value, toStatus.value)
|
|
|
|
updateData = {"status": toStatus.value}
|
|
if toStatus in TERMINAL_STATUSES and not (additionalData or {}).get("endedAt"):
|
|
updateData["endedAt"] = datetime.now(timezone.utc).isoformat()
|
|
if additionalData:
|
|
updateData.update(additionalData)
|
|
|
|
result = self.db.recordModify(MandateSubscription, subscriptionId, updateData)
|
|
logger.info("Transition %s -> %s for subscription %s", expectedFromStatus.value, toStatus.value, subscriptionId)
|
|
return result
|
|
|
|
def forceExpire(self, subscriptionId: str) -> Dict[str, Any]:
|
|
"""Sysadmin force-expire: ANY non-terminal -> EXPIRED. Bypasses normal transition guards."""
|
|
sub = self.getById(subscriptionId)
|
|
if not sub:
|
|
raise ValueError(f"Subscription {subscriptionId} not found")
|
|
|
|
currentStatus = sub.get("status", "")
|
|
if currentStatus == SubscriptionStatusEnum.EXPIRED.value:
|
|
raise ValueError(f"Subscription {subscriptionId} is already EXPIRED")
|
|
|
|
result = self.db.recordModify(MandateSubscription, subscriptionId, {
|
|
"status": SubscriptionStatusEnum.EXPIRED.value,
|
|
"endedAt": datetime.now(timezone.utc).isoformat(),
|
|
})
|
|
logger.info("Force-expired subscription %s (was %s)", subscriptionId, currentStatus)
|
|
return result
|
|
|
|
# =========================================================================
|
|
# Gate: assertActive (read-only, for billing gate)
|
|
# =========================================================================
|
|
|
|
def assertActive(self, mandateId: str) -> SubscriptionStatusEnum:
|
|
"""Return effective status for billing decisions.
|
|
Returns the operative subscription's status, or EXPIRED if none exists.
|
|
This is the ONLY read-by-mandate operation used in the hot path."""
|
|
sub = self.getOperativeForMandate(mandateId)
|
|
if sub:
|
|
return SubscriptionStatusEnum(sub["status"])
|
|
return SubscriptionStatusEnum.EXPIRED
|
|
|
|
# =========================================================================
|
|
# Gate: assertCapacity
|
|
# =========================================================================
|
|
|
|
def assertCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> bool:
|
|
sub = self.getOperativeForMandate(mandateId)
|
|
if not sub:
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
|
|
raise SubscriptionCapacityException(
|
|
resourceType=resourceType, currentCount=0, maxAllowed=0,
|
|
message="No active subscription for this mandate.",
|
|
)
|
|
|
|
plan = self.getPlan(sub.get("planKey", ""))
|
|
if not plan:
|
|
return True
|
|
|
|
if resourceType == "users":
|
|
cap = plan.maxUsers
|
|
if cap is None:
|
|
return True
|
|
current = self.countActiveUsers(mandateId)
|
|
if current + delta > cap:
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
|
|
raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap)
|
|
elif resourceType == "featureInstances":
|
|
cap = plan.maxFeatureInstances
|
|
if cap is None:
|
|
return True
|
|
current = self.countActiveFeatureInstances(mandateId)
|
|
if current + delta > cap:
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
|
|
raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap)
|
|
elif resourceType == "dataVolumeMB":
|
|
cap = plan.maxDataVolumeMB
|
|
if cap is None:
|
|
return True
|
|
currentMB = self.getMandateDataVolumeMB(mandateId)
|
|
if currentMB + delta > cap:
|
|
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
|
|
raise SubscriptionCapacityException(resourceType=resourceType, currentCount=int(currentMB), maxAllowed=cap)
|
|
|
|
return True
|
|
|
|
def getMandateDataVolumeMB(self, mandateId: str) -> float:
|
|
"""Total indexed data volume for the mandate (MB), for billing and capacity checks."""
|
|
return self._getMandateDataVolumeMB(mandateId)
|
|
|
|
def _getMandateDataVolumeMB(self, mandateId: str) -> float:
|
|
"""Sum RAG index size (FileContentIndex.totalSize) for the mandate; reads poweron_knowledge."""
|
|
try:
|
|
from modules.interfaces.interfaceDbKnowledge import aggregateMandateRagTotalBytes
|
|
|
|
return aggregateMandateRagTotalBytes(mandateId) / (1024 * 1024)
|
|
except Exception:
|
|
return 0.0
|
|
|
|
def getDataVolumeWarning(self, mandateId: str) -> Optional[Dict[str, Any]]:
|
|
"""Return a warning dict if mandate uses >=80% of maxDataVolumeMB, else None."""
|
|
sub = self.getOperativeForMandate(mandateId)
|
|
if not sub:
|
|
return None
|
|
plan = self.getPlan(sub.get("planKey", ""))
|
|
if not plan or not plan.maxDataVolumeMB:
|
|
return None
|
|
usedMB = self.getMandateDataVolumeMB(mandateId)
|
|
limitMB = plan.maxDataVolumeMB
|
|
percent = (usedMB / limitMB * 100) if limitMB > 0 else 0
|
|
if percent >= 80:
|
|
return {"usedMB": round(usedMB, 2), "limitMB": limitMB, "percent": round(percent, 1), "warning": True}
|
|
return {"usedMB": round(usedMB, 2), "limitMB": limitMB, "percent": round(percent, 1), "warning": False}
|
|
|
|
# =========================================================================
|
|
# Counting (cross-DB queries against poweron_app)
|
|
# =========================================================================
|
|
|
|
def countActiveUsers(self, mandateId: str) -> int:
|
|
try:
|
|
appDb = _getAppDatabaseConnector()
|
|
return len(appDb.getRecordset(UserMandate, recordFilter={"mandateId": mandateId}))
|
|
except Exception as e:
|
|
logger.error("countActiveUsers(%s) failed: %s", mandateId, e)
|
|
return 0
|
|
|
|
def countActiveFeatureInstances(self, mandateId: str) -> int:
|
|
try:
|
|
from modules.datamodels.datamodelFeatures import FeatureInstance
|
|
appDb = _getAppDatabaseConnector()
|
|
return len(appDb.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId, "enabled": True}))
|
|
except Exception as e:
|
|
logger.error("countActiveFeatureInstances(%s) failed: %s", mandateId, e)
|
|
return 0
|
|
|
|
# =========================================================================
|
|
# Stripe quantity sync
|
|
# =========================================================================
|
|
|
|
def syncQuantityToStripe(self, subscriptionId: str, *, raiseOnError: bool = False) -> None:
|
|
"""Update Stripe subscription item quantities to match actual active counts.
|
|
Takes subscriptionId, not mandateId.
|
|
|
|
Args:
|
|
raiseOnError: If True, propagate Stripe API errors instead of logging them.
|
|
Use True for billing-critical paths (store activation).
|
|
"""
|
|
sub = self.getById(subscriptionId)
|
|
if not sub or not sub.get("stripeSubscriptionId"):
|
|
if raiseOnError:
|
|
raise ValueError(f"Subscription {subscriptionId} hat keine Stripe-Anbindung — Abrechnung nicht möglich.")
|
|
return
|
|
|
|
mandateId = sub["mandateId"]
|
|
itemIdUsers = sub.get("stripeItemIdUsers")
|
|
itemIdInstances = sub.get("stripeItemIdInstances")
|
|
|
|
plan = self.getPlan(sub.get("planKey", ""))
|
|
includedModules = plan.includedModules if plan else 0
|
|
|
|
try:
|
|
from modules.shared.stripeClient import getStripeClient
|
|
stripe = getStripeClient()
|
|
|
|
activeUsers = self.countActiveUsers(mandateId)
|
|
activeInstances = self.countActiveFeatureInstances(mandateId)
|
|
billableModules = max(0, activeInstances - includedModules)
|
|
|
|
if itemIdUsers:
|
|
stripe.SubscriptionItem.modify(
|
|
itemIdUsers, quantity=max(activeUsers, 0), proration_behavior="create_prorations",
|
|
)
|
|
if itemIdInstances:
|
|
stripe.SubscriptionItem.modify(
|
|
itemIdInstances, quantity=billableModules, proration_behavior="create_prorations",
|
|
)
|
|
|
|
logger.info("Stripe quantity synced for sub %s: users=%d, modules=%d (total=%d, included=%d)", subscriptionId, activeUsers, billableModules, activeInstances, includedModules)
|
|
except Exception as e:
|
|
logger.error("syncQuantityToStripe(%s) failed: %s", subscriptionId, e)
|
|
if raiseOnError:
|
|
raise
|