gateway/modules/interfaces/interfaceDbSubscription.py
2026-04-10 22:44:08 +02:00

402 lines
18 KiB
Python

# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
Interface for Subscription operations — ID-based, deterministic.
Every write operation takes an explicit subscriptionId.
No status-scan guessing. See wiki/concepts/Subscription-State-Machine.md.
"""
import logging
from typing import Dict, Any, List, Optional
from datetime import datetime, timezone
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
from modules.datamodels.datamodelUam import User
from modules.datamodels.datamodelMembership import UserMandate
from modules.datamodels.datamodelSubscription import (
SubscriptionPlan,
MandateSubscription,
SubscriptionStatusEnum,
BillingPeriodEnum,
ALLOWED_TRANSITIONS,
TERMINAL_STATUSES,
OPERATIVE_STATUSES,
BUILTIN_PLANS,
_getPlan,
_getSelectablePlans,
)
logger = logging.getLogger(__name__)
SUBSCRIPTION_DATABASE = "poweron_billing"
_subscriptionInterfaces: Dict[str, "SubscriptionObjects"] = {}
class InvalidTransitionError(Exception):
"""Raised when a state transition is not allowed by the state machine."""
def __init__(self, subscriptionId: str, fromStatus: str, toStatus: str):
self.subscriptionId = subscriptionId
self.fromStatus = fromStatus
self.toStatus = toStatus
super().__init__(f"Invalid transition {fromStatus} -> {toStatus} for subscription {subscriptionId}")
def getInterface(currentUser: User, mandateId: str = None) -> "SubscriptionObjects":
cacheKey = f"{currentUser.id}_{mandateId}"
if cacheKey not in _subscriptionInterfaces:
_subscriptionInterfaces[cacheKey] = SubscriptionObjects(currentUser, mandateId)
else:
_subscriptionInterfaces[cacheKey].setUserContext(currentUser, mandateId)
return _subscriptionInterfaces[cacheKey]
def _getRootInterface() -> "SubscriptionObjects":
from modules.security.rootAccess import getRootUser
return SubscriptionObjects(getRootUser(), mandateId=None)
def _getAppDatabaseConnector() -> DatabaseConnector:
return DatabaseConnector(
dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"),
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
dbUser=APP_CONFIG.get("DB_USER"),
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
)
class SubscriptionObjects:
"""Interface for subscription operations: CRUD, gate checks, Stripe sync.
All writes are ID-based. All status changes go through transitionStatus()."""
def __init__(self, currentUser: Optional[User] = None, mandateId: str = None):
self.currentUser = currentUser
self.userId = currentUser.id if currentUser else None
self.mandateId = mandateId
self.db = DatabaseConnector(
dbDatabase=SUBSCRIPTION_DATABASE,
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
dbUser=APP_CONFIG.get("DB_USER"),
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
)
def setUserContext(self, currentUser: User, mandateId: str = None):
self.currentUser = currentUser
self.userId = currentUser.id if currentUser else None
self.mandateId = mandateId
# =========================================================================
# Plan catalog (in-memory)
# =========================================================================
def getPlan(self, planKey: str) -> Optional[SubscriptionPlan]:
return _getPlan(planKey)
def getSelectablePlans(self) -> List[SubscriptionPlan]:
return _getSelectablePlans()
# =========================================================================
# Read: by ID (primary access pattern)
# =========================================================================
def getById(self, subscriptionId: str) -> Optional[Dict[str, Any]]:
"""Load a single subscription by its primary key."""
try:
results = self.db.getRecordset(MandateSubscription, recordFilter={"id": subscriptionId})
return dict(results[0]) if results else None
except Exception as e:
logger.error("getById(%s) failed: %s", subscriptionId, e)
return None
def getByStripeSubscriptionId(self, stripeSubId: str) -> Optional[Dict[str, Any]]:
"""Load subscription by Stripe subscription ID — the webhook resolution path."""
try:
results = self.db.getRecordset(MandateSubscription, recordFilter={"stripeSubscriptionId": stripeSubId})
return dict(results[0]) if results else None
except Exception as e:
logger.error("getByStripeSubscriptionId(%s) failed: %s", stripeSubId, e)
return None
# =========================================================================
# Read: by mandate (list queries)
# =========================================================================
def listForMandate(self, mandateId: str, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]:
"""Return all subscriptions for a mandate, optionally filtered by status.
Sorted newest-first by startedAt."""
try:
results = self.db.getRecordset(MandateSubscription, recordFilter={"mandateId": mandateId})
rows = [dict(r) for r in results]
if statusFilter:
filterValues = {s.value for s in statusFilter}
rows = [r for r in rows if r.get("status") in filterValues]
rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True)
return rows
except Exception as e:
logger.error("listForMandate(%s) failed: %s", mandateId, e)
return []
def getOperativeForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]:
"""Return the single operative subscription (ACTIVE, TRIALING, or PAST_DUE).
This is a read-only query for the billing gate. Returns None if no operative sub exists."""
for row in self.listForMandate(mandateId):
if row.get("status") in {s.value for s in OPERATIVE_STATUSES}:
return row
return None
def getScheduledForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]:
"""Return a SCHEDULED subscription if one exists (next sub waiting to start)."""
for row in self.listForMandate(mandateId, [SubscriptionStatusEnum.SCHEDULED]):
return row
return None
# =========================================================================
# Read: global (SysAdmin)
# =========================================================================
def listAll(self, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]:
"""Return ALL subscriptions across all mandates, newest-first. SysAdmin use only."""
try:
results = self.db.getRecordset(MandateSubscription)
rows = [dict(r) for r in results]
if statusFilter:
filterValues = {s.value for s in statusFilter}
rows = [r for r in rows if r.get("status") in filterValues]
rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True)
return rows
except Exception as e:
logger.error("listAll() failed: %s", e)
return []
# =========================================================================
# Write: create
# =========================================================================
def createSubscription(self, sub: MandateSubscription) -> Dict[str, Any]:
"""Persist a new MandateSubscription record."""
return self.db.recordCreate(MandateSubscription, sub.model_dump())
# =========================================================================
# Write: update fields (no status change)
# =========================================================================
def updateFields(self, subscriptionId: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""Update non-status fields on a subscription (e.g. recurring, Stripe IDs, periods).
Must NOT be used for status changes — use transitionStatus() for that."""
if "status" in data:
raise ValueError("updateFields must not change status — use transitionStatus()")
return self.db.recordModify(MandateSubscription, subscriptionId, data)
# =========================================================================
# Write: status transition (guarded)
# =========================================================================
def transitionStatus(
self,
subscriptionId: str,
expectedFromStatus: SubscriptionStatusEnum,
toStatus: SubscriptionStatusEnum,
additionalData: Dict[str, Any] = None,
) -> Dict[str, Any]:
"""Execute a guarded status transition.
1. Load the record by ID
2. Verify current status matches expectedFromStatus
3. Verify the transition is allowed by the state machine
4. Apply the update
"""
sub = self.getById(subscriptionId)
if not sub:
raise ValueError(f"Subscription {subscriptionId} not found")
currentStatus = sub.get("status", "")
if currentStatus != expectedFromStatus.value:
raise InvalidTransitionError(subscriptionId, currentStatus, toStatus.value)
if (expectedFromStatus, toStatus) not in ALLOWED_TRANSITIONS:
raise InvalidTransitionError(subscriptionId, expectedFromStatus.value, toStatus.value)
updateData = {"status": toStatus.value}
if toStatus in TERMINAL_STATUSES and not (additionalData or {}).get("endedAt"):
updateData["endedAt"] = datetime.now(timezone.utc).isoformat()
if additionalData:
updateData.update(additionalData)
result = self.db.recordModify(MandateSubscription, subscriptionId, updateData)
logger.info("Transition %s -> %s for subscription %s", expectedFromStatus.value, toStatus.value, subscriptionId)
return result
def forceExpire(self, subscriptionId: str) -> Dict[str, Any]:
"""Sysadmin force-expire: ANY non-terminal -> EXPIRED. Bypasses normal transition guards."""
sub = self.getById(subscriptionId)
if not sub:
raise ValueError(f"Subscription {subscriptionId} not found")
currentStatus = sub.get("status", "")
if currentStatus == SubscriptionStatusEnum.EXPIRED.value:
raise ValueError(f"Subscription {subscriptionId} is already EXPIRED")
result = self.db.recordModify(MandateSubscription, subscriptionId, {
"status": SubscriptionStatusEnum.EXPIRED.value,
"endedAt": datetime.now(timezone.utc).isoformat(),
})
logger.info("Force-expired subscription %s (was %s)", subscriptionId, currentStatus)
return result
# =========================================================================
# Gate: assertActive (read-only, for billing gate)
# =========================================================================
def assertActive(self, mandateId: str) -> SubscriptionStatusEnum:
"""Return effective status for billing decisions.
Returns the operative subscription's status, or EXPIRED if none exists.
This is the ONLY read-by-mandate operation used in the hot path."""
sub = self.getOperativeForMandate(mandateId)
if sub:
return SubscriptionStatusEnum(sub["status"])
return SubscriptionStatusEnum.EXPIRED
# =========================================================================
# Gate: assertCapacity
# =========================================================================
def assertCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> bool:
sub = self.getOperativeForMandate(mandateId)
if not sub:
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
raise SubscriptionCapacityException(
resourceType=resourceType, currentCount=0, maxAllowed=0,
message="No active subscription for this mandate.",
)
plan = self.getPlan(sub.get("planKey", ""))
if not plan:
return True
if resourceType == "users":
cap = plan.maxUsers
if cap is None:
return True
current = self.countActiveUsers(mandateId)
if current + delta > cap:
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap)
elif resourceType == "featureInstances":
cap = plan.maxFeatureInstances
if cap is None:
return True
current = self.countActiveFeatureInstances(mandateId)
if current + delta > cap:
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap)
elif resourceType == "dataVolumeMB":
cap = plan.maxDataVolumeMB
if cap is None:
return True
currentMB = self.getMandateDataVolumeMB(mandateId)
if currentMB + delta > cap:
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
raise SubscriptionCapacityException(resourceType=resourceType, currentCount=int(currentMB), maxAllowed=cap)
return True
def getMandateDataVolumeMB(self, mandateId: str) -> float:
"""Total indexed data volume for the mandate (MB), for billing and capacity checks."""
return self._getMandateDataVolumeMB(mandateId)
def _getMandateDataVolumeMB(self, mandateId: str) -> float:
"""Sum RAG index size (FileContentIndex.totalSize) for the mandate; reads poweron_knowledge."""
try:
from modules.interfaces.interfaceDbKnowledge import aggregateMandateRagTotalBytes
return aggregateMandateRagTotalBytes(mandateId) / (1024 * 1024)
except Exception:
return 0.0
def getDataVolumeWarning(self, mandateId: str) -> Optional[Dict[str, Any]]:
"""Return a warning dict if mandate uses >=80% of maxDataVolumeMB, else None."""
sub = self.getOperativeForMandate(mandateId)
if not sub:
return None
plan = self.getPlan(sub.get("planKey", ""))
if not plan or not plan.maxDataVolumeMB:
return None
usedMB = self.getMandateDataVolumeMB(mandateId)
limitMB = plan.maxDataVolumeMB
percent = (usedMB / limitMB * 100) if limitMB > 0 else 0
if percent >= 80:
return {"usedMB": round(usedMB, 2), "limitMB": limitMB, "percent": round(percent, 1), "warning": True}
return {"usedMB": round(usedMB, 2), "limitMB": limitMB, "percent": round(percent, 1), "warning": False}
# =========================================================================
# Counting (cross-DB queries against poweron_app)
# =========================================================================
def countActiveUsers(self, mandateId: str) -> int:
try:
appDb = _getAppDatabaseConnector()
return len(appDb.getRecordset(UserMandate, recordFilter={"mandateId": mandateId}))
except Exception as e:
logger.error("countActiveUsers(%s) failed: %s", mandateId, e)
return 0
def countActiveFeatureInstances(self, mandateId: str) -> int:
try:
from modules.datamodels.datamodelFeatures import FeatureInstance
appDb = _getAppDatabaseConnector()
return len(appDb.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId, "enabled": True}))
except Exception as e:
logger.error("countActiveFeatureInstances(%s) failed: %s", mandateId, e)
return 0
# =========================================================================
# Stripe quantity sync
# =========================================================================
def syncQuantityToStripe(self, subscriptionId: str, *, raiseOnError: bool = False) -> None:
"""Update Stripe subscription item quantities to match actual active counts.
Takes subscriptionId, not mandateId.
Args:
raiseOnError: If True, propagate Stripe API errors instead of logging them.
Use True for billing-critical paths (store activation).
"""
sub = self.getById(subscriptionId)
if not sub or not sub.get("stripeSubscriptionId"):
if raiseOnError:
raise ValueError(f"Subscription {subscriptionId} hat keine Stripe-Anbindung — Abrechnung nicht möglich.")
return
mandateId = sub["mandateId"]
itemIdUsers = sub.get("stripeItemIdUsers")
itemIdInstances = sub.get("stripeItemIdInstances")
plan = self.getPlan(sub.get("planKey", ""))
includedModules = plan.includedModules if plan else 0
try:
from modules.shared.stripeClient import getStripeClient
stripe = getStripeClient()
activeUsers = self.countActiveUsers(mandateId)
activeInstances = self.countActiveFeatureInstances(mandateId)
billableModules = max(0, activeInstances - includedModules)
if itemIdUsers:
stripe.SubscriptionItem.modify(
itemIdUsers, quantity=max(activeUsers, 0), proration_behavior="create_prorations",
)
if itemIdInstances:
stripe.SubscriptionItem.modify(
itemIdInstances, quantity=billableModules, proration_behavior="create_prorations",
)
logger.info("Stripe quantity synced for sub %s: users=%d, modules=%d (total=%d, included=%d)", subscriptionId, activeUsers, billableModules, activeInstances, includedModules)
except Exception as e:
logger.error("syncQuantityToStripe(%s) failed: %s", subscriptionId, e)
if raiseOnError:
raise