From c813bd63ca8e9c425dc0d6aae30f177791eeed28 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Sun, 22 Mar 2026 17:23:54 +0100
Subject: [PATCH 1/5] subscription base logic
---
app.py | 3 +
config.ini | 5 +
modules/datamodels/datamodelBilling.py | 9 +-
modules/datamodels/datamodelSubscription.py | 235 ++++++
.../workspace/routeFeatureWorkspace.py | 13 +-
modules/interfaces/interfaceBootstrap.py | 48 ++
modules/interfaces/interfaceDbApp.py | 32 +-
modules/interfaces/interfaceDbSubscription.py | 353 +++++++++
modules/routes/routeAdminFeatures.py | 29 +-
modules/routes/routeBilling.py | 313 +++++++-
modules/routes/routeSubscription.py | 364 +++++++++
modules/serviceCenter/registry.py | 9 +-
.../services/serviceAgent/agentLoop.py | 15 +
.../services/serviceAi/mainServiceAi.py | 23 +-
.../serviceBilling/billingExhaustedNotify.py | 102 +--
.../serviceBilling/mainServiceBilling.py | 35 +
.../services/serviceBilling/stripeCheckout.py | 13 +-
.../services/serviceSubscription/__init__.py | 0
.../mainServiceSubscription.py | 710 ++++++++++++++++++
.../serviceSubscription/stripeBootstrap.py | 214 ++++++
modules/shared/configuration.py | 2 +-
modules/shared/notifyMandateAdmins.py | 285 +++++++
modules/shared/stripeClient.py | 38 +
modules/system/mainSystem.py | 1 -
24 files changed, 2722 insertions(+), 129 deletions(-)
create mode 100644 modules/datamodels/datamodelSubscription.py
create mode 100644 modules/interfaces/interfaceDbSubscription.py
create mode 100644 modules/routes/routeSubscription.py
create mode 100644 modules/serviceCenter/services/serviceSubscription/__init__.py
create mode 100644 modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py
create mode 100644 modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py
create mode 100644 modules/shared/notifyMandateAdmins.py
create mode 100644 modules/shared/stripeClient.py
diff --git a/app.py b/app.py
index 0c769a2a..c1400353 100644
--- a/app.py
+++ b/app.py
@@ -601,6 +601,9 @@ app.include_router(gdprRouter)
from modules.routes.routeBilling import router as billingRouter
app.include_router(billingRouter)
+from modules.routes.routeSubscription import router as subscriptionRouter
+app.include_router(subscriptionRouter)
+
# ============================================================================
# SYSTEM ROUTES (Navigation, etc.)
# ============================================================================
diff --git a/config.ini b/config.ini
index ccd6b77e..4a37f2f8 100644
--- a/config.ini
+++ b/config.ini
@@ -44,3 +44,8 @@ Connector_StacSwisstopo_TIMEOUT = 30
Connector_StacSwisstopo_MAX_RETRIES = 3
Connector_StacSwisstopo_RETRY_DELAY = 1.0
Connector_StacSwisstopo_ENABLE_CACHE = True
+
+# Operator company information (shown on invoice emails)
+Operator_CompanyName = PowerOn AG
+Operator_Address = Birmensdorferstrasse 94, 8003 Zürich
+Operator_VatNumber = CHE491.960.195
diff --git a/modules/datamodels/datamodelBilling.py b/modules/datamodels/datamodelBilling.py
index 8ffbdef1..995ac75d 100644
--- a/modules/datamodels/datamodelBilling.py
+++ b/modules/datamodels/datamodelBilling.py
@@ -143,6 +143,9 @@ class BillingSettings(BaseModel):
)
warningThresholdPercent: float = Field(default=10.0, description="Warning threshold as percentage")
+ # Stripe
+ stripeCustomerId: Optional[str] = Field(None, description="Stripe Customer ID (cus_xxx) — one per mandate")
+
# Notifications (e.g. mandate owner / finance — also used when PREPAY_MANDATE pool is exhausted)
notifyEmails: List[str] = Field(
default_factory=list,
@@ -163,6 +166,7 @@ registerModelLabels(
"de": "Startguthaben nur Root-Mandant (CHF)",
},
"warningThresholdPercent": {"en": "Warning Threshold (%)", "de": "Warnschwelle (%)"},
+ "stripeCustomerId": {"en": "Stripe Customer ID", "de": "Stripe-Kunden-ID"},
"notifyEmails": {
"en": "Billing notification emails (owner / admin)",
"de": "E-Mails für Billing-Alerts (Inhaber/Admin)",
@@ -260,12 +264,15 @@ class BillingStatisticsResponse(BaseModel):
class BillingCheckResult(BaseModel):
- """Result of a billing balance check."""
+ """Result of a billing balance check (budget + subscription gate)."""
allowed: bool
reason: Optional[str] = None
currentBalance: Optional[float] = None
requiredAmount: Optional[float] = None
billingModel: Optional[BillingModelEnum] = None
+ upgradeRequired: Optional[bool] = None
+ subscriptionUiPath: Optional[str] = None
+ userAction: Optional[str] = None
def parseBillingModelFromStoredValue(raw: Optional[str]) -> BillingModelEnum:
diff --git a/modules/datamodels/datamodelSubscription.py b/modules/datamodels/datamodelSubscription.py
new file mode 100644
index 00000000..1c1435d8
--- /dev/null
+++ b/modules/datamodels/datamodelSubscription.py
@@ -0,0 +1,235 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""Subscription models: SubscriptionPlan (catalog), MandateSubscription (instance per mandate),
+StripePlanPrice (persisted Stripe IDs per plan).
+
+State Machine: see wiki/concepts/Subscription-State-Machine.md
+"""
+
+from typing import Dict, List, Optional
+from enum import Enum
+from datetime import datetime, timezone
+from pydantic import BaseModel, Field
+from modules.shared.attributeUtils import registerModelLabels
+import uuid
+
+
+class SubscriptionStatusEnum(str, Enum):
+ """Lifecycle status of a mandate subscription.
+ See wiki/concepts/Subscription-State-Machine.md for transition rules."""
+ PENDING = "PENDING"
+ SCHEDULED = "SCHEDULED"
+ TRIALING = "TRIALING"
+ ACTIVE = "ACTIVE"
+ PAST_DUE = "PAST_DUE"
+ EXPIRED = "EXPIRED"
+
+
+TERMINAL_STATUSES = {SubscriptionStatusEnum.EXPIRED}
+OPERATIVE_STATUSES = {SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.PAST_DUE}
+
+ALLOWED_TRANSITIONS = {
+ (SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.ACTIVE),
+ (SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.SCHEDULED),
+ (SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.EXPIRED),
+ (SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE),
+ (SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.EXPIRED),
+ (SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.EXPIRED),
+ (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE),
+ (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.EXPIRED),
+ (SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.ACTIVE),
+ (SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.EXPIRED),
+}
+
+
+class BillingPeriodEnum(str, Enum):
+ """Recurring billing interval."""
+ MONTHLY = "MONTHLY"
+ YEARLY = "YEARLY"
+ NONE = "NONE"
+
+
+# ============================================================================
+# Catalog: SubscriptionPlan (static, in-memory)
+# ============================================================================
+
+class SubscriptionPlan(BaseModel):
+ """Plan definition (catalog entry). Not stored per mandate — static."""
+ planKey: str = Field(..., description="Unique plan identifier")
+ selectableByUser: bool = Field(default=True, description="Whether users can choose this plan in the UI")
+
+ title: Dict[str, str] = Field(default_factory=dict, description="Multilingual title (en/de/fr)")
+ description: Dict[str, str] = Field(default_factory=dict, description="Multilingual description")
+
+ currency: str = Field(default="CHF", description="Billing currency")
+ billingPeriod: BillingPeriodEnum = Field(default=BillingPeriodEnum.MONTHLY, description="Recurring interval")
+ pricePerUserCHF: float = Field(default=0.0, description="Price per active user per period")
+ pricePerFeatureInstanceCHF: float = Field(default=0.0, description="Price per active feature instance per period")
+ autoRenew: bool = Field(default=True, description="Stripe renews automatically at period end")
+
+ maxUsers: Optional[int] = Field(None, description="Hard cap on active users (None = unlimited)")
+ maxFeatureInstances: Optional[int] = Field(None, description="Hard cap on active feature instances (None = unlimited)")
+ trialDays: Optional[int] = Field(None, description="Trial duration in days (only for trial plans)")
+ successorPlanKey: Optional[str] = Field(None, description="Plan to transition to when trial ends")
+
+
+registerModelLabels(
+ "SubscriptionPlan",
+ {"en": "Subscription Plan", "de": "Abonnement-Plan", "fr": "Plan d'abonnement"},
+ {
+ "planKey": {"en": "Plan", "de": "Plan", "fr": "Plan"},
+ "selectableByUser": {"en": "Selectable", "de": "Wählbar", "fr": "Sélectionnable"},
+ "billingPeriod": {"en": "Billing Period", "de": "Abrechnungszeitraum", "fr": "Période de facturation"},
+ "pricePerUserCHF": {"en": "Price per User (CHF)", "de": "Preis pro User (CHF)"},
+ "pricePerFeatureInstanceCHF": {"en": "Price per Instance (CHF)", "de": "Preis pro Instanz (CHF)"},
+ "maxUsers": {"en": "Max Users", "de": "Max. Benutzer", "fr": "Max. utilisateurs"},
+ "maxFeatureInstances": {"en": "Max Instances", "de": "Max. Instanzen", "fr": "Max. instances"},
+ },
+)
+
+
+# ============================================================================
+# Stripe Price mapping (persisted in DB, auto-created at bootstrap)
+# ============================================================================
+
+class StripePlanPrice(BaseModel):
+ """Persisted mapping from planKey to Stripe Product/Price IDs.
+ Auto-created at startup — no manual configuration needed.
+ Uses separate Stripe Products for users and instances for clear invoice labels."""
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
+ planKey: str = Field(..., description="Reference to SubscriptionPlan.planKey")
+ stripeProductId: str = Field("", description="Legacy single-product ID (unused)")
+ stripeProductIdUsers: Optional[str] = Field(None, description="Stripe Product ID for user licenses")
+ stripeProductIdInstances: Optional[str] = Field(None, description="Stripe Product ID for feature instances")
+ stripePriceIdUsers: Optional[str] = Field(None, description="Stripe Price ID for user-seat line item")
+ stripePriceIdInstances: Optional[str] = Field(None, description="Stripe Price ID for instance line item")
+
+
+registerModelLabels(
+ "StripePlanPrice",
+ {"en": "Stripe Plan Prices", "de": "Stripe-Planpreise"},
+ {
+ "planKey": {"en": "Plan", "de": "Plan"},
+ "stripeProductIdUsers": {"en": "Product (Users)", "de": "Produkt (User)"},
+ "stripeProductIdInstances": {"en": "Product (Instances)", "de": "Produkt (Instanzen)"},
+ "stripePriceIdUsers": {"en": "Price ID (Users)", "de": "Preis-ID (User)"},
+ "stripePriceIdInstances": {"en": "Price ID (Instances)", "de": "Preis-ID (Instanzen)"},
+ },
+)
+
+
+# ============================================================================
+# Instance: MandateSubscription
+# ============================================================================
+
+class MandateSubscription(BaseModel):
+ """A subscription instance bound to a specific mandate.
+ See wiki/concepts/Subscription-State-Machine.md for state transitions."""
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
+ mandateId: str = Field(..., description="Foreign key to Mandate")
+ planKey: str = Field(..., description="Reference to SubscriptionPlan.planKey")
+
+ status: SubscriptionStatusEnum = Field(default=SubscriptionStatusEnum.PENDING, description="Current lifecycle status")
+ recurring: bool = Field(default=True, description="True: auto-renews at period end. False: expires at period end (gekuendigt).")
+
+ startedAt: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="Record creation timestamp")
+ effectiveFrom: Optional[datetime] = Field(None, description="When this subscription becomes operative. None = immediate. Set for SCHEDULED subs.")
+ endedAt: Optional[datetime] = Field(None, description="When subscription ended (terminal)")
+ currentPeriodStart: Optional[datetime] = Field(None, description="Current billing period start (synced from Stripe)")
+ currentPeriodEnd: Optional[datetime] = Field(None, description="Current billing period end (synced from Stripe)")
+ trialEndsAt: Optional[datetime] = Field(None, description="Trial expiry timestamp")
+
+ snapshotPricePerUserCHF: float = Field(default=0.0, description="Price snapshot at activation (for invoice history)")
+ snapshotPricePerInstanceCHF: float = Field(default=0.0, description="Price snapshot at activation")
+
+ stripeSubscriptionId: Optional[str] = Field(None, description="Stripe Subscription ID (sub_xxx)")
+ stripeItemIdUsers: Optional[str] = Field(None, description="Stripe Subscription Item ID for user seats")
+ stripeItemIdInstances: Optional[str] = Field(None, description="Stripe Subscription Item ID for feature instances")
+
+
+registerModelLabels(
+ "MandateSubscription",
+ {"en": "Mandate Subscription", "de": "Mandanten-Abonnement", "fr": "Abonnement du mandat"},
+ {
+ "id": {"en": "ID", "de": "ID"},
+ "mandateId": {"en": "Mandate ID", "de": "Mandanten-ID"},
+ "planKey": {"en": "Plan", "de": "Plan"},
+ "status": {"en": "Status", "de": "Status"},
+ "recurring": {"en": "Recurring", "de": "Wiederkehrend"},
+ "startedAt": {"en": "Started", "de": "Gestartet"},
+ "effectiveFrom": {"en": "Effective From", "de": "Wirksam ab"},
+ "endedAt": {"en": "Ended", "de": "Beendet"},
+ "currentPeriodStart": {"en": "Period Start", "de": "Periodenbeginn"},
+ "currentPeriodEnd": {"en": "Period End", "de": "Periodenende"},
+ "trialEndsAt": {"en": "Trial Ends", "de": "Trial endet"},
+ "snapshotPricePerUserCHF": {"en": "Price/User (CHF)", "de": "Preis/User (CHF)"},
+ "snapshotPricePerInstanceCHF": {"en": "Price/Instance (CHF)", "de": "Preis/Instanz (CHF)"},
+ },
+)
+
+
+# ============================================================================
+# Built-in plan catalog (static, no env dependency)
+# ============================================================================
+
+BUILTIN_PLANS: Dict[str, SubscriptionPlan] = {
+ "ROOT": SubscriptionPlan(
+ planKey="ROOT",
+ selectableByUser=False,
+ title={"en": "Root (System)", "de": "Root (System)", "fr": "Root (Système)"},
+ description={"en": "Internal system plan — no billing.", "de": "Interner Systemplan — keine Verrechnung."},
+ billingPeriod=BillingPeriodEnum.NONE,
+ autoRenew=False,
+ maxUsers=None,
+ maxFeatureInstances=None,
+ ),
+ "TRIAL_7D": SubscriptionPlan(
+ planKey="TRIAL_7D",
+ selectableByUser=False,
+ title={"en": "Free Trial (7 days)", "de": "Gratis-Testphase (7 Tage)", "fr": "Essai gratuit (7 jours)"},
+ description={
+ "en": "Try the platform for 7 days — 1 user, up to 3 feature instances.",
+ "de": "Plattform 7 Tage testen — 1 User, bis zu 3 Feature-Instanzen.",
+ },
+ billingPeriod=BillingPeriodEnum.NONE,
+ autoRenew=False,
+ maxUsers=1,
+ maxFeatureInstances=3,
+ trialDays=7,
+ successorPlanKey="STANDARD_MONTHLY",
+ ),
+ "STANDARD_MONTHLY": SubscriptionPlan(
+ planKey="STANDARD_MONTHLY",
+ selectableByUser=True,
+ title={"en": "Standard (Monthly)", "de": "Standard (Monatlich)", "fr": "Standard (Mensuel)"},
+ description={
+ "en": "Usage-based billing per active user and feature instance, billed monthly.",
+ "de": "Nutzungsbasierte Abrechnung pro aktivem User und Feature-Instanz, monatlich.",
+ },
+ billingPeriod=BillingPeriodEnum.MONTHLY,
+ pricePerUserCHF=90.0,
+ pricePerFeatureInstanceCHF=150.0,
+ ),
+ "STANDARD_YEARLY": SubscriptionPlan(
+ planKey="STANDARD_YEARLY",
+ selectableByUser=True,
+ title={"en": "Standard (Yearly)", "de": "Standard (Jährlich)", "fr": "Standard (Annuel)"},
+ description={
+ "en": "Usage-based billing per active user and feature instance, billed yearly.",
+ "de": "Nutzungsbasierte Abrechnung pro aktivem User und Feature-Instanz, jährlich.",
+ },
+ billingPeriod=BillingPeriodEnum.YEARLY,
+ pricePerUserCHF=1080.0,
+ pricePerFeatureInstanceCHF=1800.0,
+ ),
+}
+
+
+def _getPlan(planKey: str) -> Optional[SubscriptionPlan]:
+ """Resolve a plan by key from the built-in catalog."""
+ return BUILTIN_PLANS.get(planKey)
+
+
+def _getSelectablePlans() -> List[SubscriptionPlan]:
+ """Return plans that users can choose in the UI."""
+ return [p for p in BUILTIN_PLANS.values() if p.selectableByUser]
diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py
index d0dd22da..6dc7774e 100644
--- a/modules/features/workspace/routeFeatureWorkspace.py
+++ b/modules/features/workspace/routeFeatureWorkspace.py
@@ -19,6 +19,9 @@ from modules.auth import limiter, getRequestContext, RequestContext
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import (
InsufficientBalanceException,
)
+from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ SubscriptionInactiveException,
+)
from modules.interfaces import interfaceDbChat, interfaceDbManagement
from modules.interfaces.interfaceAiObjects import AiObjects
from modules.serviceCenter.core.serviceStreaming import get_event_manager
@@ -803,7 +806,15 @@ async def _runWorkspaceAgent(
})
except Exception as e:
- if isinstance(e, InsufficientBalanceException):
+ if isinstance(e, SubscriptionInactiveException):
+ logger.warning(f"Workspace blocked by subscription: {e.message}")
+ await eventManager.emit_event(queueId, "error", {
+ "type": "error",
+ "content": e.message,
+ "workflowId": workflowId,
+ "item": e.toClientDict(),
+ })
+ elif isinstance(e, InsufficientBalanceException):
logger.warning(f"Workspace blocked by billing: {e.message}")
await eventManager.emit_event(queueId, "error", {
"type": "error",
diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py
index c1cac9ef..4a3881f5 100644
--- a/modules/interfaces/interfaceBootstrap.py
+++ b/modules/interfaces/interfaceBootstrap.py
@@ -103,6 +103,13 @@ def initBootstrap(db: DatabaseConnector) -> None:
if mandateId:
initRootMandateBilling(mandateId)
+ # Initialize subscription for root mandate
+ if mandateId:
+ _initRootMandateSubscription(mandateId)
+
+ # Auto-provision Stripe Products/Prices for paid plans (idempotent)
+ _bootstrapStripePrices()
+
def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str] = None) -> None:
"""
@@ -2069,6 +2076,47 @@ def initRootMandateBilling(mandateId: str) -> None:
logger.warning(f"Failed to initialize root mandate billing (non-critical): {e}")
+def _initRootMandateSubscription(mandateId: str) -> None:
+ """
+ Ensure the root mandate has an active ROOT subscription.
+ Called during bootstrap after billing init.
+ """
+ try:
+ from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
+ from modules.datamodels.datamodelSubscription import (
+ MandateSubscription,
+ SubscriptionStatusEnum,
+ )
+
+ subInterface = getSubRootInterface()
+ existing = subInterface.getOperativeForMandate(mandateId)
+ if existing:
+ logger.info("Root mandate subscription already exists")
+ return
+
+ sub = MandateSubscription(
+ mandateId=mandateId,
+ planKey="ROOT",
+ status=SubscriptionStatusEnum.ACTIVE,
+ recurring=False,
+ )
+ subInterface.createSubscription(sub)
+ logger.info("Created ROOT subscription for root mandate")
+
+ except Exception as e:
+ logger.warning(f"Failed to initialize root mandate subscription (non-critical): {e}")
+
+
+def _bootstrapStripePrices() -> None:
+ """Auto-create Stripe Products and Prices for all paid plans.
+ Idempotent — safe on every startup. IDs are persisted in the StripePlanPrice table."""
+ try:
+ from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import bootstrapStripePrices
+ bootstrapStripePrices()
+ except Exception as e:
+ logger.error(f"Stripe price bootstrap failed (subscriptions will not work for paid plans): {e}")
+
+
def assignInitialUserMemberships(
db: DatabaseConnector,
mandateId: str,
diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py
index 2fec872d..092fd589 100644
--- a/modules/interfaces/interfaceDbApp.py
+++ b/modules/interfaces/interfaceDbApp.py
@@ -1615,7 +1615,10 @@ class AppObjects:
existing = self.getUserMandate(userId, mandateId)
if existing:
raise ValueError(f"User {userId} is already member of mandate {mandateId}")
-
+
+ # Subscription capacity check (before insert)
+ self._checkSubscriptionCapacity(mandateId, "users", delta=1)
+
# Create UserMandate
userMandate = UserMandate(
userId=userId,
@@ -1636,7 +1639,10 @@ class AppObjects:
# Create billing account for user if billing is configured
self._ensureUserBillingAccount(userId, mandateId)
-
+
+ # Sync Stripe quantity after successful insert
+ self._syncSubscriptionQuantity(mandateId)
+
cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")}
return UserMandate(**cleanedRecord)
except Exception as e:
@@ -1686,6 +1692,28 @@ class AppObjects:
except Exception as e:
logger.warning(f"Failed to create billing account for user {userId} (non-critical): {e}")
+ def _checkSubscriptionCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> None:
+ """Check subscription capacity before creating a resource. Raises on cap violation."""
+ try:
+ from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface
+ from modules.security.rootAccess import getRootUser
+ subIf = getSubInterface(getRootUser(), mandateId)
+ subIf.assertCapacity(mandateId, resourceType, delta)
+ except Exception as e:
+ if "SubscriptionCapacityException" in type(e).__name__:
+ raise
+ logger.debug(f"Subscription capacity check skipped: {e}")
+
+ def _syncSubscriptionQuantity(self, mandateId: str) -> None:
+ """Sync Stripe subscription quantities after a resource mutation."""
+ try:
+ from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface
+ from modules.security.rootAccess import getRootUser
+ subIf = getSubInterface(getRootUser(), mandateId)
+ subIf.syncQuantityToStripe(mandateId)
+ except Exception as e:
+ logger.debug(f"Subscription quantity sync skipped: {e}")
+
def deleteUserMandate(self, userId: str, mandateId: str) -> bool:
"""
Delete a UserMandate record (remove user from mandate).
diff --git a/modules/interfaces/interfaceDbSubscription.py b/modules/interfaces/interfaceDbSubscription.py
new file mode 100644
index 00000000..f08025ea
--- /dev/null
+++ b/modules/interfaces/interfaceDbSubscription.py
@@ -0,0 +1,353 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""
+Interface for Subscription operations — ID-based, deterministic.
+
+Every write operation takes an explicit subscriptionId.
+No status-scan guessing. See wiki/concepts/Subscription-State-Machine.md.
+"""
+
+import logging
+from typing import Dict, Any, List, Optional
+from datetime import datetime, timezone
+
+from modules.connectors.connectorDbPostgre import DatabaseConnector
+from modules.shared.configuration import APP_CONFIG
+from modules.datamodels.datamodelUam import User
+from modules.datamodels.datamodelMembership import UserMandate
+from modules.datamodels.datamodelSubscription import (
+ SubscriptionPlan,
+ MandateSubscription,
+ SubscriptionStatusEnum,
+ BillingPeriodEnum,
+ ALLOWED_TRANSITIONS,
+ TERMINAL_STATUSES,
+ OPERATIVE_STATUSES,
+ BUILTIN_PLANS,
+ _getPlan,
+ _getSelectablePlans,
+)
+
+logger = logging.getLogger(__name__)
+
+SUBSCRIPTION_DATABASE = "poweron_billing"
+
+_subscriptionInterfaces: Dict[str, "SubscriptionObjects"] = {}
+
+
+class InvalidTransitionError(Exception):
+ """Raised when a state transition is not allowed by the state machine."""
+ def __init__(self, subscriptionId: str, fromStatus: str, toStatus: str):
+ self.subscriptionId = subscriptionId
+ self.fromStatus = fromStatus
+ self.toStatus = toStatus
+ super().__init__(f"Invalid transition {fromStatus} -> {toStatus} for subscription {subscriptionId}")
+
+
+def getInterface(currentUser: User, mandateId: str = None) -> "SubscriptionObjects":
+ cacheKey = f"{currentUser.id}_{mandateId}"
+ if cacheKey not in _subscriptionInterfaces:
+ _subscriptionInterfaces[cacheKey] = SubscriptionObjects(currentUser, mandateId)
+ else:
+ _subscriptionInterfaces[cacheKey].setUserContext(currentUser, mandateId)
+ return _subscriptionInterfaces[cacheKey]
+
+
+def _getRootInterface() -> "SubscriptionObjects":
+ from modules.security.rootAccess import getRootUser
+ return SubscriptionObjects(getRootUser(), mandateId=None)
+
+
+def _getAppDatabaseConnector() -> DatabaseConnector:
+ return DatabaseConnector(
+ dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"),
+ dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
+ dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
+ dbUser=APP_CONFIG.get("DB_USER"),
+ dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
+ )
+
+
+class SubscriptionObjects:
+ """Interface for subscription operations: CRUD, gate checks, Stripe sync.
+ All writes are ID-based. All status changes go through transitionStatus()."""
+
+ def __init__(self, currentUser: Optional[User] = None, mandateId: str = None):
+ self.currentUser = currentUser
+ self.userId = currentUser.id if currentUser else None
+ self.mandateId = mandateId
+ self.db = DatabaseConnector(
+ dbDatabase=SUBSCRIPTION_DATABASE,
+ dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
+ dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
+ dbUser=APP_CONFIG.get("DB_USER"),
+ dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
+ )
+
+ def setUserContext(self, currentUser: User, mandateId: str = None):
+ self.currentUser = currentUser
+ self.userId = currentUser.id if currentUser else None
+ self.mandateId = mandateId
+
+ # =========================================================================
+ # Plan catalog (in-memory)
+ # =========================================================================
+
+ def getPlan(self, planKey: str) -> Optional[SubscriptionPlan]:
+ return _getPlan(planKey)
+
+ def getSelectablePlans(self) -> List[SubscriptionPlan]:
+ return _getSelectablePlans()
+
+ # =========================================================================
+ # Read: by ID (primary access pattern)
+ # =========================================================================
+
+ def getById(self, subscriptionId: str) -> Optional[Dict[str, Any]]:
+ """Load a single subscription by its primary key."""
+ try:
+ results = self.db.getRecordset(MandateSubscription, recordFilter={"id": subscriptionId})
+ return dict(results[0]) if results else None
+ except Exception as e:
+ logger.error("getById(%s) failed: %s", subscriptionId, e)
+ return None
+
+ def getByStripeSubscriptionId(self, stripeSubId: str) -> Optional[Dict[str, Any]]:
+ """Load subscription by Stripe subscription ID — the webhook resolution path."""
+ try:
+ results = self.db.getRecordset(MandateSubscription, recordFilter={"stripeSubscriptionId": stripeSubId})
+ return dict(results[0]) if results else None
+ except Exception as e:
+ logger.error("getByStripeSubscriptionId(%s) failed: %s", stripeSubId, e)
+ return None
+
+ # =========================================================================
+ # Read: by mandate (list queries)
+ # =========================================================================
+
+ def listForMandate(self, mandateId: str, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]:
+ """Return all subscriptions for a mandate, optionally filtered by status.
+ Sorted newest-first by startedAt."""
+ try:
+ results = self.db.getRecordset(MandateSubscription, recordFilter={"mandateId": mandateId})
+ rows = [dict(r) for r in results]
+ if statusFilter:
+ filterValues = {s.value for s in statusFilter}
+ rows = [r for r in rows if r.get("status") in filterValues]
+ rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True)
+ return rows
+ except Exception as e:
+ logger.error("listForMandate(%s) failed: %s", mandateId, e)
+ return []
+
+ def getOperativeForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]:
+ """Return the single operative subscription (ACTIVE, TRIALING, or PAST_DUE).
+ This is a read-only query for the billing gate. Returns None if no operative sub exists."""
+ for row in self.listForMandate(mandateId):
+ if row.get("status") in {s.value for s in OPERATIVE_STATUSES}:
+ return row
+ return None
+
+ def getScheduledForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]:
+ """Return a SCHEDULED subscription if one exists (next sub waiting to start)."""
+ for row in self.listForMandate(mandateId, [SubscriptionStatusEnum.SCHEDULED]):
+ return row
+ return None
+
+ # =========================================================================
+ # Read: global (SysAdmin)
+ # =========================================================================
+
+ def listAll(self, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]:
+ """Return ALL subscriptions across all mandates, newest-first. SysAdmin use only."""
+ try:
+ results = self.db.getRecordset(MandateSubscription)
+ rows = [dict(r) for r in results]
+ if statusFilter:
+ filterValues = {s.value for s in statusFilter}
+ rows = [r for r in rows if r.get("status") in filterValues]
+ rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True)
+ return rows
+ except Exception as e:
+ logger.error("listAll() failed: %s", e)
+ return []
+
+ # =========================================================================
+ # Write: create
+ # =========================================================================
+
+ def createSubscription(self, sub: MandateSubscription) -> Dict[str, Any]:
+ """Persist a new MandateSubscription record."""
+ return self.db.recordCreate(MandateSubscription, sub.model_dump())
+
+ # =========================================================================
+ # Write: update fields (no status change)
+ # =========================================================================
+
+ def updateFields(self, subscriptionId: str, data: Dict[str, Any]) -> Dict[str, Any]:
+ """Update non-status fields on a subscription (e.g. recurring, Stripe IDs, periods).
+ Must NOT be used for status changes — use transitionStatus() for that."""
+ if "status" in data:
+ raise ValueError("updateFields must not change status — use transitionStatus()")
+ return self.db.recordModify(MandateSubscription, subscriptionId, data)
+
+ # =========================================================================
+ # Write: status transition (guarded)
+ # =========================================================================
+
+ def transitionStatus(
+ self,
+ subscriptionId: str,
+ expectedFromStatus: SubscriptionStatusEnum,
+ toStatus: SubscriptionStatusEnum,
+ additionalData: Dict[str, Any] = None,
+ ) -> Dict[str, Any]:
+ """Execute a guarded status transition.
+
+ 1. Load the record by ID
+ 2. Verify current status matches expectedFromStatus
+ 3. Verify the transition is allowed by the state machine
+ 4. Apply the update
+ """
+ sub = self.getById(subscriptionId)
+ if not sub:
+ raise ValueError(f"Subscription {subscriptionId} not found")
+
+ currentStatus = sub.get("status", "")
+ if currentStatus != expectedFromStatus.value:
+ raise InvalidTransitionError(subscriptionId, currentStatus, toStatus.value)
+
+ if (expectedFromStatus, toStatus) not in ALLOWED_TRANSITIONS:
+ raise InvalidTransitionError(subscriptionId, expectedFromStatus.value, toStatus.value)
+
+ updateData = {"status": toStatus.value}
+ if toStatus in TERMINAL_STATUSES and not (additionalData or {}).get("endedAt"):
+ updateData["endedAt"] = datetime.now(timezone.utc).isoformat()
+ if additionalData:
+ updateData.update(additionalData)
+
+ result = self.db.recordModify(MandateSubscription, subscriptionId, updateData)
+ logger.info("Transition %s -> %s for subscription %s", expectedFromStatus.value, toStatus.value, subscriptionId)
+ return result
+
+ def forceExpire(self, subscriptionId: str) -> Dict[str, Any]:
+ """Sysadmin force-expire: ANY non-terminal -> EXPIRED. Bypasses normal transition guards."""
+ sub = self.getById(subscriptionId)
+ if not sub:
+ raise ValueError(f"Subscription {subscriptionId} not found")
+
+ currentStatus = sub.get("status", "")
+ if currentStatus == SubscriptionStatusEnum.EXPIRED.value:
+ raise ValueError(f"Subscription {subscriptionId} is already EXPIRED")
+
+ result = self.db.recordModify(MandateSubscription, subscriptionId, {
+ "status": SubscriptionStatusEnum.EXPIRED.value,
+ "endedAt": datetime.now(timezone.utc).isoformat(),
+ })
+ logger.info("Force-expired subscription %s (was %s)", subscriptionId, currentStatus)
+ return result
+
+ # =========================================================================
+ # Gate: assertActive (read-only, for billing gate)
+ # =========================================================================
+
+ def assertActive(self, mandateId: str) -> SubscriptionStatusEnum:
+ """Return effective status for billing decisions.
+ Returns the operative subscription's status, or EXPIRED if none exists.
+ This is the ONLY read-by-mandate operation used in the hot path."""
+ sub = self.getOperativeForMandate(mandateId)
+ if sub:
+ return SubscriptionStatusEnum(sub["status"])
+ return SubscriptionStatusEnum.EXPIRED
+
+ # =========================================================================
+ # Gate: assertCapacity
+ # =========================================================================
+
+ def assertCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> bool:
+ sub = self.getOperativeForMandate(mandateId)
+ if not sub:
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
+ raise SubscriptionCapacityException(
+ resourceType=resourceType, currentCount=0, maxAllowed=0,
+ message="No active subscription for this mandate.",
+ )
+
+ plan = self.getPlan(sub.get("planKey", ""))
+ if not plan:
+ return True
+
+ if resourceType == "users":
+ cap = plan.maxUsers
+ if cap is None:
+ return True
+ current = self.countActiveUsers(mandateId)
+ if current + delta > cap:
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
+ raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap)
+ elif resourceType == "featureInstances":
+ cap = plan.maxFeatureInstances
+ if cap is None:
+ return True
+ current = self.countActiveFeatureInstances(mandateId)
+ if current + delta > cap:
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
+ raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap)
+
+ return True
+
+ # =========================================================================
+ # Counting (cross-DB queries against poweron_app)
+ # =========================================================================
+
+ def countActiveUsers(self, mandateId: str) -> int:
+ try:
+ appDb = _getAppDatabaseConnector()
+ return len(appDb.getRecordset(UserMandate, recordFilter={"mandateId": mandateId}))
+ except Exception as e:
+ logger.error("countActiveUsers(%s) failed: %s", mandateId, e)
+ return 0
+
+ def countActiveFeatureInstances(self, mandateId: str) -> int:
+ try:
+ from modules.datamodels.datamodelFeatures import FeatureInstance
+ appDb = _getAppDatabaseConnector()
+ return len(appDb.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId, "enabled": True}))
+ except Exception as e:
+ logger.error("countActiveFeatureInstances(%s) failed: %s", mandateId, e)
+ return 0
+
+ # =========================================================================
+ # Stripe quantity sync
+ # =========================================================================
+
+ def syncQuantityToStripe(self, subscriptionId: str) -> None:
+ """Update Stripe subscription item quantities to match actual active counts.
+ Takes subscriptionId, not mandateId."""
+ sub = self.getById(subscriptionId)
+ if not sub or not sub.get("stripeSubscriptionId"):
+ return
+
+ mandateId = sub["mandateId"]
+ itemIdUsers = sub.get("stripeItemIdUsers")
+ itemIdInstances = sub.get("stripeItemIdInstances")
+
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+
+ activeUsers = self.countActiveUsers(mandateId)
+ activeInstances = self.countActiveFeatureInstances(mandateId)
+
+ if itemIdUsers:
+ stripe.SubscriptionItem.modify(
+ itemIdUsers, quantity=max(activeUsers, 0), proration_behavior="create_prorations",
+ )
+ if itemIdInstances:
+ stripe.SubscriptionItem.modify(
+ itemIdInstances, quantity=max(activeInstances, 0), proration_behavior="create_prorations",
+ )
+
+ logger.info("Stripe quantity synced for sub %s: users=%d, instances=%d", subscriptionId, activeUsers, activeInstances)
+ except Exception as e:
+ logger.error("syncQuantityToStripe(%s) failed: %s", subscriptionId, e)
diff --git a/modules/routes/routeAdminFeatures.py b/modules/routes/routeAdminFeatures.py
index 1bada7fa..b31b5e4c 100644
--- a/modules/routes/routeAdminFeatures.py
+++ b/modules/routes/routeAdminFeatures.py
@@ -518,15 +518,40 @@ def create_feature_instance(
detail=f"Feature '{data.featureCode}' not found"
)
+ # Subscription capacity check
+ mandateIdStr = str(context.mandateId)
+ try:
+ from modules.interfaces.interfaceDbSubscription import getInterface as _getSubIf
+ from modules.security.rootAccess import getRootUser
+ _subIf = _getSubIf(getRootUser(), mandateIdStr)
+ _subIf.assertCapacity(mandateIdStr, "featureInstances", delta=1)
+ except HTTPException:
+ raise
+ except Exception as capErr:
+ if "SubscriptionCapacityException" in type(capErr).__name__:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=str(capErr),
+ )
+
instance = featureInterface.createFeatureInstance(
featureCode=data.featureCode,
- mandateId=str(context.mandateId),
+ mandateId=mandateIdStr,
label=data.label,
enabled=data.enabled,
copyTemplateRoles=data.copyTemplateRoles,
config=data.config
)
-
+
+ # Sync Stripe quantity after successful creation
+ try:
+ from modules.interfaces.interfaceDbSubscription import getInterface as _getSubIf2
+ from modules.security.rootAccess import getRootUser as _getRU
+ _subIf2 = _getSubIf2(_getRU(), mandateIdStr)
+ _subIf2.syncQuantityToStripe(mandateIdStr)
+ except Exception:
+ pass
+
logger.info(
f"User {context.user.id} created feature instance '{data.label}' "
f"for feature '{data.featureCode}' in mandate {context.mandateId}"
diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py
index eeee79a7..e1c23b48 100644
--- a/modules/routes/routeBilling.py
+++ b/modules/routes/routeBilling.py
@@ -226,9 +226,9 @@ def _filterTransactionsByScope(transactions: list, scope: BillingDataScope) -> l
# =============================================================================
class CreditAddRequest(BaseModel):
- """Request model for adding credit to an account."""
+ """Request model for adding or deducting credit from an account."""
userId: Optional[str] = Field(None, description="Target user ID (for PREPAY_USER model)")
- amount: float = Field(..., gt=0, description="Amount to credit in CHF")
+ amount: float = Field(..., description="Amount in CHF. Positive = credit, negative = deduction. Must not be zero.")
description: str = Field(default="Manual credit", description="Transaction description")
@@ -358,19 +358,8 @@ class UserTransactionResponse(BaseModel):
def _getStripeClient():
"""Initialize and return configured Stripe SDK module."""
- import stripe
- from modules.shared.configuration import APP_CONFIG
-
- api_version = APP_CONFIG.get("STRIPE_API_VERSION")
- if api_version:
- stripe.api_version = api_version
-
- secret_key = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY")
- if not secret_key:
- raise ValueError("STRIPE_SECRET_KEY_SECRET not configured")
-
- stripe.api_key = secret_key
- return stripe
+ from modules.shared.stripeClient import getStripeClient
+ return getStripeClient()
def _creditStripeSessionIfNeeded(
@@ -835,20 +824,27 @@ def addCredit(
else:
raise HTTPException(status_code=400, detail=f"Cannot add credit to {billingModel.value} billing model")
- # Create credit transaction
+ if creditRequest.amount == 0:
+ raise HTTPException(status_code=400, detail="Amount must not be zero")
+
from modules.datamodels.datamodelBilling import BillingTransaction
-
+
+ isDeduction = creditRequest.amount < 0
+ txType = TransactionTypeEnum.DEBIT if isDeduction else TransactionTypeEnum.CREDIT
+ absAmount = abs(creditRequest.amount)
+
transaction = BillingTransaction(
accountId=account["id"],
- transactionType=TransactionTypeEnum.CREDIT,
- amount=creditRequest.amount,
+ transactionType=txType,
+ amount=absAmount,
description=creditRequest.description,
referenceType=ReferenceTypeEnum.ADMIN
)
result = billingInterface.createTransaction(transaction)
- logger.info(f"Added {creditRequest.amount} CHF credit to account {account['id']} in mandate {targetMandateId}")
+ action = "Deducted" if isDeduction else "Added"
+ logger.info(f"{action} {absAmount} CHF to account {account['id']} in mandate {targetMandateId}")
return result
@@ -1006,13 +1002,33 @@ async def stripeWebhook(
logger.info(f"Stripe webhook received: event={event.id}, type={event.type}")
- accepted_event_types = {"checkout.session.completed", "checkout.session.async_payment_succeeded"}
- if event.type not in accepted_event_types:
+ # Subscription-related events
+ subscriptionEventTypes = {
+ "customer.subscription.updated",
+ "customer.subscription.deleted",
+ "invoice.paid",
+ "invoice.payment_failed",
+ "customer.subscription.trial_will_end",
+ }
+
+ # Checkout events (existing)
+ checkoutEventTypes = {"checkout.session.completed", "checkout.session.async_payment_succeeded"}
+
+ if event.type in subscriptionEventTypes:
+ _handleSubscriptionWebhook(event)
return {"received": True}
-
+
+ if event.type not in checkoutEventTypes:
+ return {"received": True}
+
session = event.data.object
event_id = event.id
+ sessionMode = session.get("mode") if hasattr(session, "get") else getattr(session, "mode", None)
+ if sessionMode == "subscription":
+ _handleSubscriptionCheckoutCompleted(session, event_id)
+ return {"received": True}
+
billingInterface = _getRootInterface()
if billingInterface.getStripeWebhookEventByEventId(event_id):
logger.info(f"Stripe event {event_id} already processed, skipping")
@@ -1027,6 +1043,257 @@ async def stripeWebhook(
return {"received": True}
+def _handleSubscriptionCheckoutCompleted(session, eventId: str) -> None:
+ """Handle checkout.session.completed for mode=subscription.
+ Resolves the local PENDING record by ID from webhook metadata and transitions it."""
+ from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
+ from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, _getPlan
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ getService as getSubscriptionService,
+ _notifySubscriptionChange,
+ )
+ from modules.security.rootAccess import getRootUser
+ from datetime import datetime, timezone
+
+ metadata = {}
+ if hasattr(session, "get"):
+ metadata = session.get("metadata") or {}
+ subscriptionRecordId = metadata.get("subscriptionRecordId")
+ mandateId = metadata.get("mandateId")
+ planKey = metadata.get("planKey", "")
+
+ platformUrl = metadata.get("platformUrl", "")
+
+ if not subscriptionRecordId:
+ stripeSub = session.get("subscription")
+ if stripeSub:
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ subObj = stripe.Subscription.retrieve(stripeSub)
+ metadata = subObj.get("metadata") or {}
+ subscriptionRecordId = metadata.get("subscriptionRecordId")
+ mandateId = metadata.get("mandateId")
+ planKey = metadata.get("planKey", "")
+ platformUrl = platformUrl or metadata.get("platformUrl", "")
+ except Exception:
+ pass
+
+ stripeSubId = session.get("subscription")
+
+ if not mandateId or not subscriptionRecordId:
+ logger.warning("Subscription checkout missing metadata: %s", metadata)
+ return
+
+ subInterface = getSubRootInterface()
+ rootUser = getRootUser()
+
+ sub = subInterface.getById(subscriptionRecordId)
+ if not sub:
+ logger.error("Subscription record %s not found for checkout webhook", subscriptionRecordId)
+ return
+ if sub.get("status") != SubscriptionStatusEnum.PENDING.value:
+ logger.warning("Subscription %s is %s, expected PENDING — skipping", subscriptionRecordId, sub.get("status"))
+ return
+
+ stripeData: Dict[str, Any] = {}
+ if stripeSubId:
+ stripeData["stripeSubscriptionId"] = stripeSubId
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ stripeSub = stripe.Subscription.retrieve(stripeSubId, expand=["items"])
+
+ if stripeSub.get("current_period_start"):
+ stripeData["currentPeriodStart"] = datetime.fromtimestamp(
+ stripeSub["current_period_start"], tz=timezone.utc
+ ).isoformat()
+ if stripeSub.get("current_period_end"):
+ stripeData["currentPeriodEnd"] = datetime.fromtimestamp(
+ stripeSub["current_period_end"], tz=timezone.utc
+ ).isoformat()
+
+ from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import getStripePricesForPlan
+ priceMapping = getStripePricesForPlan(planKey)
+ for item in stripeSub.get("items", {}).get("data", []):
+ priceId = item.get("price", {}).get("id", "")
+ if priceMapping and priceId == priceMapping.stripePriceIdUsers:
+ stripeData["stripeItemIdUsers"] = item["id"]
+ elif priceMapping and priceId == priceMapping.stripePriceIdInstances:
+ stripeData["stripeItemIdInstances"] = item["id"]
+ except Exception as e:
+ logger.error("Error retrieving Stripe subscription %s: %s", stripeSubId, e)
+
+ if stripeData:
+ subInterface.updateFields(subscriptionRecordId, stripeData)
+
+ operative = subInterface.getOperativeForMandate(mandateId)
+ hasActivePredecessor = operative is not None and operative["id"] != subscriptionRecordId
+
+ if hasActivePredecessor:
+ toStatus = SubscriptionStatusEnum.SCHEDULED
+ if operative.get("recurring", True):
+ operativeStripeId = operative.get("stripeSubscriptionId")
+ if operativeStripeId:
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ stripe.Subscription.modify(operativeStripeId, cancel_at_period_end=True)
+ except Exception as e:
+ logger.error("Failed to set cancel_at_period_end on predecessor %s: %s", operativeStripeId, e)
+ subInterface.updateFields(operative["id"], {"recurring": False})
+ effectiveFrom = operative.get("currentPeriodEnd")
+ if effectiveFrom:
+ subInterface.updateFields(subscriptionRecordId, {"effectiveFrom": effectiveFrom})
+ else:
+ toStatus = SubscriptionStatusEnum.ACTIVE
+
+ try:
+ subInterface.transitionStatus(
+ subscriptionRecordId, SubscriptionStatusEnum.PENDING, toStatus,
+ {"recurring": True},
+ )
+ except Exception as e:
+ logger.error("Failed to transition subscription %s: %s", subscriptionRecordId, e)
+ return
+
+ subService = getSubscriptionService(rootUser, mandateId)
+ subService.invalidateCache(mandateId)
+
+ if toStatus == SubscriptionStatusEnum.ACTIVE:
+ plan = _getPlan(planKey)
+ updatedSub = subInterface.getById(subscriptionRecordId)
+ _notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=updatedSub, platformUrl=platformUrl)
+
+ logger.info(
+ "Checkout completed: sub=%s -> %s, mandate=%s, plan=%s",
+ subscriptionRecordId, toStatus.value, mandateId, planKey,
+ )
+
+
+def _handleSubscriptionWebhook(event) -> None:
+ """Process Stripe subscription webhook events.
+ All record resolution is by stripeSubscriptionId — no mandate-based guessing."""
+ from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
+ from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, _getPlan
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ getService as getSubscriptionService,
+ _notifySubscriptionChange,
+ )
+ from modules.security.rootAccess import getRootUser
+ from datetime import datetime, timezone
+
+ obj = event.data.object
+ stripeSubId = obj.get("id") if event.type.startswith("customer.subscription") else obj.get("subscription")
+ if not stripeSubId:
+ logger.warning("Subscription webhook %s has no subscription ID", event.type)
+ return
+
+ subInterface = getSubRootInterface()
+ sub = subInterface.getByStripeSubscriptionId(stripeSubId)
+ if not sub:
+ logger.warning("No local record for Stripe subscription %s (event: %s)", stripeSubId, event.type)
+ return
+
+ subId = sub["id"]
+ mandateId = sub["mandateId"]
+ currentStatus = SubscriptionStatusEnum(sub["status"])
+ rootUser = getRootUser()
+ subService = getSubscriptionService(rootUser, mandateId)
+
+ subMetadata = obj.get("metadata") or {}
+ webhookPlatformUrl = subMetadata.get("platformUrl", "")
+
+ if event.type == "customer.subscription.updated":
+ stripeStatus = obj.get("status", "")
+
+ periodData: Dict[str, Any] = {}
+ if obj.get("current_period_start"):
+ periodData["currentPeriodStart"] = datetime.fromtimestamp(
+ obj["current_period_start"], tz=timezone.utc
+ ).isoformat()
+ if obj.get("current_period_end"):
+ periodData["currentPeriodEnd"] = datetime.fromtimestamp(
+ obj["current_period_end"], tz=timezone.utc
+ ).isoformat()
+ if periodData:
+ subInterface.updateFields(subId, periodData)
+
+ if stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.SCHEDULED:
+ subInterface.transitionStatus(subId, SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE)
+ subService.invalidateCache(mandateId)
+ plan = _getPlan(sub.get("planKey", ""))
+ refreshedSub = subInterface.getById(subId)
+ _notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedSub, platformUrl=webhookPlatformUrl)
+ logger.info("SCHEDULED -> ACTIVE for sub %s (mandate %s)", subId, mandateId)
+
+ elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.PAST_DUE:
+ subInterface.transitionStatus(subId, SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.ACTIVE)
+ subService.invalidateCache(mandateId)
+ logger.info("PAST_DUE -> ACTIVE for sub %s (mandate %s)", subId, mandateId)
+
+ elif stripeStatus == "past_due" and currentStatus == SubscriptionStatusEnum.ACTIVE:
+ subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE)
+ subService.invalidateCache(mandateId)
+ logger.info("ACTIVE -> PAST_DUE for sub %s (mandate %s)", subId, mandateId)
+
+ elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.ACTIVE:
+ subService.invalidateCache(mandateId)
+ logger.info("Period renewed for sub %s (mandate %s)", subId, mandateId)
+
+ elif event.type == "customer.subscription.deleted":
+ if currentStatus not in (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE,
+ SubscriptionStatusEnum.SCHEDULED):
+ logger.info("Ignoring deletion for sub %s in status %s", subId, currentStatus.value)
+ return
+
+ subInterface.transitionStatus(subId, currentStatus, SubscriptionStatusEnum.EXPIRED)
+ subService.invalidateCache(mandateId)
+ logger.info("Sub %s -> EXPIRED (Stripe deleted, mandate %s)", subId, mandateId)
+
+ scheduled = subInterface.getScheduledForMandate(mandateId)
+ if scheduled:
+ try:
+ subInterface.transitionStatus(
+ scheduled["id"], SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE,
+ )
+ subService.invalidateCache(mandateId)
+ plan = _getPlan(scheduled.get("planKey", ""))
+ refreshedScheduled = subInterface.getById(scheduled["id"])
+ _notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedScheduled, platformUrl=webhookPlatformUrl)
+ logger.info("Promoted SCHEDULED sub %s -> ACTIVE (mandate %s)", scheduled["id"], mandateId)
+ except Exception as e:
+ logger.error("Failed to promote SCHEDULED sub %s: %s", scheduled["id"], e)
+
+ elif event.type == "invoice.payment_failed":
+ if currentStatus == SubscriptionStatusEnum.ACTIVE:
+ subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE)
+ subService.invalidateCache(mandateId)
+ plan = _getPlan(sub.get("planKey", ""))
+ _notifySubscriptionChange(mandateId, "payment_failed", plan, subscriptionRecord=sub, platformUrl=webhookPlatformUrl)
+ logger.info("Payment failed for sub %s (mandate %s)", subId, mandateId)
+
+ elif event.type == "customer.subscription.trial_will_end":
+ logger.info("Trial ending soon for sub %s (mandate %s)", subId, mandateId)
+ try:
+ from modules.shared.notifyMandateAdmins import notifyMandateAdmins
+ notifyMandateAdmins(
+ mandateId,
+ "[PowerOn] Testphase endet bald",
+ "Testphase endet bald",
+ [
+ "Die kostenlose Testphase für Ihren Mandanten endet in Kürze.",
+ "Bitte wählen Sie einen Plan unter Billing-Verwaltung › Abonnement.",
+ ],
+ )
+ except Exception as e:
+ logger.error("Failed to notify about trial ending: %s", e)
+
+ elif event.type == "invoice.paid":
+ logger.info("Invoice paid for sub %s (mandate %s)", subId, mandateId)
+ return None
+
+
@router.get("/admin/accounts/{targetMandateId}", response_model=List[AccountSummary])
@limiter.limit("30/minute")
def getAccounts(
diff --git a/modules/routes/routeSubscription.py b/modules/routes/routeSubscription.py
new file mode 100644
index 00000000..28ad7aa9
--- /dev/null
+++ b/modules/routes/routeSubscription.py
@@ -0,0 +1,364 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""
+Subscription routes — ID-based, state-machine-driven.
+
+Endpoints:
+- GET /api/subscription/plans — list selectable plans
+- GET /api/subscription/status — operative + scheduled subscription for current mandate
+- POST /api/subscription/activate — start checkout for a plan
+- POST /api/subscription/cancel — cancel a specific subscription (by ID)
+- POST /api/subscription/reactivate — reactivate a cancelled subscription (by ID)
+- POST /api/subscription/force-cancel — sysadmin immediate cancel (by ID)
+"""
+
+from fastapi import APIRouter, HTTPException, Depends, Request
+from fastapi import status
+from typing import Dict, Any, List, Optional
+import logging
+from pydantic import BaseModel, Field
+
+from modules.auth import limiter, getRequestContext, RequestContext
+
+logger = logging.getLogger(__name__)
+
+
+def _resolveMandateId(context: RequestContext) -> str:
+ if context.mandateId:
+ return str(context.mandateId)
+ return ""
+
+
+def _assertMandateAdmin(context: RequestContext, mandateId: str) -> None:
+ if context.hasSysAdminRole:
+ return
+ try:
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ rootInterface = getRootInterface()
+ userMandates = rootInterface.getUserMandates(str(context.user.id))
+ for um in userMandates:
+ if str(getattr(um, "mandateId", None)) != str(mandateId):
+ continue
+ if not getattr(um, "enabled", True):
+ continue
+ umId = str(getattr(um, "id", ""))
+ roleIds = rootInterface.getRoleIdsForUserMandate(umId)
+ for roleId in roleIds:
+ role = rootInterface.getRole(roleId)
+ if role and role.roleLabel == "admin" and not role.featureInstanceId:
+ return
+ except Exception:
+ pass
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate admin role required")
+
+
+# =============================================================================
+# Request / Response models
+# =============================================================================
+
+class ActivatePlanRequest(BaseModel):
+ planKey: str = Field(..., description="Key of the plan to activate")
+ returnUrl: str = Field(..., description="Frontend URL to redirect back to after Stripe Checkout")
+
+class CancelRequest(BaseModel):
+ subscriptionId: str = Field(..., description="ID of the subscription to cancel")
+
+class ReactivateRequest(BaseModel):
+ subscriptionId: str = Field(..., description="ID of the subscription to reactivate")
+
+class ForceCancelRequest(BaseModel):
+ subscriptionId: str = Field(..., description="ID of the subscription to force-cancel")
+
+class VerifyCheckoutRequest(BaseModel):
+ sessionId: str = Field(..., description="Stripe Checkout Session ID to verify")
+
+class SubscriptionStatusResponse(BaseModel):
+ active: bool
+ subscription: Optional[Dict[str, Any]] = None
+ plan: Optional[Dict[str, Any]] = None
+ scheduled: Optional[Dict[str, Any]] = None
+
+
+# =============================================================================
+# Router
+# =============================================================================
+
+router = APIRouter(
+ prefix="/api/subscription",
+ tags=["Subscription"],
+ responses={404: {"description": "Not found"}},
+)
+
+
+# =============================================================================
+# Endpoints
+# =============================================================================
+
+@router.get("/plans", response_model=List[Dict[str, Any]])
+@limiter.limit("30/minute")
+def getPlans(request: Request, context: RequestContext = Depends(getRequestContext)):
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ getService as getSubscriptionService,
+ )
+ try:
+ mandateId = _resolveMandateId(context)
+ subService = getSubscriptionService(context.user, mandateId)
+ plans = subService.getSelectablePlans()
+ return [p.model_dump() for p in plans]
+ except Exception as e:
+ logger.error("Error fetching plans: %s", e)
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.get("/status", response_model=SubscriptionStatusResponse)
+@limiter.limit("60/minute")
+def getStatus(request: Request, context: RequestContext = Depends(getRequestContext)):
+ """Return the operative subscription and any scheduled successor for the current mandate."""
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ getService as getSubscriptionService,
+ )
+ mandateId = _resolveMandateId(context)
+ if not mandateId:
+ return SubscriptionStatusResponse(active=False)
+ _assertMandateAdmin(context, mandateId)
+
+ try:
+ subService = getSubscriptionService(context.user, mandateId)
+ operative = subService.getOperativeSubscription(mandateId)
+ scheduled = subService.getScheduledSubscription(mandateId)
+
+ if not operative:
+ from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum
+ pending = subService.listSubscriptions(mandateId, [SubscriptionStatusEnum.PENDING])
+ if pending:
+ sub = pending[0]
+ plan = subService.getPlan(sub.get("planKey", ""))
+ return SubscriptionStatusResponse(
+ active=False,
+ subscription=sub,
+ plan=plan.model_dump() if plan else None,
+ scheduled=scheduled,
+ )
+ return SubscriptionStatusResponse(active=False, scheduled=scheduled)
+
+ plan = subService.getPlan(operative.get("planKey", ""))
+ return SubscriptionStatusResponse(
+ active=True,
+ subscription=operative,
+ plan=plan.model_dump() if plan else None,
+ scheduled=scheduled,
+ )
+ except Exception as e:
+ logger.error("Error fetching status: %s", e)
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.post("/activate", response_model=Dict[str, Any])
+@limiter.limit("10/minute")
+def activatePlan(
+ request: Request,
+ data: ActivatePlanRequest,
+ context: RequestContext = Depends(getRequestContext),
+):
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ getService as getSubscriptionService,
+ )
+ mandateId = _resolveMandateId(context)
+ if not mandateId:
+ raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
+ _assertMandateAdmin(context, mandateId)
+
+ try:
+ subService = getSubscriptionService(context.user, mandateId)
+ return subService.activatePlan(mandateId, data.planKey, returnUrl=data.returnUrl)
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error("Error activating plan %s: %s", data.planKey, e)
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.post("/cancel", response_model=Dict[str, Any])
+@limiter.limit("5/minute")
+def cancelSubscription(
+ request: Request,
+ data: CancelRequest,
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Cancel a specific subscription by its ID."""
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ getService as getSubscriptionService,
+ )
+ mandateId = _resolveMandateId(context)
+ if not mandateId:
+ raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
+ _assertMandateAdmin(context, mandateId)
+
+ try:
+ subService = getSubscriptionService(context.user, mandateId)
+ return subService.cancelSubscription(data.subscriptionId)
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error("Error cancelling subscription %s: %s", data.subscriptionId, e)
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.post("/reactivate", response_model=Dict[str, Any])
+@limiter.limit("5/minute")
+def reactivateSubscription(
+ request: Request,
+ data: ReactivateRequest,
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Reactivate a cancelled (non-recurring) subscription before its period ends."""
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ getService as getSubscriptionService,
+ )
+ mandateId = _resolveMandateId(context)
+ if not mandateId:
+ raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
+ _assertMandateAdmin(context, mandateId)
+
+ try:
+ subService = getSubscriptionService(context.user, mandateId)
+ return subService.reactivateSubscription(data.subscriptionId)
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error("Error reactivating subscription %s: %s", data.subscriptionId, e)
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.post("/force-cancel", response_model=Dict[str, Any])
+@limiter.limit("5/minute")
+def forceCancel(
+ request: Request,
+ data: ForceCancelRequest,
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Sysadmin: immediately expire any non-terminal subscription."""
+ if not context.hasSysAdminRole:
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
+
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ getService as getSubscriptionService,
+ )
+ from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
+ sub = getSubRootInterface().getById(data.subscriptionId)
+ if not sub:
+ raise HTTPException(status_code=404, detail="Subscription not found")
+ mandateId = sub["mandateId"]
+
+ try:
+ subService = getSubscriptionService(context.user, mandateId)
+ return subService.forceCancel(data.subscriptionId)
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error("Error force-cancelling subscription %s: %s", data.subscriptionId, e)
+ raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.post("/checkout/verify", response_model=Dict[str, Any])
+@limiter.limit("20/minute")
+def verifyCheckout(
+ request: Request,
+ data: VerifyCheckoutRequest,
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Verify a Stripe Checkout Session and activate the subscription if paid.
+
+ This is the synchronous counterpart to the checkout.session.completed webhook.
+ It's called by the frontend immediately after returning from Stripe to handle
+ environments where webhooks may be delayed or unavailable (e.g. localhost dev).
+ The logic is idempotent — if the webhook already processed the session, this is a no-op.
+ """
+ mandateId = _resolveMandateId(context)
+ if not mandateId:
+ raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
+ _assertMandateAdmin(context, mandateId)
+
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ session = stripe.checkout.Session.retrieve(data.sessionId)
+ except Exception as e:
+ logger.error("Failed to retrieve checkout session %s: %s", data.sessionId, e)
+ raise HTTPException(status_code=400, detail="Invalid session ID")
+
+ if session.get("status") != "complete" or session.get("payment_status") != "paid":
+ return {"status": "pending", "message": "Checkout not yet completed"}
+
+ if session.get("mode") != "subscription":
+ raise HTTPException(status_code=400, detail="Not a subscription checkout session")
+
+ from modules.routes.routeBilling import _handleSubscriptionCheckoutCompleted
+ _handleSubscriptionCheckoutCompleted(session, f"verify-{data.sessionId}")
+
+ return {"status": "activated", "message": "Subscription activated"}
+
+
+# =============================================================================
+# SysAdmin: global subscription overview
+# =============================================================================
+
+@router.get("/admin/all", response_model=List[Dict[str, Any]])
+@limiter.limit("30/minute")
+def getAllSubscriptions(
+ request: Request,
+ context: RequestContext = Depends(getRequestContext),
+):
+ """SysAdmin: list ALL subscriptions across all mandates with enriched metadata."""
+ if not context.hasSysAdminRole:
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
+
+ from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
+ from modules.datamodels.datamodelSubscription import BUILTIN_PLANS, OPERATIVE_STATUSES
+
+ subInterface = getSubRootInterface()
+ allSubs = subInterface.listAll()
+
+ mandateNames: Dict[str, str] = {}
+ try:
+ from modules.datamodels.datamodelUam import Mandate
+ from modules.security.rootAccess import getRootDbAppConnector
+ appDb = getRootDbAppConnector()
+ for row in appDb.getRecordset(Mandate):
+ r = dict(row)
+ mid = r.get("id", "")
+ mandateNames[mid] = r.get("label") or r.get("name") or mid[:8]
+ except Exception as e:
+ logger.warning("Could not bulk-resolve mandate names: %s", e)
+
+ operativeValues = {s.value for s in OPERATIVE_STATUSES}
+
+ enriched = []
+ for sub in allSubs:
+ mid = sub.get("mandateId", "")
+ planKey = sub.get("planKey", "")
+ plan = BUILTIN_PLANS.get(planKey)
+
+ sub["mandateName"] = mandateNames.get(mid, mid[:8])
+ sub["planTitle"] = (plan.title.get("de") or plan.title.get("en") or planKey) if plan else planKey
+
+ if sub.get("status") in operativeValues:
+ userPrice = sub.get("snapshotPricePerUserCHF", 0) or 0
+ instPrice = sub.get("snapshotPricePerInstanceCHF", 0) or 0
+ try:
+ userCount = subInterface.countActiveUsers(mid)
+ instanceCount = subInterface.countActiveFeatureInstances(mid)
+ except Exception:
+ userCount = 0
+ instanceCount = 0
+ sub["monthlyRevenueCHF"] = round(userPrice * userCount + instPrice * instanceCount, 2)
+ sub["activeUsers"] = userCount
+ sub["activeInstances"] = instanceCount
+ else:
+ sub["monthlyRevenueCHF"] = 0
+ sub["activeUsers"] = 0
+ sub["activeInstances"] = 0
+
+ enriched.append(sub)
+
+ return enriched
diff --git a/modules/serviceCenter/registry.py b/modules/serviceCenter/registry.py
index 900f9f0e..be0accba 100644
--- a/modules/serviceCenter/registry.py
+++ b/modules/serviceCenter/registry.py
@@ -45,10 +45,17 @@ IMPORTABLE_SERVICES: Dict[str, Dict[str, Any]] = {
"billing": {
"module": "modules.serviceCenter.services.serviceBilling.mainServiceBilling",
"class": "BillingService",
- "dependencies": [],
+ "dependencies": ["subscription"],
"objectKey": "service.billing",
"label": {"en": "Billing", "de": "Abrechnung", "fr": "Facturation"},
},
+ "subscription": {
+ "module": "modules.serviceCenter.services.serviceSubscription.mainServiceSubscription",
+ "class": "SubscriptionService",
+ "dependencies": [],
+ "objectKey": "service.subscription",
+ "label": {"en": "Subscription", "de": "Abonnement", "fr": "Abonnement"},
+ },
"sharepoint": {
"module": "modules.serviceCenter.services.serviceSharepoint.mainServiceSharepoint",
"class": "SharepointService",
diff --git a/modules/serviceCenter/services/serviceAgent/agentLoop.py b/modules/serviceCenter/services/serviceAgent/agentLoop.py
index bee03424..c196d237 100644
--- a/modules/serviceCenter/services/serviceAgent/agentLoop.py
+++ b/modules/serviceCenter/services/serviceAgent/agentLoop.py
@@ -25,6 +25,9 @@ from modules.shared.jsonUtils import closeJsonStructures
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import (
InsufficientBalanceException,
)
+from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ SubscriptionInactiveException,
+)
logger = logging.getLogger(__name__)
@@ -191,6 +194,18 @@ async def runAgentLoop(
else:
aiResponse = await aiCallFn(aiRequest)
+ except SubscriptionInactiveException as e:
+ logger.warning(
+ f"Subscription inactive in round {state.currentRound} (mandate={mandateId}): {e.message}"
+ )
+ state.status = AgentStatusEnum.ERROR
+ state.abortReason = e.message
+ yield AgentEvent(
+ type=AgentEventTypeEnum.ERROR,
+ content=e.message,
+ data=e.toClientDict(),
+ )
+ break
except InsufficientBalanceException as e:
logger.warning(
f"Insufficient balance in round {state.currentRound} (mandate={mandateId}): {e.message}"
diff --git a/modules/serviceCenter/services/serviceAi/mainServiceAi.py b/modules/serviceCenter/services/serviceAi/mainServiceAi.py
index 90fd4d9a..09e2d708 100644
--- a/modules/serviceCenter/services/serviceAi/mainServiceAi.py
+++ b/modules/serviceCenter/services/serviceAi/mainServiceAi.py
@@ -27,6 +27,10 @@ from modules.serviceCenter.services.serviceBilling.mainServiceBilling import (
ProviderNotAllowedException,
BillingContextError
)
+from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ SubscriptionInactiveException,
+ SUBSCRIPTION_REASONS,
+)
logger = logging.getLogger(__name__)
@@ -590,11 +594,25 @@ detectedIntent-Werte:
balanceCheck = billingService.checkBalance(estimatedCost)
if not balanceCheck.allowed:
+ reason = balanceCheck.reason or ""
+
+ if reason in SUBSCRIPTION_REASONS:
+ from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum
+ statusMap = {
+ "SUBSCRIPTION_PAYMENT_REQUIRED": SubscriptionStatusEnum.PAST_DUE,
+ "SUBSCRIPTION_EXPIRED": SubscriptionStatusEnum.EXPIRED,
+ "SUBSCRIPTION_INACTIVE": SubscriptionStatusEnum.EXPIRED,
+ }
+ raise SubscriptionInactiveException(
+ status=statusMap.get(reason, SubscriptionStatusEnum.EXPIRED),
+ mandateId=str(mandateId),
+ )
+
balance_str = f"{(balanceCheck.currentBalance or 0):.2f}"
logger.warning(
f"Billing check failed for user {user.id}: "
f"Balance {balance_str} CHF, "
- f"Reason: {balanceCheck.reason}"
+ f"Reason: {reason}"
)
if balanceCheck.billingModel == BillingModelEnum.PREPAY_MANDATE:
ulabel = (getattr(user, "email", None) or getattr(user, "username", None) or str(user.id))
@@ -651,6 +669,8 @@ detectedIntent-Werte:
logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed")
+ except SubscriptionInactiveException:
+ raise
except InsufficientBalanceException:
raise
except ProviderNotAllowedException:
@@ -658,7 +678,6 @@ detectedIntent-Werte:
except BillingContextError:
raise
except Exception as e:
- # FAIL-SAFE: Don't silently swallow errors - log at ERROR level
logger.error(f"BILLING FAIL-SAFE: Billing check failed with unexpected error: {e}")
raise BillingContextError(f"Billing check failed: {e}")
diff --git a/modules/serviceCenter/services/serviceBilling/billingExhaustedNotify.py b/modules/serviceCenter/services/serviceBilling/billingExhaustedNotify.py
index aba08b89..d8f2acc4 100644
--- a/modules/serviceCenter/services/serviceBilling/billingExhaustedNotify.py
+++ b/modules/serviceCenter/services/serviceBilling/billingExhaustedNotify.py
@@ -1,23 +1,17 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
"""
-When the shared mandate pool (PREPAY_MANDATE) is exhausted, notify billing contacts.
+When the shared mandate pool (PREPAY_MANDATE) is exhausted, notify mandate admins.
-Recipients: BillingSettings.notifyEmails for the mandate (configure as mandate owner / finance).
+Uses the central notifyMandateAdmins() function for recipient resolution and delivery.
Emails are throttled per mandate to avoid spam (one notification per cooldown window).
"""
from __future__ import annotations
-import html
import logging
import time
-from typing import Any, Dict, List, Optional
-
-from modules.datamodels.datamodelMessaging import MessagingChannel
-from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
-from modules.interfaces.interfaceMessaging import getInterface as getMessagingInterface
-from modules.security.rootAccess import getRootUser
+from typing import Dict
logger = logging.getLogger(__name__)
@@ -26,29 +20,6 @@ _poolExhaustedEmailLastSent: Dict[str, float] = {}
_DEFAULT_COOLDOWN_SEC = 3600
-def _normalizeNotifyEmails(raw: Any) -> List[str]:
- if raw is None:
- return []
- if isinstance(raw, list):
- return [str(e).strip() for e in raw if str(e).strip()]
- if isinstance(raw, str):
- s = raw.strip()
- if not s:
- return []
- # JSON array string
- if s.startswith("["):
- try:
- import json
-
- parsed = json.loads(s)
- if isinstance(parsed, list):
- return [str(e).strip() for e in parsed if str(e).strip()]
- except Exception:
- pass
- return [s]
- return []
-
-
def maybeEmailMandatePoolExhausted(
mandateId: str,
triggeringUserId: str,
@@ -58,7 +29,7 @@ def maybeEmailMandatePoolExhausted(
cooldownSec: float = _DEFAULT_COOLDOWN_SEC,
) -> None:
"""
- Send one email per mandate per cooldown to BillingSettings.notifyEmails.
+ Send one notification per mandate per cooldown window when the pool is exhausted.
Args:
mandateId: Mandate whose pool is empty.
@@ -82,59 +53,22 @@ def maybeEmailMandatePoolExhausted(
return
try:
- billing = getBillingInterface(getRootUser(), mandateId)
- settings = billing.getSettings(mandateId) or {}
- recipients = _normalizeNotifyEmails(settings.get("notifyEmails"))
- if not recipients:
- logger.warning(
- "PREPAY_MANDATE pool exhausted for mandate %s but notifyEmails is empty — "
- "configure BillingSettings.notifyEmails for owner alerts",
- mandateId,
- )
- return
+ from modules.shared.notifyMandateAdmins import notifyMandateAdmins
- subject = f"[PowerOn] Mandanten-Budget aufgebraucht (Mandant {mandateId[:8]}…)"
- body = (
- f"Das gemeinsame Guthaben (PREPAY_MANDATE) für diesen Mandanten ist nicht mehr ausreichend.\n\n"
- f"Mandanten-ID: {mandateId}\n"
- f"Aktuelles Guthaben (Pool): CHF {currentBalance:.2f}\n"
- f"Benötigt (mind.): CHF {requiredAmount:.2f}\n\n"
- f"Auslösende/r Benutzer/in: {triggeringUserLabel} (ID: {triggeringUserId})\n\n"
- f"Bitte laden Sie das Mandats-Guthaben in der Billing-Verwaltung auf, "
- f"damit Benutzer wieder AI-Funktionen nutzen können.\n"
+ sent = notifyMandateAdmins(
+ mandateId,
+ "[PowerOn] Mandanten-Budget aufgebraucht",
+ "Budget aufgebraucht",
+ [
+ "Das gemeinsame Guthaben (Prepaid-Pool) für diesen Mandanten ist nicht mehr ausreichend.",
+ f"Aktuelles Guthaben: CHF {currentBalance:.2f}\n"
+ f"Benötigt (mindestens): CHF {requiredAmount:.2f}",
+ f"Ausgelöst durch: {triggeringUserLabel}",
+ "Bitte laden Sie das Mandats-Guthaben in der Billing-Verwaltung auf, "
+ "damit Benutzer wieder AI-Funktionen nutzen können.",
+ ],
)
- escaped = html.escape(body)
- # Cannot use '\\n' inside f-string {…} expression (SyntaxError); build replacement outside.
- brWithNl = "
" + "\n"
- htmlMessage = f"""
-
-
-{escaped.replace(chr(10), brWithNl)}
-"""
-
- messaging = getMessagingInterface()
- any_ok = False
- for to in recipients:
- try:
- ok = messaging.send(
- channel=MessagingChannel.EMAIL,
- recipient=to,
- subject=subject,
- message=htmlMessage,
- )
- if ok:
- any_ok = True
- else:
- logger.warning("Pool exhausted email failed for %s", to)
- except Exception as send_err:
- logger.error("Error sending pool exhausted email to %s: %s", to, send_err)
-
- if any_ok:
+ if sent > 0:
_poolExhaustedEmailLastSent[mandateId] = now
- logger.info(
- "Sent mandate pool exhausted notification for mandate %s to %s recipient(s)",
- mandateId,
- len(recipients),
- )
except Exception as e:
logger.error("maybeEmailMandatePoolExhausted failed: %s", e, exc_info=True)
diff --git a/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py b/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py
index d0325a5e..969ec6b8 100644
--- a/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py
+++ b/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py
@@ -160,6 +160,10 @@ class BillingService:
def checkBalance(self, estimatedCost: float = 0.0) -> BillingCheckResult:
"""
Check if the current user/mandate has sufficient balance.
+
+ Gate order:
+ 1. Subscription active? (fast, cached) — blocks AI if not
+ 2. Budget sufficient? (existing prepaid logic)
Args:
estimatedCost: Estimated cost of the operation (with markup applied)
@@ -167,11 +171,42 @@ class BillingService:
Returns:
BillingCheckResult indicating if operation is allowed
"""
+ subResult = self._checkSubscription()
+ if subResult is not None:
+ return subResult
+
return self._billingInterface.checkBalance(
self.mandateId,
self.currentUser.id,
estimatedCost
)
+
+ def _checkSubscription(self) -> Optional[BillingCheckResult]:
+ """Return a failing BillingCheckResult if subscription is not active, else None."""
+ try:
+ from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum
+ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
+ getService as getSubscriptionService,
+ _subscriptionReasonForStatus,
+ _subscriptionUserActionForStatus,
+ )
+
+ subService = getSubscriptionService(self.currentUser, self.mandateId)
+ status = subService.assertActive(self.mandateId)
+
+ if status in (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.PAST_DUE):
+ return None
+
+ return BillingCheckResult(
+ allowed=False,
+ reason=_subscriptionReasonForStatus(status),
+ upgradeRequired=True,
+ subscriptionUiPath="/admin/billing?tab=subscription",
+ userAction=_subscriptionUserActionForStatus(status),
+ )
+ except Exception as e:
+ logger.warning(f"Subscription check failed (allowing): {e}")
+ return None
def hasBalance(self, estimatedCost: float = 0.0) -> bool:
"""
diff --git a/modules/serviceCenter/services/serviceBilling/stripeCheckout.py b/modules/serviceCenter/services/serviceBilling/stripeCheckout.py
index 9f3f7e68..8d6b4a57 100644
--- a/modules/serviceCenter/services/serviceBilling/stripeCheckout.py
+++ b/modules/serviceCenter/services/serviceBilling/stripeCheckout.py
@@ -82,17 +82,8 @@ def create_checkout_session(
f"Invalid amount {amount_chf} CHF. Allowed: {ALLOWED_AMOUNTS_CHF}"
)
- # Pin API version from config (match Stripe Dashboard)
- api_version = APP_CONFIG.get("STRIPE_API_VERSION")
- if api_version:
- stripe.api_version = api_version
-
- # Get secrets
- secret_key = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY")
- if not secret_key:
- raise ValueError("STRIPE_SECRET_KEY_SECRET not configured")
-
- stripe.api_key = secret_key
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
base_return_url = _normalizeReturnUrl(return_url)
query_separator = "&" if "?" in base_return_url else "?"
diff --git a/modules/serviceCenter/services/serviceSubscription/__init__.py b/modules/serviceCenter/services/serviceSubscription/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py b/modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py
new file mode 100644
index 00000000..b9a1481c
--- /dev/null
+++ b/modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py
@@ -0,0 +1,710 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""
+Subscription Service — state-machine-based lifecycle management.
+
+Every mutation takes an explicit subscriptionId. No status-scan guessing.
+See wiki/concepts/Subscription-State-Machine.md for the full state machine.
+"""
+
+import logging
+import time
+from typing import Dict, Any, List, Optional
+from datetime import datetime, timezone, timedelta
+
+from modules.datamodels.datamodelUam import User
+from modules.datamodels.datamodelSubscription import (
+ SubscriptionPlan,
+ MandateSubscription,
+ SubscriptionStatusEnum,
+ BillingPeriodEnum,
+ OPERATIVE_STATUSES,
+ _getPlan,
+ _getSelectablePlans,
+)
+from modules.interfaces.interfaceDbSubscription import (
+ getInterface as getSubscriptionInterface,
+ InvalidTransitionError,
+)
+
+logger = logging.getLogger(__name__)
+
+SUBSCRIPTION_CACHE_TTL_SECONDS = 60
+_STALE_PENDING_SECONDS = 30 * 60
+
+_subscriptionServices: Dict[str, "SubscriptionService"] = {}
+_statusCache: Dict[str, tuple] = {}
+
+
+def getService(currentUser: User, mandateId: str) -> "SubscriptionService":
+ cacheKey = f"{currentUser.id}_{mandateId}"
+ if cacheKey not in _subscriptionServices:
+ _subscriptionServices[cacheKey] = SubscriptionService(currentUser, mandateId)
+ else:
+ _subscriptionServices[cacheKey].setContext(currentUser, mandateId)
+ return _subscriptionServices[cacheKey]
+
+
+class SubscriptionService:
+ """State-machine-based subscription service.
+ All mutations use explicit subscriptionId. No scan-based writes."""
+
+ def __init__(self, contextOrUser, mandateId=None, get_service=None):
+ if mandateId is not None and callable(mandateId):
+ ctx = contextOrUser
+ self.currentUser = ctx.user
+ self.mandateId = ctx.mandate_id or ""
+ elif get_service is not None and hasattr(contextOrUser, "user"):
+ ctx = contextOrUser
+ self.currentUser = ctx.user
+ self.mandateId = ctx.mandate_id or ""
+ else:
+ self.currentUser = contextOrUser
+ self.mandateId = mandateId or ""
+ self._interface = getSubscriptionInterface(self.currentUser, self.mandateId)
+
+ def setContext(self, currentUser: User, mandateId: str):
+ self.currentUser = currentUser
+ self.mandateId = mandateId
+ self._interface = getSubscriptionInterface(currentUser, mandateId)
+
+ # =========================================================================
+ # Billing gate (cached, read-only)
+ # =========================================================================
+
+ def assertActive(self, mandateId: str = None) -> SubscriptionStatusEnum:
+ """Return subscription status for billing decisions. Uses TTL cache.
+ This is the ONLY method that works by mandateId (read-only)."""
+ mid = mandateId or self.mandateId
+ now = time.monotonic()
+
+ cached = _statusCache.get(mid)
+ if cached and cached[1] > now:
+ return cached[0]
+
+ status = self._interface.assertActive(mid)
+ _statusCache[mid] = (status, now + SUBSCRIPTION_CACHE_TTL_SECONDS)
+ return status
+
+ def invalidateCache(self, mandateId: str = None):
+ mid = mandateId or self.mandateId
+ _statusCache.pop(mid, None)
+
+ # =========================================================================
+ # Capacity (delegation)
+ # =========================================================================
+
+ def assertCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> bool:
+ return self._interface.assertCapacity(mandateId or self.mandateId, resourceType, delta)
+
+ # =========================================================================
+ # Read operations
+ # =========================================================================
+
+ def getById(self, subscriptionId: str) -> Optional[Dict[str, Any]]:
+ return self._interface.getById(subscriptionId)
+
+ def getOperativeSubscription(self, mandateId: str = None) -> Optional[Dict[str, Any]]:
+ return self._interface.getOperativeForMandate(mandateId or self.mandateId)
+
+ def getScheduledSubscription(self, mandateId: str = None) -> Optional[Dict[str, Any]]:
+ return self._interface.getScheduledForMandate(mandateId or self.mandateId)
+
+ def listSubscriptions(self, mandateId: str = None, statusFilter=None) -> List[Dict[str, Any]]:
+ return self._interface.listForMandate(mandateId or self.mandateId, statusFilter)
+
+ def getSelectablePlans(self) -> List[SubscriptionPlan]:
+ return _getSelectablePlans()
+
+ def getPlan(self, planKey: str) -> Optional[SubscriptionPlan]:
+ return _getPlan(planKey)
+
+ # =========================================================================
+ # T1/T2: Plan activation (creates PENDING, returns checkout URL)
+ # =========================================================================
+
+ def activatePlan(self, mandateId: str, planKey: str, returnUrl: str) -> Dict[str, Any]:
+ """Create a new subscription as PENDING and start the checkout flow.
+
+ - Free/trial plans: immediately ACTIVE/TRIALING (no checkout).
+ - Paid plans with active predecessor: PENDING -> checkout -> SCHEDULED on confirmation.
+ - Paid plans without predecessor: PENDING -> checkout -> ACTIVE on confirmation.
+
+ Cleans up any existing PENDING/SCHEDULED for this mandate first (by ID)."""
+ mid = mandateId or self.mandateId
+ plan = _getPlan(planKey)
+ if not plan:
+ raise ValueError(f"Unknown plan: {planKey}")
+
+ isPaid = plan.billingPeriod != BillingPeriodEnum.NONE and not plan.trialDays
+ currentOperative = self._interface.getOperativeForMandate(mid)
+
+ self._cleanupPreparatorySubscriptions(mid)
+
+ now = datetime.now(timezone.utc)
+ if plan.trialDays:
+ initialStatus = SubscriptionStatusEnum.TRIALING
+ elif isPaid:
+ initialStatus = SubscriptionStatusEnum.PENDING
+ else:
+ initialStatus = SubscriptionStatusEnum.ACTIVE
+
+ sub = MandateSubscription(
+ mandateId=mid,
+ planKey=planKey,
+ status=initialStatus,
+ recurring=plan.autoRenew and not plan.trialDays,
+ startedAt=now,
+ currentPeriodStart=now,
+ snapshotPricePerUserCHF=plan.pricePerUserCHF,
+ snapshotPricePerInstanceCHF=plan.pricePerFeatureInstanceCHF,
+ )
+
+ if plan.trialDays:
+ sub.trialEndsAt = now + timedelta(days=plan.trialDays)
+
+ if plan.billingPeriod == BillingPeriodEnum.MONTHLY:
+ sub.currentPeriodEnd = now + timedelta(days=30)
+ elif plan.billingPeriod == BillingPeriodEnum.YEARLY:
+ sub.currentPeriodEnd = now + timedelta(days=365)
+
+ created = self._interface.createSubscription(sub)
+
+ from urllib.parse import urlparse
+ parsed = urlparse(returnUrl) if returnUrl else None
+ pUrl = f"{parsed.scheme}://{parsed.netloc}" if parsed and parsed.scheme else ""
+
+ if isPaid:
+ try:
+ checkoutUrl = self._createCheckoutSession(mid, plan, created, currentOperative, returnUrl)
+ created["redirectUrl"] = checkoutUrl
+ except Exception as e:
+ self._interface.forceExpire(created["id"])
+ self.invalidateCache(mid)
+ raise ValueError(f"Subscription konnte nicht erstellt werden: {e}") from e
+ else:
+ if currentOperative:
+ self._expireOperative(currentOperative["id"], mid)
+ _notifySubscriptionChange(mid, "activated", plan, subscriptionRecord=created, platformUrl=pUrl)
+
+ self.invalidateCache(mid)
+ return created
+
+ def _cleanupPreparatorySubscriptions(self, mandateId: str) -> None:
+ """Expire any existing PENDING or SCHEDULED subscriptions for this mandate (by ID)."""
+ preparatory = self._interface.listForMandate(
+ mandateId, [SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.SCHEDULED],
+ )
+ for sub in preparatory:
+ subId = sub["id"]
+ currentStatus = SubscriptionStatusEnum(sub["status"])
+ stripeSubId = sub.get("stripeSubscriptionId")
+
+ if stripeSubId and currentStatus == SubscriptionStatusEnum.SCHEDULED:
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ stripe.Subscription.cancel(stripeSubId)
+ except Exception as e:
+ logger.error("Failed to cancel Stripe sub %s during cleanup: %s", stripeSubId, e)
+
+ self._interface.transitionStatus(subId, currentStatus, SubscriptionStatusEnum.EXPIRED)
+ logger.info("Cleaned up %s subscription %s for mandate %s", currentStatus.value, subId, mandateId)
+
+ def _expireOperative(self, subscriptionId: str, mandateId: str) -> None:
+ """Expire the current operative subscription (used when a free/trial plan replaces it)."""
+ sub = self._interface.getById(subscriptionId)
+ if not sub:
+ return
+ currentStatus = SubscriptionStatusEnum(sub["status"])
+ if currentStatus in OPERATIVE_STATUSES:
+ stripeSubId = sub.get("stripeSubscriptionId")
+ if stripeSubId:
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ stripe.Subscription.cancel(stripeSubId)
+ except Exception as e:
+ logger.error("Failed to cancel Stripe sub %s: %s", stripeSubId, e)
+ self._interface.transitionStatus(subscriptionId, currentStatus, SubscriptionStatusEnum.EXPIRED)
+
+ def _createCheckoutSession(
+ self, mandateId: str, plan: SubscriptionPlan, subRecord: Dict[str, Any],
+ currentOperative: Optional[Dict[str, Any]], returnUrl: str,
+ ) -> str:
+ """Create a Stripe Checkout Session. If a predecessor exists, delays billing
+ via trial_end to start after the predecessor's period end."""
+ from modules.shared.stripeClient import getStripeClient
+ from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import getStripePricesForPlan
+
+ stripe = getStripeClient()
+ priceMapping = getStripePricesForPlan(plan.planKey)
+ if not priceMapping or (not priceMapping.stripePriceIdUsers and not priceMapping.stripePriceIdInstances):
+ raise ValueError(f"Stripe Price IDs not provisioned for plan {plan.planKey}")
+
+ stripeCustomerId = self._resolveStripeCustomer(mandateId)
+ if not stripeCustomerId:
+ raise ValueError(f"Could not resolve Stripe customer for mandate {mandateId}")
+
+ activeUsers = self._interface.countActiveUsers(mandateId)
+ activeInstances = self._interface.countActiveFeatureInstances(mandateId)
+
+ lineItems = []
+ if priceMapping.stripePriceIdUsers:
+ lineItems.append({"price": priceMapping.stripePriceIdUsers, "quantity": max(activeUsers, 1)})
+ if priceMapping.stripePriceIdInstances and activeInstances > 0:
+ lineItems.append({"price": priceMapping.stripePriceIdInstances, "quantity": activeInstances})
+
+ if not returnUrl:
+ raise ValueError("returnUrl is required for paid subscription checkout")
+
+ from urllib.parse import urlparse
+ parsedReturn = urlparse(returnUrl)
+ platformUrl = f"{parsedReturn.scheme}://{parsedReturn.netloc}" if parsedReturn.scheme else ""
+
+ separator = "&" if "?" in returnUrl else "?"
+ successUrl = f"{returnUrl}{separator}success=true&session_id={{CHECKOUT_SESSION_ID}}"
+ cancelUrl = f"{returnUrl}{separator}canceled=true"
+
+ subscriptionData: Dict[str, Any] = {
+ "metadata": {
+ "mandateId": mandateId,
+ "subscriptionRecordId": subRecord["id"],
+ "planKey": plan.planKey,
+ "platformUrl": platformUrl,
+ },
+ }
+
+ if currentOperative and currentOperative.get("currentPeriodEnd"):
+ periodEnd = currentOperative["currentPeriodEnd"]
+ if isinstance(periodEnd, str):
+ periodEnd = datetime.fromisoformat(periodEnd)
+ trialEndTs = int(periodEnd.timestamp())
+ subscriptionData["trial_end"] = trialEndTs
+ self._interface.updateFields(subRecord["id"], {"effectiveFrom": periodEnd.isoformat()})
+
+ session = stripe.checkout.Session.create(
+ mode="subscription",
+ customer=stripeCustomerId,
+ line_items=lineItems,
+ success_url=successUrl,
+ cancel_url=cancelUrl,
+ subscription_data=subscriptionData,
+ )
+
+ if not session or not session.url:
+ raise ValueError("Stripe Checkout Session creation failed")
+
+ logger.info("Checkout session %s created for mandate %s, plan %s", session.id, mandateId, plan.planKey)
+ return session.url
+
+ def _resolveStripeCustomer(self, mandateId: str) -> Optional[str]:
+ try:
+ from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
+ billingIf = getBillingInterface(self.currentUser, mandateId)
+ settings = billingIf.getSettings(mandateId)
+ if not settings:
+ return None
+ customerId = settings.get("stripeCustomerId")
+ if customerId:
+ return customerId
+
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+
+ mandateLabel = mandateId
+ try:
+ from modules.datamodels.datamodelUam import Mandate
+ from modules.security.rootAccess import getRootDbAppConnector
+ appDb = getRootDbAppConnector()
+ rows = appDb.getRecordset(Mandate, recordFilter={"id": mandateId})
+ if rows:
+ mandateLabel = rows[0].get("label") or rows[0].get("name") or mandateId
+ except Exception:
+ pass
+
+ customer = stripe.Customer.create(name=mandateLabel, metadata={"mandateId": mandateId})
+ billingIf.updateSettings(settings["id"], {"stripeCustomerId": customer.id})
+ logger.info("Stripe customer %s created for mandate %s", customer.id, mandateId)
+ return customer.id
+ except Exception as e:
+ logger.error("_resolveStripeCustomer(%s) failed: %s", mandateId, e)
+ return None
+
+ # =========================================================================
+ # T7: Cancel (set recurring=false)
+ # =========================================================================
+
+ def cancelSubscription(self, subscriptionId: str) -> Dict[str, Any]:
+ """Cancel a subscription (T7: set recurring=false, Stripe cancel_at_period_end).
+ The subscription stays ACTIVE until its period ends."""
+ sub = self._interface.getById(subscriptionId)
+ if not sub:
+ raise ValueError(f"Subscription {subscriptionId} not found")
+
+ status = sub.get("status", "")
+ mandateId = sub["mandateId"]
+
+ if status == SubscriptionStatusEnum.PENDING.value:
+ result = self._interface.transitionStatus(
+ subscriptionId, SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.EXPIRED,
+ )
+ self.invalidateCache(mandateId)
+ return result
+
+ if status == SubscriptionStatusEnum.SCHEDULED.value:
+ stripeSubId = sub.get("stripeSubscriptionId")
+ if stripeSubId:
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ stripe.Subscription.cancel(stripeSubId)
+ except Exception as e:
+ logger.error("Failed to cancel Stripe sub %s: %s", stripeSubId, e)
+ result = self._interface.transitionStatus(
+ subscriptionId, SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.EXPIRED,
+ )
+ self.invalidateCache(mandateId)
+ return result
+
+ if status != SubscriptionStatusEnum.ACTIVE.value:
+ raise ValueError(f"Cannot cancel subscription in status {status}")
+
+ if not sub.get("recurring", True):
+ raise ValueError("Subscription is already cancelled (non-recurring)")
+
+ stripeSubId = sub.get("stripeSubscriptionId")
+ if stripeSubId:
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ stripe.Subscription.modify(stripeSubId, cancel_at_period_end=True)
+ except Exception as e:
+ logger.error("Failed to set cancel_at_period_end for %s: %s", stripeSubId, e)
+
+ result = self._interface.updateFields(subscriptionId, {"recurring": False})
+ self.invalidateCache(mandateId)
+
+ plan = _getPlan(sub.get("planKey", ""))
+ _notifySubscriptionChange(mandateId, "cancelled", plan)
+ return result
+
+ # =========================================================================
+ # T8: Reactivate (set recurring=true)
+ # =========================================================================
+
+ def reactivateSubscription(self, subscriptionId: str) -> Dict[str, Any]:
+ """Reactivate a cancelled subscription before its period ends (T8: recurring=true)."""
+ sub = self._interface.getById(subscriptionId)
+ if not sub:
+ raise ValueError(f"Subscription {subscriptionId} not found")
+
+ if sub.get("status") != SubscriptionStatusEnum.ACTIVE.value:
+ raise ValueError(f"Can only reactivate ACTIVE subscriptions, got {sub.get('status')}")
+ if sub.get("recurring", True):
+ raise ValueError("Subscription is already recurring")
+
+ periodEnd = sub.get("currentPeriodEnd")
+ if periodEnd:
+ if isinstance(periodEnd, str):
+ periodEnd = datetime.fromisoformat(periodEnd)
+ if periodEnd <= datetime.now(timezone.utc):
+ raise ValueError("Cannot reactivate — period has already ended")
+
+ stripeSubId = sub.get("stripeSubscriptionId")
+ if stripeSubId:
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ stripe.Subscription.modify(stripeSubId, cancel_at_period_end=False)
+ except Exception as e:
+ logger.error("Failed to reactivate Stripe sub %s: %s", stripeSubId, e)
+
+ result = self._interface.updateFields(subscriptionId, {"recurring": True})
+ self.invalidateCache(sub["mandateId"])
+ return result
+
+ # =========================================================================
+ # T13: Sysadmin force-cancel
+ # =========================================================================
+
+ def forceCancel(self, subscriptionId: str) -> Dict[str, Any]:
+ """Sysadmin force-cancel: immediately expire any non-terminal subscription."""
+ sub = self._interface.getById(subscriptionId)
+ if not sub:
+ raise ValueError(f"Subscription {subscriptionId} not found")
+
+ stripeSubId = sub.get("stripeSubscriptionId")
+ if stripeSubId:
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ stripe.Subscription.cancel(stripeSubId)
+ except Exception as e:
+ logger.error("Failed to cancel Stripe sub %s: %s", stripeSubId, e)
+
+ result = self._interface.forceExpire(subscriptionId)
+ self.invalidateCache(sub["mandateId"])
+ return result
+
+ # =========================================================================
+ # T6: Trial expiry
+ # =========================================================================
+
+ def handleTrialExpiry(self, subscriptionId: str) -> None:
+ """Expire a trial subscription (T6: TRIALING -> EXPIRED)."""
+ sub = self._interface.getById(subscriptionId)
+ if not sub or sub.get("status") != SubscriptionStatusEnum.TRIALING.value:
+ return
+
+ self._interface.transitionStatus(
+ subscriptionId, SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.EXPIRED,
+ )
+ self.invalidateCache(sub["mandateId"])
+
+ plan = _getPlan(sub.get("planKey", ""))
+ successorPlan = _getPlan(plan.successorPlanKey) if plan and plan.successorPlanKey else None
+ _notifySubscriptionChange(sub["mandateId"], "trial_expired", successorPlan)
+ logger.info("Trial expired for subscription %s", subscriptionId)
+
+ # =========================================================================
+ # Stripe quantity sync
+ # =========================================================================
+
+ def syncStripeQuantity(self, subscriptionId: str):
+ self._interface.syncQuantityToStripe(subscriptionId)
+
+
+# ============================================================================
+# Notifications
+# ============================================================================
+
+def _notifySubscriptionChange(
+ mandateId: str,
+ event: str,
+ plan: Optional[SubscriptionPlan] = None,
+ subscriptionRecord: Optional[Dict[str, Any]] = None,
+ platformUrl: str = "",
+) -> None:
+ try:
+ from modules.shared.notifyMandateAdmins import notifyMandateAdmins
+
+ planLabel = (plan.title.get("de") or plan.title.get("en") or plan.planKey) if plan else "—"
+ platformHint = f"Plattform: {platformUrl}" if platformUrl else ""
+
+ rawHtmlBlock: Optional[str] = None
+
+ if event == "activated" and plan and subscriptionRecord:
+ rawHtmlBlock = _buildInvoiceSummaryHtml(plan, subscriptionRecord, mandateId, platformUrl)
+
+ templates: Dict[str, Dict[str, Any]] = {
+ "activated": {
+ "subject": f"[PowerOn] Abonnement aktiviert — {planLabel}",
+ "headline": "Abonnement aktiviert",
+ "paragraphs": [
+ p for p in [
+ f"Das Abonnement wurde auf den Plan «{planLabel}» aktiviert.",
+ platformHint,
+ "Sie können Ihr Abonnement jederzeit unter Billing-Verwaltung › Abonnement einsehen und verwalten.",
+ ] if p
+ ],
+ },
+ "cancelled": {
+ "subject": f"[PowerOn] Abonnement gekündigt — {planLabel}",
+ "headline": "Abonnement gekündigt",
+ "paragraphs": [
+ p for p in [
+ f"Das Abonnement «{planLabel}» wurde gekündigt.",
+ platformHint,
+ "Die Kündigung wird zum Ende der aktuellen bezahlten Periode wirksam. Bis dahin bleibt der volle Zugang bestehen.",
+ ] if p
+ ],
+ },
+ "trial_expired": {
+ "subject": "[PowerOn] Testphase abgelaufen",
+ "headline": "Testphase abgelaufen",
+ "paragraphs": [
+ p for p in [
+ "Die kostenlose Testphase ist abgelaufen.",
+ platformHint,
+ "Bitte wählen Sie einen Plan unter Billing-Verwaltung › Abonnement, damit der Zugang nicht unterbrochen wird.",
+ ] if p
+ ],
+ },
+ "payment_failed": {
+ "subject": f"[PowerOn] Zahlung fehlgeschlagen — {planLabel}",
+ "headline": "Zahlung fehlgeschlagen",
+ "paragraphs": [
+ p for p in [
+ f"Die Zahlung für das Abonnement «{planLabel}» ist fehlgeschlagen.",
+ platformHint,
+ "Bitte aktualisieren Sie Ihr Zahlungsmittel unter Billing-Verwaltung.",
+ ] if p
+ ],
+ },
+ }
+
+ tpl = templates.get(event, {
+ "subject": f"[PowerOn] Abonnement-Änderung — {planLabel}",
+ "headline": "Abonnement-Änderung",
+ "paragraphs": [f"Änderung am Abonnement «{planLabel}»."],
+ })
+
+ notifyMandateAdmins(
+ mandateId, tpl["subject"], tpl["headline"], tpl["paragraphs"],
+ rawHtmlBlock=rawHtmlBlock,
+ )
+ except Exception as e:
+ logger.error("_notifySubscriptionChange failed for mandate %s event %s: %s", mandateId, event, e)
+
+
+def _buildInvoiceSummaryHtml(
+ plan: SubscriptionPlan,
+ subRecord: Dict[str, Any],
+ mandateId: str,
+ platformUrl: str = "",
+) -> str:
+ """Build an HTML invoice summary block for inclusion in the activation email."""
+ import html as htmlmod
+ from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
+
+ subInterface = getSubRootInterface()
+ userCount = subInterface.countActiveUsers(mandateId)
+ instanceCount = subInterface.countActiveFeatureInstances(mandateId)
+
+ userPrice = plan.pricePerUserCHF
+ instancePrice = plan.pricePerFeatureInstanceCHF
+ userTotal = userCount * userPrice
+ instanceTotal = instanceCount * instancePrice
+ netTotal = userTotal + instanceTotal
+
+ periodLabel = {"MONTHLY": "Monatlich", "YEARLY": "Jährlich"}.get(plan.billingPeriod, plan.billingPeriod)
+
+ def _chf(amount: float) -> str:
+ return f"CHF {amount:,.2f}".replace(",", "'")
+
+ rows = ""
+ if userPrice > 0:
+ rows += (
+ f'| Benutzer-Lizenzen | '
+ f'{userCount} × {_chf(userPrice)} | '
+ f'{_chf(userTotal)} |
\n'
+ )
+ if instancePrice > 0:
+ rows += (
+ f'| Feature-Instanzen | '
+ f'{instanceCount} × {_chf(instancePrice)} | '
+ f'{_chf(instanceTotal)} |
\n'
+ )
+
+ invoiceLink = ""
+ stripeSubId = subRecord.get("stripeSubscriptionId")
+ if stripeSubId:
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ invoices = stripe.Invoice.list(subscription=stripeSubId, limit=1)
+ if invoices.data:
+ hostedUrl = invoices.data[0].get("hosted_invoice_url", "")
+ if hostedUrl:
+ invoiceLink = (
+ f''
+ f''
+ f'Vollständige Rechnung mit MwSt-Ausweis anzeigen
\n'
+ )
+ except Exception as e:
+ logger.warning("Could not fetch Stripe invoice URL for sub %s: %s", stripeSubId, e)
+
+ return (
+ f''
+ f''
+ f'| Position | '
+ f'Menge × Preis | '
+ f'Total | '
+ f'
'
+ f'{rows}'
+ f''
+ f'| Netto-Total ({periodLabel}) | '
+ f' | '
+ f'{_chf(netTotal)} | '
+ f'
'
+ f'
'
+ f'{invoiceLink}'
+ )
+
+
+# ============================================================================
+# Exception Classes
+# ============================================================================
+
+SUBSCRIPTION_USER_ACTION_UPGRADE = "UPGRADE_SUBSCRIPTION"
+SUBSCRIPTION_USER_ACTION_REACTIVATE = "REACTIVATE_SUBSCRIPTION"
+SUBSCRIPTION_USER_ACTION_ADD_PAYMENT = "ADD_PAYMENT_METHOD"
+
+SUBSCRIPTION_REASONS = {
+ "SUBSCRIPTION_INACTIVE",
+ "SUBSCRIPTION_PAYMENT_REQUIRED",
+ "SUBSCRIPTION_PAYMENT_PENDING",
+ "SUBSCRIPTION_EXPIRED",
+}
+
+
+def _subscriptionReasonForStatus(status: SubscriptionStatusEnum) -> str:
+ if status == SubscriptionStatusEnum.PENDING:
+ return "SUBSCRIPTION_PAYMENT_PENDING"
+ if status == SubscriptionStatusEnum.PAST_DUE:
+ return "SUBSCRIPTION_PAYMENT_REQUIRED"
+ if status == SubscriptionStatusEnum.EXPIRED:
+ return "SUBSCRIPTION_EXPIRED"
+ return "SUBSCRIPTION_INACTIVE"
+
+
+def _subscriptionUserActionForStatus(status: SubscriptionStatusEnum) -> str:
+ if status in (SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.PENDING):
+ return SUBSCRIPTION_USER_ACTION_ADD_PAYMENT
+ return SUBSCRIPTION_USER_ACTION_UPGRADE
+
+
+class SubscriptionInactiveException(Exception):
+ def __init__(self, status: SubscriptionStatusEnum, mandateId: str = "", message: Optional[str] = None):
+ self.status = status
+ self.mandateId = mandateId
+ self.reason = _subscriptionReasonForStatus(status)
+ self.userAction = _subscriptionUserActionForStatus(status)
+ self.message = message or (
+ "Kein aktives Abonnement für diesen Mandanten. Bitte wählen Sie einen Plan unter Billing."
+ )
+ super().__init__(self.message)
+
+ def toClientDict(self) -> Dict[str, Any]:
+ out: Dict[str, Any] = {
+ "error": self.reason, "message": self.message,
+ "userAction": self.userAction, "subscriptionUiPath": "/admin/billing?tab=subscription",
+ }
+ if self.mandateId:
+ out["mandateId"] = self.mandateId
+ return out
+
+
+class SubscriptionCapacityException(Exception):
+ def __init__(self, resourceType: str, currentCount: int, maxAllowed: int, message: Optional[str] = None):
+ self.resourceType = resourceType
+ self.currentCount = currentCount
+ self.maxAllowed = maxAllowed
+ self.message = message or (
+ f"Ihr Plan erlaubt maximal {maxAllowed} {'Benutzer' if resourceType == 'users' else 'Feature-Instanzen'} "
+ f"(aktuell {currentCount}). Bitte wechseln Sie zu einem grösseren Plan."
+ )
+ super().__init__(self.message)
+
+ def toClientDict(self) -> Dict[str, Any]:
+ return {
+ "error": f"SUBSCRIPTION_{self.resourceType.upper()}_LIMIT",
+ "currentCount": self.currentCount, "maxAllowed": self.maxAllowed,
+ "message": self.message, "userAction": SUBSCRIPTION_USER_ACTION_UPGRADE,
+ "subscriptionUiPath": "/admin/billing?tab=subscription",
+ }
+
+
+SubscriptionService.SubscriptionInactiveException = SubscriptionInactiveException
+SubscriptionService.SubscriptionCapacityException = SubscriptionCapacityException
diff --git a/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py b/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py
new file mode 100644
index 00000000..fd5666e0
--- /dev/null
+++ b/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py
@@ -0,0 +1,214 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""
+Auto-provision Stripe Products and Prices from the built-in plan catalog.
+
+Creates separate Stripe Products for user licenses and feature instances
+so that invoice line items show clear, descriptive names:
+ - "Benutzer-Lizenzen"
+ - "Feature-Instanzen"
+
+Idempotent — safe to call on every startup.
+"""
+
+import logging
+from typing import Dict, Optional
+
+from modules.connectors.connectorDbPostgre import DatabaseConnector
+from modules.shared.configuration import APP_CONFIG
+from modules.datamodels.datamodelSubscription import (
+ BUILTIN_PLANS,
+ SubscriptionPlan,
+ BillingPeriodEnum,
+ StripePlanPrice,
+)
+
+logger = logging.getLogger(__name__)
+
+_BILLING_DATABASE = "poweron_billing"
+_METADATA_KEY = "poweron_plan_key"
+_METADATA_LINE_TYPE = "poweron_line_type"
+
+_PERIOD_TO_STRIPE = {
+ BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1},
+ BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1},
+}
+
+
+def _getBillingDb() -> DatabaseConnector:
+ return DatabaseConnector(
+ dbDatabase=_BILLING_DATABASE,
+ dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
+ dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
+ dbUser=APP_CONFIG.get("DB_USER"),
+ dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
+ )
+
+
+def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]:
+ try:
+ rows = db.getRecordset(StripePlanPrice)
+ result = {}
+ for row in rows:
+ pk = row.get("planKey")
+ if pk:
+ result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")})
+ return result
+ except Exception as e:
+ logger.warning("Could not load StripePlanPrice records: %s", e)
+ return {}
+
+
+def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]:
+ """Search Stripe for a product tagged with plan key + line type."""
+ try:
+ products = stripe.Product.search(
+ query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"',
+ limit=1,
+ )
+ if products.data:
+ return products.data[0].id
+ except Exception:
+ try:
+ products = stripe.Product.search(
+ query=f'metadata["{_METADATA_KEY}"]:"{planKey}"',
+ limit=10,
+ )
+ for p in products.data:
+ meta = p.get("metadata") or {}
+ if meta.get(_METADATA_LINE_TYPE) == lineType:
+ return p.id
+ except Exception:
+ pass
+ return None
+
+
+def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str:
+ product = stripe.Product.create(
+ name=name,
+ description=description,
+ metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType},
+ )
+ logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType)
+ return product.id
+
+
+def _findExistingStripePrice(stripe, productId: str, unitAmount: int, interval: str) -> Optional[str]:
+ try:
+ prices = stripe.Price.list(product=productId, active=True, limit=50)
+ for p in prices.data:
+ recurring = p.get("recurring") or {}
+ if p.get("unit_amount") == unitAmount and recurring.get("interval") == interval:
+ return p.id
+ except Exception:
+ pass
+ return None
+
+
+def _createStripePrice(stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str) -> str:
+ price = stripe.Price.create(
+ product=productId,
+ unit_amount=int(unitAmountCHF * 100),
+ currency="chf",
+ recurring={"interval": interval},
+ nickname=nickname,
+ )
+ logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval)
+ return price.id
+
+
+def bootstrapStripePrices() -> None:
+ """Ensure all paid plans have separate Stripe Products for users and instances."""
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ except ValueError as e:
+ logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e)
+ return
+
+ db = _getBillingDb()
+ existing = _loadExistingMappings(db)
+
+ for planKey, plan in BUILTIN_PLANS.items():
+ if plan.billingPeriod == BillingPeriodEnum.NONE:
+ continue
+ if plan.pricePerUserCHF == 0 and plan.pricePerFeatureInstanceCHF == 0:
+ continue
+
+ stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod)
+ if not stripePeriod:
+ continue
+
+ interval = stripePeriod["interval"]
+
+ if planKey in existing:
+ mapping = existing[planKey]
+ hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances
+ hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances
+ if hasAllPrices and hasAllProducts:
+ logger.debug("Stripe prices already configured for plan %s", planKey)
+ continue
+
+ productIdUsers = None
+ productIdInstances = None
+ priceIdUsers = None
+ priceIdInstances = None
+
+ if plan.pricePerUserCHF > 0:
+ productIdUsers = _findStripeProduct(stripe, planKey, "users")
+ if not productIdUsers:
+ productIdUsers = _createStripeProduct(
+ stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title.get('de', planKey)}",
+ planKey, "users",
+ )
+ priceIdUsers = _findExistingStripePrice(stripe, productIdUsers, int(plan.pricePerUserCHF * 100), interval)
+ if not priceIdUsers:
+ priceIdUsers = _createStripePrice(
+ stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
+ )
+
+ if plan.pricePerFeatureInstanceCHF > 0:
+ productIdInstances = _findStripeProduct(stripe, planKey, "instances")
+ if not productIdInstances:
+ productIdInstances = _createStripeProduct(
+ stripe, "Feature-Instanzen", f"Feature-Instanzen für {plan.title.get('de', planKey)}",
+ planKey, "instances",
+ )
+ priceIdInstances = _findExistingStripePrice(
+ stripe, productIdInstances, int(plan.pricePerFeatureInstanceCHF * 100), interval,
+ )
+ if not priceIdInstances:
+ priceIdInstances = _createStripePrice(
+ stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval,
+ f"{planKey} — Feature-Instanz",
+ )
+
+ persistData = {
+ "stripeProductId": "",
+ "stripeProductIdUsers": productIdUsers,
+ "stripeProductIdInstances": productIdInstances,
+ "stripePriceIdUsers": priceIdUsers,
+ "stripePriceIdInstances": priceIdInstances,
+ }
+
+ if planKey in existing:
+ db.recordModify(StripePlanPrice, existing[planKey].id, persistData)
+ else:
+ db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump())
+
+ logger.info(
+ "Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s",
+ planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances,
+ )
+
+
+def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]:
+ """Load the persisted Stripe IDs for a plan."""
+ try:
+ db = _getBillingDb()
+ rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey})
+ if rows:
+ return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")})
+ except Exception as e:
+ logger.error("Error loading Stripe prices for plan %s: %s", planKey, e)
+ return None
diff --git a/modules/shared/configuration.py b/modules/shared/configuration.py
index 73e2da90..721ce448 100644
--- a/modules/shared/configuration.py
+++ b/modules/shared/configuration.py
@@ -66,7 +66,7 @@ class Configuration:
self._configMtime = currentMtime
try:
- with open(configPath, 'r') as f:
+ with open(configPath, 'r', encoding='utf-8') as f:
lines = f.readlines()
i = 0
diff --git a/modules/shared/notifyMandateAdmins.py b/modules/shared/notifyMandateAdmins.py
new file mode 100644
index 00000000..27445afb
--- /dev/null
+++ b/modules/shared/notifyMandateAdmins.py
@@ -0,0 +1,285 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""
+Central notification utility for mandate administrators.
+
+All mandate-level notifications (subscription changes, billing warnings, etc.)
+MUST go through notifyMandateAdmins() to ensure consistent recipient resolution
+and delivery.
+
+Recipients are the union of:
+1. BillingSettings.notifyEmails for the mandate (configured contact addresses)
+2. All users with the mandate-level "admin" RBAC role
+"""
+
+from __future__ import annotations
+
+import html
+import json
+import logging
+from typing import Any, Dict, List, Optional, Set
+
+from modules.datamodels.datamodelMessaging import MessagingChannel
+from modules.interfaces.interfaceMessaging import getInterface as getMessagingInterface
+
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Recipient resolution
+# ============================================================================
+
+
+def _normalizeEmailList(raw: Any) -> List[str]:
+ """Parse notifyEmails which can be a list, JSON string, or single address."""
+ if raw is None:
+ return []
+ if isinstance(raw, list):
+ return [str(e).strip().lower() for e in raw if str(e).strip()]
+ if isinstance(raw, str):
+ s = raw.strip()
+ if not s:
+ return []
+ if s.startswith("["):
+ try:
+ parsed = json.loads(s)
+ if isinstance(parsed, list):
+ return [str(e).strip().lower() for e in parsed if str(e).strip()]
+ except Exception:
+ pass
+ return [s.lower()]
+ return []
+
+
+def _resolveMandateContactEmails(mandateId: str) -> List[str]:
+ """Get the configured notifyEmails from BillingSettings."""
+ try:
+ from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
+ from modules.security.rootAccess import getRootUser
+ billingIf = getBillingInterface(getRootUser(), mandateId)
+ settings = billingIf.getSettings(mandateId) or {}
+ return _normalizeEmailList(settings.get("notifyEmails"))
+ except Exception as e:
+ logger.warning("Could not resolve BillingSettings.notifyEmails for mandate %s: %s", mandateId, e)
+ return []
+
+
+def _resolveMandateAdminEmails(mandateId: str) -> List[str]:
+ """Resolve all admin users of a mandate via RBAC and return their emails."""
+ emails: List[str] = []
+ try:
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ rootIf = getRootInterface()
+ userMandates = rootIf.getUserMandatesByMandate(mandateId)
+ for um in userMandates:
+ if not getattr(um, "enabled", True):
+ continue
+ umId = str(getattr(um, "id", ""))
+ userId = getattr(um, "userId", None)
+ if not userId:
+ continue
+ roleIds = rootIf.getRoleIdsForUserMandate(umId)
+ isAdmin = False
+ for roleId in roleIds:
+ role = rootIf.getRole(roleId)
+ if role and role.roleLabel == "admin" and not role.featureInstanceId:
+ isAdmin = True
+ break
+ if not isAdmin:
+ continue
+ user = rootIf.getUser(str(userId))
+ if user and user.email:
+ emails.append(user.email.strip().lower())
+ except Exception as e:
+ logger.warning("Could not resolve admin emails for mandate %s: %s", mandateId, e)
+ return emails
+
+
+def _resolveAllRecipients(mandateId: str) -> List[str]:
+ """Union of BillingSettings.notifyEmails + all mandate admin user emails, deduplicated."""
+ seen: Set[str] = set()
+ result: List[str] = []
+ for email in _resolveMandateContactEmails(mandateId) + _resolveMandateAdminEmails(mandateId):
+ if email and email not in seen:
+ seen.add(email)
+ result.append(email)
+ return result
+
+
+# ============================================================================
+# Mandate name resolution
+# ============================================================================
+
+
+def _resolveMandateName(mandateId: str) -> str:
+ """Return the human-readable mandate name (label or name), falling back to a short ID."""
+ try:
+ from modules.datamodels.datamodelUam import Mandate
+ from modules.security.rootAccess import getRootDbAppConnector
+ appDb = getRootDbAppConnector()
+ rows = appDb.getRecordset(Mandate, recordFilter={"id": mandateId})
+ if rows:
+ return rows[0].get("label") or rows[0].get("name") or mandateId[:8]
+ except Exception as e:
+ logger.warning("Could not resolve mandate name for %s: %s", mandateId, e)
+ return mandateId[:8]
+
+
+# ============================================================================
+# HTML email rendering
+# ============================================================================
+
+
+def _getOperatorInfo() -> Dict[str, str]:
+ """Load operator company data from config.ini."""
+ try:
+ from modules.shared.configuration import APP_CONFIG
+ return {
+ "companyName": APP_CONFIG.get("Operator_CompanyName", ""),
+ "address": APP_CONFIG.get("Operator_Address", ""),
+ "vatNumber": APP_CONFIG.get("Operator_VatNumber", ""),
+ }
+ except Exception:
+ return {"companyName": "", "address": "", "vatNumber": ""}
+
+
+def _renderHtmlEmail(
+ headline: str,
+ bodyParagraphs: List[str],
+ mandateName: str,
+ footerNote: Optional[str] = None,
+ rawHtmlBlock: Optional[str] = None,
+) -> str:
+ """Render a clean, professional HTML notification email.
+
+ Args:
+ rawHtmlBlock: Optional pre-formatted HTML inserted after bodyParagraphs (e.g. invoice table).
+ """
+ hl = html.escape(headline)
+ mn = html.escape(mandateName)
+
+ paragraphsHtml = ""
+ for p in bodyParagraphs:
+ escaped = html.escape(p).replace("\n", "
")
+ paragraphsHtml += f'{escaped}
\n'
+
+ rawBlock = ""
+ if rawHtmlBlock:
+ rawBlock = f'{rawHtmlBlock}
\n'
+
+ footer = ""
+ if footerNote:
+ footer = (
+ f''
+ f'{html.escape(footerNote)}
\n'
+ )
+
+ operator = _getOperatorInfo()
+ operatorLine = ""
+ parts = [p for p in [operator["companyName"], operator["address"], operator["vatNumber"]] if p]
+ if parts:
+ operatorLine = (
+ f''
+ f'{html.escape(" | ".join(parts))}
\n'
+ )
+
+ return f"""
+
+
+
+
+
+
+
+
+ PowerOn
+ |
+
+
+ {hl}
+ Mandant: {mn}
+
+ {paragraphsHtml}
+ {rawBlock}
+
+ {footer}
+ |
+
+ |
+
+ Diese E-Mail wurde automatisch von PowerOn versendet.
+
+ {operatorLine}
+ |
+
+ |
+
+
+"""
+
+
+# ============================================================================
+# Public API
+# ============================================================================
+
+
+def notifyMandateAdmins(
+ mandateId: str,
+ subject: str,
+ headline: str,
+ bodyParagraphs: List[str],
+ *,
+ footerNote: Optional[str] = None,
+ rawHtmlBlock: Optional[str] = None,
+) -> int:
+ """
+ Send a styled HTML notification to all mandate admins and configured contacts.
+
+ Args:
+ mandateId: The mandate to notify admins for.
+ subject: Email subject line.
+ headline: Bold headline inside the email body.
+ bodyParagraphs: List of paragraph strings (plain text, auto-escaped).
+ footerNote: Optional small-print note below the main content.
+ rawHtmlBlock: Optional pre-formatted HTML block (e.g. invoice summary table).
+
+ Returns:
+ Number of recipients that were successfully notified.
+ """
+ if not mandateId:
+ return 0
+
+ recipients = _resolveAllRecipients(mandateId)
+ if not recipients:
+ logger.warning(
+ "notifyMandateAdmins: no recipients found for mandate %s "
+ "(no notifyEmails configured and no admin users with email)",
+ mandateId,
+ )
+ return 0
+
+ mandateName = _resolveMandateName(mandateId)
+ htmlMessage = _renderHtmlEmail(headline, bodyParagraphs, mandateName, footerNote, rawHtmlBlock)
+ messaging = getMessagingInterface()
+ successCount = 0
+
+ for to in recipients:
+ try:
+ ok = messaging.send(
+ channel=MessagingChannel.EMAIL,
+ recipient=to,
+ subject=subject,
+ message=htmlMessage,
+ )
+ if ok:
+ successCount += 1
+ else:
+ logger.warning("notifyMandateAdmins: send failed for %s", to)
+ except Exception as e:
+ logger.error("notifyMandateAdmins: error sending to %s: %s", to, e)
+
+ logger.info(
+ "notifyMandateAdmins: sent '%s' to %d/%d recipients for mandate %s (%s)",
+ subject, successCount, len(recipients), mandateId, mandateName,
+ )
+ return successCount
diff --git a/modules/shared/stripeClient.py b/modules/shared/stripeClient.py
new file mode 100644
index 00000000..9c7b4c67
--- /dev/null
+++ b/modules/shared/stripeClient.py
@@ -0,0 +1,38 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""
+Central Stripe SDK initialization.
+
+All Stripe interactions MUST use getStripeClient() to ensure consistent
+API key, API version, and fallback handling across billing and subscription flows.
+"""
+
+import logging
+from typing import Optional
+
+logger = logging.getLogger(__name__)
+
+_stripeInitialized = False
+
+
+def getStripeClient():
+ """
+ Initialize and return the configured Stripe SDK module.
+
+ Raises ValueError if no Stripe secret key is configured.
+ """
+ import stripe
+ from modules.shared.configuration import APP_CONFIG
+
+ apiVersion = APP_CONFIG.get("STRIPE_API_VERSION")
+ if apiVersion:
+ stripe.api_version = apiVersion
+
+ secretKey = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY")
+ if not secretKey:
+ raise ValueError("STRIPE_SECRET_KEY_SECRET not configured")
+
+ stripe.api_key = secretKey
+ return stripe
+
+
diff --git a/modules/system/mainSystem.py b/modules/system/mainSystem.py
index a4b92ca4..58667645 100644
--- a/modules/system/mainSystem.py
+++ b/modules/system/mainSystem.py
@@ -190,7 +190,6 @@ NAVIGATION_SECTIONS = [
"path": "/admin/billing",
"order": 40,
"adminOnly": True,
- "sysAdminOnly": True,
},
],
},
From 2e7a0a73c790f26beba755caaa1880244c5b37b3 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Sun, 22 Mar 2026 21:34:54 +0100
Subject: [PATCH 2/5] streamlined formgeneratortable and sort/filter globally
---
modules/connectors/connectorDbPostgre.py | 270 +-
modules/datamodels/datamodelPagination.py | 6 +
.../automation/interfaceFeatureAutomation.py | 2 +-
.../automation/routeFeatureAutomation.py | 44 +
.../chatbot/interfaceFeatureChatbot.py | 6 +-
.../realEstate/interfaceFeatureRealEstate.py | 86 +-
.../realEstate/routeFeatureRealEstate.py | 50 +
.../trustee/interfaceFeatureTrustee.py | 453 ++-
.../features/trustee/routeFeatureTrustee.py | 128 +-
modules/interfaces/interfaceDbApp.py | 148 +-
modules/interfaces/interfaceDbBilling.py | 78 +-
modules/interfaces/interfaceDbChat.py | 6 +-
modules/interfaces/interfaceDbManagement.py | 56 +-
modules/interfaces/interfaceRbac.py | 307 +-
modules/routes/routeAdminAutomationEvents.py | 278 +-
modules/routes/routeAdminFeatures.py | 176 +-
modules/routes/routeAdminRbacRules.py | 81 +-
modules/routes/routeBilling.py | 245 +-
modules/routes/routeDataConnections.py | 41 +
modules/routes/routeDataFiles.py | 50 +
modules/routes/routeDataMandates.py | 117 +
modules/routes/routeDataPrompts.py | 23 +
modules/routes/routeDataUsers.py | 221 +-
modules/routes/routeInvitations.py | 55 +
modules/routes/routeMessaging.py | 10 +-
modules/routes/routeRealEstate.py | 70 +-
modules/routes/routeSubscription.py | 95 +-
modules/system/mainSystem.py | 9 +
scripts/.$import_diagram.drawio.bkp | 2723 -----------------
.../.$import_diagram_containers.drawio.bkp | 317 --
30 files changed, 2402 insertions(+), 3749 deletions(-)
delete mode 100644 scripts/.$import_diagram.drawio.bkp
delete mode 100644 scripts/.$import_diagram_containers.drawio.bkp
diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py
index 83517f31..9675ffca 100644
--- a/modules/connectors/connectorDbPostgre.py
+++ b/modules/connectors/connectorDbPostgre.py
@@ -1,6 +1,7 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
import contextvars
+import re
import psycopg2
import psycopg2.extras
import logging
@@ -92,19 +93,27 @@ def _get_model_fields(model_class) -> Dict[str, str]:
fields[field_name] = extra["db_type"]
continue
- # Check for JSONB fields (Dict, List, or complex types)
+ # Unwrap Optional[X] → X (handles both typing.Union and types.UnionType)
+ origin = get_origin(field_type)
+ if origin is Union:
+ args = [a for a in get_args(field_type) if a is not type(None)]
+ if len(args) == 1:
+ field_type = args[0]
+ elif hasattr(field_type, '__args__') and type(None) in getattr(field_type, '__args__', ()):
+ args = [a for a in field_type.__args__ if a is not type(None)]
+ if len(args) == 1:
+ field_type = args[0]
+
if _isJsonbType(field_type):
fields[field_name] = "JSONB"
- elif field_type in (str, type(None)) or (
- get_origin(field_type) is Union and type(None) in get_args(field_type)
- ):
- fields[field_name] = "TEXT"
- elif field_type == int:
- fields[field_name] = "INTEGER"
- elif field_type == float:
- fields[field_name] = "DOUBLE PRECISION"
- elif field_type == bool:
+ elif field_type is bool:
fields[field_name] = "BOOLEAN"
+ elif field_type is int:
+ fields[field_name] = "INTEGER"
+ elif field_type is float:
+ fields[field_name] = "DOUBLE PRECISION"
+ elif field_type in (str, type(None)):
+ fields[field_name] = "TEXT"
else:
fields[field_name] = "TEXT"
@@ -135,6 +144,9 @@ def _parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context:
elif isinstance(value, list):
pass # already a list
+ elif fieldType == "BOOLEAN":
+ record[fieldName] = bool(value) if value is not None else False
+
elif fieldType == "JSONB" and value is not None:
try:
if isinstance(value, str):
@@ -969,6 +981,244 @@ class DatabaseConnector:
logger.error(f"Error loading records from table {table}: {e}")
return []
+ def _buildPaginationClauses(
+ self,
+ model_class: type,
+ pagination,
+ recordFilter: Dict[str, Any] = None,
+ ):
+ """
+ Translate PaginationParams + recordFilter into SQL clauses.
+ Returns (where_clause, order_clause, limit_clause, values, count_values).
+ """
+ fields = _get_model_fields(model_class)
+ validColumns = set(fields.keys())
+ where_parts: List[str] = []
+ values: List[Any] = []
+
+ if recordFilter:
+ for field, value in recordFilter.items():
+ if value is None:
+ where_parts.append(f'"{field}" IS NULL')
+ elif isinstance(value, list):
+ where_parts.append(f'"{field}" = ANY(%s)')
+ values.append(value)
+ else:
+ where_parts.append(f'"{field}" = %s')
+ values.append(value)
+
+ if pagination and pagination.filters:
+ for key, val in pagination.filters.items():
+ if key == "search" and isinstance(val, str) and val.strip():
+ term = f"%{val.strip()}%"
+ textCols = [c for c, t in fields.items() if t == "TEXT"]
+ if textCols:
+ orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols]
+ where_parts.append(f"({' OR '.join(orParts)})")
+ values.extend([term] * len(textCols))
+ continue
+ if key not in validColumns:
+ logger.debug(f"_buildPaginationClauses: key '{key}' NOT in validColumns {list(validColumns)[:10]}")
+ continue
+ colType = fields.get(key, "TEXT")
+ logger.debug(f"_buildPaginationClauses: filter key='{key}' val={val!r} type(val)={type(val).__name__} colType={colType}")
+ if isinstance(val, dict):
+ op = val.get("operator", "equals")
+ v = val.get("value", "")
+ if op in ("equals", "eq"):
+ if colType == "BOOLEAN":
+ where_parts.append(f'COALESCE("{key}", FALSE) = %s')
+ values.append(str(v).lower() == "true")
+ else:
+ where_parts.append(f'"{key}"::TEXT = %s')
+ values.append(str(v))
+ elif op == "contains":
+ where_parts.append(f'"{key}"::TEXT ILIKE %s')
+ values.append(f"%{v}%")
+ elif op == "startsWith":
+ where_parts.append(f'"{key}"::TEXT ILIKE %s')
+ values.append(f"{v}%")
+ elif op == "endsWith":
+ where_parts.append(f'"{key}"::TEXT ILIKE %s')
+ values.append(f"%{v}")
+ elif op in ("gt", "gte", "lt", "lte"):
+ sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op]
+ where_parts.append(f'"{key}"::TEXT {sqlOp} %s')
+ values.append(str(v))
+ elif op == "between":
+ fromVal = v.get("from", "") if isinstance(v, dict) else ""
+ toVal = v.get("to", "") if isinstance(v, dict) else ""
+ if not fromVal and not toVal:
+ continue
+ colType = fields.get(key, "TEXT")
+ isNumericCol = colType in ("INTEGER", "DOUBLE PRECISION")
+ isDateVal = bool(fromVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(fromVal))) or \
+ bool(toVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(toVal)))
+ if isNumericCol and isDateVal:
+ from datetime import datetime as _dt, timezone as _tz
+ if fromVal and toVal:
+ fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp()
+ toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp()
+ where_parts.append(f'"{key}" >= %s AND "{key}" <= %s')
+ values.extend([fromTs, toTs])
+ elif fromVal:
+ fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp()
+ where_parts.append(f'"{key}" >= %s')
+ values.append(fromTs)
+ else:
+ toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp()
+ where_parts.append(f'"{key}" <= %s')
+ values.append(toTs)
+ else:
+ if fromVal and toVal:
+ where_parts.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s')
+ values.extend([str(fromVal), str(toVal)])
+ elif fromVal:
+ where_parts.append(f'"{key}"::TEXT >= %s')
+ values.append(str(fromVal))
+ elif toVal:
+ where_parts.append(f'"{key}"::TEXT <= %s')
+ values.append(str(toVal))
+ else:
+ if colType == "BOOLEAN":
+ where_parts.append(f'COALESCE("{key}", FALSE) = %s')
+ values.append(str(val).lower() == "true")
+ else:
+ where_parts.append(f'"{key}"::TEXT ILIKE %s')
+ values.append(str(val))
+
+ where_clause = " WHERE " + " AND ".join(where_parts) if where_parts else ""
+ count_values = list(values)
+
+ orderParts: List[str] = []
+ if pagination and pagination.sort:
+ for sf in pagination.sort:
+ if sf.field in validColumns:
+ direction = "DESC" if sf.direction.lower() == "desc" else "ASC"
+ colType = fields.get(sf.field, "TEXT")
+ if colType == "BOOLEAN":
+ orderParts.append(f'COALESCE("{sf.field}", FALSE) {direction}')
+ else:
+ orderParts.append(f'"{sf.field}" {direction} NULLS LAST')
+ if not orderParts:
+ orderParts.append('"id"')
+ order_clause = " ORDER BY " + ", ".join(orderParts)
+
+ limit_clause = ""
+ if pagination:
+ offset = (pagination.page - 1) * pagination.pageSize
+ limit_clause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
+
+ return where_clause, order_clause, limit_clause, values, count_values
+
+ def getRecordsetPaginated(
+ self,
+ model_class: type,
+ pagination=None,
+ recordFilter: Dict[str, Any] = None,
+ fieldFilter: List[str] = None,
+ ) -> Dict[str, Any]:
+ """
+ Returns paginated records with filtering + sorting at the SQL level.
+ Returns { "items": [...], "totalItems": int, "totalPages": int }.
+ If pagination is None, returns all records (no LIMIT/OFFSET).
+ """
+ from modules.datamodels.datamodelPagination import PaginationParams
+ import math
+
+ table = model_class.__name__
+
+ try:
+ if not self._ensureTableExists(model_class):
+ return {"items": [], "totalItems": 0, "totalPages": 0}
+
+ where_clause, order_clause, limit_clause, values, count_values = \
+ self._buildPaginationClauses(model_class, pagination, recordFilter)
+
+ with self.connection.cursor() as cursor:
+ countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
+ cursor.execute(countSql, count_values)
+ totalItems = cursor.fetchone()["count"]
+
+ dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
+ cursor.execute(dataSql, values)
+ records = [dict(row) for row in cursor.fetchall()]
+
+ fields = _get_model_fields(model_class)
+ modelFields = model_class.model_fields
+ for record in records:
+ _parseRecordFields(record, fields, f"table {table}")
+ for fieldName, fieldType in fields.items():
+ if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
+ fieldInfo = modelFields.get(fieldName)
+ if fieldInfo:
+ fieldAnnotation = fieldInfo.annotation
+ if (fieldAnnotation == list or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is list)):
+ record[fieldName] = []
+ elif (fieldAnnotation == dict or
+ (hasattr(fieldAnnotation, "__origin__") and
+ fieldAnnotation.__origin__ is dict)):
+ record[fieldName] = {}
+
+ if fieldFilter and isinstance(fieldFilter, list):
+ records = [{f: r[f] for f in fieldFilter if f in r} for r in records]
+
+ pageSize = pagination.pageSize if pagination else max(totalItems, 1)
+ totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0
+
+ return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
+ except Exception as e:
+ logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
+ return {"items": [], "totalItems": 0, "totalPages": 0}
+
+ def getDistinctColumnValues(
+ self,
+ model_class: type,
+ column: str,
+ pagination=None,
+ recordFilter: Dict[str, Any] = None,
+ ) -> List[str]:
+ """
+ Returns sorted distinct non-null values for a column using SQL DISTINCT.
+ Applies cross-filtering (all filters except the requested column).
+ """
+ table = model_class.__name__
+ fields = _get_model_fields(model_class)
+
+ if column not in fields:
+ return []
+
+ try:
+ if not self._ensureTableExists(model_class):
+ return []
+
+ if pagination:
+ if pagination.filters and column in pagination.filters:
+ import copy
+ pagination = copy.deepcopy(pagination)
+ pagination.filters.pop(column, None)
+
+ where_clause, _, _, values, _ = \
+ self._buildPaginationClauses(model_class, pagination, recordFilter)
+
+ sql = (
+ f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} '
+ f'WHERE "{column}" IS NOT NULL AND "{column}"::TEXT != \'\' '
+ if not where_clause else
+ f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} '
+ f'AND "{column}" IS NOT NULL AND "{column}"::TEXT != \'\' '
+ )
+ sql += 'ORDER BY val'
+
+ with self.connection.cursor() as cursor:
+ cursor.execute(sql, values)
+ return [row["val"] for row in cursor.fetchall()]
+ except Exception as e:
+ logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
+ return []
+
def recordCreate(
self, model_class: type, record: Union[Dict[str, Any], BaseModel]
) -> Dict[str, Any]:
diff --git a/modules/datamodels/datamodelPagination.py b/modules/datamodels/datamodelPagination.py
index 9815027f..2719327b 100644
--- a/modules/datamodels/datamodelPagination.py
+++ b/modules/datamodels/datamodelPagination.py
@@ -98,6 +98,12 @@ def normalize_pagination_dict(pagination_dict: Dict[str, Any]) -> Dict[str, Any]
# Create a copy to avoid modifying the original
normalized = dict(pagination_dict)
+ # Ensure required fields have sensible defaults
+ if "page" not in normalized:
+ normalized["page"] = 1
+ if "pageSize" not in normalized:
+ normalized["pageSize"] = 25
+
# Move top-level "search" into filters if present
if "search" in normalized:
if "filters" not in normalized or normalized["filters"] is None:
diff --git a/modules/features/automation/interfaceFeatureAutomation.py b/modules/features/automation/interfaceFeatureAutomation.py
index 3b20ca3d..3fc2420b 100644
--- a/modules/features/automation/interfaceFeatureAutomation.py
+++ b/modules/features/automation/interfaceFeatureAutomation.py
@@ -270,7 +270,7 @@ class AutomationObjects:
if value.lower() not in itemValue.lower():
match = False
break
- elif itemValue != value:
+ elif str(itemValue).lower() != str(value).lower():
match = False
break
if match:
diff --git a/modules/features/automation/routeFeatureAutomation.py b/modules/features/automation/routeFeatureAutomation.py
index 6cdc4d44..81f0852b 100644
--- a/modules/features/automation/routeFeatureAutomation.py
+++ b/modules/features/automation/routeFeatureAutomation.py
@@ -109,6 +109,26 @@ def get_automations(
detail=f"Error getting automations: {str(e)}"
)
+@router.get("/filter-values")
+@limiter.limit("60/minute")
+def get_automation_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in automations."""
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
+ result = chatInterface.getAllAutomationDefinitions(pagination=None)
+ items = result if isinstance(result, list) else [r if isinstance(r, dict) else r.model_dump() if hasattr(r, 'model_dump') else r for r in result]
+ return _handleFilterValuesRequest(items, column, pagination)
+ except Exception as e:
+ logger.error(f"Error getting filter values for automations: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@router.post("", response_model=AutomationDefinition)
@limiter.limit("10/minute")
def create_automation(
@@ -1017,6 +1037,30 @@ def get_db_templates(
)
+@templateRouter.get("/filter-values")
+@limiter.limit("60/minute")
+def get_template_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in automation templates."""
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ chatInterface = getAutomationInterface(
+ context.user,
+ mandateId=str(context.mandateId) if context.mandateId else None,
+ featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None
+ )
+ result = chatInterface.getAllAutomationTemplates(pagination=None)
+ items = [r if isinstance(r, dict) else r.model_dump() if hasattr(r, 'model_dump') else r for r in result]
+ return _handleFilterValuesRequest(items, column, pagination)
+ except Exception as e:
+ logger.error(f"Error getting filter values for automation templates: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@templateRouter.get("/attributes", response_model=Dict[str, Any])
def get_template_attributes(
request: Request
diff --git a/modules/features/chatbot/interfaceFeatureChatbot.py b/modules/features/chatbot/interfaceFeatureChatbot.py
index 559a9187..4a03bec9 100644
--- a/modules/features/chatbot/interfaceFeatureChatbot.py
+++ b/modules/features/chatbot/interfaceFeatureChatbot.py
@@ -514,7 +514,7 @@ class ChatObjects:
# Handle simple value (equals operator)
if not isinstance(filter_value, dict):
- if record_value != filter_value:
+ if str(record_value).lower() != str(filter_value).lower():
matches = False
break
continue
@@ -524,7 +524,7 @@ class ChatObjects:
filter_val = filter_value.get("value")
if operator in ["equals", "eq"]:
- if record_value != filter_val:
+ if str(record_value).lower() != str(filter_val).lower():
matches = False
break
@@ -609,7 +609,7 @@ class ChatObjects:
else:
# Unknown operator - default to equals
- if record_value != filter_val:
+ if str(record_value).lower() != str(filter_val).lower():
matches = False
break
diff --git a/modules/features/realEstate/interfaceFeatureRealEstate.py b/modules/features/realEstate/interfaceFeatureRealEstate.py
index 65601d9a..f7ed52b6 100644
--- a/modules/features/realEstate/interfaceFeatureRealEstate.py
+++ b/modules/features/realEstate/interfaceFeatureRealEstate.py
@@ -24,7 +24,8 @@ from modules.shared.configuration import APP_CONFIG
from modules.security.rbac import RbacClass
from modules.datamodels.datamodelRbac import AccessRuleContext
from modules.datamodels.datamodelUam import AccessLevel
-from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
+from modules.interfaces.interfaceRbac import getRecordsetWithRBAC, getRecordsetPaginatedWithRBAC, getDistinctColumnValuesWithRBAC
+from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
logger = logging.getLogger(__name__)
@@ -175,17 +176,21 @@ class RealEstateObjects:
return Projekt(**records[0])
- def getProjekte(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Projekt]:
- """Get all projects matching the filter."""
- records = getRecordsetWithRBAC(
+ def getProjekte(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Projekt], PaginatedResult]:
+ """Get all projects matching the filter with optional DB-level pagination."""
+ result = getRecordsetPaginatedWithRBAC(
self.db,
Projekt,
self.currentUser,
+ pagination=pagination,
recordFilter=recordFilter or {},
featureCode=self.FEATURE_CODE
)
-
- return [Projekt(**r) for r in records]
+
+ if isinstance(result, PaginatedResult):
+ result.items = [Projekt(**r) for r in result.items]
+ return result
+ return [Projekt(**r) for r in result]
def updateProjekt(self, projektId_or_projekt: Union[str, Projekt], updateData: Optional[Dict[str, Any]] = None) -> Optional[Projekt]:
"""Update a project.
@@ -265,20 +270,23 @@ class RealEstateObjects:
return Parzelle(**records[0])
- def getParzellen(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Parzelle]:
- """Get all plots matching the filter."""
- # Resolve location names to IDs if needed
+ def getParzellen(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Parzelle], PaginatedResult]:
+ """Get all plots matching the filter with optional DB-level pagination."""
if recordFilter:
recordFilter = self._resolveLocationFilters(recordFilter)
-
- records = getRecordsetWithRBAC(
+
+ result = getRecordsetPaginatedWithRBAC(
self.db,
Parzelle,
self.currentUser,
+ pagination=pagination,
recordFilter=recordFilter or {}
)
-
- return [Parzelle(**r) for r in records]
+
+ if isinstance(result, PaginatedResult):
+ result.items = [Parzelle(**r) for r in result.items]
+ return result
+ return [Parzelle(**r) for r in result]
def _resolveLocationFilters(self, recordFilter: Dict[str, Any]) -> Dict[str, Any]:
"""
@@ -477,18 +485,23 @@ class RealEstateObjects:
return Dokument(**records[0])
- def getDokumente(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Dokument]:
- """Get all documents matching the filter."""
- records = getRecordsetWithRBAC(
+ def getDokumente(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Dokument], PaginatedResult]:
+ """Get all documents matching the filter with optional DB-level pagination."""
+ result = getRecordsetPaginatedWithRBAC(
self.db,
Dokument,
self.currentUser,
+ pagination=pagination,
recordFilter=recordFilter or {},
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- return [Dokument(**r) for r in records]
+
+ if isinstance(result, PaginatedResult):
+ result.items = [Dokument(**r) for r in result.items]
+ return result
+ return [Dokument(**r) for r in result]
def updateDokument(self, dokumentId: str, updateData: Dict[str, Any]) -> Optional[Dokument]:
"""Update a document."""
@@ -552,18 +565,23 @@ class RealEstateObjects:
return Gemeinde(**records[0])
- def getGemeinden(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Gemeinde]:
- """Get all municipalities matching the filter."""
- records = getRecordsetWithRBAC(
+ def getGemeinden(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Gemeinde], PaginatedResult]:
+ """Get all municipalities matching the filter with optional DB-level pagination."""
+ result = getRecordsetPaginatedWithRBAC(
self.db,
Gemeinde,
self.currentUser,
+ pagination=pagination,
recordFilter=recordFilter or {},
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- return [Gemeinde(**r) for r in records]
+
+ if isinstance(result, PaginatedResult):
+ result.items = [Gemeinde(**r) for r in result.items]
+ return result
+ return [Gemeinde(**r) for r in result]
def updateGemeinde(self, gemeindeId: str, updateData: Dict[str, Any]) -> Optional[Gemeinde]:
"""Update a municipality."""
@@ -627,18 +645,23 @@ class RealEstateObjects:
return Kanton(**records[0])
- def getKantone(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Kanton]:
- """Get all cantons matching the filter."""
- records = getRecordsetWithRBAC(
+ def getKantone(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Kanton], PaginatedResult]:
+ """Get all cantons matching the filter with optional DB-level pagination."""
+ result = getRecordsetPaginatedWithRBAC(
self.db,
Kanton,
self.currentUser,
+ pagination=pagination,
recordFilter=recordFilter or {},
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- return [Kanton(**r) for r in records]
+
+ if isinstance(result, PaginatedResult):
+ result.items = [Kanton(**r) for r in result.items]
+ return result
+ return [Kanton(**r) for r in result]
def updateKanton(self, kantonId: str, updateData: Dict[str, Any]) -> Optional[Kanton]:
"""Update a canton."""
@@ -700,16 +723,21 @@ class RealEstateObjects:
return Land(**records[0])
- def getLaender(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Land]:
- """Get all countries matching the filter."""
- records = getRecordsetWithRBAC(
+ def getLaender(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Land], PaginatedResult]:
+ """Get all countries matching the filter with optional DB-level pagination."""
+ result = getRecordsetPaginatedWithRBAC(
self.db,
Land,
self.currentUser,
+ pagination=pagination,
recordFilter=recordFilter or {},
featureCode=self.FEATURE_CODE
)
- return [Land(**r) for r in records]
+
+ if isinstance(result, PaginatedResult):
+ result.items = [Land(**r) for r in result.items]
+ return result
+ return [Land(**r) for r in result]
def updateLand(self, landId: str, updateData: Dict[str, Any]) -> Optional[Land]:
"""Update a country."""
diff --git a/modules/features/realEstate/routeFeatureRealEstate.py b/modules/features/realEstate/routeFeatureRealEstate.py
index b4df86dd..82fa55ba 100644
--- a/modules/features/realEstate/routeFeatureRealEstate.py
+++ b/modules/features/realEstate/routeFeatureRealEstate.py
@@ -252,6 +252,31 @@ def get_projects(
return PaginatedResponse(items=items, pagination=None)
+@router.get("/{instanceId}/projects/filter-values")
+@limiter.limit("60/minute")
+def get_project_filter_values(
+ request: Request,
+ instanceId: str = Path(..., description="Feature Instance ID"),
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in real estate projects."""
+ mandateId = _validateInstanceAccess(instanceId, context)
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ interface = getRealEstateInterface(
+ context.user, mandateId=mandateId, featureInstanceId=instanceId
+ )
+ recordFilter = {"featureInstanceId": instanceId}
+ items = interface.getProjekte(recordFilter=recordFilter)
+ itemDicts = [i.model_dump() if hasattr(i, 'model_dump') else i for i in items]
+ return _handleFilterValuesRequest(itemDicts, column, pagination)
+ except Exception as e:
+ logger.error(f"Error getting filter values for projects: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@router.get("/{instanceId}/projects/{projectId}", response_model=Projekt)
@limiter.limit("30/minute")
def get_project_by_id(
@@ -384,6 +409,31 @@ def get_parcels(
return PaginatedResponse(items=items, pagination=None)
+@router.get("/{instanceId}/parcels/filter-values")
+@limiter.limit("60/minute")
+def get_parcel_filter_values(
+ request: Request,
+ instanceId: str = Path(..., description="Feature Instance ID"),
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in real estate parcels."""
+ mandateId = _validateInstanceAccess(instanceId, context)
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ interface = getRealEstateInterface(
+ context.user, mandateId=mandateId, featureInstanceId=instanceId
+ )
+ recordFilter = {"featureInstanceId": instanceId}
+ items = interface.getParzellen(recordFilter=recordFilter)
+ itemDicts = [i.model_dump() if hasattr(i, 'model_dump') else i for i in items]
+ return _handleFilterValuesRequest(itemDicts, column, pagination)
+ except Exception as e:
+ logger.error(f"Error getting filter values for parcels: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@router.get("/{instanceId}/parcels/{parcelId}", response_model=Parzelle)
@limiter.limit("30/minute")
def get_parcel_by_id(
diff --git a/modules/features/trustee/interfaceFeatureTrustee.py b/modules/features/trustee/interfaceFeatureTrustee.py
index a4b13c27..b9a95005 100644
--- a/modules/features/trustee/interfaceFeatureTrustee.py
+++ b/modules/features/trustee/interfaceFeatureTrustee.py
@@ -14,7 +14,7 @@ from pydantic import ValidationError
from modules.connectors.connectorDbPostgre import DatabaseConnector
from modules.shared.configuration import APP_CONFIG
-from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
+from modules.interfaces.interfaceRbac import getRecordsetWithRBAC, getRecordsetPaginatedWithRBAC, getDistinctColumnValuesWithRBAC
from modules.security.rbac import RbacClass
from modules.datamodels.datamodelUam import User, AccessLevel
from modules.datamodels.datamodelRbac import AccessRuleContext
@@ -414,7 +414,7 @@ class TrusteeObjects:
# Handle simple value (equals operator)
if not isinstance(filterValue, dict):
- if recordValue == filterValue:
+ if str(recordValue).lower() == str(filterValue).lower():
fieldFiltered.append(record)
continue
@@ -424,7 +424,7 @@ class TrusteeObjects:
matches = False
if operator in ["equals", "eq"]:
- matches = recordValue == filterVal
+ matches = str(recordValue).lower() == str(filterVal).lower()
elif operator == "contains":
recordStr = str(recordValue).lower() if recordValue is not None else ""
@@ -577,8 +577,8 @@ class TrusteeObjects:
return None
return TrusteeOrganisation(**{k: v for k, v in records[0].items() if not k.startswith("_")})
- def getAllOrganisations(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
- """Get all organisations with RBAC filtering.
+ def getAllOrganisations(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
+ """Get all organisations with RBAC filtering and optional DB-level pagination.
Note: Organisations are managed at system level (by mandate).
Feature-level filtering (trustee.access) is NOT applied here because:
@@ -586,42 +586,18 @@ class TrusteeObjects:
- trustee.access grants access to specific orgs for other users
- New organisations wouldn't be visible without an access record
"""
- # Debug: Log user info and permissions
logger.debug(f"getAllOrganisations called for user {self.userId}, mandateId: {self.mandateId}")
- # System RBAC filtering (filters by mandate for GROUP access level)
- records = getRecordsetWithRBAC(
+ return getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteeOrganisation,
currentUser=self.currentUser,
+ pagination=params,
recordFilter=None,
- orderBy="id",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- logger.debug(f"getAllOrganisations: getRecordsetWithRBAC returned {len(records)} records")
-
- # Apply pagination
- totalItems = len(records)
- if params:
- pageSize = params.pageSize or 20
- page = params.page or 1
- startIdx = (page - 1) * pageSize
- endIdx = startIdx + pageSize
- items = records[startIdx:endIdx]
- totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
- else:
- items = records
- totalPages = 1
- page = 1
- pageSize = totalItems
-
- return PaginatedResult(
- items=items,
- totalItems=totalItems,
- totalPages=totalPages
- )
def updateOrganisation(self, orgId: str, data: Dict[str, Any]) -> Optional[TrusteeOrganisation]:
"""Update an organisation."""
@@ -677,51 +653,40 @@ class TrusteeObjects:
return None
return TrusteeRole(**{k: v for k, v in records[0].items() if not k.startswith("_")})
- def getAllRoles(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
- """Get all roles with RBAC filtering.
+ def getAllRoles(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
+ """Get all roles with RBAC filtering and optional DB-level pagination.
Note: Roles are available to all users with trustee access.
They are not filtered by organisation since they define the role types.
Users with ALL access level see all roles; others need trustee.access records.
+
+ NOTE(post-filter): Feature-level access check runs after the paginated query.
+ When pagination is active the totals may overcount if the user lacks any
+ trustee.access record (in that case the entire result is emptied).
"""
- records = getRecordsetWithRBAC(
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteeRole,
currentUser=self.currentUser,
+ pagination=params,
recordFilter=None,
- orderBy="id",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- # Users with ALL access level (from system RBAC) see all roles
- # Others need at least one trustee.access record
accessLevel = self.getRbacAccessLevel(TrusteeRole, "read")
if accessLevel != AccessLevel.ALL:
userAccess = self.getAllUserAccess(self.userId)
if not userAccess:
- records = [] # No trustee access at all
+ if isinstance(result, PaginatedResult):
+ result.items = []
+ result.totalItems = 0
+ result.totalPages = 0
+ return result
+ return []
- totalItems = len(records)
- if params:
- pageSize = params.pageSize or 20
- page = params.page or 1
- startIdx = (page - 1) * pageSize
- endIdx = startIdx + pageSize
- items = records[startIdx:endIdx]
- totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
- else:
- items = records
- totalPages = 1
- page = 1
- pageSize = totalItems
-
- return PaginatedResult(
- items=items,
- totalItems=totalItems,
- totalPages=totalPages
- )
+ return result
def updateRole(self, roleId: str, data: Dict[str, Any]) -> Optional[TrusteeRole]:
"""Update a role (sysadmin only)."""
@@ -788,116 +753,113 @@ class TrusteeObjects:
return None
return TrusteeAccess(**{k: v for k, v in records[0].items() if not k.startswith("_")})
- def getAllAccess(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
+ def getAllAccess(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
"""Get all access records with RBAC filtering + feature-level filtering.
Users with ALL access level see all access records.
Others can only see access records for organisations they have admin access to.
+
+ NOTE(post-filter): Feature-level admin-org filtering runs after the paginated
+ query, so totals may overcount when the user has restricted org access.
"""
- # Step 1: System RBAC filtering
- records = getRecordsetWithRBAC(
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteeAccess,
currentUser=self.currentUser,
+ pagination=params,
recordFilter=None,
- orderBy="id",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- # Users with ALL access level (from system RBAC) see all records
accessLevel = self.getRbacAccessLevel(TrusteeAccess, "read")
-
+
if accessLevel != AccessLevel.ALL:
- # Step 2: Feature-level filtering - only see access for organisations where user is admin
userAccess = self.getAllUserAccess(self.userId)
-
- # Get organisations where user has admin role
adminOrgs = set()
for access in userAccess:
if access.get("roleId") == "admin":
adminOrgs.add(access.get("organisationId"))
-
- # Filter records to only show those in admin organisations
+
+ if isinstance(result, PaginatedResult):
+ if adminOrgs:
+ result.items = [r for r in result.items if r.get("organisationId") in adminOrgs]
+ else:
+ result.items = []
+ return result
+
if adminOrgs:
- records = [r for r in records if r.get("organisationId") in adminOrgs]
+ result = [r for r in result if r.get("organisationId") in adminOrgs]
else:
- records = []
+ result = []
- totalItems = len(records)
- if params:
- pageSize = params.pageSize or 20
- page = params.page or 1
- startIdx = (page - 1) * pageSize
- endIdx = startIdx + pageSize
- items = records[startIdx:endIdx]
- totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
- else:
- items = records
- totalPages = 1
- page = 1
- pageSize = totalItems
+ return result
- return PaginatedResult(
- items=items,
- totalItems=totalItems,
- totalPages=totalPages
- )
-
- def getAccessByOrganisation(self, organisationId: str) -> List[TrusteeAccess]:
+ def getAccessByOrganisation(self, organisationId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeAccess], PaginatedResult]:
"""Get all access records for a specific organisation.
Requires admin role for the organisation.
"""
- # Check if user has admin access for this organisation
if not self.checkUserTrusteePermission(self.userId, organisationId, "admin"):
logger.warning(f"User {self.userId} lacks admin role for organisation {organisationId}")
return []
-
- records = getRecordsetWithRBAC(
+
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteeAccess,
currentUser=self.currentUser,
+ pagination=pagination,
recordFilter={"organisationId": organisationId},
- orderBy="id",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in records]
- def getAccessByUser(self, userId: str) -> List[TrusteeAccess]:
+ if isinstance(result, PaginatedResult):
+ result.items = [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items]
+ return result
+ return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result]
+
+ def getAccessByUser(self, userId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeAccess], PaginatedResult]:
"""Get all access records for a specific user.
Users with ALL access level see all access records.
Others can only see access records for organisations where they have admin role.
+
+ NOTE(post-filter): Admin-org filtering runs after the paginated query,
+ so totals may overcount when the user has restricted org access.
"""
- records = getRecordsetWithRBAC(
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteeAccess,
currentUser=self.currentUser,
+ pagination=pagination,
recordFilter={"userId": userId},
- orderBy="id",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
-
- # Users with ALL access level (from system RBAC) see all records
+
accessLevel = self.getRbacAccessLevel(TrusteeAccess, "read")
- if accessLevel == AccessLevel.ALL:
- return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in records]
-
- # Filter to only organisations where current user has admin role
- userAccess = self.getAllUserAccess(self.userId)
- adminOrgs = set()
- for access in userAccess:
- if access.get("roleId") == "admin":
- adminOrgs.add(access.get("organisationId"))
-
- filtered = [r for r in records if r.get("organisationId") in adminOrgs]
- return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in filtered]
+
+ if accessLevel != AccessLevel.ALL:
+ userAccess = self.getAllUserAccess(self.userId)
+ adminOrgs = set()
+ for access in userAccess:
+ if access.get("roleId") == "admin":
+ adminOrgs.add(access.get("organisationId"))
+
+ if isinstance(result, PaginatedResult):
+ result.items = [r for r in result.items if r.get("organisationId") in adminOrgs]
+ result.items = [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items]
+ return result
+ result = [r for r in result if r.get("organisationId") in adminOrgs]
+
+ if isinstance(result, PaginatedResult):
+ result.items = [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items]
+ return result
+ return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result]
def updateAccess(self, accessId: str, data: Dict[str, Any]) -> Optional[TrusteeAccess]:
"""Update an access record. Requires admin role for the organisation or ALL access level."""
@@ -988,54 +950,36 @@ class TrusteeObjects:
return None
return TrusteeContract(**{k: v for k, v in records[0].items() if not k.startswith("_")})
- def getAllContracts(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
- """Get all contracts with RBAC filtering + feature-level access filtering."""
- # Step 1: System RBAC filtering
- records = getRecordsetWithRBAC(
+ def getAllContracts(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
+ """Get all contracts with RBAC filtering and optional DB-level pagination."""
+ return getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteeContract,
currentUser=self.currentUser,
+ pagination=params,
recordFilter=None,
- orderBy="id",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- totalItems = len(records)
- if params:
- pageSize = params.pageSize or 20
- page = params.page or 1
- startIdx = (page - 1) * pageSize
- endIdx = startIdx + pageSize
- items = records[startIdx:endIdx]
- totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
- else:
- items = records
- totalPages = 1
- page = 1
- pageSize = totalItems
-
- return PaginatedResult(
- items=items,
- totalItems=totalItems,
- totalPages=totalPages
- )
-
- def getContractsByOrganisation(self, organisationId: str) -> List[TrusteeContract]:
+ def getContractsByOrganisation(self, organisationId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeContract], PaginatedResult]:
"""Get all contracts for a specific organisation."""
- # Step 1: System RBAC filtering
- records = getRecordsetWithRBAC(
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteeContract,
currentUser=self.currentUser,
+ pagination=pagination,
recordFilter={"organisationId": organisationId},
- orderBy="label",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- return [TrusteeContract(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in records]
+
+ if isinstance(result, PaginatedResult):
+ result.items = [TrusteeContract(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items]
+ return result
+ return [TrusteeContract(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result]
def updateContract(self, contractId: str, data: Dict[str, Any]) -> Optional[TrusteeContract]:
"""Update a contract (organisationId is immutable)."""
@@ -1142,75 +1086,58 @@ class TrusteeObjects:
# Legacy fallback: documentData was stored directly (for migration)
return record.get("documentData")
- def getAllDocuments(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
- """Get all documents with RBAC filtering + feature-level access filtering (metadata only)."""
- # Step 1: System RBAC filtering
- records = getRecordsetWithRBAC(
+ def getAllDocuments(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
+ """Get all documents with RBAC filtering and optional DB-level pagination (metadata only).
+
+ Filtering, sorting, and pagination are handled at the SQL level by
+ getRecordsetPaginatedWithRBAC. Binary documentData is stripped from
+ the returned items.
+ """
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteeDocument,
currentUser=self.currentUser,
+ pagination=params,
recordFilter=None,
- orderBy="documentName",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- # Clean records (remove binary data and internal fields) - keep as dicts for filtering/sorting
- cleanedRecords = []
- for record in records:
- cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") and k != "documentData"}
- cleanedRecords.append(cleanedRecord)
+ def _cleanDocumentRecords(records):
+ return [
+ TrusteeDocument(**{k: v for k, v in r.items() if not k.startswith("_") and k != "documentData"})
+ for r in records
+ ]
- # Step 2: Apply filters (search and field filters)
- filteredRecords = self._applyFilters(cleanedRecords, params)
-
- # Step 3: Apply sorting
- sortedRecords = self._applySorting(filteredRecords, params)
-
- # Step 4: Convert to Pydantic objects
- pydanticItems = [TrusteeDocument(**r) for r in sortedRecords]
+ if isinstance(result, PaginatedResult):
+ result.items = _cleanDocumentRecords(result.items)
+ return result
+ return _cleanDocumentRecords(result)
- # Step 5: Apply pagination
- totalItems = len(pydanticItems)
- if params:
- pageSize = params.pageSize or 20
- page = params.page or 1
- startIdx = (page - 1) * pageSize
- endIdx = startIdx + pageSize
- items = pydanticItems[startIdx:endIdx]
- totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
- else:
- items = pydanticItems
- totalPages = 1
- page = 1
- pageSize = totalItems
-
- return PaginatedResult(
- items=items,
- totalItems=totalItems,
- totalPages=totalPages
- )
-
- def getDocumentsByContract(self, contractId: str) -> List[TrusteeDocument]:
+ def getDocumentsByContract(self, contractId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeDocument], PaginatedResult]:
"""Get all documents for a specific contract."""
- # Step 1: System RBAC filtering
- records = getRecordsetWithRBAC(
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteeDocument,
currentUser=self.currentUser,
+ pagination=pagination,
recordFilter={"contractId": contractId},
- orderBy="documentName",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
-
- result = []
- for record in records:
- cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") and k != "documentData"}
- result.append(TrusteeDocument(**cleanedRecord))
- return result
+
+ def _cleanDocumentRecords(records):
+ return [
+ TrusteeDocument(**{k: v for k, v in r.items() if not k.startswith("_") and k != "documentData"})
+ for r in records
+ ]
+
+ if isinstance(result, PaginatedResult):
+ result.items = _cleanDocumentRecords(result.items)
+ return result
+ return _cleanDocumentRecords(result)
def updateDocument(self, documentId: str, data: Dict[str, Any]) -> Optional[TrusteeDocument]:
"""Update a document.
@@ -1340,101 +1267,94 @@ class TrusteeObjects:
return None
return self._toTrusteePositionOrDelete(records[0], deleteCorrupt=True)
- def getAllPositions(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
- """Get all positions with RBAC filtering + feature-level access filtering."""
- # Step 1: System RBAC filtering
- records = getRecordsetWithRBAC(
+ def getAllPositions(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
+ """Get all positions with RBAC filtering and optional DB-level pagination.
+
+ Filtering, sorting, and pagination are handled at the SQL level.
+ Post-processing cleans internal fields (keeps _createdAt) and validates
+ each record via _toTrusteePositionOrDelete (corrupt rows are deleted).
+
+ NOTE(post-process): totalItems may slightly overcount when corrupt legacy
+ records are removed from the current page.
+ """
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteePosition,
currentUser=self.currentUser,
+ pagination=params,
recordFilter=None,
- orderBy="valuta",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- # Clean records (remove internal fields except _createdAt) - keep as dicts for filtering/sorting
- # Keep _createdAt for display in frontend
keepFields = {'_createdAt'}
- cleanedRecords = []
- for record in records:
- cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") or k in keepFields}
- cleanedRecords.append(cleanedRecord)
- # Step 2: Apply filters (search and field filters)
- filteredRecords = self._applyFilters(cleanedRecords, params)
-
- # Step 3: Apply sorting
- sortedRecords = self._applySorting(filteredRecords, params)
-
- # Step 4: Convert to Pydantic objects and cleanup corrupt legacy records.
- pydanticItems = []
- for record in sortedRecords:
- position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
- if position is not None:
- pydanticItems.append(position)
+ def _cleanAndValidate(records):
+ items = []
+ for record in records:
+ cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") or k in keepFields}
+ position = self._toTrusteePositionOrDelete(cleanedRecord, deleteCorrupt=True)
+ if position is not None:
+ items.append(position)
+ return items
- # Step 5: Apply pagination
- totalItems = len(pydanticItems)
- if params:
- pageSize = params.pageSize or 20
- page = params.page or 1
- startIdx = (page - 1) * pageSize
- endIdx = startIdx + pageSize
- items = pydanticItems[startIdx:endIdx]
- totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
- else:
- items = pydanticItems
- totalPages = 1
- page = 1
- pageSize = totalItems
+ if isinstance(result, PaginatedResult):
+ result.items = _cleanAndValidate(result.items)
+ return result
+ return _cleanAndValidate(result)
- return PaginatedResult(
- items=items,
- totalItems=totalItems,
- totalPages=totalPages
- )
-
- def getPositionsByContract(self, contractId: str) -> List[TrusteePosition]:
+ def getPositionsByContract(self, contractId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteePosition], PaginatedResult]:
"""Get all positions for a specific contract."""
- # Step 1: System RBAC filtering
- records = getRecordsetWithRBAC(
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteePosition,
currentUser=self.currentUser,
+ pagination=pagination,
recordFilter={"contractId": contractId},
- orderBy="valuta",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- safeItems = []
- for record in records:
- position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
- if position is not None:
- safeItems.append(position)
- return safeItems
- def getPositionsByOrganisation(self, organisationId: str) -> List[TrusteePosition]:
+ def _validatePositions(records):
+ items = []
+ for record in records:
+ position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
+ if position is not None:
+ items.append(position)
+ return items
+
+ if isinstance(result, PaginatedResult):
+ result.items = _validatePositions(result.items)
+ return result
+ return _validatePositions(result)
+
+ def getPositionsByOrganisation(self, organisationId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteePosition], PaginatedResult]:
"""Get all positions for a specific organisation."""
- # Step 1: System RBAC filtering
- records = getRecordsetWithRBAC(
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteePosition,
currentUser=self.currentUser,
+ pagination=pagination,
recordFilter={"organisationId": organisationId},
- orderBy="valuta",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- safeItems = []
- for record in records:
- position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
- if position is not None:
- safeItems.append(position)
- return safeItems
+
+ def _validatePositions(records):
+ items = []
+ for record in records:
+ position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
+ if position is not None:
+ items.append(position)
+ return items
+
+ if isinstance(result, PaginatedResult):
+ result.items = _validatePositions(result.items)
+ return result
+ return _validatePositions(result)
def updatePosition(self, positionId: str, data: Dict[str, Any]) -> Optional[TrusteePosition]:
"""Update a position.
@@ -1481,24 +1401,31 @@ class TrusteeObjects:
# ===== Position-Document Queries =====
- def getPositionsByDocument(self, documentId: str) -> List[TrusteePosition]:
+ def getPositionsByDocument(self, documentId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteePosition], PaginatedResult]:
"""Get all positions that reference a specific document (1:N via documentId FK)."""
- records = getRecordsetWithRBAC(
+ result = getRecordsetPaginatedWithRBAC(
connector=self.db,
modelClass=TrusteePosition,
currentUser=self.currentUser,
+ pagination=pagination,
recordFilter={"documentId": documentId},
- orderBy="valuta",
mandateId=self.mandateId,
featureInstanceId=self.featureInstanceId,
featureCode=self.FEATURE_CODE
)
- safeItems = []
- for record in records:
- position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
- if position is not None:
- safeItems.append(position)
- return safeItems
+
+ def _validatePositions(records):
+ items = []
+ for record in records:
+ position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
+ if position is not None:
+ items.append(position)
+ return items
+
+ if isinstance(result, PaginatedResult):
+ result.items = _validatePositions(result.items)
+ return result
+ return _validatePositions(result)
# ===== Trustee-specific Access Check =====
diff --git a/modules/features/trustee/routeFeatureTrustee.py b/modules/features/trustee/routeFeatureTrustee.py
index feb873ae..13b28b07 100644
--- a/modules/features/trustee/routeFeatureTrustee.py
+++ b/modules/features/trustee/routeFeatureTrustee.py
@@ -338,7 +338,7 @@ def get_organisations(
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
result = interface.getAllOrganisations(paginationParams)
- if paginationParams:
+ if paginationParams and hasattr(result, 'items'):
return PaginatedResponse(
items=result.items,
pagination=PaginationMetadata(
@@ -350,7 +350,7 @@ def get_organisations(
filters=paginationParams.filters if paginationParams else None
)
)
- return PaginatedResponse(items=result.items, pagination=None)
+ return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
@router.get("/{instanceId}/organisations/{orgId}", response_model=TrusteeOrganisation)
@@ -451,7 +451,7 @@ def get_roles(
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
result = interface.getAllRoles(paginationParams)
- if paginationParams:
+ if paginationParams and hasattr(result, 'items'):
return PaginatedResponse(
items=result.items,
pagination=PaginationMetadata(
@@ -463,7 +463,7 @@ def get_roles(
filters=paginationParams.filters if paginationParams else None
)
)
- return PaginatedResponse(items=result.items, pagination=None)
+ return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
@router.get("/{instanceId}/roles/{roleId}", response_model=TrusteeRole)
@@ -564,7 +564,7 @@ def get_all_access(
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
result = interface.getAllAccess(paginationParams)
- if paginationParams:
+ if paginationParams and hasattr(result, 'items'):
return PaginatedResponse(
items=result.items,
pagination=PaginationMetadata(
@@ -576,7 +576,7 @@ def get_all_access(
filters=paginationParams.filters if paginationParams else None
)
)
- return PaginatedResponse(items=result.items, pagination=None)
+ return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
@router.get("/{instanceId}/access/{accessId}", response_model=TrusteeAccess)
@@ -707,7 +707,7 @@ def get_contracts(
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
result = interface.getAllContracts(paginationParams)
- if paginationParams:
+ if paginationParams and hasattr(result, 'items'):
return PaginatedResponse(
items=result.items,
pagination=PaginationMetadata(
@@ -719,7 +719,7 @@ def get_contracts(
filters=paginationParams.filters if paginationParams else None
)
)
- return PaginatedResponse(items=result.items, pagination=None)
+ return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
@router.get("/{instanceId}/contracts/{contractId}", response_model=TrusteeContract)
@@ -835,7 +835,7 @@ def get_documents(
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
result = interface.getAllDocuments(paginationParams)
- if paginationParams:
+ if paginationParams and hasattr(result, 'items'):
return PaginatedResponse(
items=result.items,
pagination=PaginationMetadata(
@@ -847,7 +847,59 @@ def get_documents(
filters=paginationParams.filters if paginationParams else None
)
)
- return PaginatedResponse(items=result.items, pagination=None)
+ return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
+
+
+@router.get("/{instanceId}/documents/filter-values")
+@limiter.limit("60/minute")
+def get_document_filter_values(
+ request: Request,
+ instanceId: str = Path(..., description="Feature Instance ID"),
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in trustee documents."""
+ mandateId = _validateInstanceAccess(instanceId, context)
+ try:
+ from modules.interfaces.interfaceRbac import getDistinctColumnValuesWithRBAC
+ interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
+
+ crossFilterPagination = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterPagination = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ try:
+ values = getDistinctColumnValuesWithRBAC(
+ connector=interface.db,
+ modelClass=TrusteeDocument,
+ column=column,
+ currentUser=interface.currentUser,
+ pagination=crossFilterPagination,
+ recordFilter=None,
+ mandateId=interface.mandateId,
+ featureInstanceId=interface.featureInstanceId,
+ featureCode=interface.FEATURE_CODE
+ )
+ return sorted(values, key=lambda v: str(v).lower())
+ except Exception:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ result = interface.getAllDocuments(None)
+ items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in (result.items if hasattr(result, 'items') else result)]
+ return _handleFilterValuesRequest(items, column, pagination)
+ except Exception as e:
+ logger.error(f"Error getting filter values for trustee documents: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
@router.get("/{instanceId}/documents/{documentId}", response_model=TrusteeDocument)
@@ -1039,7 +1091,7 @@ def get_positions(
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
result = interface.getAllPositions(paginationParams)
- if paginationParams:
+ if paginationParams and hasattr(result, 'items'):
return PaginatedResponse(
items=result.items,
pagination=PaginationMetadata(
@@ -1051,7 +1103,59 @@ def get_positions(
filters=paginationParams.filters if paginationParams else None
)
)
- return PaginatedResponse(items=result.items, pagination=None)
+ return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
+
+
+@router.get("/{instanceId}/positions/filter-values")
+@limiter.limit("60/minute")
+def get_position_filter_values(
+ request: Request,
+ instanceId: str = Path(..., description="Feature Instance ID"),
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in trustee positions."""
+ mandateId = _validateInstanceAccess(instanceId, context)
+ try:
+ from modules.interfaces.interfaceRbac import getDistinctColumnValuesWithRBAC
+ interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
+
+ crossFilterPagination = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterPagination = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ try:
+ values = getDistinctColumnValuesWithRBAC(
+ connector=interface.db,
+ modelClass=TrusteePosition,
+ column=column,
+ currentUser=interface.currentUser,
+ pagination=crossFilterPagination,
+ recordFilter=None,
+ mandateId=interface.mandateId,
+ featureInstanceId=interface.featureInstanceId,
+ featureCode=interface.FEATURE_CODE
+ )
+ return sorted(values, key=lambda v: str(v).lower())
+ except Exception:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ result = interface.getAllPositions(None)
+ items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in (result.items if hasattr(result, 'items') else result)]
+ return _handleFilterValuesRequest(items, column, pagination)
+ except Exception as e:
+ logger.error(f"Error getting filter values for trustee positions: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
@router.get("/{instanceId}/positions/{positionId}", response_model=TrusteePosition)
diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py
index 092fd589..12eb935b 100644
--- a/modules/interfaces/interfaceDbApp.py
+++ b/modules/interfaces/interfaceDbApp.py
@@ -423,7 +423,7 @@ class AppObjects:
else:
# Unknown operator - default to equals
- if record_value != filter_val:
+ if str(record_value).lower() != str(filter_val).lower():
matches = False
break
@@ -512,53 +512,34 @@ class AppObjects:
If pagination is None: List[User]
If pagination is provided: PaginatedResult with items and metadata
"""
- # Get user IDs via UserMandate junction table (UserInDB has no mandateId column)
userMandates = self.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})
userIds = [um.get("userId") for um in userMandates if um.get("userId")]
-
- # Fetch each user by ID
- filteredUsers = []
- for userId in userIds:
- userRecords = self.db.getRecordset(UserInDB, recordFilter={"id": userId})
- if userRecords:
- cleanedUser = {k: v for k, v in userRecords[0].items() if not k.startswith("_")}
- if cleanedUser.get("roleLabels") is None:
- cleanedUser["roleLabels"] = []
- filteredUsers.append(cleanedUser)
- # If no pagination requested, return all items
+ if not userIds:
+ if pagination is None:
+ return []
+ return PaginatedResult(items=[], totalItems=0, totalPages=0)
+
+ result = self.db.getRecordsetPaginated(
+ UserInDB,
+ pagination=pagination,
+ recordFilter={"id": userIds}
+ )
+
+ items = []
+ for record in result["items"]:
+ cleanedUser = {k: v for k, v in record.items() if not k.startswith("_")}
+ if cleanedUser.get("roleLabels") is None:
+ cleanedUser["roleLabels"] = []
+ items.append(User(**cleanedUser))
+
if pagination is None:
- return [User(**user) for user in filteredUsers]
-
- # Apply filtering (if filters provided)
- if pagination.filters:
- filteredUsers = self._applyFilters(filteredUsers, pagination.filters)
-
- # Apply sorting (in order of sortFields)
- if pagination.sort:
- filteredUsers = self._applySorting(filteredUsers, pagination.sort)
-
- # Count total items after filters
- totalItems = len(filteredUsers)
- totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
-
- # Apply pagination (skip/limit)
- startIdx = (pagination.page - 1) * pagination.pageSize
- endIdx = startIdx + pagination.pageSize
- pagedUsers = filteredUsers[startIdx:endIdx]
-
- # Ensure roleLabels is always a list for paginated results too
- for user in pagedUsers:
- if user.get("roleLabels") is None:
- user["roleLabels"] = []
-
- # Convert to model objects
- items = [User(**user) for user in pagedUsers]
-
+ return items
+
return PaginatedResult(
items=items,
- totalItems=totalItems,
- totalPages=totalPages
+ totalItems=result["totalItems"],
+ totalPages=result["totalPages"]
)
def getUserByUsername(self, username: str) -> Optional[User]:
@@ -2312,30 +2293,43 @@ class AppObjects:
# Additional Helper Methods
# ============================================
- def getAllUsers(self) -> List[User]:
+ def getAllUsers(self, pagination: Optional[PaginationParams] = None) -> Union[List[User], PaginatedResult]:
"""
Get all users (for SysAdmin only).
+ Args:
+ pagination: Optional pagination parameters. If None, returns all items.
+
Returns:
- List of User objects (without sensitive fields)
+ If pagination is None: List[User] (without sensitive fields)
+ If pagination is provided: PaginatedResult with items and metadata
"""
try:
- records = self.db.getRecordset(UserInDB)
- result = []
- for record in records:
- # Filter out sensitive and internal fields
+ result = self.db.getRecordsetPaginated(UserInDB, pagination=pagination)
+
+ items = []
+ for record in result["items"]:
cleanedRecord = {
- k: v for k, v in record.items()
+ k: v for k, v in record.items()
if not k.startswith("_") and k not in ["hashedPassword", "resetToken", "resetTokenExpires"]
}
- # Ensure roleLabels is a list
if cleanedRecord.get("roleLabels") is None:
cleanedRecord["roleLabels"] = []
- result.append(User(**cleanedRecord))
- return result
+ items.append(User(**cleanedRecord))
+
+ if pagination is None:
+ return items
+
+ return PaginatedResult(
+ items=items,
+ totalItems=result["totalItems"],
+ totalPages=result["totalPages"]
+ )
except Exception as e:
logger.error(f"Error getting all users: {e}")
- return []
+ if pagination is None:
+ return []
+ return PaginatedResult(items=[], totalItems=0, totalPages=0)
def getUserMandateById(self, userMandateId: str) -> Optional[UserMandate]:
"""
@@ -3321,50 +3315,26 @@ class AppObjects:
If pagination is provided: PaginatedResult with items and metadata
"""
try:
- # Get all roles from database
- roles = self.db.getRecordset(Role)
-
- # Filter out database-specific fields
- filteredRoles = []
- for role in roles:
- cleanedRole = {k: v for k, v in role.items() if not k.startswith("_")}
- filteredRoles.append(cleanedRole)
-
- # If no pagination requested, return all items
+ result = self.db.getRecordsetPaginated(Role, pagination=pagination)
+
+ items = []
+ for record in result["items"]:
+ cleanedRole = {k: v for k, v in record.items() if not k.startswith("_")}
+ items.append(Role(**cleanedRole))
+
if pagination is None:
- return [Role(**role) for role in filteredRoles]
-
- # Apply filtering (if filters provided)
- if pagination.filters:
- filteredRoles = self._applyFilters(filteredRoles, pagination.filters)
-
- # Apply sorting (in order of sortFields)
- if pagination.sort:
- filteredRoles = self._applySorting(filteredRoles, pagination.sort)
-
- # Count total items after filters
- totalItems = len(filteredRoles)
- totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
-
- # Apply pagination (skip/limit)
- startIdx = (pagination.page - 1) * pagination.pageSize
- endIdx = startIdx + pagination.pageSize
- pagedRoles = filteredRoles[startIdx:endIdx]
-
- # Convert to model objects
- items = [Role(**role) for role in pagedRoles]
-
+ return items
+
return PaginatedResult(
items=items,
- totalItems=totalItems,
- totalPages=totalPages
+ totalItems=result["totalItems"],
+ totalPages=result["totalPages"]
)
except Exception as e:
logger.error(f"Error getting all roles: {str(e)}")
if pagination is None:
return []
- else:
- return PaginatedResult(items=[], totalItems=0, totalPages=0)
+ return PaginatedResult(items=[], totalItems=0, totalPages=0)
def countRoleAssignments(self) -> Dict[str, int]:
"""
diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py
index 58d56895..2db71bb4 100644
--- a/modules/interfaces/interfaceDbBilling.py
+++ b/modules/interfaces/interfaceDbBilling.py
@@ -8,7 +8,7 @@ All billing data is stored in the poweron_billing database.
"""
import logging
-from typing import Dict, Any, List, Optional
+from typing import Dict, Any, List, Optional, Union
from datetime import date, datetime, timedelta
import uuid
@@ -17,6 +17,7 @@ from modules.shared.configuration import APP_CONFIG
from modules.shared.timeUtils import getUtcTimestamp
from modules.datamodels.datamodelUam import User, Mandate
from modules.datamodels.datamodelMembership import UserMandate
+from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
from modules.datamodels.datamodelBilling import (
BillingAccount,
BillingTransaction,
@@ -633,26 +634,43 @@ class BillingObjects:
limit: int = 100,
offset: int = 0,
startDate: date = None,
- endDate: date = None
- ) -> List[Dict[str, Any]]:
+ endDate: date = None,
+ pagination: PaginationParams = None
+ ) -> Union[List[Dict[str, Any]], PaginatedResult]:
"""
Get transactions for an account.
+ When pagination is provided, uses database-level pagination and returns
+ PaginatedResult. Otherwise falls back to in-memory filtering/sorting/slicing.
+
Args:
accountId: Account ID
- limit: Maximum number of results
- offset: Offset for pagination
- startDate: Filter by start date
- endDate: Filter by end date
+ limit: Maximum number of results (legacy path)
+ offset: Offset for pagination (legacy path)
+ startDate: Filter by start date (legacy path)
+ endDate: Filter by end date (legacy path)
+ pagination: PaginationParams for DB-level pagination
Returns:
- List of transaction dicts
+ PaginatedResult when pagination is provided, List of dicts otherwise
"""
try:
+ if pagination:
+ recordFilter = {"accountId": accountId}
+ result = self.db.getRecordsetPaginated(
+ BillingTransaction,
+ pagination=pagination,
+ recordFilter=recordFilter
+ )
+ return PaginatedResult(
+ items=result["items"],
+ totalItems=result["totalItems"],
+ totalPages=result["totalPages"]
+ )
+
filterDict = {"accountId": accountId}
results = self.db.getRecordset(BillingTransaction, recordFilter=filterDict)
- # Apply date filters if provided
if startDate or endDate:
filtered = []
for t in results:
@@ -666,35 +684,61 @@ class BillingObjects:
filtered.append(t)
results = filtered
- # Sort by creation date descending
results.sort(key=lambda x: x.get("_createdAt", ""), reverse=True)
- # Apply pagination
return results[offset:offset + limit]
except Exception as e:
logger.error(f"Error getting transactions: {e}")
+ if pagination:
+ return PaginatedResult(items=[], totalItems=0, totalPages=0)
return []
- def getTransactionsByMandate(self, mandateId: str, limit: int = 100) -> List[Dict[str, Any]]:
+ def getTransactionsByMandate(
+ self,
+ mandateId: str,
+ limit: int = 100,
+ pagination: PaginationParams = None
+ ) -> Union[List[Dict[str, Any]], PaginatedResult]:
"""
Get all transactions for a mandate (across all accounts).
+ When pagination is provided, collects all accountIds for the mandate and
+ issues a single DB query with SQL-level filtering, sorting, and pagination.
+ Otherwise falls back to per-account querying and in-memory merging.
+
Args:
mandateId: Mandate ID
- limit: Maximum number of results
+ limit: Maximum number of results (legacy path)
+ pagination: PaginationParams for DB-level pagination
Returns:
- List of transaction dicts
+ PaginatedResult when pagination is provided, List of dicts otherwise
"""
- # Get all accounts for mandate
accounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId})
-
+ accountIds = [acc["id"] for acc in accounts if acc.get("id")]
+
+ if not accountIds:
+ if pagination:
+ return PaginatedResult(items=[], totalItems=0, totalPages=0)
+ return []
+
+ if pagination:
+ result = self.db.getRecordsetPaginated(
+ BillingTransaction,
+ pagination=pagination,
+ recordFilter={"accountId": accountIds}
+ )
+ return PaginatedResult(
+ items=result["items"],
+ totalItems=result["totalItems"],
+ totalPages=result["totalPages"]
+ )
+
allTransactions = []
for account in accounts:
transactions = self.getTransactions(account["id"], limit=limit)
allTransactions.extend(transactions)
- # Sort by creation date descending and limit
allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True)
return allTransactions[:limit]
diff --git a/modules/interfaces/interfaceDbChat.py b/modules/interfaces/interfaceDbChat.py
index f2a3d4c6..b0d4aff3 100644
--- a/modules/interfaces/interfaceDbChat.py
+++ b/modules/interfaces/interfaceDbChat.py
@@ -450,7 +450,7 @@ class ChatObjects:
# Handle simple value (equals operator)
if not isinstance(filter_value, dict):
- if record_value != filter_value:
+ if str(record_value).lower() != str(filter_value).lower():
matches = False
break
continue
@@ -460,7 +460,7 @@ class ChatObjects:
filter_val = filter_value.get("value")
if operator in ["equals", "eq"]:
- if record_value != filter_val:
+ if str(record_value).lower() != str(filter_val).lower():
matches = False
break
@@ -545,7 +545,7 @@ class ChatObjects:
else:
# Unknown operator - default to equals
- if record_value != filter_val:
+ if str(record_value).lower() != str(filter_val).lower():
matches = False
break
diff --git a/modules/interfaces/interfaceDbManagement.py b/modules/interfaces/interfaceDbManagement.py
index 9c266aac..ae0b14b0 100644
--- a/modules/interfaces/interfaceDbManagement.py
+++ b/modules/interfaces/interfaceDbManagement.py
@@ -390,7 +390,7 @@ class ComponentObjects:
# Handle simple value (equals operator)
if not isinstance(filter_value, dict):
- if record_value != filter_value:
+ if str(record_value).lower() != str(filter_value).lower():
matches = False
break
continue
@@ -400,7 +400,7 @@ class ComponentObjects:
filter_val = filter_value.get("value")
if operator in ["equals", "eq"]:
- if record_value != filter_val:
+ if str(record_value).lower() != str(filter_val).lower():
matches = False
break
@@ -650,6 +650,10 @@ class ComponentObjects:
- Regular user: sees own prompts + system prompts (isSystem=True), can only CRUD own
- Row-level _permissions control edit/delete buttons in the UI
+ NOTE: Cannot use db.getRecordsetPaginated() because visibility rules
+ (_getPromptsForUser: own + system for regular, all for SysAdmin) and
+ per-row _permissions enrichment require loading all records first.
+
Args:
pagination: Optional pagination parameters. If None, returns all items.
@@ -914,7 +918,7 @@ class ComponentObjects:
"""
Returns files owned by the current user (user-scoped, not RBAC-based).
Every user (including SysAdmin) only sees their own files.
- Supports optional pagination, sorting, and filtering.
+ Supports optional pagination, sorting, and filtering via database-level queries.
Args:
pagination: Optional pagination parameters. If None, returns all items.
@@ -923,24 +927,21 @@ class ComponentObjects:
If pagination is None: List[FileItem]
If pagination is provided: PaginatedResult with items and metadata
"""
- # Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override)
- filteredFiles = self._getFilesByCurrentUser()
-
- # Convert database records to FileItem instances (extra='allow' preserves system fields like _createdBy)
- def convertFileItems(files):
+ # User-scoping filter: every user only sees their own files (bypasses RBAC SysAdmin override)
+ recordFilter = {"_createdBy": self.userId}
+
+ def _convertFileItems(files):
fileItems = []
for file in files:
try:
- # Ensure proper values, use defaults for invalid data
creationDate = file.get("creationDate")
if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0:
file["creationDate"] = getUtcTimestamp()
fileName = file.get("fileName")
if not fileName or fileName == "None":
- continue # Skip records with invalid fileName
+ continue
- # Use **file to pass all fields including system fields (_createdBy, etc.)
fileItem = FileItem(**file)
fileItems.append(fileItem)
except Exception as e:
@@ -948,34 +949,23 @@ class ComponentObjects:
continue
return fileItems
- # If no pagination requested, return all items
if pagination is None:
- return convertFileItems(filteredFiles)
+ allFiles = self._getFilesByCurrentUser()
+ return _convertFileItems(allFiles)
- # Apply filtering (if filters provided)
- if pagination.filters:
- filteredFiles = self._applyFilters(filteredFiles, pagination.filters)
+ # Database-level pagination: filtering, sorting, and LIMIT/OFFSET happen in SQL
+ result = self.db.getRecordsetPaginated(
+ FileItem,
+ pagination=pagination,
+ recordFilter=recordFilter
+ )
- # Apply sorting (in order of sortFields)
- if pagination.sort:
- filteredFiles = self._applySorting(filteredFiles, pagination.sort)
-
- # Count total items after filters
- totalItems = len(filteredFiles)
- totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
-
- # Apply pagination (skip/limit)
- startIdx = (pagination.page - 1) * pagination.pageSize
- endIdx = startIdx + pagination.pageSize
- pagedFiles = filteredFiles[startIdx:endIdx]
-
- # Convert to model objects (extra='allow' on FileItem preserves system fields)
- items = convertFileItems(pagedFiles)
+ items = _convertFileItems(result["items"])
return PaginatedResult(
items=items,
- totalItems=totalItems,
- totalPages=totalPages
+ totalItems=result["totalItems"],
+ totalPages=result["totalPages"]
)
def getFile(self, fileId: str) -> Optional[FileItem]:
diff --git a/modules/interfaces/interfaceRbac.py b/modules/interfaces/interfaceRbac.py
index b4c9a3b4..b9a4ac9b 100644
--- a/modules/interfaces/interfaceRbac.py
+++ b/modules/interfaces/interfaceRbac.py
@@ -23,10 +23,12 @@ GROUP-Berechtigung:
import logging
import json
-from typing import List, Dict, Any, Optional, Type
+import math
+from typing import List, Dict, Any, Optional, Type, Union
from pydantic import BaseModel
from modules.datamodels.datamodelRbac import AccessRuleContext
from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel
+from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
from modules.security.rbac import RbacClass
from modules.security.rootAccess import getRootDbAppConnector
@@ -297,6 +299,309 @@ def getRecordsetWithRBAC(
return []
+def getRecordsetPaginatedWithRBAC(
+ connector,
+ modelClass: Type[BaseModel],
+ currentUser: User,
+ pagination: Optional[PaginationParams] = None,
+ recordFilter: Dict[str, Any] = None,
+ mandateId: Optional[str] = None,
+ featureInstanceId: Optional[str] = None,
+ enrichPermissions: bool = False,
+ featureCode: Optional[str] = None,
+) -> Union[List[Dict[str, Any]], PaginatedResult]:
+ """
+ Get records with RBAC filtering and SQL-level pagination.
+ When pagination is None, returns a plain list (backward compatible).
+ When pagination is provided, returns PaginatedResult with COUNT + LIMIT/OFFSET at SQL level.
+ """
+ table = modelClass.__name__
+ objectKey = buildDataObjectKey(table, featureCode)
+ effectiveMandateId = mandateId
+
+ try:
+ if not connector._ensureTableExists(modelClass):
+ return PaginatedResult(items=[], totalItems=0, totalPages=0) if pagination else []
+
+ dbApp = getRootDbAppConnector()
+ rbacInstance = RbacClass(connector, dbApp=dbApp)
+ permissions = rbacInstance.getUserPermissions(
+ currentUser,
+ AccessRuleContext.DATA,
+ objectKey,
+ mandateId=effectiveMandateId,
+ featureInstanceId=featureInstanceId
+ )
+
+ if not permissions.view:
+ return PaginatedResult(items=[], totalItems=0, totalPages=0) if pagination else []
+
+ whereConditions = []
+ whereValues = []
+
+ featureInstanceIdForQuery = featureInstanceId
+ if featureInstanceId and hasattr(modelClass, 'model_fields') and "featureInstanceId" not in modelClass.model_fields:
+ featureInstanceIdForQuery = None
+
+ rbacWhereClause = buildRbacWhereClause(
+ permissions, currentUser, table, connector,
+ mandateId=effectiveMandateId,
+ featureInstanceId=featureInstanceIdForQuery
+ )
+ if rbacWhereClause:
+ whereConditions.append(rbacWhereClause["condition"])
+ whereValues.extend(rbacWhereClause["values"])
+
+ if recordFilter:
+ for field, value in recordFilter.items():
+ if isinstance(value, (list, tuple)):
+ if len(value) == 0:
+ whereConditions.append("1 = 0")
+ else:
+ whereConditions.append(f'"{field}" = ANY(%s)')
+ whereValues.append(list(value))
+ elif value is None:
+ whereConditions.append(f'"{field}" IS NULL')
+ else:
+ whereConditions.append(f'"{field}" = %s')
+ whereValues.append(value)
+
+ if pagination and pagination.filters:
+ from modules.connectors.connectorDbPostgre import _get_model_fields
+ fields = _get_model_fields(modelClass)
+ validColumns = set(fields.keys())
+ for key, val in pagination.filters.items():
+ if key == "search" and isinstance(val, str) and val.strip():
+ term = f"%{val.strip()}%"
+ textCols = [c for c, t in fields.items() if t == "TEXT"]
+ if textCols:
+ orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols]
+ whereConditions.append(f"({' OR '.join(orParts)})")
+ whereValues.extend([term] * len(textCols))
+ continue
+ if key not in validColumns:
+ continue
+ if isinstance(val, dict):
+ op = val.get("operator", "equals")
+ v = val.get("value", "")
+ if op in ("equals", "eq"):
+ whereConditions.append(f'"{key}"::TEXT = %s')
+ whereValues.append(str(v))
+ elif op == "contains":
+ whereConditions.append(f'"{key}"::TEXT ILIKE %s')
+ whereValues.append(f"%{v}%")
+ elif op == "startsWith":
+ whereConditions.append(f'"{key}"::TEXT ILIKE %s')
+ whereValues.append(f"{v}%")
+ elif op == "endsWith":
+ whereConditions.append(f'"{key}"::TEXT ILIKE %s')
+ whereValues.append(f"%{v}")
+ elif op in ("gt", "gte", "lt", "lte"):
+ sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op]
+ whereConditions.append(f'"{key}"::TEXT {sqlOp} %s')
+ whereValues.append(str(v))
+ elif op == "between":
+ fromVal = v.get("from", "") if isinstance(v, dict) else ""
+ toVal = v.get("to", "") if isinstance(v, dict) else ""
+ if fromVal and toVal:
+ whereConditions.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s')
+ whereValues.extend([str(fromVal), str(toVal)])
+ elif fromVal:
+ whereConditions.append(f'"{key}"::TEXT >= %s')
+ whereValues.append(str(fromVal))
+ elif toVal:
+ whereConditions.append(f'"{key}"::TEXT <= %s')
+ whereValues.append(str(toVal))
+ else:
+ whereConditions.append(f'"{key}"::TEXT ILIKE %s')
+ whereValues.append(str(val))
+
+ whereClause = " WHERE " + " AND ".join(whereConditions) if whereConditions else ""
+ countValues = list(whereValues)
+
+ orderParts: List[str] = []
+ if pagination and pagination.sort:
+ from modules.connectors.connectorDbPostgre import _get_model_fields
+ validColumns = set(_get_model_fields(modelClass).keys())
+ for sf in pagination.sort:
+ if sf.field in validColumns:
+ direction = "DESC" if sf.direction.lower() == "desc" else "ASC"
+ orderParts.append(f'"{sf.field}" {direction}')
+ if not orderParts:
+ orderParts.append('"id"')
+ orderByClause = " ORDER BY " + ", ".join(orderParts)
+
+ limitClause = ""
+ if pagination:
+ offset = (pagination.page - 1) * pagination.pageSize
+ limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
+
+ with connector.connection.cursor() as cursor:
+ countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
+ cursor.execute(countSql, countValues)
+ totalItems = cursor.fetchone()["count"]
+
+ dataSql = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
+ cursor.execute(dataSql, whereValues)
+ records = [dict(row) for row in cursor.fetchall()]
+
+ from modules.connectors.connectorDbPostgre import _get_model_fields, _parseRecordFields
+ fields = _get_model_fields(modelClass)
+ for record in records:
+ _parseRecordFields(record, fields, f"table {table}")
+ for fieldName, fieldType in fields.items():
+ if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
+ modelFields = modelClass.model_fields
+ fieldInfo = modelFields.get(fieldName)
+ if fieldInfo:
+ fieldAnnotation = fieldInfo.annotation
+ if (fieldAnnotation == list or
+ (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is list)):
+ record[fieldName] = []
+ elif (fieldAnnotation == dict or
+ (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is dict)):
+ record[fieldName] = {}
+
+ if enrichPermissions:
+ records = _enrichRecordsWithPermissions(records, permissions, currentUser)
+
+ if pagination:
+ pageSize = pagination.pageSize
+ totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0
+ return PaginatedResult(items=records, totalItems=totalItems, totalPages=totalPages)
+
+ return records
+ except Exception as e:
+ logger.error(f"Error in getRecordsetPaginatedWithRBAC for table {table}: {e}")
+ return PaginatedResult(items=[], totalItems=0, totalPages=0) if pagination else []
+
+
+def getDistinctColumnValuesWithRBAC(
+ connector,
+ modelClass: Type[BaseModel],
+ currentUser: User,
+ column: str,
+ pagination: Optional[PaginationParams] = None,
+ recordFilter: Dict[str, Any] = None,
+ mandateId: Optional[str] = None,
+ featureInstanceId: Optional[str] = None,
+ featureCode: Optional[str] = None,
+) -> List[str]:
+ """
+ Get sorted distinct values for a column with RBAC filtering at SQL level.
+ Cross-filtering: removes the requested column from active filters.
+ """
+ import copy
+ table = modelClass.__name__
+ objectKey = buildDataObjectKey(table, featureCode)
+
+ try:
+ if not connector._ensureTableExists(modelClass):
+ return []
+
+ from modules.connectors.connectorDbPostgre import _get_model_fields
+ fields = _get_model_fields(modelClass)
+ if column not in fields:
+ return []
+
+ dbApp = getRootDbAppConnector()
+ rbacInstance = RbacClass(connector, dbApp=dbApp)
+ permissions = rbacInstance.getUserPermissions(
+ currentUser, AccessRuleContext.DATA, objectKey,
+ mandateId=mandateId, featureInstanceId=featureInstanceId
+ )
+ if not permissions.view:
+ return []
+
+ whereConditions = []
+ whereValues = []
+
+ featureInstanceIdForQuery = featureInstanceId
+ if featureInstanceId and hasattr(modelClass, 'model_fields') and "featureInstanceId" not in modelClass.model_fields:
+ featureInstanceIdForQuery = None
+
+ rbacWhereClause = buildRbacWhereClause(
+ permissions, currentUser, table, connector,
+ mandateId=mandateId, featureInstanceId=featureInstanceIdForQuery
+ )
+ if rbacWhereClause:
+ whereConditions.append(rbacWhereClause["condition"])
+ whereValues.extend(rbacWhereClause["values"])
+
+ if recordFilter:
+ for field, value in recordFilter.items():
+ if isinstance(value, (list, tuple)):
+ if not value:
+ whereConditions.append("1 = 0")
+ else:
+ whereConditions.append(f'"{field}" = ANY(%s)')
+ whereValues.append(list(value))
+ elif value is None:
+ whereConditions.append(f'"{field}" IS NULL')
+ else:
+ whereConditions.append(f'"{field}" = %s')
+ whereValues.append(value)
+
+ crossPagination = copy.deepcopy(pagination) if pagination else None
+ if crossPagination and crossPagination.filters:
+ crossPagination.filters.pop(column, None)
+ validColumns = set(fields.keys())
+ for key, val in crossPagination.filters.items():
+ if key == "search" and isinstance(val, str) and val.strip():
+ term = f"%{val.strip()}%"
+ textCols = [c for c, t in fields.items() if t == "TEXT"]
+ if textCols:
+ orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols]
+ whereConditions.append(f"({' OR '.join(orParts)})")
+ whereValues.extend([term] * len(textCols))
+ continue
+ if key not in validColumns:
+ continue
+ if isinstance(val, dict):
+ op = val.get("operator", "equals")
+ v = val.get("value", "")
+ if op in ("equals", "eq"):
+ whereConditions.append(f'"{key}"::TEXT = %s')
+ whereValues.append(str(v))
+ elif op == "contains":
+ whereConditions.append(f'"{key}"::TEXT ILIKE %s')
+ whereValues.append(f"%{v}%")
+ elif op == "between":
+ fromVal = v.get("from", "") if isinstance(v, dict) else ""
+ toVal = v.get("to", "") if isinstance(v, dict) else ""
+ if fromVal and toVal:
+ whereConditions.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s')
+ whereValues.extend([str(fromVal), str(toVal)])
+ elif fromVal:
+ whereConditions.append(f'"{key}"::TEXT >= %s')
+ whereValues.append(str(fromVal))
+ elif toVal:
+ whereConditions.append(f'"{key}"::TEXT <= %s')
+ whereValues.append(str(toVal))
+ else:
+ whereConditions.append(f'"{key}"::TEXT ILIKE %s')
+ whereValues.append(str(v) if isinstance(v, str) else str(val))
+ else:
+ whereConditions.append(f'"{key}"::TEXT ILIKE %s')
+ whereValues.append(str(val))
+
+ whereClause = " WHERE " + " AND ".join(whereConditions) if whereConditions else ""
+ notNullCond = f'"{column}" IS NOT NULL AND "{column}"::TEXT != \'\''
+ if whereClause:
+ whereClause += f" AND {notNullCond}"
+ else:
+ whereClause = f" WHERE {notNullCond}"
+
+ sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{whereClause} ORDER BY val'
+
+ with connector.connection.cursor() as cursor:
+ cursor.execute(sql, whereValues)
+ return [row["val"] for row in cursor.fetchall()]
+ except Exception as e:
+ logger.error(f"Error in getDistinctColumnValuesWithRBAC for {table}.{column}: {e}")
+ return []
+
+
def buildRbacWhereClause(
permissions: UserPermissions,
currentUser: User,
diff --git a/modules/routes/routeAdminAutomationEvents.py b/modules/routes/routeAdminAutomationEvents.py
index d184ae76..47d3ac9c 100644
--- a/modules/routes/routeAdminAutomationEvents.py
+++ b/modules/routes/routeAdminAutomationEvents.py
@@ -5,15 +5,19 @@ Admin automation events routes for the backend API.
Sysadmin-only endpoints for viewing and controlling scheduler events.
"""
-from fastapi import APIRouter, HTTPException, Depends, Path, Request, Response
-from typing import List, Dict, Any
+from fastapi import APIRouter, HTTPException, Depends, Path, Request, Response, Query
+from typing import List, Dict, Any, Optional
from fastapi import status
import logging
+import json
+import math
# Import interfaces and models from feature containers
import modules.features.automation.interfaceFeatureAutomation as interfaceAutomation
from modules.auth import limiter, getRequestContext, requireSysAdminRole, RequestContext
from modules.datamodels.datamodelUam import User
+from modules.datamodels.datamodelPagination import PaginationParams, PaginationMetadata, normalize_pagination_dict
+from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
# Configure logger
logger = logging.getLogger(__name__)
@@ -31,122 +35,176 @@ router = APIRouter(
}
)
+def _buildEnrichedAutomationEvents(currentUser: User) -> List[Dict[str, Any]]:
+ """Build the full enriched automation events list."""
+ from modules.shared.eventManagement import eventManager
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.features.automation.mainAutomation import getAutomationServices
+
+ if not eventManager.scheduler:
+ return []
+
+ jobs = []
+ for job in eventManager.scheduler.get_jobs():
+ if job.id.startswith("automation."):
+ automationId = job.id.replace("automation.", "")
+ jobs.append({
+ "eventId": job.id,
+ "id": job.id,
+ "automationId": automationId,
+ "nextRunTime": str(job.next_run_time) if job.next_run_time else None,
+ "trigger": str(job.trigger) if job.trigger else None,
+ "name": "",
+ "createdBy": "",
+ "mandate": "",
+ "featureInstance": ""
+ })
+
+ if jobs:
+ try:
+ rootInterface = getRootInterface()
+ eventUser = rootInterface.getUserByUsername("event")
+ if eventUser:
+ services = getAutomationServices(currentUser, mandateId=None, featureInstanceId=None)
+ allAutomations = services.interfaceDbAutomation.getAllAutomationDefinitionsWithRBAC(eventUser)
+
+ automationLookup = {}
+ for a in allAutomations:
+ aId = a.get("id", "") if isinstance(a, dict) else getattr(a, "id", "")
+ automationLookup[aId] = a
+
+ _userCache: Dict[str, str] = {}
+ _mandateCache: Dict[str, str] = {}
+ _featureCache: Dict[str, str] = {}
+
+ def _resolveUsername(userId):
+ if not userId: return ""
+ if userId not in _userCache:
+ try:
+ user = rootInterface.getUser(userId)
+ _userCache[userId] = user.username if user else userId[:8]
+ except Exception:
+ _userCache[userId] = userId[:8]
+ return _userCache[userId]
+
+ def _resolveMandateLabel(mandateId):
+ if not mandateId: return ""
+ if mandateId not in _mandateCache:
+ try:
+ mandate = rootInterface.getMandate(mandateId)
+ _mandateCache[mandateId] = getattr(mandate, "label", None) or mandateId[:8]
+ except Exception:
+ _mandateCache[mandateId] = mandateId[:8]
+ return _mandateCache[mandateId]
+
+ def _resolveFeatureLabel(featureInstanceId):
+ if not featureInstanceId: return ""
+ if featureInstanceId not in _featureCache:
+ try:
+ instance = rootInterface.getFeatureInstance(featureInstanceId)
+ _featureCache[featureInstanceId] = getattr(instance, "label", None) or getattr(instance, "featureCode", None) or featureInstanceId[:8]
+ except Exception:
+ _featureCache[featureInstanceId] = featureInstanceId[:8]
+ return _featureCache[featureInstanceId]
+
+ for job in jobs:
+ automation = automationLookup.get(job["automationId"])
+ if automation:
+ if isinstance(automation, dict):
+ job["name"] = automation.get("label", "")
+ job["createdBy"] = _resolveUsername(automation.get("_createdBy", ""))
+ job["mandate"] = _resolveMandateLabel(automation.get("mandateId", ""))
+ job["featureInstance"] = _resolveFeatureLabel(automation.get("featureInstanceId", ""))
+ else:
+ job["name"] = getattr(automation, "label", "")
+ job["createdBy"] = _resolveUsername(getattr(automation, "_createdBy", ""))
+ job["mandate"] = _resolveMandateLabel(getattr(automation, "mandateId", ""))
+ job["featureInstance"] = _resolveFeatureLabel(getattr(automation, "featureInstanceId", ""))
+ else:
+ job["name"] = "(orphaned)"
+ except Exception as e:
+ logger.warning(f"Could not enrich automation events with context: {e}")
+
+ return jobs
+
+
@router.get("")
@limiter.limit("30/minute")
def get_all_automation_events(
request: Request,
- currentUser: User = Depends(requireSysAdminRole)
-) -> List[Dict[str, Any]]:
- """
- Get all active scheduler jobs (sysadmin only).
- Each job is enriched with context from its automation definition
- (name, mandate, feature instance, creator) for readability.
- """
+ pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
+ currentUser: User = Depends(requireSysAdminRole),
+):
+ """Get all active scheduler jobs with pagination support (sysadmin only)."""
try:
- from modules.shared.eventManagement import eventManager
- from modules.interfaces.interfaceDbApp import getRootInterface
- from modules.features.automation.mainAutomation import getAutomationServices
-
- if not eventManager.scheduler:
- return []
-
- # 1. Collect all scheduler jobs
- jobs = []
- automationIds = []
- for job in eventManager.scheduler.get_jobs():
- if job.id.startswith("automation."):
- automationId = job.id.replace("automation.", "")
- automationIds.append(automationId)
- jobs.append({
- "eventId": job.id,
- "automationId": automationId,
- "nextRunTime": str(job.next_run_time) if job.next_run_time else None,
- "trigger": str(job.trigger) if job.trigger else None,
- "name": "",
- "createdBy": "",
- "mandate": "",
- "featureInstance": ""
- })
-
- # 2. Enrich with context from automation definitions
- if jobs:
+ paginationParams: Optional[PaginationParams] = None
+ if pagination:
try:
- rootInterface = getRootInterface()
- eventUser = rootInterface.getUserByUsername("event")
- if eventUser:
- services = getAutomationServices(currentUser, mandateId=None, featureInstanceId=None)
- allAutomations = services.interfaceDbAutomation.getAllAutomationDefinitionsWithRBAC(eventUser)
-
- # Build lookup by automation ID
- automationLookup = {}
- for a in allAutomations:
- aId = a.get("id", "") if isinstance(a, dict) else getattr(a, "id", "")
- automationLookup[aId] = a
-
- # Caches for resolving UUIDs to names
- _userCache = {}
- _mandateCache = {}
- _featureCache = {}
-
- def _resolveUsername(userId):
- if not userId:
- return ""
- if userId not in _userCache:
- try:
- user = rootInterface.getUser(userId)
- _userCache[userId] = user.username if user else userId[:8]
- except Exception:
- _userCache[userId] = userId[:8]
- return _userCache[userId]
-
- def _resolveMandateLabel(mandateId):
- if not mandateId:
- return ""
- if mandateId not in _mandateCache:
- try:
- mandate = rootInterface.getMandate(mandateId)
- _mandateCache[mandateId] = getattr(mandate, "label", None) or mandateId[:8]
- except Exception:
- _mandateCache[mandateId] = mandateId[:8]
- return _mandateCache[mandateId]
-
- def _resolveFeatureLabel(featureInstanceId):
- if not featureInstanceId:
- return ""
- if featureInstanceId not in _featureCache:
- try:
- instance = rootInterface.getFeatureInstance(featureInstanceId)
- _featureCache[featureInstanceId] = getattr(instance, "label", None) or getattr(instance, "featureCode", None) or featureInstanceId[:8]
- except Exception:
- _featureCache[featureInstanceId] = featureInstanceId[:8]
- return _featureCache[featureInstanceId]
-
- # Enrich each job
- for job in jobs:
- automation = automationLookup.get(job["automationId"])
- if automation:
- if isinstance(automation, dict):
- job["name"] = automation.get("label", "")
- job["createdBy"] = _resolveUsername(automation.get("_createdBy", ""))
- job["mandate"] = _resolveMandateLabel(automation.get("mandateId", ""))
- job["featureInstance"] = _resolveFeatureLabel(automation.get("featureInstanceId", ""))
- else:
- job["name"] = getattr(automation, "label", "")
- job["createdBy"] = _resolveUsername(getattr(automation, "_createdBy", ""))
- job["mandate"] = _resolveMandateLabel(getattr(automation, "mandateId", ""))
- job["featureInstance"] = _resolveFeatureLabel(getattr(automation, "featureInstanceId", ""))
- else:
- job["name"] = "(orphaned)"
- except Exception as e:
- logger.warning(f"Could not enrich automation events with context: {e}")
-
- return jobs
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ paginationParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError) as e:
+ raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
+
+ enriched = _buildEnrichedAutomationEvents(currentUser)
+ filtered = _applyFiltersAndSort(enriched, paginationParams)
+
+ if paginationParams:
+ totalItems = len(filtered)
+ totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
+ startIdx = (paginationParams.page - 1) * paginationParams.pageSize
+ endIdx = startIdx + paginationParams.pageSize
+ return {
+ "items": filtered[startIdx:endIdx],
+ "pagination": PaginationMetadata(
+ currentPage=paginationParams.page,
+ pageSize=paginationParams.pageSize,
+ totalItems=totalItems,
+ totalPages=totalPages,
+ sort=paginationParams.sort,
+ filters=paginationParams.filters,
+ ).model_dump(),
+ }
+
+ return {"items": enriched, "pagination": None}
+ except HTTPException:
+ raise
except Exception as e:
logger.error(f"Error getting automation events: {str(e)}")
- raise HTTPException(
- status_code=500,
- detail=f"Error getting automation events: {str(e)}"
- )
+ raise HTTPException(status_code=500, detail=f"Error getting automation events: {str(e)}")
+
+
+@router.get("/filter-values")
+@limiter.limit("60/minute")
+def get_automation_event_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ currentUser: User = Depends(requireSysAdminRole),
+):
+ """Return distinct filter values for a column in automation events."""
+ try:
+ crossFilterParams: Optional[PaginationParams] = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ enriched = _buildEnrichedAutomationEvents(currentUser)
+ crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
+ return _extractDistinctValues(crossFiltered, column)
+ except Exception as e:
+ logger.error(f"Error getting filter values: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
@router.post("/sync")
@limiter.limit("5/minute")
diff --git a/modules/routes/routeAdminFeatures.py b/modules/routes/routeAdminFeatures.py
index b31b5e4c..3e701548 100644
--- a/modules/routes/routeAdminFeatures.py
+++ b/modules/routes/routeAdminFeatures.py
@@ -14,7 +14,11 @@ from fastapi import APIRouter, HTTPException, Depends, Request, Query
from typing import List, Dict, Any, Optional
from fastapi import status
import logging
+import json
+import math
from pydantic import BaseModel, Field
+from modules.datamodels.datamodelPagination import PaginationParams, PaginationMetadata, normalize_pagination_dict
+from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
from modules.auth import limiter, getRequestContext, RequestContext, requireSysAdminRole
from modules.datamodels.datamodelUam import User, UserInDB
@@ -433,6 +437,35 @@ def list_feature_instances(
)
+@router.get("/instances/filter-values")
+@limiter.limit("60/minute")
+def get_feature_instance_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ featureCode: Optional[str] = Query(None, description="Filter by feature code"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in feature instances."""
+ if not context.mandateId:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="X-Mandate-Id header is required")
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ rootInterface = getRootInterface()
+ featureInterface = getFeatureInterface(rootInterface.db)
+ instances = featureInterface.getFeatureInstancesForMandate(
+ mandateId=str(context.mandateId),
+ featureCode=featureCode
+ )
+ items = [inst.model_dump() for inst in instances]
+ return _handleFilterValuesRequest(items, column, pagination)
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting filter values for feature instances: {e}")
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
+
+
@router.get("/instances/{instanceId}", response_model=Dict[str, Any])
@limiter.limit("60/minute")
def get_feature_instance(
@@ -783,27 +816,55 @@ def sync_instance_roles(
# Template Role Endpoints (SysAdmin only)
# =============================================================================
-@router.get("/templates/roles", response_model=List[Dict[str, Any]])
+def _buildTemplateRolesList(featureCode: Optional[str] = None) -> List[Dict[str, Any]]:
+ """Build the full template roles list."""
+ rootInterface = getRootInterface()
+ featureInterface = getFeatureInterface(rootInterface.db)
+ roles = featureInterface.getTemplateRoles(featureCode)
+ return [r.model_dump() for r in roles]
+
+
+@router.get("/templates/roles")
@limiter.limit("60/minute")
def list_template_roles(
request: Request,
featureCode: Optional[str] = Query(None, description="Filter by feature code"),
- sysAdmin: User = Depends(requireSysAdminRole)
-) -> List[Dict[str, Any]]:
- """
- List global template roles.
-
- SysAdmin only - returns template roles that are copied to new feature instances.
-
- Args:
- featureCode: Optional filter by feature code
- """
+ pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
+ sysAdmin: User = Depends(requireSysAdminRole),
+):
+ """List global template roles with pagination support."""
try:
- rootInterface = getRootInterface()
- featureInterface = getFeatureInterface(rootInterface.db)
-
- roles = featureInterface.getTemplateRoles(featureCode)
- return [r.model_dump() for r in roles]
+ paginationParams: Optional[PaginationParams] = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ paginationParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError) as e:
+ raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
+
+ enriched = _buildTemplateRolesList(featureCode)
+ filtered = _applyFiltersAndSort(enriched, paginationParams)
+
+ if paginationParams:
+ totalItems = len(filtered)
+ totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
+ startIdx = (paginationParams.page - 1) * paginationParams.pageSize
+ endIdx = startIdx + paginationParams.pageSize
+ return {
+ "items": filtered[startIdx:endIdx],
+ "pagination": PaginationMetadata(
+ currentPage=paginationParams.page,
+ pageSize=paginationParams.pageSize,
+ totalItems=totalItems,
+ totalPages=totalPages,
+ sort=paginationParams.sort,
+ filters=paginationParams.filters,
+ ).model_dump(),
+ }
+
+ return {"items": enriched, "pagination": None}
except Exception as e:
logger.error(f"Error listing template roles: {e}")
@@ -813,6 +874,39 @@ def list_template_roles(
)
+@router.get("/templates/roles/filter-values")
+@limiter.limit("60/minute")
+def get_template_role_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ featureCode: Optional[str] = Query(None, description="Filter by feature code"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ sysAdmin: User = Depends(requireSysAdminRole),
+):
+ """Return distinct filter values for a column in template roles."""
+ try:
+ crossFilterParams: Optional[PaginationParams] = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ enriched = _buildTemplateRolesList(featureCode)
+ crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
+ return _extractDistinctValues(crossFiltered, column)
+ except Exception as e:
+ logger.error(f"Error getting filter values: {e}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@router.post("/templates/roles", response_model=Dict[str, Any])
@limiter.limit("10/minute")
def create_template_role(
@@ -976,6 +1070,56 @@ def list_feature_instance_users(
)
+@router.get("/instances/{instanceId}/users/filter-values")
+@limiter.limit("60/minute")
+def get_feature_instance_users_filter_values(
+ request: Request,
+ instanceId: str,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in feature instance users."""
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ rootInterface = getRootInterface()
+ featureInterface = getFeatureInterface(rootInterface.db)
+ instance = featureInterface.getFeatureInstance(instanceId)
+ if not instance:
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Feature instance '{instanceId}' not found")
+ if context.mandateId and str(instance.mandateId) != str(context.mandateId):
+ if not context.hasSysAdminRole:
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied to this feature instance")
+ featureAccesses = rootInterface.getFeatureAccessesByInstance(instanceId)
+ result = []
+ for fa in featureAccesses:
+ user = rootInterface.getUser(str(fa.userId))
+ if not user:
+ continue
+ roleIds = rootInterface.getRoleIdsForFeatureAccess(str(fa.id))
+ roleLabels = []
+ for roleId in roleIds:
+ role = rootInterface.getRole(roleId)
+ if role:
+ roleLabels.append(role.roleLabel)
+ result.append({
+ "id": str(fa.id),
+ "userId": str(fa.userId),
+ "username": user.username,
+ "email": user.email,
+ "fullName": user.fullName,
+ "roleIds": roleIds,
+ "roleLabels": roleLabels,
+ "enabled": fa.enabled
+ })
+ return _handleFilterValuesRequest(result, column, pagination)
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting filter values for feature instance users: {e}")
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
+
+
@router.post("/instances/{instanceId}/users", response_model=Dict[str, Any])
@limiter.limit("30/minute")
def add_user_to_feature_instance(
diff --git a/modules/routes/routeAdminRbacRules.py b/modules/routes/routeAdminRbacRules.py
index a7d4637e..3778d227 100644
--- a/modules/routes/routeAdminRbacRules.py
+++ b/modules/routes/routeAdminRbacRules.py
@@ -394,7 +394,10 @@ def get_access_rules(
)
# Get rules with optional pagination
- # MandateAdmin: fetch all then filter by admin's mandates
+ # MandateAdmin: fetch all then filter by admin's mandates.
+ # NOTE: Cannot use DB-level pagination for MandateAdmin because
+ # _isRoleInAdminMandates requires joining Role → mandateId which
+ # isn't expressible via getRecordsetPaginated's recordFilter.
if not isSysAdmin:
allRules = interface.getAccessRules(
roleLabel=roleLabel,
@@ -814,6 +817,11 @@ def list_roles(
By default, only returns true global roles (mandateId=None, featureInstanceId=None, featureCode=None).
Feature template roles are managed via /api/features/templates/roles.
+ NOTE: Base query (getAllRoles) already uses db.getRecordsetPaginated() internally.
+ However pagination=None is passed here because post-processing adds computed fields
+ (userCount, scopeType) and applies scope/mandate/template filtering that cannot run
+ at the DB level. In-memory pagination is applied after all transformations.
+
Args:
pagination: Optional pagination parameters (includes search, filters, sort)
includeTemplates: If True, also include feature template roles (featureCode != None)
@@ -983,6 +991,77 @@ def list_roles(
)
+@router.get("/roles/filter-values")
+@limiter.limit("60/minute")
+def get_roles_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ includeTemplates: bool = Query(False, description="Include feature template roles"),
+ mandateId: Optional[str] = Query(None, description="Include mandate-specific roles for this mandate"),
+ scopeFilter: Optional[str] = Query(None, description="Filter by scope: 'all', 'mandate', 'global', 'system'"),
+ reqContext: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in roles."""
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ isSysAdmin = reqContext.hasSysAdminRole
+ adminMandateIds = [] if isSysAdmin else _getAdminMandateIds(reqContext)
+ if not isSysAdmin and not adminMandateIds:
+ raise HTTPException(status_code=403, detail="Admin role required")
+
+ interface = getRootInterface()
+ dbRoles = interface.getAllRoles(pagination=None)
+ roleCounts = interface.countRoleAssignments()
+
+ def _computeScopeType(role) -> str:
+ if role.mandateId:
+ return "mandate"
+ if role.isSystemRole:
+ return "system"
+ return "global"
+
+ result = []
+ for role in dbRoles:
+ if role.featureInstanceId is not None:
+ continue
+ if mandateId:
+ if role.mandateId != mandateId:
+ continue
+ else:
+ if role.mandateId is not None:
+ continue
+ if not includeTemplates and role.featureCode is not None:
+ continue
+ scopeType = _computeScopeType(role)
+ if scopeFilter and scopeFilter != 'all':
+ if scopeFilter == 'mandate' and scopeType != 'mandate':
+ continue
+ if scopeFilter == 'global' and scopeType not in ('global', 'system'):
+ continue
+ if scopeFilter == 'system' and scopeType != 'system':
+ continue
+ result.append({
+ "id": role.id,
+ "roleLabel": role.roleLabel,
+ "description": role.description,
+ "mandateId": role.mandateId,
+ "featureInstanceId": role.featureInstanceId,
+ "featureCode": role.featureCode,
+ "userCount": roleCounts.get(str(role.id), 0),
+ "isSystemRole": role.isSystemRole,
+ "scopeType": scopeType
+ })
+ if not isSysAdmin:
+ result = [r for r in result if r.get("mandateId") and str(r["mandateId"]) in adminMandateIds]
+ return _handleFilterValuesRequest(result, column, pagination)
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting filter values for roles: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@router.post("/roles", response_model=Dict[str, Any])
@limiter.limit("30/minute")
def create_role(
diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py
index e1c23b48..04412752 100644
--- a/modules/routes/routeBilling.py
+++ b/modules/routes/routeBilling.py
@@ -22,8 +22,10 @@ from modules.auth import limiter, requireSysAdminRole, getRequestContext, Reques
# Import billing components
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface, _getRootInterface
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService
+import json
+import math
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
-from modules.routes.routeDataUsers import _applyFiltersAndSort
+from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues, _handleFilterValuesRequest
from modules.datamodels.datamodelBilling import (
BillingAccount,
BillingTransaction,
@@ -1402,49 +1404,187 @@ def getUsersForMandate(
raise HTTPException(status_code=500, detail=str(e))
-@router.get("/admin/transactions/{targetMandateId}", response_model=List[TransactionResponse])
+def _enrichTransactionRows(transactions) -> List[Dict[str, Any]]:
+ """Convert raw transaction dicts to enriched TransactionResponse rows with resolved usernames."""
+ result = []
+ for t in transactions:
+ row = TransactionResponse(
+ id=t.get("id"),
+ accountId=t.get("accountId"),
+ transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
+ amount=t.get("amount", 0.0),
+ description=t.get("description", ""),
+ referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
+ workflowId=t.get("workflowId"),
+ featureCode=t.get("featureCode"),
+ featureInstanceId=t.get("featureInstanceId"),
+ aicoreProvider=t.get("aicoreProvider"),
+ aicoreModel=t.get("aicoreModel"),
+ createdByUserId=t.get("createdByUserId"),
+ createdAt=t.get("_createdAt")
+ )
+ result.append(row.model_dump())
+
+ try:
+ from modules.interfaces.interfaceDbUam import _getRootInterface as getUamRoot
+ uamInterface = getUamRoot()
+ userNames: Dict[str, str] = {}
+ for row in result:
+ uid = row.get("createdByUserId")
+ if uid and uid not in userNames:
+ try:
+ user = uamInterface.getUser(uid)
+ userNames[uid] = user.get("username", uid[:8]) if user else uid[:8]
+ except Exception:
+ userNames[uid] = uid[:8]
+ row["userName"] = userNames.get(uid, "") if uid else ""
+ except Exception:
+ for row in result:
+ row["userName"] = row.get("createdByUserId", "")[:8] if row.get("createdByUserId") else ""
+
+ return result
+
+
+def _buildTransactionsList(ctx: RequestContext, targetMandateId: str) -> List[Dict[str, Any]]:
+ """Build the full enriched transactions list for a mandate."""
+ billingInterface = getBillingInterface(ctx.user, targetMandateId)
+ transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=5000)
+
+ result = []
+ for t in transactions:
+ row = TransactionResponse(
+ id=t.get("id"),
+ accountId=t.get("accountId"),
+ transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
+ amount=t.get("amount", 0.0),
+ description=t.get("description", ""),
+ referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
+ workflowId=t.get("workflowId"),
+ featureCode=t.get("featureCode"),
+ featureInstanceId=t.get("featureInstanceId"),
+ aicoreProvider=t.get("aicoreProvider"),
+ aicoreModel=t.get("aicoreModel"),
+ createdByUserId=t.get("createdByUserId"),
+ createdAt=t.get("_createdAt")
+ )
+ result.append(row.model_dump())
+
+ # Resolve user names
+ try:
+ from modules.interfaces.interfaceDbUam import _getRootInterface as getUamRoot
+ uamInterface = getUamRoot()
+ userNames: Dict[str, str] = {}
+ for row in result:
+ uid = row.get("createdByUserId")
+ if uid and uid not in userNames:
+ try:
+ user = uamInterface.getUser(uid)
+ userNames[uid] = user.get("username", uid[:8]) if user else uid[:8]
+ except Exception:
+ userNames[uid] = uid[:8]
+ row["userName"] = userNames.get(uid, "") if uid else ""
+ except Exception:
+ for row in result:
+ row["userName"] = row.get("createdByUserId", "")[:8] if row.get("createdByUserId") else ""
+
+ return result
+
+
+@router.get("/admin/transactions/{targetMandateId}")
@limiter.limit("30/minute")
def getTransactionsAdmin(
request: Request,
targetMandateId: str = Path(..., description="Mandate ID"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
limit: int = Query(default=100, ge=1, le=1000),
ctx: RequestContext = Depends(getRequestContext),
):
- """
- Get all transactions for a mandate.
- Access: SysAdmin (any mandate) or MandateAdmin (own mandate).
- """
+ """Get all transactions for a mandate with pagination support."""
if not _isAdminOfMandate(ctx, targetMandateId):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate")
try:
- billingInterface = getBillingInterface(ctx.user, targetMandateId)
- transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=limit)
-
- result = []
- for t in transactions:
- result.append(TransactionResponse(
- id=t.get("id"),
- accountId=t.get("accountId"),
- transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
- amount=t.get("amount", 0.0),
- description=t.get("description", ""),
- referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
- workflowId=t.get("workflowId"),
- featureCode=t.get("featureCode"),
- featureInstanceId=t.get("featureInstanceId"),
- aicoreProvider=t.get("aicoreProvider"),
- aicoreModel=t.get("aicoreModel"),
- createdByUserId=t.get("createdByUserId"),
- createdAt=t.get("_createdAt")
- ))
-
- return result
-
+ paginationParams: Optional[PaginationParams] = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ paginationParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError) as e:
+ raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
+
+ if paginationParams:
+ # DB-level pagination — enrich only the returned page
+ billingInterface = getBillingInterface(ctx.user, targetMandateId)
+ result = billingInterface.getTransactionsByMandate(targetMandateId, pagination=paginationParams)
+ transactions = result.items if hasattr(result, 'items') else result
+ enrichedItems = _enrichTransactionRows(transactions)
+ return {
+ "items": enrichedItems,
+ "pagination": PaginationMetadata(
+ currentPage=paginationParams.page,
+ pageSize=paginationParams.pageSize,
+ totalItems=result.totalItems if hasattr(result, 'totalItems') else len(enrichedItems),
+ totalPages=result.totalPages if hasattr(result, 'totalPages') else 0,
+ sort=paginationParams.sort,
+ filters=paginationParams.filters,
+ ).model_dump(),
+ }
+
+ enriched = _buildTransactionsList(ctx, targetMandateId)
+ return {"items": enriched, "pagination": None}
+
+ except HTTPException:
+ raise
except Exception as e:
logger.error(f"Error getting billing transactions for mandate {targetMandateId}: {e}")
raise HTTPException(status_code=500, detail=str(e))
+@router.get("/admin/transactions/{targetMandateId}/filter-values")
+@limiter.limit("60/minute")
+def getTransactionFilterValues(
+ request: Request,
+ targetMandateId: str = Path(..., description="Mandate ID"),
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ ctx: RequestContext = Depends(getRequestContext),
+):
+ """Return distinct filter values for a column in mandate transactions."""
+ if not _isAdminOfMandate(ctx, targetMandateId):
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate")
+ try:
+ crossFilterParams: Optional[PaginationParams] = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ # Try SQL DISTINCT for native DB columns; fallback to in-memory for enriched columns (e.g. userName)
+ try:
+ rootBillingInterface = _getRootInterface()
+ recordFilter = {"mandateId": targetMandateId}
+ values = rootBillingInterface.db.getDistinctColumnValues(
+ BillingTransaction, column, crossFilterParams, recordFilter
+ )
+ return sorted(values, key=lambda v: str(v).lower())
+ except Exception:
+ enriched = _buildTransactionsList(ctx, targetMandateId)
+ crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
+ return _extractDistinctValues(crossFiltered, column)
+ except Exception as e:
+ logger.error(f"Error getting filter values for transactions: {e}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+
# =============================================================================
# Mandate View Endpoints (for Admins)
# =============================================================================
@@ -1892,3 +2032,50 @@ def getUserViewTransactions(
except Exception as e:
logger.error(f"Error getting user view transactions: {e}")
raise HTTPException(status_code=500, detail=str(e))
+
+
+@router.get("/view/users/transactions/filter-values")
+@limiter.limit("60/minute")
+def getUserViewTransactionsFilterValues(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ ctx: RequestContext = Depends(getRequestContext)
+):
+ """Return distinct filter values for a column in user transactions."""
+ try:
+ billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
+ scope = _getBillingDataScope(ctx.user)
+ if scope.isGlobalAdmin:
+ mandateIds = None
+ else:
+ mandateIds = scope.adminMandateIds + scope.memberMandateIds
+ if not mandateIds:
+ return []
+ allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
+ allTransactions = _filterTransactionsByScope(allTransactions, scope)
+ transactionDicts = []
+ for t in allTransactions:
+ transactionDicts.append({
+ "id": t.get("id"),
+ "accountId": t.get("accountId"),
+ "transactionType": t.get("transactionType", "DEBIT"),
+ "amount": t.get("amount", 0.0),
+ "description": t.get("description", ""),
+ "referenceType": t.get("referenceType"),
+ "workflowId": t.get("workflowId"),
+ "featureCode": t.get("featureCode"),
+ "featureInstanceId": t.get("featureInstanceId"),
+ "aicoreProvider": t.get("aicoreProvider"),
+ "aicoreModel": t.get("aicoreModel"),
+ "createdByUserId": t.get("createdByUserId"),
+ "createdAt": t.get("_createdAt"),
+ "mandateId": t.get("mandateId"),
+ "mandateName": t.get("mandateName"),
+ "userId": t.get("userId"),
+ "userName": t.get("userName"),
+ })
+ return _handleFilterValuesRequest(transactionDicts, column, pagination)
+ except Exception as e:
+ logger.error(f"Error getting filter values for user transactions: {e}")
+ raise HTTPException(status_code=500, detail=str(e))
diff --git a/modules/routes/routeDataConnections.py b/modules/routes/routeDataConnections.py
index 2a6be738..17ef0115 100644
--- a/modules/routes/routeDataConnections.py
+++ b/modules/routes/routeDataConnections.py
@@ -147,6 +147,11 @@ async def get_connections(
try:
interface = getInterface(currentUser)
+ # NOTE: Cannot use db.getRecordsetPaginated() here because each connection
+ # is enriched with computed tokenStatus/tokenExpiresAt (requires per-row DB lookup).
+ # Token refresh also may trigger re-fetch. Connections per user are typically < 10,
+ # so in-memory pagination is acceptable.
+
# Parse pagination parameter
paginationParams = None
if pagination:
@@ -287,6 +292,42 @@ async def get_connections(
detail=f"Failed to get connections: {str(e)}"
)
+@router.get("/filter-values")
+@limiter.limit("60/minute")
+def get_connection_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ currentUser: User = Depends(getCurrentUser)
+) -> List[str]:
+ """Return distinct filter values for a column in connections."""
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ interface = getInterface(currentUser)
+ connections = interface.getUserConnections(currentUser.id)
+ items = []
+ for connection in connections:
+ tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connection.id)
+ items.append({
+ "id": connection.id,
+ "userId": connection.userId,
+ "authority": connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority),
+ "externalId": connection.externalId,
+ "externalUsername": connection.externalUsername or "",
+ "externalEmail": connection.externalEmail,
+ "status": connection.status.value if hasattr(connection.status, 'value') else str(connection.status),
+ "connectedAt": connection.connectedAt,
+ "lastChecked": connection.lastChecked,
+ "expiresAt": connection.expiresAt,
+ "tokenStatus": tokenStatus,
+ "tokenExpiresAt": tokenExpiresAt
+ })
+ return _handleFilterValuesRequest(items, column, pagination)
+ except Exception as e:
+ logger.error(f"Error getting filter values for connections: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@router.post("/", response_model=UserConnection)
@limiter.limit("10/minute")
def create_connection(
diff --git a/modules/routes/routeDataFiles.py b/modules/routes/routeDataFiles.py
index 470793bb..bb77ffc5 100644
--- a/modules/routes/routeDataFiles.py
+++ b/modules/routes/routeDataFiles.py
@@ -214,6 +214,56 @@ def get_files(
detail=f"Failed to get files: {str(e)}"
)
+@router.get("/list/filter-values")
+@limiter.limit("60/minute")
+def get_file_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ currentUser: User = Depends(getCurrentUser),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in files."""
+ try:
+ managementInterface = interfaceDbManagement.getInterface(
+ currentUser,
+ mandateId=str(context.mandateId) if context.mandateId else None,
+ featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None
+ )
+
+ crossFilterPagination = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterPagination = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ try:
+ recordFilter = {"_createdBy": managementInterface.userId}
+ values = managementInterface.db.getDistinctColumnValues(
+ FileItem, column, crossFilterPagination, recordFilter
+ )
+ return sorted(values, key=lambda v: str(v).lower())
+ except Exception:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ result = managementInterface.getAllFiles(pagination=None)
+ items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in result]
+ return _handleFilterValuesRequest(items, column, pagination)
+ except Exception as e:
+ logger.error(f"Error getting filter values for files: {str(e)}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=str(e)
+ )
+
+
@router.post("/upload", status_code=status.HTTP_201_CREATED)
@limiter.limit("10/minute")
async def upload_file(
diff --git a/modules/routes/routeDataMandates.py b/modules/routes/routeDataMandates.py
index 16b7166d..2c2bd31c 100644
--- a/modules/routes/routeDataMandates.py
+++ b/modules/routes/routeDataMandates.py
@@ -164,6 +164,66 @@ def get_mandates(
detail=f"Failed to get mandates: {str(e)}"
)
+@router.get("/filter-values")
+@limiter.limit("60/minute")
+def get_mandate_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in mandates."""
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ isSysAdmin = context.hasSysAdminRole
+ if not isSysAdmin:
+ adminMandateIds = _getAdminMandateIds(context)
+ if not adminMandateIds:
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required")
+
+ appInterface = interfaceDbApp.getRootInterface()
+
+ if isSysAdmin:
+ # SysAdmin: try SQL DISTINCT for DB columns
+ crossFilterPagination = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterPagination = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError):
+ pass
+ try:
+ values = appInterface.db.getDistinctColumnValues(
+ Mandate, column, crossFilterPagination
+ )
+ return sorted(values, key=lambda v: str(v).lower())
+ except Exception:
+ result = appInterface.getAllMandates(pagination=None)
+ items = result if isinstance(result, list) else (result.items if hasattr(result, 'items') else result)
+ items = [i.model_dump() if hasattr(i, 'model_dump') else i for i in items]
+ return _handleFilterValuesRequest(items, column, pagination)
+ else:
+ # MandateAdmin: in-memory (small set of individual mandate lookups)
+ result = []
+ for mid in adminMandateIds:
+ mandate = appInterface.getMandate(mid)
+ if mandate:
+ result.append(mandate if isinstance(mandate, dict) else mandate.model_dump() if hasattr(mandate, 'model_dump') else vars(mandate))
+ items = [i.model_dump() if hasattr(i, 'model_dump') else i for i in result]
+ return _handleFilterValuesRequest(items, column, pagination)
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting filter values for mandates: {str(e)}")
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
+
+
@router.get("/{targetMandateId}", response_model=Mandate)
@limiter.limit("30/minute")
def get_mandate(
@@ -540,6 +600,63 @@ def list_mandate_users(
)
+@router.get("/{targetMandateId}/users/filter-values")
+@limiter.limit("60/minute")
+def get_mandate_users_filter_values(
+ request: Request,
+ targetMandateId: str = Path(..., description="ID of the mandate"),
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in mandate users."""
+ if not _hasMandateAdminRole(context, targetMandateId) and not context.hasSysAdminRole:
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate-Admin role required")
+
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ rootInterface = interfaceDbApp.getRootInterface()
+ mandate = rootInterface.getMandate(targetMandateId)
+ if not mandate:
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Mandate {targetMandateId} not found")
+
+ userMandates = rootInterface.getUserMandatesByMandate(targetMandateId)
+ result = []
+ for um in userMandates:
+ user = rootInterface.getUser(str(um.userId))
+ if not user:
+ continue
+ roleIds = rootInterface.getRoleIdsForUserMandate(str(um.id))
+ roleLabels = []
+ filteredRoleIds = []
+ seenLabels = set()
+ for roleId in roleIds:
+ role = rootInterface.getRole(roleId)
+ if role:
+ if role.featureInstanceId:
+ continue
+ filteredRoleIds.append(roleId)
+ if role.roleLabel not in seenLabels:
+ roleLabels.append(role.roleLabel)
+ seenLabels.add(role.roleLabel)
+ result.append({
+ "id": str(um.id),
+ "userId": str(user.id),
+ "username": user.username,
+ "email": user.email,
+ "fullName": user.fullName,
+ "roleIds": filteredRoleIds,
+ "roleLabels": roleLabels,
+ "enabled": um.enabled
+ })
+ return _handleFilterValuesRequest(result, column, pagination)
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting filter values for mandate users: {str(e)}")
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
+
+
@router.post("/{targetMandateId}/users", response_model=UserMandateResponse)
@limiter.limit("30/minute")
def add_user_to_mandate(
diff --git a/modules/routes/routeDataPrompts.py b/modules/routes/routeDataPrompts.py
index faf692fe..f9246ab6 100644
--- a/modules/routes/routeDataPrompts.py
+++ b/modules/routes/routeDataPrompts.py
@@ -81,6 +81,29 @@ def get_prompts(
pagination=None
)
+@router.get("/filter-values")
+@limiter.limit("60/minute")
+def get_prompt_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ currentUser: User = Depends(getCurrentUser)
+) -> list:
+ """Return distinct filter values for a column in prompts.
+
+ NOTE: Cannot use db.getDistinctColumnValues() because visibility rules
+ (own + system for regular users) require pre-filtering the recordset.
+ """
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ managementInterface = interfaceDbManagement.getInterface(currentUser)
+ result = managementInterface.getAllPrompts(pagination=None)
+ items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in result]
+ return _handleFilterValuesRequest(items, column, pagination)
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+
+
@router.post("", response_model=Prompt)
@limiter.limit("10/minute")
def create_prompt(
diff --git a/modules/routes/routeDataUsers.py b/modules/routes/routeDataUsers.py
index 16742388..7e903466 100644
--- a/modules/routes/routeDataUsers.py
+++ b/modules/routes/routeDataUsers.py
@@ -69,6 +69,55 @@ def _isAdminForUser(context: RequestContext, targetUserId: str) -> bool:
return False
+def _extractDistinctValues(items: List[Dict[str, Any]], columnKey: str) -> List[str]:
+ """Extract sorted distinct display values for a column from enriched items."""
+ values = set()
+ for item in items:
+ val = item.get(columnKey)
+ if val is None or val == "":
+ continue
+ if isinstance(val, bool):
+ values.add("true" if val else "false")
+ elif isinstance(val, (int, float)):
+ values.add(str(val))
+ elif isinstance(val, dict):
+ text = val.get("en") or next((v for v in val.values() if isinstance(v, str) and v), None)
+ if text:
+ values.add(str(text))
+ else:
+ values.add(str(val))
+ return sorted(values, key=lambda v: v.lower())
+
+
+def _handleFilterValuesRequest(
+ items: List[Dict[str, Any]],
+ column: str,
+ paginationJson: Optional[str] = None,
+) -> List[str]:
+ """
+ Generic handler for /filter-values endpoints.
+ Applies all active filters EXCEPT the one for the requested column (cross-filtering),
+ then extracts distinct values for that column.
+ """
+ crossFilterParams: Optional[PaginationParams] = None
+ if paginationJson:
+ try:
+ import json
+ paginationDict = json.loads(paginationJson)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ crossFiltered = _applyFiltersAndSort(items, crossFilterParams)
+ return _extractDistinctValues(crossFiltered, column)
+
+
def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional[PaginationParams]) -> List[Dict[str, Any]]:
"""
Apply filters and sorting to a list of items.
@@ -121,7 +170,6 @@ def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional
if itemValue is None:
return False
- # Convert to string for comparison if needed
itemStr = str(itemValue).lower()
valueStr = str(v).lower()
@@ -147,6 +195,42 @@ def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional
return itemNum <= valueNum
except (ValueError, TypeError):
return False
+ elif op == 'between':
+ if isinstance(v, dict):
+ fromVal = v.get('from', '')
+ toVal = v.get('to', '')
+ if not fromVal and not toVal:
+ return True
+ # Date range: from/to are YYYY-MM-DD strings, itemValue may be Unix timestamp
+ try:
+ from datetime import datetime, timezone
+ fromTs = None
+ toTs = None
+ if fromVal:
+ fromTs = datetime.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=timezone.utc).timestamp()
+ if toVal:
+ toTs = datetime.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=timezone.utc).timestamp()
+ itemNum = float(itemValue) if not isinstance(itemValue, (int, float)) else itemValue
+ # Normalize: if item looks like a millisecond timestamp, convert to seconds
+ if itemNum > 10000000000:
+ itemNum = itemNum / 1000
+ if fromTs is not None and toTs is not None:
+ return fromTs <= itemNum <= toTs
+ elif fromTs is not None:
+ return itemNum >= fromTs
+ elif toTs is not None:
+ return itemNum <= toTs
+ except (ValueError, TypeError):
+ # Fallback: string comparison (for non-numeric date fields)
+ fromStr = str(fromVal).lower() if fromVal else ''
+ toStr = str(toVal).lower() if toVal else ''
+ if fromStr and toStr:
+ return fromStr <= itemStr <= toStr
+ elif fromStr:
+ return itemStr >= fromStr
+ elif toStr:
+ return itemStr <= toStr
+ return True
elif op == 'in':
if isinstance(v, list):
return itemStr in [str(x).lower() for x in v]
@@ -159,23 +243,25 @@ def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional
result = [item for item in result if matchesFilter(item, field, operator, value)]
- # Apply sorting
+ # Apply sorting — None values always last
if paginationParams.sort:
for sortField in reversed(paginationParams.sort):
fieldName = sortField.field
ascending = sortField.direction == 'asc'
- def getSortKey(item: Dict[str, Any]):
- value = item.get(fieldName)
- if value is None:
- return (1, '') # Nulls last
- if isinstance(value, bool):
- return (0, not value if ascending else value)
- if isinstance(value, (int, float)):
- return (0, value)
- return (0, str(value).lower())
+ noneItems = [item for item in result if item.get(fieldName) is None]
+ nonNoneItems = [item for item in result if item.get(fieldName) is not None]
- result = sorted(result, key=getSortKey, reverse=not ascending)
+ def getSortKey(item: Dict[str, Any], _fn=fieldName):
+ value = item.get(_fn)
+ if isinstance(value, bool):
+ return (0, int(value), '')
+ if isinstance(value, (int, float)):
+ return (0, value, '')
+ return (1, 0, str(value).lower())
+
+ nonNoneItems = sorted(nonNoneItems, key=getSortKey, reverse=not ascending)
+ result = nonNoneItems + noneItems
return result
@@ -291,38 +377,23 @@ def get_users(
pagination=None
)
elif context.hasSysAdminRole:
- # SysAdmin without mandateId sees all users
- # Get all users via interface method (returns Pydantic User models)
- allUserModels = appInterface.getAllUsers()
- # Convert to dictionaries for filtering/sorting
- cleanedUsers = [u.model_dump() for u in allUserModels]
+ # SysAdmin without mandateId — DB-level pagination via interface
+ result = appInterface.getAllUsers(paginationParams)
- # Apply server-side filtering and sorting
- filteredUsers = _applyFiltersAndSort(cleanedUsers, paginationParams)
-
- # Convert to User objects
- users = [User(**u) for u in filteredUsers]
-
- if paginationParams:
- import math
- totalItems = len(users)
- totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
- startIdx = (paginationParams.page - 1) * paginationParams.pageSize
- endIdx = startIdx + paginationParams.pageSize
- paginatedUsers = users[startIdx:endIdx]
-
+ if paginationParams and hasattr(result, 'items'):
return PaginatedResponse(
- items=paginatedUsers,
+ items=result.items,
pagination=PaginationMetadata(
currentPage=paginationParams.page,
pageSize=paginationParams.pageSize,
- totalItems=totalItems,
- totalPages=totalPages,
+ totalItems=result.totalItems,
+ totalPages=result.totalPages,
sort=paginationParams.sort,
filters=paginationParams.filters
)
)
else:
+ users = result if isinstance(result, list) else (result.items if hasattr(result, 'items') else [])
return PaginatedResponse(
items=users,
pagination=None
@@ -407,6 +478,88 @@ def get_users(
detail=f"Failed to get users: {str(e)}"
)
+@router.get("/filter-values")
+@limiter.limit("60/minute")
+def get_user_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in users."""
+ try:
+ appInterface = interfaceDbApp.getInterface(context.user, mandateId=context.mandateId)
+
+ # Build cross-filter pagination (all filters except the requested column)
+ crossFilterPagination = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterPagination = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ if context.mandateId:
+ # Mandate-scoped: in-memory (users require UserMandate join)
+ result = appInterface.getUsersByMandate(str(context.mandateId), None)
+ users = result if isinstance(result, list) else (result.items if hasattr(result, 'items') else [])
+ items = [u.model_dump() if hasattr(u, 'model_dump') else u for u in users]
+ return _handleFilterValuesRequest(items, column, pagination)
+ elif context.hasSysAdminRole:
+ # SysAdmin: use SQL DISTINCT for DB columns
+ try:
+ rootInterface = getRootInterface()
+ values = rootInterface.db.getDistinctColumnValues(
+ UserInDB, column, crossFilterPagination
+ )
+ return sorted(values, key=lambda v: v.lower())
+ except Exception:
+ users = appInterface.getAllUsers()
+ items = [u.model_dump() if hasattr(u, 'model_dump') else u for u in users]
+ return _handleFilterValuesRequest(items, column, pagination)
+ else:
+ # Non-admin multi-mandate: aggregate across admin mandates (in-memory)
+ rootInterface = getRootInterface()
+ userMandates = rootInterface.getUserMandates(str(context.user.id))
+ adminMandateIds = []
+ for um in userMandates:
+ umId = getattr(um, 'id', None)
+ mandateId = getattr(um, 'mandateId', None)
+ if not umId or not mandateId:
+ continue
+ roleIds = rootInterface.getRoleIdsForUserMandate(str(umId))
+ for roleId in roleIds:
+ role = rootInterface.getRole(roleId)
+ if role and role.roleLabel == "admin" and not role.featureInstanceId:
+ adminMandateIds.append(str(mandateId))
+ break
+ if not adminMandateIds:
+ return []
+ seenUserIds = set()
+ users = []
+ for mid in adminMandateIds:
+ mandateUsers = rootInterface.getUsersByMandate(mid)
+ uList = mandateUsers if isinstance(mandateUsers, list) else (mandateUsers.items if hasattr(mandateUsers, 'items') else [])
+ for u in uList:
+ uid = u.get("id") if isinstance(u, dict) else getattr(u, "id", None)
+ if uid and uid not in seenUserIds:
+ seenUserIds.add(uid)
+ users.append(u)
+ items = [u.model_dump() if hasattr(u, 'model_dump') else u for u in users]
+ return _handleFilterValuesRequest(items, column, pagination)
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting filter values for users: {str(e)}")
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
+
+
@router.get("/{userId}", response_model=User)
@limiter.limit("30/minute")
def get_user(
diff --git a/modules/routes/routeInvitations.py b/modules/routes/routeInvitations.py
index 906b11e7..cb913137 100644
--- a/modules/routes/routeInvitations.py
+++ b/modules/routes/routeInvitations.py
@@ -420,6 +420,13 @@ def list_invitations(
Requires Mandate-Admin role. Returns all invitations created for this mandate.
+ NOTE: Cannot use db.getRecordsetPaginated() because:
+ - Computed status fields (isExpired, isUsedUp) are derived in-memory
+ - Filtering by revoked/used/expired requires post-fetch logic
+ - Invitation volume per mandate is typically low (< 100)
+ When this endpoint needs FormGeneratorTable pagination, add PaginatedResponse
+ support with in-memory slicing (similar to routeDataConnections).
+
Args:
includeUsed: Include invitations that have reached maxUses
includeExpired: Include expired invitations
@@ -485,6 +492,54 @@ def list_invitations(
)
+@router.get("/filter-values")
+@limiter.limit("60/minute")
+def get_invitation_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ frontendUrl: str = Query("", description="Frontend URL for building invite links"),
+ includeUsed: bool = Query(False, description="Include already used invitations"),
+ includeExpired: bool = Query(False, description="Include expired invitations"),
+ context: RequestContext = Depends(getRequestContext)
+) -> list:
+ """Return distinct filter values for a column in invitations."""
+ if not context.mandateId:
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="X-Mandate-Id header is required")
+ if not _hasMandateAdminRole(context):
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate-Admin role required")
+ try:
+ from modules.routes.routeDataUsers import _handleFilterValuesRequest
+ rootInterface = getRootInterface()
+ allInvitations = rootInterface.getInvitationsByMandate(str(context.mandateId))
+ currentTime = getUtcTimestamp()
+ result = []
+ for inv in allInvitations:
+ if inv.revokedAt:
+ continue
+ currentUses = inv.currentUses or 0
+ maxUses = inv.maxUses or 1
+ if not includeUsed and currentUses >= maxUses:
+ continue
+ expiresAt = inv.expiresAt or 0
+ if not includeExpired and expiresAt < currentTime:
+ continue
+ baseUrl = frontendUrl.rstrip("/") if frontendUrl else ""
+ inviteUrl = f"{baseUrl}/invite/{inv.token}" if baseUrl else ""
+ result.append({
+ **inv.model_dump(),
+ "inviteUrl": inviteUrl,
+ "isExpired": expiresAt < currentTime,
+ "isUsedUp": currentUses >= maxUses
+ })
+ return _handleFilterValuesRequest(result, column, pagination)
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting filter values for invitations: {e}")
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
+
+
@router.delete("/{invitationId}", response_model=Dict[str, str])
@limiter.limit("30/minute")
def revoke_invitation(
diff --git a/modules/routes/routeMessaging.py b/modules/routes/routeMessaging.py
index 0b4784ff..42e15f0e 100644
--- a/modules/routes/routeMessaging.py
+++ b/modules/routes/routeMessaging.py
@@ -21,7 +21,7 @@ from modules.datamodels.datamodelMessaging import (
MessagingSubscriptionExecutionResult
)
from modules.datamodels.datamodelUam import User
-from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
+from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
# Configure logger
logger = logging.getLogger(__name__)
@@ -48,6 +48,8 @@ def get_subscriptions(
if pagination:
try:
paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
@@ -185,6 +187,8 @@ def get_subscription_registrations(
if pagination:
try:
paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
@@ -277,6 +281,8 @@ def get_my_registrations(
if pagination:
try:
paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
@@ -450,6 +456,8 @@ def get_deliveries(
if pagination:
try:
paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
except (json.JSONDecodeError, ValueError) as e:
raise HTTPException(
diff --git a/modules/routes/routeRealEstate.py b/modules/routes/routeRealEstate.py
index 881a77dc..a3466aca 100644
--- a/modules/routes/routeRealEstate.py
+++ b/modules/routes/routeRealEstate.py
@@ -227,33 +227,22 @@ async def get_projects(
context.user, mandateId=mandateId, featureInstanceId=instanceId
)
recordFilter = {"featureInstanceId": instanceId}
- items = interface.getProjekte(recordFilter=recordFilter)
paginationParams = _parsePagination(pagination)
if paginationParams:
- if paginationParams.sort:
- for sort_field in reversed(paginationParams.sort):
- field_name = sort_field.field
- direction = sort_field.direction.lower()
- items.sort(
- key=lambda x: getattr(x, field_name, None),
- reverse=(direction == "desc")
+ result = interface.getProjekte(pagination=paginationParams, recordFilter=recordFilter)
+ if hasattr(result, 'items'):
+ return PaginatedResponse(
+ items=result.items,
+ pagination=PaginationMetadata(
+ currentPage=paginationParams.page,
+ pageSize=paginationParams.pageSize,
+ totalItems=result.totalItems,
+ totalPages=result.totalPages,
+ sort=paginationParams.sort or [],
+ filters=paginationParams.filters
)
- total_items = len(items)
- total_pages = (total_items + paginationParams.pageSize - 1) // paginationParams.pageSize
- start_idx = (paginationParams.page - 1) * paginationParams.pageSize
- end_idx = start_idx + paginationParams.pageSize
- paginated_items = items[start_idx:end_idx]
- return PaginatedResponse(
- items=paginated_items,
- pagination=PaginationMetadata(
- currentPage=paginationParams.page,
- pageSize=paginationParams.pageSize,
- totalItems=total_items,
- totalPages=total_pages,
- sort=paginationParams.sort or [],
- filters=paginationParams.filters
)
- )
+ items = interface.getProjekte(recordFilter=recordFilter)
return PaginatedResponse(items=items, pagination=None)
@@ -359,33 +348,22 @@ async def get_parcels(
context.user, mandateId=mandateId, featureInstanceId=instanceId
)
recordFilter = {"featureInstanceId": instanceId}
- items = interface.getParzellen(recordFilter=recordFilter)
paginationParams = _parsePagination(pagination)
if paginationParams:
- if paginationParams.sort:
- for sort_field in reversed(paginationParams.sort):
- field_name = sort_field.field
- direction = sort_field.direction.lower()
- items.sort(
- key=lambda x: getattr(x, field_name, None),
- reverse=(direction == "desc")
+ result = interface.getParzellen(pagination=paginationParams, recordFilter=recordFilter)
+ if hasattr(result, 'items'):
+ return PaginatedResponse(
+ items=result.items,
+ pagination=PaginationMetadata(
+ currentPage=paginationParams.page,
+ pageSize=paginationParams.pageSize,
+ totalItems=result.totalItems,
+ totalPages=result.totalPages,
+ sort=paginationParams.sort or [],
+ filters=paginationParams.filters
)
- total_items = len(items)
- total_pages = (total_items + paginationParams.pageSize - 1) // paginationParams.pageSize
- start_idx = (paginationParams.page - 1) * paginationParams.pageSize
- end_idx = start_idx + paginationParams.pageSize
- paginated_items = items[start_idx:end_idx]
- return PaginatedResponse(
- items=paginated_items,
- pagination=PaginationMetadata(
- currentPage=paginationParams.page,
- pageSize=paginationParams.pageSize,
- totalItems=total_items,
- totalPages=total_pages,
- sort=paginationParams.sort or [],
- filters=paginationParams.filters
)
- )
+ items = interface.getParzellen(recordFilter=recordFilter)
return PaginatedResponse(items=items, pagination=None)
diff --git a/modules/routes/routeSubscription.py b/modules/routes/routeSubscription.py
index 28ad7aa9..8334a8c0 100644
--- a/modules/routes/routeSubscription.py
+++ b/modules/routes/routeSubscription.py
@@ -12,13 +12,17 @@ Endpoints:
- POST /api/subscription/force-cancel — sysadmin immediate cancel (by ID)
"""
-from fastapi import APIRouter, HTTPException, Depends, Request
+from fastapi import APIRouter, HTTPException, Depends, Request, Query
from fastapi import status
from typing import Dict, Any, List, Optional
import logging
+import json
+import math
from pydantic import BaseModel, Field
from modules.auth import limiter, getRequestContext, RequestContext
+from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
+from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
logger = logging.getLogger(__name__)
@@ -303,16 +307,8 @@ def verifyCheckout(
# SysAdmin: global subscription overview
# =============================================================================
-@router.get("/admin/all", response_model=List[Dict[str, Any]])
-@limiter.limit("30/minute")
-def getAllSubscriptions(
- request: Request,
- context: RequestContext = Depends(getRequestContext),
-):
- """SysAdmin: list ALL subscriptions across all mandates with enriched metadata."""
- if not context.hasSysAdminRole:
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
-
+def _buildEnrichedSubscriptions() -> List[Dict[str, Any]]:
+ """Build the full enriched subscription list (shared by list + filter-values endpoints)."""
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
from modules.datamodels.datamodelSubscription import BUILTIN_PLANS, OPERATIVE_STATUSES
@@ -362,3 +358,80 @@ def getAllSubscriptions(
enriched.append(sub)
return enriched
+
+
+@router.get("/admin/all")
+@limiter.limit("30/minute")
+def getAllSubscriptions(
+ request: Request,
+ pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
+ context: RequestContext = Depends(getRequestContext),
+):
+ """SysAdmin: list ALL subscriptions across all mandates with enriched metadata."""
+ if not context.hasSysAdminRole:
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
+
+ paginationParams: Optional[PaginationParams] = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ paginationParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError) as e:
+ raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
+
+ enriched = _buildEnrichedSubscriptions()
+ filtered = _applyFiltersAndSort(enriched, paginationParams)
+
+ if paginationParams:
+ totalItems = len(filtered)
+ totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
+ startIdx = (paginationParams.page - 1) * paginationParams.pageSize
+ endIdx = startIdx + paginationParams.pageSize
+ pageItems = filtered[startIdx:endIdx]
+ return {
+ "items": pageItems,
+ "pagination": PaginationMetadata(
+ currentPage=paginationParams.page,
+ pageSize=paginationParams.pageSize,
+ totalItems=totalItems,
+ totalPages=totalPages,
+ sort=paginationParams.sort,
+ filters=paginationParams.filters,
+ ).model_dump(),
+ }
+
+ return {"items": enriched, "pagination": None}
+
+
+@router.get("/admin/all/filter-values")
+@limiter.limit("60/minute")
+def getFilterValues(
+ request: Request,
+ column: str = Query(..., description="Column key to extract distinct values for"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters (applied except for the requested column)"),
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Return distinct values for a column, respecting all active filters except the requested one."""
+ if not context.hasSysAdminRole:
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
+
+ crossFilterParams: Optional[PaginationParams] = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError) as e:
+ raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
+
+ enriched = _buildEnrichedSubscriptions()
+ crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
+
+ return _extractDistinctValues(crossFiltered, column)
diff --git a/modules/system/mainSystem.py b/modules/system/mainSystem.py
index 58667645..bfa3e23f 100644
--- a/modules/system/mainSystem.py
+++ b/modules/system/mainSystem.py
@@ -191,6 +191,15 @@ NAVIGATION_SECTIONS = [
"order": 40,
"adminOnly": True,
},
+ {
+ "id": "admin-subscriptions",
+ "objectKey": "ui.admin.subscriptions",
+ "label": {"en": "Subscriptions", "de": "Abonnements", "fr": "Abonnements"},
+ "icon": "FaFileContract",
+ "path": "/admin/billing/subscriptions",
+ "order": 50,
+ "adminOnly": True,
+ },
],
},
# ── System ──
diff --git a/scripts/.$import_diagram.drawio.bkp b/scripts/.$import_diagram.drawio.bkp
deleted file mode 100644
index 588285b0..00000000
--- a/scripts/.$import_diagram.drawio.bkp
+++ /dev/null
@@ -1,2723 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/scripts/.$import_diagram_containers.drawio.bkp b/scripts/.$import_diagram_containers.drawio.bkp
deleted file mode 100644
index cad95eef..00000000
--- a/scripts/.$import_diagram_containers.drawio.bkp
+++ /dev/null
@@ -1,317 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
From 9ef0d430915e9b7870eec2c224d5abf44f822dc1 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Sun, 22 Mar 2026 22:19:50 +0100
Subject: [PATCH 3/5] fixed tables and forms
---
app.py | 3 +
modules/connectors/connectorDbPostgre.py | 8 +
.../automation/interfaceFeatureAutomation.py | 2 +
modules/features/automation/mainAutomation.py | 5 -
.../automation/routeFeatureAutomation.py | 1 +
modules/routes/routeAdminAutomationLogs.py | 207 ++++++++++++++++++
modules/shared/attributeUtils.py | 21 ++
modules/system/mainSystem.py | 12 +-
8 files changed, 253 insertions(+), 6 deletions(-)
create mode 100644 modules/routes/routeAdminAutomationLogs.py
diff --git a/app.py b/app.py
index c1400353..e69d7a0f 100644
--- a/app.py
+++ b/app.py
@@ -568,6 +568,9 @@ app.include_router(sharepointRouter)
from modules.routes.routeAdminAutomationEvents import router as adminAutomationEventsRouter
app.include_router(adminAutomationEventsRouter)
+from modules.routes.routeAdminAutomationLogs import router as adminAutomationLogsRouter
+app.include_router(adminAutomationLogsRouter)
+
from modules.routes.routeAdminLogs import router as adminLogsRouter
app.include_router(adminLogsRouter)
diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py
index 9675ffca..67cceb45 100644
--- a/modules/connectors/connectorDbPostgre.py
+++ b/modules/connectors/connectorDbPostgre.py
@@ -992,6 +992,10 @@ class DatabaseConnector:
Returns (where_clause, order_clause, limit_clause, values, count_values).
"""
fields = _get_model_fields(model_class)
+ fields["_createdAt"] = "DOUBLE PRECISION"
+ fields["_modifiedAt"] = "DOUBLE PRECISION"
+ fields["_createdBy"] = "TEXT"
+ fields["_modifiedBy"] = "TEXT"
validColumns = set(fields.keys())
where_parts: List[str] = []
values: List[Any] = []
@@ -1186,6 +1190,10 @@ class DatabaseConnector:
"""
table = model_class.__name__
fields = _get_model_fields(model_class)
+ fields["_createdAt"] = "DOUBLE PRECISION"
+ fields["_modifiedAt"] = "DOUBLE PRECISION"
+ fields["_createdBy"] = "TEXT"
+ fields["_modifiedBy"] = "TEXT"
if column not in fields:
return []
diff --git a/modules/features/automation/interfaceFeatureAutomation.py b/modules/features/automation/interfaceFeatureAutomation.py
index 3fc2420b..4091bc28 100644
--- a/modules/features/automation/interfaceFeatureAutomation.py
+++ b/modules/features/automation/interfaceFeatureAutomation.py
@@ -418,6 +418,8 @@ class AutomationObjects:
if not self.checkRbacPermission(AutomationDefinition, "update", automationId):
raise PermissionError(f"No permission to modify automation {automationId}")
+ automationData.pop("executionLogs", None)
+
# If deactivating: immediately remove scheduler job (don't rely on async callback)
isBeingDeactivated = "active" in automationData and not automationData["active"]
if isBeingDeactivated:
diff --git a/modules/features/automation/mainAutomation.py b/modules/features/automation/mainAutomation.py
index 35a61512..4bb30f7f 100644
--- a/modules/features/automation/mainAutomation.py
+++ b/modules/features/automation/mainAutomation.py
@@ -27,11 +27,6 @@ UI_OBJECTS = [
"label": {"en": "Templates", "de": "Vorlagen", "fr": "Modèles"},
"meta": {"area": "templates"}
},
- {
- "objectKey": "ui.feature.automation.logs",
- "label": {"en": "Execution Logs", "de": "Ausführungsprotokolle", "fr": "Journaux d'exécution"},
- "meta": {"area": "logs"}
- },
]
# Resource Objects for RBAC catalog
diff --git a/modules/features/automation/routeFeatureAutomation.py b/modules/features/automation/routeFeatureAutomation.py
index 81f0852b..48f53eea 100644
--- a/modules/features/automation/routeFeatureAutomation.py
+++ b/modules/features/automation/routeFeatureAutomation.py
@@ -785,6 +785,7 @@ def update_automation(
try:
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
automationData = automation.model_dump()
+ automationData.pop("executionLogs", None)
updated = chatInterface.updateAutomationDefinition(automationId, automationData)
return updated
except HTTPException:
diff --git a/modules/routes/routeAdminAutomationLogs.py b/modules/routes/routeAdminAutomationLogs.py
new file mode 100644
index 00000000..8b4d897b
--- /dev/null
+++ b/modules/routes/routeAdminAutomationLogs.py
@@ -0,0 +1,207 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""
+Admin automation execution logs routes.
+SysAdmin-only endpoints for viewing consolidated automation execution history
+across all mandates and feature instances.
+"""
+
+from fastapi import APIRouter, HTTPException, Depends, Request, Query
+from typing import List, Dict, Any, Optional
+import logging
+import json
+import math
+import uuid
+
+from modules.auth import limiter, requireSysAdminRole
+from modules.datamodels.datamodelUam import User
+from modules.datamodels.datamodelPagination import PaginationParams, PaginationMetadata, normalize_pagination_dict
+from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(
+ prefix="/api/admin/automation-logs",
+ tags=["Admin Automation Logs"],
+ responses={
+ 401: {"description": "Unauthorized"},
+ 403: {"description": "Forbidden - Sysadmin only"},
+ 500: {"description": "Internal server error"},
+ },
+)
+
+
+def _buildFlattenedExecutionLogs(currentUser: User) -> List[Dict[str, Any]]:
+ """Flatten executionLogs from all AutomationDefinitions across all mandates.
+ Called from a SysAdmin-only endpoint — bypasses RBAC, reads directly from DB."""
+ from modules.interfaces.interfaceDbApp import getRootInterface
+ from modules.features.automation.mainAutomation import getAutomationServices
+ from modules.features.automation.datamodelFeatureAutomation import AutomationDefinition
+
+ rootInterface = getRootInterface()
+ services = getAutomationServices(currentUser, mandateId=None, featureInstanceId=None)
+ allAutomations = services.interfaceDbAutomation.db.getRecordset(AutomationDefinition)
+
+ userCache: Dict[str, str] = {}
+ mandateCache: Dict[str, str] = {}
+ featureCache: Dict[str, str] = {}
+
+ def _resolveUsername(userId: str) -> str:
+ if not userId:
+ return ""
+ if userId not in userCache:
+ try:
+ user = rootInterface.getUser(userId)
+ userCache[userId] = user.username if user else userId[:8]
+ except Exception:
+ userCache[userId] = userId[:8]
+ return userCache[userId]
+
+ def _resolveMandateLabel(mandateId: str) -> str:
+ if not mandateId:
+ return ""
+ if mandateId not in mandateCache:
+ try:
+ mandate = rootInterface.getMandate(mandateId)
+ mandateCache[mandateId] = getattr(mandate, "label", None) or mandateId[:8]
+ except Exception:
+ mandateCache[mandateId] = mandateId[:8]
+ return mandateCache[mandateId]
+
+ def _resolveFeatureLabel(featureInstanceId: str) -> str:
+ if not featureInstanceId:
+ return ""
+ if featureInstanceId not in featureCache:
+ try:
+ instance = rootInterface.getFeatureInstance(featureInstanceId)
+ featureCache[featureInstanceId] = (
+ getattr(instance, "label", None)
+ or getattr(instance, "featureCode", None)
+ or featureInstanceId[:8]
+ )
+ except Exception:
+ featureCache[featureInstanceId] = featureInstanceId[:8]
+ return featureCache[featureInstanceId]
+
+ flatLogs: List[Dict[str, Any]] = []
+
+ for automation in allAutomations:
+ if isinstance(automation, dict):
+ automationId = automation.get("id", "")
+ automationLabel = automation.get("label", "")
+ mandateId = automation.get("mandateId", "")
+ featureInstanceId = automation.get("featureInstanceId", "")
+ createdBy = automation.get("_createdBy", "")
+ logs = automation.get("executionLogs") or []
+ else:
+ automationId = getattr(automation, "id", "")
+ automationLabel = getattr(automation, "label", "")
+ mandateId = getattr(automation, "mandateId", "")
+ featureInstanceId = getattr(automation, "featureInstanceId", "")
+ createdBy = getattr(automation, "_createdBy", "")
+ logs = getattr(automation, "executionLogs", None) or []
+
+ mandateName = _resolveMandateLabel(mandateId)
+ featureInstanceName = _resolveFeatureLabel(featureInstanceId)
+ executedByName = _resolveUsername(createdBy)
+
+ for log in logs:
+ timestamp = log.get("timestamp", 0) if isinstance(log, dict) else 0
+ status = log.get("status", "") if isinstance(log, dict) else ""
+ workflowId = log.get("workflowId", "") if isinstance(log, dict) else ""
+ messages = log.get("messages", []) if isinstance(log, dict) else []
+
+ flatLogs.append({
+ "id": str(uuid.uuid4()),
+ "timestamp": timestamp,
+ "automationId": automationId,
+ "automationLabel": automationLabel,
+ "mandateName": mandateName,
+ "featureInstanceName": featureInstanceName,
+ "executedBy": executedByName,
+ "status": status,
+ "workflowId": workflowId,
+ "messages": "; ".join(messages) if messages else "",
+ })
+
+ flatLogs.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
+ return flatLogs
+
+
+@router.get("")
+@limiter.limit("30/minute")
+def get_all_automation_logs(
+ request: Request,
+ pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
+ currentUser: User = Depends(requireSysAdminRole),
+):
+ """Get consolidated execution logs from all automations (sysadmin only)."""
+ try:
+ paginationParams: Optional[PaginationParams] = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ paginationParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError) as e:
+ raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
+
+ logs = _buildFlattenedExecutionLogs(currentUser)
+ filtered = _applyFiltersAndSort(logs, paginationParams)
+
+ if paginationParams:
+ totalItems = len(filtered)
+ totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
+ startIdx = (paginationParams.page - 1) * paginationParams.pageSize
+ endIdx = startIdx + paginationParams.pageSize
+ return {
+ "items": filtered[startIdx:endIdx],
+ "pagination": PaginationMetadata(
+ currentPage=paginationParams.page,
+ pageSize=paginationParams.pageSize,
+ totalItems=totalItems,
+ totalPages=totalPages,
+ sort=paginationParams.sort,
+ filters=paginationParams.filters,
+ ).model_dump(),
+ }
+
+ return {"items": logs, "pagination": None}
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Error getting automation logs: {str(e)}")
+ raise HTTPException(status_code=500, detail=f"Error getting automation logs: {str(e)}")
+
+
+@router.get("/filter-values")
+@limiter.limit("60/minute")
+def get_automation_log_filter_values(
+ request: Request,
+ column: str = Query(..., description="Column key"),
+ pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
+ currentUser: User = Depends(requireSysAdminRole),
+):
+ """Return distinct filter values for a column in automation logs."""
+ try:
+ crossFilterParams: Optional[PaginationParams] = None
+ if pagination:
+ try:
+ paginationDict = json.loads(pagination)
+ if paginationDict:
+ paginationDict = normalize_pagination_dict(paginationDict)
+ filters = paginationDict.get("filters", {})
+ filters.pop(column, None)
+ paginationDict["filters"] = filters
+ paginationDict.pop("sort", None)
+ crossFilterParams = PaginationParams(**paginationDict)
+ except (json.JSONDecodeError, ValueError):
+ pass
+
+ logs = _buildFlattenedExecutionLogs(currentUser)
+ crossFiltered = _applyFiltersAndSort(logs, crossFilterParams)
+ return _extractDistinctValues(crossFiltered, column)
+ except Exception as e:
+ logger.error(f"Error getting filter values: {str(e)}")
+ raise HTTPException(status_code=500, detail=str(e))
diff --git a/modules/shared/attributeUtils.py b/modules/shared/attributeUtils.py
index 863d7f36..6a857d85 100644
--- a/modules/shared/attributeUtils.py
+++ b/modules/shared/attributeUtils.py
@@ -258,6 +258,27 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag
attributes.append(attr_def)
+ # Append system timestamp fields (set automatically by DatabaseConnector)
+ systemTimestampFields = [
+ ("_createdAt", {"en": "Created at", "de": "Erstellt am", "fr": "Créé le"}),
+ ("_modifiedAt", {"en": "Modified at", "de": "Geändert am", "fr": "Modifié le"}),
+ ]
+ for sysName, sysLabels in systemTimestampFields:
+ attributes.append({
+ "name": sysName,
+ "type": "timestamp",
+ "required": False,
+ "description": "",
+ "label": sysLabels.get(userLanguage, sysLabels["en"]),
+ "placeholder": "",
+ "editable": False,
+ "visible": True,
+ "order": len(attributes),
+ "readonly": True,
+ "options": None,
+ "default": None,
+ })
+
return {"model": model_label, "attributes": attributes}
diff --git a/modules/system/mainSystem.py b/modules/system/mainSystem.py
index bfa3e23f..1325f060 100644
--- a/modules/system/mainSystem.py
+++ b/modules/system/mainSystem.py
@@ -196,7 +196,7 @@ NAVIGATION_SECTIONS = [
"objectKey": "ui.admin.subscriptions",
"label": {"en": "Subscriptions", "de": "Abonnements", "fr": "Abonnements"},
"icon": "FaFileContract",
- "path": "/admin/billing/subscriptions",
+ "path": "/admin/subscriptions",
"order": 50,
"adminOnly": True,
},
@@ -282,6 +282,16 @@ NAVIGATION_SECTIONS = [
"adminOnly": True,
"sysAdminOnly": True,
},
+ {
+ "id": "admin-automation-logs",
+ "objectKey": "ui.admin.automationLogs",
+ "label": {"en": "Execution Logs", "de": "Ausführungsprotokolle", "fr": "Journaux d'exécution"},
+ "icon": "FaClipboardList",
+ "path": "/admin/automation-logs",
+ "order": 85,
+ "adminOnly": True,
+ "sysAdminOnly": True,
+ },
{
"id": "admin-logs",
"objectKey": "ui.admin.logs",
From 3934cdd3ee5e5f9e1186f6c9679a122c276e5bbd Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Mon, 23 Mar 2026 00:05:29 +0100
Subject: [PATCH 4/5] tool fixes
---
.../connectors/providerMsft/connectorMsft.py | 49 +++-
modules/datamodels/datamodelVoice.py | 45 +---
.../commcoach/routeFeatureCommcoach.py | 73 +++++-
.../features/commcoach/serviceCommcoach.py | 3 +-
.../mainServiceNeutralization.py | 18 +-
.../workspace/datamodelFeatureWorkspace.py | 65 +++++
.../workspace/interfaceFeatureWorkspace.py | 248 ++++++++++++++++++
.../workspace/routeFeatureWorkspace.py | 95 ++++++-
modules/interfaces/interfaceDbKnowledge.py | 2 +
modules/interfaces/interfaceDbManagement.py | 153 -----------
modules/serviceCenter/registry.py | 2 +-
.../services/serviceAgent/mainServiceAgent.py | 59 ++++-
.../serviceBilling/mainServiceBilling.py | 2 +-
.../mainServiceSubscription.py | 54 +++-
.../serviceSubscription/stripeBootstrap.py | 61 ++++-
15 files changed, 691 insertions(+), 238 deletions(-)
create mode 100644 modules/features/workspace/datamodelFeatureWorkspace.py
create mode 100644 modules/features/workspace/interfaceFeatureWorkspace.py
diff --git a/modules/connectors/providerMsft/connectorMsft.py b/modules/connectors/providerMsft/connectorMsft.py
index 3654bad9..a51fa231 100644
--- a/modules/connectors/providerMsft/connectorMsft.py
+++ b/modules/connectors/providerMsft/connectorMsft.py
@@ -299,26 +299,65 @@ class OutlookAdapter(_GraphApiMixin, ServiceAdapter):
for m in result.get("value", [])
]
- async def sendMail(
+ def _buildMessage(
self, to: List[str], subject: str, body: str,
- cc: Optional[List[str]] = None, attachments: Optional[List[Dict]] = None
+ bodyType: str = "Text",
+ cc: Optional[List[str]] = None,
+ attachments: Optional[List[Dict]] = None,
) -> Dict[str, Any]:
- """Send an email via Microsoft Graph."""
- import json
+ """Build a Graph API message object.
+
+ attachments: list of {"name": str, "contentBytes": str (base64), "contentType": str}
+ """
message: Dict[str, Any] = {
"subject": subject,
- "body": {"contentType": "Text", "content": body},
+ "body": {"contentType": bodyType, "content": body},
"toRecipients": [{"emailAddress": {"address": addr}} for addr in to],
}
if cc:
message["ccRecipients"] = [{"emailAddress": {"address": addr}} for addr in cc]
+ if attachments:
+ message["attachments"] = [
+ {
+ "@odata.type": "#microsoft.graph.fileAttachment",
+ "name": att["name"],
+ "contentBytes": att["contentBytes"],
+ "contentType": att.get("contentType", "application/octet-stream"),
+ }
+ for att in attachments
+ ]
+ return message
+ async def sendMail(
+ self, to: List[str], subject: str, body: str,
+ bodyType: str = "Text",
+ cc: Optional[List[str]] = None,
+ attachments: Optional[List[Dict]] = None,
+ ) -> Dict[str, Any]:
+ """Send an email via Microsoft Graph. bodyType: 'Text' or 'HTML'."""
+ import json
+ message = self._buildMessage(to, subject, body, bodyType, cc, attachments)
payload = json.dumps({"message": message, "saveToSentItems": True}).encode("utf-8")
result = await self._graphPost("me/sendMail", payload)
if "error" in result:
return result
return {"success": True}
+ async def createDraft(
+ self, to: List[str], subject: str, body: str,
+ bodyType: str = "Text",
+ cc: Optional[List[str]] = None,
+ attachments: Optional[List[Dict]] = None,
+ ) -> Dict[str, Any]:
+ """Create a draft email in the user's Drafts folder via Microsoft Graph."""
+ import json
+ message = self._buildMessage(to, subject, body, bodyType, cc, attachments)
+ payload = json.dumps(message).encode("utf-8")
+ result = await self._graphPost("me/messages", payload)
+ if "error" in result:
+ return result
+ return {"success": True, "draft": True, "messageId": result.get("id", "")}
+
# ---------------------------------------------------------------------------
# Teams Adapter (Stub)
diff --git a/modules/datamodels/datamodelVoice.py b/modules/datamodels/datamodelVoice.py
index 2223a3e6..565c7677 100644
--- a/modules/datamodels/datamodelVoice.py
+++ b/modules/datamodels/datamodelVoice.py
@@ -1,46 +1,7 @@
# Copyright (c) 2025 Patrick Motsch
# All rights reserved.
-"""Voice settings datamodel."""
-
-from typing import Dict, Any, Optional
-from pydantic import BaseModel, Field
-from modules.shared.attributeUtils import registerModelLabels
-from modules.shared.timeUtils import getUtcTimestamp
-import uuid
-
-
-class VoiceSettings(BaseModel):
- id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
- userId: str = Field(description="ID of the user these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
- mandateId: str = Field(description="ID of the mandate these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
- featureInstanceId: str = Field(description="ID of the feature instance these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
- sttLanguage: str = Field(default="de-DE", description="Speech-to-Text language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
- ttsLanguage: str = Field(default="de-DE", description="Text-to-Speech language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
- ttsVoice: str = Field(default="de-DE-KatjaNeural", description="Text-to-Speech voice", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
- ttsVoiceMap: Dict[str, Any] = Field(default_factory=dict, description="Per-language voice mapping, e.g. {'de-DE': {'voiceName': 'de-DE-Wavenet-A'}, 'en-US': {'voiceName': 'en-US-Wavenet-C'}}", json_schema_extra={"frontend_type": "json", "frontend_readonly": False, "frontend_required": False})
- translationEnabled: bool = Field(default=True, description="Whether translation is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False})
- targetLanguage: str = Field(default="en-US", description="Target language for translation", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False})
- creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
- lastModified: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were last modified (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
-
-
-registerModelLabels(
- "VoiceSettings",
- {"en": "Voice Settings", "fr": "Paramètres vocaux"},
- {
- "id": {"en": "ID", "fr": "ID"},
- "userId": {"en": "User ID", "fr": "ID utilisateur"},
- "mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
- "featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"},
- "sttLanguage": {"en": "STT Language", "fr": "Langue STT"},
- "ttsLanguage": {"en": "TTS Language", "fr": "Langue TTS"},
- "ttsVoice": {"en": "TTS Voice", "fr": "Voix TTS"},
- "ttsVoiceMap": {"en": "TTS Voice Map", "fr": "Carte des voix TTS"},
- "translationEnabled": {"en": "Translation Enabled", "fr": "Traduction activée"},
- "targetLanguage": {"en": "Target Language", "fr": "Langue cible"},
- "creationDate": {"en": "Creation Date", "fr": "Date de création"},
- "lastModified": {"en": "Last Modified", "fr": "Dernière modification"},
- },
-)
+"""Voice settings datamodel — re-exported from workspace feature for backward compatibility."""
+from modules.features.workspace.datamodelFeatureWorkspace import VoiceSettings
+__all__ = ["VoiceSettings"]
diff --git a/modules/features/commcoach/routeFeatureCommcoach.py b/modules/features/commcoach/routeFeatureCommcoach.py
index 81585a19..9074d2ba 100644
--- a/modules/features/commcoach/routeFeatureCommcoach.py
+++ b/modules/features/commcoach/routeFeatureCommcoach.py
@@ -1192,24 +1192,79 @@ async def deleteDocumentRoute(
def _extractText(content: bytes, mimeType: str, fileName: str) -> Optional[str]:
- """Extract text from uploaded file content."""
+ """Extract text from uploaded file content (TXT, MD, HTML, PDF, DOCX, XLSX, PPTX)."""
+ import io
+
+ lowerName = fileName.lower()
try:
- if mimeType == "text/plain" or fileName.endswith(".txt"):
+ if mimeType in ("text/plain",) or lowerName.endswith(".txt"):
return content.decode("utf-8", errors="replace")
- if mimeType == "text/markdown" or fileName.endswith(".md"):
+
+ if mimeType in ("text/markdown",) or lowerName.endswith(".md"):
return content.decode("utf-8", errors="replace")
- if "pdf" in mimeType or fileName.endswith(".pdf"):
+
+ if mimeType in ("text/html",) or lowerName.endswith((".html", ".htm")):
+ from html.parser import HTMLParser
+ class _Strip(HTMLParser):
+ def __init__(self):
+ super().__init__()
+ self._parts: list[str] = []
+ def handle_data(self, d):
+ self._parts.append(d)
+ def result(self):
+ return " ".join(self._parts)
+ parser = _Strip()
+ parser.feed(content.decode("utf-8", errors="replace"))
+ return parser.result()
+
+ if "pdf" in mimeType or lowerName.endswith(".pdf"):
try:
- import io
from PyPDF2 import PdfReader
reader = PdfReader(io.BytesIO(content))
- text = ""
- for page in reader.pages:
- text += page.extract_text() or ""
- return text
+ return "".join(page.extract_text() or "" for page in reader.pages)
except ImportError:
logger.warning("PyPDF2 not installed, cannot extract PDF text")
return None
+
+ if "wordprocessingml" in mimeType or lowerName.endswith(".docx"):
+ try:
+ from docx import Document
+ doc = Document(io.BytesIO(content))
+ return "\n".join(p.text for p in doc.paragraphs if p.text)
+ except ImportError:
+ logger.warning("python-docx not installed, cannot extract DOCX text")
+ return None
+
+ if "spreadsheetml" in mimeType or lowerName.endswith(".xlsx"):
+ try:
+ from openpyxl import load_workbook
+ wb = load_workbook(io.BytesIO(content), read_only=True, data_only=True)
+ parts: list[str] = []
+ for ws in wb.worksheets:
+ for row in ws.iter_rows(values_only=True):
+ cells = [str(c) for c in row if c is not None]
+ if cells:
+ parts.append("\t".join(cells))
+ return "\n".join(parts)
+ except ImportError:
+ logger.warning("openpyxl not installed, cannot extract XLSX text")
+ return None
+
+ if "presentationml" in mimeType or lowerName.endswith(".pptx"):
+ try:
+ from pptx import Presentation
+ prs = Presentation(io.BytesIO(content))
+ parts = []
+ for slide in prs.slides:
+ for shape in slide.shapes:
+ if shape.has_text_frame:
+ parts.append(shape.text_frame.text)
+ return "\n".join(parts)
+ except ImportError:
+ logger.warning("python-pptx not installed, cannot extract PPTX text")
+ return None
+
+ logger.info(f"No text extractor for {fileName} (mime={mimeType})")
except Exception as e:
logger.warning(f"Text extraction failed for {fileName}: {e}")
return None
diff --git a/modules/features/commcoach/serviceCommcoach.py b/modules/features/commcoach/serviceCommcoach.py
index be47a917..bf5ec281 100644
--- a/modules/features/commcoach/serviceCommcoach.py
+++ b/modules/features/commcoach/serviceCommcoach.py
@@ -331,7 +331,8 @@ def _getDocumentSummaries(contextId: str, userId: str, interface) -> Optional[Li
elif doc.get("extractedText"):
summaries.append(f"[{doc.get('fileName', 'Dokument')}] {doc['extractedText'][:200]}...")
return summaries if summaries else None
- except Exception:
+ except Exception as e:
+ logger.warning(f"Failed to load document summaries for context {contextId}: {e}")
return None
diff --git a/modules/features/neutralization/serviceNeutralization/mainServiceNeutralization.py b/modules/features/neutralization/serviceNeutralization/mainServiceNeutralization.py
index cf8f0f53..c803b375 100644
--- a/modules/features/neutralization/serviceNeutralization/mainServiceNeutralization.py
+++ b/modules/features/neutralization/serviceNeutralization/mainServiceNeutralization.py
@@ -39,19 +39,21 @@ EXTRACTABLE_BINARY_MIME_TYPES = frozenset({
class NeutralizationService:
"""Service for handling data neutralization operations"""
- def __init__(self, serviceCenter=None, NamesToParse: List[str] = None):
+ def __init__(self, serviceCenter=None, getServiceFn=None, NamesToParse: List[str] = None):
"""Initialize the service with user context and anonymization processors
Args:
- serviceCenter: Service center instance for accessing other services
+ serviceCenter: Service center context or legacy service center instance
+ getServiceFn: Service resolver function (injected by ServiceCenter resolver)
NamesToParse: List of names to parse and replace (case-insensitive)
"""
self.services = serviceCenter
- self.interfaceDbComponent = serviceCenter.interfaceDbComponent
+ self._getService = getServiceFn
+ self.interfaceDbComponent = getattr(serviceCenter, "interfaceDbComponent", None)
# Create feature-specific interface for neutralizer DB operations
self.interfaceNeutralizer: InterfaceFeatureNeutralizer = None
- if serviceCenter and serviceCenter.interfaceDbApp:
+ if serviceCenter and getattr(serviceCenter, "interfaceDbApp", None):
dbApp = serviceCenter.interfaceDbApp
self.interfaceNeutralizer = getNeutralizerInterface(
currentUser=dbApp.currentUser,
@@ -59,10 +61,10 @@ class NeutralizationService:
featureInstanceId=getattr(serviceCenter, 'featureInstanceId', None) or getattr(dbApp, 'featureInstanceId', None)
)
- # Initialize anonymization processors
- self.NamesToParse = NamesToParse or []
- self.textProcessor = TextProcessor(NamesToParse)
- self.listProcessor = ListProcessor(NamesToParse)
+ namesList = NamesToParse if isinstance(NamesToParse, list) else []
+ self.NamesToParse = namesList
+ self.textProcessor = TextProcessor(namesList)
+ self.listProcessor = ListProcessor(namesList)
self.binaryProcessor = BinaryProcessor()
self.commonUtils = CommonUtils()
diff --git a/modules/features/workspace/datamodelFeatureWorkspace.py b/modules/features/workspace/datamodelFeatureWorkspace.py
new file mode 100644
index 00000000..80da5915
--- /dev/null
+++ b/modules/features/workspace/datamodelFeatureWorkspace.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""Workspace feature data models — VoiceSettings and WorkspaceUserSettings."""
+
+from typing import Dict, Any, Optional
+from pydantic import BaseModel, Field
+from modules.shared.attributeUtils import registerModelLabels
+from modules.shared.timeUtils import getUtcTimestamp
+import uuid
+
+
+class VoiceSettings(BaseModel):
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
+ userId: str = Field(description="ID of the user these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
+ mandateId: str = Field(description="ID of the mandate these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
+ featureInstanceId: str = Field(description="ID of the feature instance these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
+ sttLanguage: str = Field(default="de-DE", description="Speech-to-Text language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
+ ttsLanguage: str = Field(default="de-DE", description="Text-to-Speech language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
+ ttsVoice: str = Field(default="de-DE-KatjaNeural", description="Text-to-Speech voice", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
+ ttsVoiceMap: Dict[str, Any] = Field(default_factory=dict, description="Per-language voice mapping, e.g. {'de-DE': {'voiceName': 'de-DE-Wavenet-A'}, 'en-US': {'voiceName': 'en-US-Wavenet-C'}}", json_schema_extra={"frontend_type": "json", "frontend_readonly": False, "frontend_required": False})
+ translationEnabled: bool = Field(default=True, description="Whether translation is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False})
+ targetLanguage: str = Field(default="en-US", description="Target language for translation", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False})
+ creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
+ lastModified: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were last modified (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
+
+
+class WorkspaceUserSettings(BaseModel):
+ """Per-user workspace settings. None values mean 'use instance default'."""
+ id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
+ userId: str = Field(description="User ID", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
+ mandateId: str = Field(description="Mandate ID", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
+ featureInstanceId: str = Field(description="Feature Instance ID", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
+ maxAgentRounds: Optional[int] = Field(default=None, description="Max agent rounds override (None = instance default)", json_schema_extra={"frontend_type": "number", "frontend_readonly": False, "frontend_required": False})
+
+
+registerModelLabels(
+ "VoiceSettings",
+ {"en": "Voice Settings", "fr": "Paramètres vocaux"},
+ {
+ "id": {"en": "ID", "fr": "ID"},
+ "userId": {"en": "User ID", "fr": "ID utilisateur"},
+ "mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
+ "featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"},
+ "sttLanguage": {"en": "STT Language", "fr": "Langue STT"},
+ "ttsLanguage": {"en": "TTS Language", "fr": "Langue TTS"},
+ "ttsVoice": {"en": "TTS Voice", "fr": "Voix TTS"},
+ "ttsVoiceMap": {"en": "TTS Voice Map", "fr": "Carte des voix TTS"},
+ "translationEnabled": {"en": "Translation Enabled", "fr": "Traduction activée"},
+ "targetLanguage": {"en": "Target Language", "fr": "Langue cible"},
+ "creationDate": {"en": "Creation Date", "fr": "Date de création"},
+ "lastModified": {"en": "Last Modified", "fr": "Dernière modification"},
+ },
+)
+
+registerModelLabels(
+ "WorkspaceUserSettings",
+ {"en": "Workspace User Settings", "de": "Workspace Benutzereinstellungen"},
+ {
+ "id": {"en": "ID", "de": "ID"},
+ "userId": {"en": "User ID", "de": "Benutzer-ID"},
+ "mandateId": {"en": "Mandate ID", "de": "Mandanten-ID"},
+ "featureInstanceId": {"en": "Feature Instance ID", "de": "Feature-Instanz-ID"},
+ "maxAgentRounds": {"en": "Max Agent Rounds", "de": "Max. Agenten-Runden"},
+ },
+)
diff --git a/modules/features/workspace/interfaceFeatureWorkspace.py b/modules/features/workspace/interfaceFeatureWorkspace.py
new file mode 100644
index 00000000..bd1a03c4
--- /dev/null
+++ b/modules/features/workspace/interfaceFeatureWorkspace.py
@@ -0,0 +1,248 @@
+# Copyright (c) 2025 Patrick Motsch
+# All rights reserved.
+"""
+Interface for Workspace feature — manages VoiceSettings and WorkspaceUserSettings.
+Uses a dedicated poweron_workspace database.
+"""
+
+import logging
+from typing import Dict, Any, Optional
+
+from modules.connectors.connectorDbPostgre import DatabaseConnector
+from modules.datamodels.datamodelUam import User
+from modules.features.workspace.datamodelFeatureWorkspace import VoiceSettings, WorkspaceUserSettings
+from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
+from modules.security.rbac import RbacClass
+from modules.shared.configuration import APP_CONFIG
+from modules.shared.timeUtils import getUtcTimestamp
+
+logger = logging.getLogger(__name__)
+
+_workspaceInterfaces: Dict[str, "WorkspaceObjects"] = {}
+
+
+class WorkspaceObjects:
+ """Interface for Workspace-specific database operations (voice + general settings)."""
+
+ def __init__(self, currentUser: User, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
+ self.currentUser = currentUser
+ self.mandateId = mandateId
+ self.featureInstanceId = featureInstanceId
+ self.userId = currentUser.id if currentUser else None
+
+ self._initializeDatabase()
+
+ from modules.security.rootAccess import getRootDbAppConnector
+ dbApp = getRootDbAppConnector()
+ self.rbac = RbacClass(self.db, dbApp=dbApp)
+
+ self.db.updateContext(self.userId)
+
+ def _initializeDatabase(self):
+ dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data")
+ dbDatabase = "poweron_workspace"
+ dbUser = APP_CONFIG.get("DB_USER")
+ dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
+ dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
+
+ self.db = DatabaseConnector(
+ dbHost=dbHost,
+ dbDatabase=dbDatabase,
+ dbUser=dbUser,
+ dbPassword=dbPassword,
+ dbPort=dbPort,
+ userId=self.userId,
+ )
+ logger.debug(f"Workspace database initialized for user {self.userId}")
+
+ def setUserContext(self, currentUser: User, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
+ self.currentUser = currentUser
+ self.userId = currentUser.id if currentUser else None
+ self.mandateId = mandateId
+ self.featureInstanceId = featureInstanceId
+ self.db.updateContext(self.userId)
+
+ # =========================================================================
+ # VoiceSettings CRUD
+ # =========================================================================
+
+ def getVoiceSettings(self, userId: Optional[str] = None) -> Optional[VoiceSettings]:
+ try:
+ targetUserId = userId or self.userId
+ if not targetUserId:
+ logger.error("No user ID provided for voice settings")
+ return None
+
+ recordFilter: Dict[str, Any] = {"userId": targetUserId}
+ if self.featureInstanceId:
+ recordFilter["featureInstanceId"] = self.featureInstanceId
+
+ filteredSettings = getRecordsetWithRBAC(
+ self.db, VoiceSettings, self.currentUser,
+ recordFilter=recordFilter, mandateId=self.mandateId,
+ )
+
+ if not filteredSettings:
+ return None
+
+ settingsData = filteredSettings[0]
+ if not settingsData.get("creationDate"):
+ settingsData["creationDate"] = getUtcTimestamp()
+ if not settingsData.get("lastModified"):
+ settingsData["lastModified"] = getUtcTimestamp()
+
+ return VoiceSettings(**settingsData)
+
+ except Exception as e:
+ logger.error(f"Error getting voice settings: {e}")
+ return None
+
+ def createVoiceSettings(self, settingsData: Dict[str, Any]) -> Dict[str, Any]:
+ try:
+ if "userId" not in settingsData:
+ settingsData["userId"] = self.userId
+ if "mandateId" not in settingsData:
+ settingsData["mandateId"] = self.mandateId
+ if "featureInstanceId" not in settingsData:
+ settingsData["featureInstanceId"] = self.featureInstanceId
+
+ existing = self.getVoiceSettings(settingsData["userId"])
+ if existing:
+ raise ValueError(f"Voice settings already exist for user {settingsData['userId']}")
+
+ createdRecord = self.db.recordCreate(VoiceSettings, settingsData)
+ if not createdRecord or not createdRecord.get("id"):
+ raise ValueError("Failed to create voice settings record")
+
+ logger.info(f"Created voice settings for user {settingsData['userId']}")
+ return createdRecord
+
+ except Exception as e:
+ logger.error(f"Error creating voice settings: {e}")
+ raise
+
+ def updateVoiceSettings(self, userId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
+ try:
+ existing = self.getVoiceSettings(userId)
+ if not existing:
+ raise ValueError(f"Voice settings not found for user {userId}")
+
+ updateData["lastModified"] = getUtcTimestamp()
+ success = self.db.recordModify(VoiceSettings, existing.id, updateData)
+ if not success:
+ raise ValueError("Failed to update voice settings record")
+
+ updated = self.getVoiceSettings(userId)
+ if not updated:
+ raise ValueError("Failed to retrieve updated voice settings")
+
+ logger.info(f"Updated voice settings for user {userId}")
+ return updated.model_dump()
+
+ except Exception as e:
+ logger.error(f"Error updating voice settings: {e}")
+ raise
+
+ def deleteVoiceSettings(self, userId: str) -> bool:
+ try:
+ existing = self.getVoiceSettings(userId)
+ if not existing:
+ return False
+ success = self.db.recordDelete(VoiceSettings, existing.id)
+ if success:
+ logger.info(f"Deleted voice settings for user {userId}")
+ return success
+ except Exception as e:
+ logger.error(f"Error deleting voice settings: {e}")
+ return False
+
+ def getOrCreateVoiceSettings(self, userId: Optional[str] = None) -> VoiceSettings:
+ targetUserId = userId or self.userId
+ if not targetUserId:
+ raise ValueError("No user ID provided for voice settings")
+
+ existing = self.getVoiceSettings(targetUserId)
+ if existing:
+ return existing
+
+ defaultSettings = {
+ "userId": targetUserId,
+ "mandateId": self.mandateId,
+ "featureInstanceId": self.featureInstanceId,
+ "sttLanguage": "de-DE",
+ "ttsLanguage": "de-DE",
+ "ttsVoice": "de-DE-KatjaNeural",
+ "translationEnabled": True,
+ "targetLanguage": "en-US",
+ }
+ createdRecord = self.createVoiceSettings(defaultSettings)
+ return VoiceSettings(**createdRecord)
+
+ # =========================================================================
+ # WorkspaceUserSettings CRUD
+ # =========================================================================
+
+ def getWorkspaceUserSettings(self, userId: Optional[str] = None) -> Optional[WorkspaceUserSettings]:
+ targetUserId = userId or self.userId
+ if not targetUserId:
+ return None
+
+ try:
+ recordFilter: Dict[str, Any] = {"userId": targetUserId}
+ if self.featureInstanceId:
+ recordFilter["featureInstanceId"] = self.featureInstanceId
+
+ records = getRecordsetWithRBAC(
+ self.db, WorkspaceUserSettings, self.currentUser,
+ recordFilter=recordFilter, mandateId=self.mandateId,
+ )
+ if not records:
+ return None
+ return WorkspaceUserSettings(**records[0])
+
+ except Exception as e:
+ logger.error(f"Error getting workspace user settings: {e}")
+ return None
+
+ def saveWorkspaceUserSettings(self, data: Dict[str, Any]) -> WorkspaceUserSettings:
+ """Upsert: create or update workspace user settings for the current user."""
+ targetUserId = data.get("userId", self.userId)
+ if not targetUserId:
+ raise ValueError("No user ID provided")
+
+ existing = self.getWorkspaceUserSettings(targetUserId)
+ if existing:
+ updateData = {k: v for k, v in data.items() if k not in ("id", "userId", "mandateId", "featureInstanceId")}
+ self.db.recordModify(WorkspaceUserSettings, existing.id, updateData)
+ updated = self.getWorkspaceUserSettings(targetUserId)
+ if not updated:
+ raise ValueError("Failed to retrieve updated workspace user settings")
+ return updated
+
+ createData = {
+ "userId": targetUserId,
+ "mandateId": data.get("mandateId", self.mandateId),
+ "featureInstanceId": data.get("featureInstanceId", self.featureInstanceId),
+ }
+ createData.update({k: v for k, v in data.items() if k not in ("id", "userId", "mandateId", "featureInstanceId")})
+ createdRecord = self.db.recordCreate(WorkspaceUserSettings, createData)
+ if not createdRecord or not createdRecord.get("id"):
+ raise ValueError("Failed to create workspace user settings")
+ return WorkspaceUserSettings(**createdRecord)
+
+
+def getInterface(currentUser: Optional[User] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> WorkspaceObjects:
+ if not currentUser:
+ raise ValueError("Invalid user context: user is required")
+
+ effectiveMandateId = str(mandateId) if mandateId else None
+ effectiveFeatureInstanceId = str(featureInstanceId) if featureInstanceId else None
+
+ contextKey = f"workspace_{effectiveMandateId}_{effectiveFeatureInstanceId}_{currentUser.id}"
+
+ if contextKey not in _workspaceInterfaces:
+ _workspaceInterfaces[contextKey] = WorkspaceObjects(currentUser, mandateId=effectiveMandateId, featureInstanceId=effectiveFeatureInstanceId)
+ else:
+ _workspaceInterfaces[contextKey].setUserContext(currentUser, mandateId=effectiveMandateId, featureInstanceId=effectiveFeatureInstanceId)
+
+ return _workspaceInterfaces[contextKey]
diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py
index 6dc7774e..dd8481ff 100644
--- a/modules/features/workspace/routeFeatureWorkspace.py
+++ b/modules/features/workspace/routeFeatureWorkspace.py
@@ -23,6 +23,7 @@ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription
SubscriptionInactiveException,
)
from modules.interfaces import interfaceDbChat, interfaceDbManagement
+from modules.features.workspace import interfaceFeatureWorkspace
from modules.interfaces.interfaceAiObjects import AiObjects
from modules.serviceCenter.core.serviceStreaming import get_event_manager
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentEventTypeEnum, PendingFileEdit
@@ -145,6 +146,14 @@ def _getDbManagement(context: RequestContext, featureInstanceId: str = None):
)
+def _getWorkspaceInterface(context: RequestContext, featureInstanceId: str = None):
+ return interfaceFeatureWorkspace.getInterface(
+ context.user,
+ mandateId=str(context.mandateId) if context.mandateId else None,
+ featureInstanceId=featureInstanceId,
+ )
+
+
_SOURCE_TYPE_TO_SERVICE = {
"sharepointFolder": "sharepoint",
"onedriveFolder": "onedrive",
@@ -701,7 +710,17 @@ async def _runWorkspaceAgent(
_toolSet = _cfg.get("toolSet", "core")
_agentCfg = _cfg.get("agentConfig")
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentConfig
- agentConfig = AgentConfig(**_agentCfg) if isinstance(_agentCfg, dict) else None
+
+ agentCfgDict = dict(_agentCfg) if isinstance(_agentCfg, dict) else {}
+ try:
+ wsIf = interfaceFeatureWorkspace.getInterface(user, mandateId=mandateId, featureInstanceId=instanceId)
+ userSettings = wsIf.getWorkspaceUserSettings(user.id if user else None)
+ if userSettings and userSettings.maxAgentRounds is not None:
+ agentCfgDict["maxRounds"] = userSettings.maxAgentRounds
+ except Exception as e:
+ logger.debug(f"Could not load workspace user settings for agent config: {e}")
+
+ agentConfig = AgentConfig(**agentCfgDict) if agentCfgDict else None
async for event in agentService.runAgent(
prompt=enrichedPrompt,
@@ -1575,13 +1594,13 @@ async def getVoiceSettings(
):
"""Load voice settings for the current user and instance."""
_validateInstanceAccess(instanceId, context)
- dbMgmt = _getDbManagement(context, instanceId)
+ wsInterface = _getWorkspaceInterface(context, instanceId)
userId = str(context.user.id)
try:
- vs = dbMgmt.getVoiceSettings(userId)
+ vs = wsInterface.getVoiceSettings(userId)
if not vs:
logger.info(f"GET voice settings: not found for user={userId}, creating defaults")
- vs = dbMgmt.getOrCreateVoiceSettings(userId)
+ vs = wsInterface.getOrCreateVoiceSettings(userId)
result = vs.model_dump() if vs else {}
mapKeys = list(result.get("ttsVoiceMap", {}).keys()) if result else []
logger.info(f"GET voice settings for user={userId}: ttsVoiceMap languages={mapKeys}")
@@ -1601,12 +1620,12 @@ async def updateVoiceSettings(
):
"""Update voice settings for the current user and instance."""
_validateInstanceAccess(instanceId, context)
- dbMgmt = _getDbManagement(context, instanceId)
+ wsInterface = _getWorkspaceInterface(context, instanceId)
userId = str(context.user.id)
try:
logger.info(f"PUT voice settings for user={userId}, instance={instanceId}, body keys={list(body.keys())}")
- vs = dbMgmt.getVoiceSettings(userId)
+ vs = wsInterface.getVoiceSettings(userId)
if not vs:
logger.info(f"No existing voice settings, creating new for user={userId}")
createData = {
@@ -1615,13 +1634,13 @@ async def updateVoiceSettings(
"featureInstanceId": instanceId,
}
createData.update(body)
- created = dbMgmt.createVoiceSettings(createData)
+ created = wsInterface.createVoiceSettings(createData)
logger.info(f"Created voice settings for user={userId}, ttsVoiceMap keys={list((created or {}).get('ttsVoiceMap', {}).keys())}")
return JSONResponse(created)
updateData = {k: v for k, v in body.items() if k not in ("id", "userId", "mandateId", "featureInstanceId", "creationDate")}
logger.info(f"Updating voice settings for user={userId}, update keys={list(updateData.keys())}")
- updated = dbMgmt.updateVoiceSettings(userId, updateData)
+ updated = wsInterface.updateVoiceSettings(userId, updateData)
logger.info(f"Updated voice settings for user={userId}, ttsVoiceMap keys={list((updated or {}).get('ttsVoiceMap', {}).keys())}")
return JSONResponse(updated)
except Exception as e:
@@ -1827,3 +1846,63 @@ async def rejectAllEdits(
logger.info(f"Rejected {len(rejected)} edits for instance {instanceId}")
return JSONResponse({"rejected": rejected})
+
+
+# =========================================================================
+# General Settings Endpoints (per-user workspace settings)
+# =========================================================================
+
+@router.get("/{instanceId}/settings/general")
+@limiter.limit("120/minute")
+async def getGeneralSettings(
+ request: Request,
+ instanceId: str = Path(...),
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Load general workspace settings for the current user, with effective values."""
+ _mandateId, instanceConfig = _validateInstanceAccess(instanceId, context)
+ wsInterface = _getWorkspaceInterface(context, instanceId)
+ userId = str(context.user.id)
+
+ userSettings = wsInterface.getWorkspaceUserSettings(userId)
+
+ agentCfg = (instanceConfig or {}).get("agentConfig", {})
+ instanceDefault = agentCfg.get("maxRounds", 25) if isinstance(agentCfg, dict) else 25
+
+ userOverride = userSettings.maxAgentRounds if userSettings else None
+ effective = userOverride if userOverride is not None else instanceDefault
+
+ return JSONResponse({
+ "maxAgentRounds": {
+ "effective": effective,
+ "userOverride": userOverride,
+ "instanceDefault": instanceDefault,
+ },
+ })
+
+
+@router.put("/{instanceId}/settings/general")
+@limiter.limit("120/minute")
+async def updateGeneralSettings(
+ request: Request,
+ instanceId: str = Path(...),
+ body: dict = Body(...),
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Update general workspace settings for the current user."""
+ _validateInstanceAccess(instanceId, context)
+ wsInterface = _getWorkspaceInterface(context, instanceId)
+ userId = str(context.user.id)
+
+ data = {
+ "userId": userId,
+ "mandateId": str(context.mandateId) if context.mandateId else "",
+ "featureInstanceId": instanceId,
+ }
+ if "maxAgentRounds" in body:
+ val = body["maxAgentRounds"]
+ data["maxAgentRounds"] = int(val) if val is not None else None
+
+ wsInterface.saveWorkspaceUserSettings(data)
+
+ return await getGeneralSettings(request, instanceId, context)
diff --git a/modules/interfaces/interfaceDbKnowledge.py b/modules/interfaces/interfaceDbKnowledge.py
index c8a597df..ae822db8 100644
--- a/modules/interfaces/interfaceDbKnowledge.py
+++ b/modules/interfaces/interfaceDbKnowledge.py
@@ -7,6 +7,8 @@ and semantic search via pgvector.
"""
import logging
+from collections import defaultdict
+from datetime import datetime, timezone, timedelta
from typing import Dict, Any, List, Optional
from modules.connectors.connectorDbPostgre import _get_cached_connector
diff --git a/modules/interfaces/interfaceDbManagement.py b/modules/interfaces/interfaceDbManagement.py
index ae0b14b0..64883b95 100644
--- a/modules/interfaces/interfaceDbManagement.py
+++ b/modules/interfaces/interfaceDbManagement.py
@@ -21,7 +21,6 @@ from modules.datamodels.datamodelUam import AccessLevel
from modules.datamodels.datamodelFiles import FilePreview, FileItem, FileData
from modules.datamodels.datamodelFileFolder import FileFolder
from modules.datamodels.datamodelUtils import Prompt
-from modules.datamodels.datamodelVoice import VoiceSettings
from modules.datamodels.datamodelMessaging import (
MessagingSubscription,
MessagingSubscriptionRegistration,
@@ -1723,158 +1722,6 @@ class ComponentObjects:
logger.error(f"Error in saveUploadedFile for {fileName}: {str(e)}", exc_info=True)
raise FileStorageError(f"Error saving file: {str(e)}")
- # VoiceSettings methods
-
- def getVoiceSettings(self, userId: Optional[str] = None) -> Optional[VoiceSettings]:
- """Returns voice settings for a user if user has access."""
- try:
- targetUserId = userId or self.userId
- if not targetUserId:
- logger.error("No user ID provided for voice settings")
- return None
-
- recordFilter: Dict[str, Any] = {"userId": targetUserId}
- if self.featureInstanceId:
- recordFilter["featureInstanceId"] = self.featureInstanceId
-
- # Get voice settings for the user (scoped to current feature instance if available), filtered by RBAC
- filteredSettings = getRecordsetWithRBAC(self.db,
- VoiceSettings,
- self.currentUser,
- recordFilter=recordFilter,
- mandateId=self.mandateId
- )
-
- if not filteredSettings:
- logger.warning(f"No access to voice settings for user {targetUserId}")
- return None
-
- # Ensure timestamps are set for validation
- settings_data = filteredSettings[0]
- if not settings_data.get("creationDate"):
- settings_data["creationDate"] = getUtcTimestamp()
- if not settings_data.get("lastModified"):
- settings_data["lastModified"] = getUtcTimestamp()
-
- return VoiceSettings(**settings_data)
-
- except Exception as e:
- logger.error(f"Error getting voice settings: {str(e)}")
- return None
-
- def createVoiceSettings(self, settingsData: Dict[str, Any]) -> Dict[str, Any]:
- """Creates voice settings for a user if user has permission."""
- try:
- if not self.checkRbacPermission(VoiceSettings, "update"):
- raise PermissionError("No permission to create voice settings")
-
- # Ensure userId is set
- if "userId" not in settingsData:
- settingsData["userId"] = self.userId
-
- # Ensure mandateId and featureInstanceId are set from context
- if "mandateId" not in settingsData:
- settingsData["mandateId"] = self.mandateId
- if "featureInstanceId" not in settingsData:
- settingsData["featureInstanceId"] = self.featureInstanceId
-
- # Check if settings already exist for this user
- existingSettings = self.getVoiceSettings(settingsData["userId"])
- if existingSettings:
- raise ValueError(f"Voice settings already exist for user {settingsData['userId']}")
-
- # Create voice settings record
- createdRecord = self.db.recordCreate(VoiceSettings, settingsData)
- if not createdRecord or not createdRecord.get("id"):
- raise ValueError("Failed to create voice settings record")
-
- logger.info(f"Created voice settings for user {settingsData['userId']}")
- return createdRecord
-
- except Exception as e:
- logger.error(f"Error creating voice settings: {str(e)}")
- raise
-
- def updateVoiceSettings(self, userId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
- """Updates voice settings for a user if user has access."""
- try:
- # Get existing settings
- existingSettings = self.getVoiceSettings(userId)
- if not existingSettings:
- raise ValueError(f"Voice settings not found for user {userId}")
-
- # Update lastModified timestamp
- updateData["lastModified"] = getUtcTimestamp()
-
- # Update voice settings record
- success = self.db.recordModify(VoiceSettings, existingSettings.id, updateData)
- if not success:
- raise ValueError("Failed to update voice settings record")
-
- # Get updated settings
- updatedSettings = self.getVoiceSettings(userId)
- if not updatedSettings:
- raise ValueError("Failed to retrieve updated voice settings")
-
- logger.info(f"Updated voice settings for user {userId}")
- return updatedSettings.model_dump()
-
- except Exception as e:
- logger.error(f"Error updating voice settings: {str(e)}")
- raise
-
- def deleteVoiceSettings(self, userId: str) -> bool:
- """Deletes voice settings for a user if user has access."""
- try:
- # Get existing settings
- existingSettings = self.getVoiceSettings(userId)
- if not existingSettings:
- logger.warning(f"Voice settings not found for user {userId}")
- return False
-
- # Delete voice settings
- success = self.db.recordDelete(VoiceSettings, existingSettings.id)
- if success:
- logger.info(f"Deleted voice settings for user {userId}")
- else:
- logger.error(f"Failed to delete voice settings for user {userId}")
-
- return success
-
- except Exception as e:
- logger.error(f"Error deleting voice settings: {str(e)}")
- return False
-
- def getOrCreateVoiceSettings(self, userId: Optional[str] = None) -> VoiceSettings:
- """Gets existing voice settings or creates default ones for a user."""
- try:
- targetUserId = userId or self.userId
- if not targetUserId:
- raise ValueError("No user ID provided for voice settings")
-
- # Try to get existing settings
- existingSettings = self.getVoiceSettings(targetUserId)
- if existingSettings:
- return existingSettings
-
- # Create default settings
- defaultSettings = {
- "userId": targetUserId,
- "mandateId": self.mandateId,
- "sttLanguage": "de-DE",
- "ttsLanguage": "de-DE",
- "ttsVoice": "de-DE-KatjaNeural",
- "translationEnabled": True,
- "targetLanguage": "en-US"
- }
-
- createdRecord = self.createVoiceSettings(defaultSettings)
- return VoiceSettings(**createdRecord)
-
- except Exception as e:
- logger.error(f"Error getting or creating voice settings: {str(e)}")
- raise
-
# Messaging Subscription methods
def getAllSubscriptions(self, pagination: Optional[PaginationParams] = None) -> Union[List[MessagingSubscription], PaginatedResult]:
diff --git a/modules/serviceCenter/registry.py b/modules/serviceCenter/registry.py
index be0accba..cdf57304 100644
--- a/modules/serviceCenter/registry.py
+++ b/modules/serviceCenter/registry.py
@@ -99,7 +99,7 @@ IMPORTABLE_SERVICES: Dict[str, Dict[str, Any]] = {
"label": {"en": "Web Research", "de": "Web-Recherche", "fr": "Recherche Web"},
},
"neutralization": {
- "module": "modules.serviceCenter.services.serviceNeutralization.mainServiceNeutralization",
+ "module": "modules.features.neutralization.serviceNeutralization.mainServiceNeutralization",
"class": "NeutralizationService",
"dependencies": ["extraction", "generation"],
"objectKey": "service.neutralization",
diff --git a/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py b/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py
index 03b8598e..78c69ff3 100644
--- a/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py
+++ b/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py
@@ -1432,21 +1432,55 @@ def _registerCoreTools(registry: ToolRegistry, services):
return ToolResult(toolCallId="", toolName="uploadToExternal", success=False, error=str(e))
async def _sendMail(args: Dict[str, Any], context: Dict[str, Any]):
+ import base64 as _b64
+
connectionId = args.get("connectionId", "")
to = args.get("to", [])
subject = args.get("subject", "")
body = args.get("body", "")
+ bodyType = "HTML" if args.get("bodyType", "text").lower() == "html" else "Text"
+ draft = args.get("draft", False)
+ attachmentFileIds = args.get("attachmentFileIds") or []
+
if not connectionId or not to or not subject:
return ToolResult(toolCallId="", toolName="sendMail", success=False, error="connectionId, to, and subject are required")
try:
+ graphAttachments: List[Dict[str, Any]] = []
+ if attachmentFileIds:
+ chatService = services.chat
+ dbMgmt = chatService.interfaceDbComponent
+ for fid in attachmentFileIds:
+ fileRow = dbMgmt.getFile(fid)
+ if not fileRow:
+ return ToolResult(toolCallId="", toolName="sendMail", success=False, error=f"Attachment file not found: {fid}")
+ rawBytes = dbMgmt.getFileData(fid)
+ if not rawBytes:
+ return ToolResult(toolCallId="", toolName="sendMail", success=False, error=f"Attachment file has no data: {fid}")
+ graphAttachments.append({
+ "name": fileRow.fileName,
+ "contentBytes": _b64.b64encode(rawBytes).decode("ascii"),
+ "contentType": getattr(fileRow, "mimeType", "application/octet-stream"),
+ })
+
from modules.connectors.connectorResolver import ConnectorResolver
resolver = ConnectorResolver(
services.getService("security"),
_buildResolverDb(),
)
adapter = await resolver.resolveService(connectionId, "outlook")
+
+ if draft and hasattr(adapter, "createDraft"):
+ result = await adapter.createDraft(
+ to=to, subject=subject, body=body, bodyType=bodyType,
+ cc=args.get("cc"), attachments=graphAttachments or None,
+ )
+ return ToolResult(toolCallId="", toolName="sendMail", success=True, data=str(result))
+
if hasattr(adapter, "sendMail"):
- result = await adapter.sendMail(to=to, subject=subject, body=body, cc=args.get("cc"))
+ result = await adapter.sendMail(
+ to=to, subject=subject, body=body, bodyType=bodyType,
+ cc=args.get("cc"), attachments=graphAttachments or None,
+ )
return ToolResult(toolCallId="", toolName="sendMail", success=True, data=str(result))
return ToolResult(toolCallId="", toolName="sendMail", success=False, error="Mail not supported by this adapter")
except Exception as e:
@@ -1484,15 +1518,26 @@ def _registerCoreTools(registry: ToolRegistry, services):
registry.register(
"sendMail", _sendMail,
- description="Send an email via a connected mail service (Outlook, Gmail). Use listConnections to find the connectionId.",
+ description=(
+ "Send or draft an email via a connected mail service (Outlook). "
+ "Supports HTML body and file attachments from the workspace. "
+ "Set draft=true to save as draft without sending. "
+ "Use listConnections to find the connectionId."
+ ),
parameters={
"type": "object",
"properties": {
"connectionId": {"type": "string", "description": "UserConnection ID"},
"to": {"type": "array", "items": {"type": "string"}, "description": "Recipient email addresses"},
"subject": {"type": "string", "description": "Email subject"},
- "body": {"type": "string", "description": "Email body text"},
+ "body": {"type": "string", "description": "Email body — plain text or HTML markup"},
+ "bodyType": {"type": "string", "enum": ["text", "html"], "description": "Body format: 'text' (default) or 'html'"},
"cc": {"type": "array", "items": {"type": "string"}, "description": "CC addresses"},
+ "attachmentFileIds": {
+ "type": "array", "items": {"type": "string"},
+ "description": "File IDs from the workspace to attach (use listFiles to find IDs)",
+ },
+ "draft": {"type": "boolean", "description": "If true, save as draft in Drafts folder instead of sending"},
},
"required": ["connectionId", "to", "subject", "body"],
},
@@ -2471,16 +2516,16 @@ def _registerCoreTools(registry: ToolRegistry, services):
if not voiceName:
try:
- from modules.interfaces import interfaceDbManagement
+ from modules.features.workspace import interfaceFeatureWorkspace
featureInstanceId = context.get("featureInstanceId", "")
userId = context.get("userId", "")
if userId:
- dbMgmt = interfaceDbManagement.getInterface(
+ wsIf = interfaceFeatureWorkspace.getInterface(
services.user,
mandateId=mandateId or None,
featureInstanceId=featureInstanceId or None,
)
- vs = dbMgmt.getVoiceSettings(userId) if dbMgmt and hasattr(dbMgmt, "getVoiceSettings") else None
+ vs = wsIf.getVoiceSettings(userId) if wsIf else None
if vs:
voiceMap = {}
if hasattr(vs, "ttsVoiceMap") and vs.ttsVoiceMap:
@@ -2914,6 +2959,8 @@ def _registerCoreTools(registry: ToolRegistry, services):
neutralizationService = services.getService("neutralization")
if not neutralizationService:
return ToolResult(toolCallId="", toolName="neutralizeData", success=False, error="Neutralization service not available")
+ if not neutralizationService.interfaceDbComponent:
+ neutralizationService.interfaceDbComponent = services.chat.interfaceDbComponent
if text:
result = neutralizationService.processText(text)
else:
diff --git a/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py b/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py
index 969ec6b8..790612ed 100644
--- a/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py
+++ b/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py
@@ -29,7 +29,7 @@ from modules.interfaces.interfaceDbBilling import getInterface as getBillingInte
logger = logging.getLogger(__name__)
# Markup percentage for internal pricing (+50% für Infrastruktur und Platform Service + 50% für Währungsrisiko ==> Faktor 2.0)
-BILLING_MARKUP_PERCENT = 100
+BILLING_MARKUP_PERCENT = 400
# Singleton cache
_billingServices: Dict[str, "BillingService"] = {}
diff --git a/modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py b/modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py
index b9a1481c..944da4f7 100644
--- a/modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py
+++ b/modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py
@@ -374,11 +374,13 @@ class SubscriptionService:
raise ValueError("Subscription is already cancelled (non-recurring)")
stripeSubId = sub.get("stripeSubscriptionId")
+ pUrl = ""
if stripeSubId:
try:
from modules.shared.stripeClient import getStripeClient
stripe = getStripeClient()
- stripe.Subscription.modify(stripeSubId, cancel_at_period_end=True)
+ stripeSub = stripe.Subscription.modify(stripeSubId, cancel_at_period_end=True)
+ pUrl = (stripeSub.get("metadata") or {}).get("platformUrl", "")
except Exception as e:
logger.error("Failed to set cancel_at_period_end for %s: %s", stripeSubId, e)
@@ -386,7 +388,7 @@ class SubscriptionService:
self.invalidateCache(mandateId)
plan = _getPlan(sub.get("planKey", ""))
- _notifySubscriptionChange(mandateId, "cancelled", plan)
+ _notifySubscriptionChange(mandateId, "cancelled", plan, subscriptionRecord=sub, platformUrl=pUrl)
return result
# =========================================================================
@@ -435,16 +437,23 @@ class SubscriptionService:
raise ValueError(f"Subscription {subscriptionId} not found")
stripeSubId = sub.get("stripeSubscriptionId")
+ pUrl = ""
if stripeSubId:
try:
from modules.shared.stripeClient import getStripeClient
stripe = getStripeClient()
+ stripeSub = stripe.Subscription.retrieve(stripeSubId)
+ pUrl = (stripeSub.get("metadata") or {}).get("platformUrl", "")
stripe.Subscription.cancel(stripeSubId)
except Exception as e:
logger.error("Failed to cancel Stripe sub %s: %s", stripeSubId, e)
result = self._interface.forceExpire(subscriptionId)
- self.invalidateCache(sub["mandateId"])
+ mandateId = sub["mandateId"]
+ self.invalidateCache(mandateId)
+
+ plan = _getPlan(sub.get("planKey", ""))
+ _notifySubscriptionChange(mandateId, "force_cancelled", plan, subscriptionRecord=sub, platformUrl=pUrl)
return result
# =========================================================================
@@ -496,6 +505,8 @@ def _notifySubscriptionChange(
if event == "activated" and plan and subscriptionRecord:
rawHtmlBlock = _buildInvoiceSummaryHtml(plan, subscriptionRecord, mandateId, platformUrl)
+ elif event in ("cancelled", "force_cancelled") and subscriptionRecord:
+ rawHtmlBlock = _buildCancelSummaryHtml(subscriptionRecord, platformUrl)
templates: Dict[str, Dict[str, Any]] = {
"activated": {
@@ -520,6 +531,17 @@ def _notifySubscriptionChange(
] if p
],
},
+ "force_cancelled": {
+ "subject": f"[PowerOn] Abonnement sofort beendet — {planLabel}",
+ "headline": "Abonnement sofort beendet",
+ "paragraphs": [
+ p for p in [
+ f"Das Abonnement «{planLabel}» wurde durch den Plattform-Administrator sofort beendet.",
+ platformHint,
+ "Der Zugang wurde per sofort deaktiviert. Bei Fragen wenden Sie sich an den Plattform-Support.",
+ ] if p
+ ],
+ },
"trial_expired": {
"subject": "[PowerOn] Testphase abgelaufen",
"headline": "Testphase abgelaufen",
@@ -633,6 +655,32 @@ def _buildInvoiceSummaryHtml(
)
+def _buildCancelSummaryHtml(subRecord: Dict[str, Any], platformUrl: str = "") -> str:
+ """Build an HTML block with billing link and Stripe invoice link for cancel emails."""
+ import html as htmlmod
+
+ parts: list[str] = []
+
+ stripeSubId = subRecord.get("stripeSubscriptionId")
+ if stripeSubId:
+ try:
+ from modules.shared.stripeClient import getStripeClient
+ stripe = getStripeClient()
+ invoices = stripe.Invoice.list(subscription=stripeSubId, limit=1)
+ if invoices.data:
+ hostedUrl = invoices.data[0].get("hosted_invoice_url", "")
+ if hostedUrl:
+ parts.append(
+ f''
+ f''
+ f'Letzte Stripe-Rechnung anzeigen
'
+ )
+ except Exception as e:
+ logger.warning("Could not fetch Stripe invoice URL for sub %s: %s", stripeSubId, e)
+
+ return "\n".join(parts) if parts else ""
+
+
# ============================================================================
# Exception Classes
# ============================================================================
diff --git a/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py b/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py
index fd5666e0..38ac29e1 100644
--- a/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py
+++ b/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py
@@ -105,6 +105,43 @@ def _findExistingStripePrice(stripe, productId: str, unitAmount: int, interval:
return None
+def _getStripePriceAmount(stripe, priceId: str) -> Optional[int]:
+ """Retrieve the unit_amount (in Rappen) of an existing Stripe Price."""
+ try:
+ price = stripe.Price.retrieve(priceId)
+ return price.get("unit_amount") if price else None
+ except Exception:
+ return None
+
+
+def _reconcilePrice(stripe, productId: str, oldPriceId: str, expectedCHF: float, interval: str, nickname: str) -> str:
+ """If the stored Stripe Price has a different amount, create a new one and deactivate the old."""
+ expectedCents = int(expectedCHF * 100)
+ actualCents = _getStripePriceAmount(stripe, oldPriceId)
+
+ if actualCents == expectedCents:
+ return oldPriceId
+
+ logger.warning(
+ "Price drift detected for %s: Stripe has %s Rappen, catalog expects %s Rappen. Rotating price.",
+ oldPriceId, actualCents, expectedCents,
+ )
+
+ existingMatch = _findExistingStripePrice(stripe, productId, expectedCents, interval)
+ if existingMatch:
+ newPriceId = existingMatch
+ else:
+ newPriceId = _createStripePrice(stripe, productId, expectedCHF, interval, nickname)
+
+ try:
+ stripe.Price.modify(oldPriceId, active=False)
+ logger.info("Deactivated old Stripe Price %s", oldPriceId)
+ except Exception as e:
+ logger.warning("Could not deactivate old price %s: %s", oldPriceId, e)
+
+ return newPriceId
+
+
def _createStripePrice(stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str) -> str:
price = stripe.Price.create(
product=productId,
@@ -146,7 +183,29 @@ def bootstrapStripePrices() -> None:
hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances
hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances
if hasAllPrices and hasAllProducts:
- logger.debug("Stripe prices already configured for plan %s", planKey)
+ changed = False
+ reconciledUsers = _reconcilePrice(
+ stripe, mapping.stripeProductIdUsers, mapping.stripePriceIdUsers,
+ plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
+ )
+ if reconciledUsers != mapping.stripePriceIdUsers:
+ changed = True
+
+ reconciledInstances = _reconcilePrice(
+ stripe, mapping.stripeProductIdInstances, mapping.stripePriceIdInstances,
+ plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Feature-Instanz",
+ )
+ if reconciledInstances != mapping.stripePriceIdInstances:
+ changed = True
+
+ if changed:
+ db.recordModify(StripePlanPrice, mapping.id, {
+ "stripePriceIdUsers": reconciledUsers,
+ "stripePriceIdInstances": reconciledInstances,
+ })
+ logger.info("Reconciled Stripe prices for plan %s: users=%s, instances=%s", planKey, reconciledUsers, reconciledInstances)
+ else:
+ logger.debug("Stripe prices up-to-date for plan %s", planKey)
continue
productIdUsers = None
From da096bb6e294503bf325847177ce357dd3c5c8f7 Mon Sep 17 00:00:00 2001
From: ValueOn AG
Date: Mon, 23 Mar 2026 00:17:59 +0100
Subject: [PATCH 5/5] rag stats
---
modules/features/workspace/mainWorkspace.py | 11 ++
.../workspace/routeFeatureWorkspace.py | 51 +++++
modules/interfaces/interfaceDbKnowledge.py | 177 ++++++++++++++++++
modules/routes/routeDataFiles.py | 21 ++-
4 files changed, 259 insertions(+), 1 deletion(-)
diff --git a/modules/features/workspace/mainWorkspace.py b/modules/features/workspace/mainWorkspace.py
index 81526414..353129cc 100644
--- a/modules/features/workspace/mainWorkspace.py
+++ b/modules/features/workspace/mainWorkspace.py
@@ -31,6 +31,15 @@ UI_OBJECTS = [
"label": {"en": "Settings", "de": "Einstellungen", "fr": "Parametres"},
"meta": {"area": "settings"}
},
+ {
+ "objectKey": "ui.feature.workspace.rag-insights",
+ "label": {
+ "en": "Knowledge insights",
+ "de": "Wissens-Insights",
+ "fr": "Aperçu des connaissances",
+ },
+ "meta": {"area": "rag-insights"},
+ },
]
RESOURCE_OBJECTS = [
@@ -83,6 +92,7 @@ TEMPLATE_ROLES = [
{"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True},
{"context": "UI", "item": "ui.feature.workspace.editor", "view": True},
{"context": "UI", "item": "ui.feature.workspace.settings", "view": True},
+ {"context": "UI", "item": "ui.feature.workspace.rag-insights", "view": True},
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "n", "update": "n", "delete": "n"},
]
},
@@ -97,6 +107,7 @@ TEMPLATE_ROLES = [
{"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True},
{"context": "UI", "item": "ui.feature.workspace.editor", "view": True},
{"context": "UI", "item": "ui.feature.workspace.settings", "view": True},
+ {"context": "UI", "item": "ui.feature.workspace.rag-insights", "view": True},
{"context": "RESOURCE", "item": "resource.feature.workspace.start", "view": True},
{"context": "RESOURCE", "item": "resource.feature.workspace.stop", "view": True},
{"context": "RESOURCE", "item": "resource.feature.workspace.files", "view": True},
diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py
index dd8481ff..6b8c529b 100644
--- a/modules/features/workspace/routeFeatureWorkspace.py
+++ b/modules/features/workspace/routeFeatureWorkspace.py
@@ -24,6 +24,7 @@ from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription
)
from modules.interfaces import interfaceDbChat, interfaceDbManagement
from modules.features.workspace import interfaceFeatureWorkspace
+from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
from modules.interfaces.interfaceAiObjects import AiObjects
from modules.serviceCenter.core.serviceStreaming import get_event_manager
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentEventTypeEnum, PendingFileEdit
@@ -1906,3 +1907,53 @@ async def updateGeneralSettings(
wsInterface.saveWorkspaceUserSettings(data)
return await getGeneralSettings(request, instanceId, context)
+
+
+# =========================================================================
+# RAG / Knowledge — anonymised instance statistics (presentation / KPIs)
+# =========================================================================
+
+def _collectWorkspaceFileIdsForStats(instanceId: str, mandateId: Optional[str]) -> List[str]:
+ """All FileItem ids for this feature instance (any user). Knowledge rows are often stored
+ without featureInstanceId; we correlate by file id from the Management DB."""
+ from modules.datamodels.datamodelFiles import FileItem
+ from modules.interfaces.interfaceDbManagement import ComponentObjects
+
+ co = ComponentObjects()
+ rows = co.db.getRecordset(FileItem, recordFilter={"featureInstanceId": instanceId})
+ out: List[str] = []
+ m = str(mandateId) if mandateId else ""
+ for r in rows or []:
+ rid = r.get("id") if isinstance(r, dict) else getattr(r, "id", None)
+ if not rid:
+ continue
+ if m:
+ mid = r.get("mandateId") if isinstance(r, dict) else getattr(r, "mandateId", "") or ""
+ if mid and mid != m:
+ continue
+ out.append(str(rid))
+ return out
+
+
+@router.get("/{instanceId}/rag-statistics")
+@limiter.limit("60/minute")
+async def getRagStatistics(
+ request: Request,
+ instanceId: str = Path(...),
+ days: int = Query(90, ge=7, le=365, description="Timeline window in days"),
+ context: RequestContext = Depends(getRequestContext),
+):
+ """Aggregated, non-identifying knowledge-store metrics for this workspace instance."""
+ mandateId, _instanceConfig = _validateInstanceAccess(instanceId, context)
+ workspaceFileIds = _collectWorkspaceFileIdsForStats(instanceId, mandateId)
+ kdb = getKnowledgeInterface(context.user)
+ stats = kdb.getRagStatisticsForInstance(
+ featureInstanceId=instanceId,
+ mandateId=str(mandateId) if mandateId else "",
+ timelineDays=days,
+ workspaceFileIds=workspaceFileIds,
+ )
+ if isinstance(stats, dict):
+ stats.setdefault("scope", {})
+ stats["scope"]["workspaceFileIdsResolved"] = len(workspaceFileIds)
+ return JSONResponse(stats)
diff --git a/modules/interfaces/interfaceDbKnowledge.py b/modules/interfaces/interfaceDbKnowledge.py
index ae822db8..adf8ed0a 100644
--- a/modules/interfaces/interfaceDbKnowledge.py
+++ b/modules/interfaces/interfaceDbKnowledge.py
@@ -288,6 +288,183 @@ class KnowledgeObjects:
minScore=minScore,
)
+ def getRagStatisticsForInstance(
+ self,
+ featureInstanceId: str,
+ mandateId: str,
+ timelineDays: int = 90,
+ workspaceFileIds: Optional[List[str]] = None,
+ ) -> Dict[str, Any]:
+ """Aggregate anonymised RAG / knowledge-store metrics for one workspace instance.
+
+ No file names, user identifiers, or chunk text are returned — only counts and
+ distributions suitable for dashboards and presentations.
+
+ workspaceFileIds: optional list of FileItem ids for this feature instance (from Management DB).
+ Index pipelines often stored rows with empty featureInstanceId; linking by file id fixes stats.
+ """
+ if not featureInstanceId:
+ return {"error": "featureInstanceId required"}
+
+ ws_ids = [x for x in (workspaceFileIds or []) if x]
+ ws_id_set = set(ws_ids)
+
+ files_inst = self.db.getRecordset(
+ FileContentIndex,
+ recordFilter={"featureInstanceId": featureInstanceId},
+ )
+ files_shared: List[Dict[str, Any]] = []
+ if mandateId:
+ files_shared = self.db.getRecordset(
+ FileContentIndex,
+ recordFilter={"mandateId": mandateId, "isShared": True},
+ )
+
+ by_id: Dict[str, Dict[str, Any]] = {}
+ for row in files_inst + files_shared:
+ rid = row.get("id")
+ if rid and rid not in by_id:
+ by_id[rid] = row
+
+ for fid in ws_ids:
+ if fid in by_id:
+ continue
+ row = self.getFileContentIndex(fid)
+ if row:
+ by_id[fid] = row
+
+ files = list(by_id.values())
+
+ chunks_by_id: Dict[str, Dict[str, Any]] = {}
+ inst_chunks = self.db.getRecordset(
+ ContentChunk,
+ recordFilter={"featureInstanceId": featureInstanceId},
+ )
+ for c in inst_chunks:
+ cid = c.get("id")
+ if cid:
+ chunks_by_id[cid] = c
+
+ for fid in ws_id_set:
+ for c in self.getContentChunks(fid):
+ cid = c.get("id")
+ if cid and cid not in chunks_by_id:
+ chunks_by_id[cid] = c
+
+ covered_file_ids = {c.get("fileId") for c in chunks_by_id.values() if c.get("fileId")}
+ for row in files:
+ fid = row.get("id")
+ if fid and fid not in covered_file_ids:
+ for c in self.getContentChunks(fid):
+ cid = c.get("id")
+ if cid and cid not in chunks_by_id:
+ chunks_by_id[cid] = c
+
+ chunks = list(chunks_by_id.values())
+
+ def _mimeCategory(mime: str) -> str:
+ m = (mime or "").lower()
+ if "pdf" in m:
+ return "pdf"
+ if "wordprocessing" in m or "msword" in m or "officedocument.wordprocessing" in m:
+ return "office_doc"
+ if "spreadsheet" in m or "excel" in m or "officedocument.spreadsheet" in m:
+ return "office_sheet"
+ if "presentation" in m or "officedocument.presentation" in m:
+ return "office_slides"
+ if m.startswith("text/"):
+ return "text"
+ if m.startswith("image/"):
+ return "image"
+ if "html" in m:
+ return "html"
+ return "other"
+
+ def _utcDay(ts: Any) -> str:
+ if ts is None:
+ return ""
+ try:
+ return datetime.fromtimestamp(float(ts), tz=timezone.utc).strftime("%Y-%m-%d")
+ except (TypeError, ValueError, OSError):
+ return ""
+
+ status_counts: Dict[str, int] = defaultdict(int)
+ mime_counts: Dict[str, int] = defaultdict(int)
+ extracted_by_day: Dict[str, int] = defaultdict(int)
+ total_bytes = 0
+ user_ids = set()
+
+ for row in files:
+ st = row.get("status") or "unknown"
+ status_counts[st] += 1
+ mime_counts[_mimeCategory(row.get("mimeType") or "")] += 1
+ day = _utcDay(row.get("extractedAt"))
+ if day:
+ extracted_by_day[day] += 1
+ try:
+ total_bytes += int(row.get("totalSize") or 0)
+ except (TypeError, ValueError):
+ pass
+ uid = row.get("userId")
+ if uid:
+ user_ids.add(str(uid))
+
+ content_type_counts: Dict[str, int] = defaultdict(int)
+ chunks_with_embedding = 0
+ for c in chunks:
+ ct = c.get("contentType") or "other"
+ content_type_counts[ct] += 1
+ emb = c.get("embedding")
+ if emb is not None and (
+ (isinstance(emb, list) and len(emb) > 0)
+ or (isinstance(emb, str) and len(emb) > 10)
+ ):
+ chunks_with_embedding += 1
+
+ wf_mem = self.db.getRecordset(
+ WorkflowMemory,
+ recordFilter={"featureInstanceId": featureInstanceId},
+ )
+
+ cutoff = datetime.now(timezone.utc) - timedelta(days=max(1, int(timelineDays)))
+ cutoff_ts = cutoff.timestamp()
+
+ timeline: List[Dict[str, Any]] = []
+ for day in sorted(extracted_by_day.keys()):
+ try:
+ d = datetime.strptime(day, "%Y-%m-%d").replace(tzinfo=timezone.utc)
+ except ValueError:
+ continue
+ if d.timestamp() >= cutoff_ts:
+ timeline.append({"date": day, "indexedDocuments": extracted_by_day[day]})
+
+ if len(timeline) > 120:
+ timeline = timeline[-120:]
+
+ total_chunks = len(chunks)
+ embedding_pct = round(100.0 * chunks_with_embedding / total_chunks, 1) if total_chunks else 0.0
+
+ return {
+ "scope": {
+ "featureInstanceId": featureInstanceId,
+ "mandateScopedShared": bool(mandateId),
+ },
+ "kpis": {
+ "indexedDocuments": len(files),
+ "indexedBytesTotal": total_bytes,
+ "contributorUsers": len(user_ids),
+ "contentChunks": total_chunks,
+ "chunksWithEmbedding": chunks_with_embedding,
+ "embeddingCoveragePercent": embedding_pct,
+ "workflowEntities": len(wf_mem),
+ },
+ "indexedDocumentsByStatus": dict(sorted(status_counts.items())),
+ "documentsByMimeCategory": dict(sorted(mime_counts.items(), key=lambda x: -x[1])),
+ "chunksByContentType": dict(sorted(content_type_counts.items())),
+ "timelineIndexedDocuments": timeline,
+ "generatedAtUtc": datetime.now(timezone.utc).isoformat(),
+ }
+
def getInterface(currentUser: Optional[User] = None) -> KnowledgeObjects:
"""Get or create a KnowledgeObjects singleton."""
diff --git a/modules/routes/routeDataFiles.py b/modules/routes/routeDataFiles.py
index bb77ffc5..999d07df 100644
--- a/modules/routes/routeDataFiles.py
+++ b/modules/routes/routeDataFiles.py
@@ -37,6 +37,17 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
mgmtInterface.updateFile(fileId, {"status": "active"})
return
+ file_meta = mgmtInterface.getFile(fileId)
+ feature_instance_id = ""
+ mandate_id = ""
+ if file_meta:
+ if isinstance(file_meta, dict):
+ feature_instance_id = file_meta.get("featureInstanceId") or ""
+ mandate_id = file_meta.get("mandateId") or ""
+ else:
+ feature_instance_id = getattr(file_meta, "featureInstanceId", None) or ""
+ mandate_id = getattr(file_meta, "mandateId", None) or ""
+
logger.info(f"Auto-index starting for {fileName} ({len(rawBytes)} bytes, {mimeType})")
# Step 1: Structure Pre-Scan (AI-free)
@@ -47,6 +58,8 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
fileId=fileId,
fileName=fileName,
userId=userId,
+ featureInstanceId=str(feature_instance_id) if feature_instance_id else "",
+ mandateId=str(mandate_id) if mandate_id else "",
)
logger.info(
f"Pre-scan complete for {fileName}: "
@@ -105,7 +118,11 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
from modules.serviceCenter import getService
from modules.serviceCenter.context import ServiceCenterContext
- ctx = ServiceCenterContext(user=user, mandate_id="", feature_instance_id="")
+ ctx = ServiceCenterContext(
+ user=user,
+ mandate_id=str(mandate_id) if mandate_id else "",
+ feature_instance_id=str(feature_instance_id) if feature_instance_id else "",
+ )
knowledgeService = getService("knowledge", ctx)
await knowledgeService.indexFile(
@@ -113,6 +130,8 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
fileName=fileName,
mimeType=mimeType,
userId=userId,
+ featureInstanceId=str(feature_instance_id) if feature_instance_id else "",
+ mandateId=str(mandate_id) if mandate_id else "",
contentObjects=contentObjects,
structure=contentIndex.structure,
)