diff --git a/app.py b/app.py index 0c769a2a..8268377a 100644 --- a/app.py +++ b/app.py @@ -21,6 +21,8 @@ from datetime import datetime from modules.shared.configuration import APP_CONFIG from modules.shared.eventManagement import eventManager from modules.workflows.automation import subAutomationSchedule +from modules.features.automation2.emailPoller import start as startAutomation2EmailPoller +from modules.features.automation2.emailPoller import stop as stopAutomation2EmailPoller from modules.interfaces.interfaceDbApp import getRootInterface from modules.system.registry import loadFeatureMainModules @@ -354,6 +356,7 @@ async def lifespan(app: FastAPI): # --- Init Managers --- subAutomationSchedule.start(eventUser) # Automation scheduler + # Automation2 email poller: started on-demand when a run pauses for email.checkEmail eventManager.start() # Register audit log cleanup scheduler @@ -382,6 +385,7 @@ async def lifespan(app: FastAPI): yield # --- Stop Managers --- + stopAutomation2EmailPoller(eventUser) # Automation2 email poller (no-op if not running) eventManager.stop() subAutomationSchedule.stop(eventUser) # Automation scheduler @@ -568,6 +572,9 @@ app.include_router(sharepointRouter) from modules.routes.routeAdminAutomationEvents import router as adminAutomationEventsRouter app.include_router(adminAutomationEventsRouter) +from modules.routes.routeAdminAutomationLogs import router as adminAutomationLogsRouter +app.include_router(adminAutomationLogsRouter) + from modules.routes.routeAdminLogs import router as adminLogsRouter app.include_router(adminLogsRouter) @@ -601,6 +608,9 @@ app.include_router(gdprRouter) from modules.routes.routeBilling import router as billingRouter app.include_router(billingRouter) +from modules.routes.routeSubscription import router as subscriptionRouter +app.include_router(subscriptionRouter) + # ============================================================================ # SYSTEM ROUTES (Navigation, etc.) # ============================================================================ diff --git a/config.ini b/config.ini index ccd6b77e..4a37f2f8 100644 --- a/config.ini +++ b/config.ini @@ -44,3 +44,8 @@ Connector_StacSwisstopo_TIMEOUT = 30 Connector_StacSwisstopo_MAX_RETRIES = 3 Connector_StacSwisstopo_RETRY_DELAY = 1.0 Connector_StacSwisstopo_ENABLE_CACHE = True + +# Operator company information (shown on invoice emails) +Operator_CompanyName = PowerOn AG +Operator_Address = Birmensdorferstrasse 94, 8003 Zürich +Operator_VatNumber = CHE491.960.195 diff --git a/modules/connectors/connectorDbPostgre.py b/modules/connectors/connectorDbPostgre.py index 83517f31..67cceb45 100644 --- a/modules/connectors/connectorDbPostgre.py +++ b/modules/connectors/connectorDbPostgre.py @@ -1,6 +1,7 @@ # Copyright (c) 2025 Patrick Motsch # All rights reserved. import contextvars +import re import psycopg2 import psycopg2.extras import logging @@ -92,19 +93,27 @@ def _get_model_fields(model_class) -> Dict[str, str]: fields[field_name] = extra["db_type"] continue - # Check for JSONB fields (Dict, List, or complex types) + # Unwrap Optional[X] → X (handles both typing.Union and types.UnionType) + origin = get_origin(field_type) + if origin is Union: + args = [a for a in get_args(field_type) if a is not type(None)] + if len(args) == 1: + field_type = args[0] + elif hasattr(field_type, '__args__') and type(None) in getattr(field_type, '__args__', ()): + args = [a for a in field_type.__args__ if a is not type(None)] + if len(args) == 1: + field_type = args[0] + if _isJsonbType(field_type): fields[field_name] = "JSONB" - elif field_type in (str, type(None)) or ( - get_origin(field_type) is Union and type(None) in get_args(field_type) - ): - fields[field_name] = "TEXT" - elif field_type == int: - fields[field_name] = "INTEGER" - elif field_type == float: - fields[field_name] = "DOUBLE PRECISION" - elif field_type == bool: + elif field_type is bool: fields[field_name] = "BOOLEAN" + elif field_type is int: + fields[field_name] = "INTEGER" + elif field_type is float: + fields[field_name] = "DOUBLE PRECISION" + elif field_type in (str, type(None)): + fields[field_name] = "TEXT" else: fields[field_name] = "TEXT" @@ -135,6 +144,9 @@ def _parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context: elif isinstance(value, list): pass # already a list + elif fieldType == "BOOLEAN": + record[fieldName] = bool(value) if value is not None else False + elif fieldType == "JSONB" and value is not None: try: if isinstance(value, str): @@ -969,6 +981,252 @@ class DatabaseConnector: logger.error(f"Error loading records from table {table}: {e}") return [] + def _buildPaginationClauses( + self, + model_class: type, + pagination, + recordFilter: Dict[str, Any] = None, + ): + """ + Translate PaginationParams + recordFilter into SQL clauses. + Returns (where_clause, order_clause, limit_clause, values, count_values). + """ + fields = _get_model_fields(model_class) + fields["_createdAt"] = "DOUBLE PRECISION" + fields["_modifiedAt"] = "DOUBLE PRECISION" + fields["_createdBy"] = "TEXT" + fields["_modifiedBy"] = "TEXT" + validColumns = set(fields.keys()) + where_parts: List[str] = [] + values: List[Any] = [] + + if recordFilter: + for field, value in recordFilter.items(): + if value is None: + where_parts.append(f'"{field}" IS NULL') + elif isinstance(value, list): + where_parts.append(f'"{field}" = ANY(%s)') + values.append(value) + else: + where_parts.append(f'"{field}" = %s') + values.append(value) + + if pagination and pagination.filters: + for key, val in pagination.filters.items(): + if key == "search" and isinstance(val, str) and val.strip(): + term = f"%{val.strip()}%" + textCols = [c for c, t in fields.items() if t == "TEXT"] + if textCols: + orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols] + where_parts.append(f"({' OR '.join(orParts)})") + values.extend([term] * len(textCols)) + continue + if key not in validColumns: + logger.debug(f"_buildPaginationClauses: key '{key}' NOT in validColumns {list(validColumns)[:10]}") + continue + colType = fields.get(key, "TEXT") + logger.debug(f"_buildPaginationClauses: filter key='{key}' val={val!r} type(val)={type(val).__name__} colType={colType}") + if isinstance(val, dict): + op = val.get("operator", "equals") + v = val.get("value", "") + if op in ("equals", "eq"): + if colType == "BOOLEAN": + where_parts.append(f'COALESCE("{key}", FALSE) = %s') + values.append(str(v).lower() == "true") + else: + where_parts.append(f'"{key}"::TEXT = %s') + values.append(str(v)) + elif op == "contains": + where_parts.append(f'"{key}"::TEXT ILIKE %s') + values.append(f"%{v}%") + elif op == "startsWith": + where_parts.append(f'"{key}"::TEXT ILIKE %s') + values.append(f"{v}%") + elif op == "endsWith": + where_parts.append(f'"{key}"::TEXT ILIKE %s') + values.append(f"%{v}") + elif op in ("gt", "gte", "lt", "lte"): + sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op] + where_parts.append(f'"{key}"::TEXT {sqlOp} %s') + values.append(str(v)) + elif op == "between": + fromVal = v.get("from", "") if isinstance(v, dict) else "" + toVal = v.get("to", "") if isinstance(v, dict) else "" + if not fromVal and not toVal: + continue + colType = fields.get(key, "TEXT") + isNumericCol = colType in ("INTEGER", "DOUBLE PRECISION") + isDateVal = bool(fromVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(fromVal))) or \ + bool(toVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(toVal))) + if isNumericCol and isDateVal: + from datetime import datetime as _dt, timezone as _tz + if fromVal and toVal: + fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp() + toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp() + where_parts.append(f'"{key}" >= %s AND "{key}" <= %s') + values.extend([fromTs, toTs]) + elif fromVal: + fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp() + where_parts.append(f'"{key}" >= %s') + values.append(fromTs) + else: + toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp() + where_parts.append(f'"{key}" <= %s') + values.append(toTs) + else: + if fromVal and toVal: + where_parts.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s') + values.extend([str(fromVal), str(toVal)]) + elif fromVal: + where_parts.append(f'"{key}"::TEXT >= %s') + values.append(str(fromVal)) + elif toVal: + where_parts.append(f'"{key}"::TEXT <= %s') + values.append(str(toVal)) + else: + if colType == "BOOLEAN": + where_parts.append(f'COALESCE("{key}", FALSE) = %s') + values.append(str(val).lower() == "true") + else: + where_parts.append(f'"{key}"::TEXT ILIKE %s') + values.append(str(val)) + + where_clause = " WHERE " + " AND ".join(where_parts) if where_parts else "" + count_values = list(values) + + orderParts: List[str] = [] + if pagination and pagination.sort: + for sf in pagination.sort: + if sf.field in validColumns: + direction = "DESC" if sf.direction.lower() == "desc" else "ASC" + colType = fields.get(sf.field, "TEXT") + if colType == "BOOLEAN": + orderParts.append(f'COALESCE("{sf.field}", FALSE) {direction}') + else: + orderParts.append(f'"{sf.field}" {direction} NULLS LAST') + if not orderParts: + orderParts.append('"id"') + order_clause = " ORDER BY " + ", ".join(orderParts) + + limit_clause = "" + if pagination: + offset = (pagination.page - 1) * pagination.pageSize + limit_clause = f" LIMIT {pagination.pageSize} OFFSET {offset}" + + return where_clause, order_clause, limit_clause, values, count_values + + def getRecordsetPaginated( + self, + model_class: type, + pagination=None, + recordFilter: Dict[str, Any] = None, + fieldFilter: List[str] = None, + ) -> Dict[str, Any]: + """ + Returns paginated records with filtering + sorting at the SQL level. + Returns { "items": [...], "totalItems": int, "totalPages": int }. + If pagination is None, returns all records (no LIMIT/OFFSET). + """ + from modules.datamodels.datamodelPagination import PaginationParams + import math + + table = model_class.__name__ + + try: + if not self._ensureTableExists(model_class): + return {"items": [], "totalItems": 0, "totalPages": 0} + + where_clause, order_clause, limit_clause, values, count_values = \ + self._buildPaginationClauses(model_class, pagination, recordFilter) + + with self.connection.cursor() as cursor: + countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}' + cursor.execute(countSql, count_values) + totalItems = cursor.fetchone()["count"] + + dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}' + cursor.execute(dataSql, values) + records = [dict(row) for row in cursor.fetchall()] + + fields = _get_model_fields(model_class) + modelFields = model_class.model_fields + for record in records: + _parseRecordFields(record, fields, f"table {table}") + for fieldName, fieldType in fields.items(): + if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: + fieldInfo = modelFields.get(fieldName) + if fieldInfo: + fieldAnnotation = fieldInfo.annotation + if (fieldAnnotation == list or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is list)): + record[fieldName] = [] + elif (fieldAnnotation == dict or + (hasattr(fieldAnnotation, "__origin__") and + fieldAnnotation.__origin__ is dict)): + record[fieldName] = {} + + if fieldFilter and isinstance(fieldFilter, list): + records = [{f: r[f] for f in fieldFilter if f in r} for r in records] + + pageSize = pagination.pageSize if pagination else max(totalItems, 1) + totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0 + + return {"items": records, "totalItems": totalItems, "totalPages": totalPages} + except Exception as e: + logger.error(f"Error in getRecordsetPaginated for table {table}: {e}") + return {"items": [], "totalItems": 0, "totalPages": 0} + + def getDistinctColumnValues( + self, + model_class: type, + column: str, + pagination=None, + recordFilter: Dict[str, Any] = None, + ) -> List[str]: + """ + Returns sorted distinct non-null values for a column using SQL DISTINCT. + Applies cross-filtering (all filters except the requested column). + """ + table = model_class.__name__ + fields = _get_model_fields(model_class) + fields["_createdAt"] = "DOUBLE PRECISION" + fields["_modifiedAt"] = "DOUBLE PRECISION" + fields["_createdBy"] = "TEXT" + fields["_modifiedBy"] = "TEXT" + + if column not in fields: + return [] + + try: + if not self._ensureTableExists(model_class): + return [] + + if pagination: + if pagination.filters and column in pagination.filters: + import copy + pagination = copy.deepcopy(pagination) + pagination.filters.pop(column, None) + + where_clause, _, _, values, _ = \ + self._buildPaginationClauses(model_class, pagination, recordFilter) + + sql = ( + f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} ' + f'WHERE "{column}" IS NOT NULL AND "{column}"::TEXT != \'\' ' + if not where_clause else + f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} ' + f'AND "{column}" IS NOT NULL AND "{column}"::TEXT != \'\' ' + ) + sql += 'ORDER BY val' + + with self.connection.cursor() as cursor: + cursor.execute(sql, values) + return [row["val"] for row in cursor.fetchall()] + except Exception as e: + logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}") + return [] + def recordCreate( self, model_class: type, record: Union[Dict[str, Any], BaseModel] ) -> Dict[str, Any]: diff --git a/modules/connectors/providerMsft/connectorMsft.py b/modules/connectors/providerMsft/connectorMsft.py index 3654bad9..a51fa231 100644 --- a/modules/connectors/providerMsft/connectorMsft.py +++ b/modules/connectors/providerMsft/connectorMsft.py @@ -299,26 +299,65 @@ class OutlookAdapter(_GraphApiMixin, ServiceAdapter): for m in result.get("value", []) ] - async def sendMail( + def _buildMessage( self, to: List[str], subject: str, body: str, - cc: Optional[List[str]] = None, attachments: Optional[List[Dict]] = None + bodyType: str = "Text", + cc: Optional[List[str]] = None, + attachments: Optional[List[Dict]] = None, ) -> Dict[str, Any]: - """Send an email via Microsoft Graph.""" - import json + """Build a Graph API message object. + + attachments: list of {"name": str, "contentBytes": str (base64), "contentType": str} + """ message: Dict[str, Any] = { "subject": subject, - "body": {"contentType": "Text", "content": body}, + "body": {"contentType": bodyType, "content": body}, "toRecipients": [{"emailAddress": {"address": addr}} for addr in to], } if cc: message["ccRecipients"] = [{"emailAddress": {"address": addr}} for addr in cc] + if attachments: + message["attachments"] = [ + { + "@odata.type": "#microsoft.graph.fileAttachment", + "name": att["name"], + "contentBytes": att["contentBytes"], + "contentType": att.get("contentType", "application/octet-stream"), + } + for att in attachments + ] + return message + async def sendMail( + self, to: List[str], subject: str, body: str, + bodyType: str = "Text", + cc: Optional[List[str]] = None, + attachments: Optional[List[Dict]] = None, + ) -> Dict[str, Any]: + """Send an email via Microsoft Graph. bodyType: 'Text' or 'HTML'.""" + import json + message = self._buildMessage(to, subject, body, bodyType, cc, attachments) payload = json.dumps({"message": message, "saveToSentItems": True}).encode("utf-8") result = await self._graphPost("me/sendMail", payload) if "error" in result: return result return {"success": True} + async def createDraft( + self, to: List[str], subject: str, body: str, + bodyType: str = "Text", + cc: Optional[List[str]] = None, + attachments: Optional[List[Dict]] = None, + ) -> Dict[str, Any]: + """Create a draft email in the user's Drafts folder via Microsoft Graph.""" + import json + message = self._buildMessage(to, subject, body, bodyType, cc, attachments) + payload = json.dumps(message).encode("utf-8") + result = await self._graphPost("me/messages", payload) + if "error" in result: + return result + return {"success": True, "draft": True, "messageId": result.get("id", "")} + # --------------------------------------------------------------------------- # Teams Adapter (Stub) diff --git a/modules/datamodels/datamodelBilling.py b/modules/datamodels/datamodelBilling.py index 8ffbdef1..995ac75d 100644 --- a/modules/datamodels/datamodelBilling.py +++ b/modules/datamodels/datamodelBilling.py @@ -143,6 +143,9 @@ class BillingSettings(BaseModel): ) warningThresholdPercent: float = Field(default=10.0, description="Warning threshold as percentage") + # Stripe + stripeCustomerId: Optional[str] = Field(None, description="Stripe Customer ID (cus_xxx) — one per mandate") + # Notifications (e.g. mandate owner / finance — also used when PREPAY_MANDATE pool is exhausted) notifyEmails: List[str] = Field( default_factory=list, @@ -163,6 +166,7 @@ registerModelLabels( "de": "Startguthaben nur Root-Mandant (CHF)", }, "warningThresholdPercent": {"en": "Warning Threshold (%)", "de": "Warnschwelle (%)"}, + "stripeCustomerId": {"en": "Stripe Customer ID", "de": "Stripe-Kunden-ID"}, "notifyEmails": { "en": "Billing notification emails (owner / admin)", "de": "E-Mails für Billing-Alerts (Inhaber/Admin)", @@ -260,12 +264,15 @@ class BillingStatisticsResponse(BaseModel): class BillingCheckResult(BaseModel): - """Result of a billing balance check.""" + """Result of a billing balance check (budget + subscription gate).""" allowed: bool reason: Optional[str] = None currentBalance: Optional[float] = None requiredAmount: Optional[float] = None billingModel: Optional[BillingModelEnum] = None + upgradeRequired: Optional[bool] = None + subscriptionUiPath: Optional[str] = None + userAction: Optional[str] = None def parseBillingModelFromStoredValue(raw: Optional[str]) -> BillingModelEnum: diff --git a/modules/datamodels/datamodelPagination.py b/modules/datamodels/datamodelPagination.py index 9815027f..2719327b 100644 --- a/modules/datamodels/datamodelPagination.py +++ b/modules/datamodels/datamodelPagination.py @@ -98,6 +98,12 @@ def normalize_pagination_dict(pagination_dict: Dict[str, Any]) -> Dict[str, Any] # Create a copy to avoid modifying the original normalized = dict(pagination_dict) + # Ensure required fields have sensible defaults + if "page" not in normalized: + normalized["page"] = 1 + if "pageSize" not in normalized: + normalized["pageSize"] = 25 + # Move top-level "search" into filters if present if "search" in normalized: if "filters" not in normalized or normalized["filters"] is None: diff --git a/modules/datamodels/datamodelSubscription.py b/modules/datamodels/datamodelSubscription.py new file mode 100644 index 00000000..1c1435d8 --- /dev/null +++ b/modules/datamodels/datamodelSubscription.py @@ -0,0 +1,235 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Subscription models: SubscriptionPlan (catalog), MandateSubscription (instance per mandate), +StripePlanPrice (persisted Stripe IDs per plan). + +State Machine: see wiki/concepts/Subscription-State-Machine.md +""" + +from typing import Dict, List, Optional +from enum import Enum +from datetime import datetime, timezone +from pydantic import BaseModel, Field +from modules.shared.attributeUtils import registerModelLabels +import uuid + + +class SubscriptionStatusEnum(str, Enum): + """Lifecycle status of a mandate subscription. + See wiki/concepts/Subscription-State-Machine.md for transition rules.""" + PENDING = "PENDING" + SCHEDULED = "SCHEDULED" + TRIALING = "TRIALING" + ACTIVE = "ACTIVE" + PAST_DUE = "PAST_DUE" + EXPIRED = "EXPIRED" + + +TERMINAL_STATUSES = {SubscriptionStatusEnum.EXPIRED} +OPERATIVE_STATUSES = {SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.PAST_DUE} + +ALLOWED_TRANSITIONS = { + (SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.ACTIVE), + (SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.SCHEDULED), + (SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.EXPIRED), + (SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE), + (SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.EXPIRED), + (SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.EXPIRED), + (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE), + (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.EXPIRED), + (SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.ACTIVE), + (SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.EXPIRED), +} + + +class BillingPeriodEnum(str, Enum): + """Recurring billing interval.""" + MONTHLY = "MONTHLY" + YEARLY = "YEARLY" + NONE = "NONE" + + +# ============================================================================ +# Catalog: SubscriptionPlan (static, in-memory) +# ============================================================================ + +class SubscriptionPlan(BaseModel): + """Plan definition (catalog entry). Not stored per mandate — static.""" + planKey: str = Field(..., description="Unique plan identifier") + selectableByUser: bool = Field(default=True, description="Whether users can choose this plan in the UI") + + title: Dict[str, str] = Field(default_factory=dict, description="Multilingual title (en/de/fr)") + description: Dict[str, str] = Field(default_factory=dict, description="Multilingual description") + + currency: str = Field(default="CHF", description="Billing currency") + billingPeriod: BillingPeriodEnum = Field(default=BillingPeriodEnum.MONTHLY, description="Recurring interval") + pricePerUserCHF: float = Field(default=0.0, description="Price per active user per period") + pricePerFeatureInstanceCHF: float = Field(default=0.0, description="Price per active feature instance per period") + autoRenew: bool = Field(default=True, description="Stripe renews automatically at period end") + + maxUsers: Optional[int] = Field(None, description="Hard cap on active users (None = unlimited)") + maxFeatureInstances: Optional[int] = Field(None, description="Hard cap on active feature instances (None = unlimited)") + trialDays: Optional[int] = Field(None, description="Trial duration in days (only for trial plans)") + successorPlanKey: Optional[str] = Field(None, description="Plan to transition to when trial ends") + + +registerModelLabels( + "SubscriptionPlan", + {"en": "Subscription Plan", "de": "Abonnement-Plan", "fr": "Plan d'abonnement"}, + { + "planKey": {"en": "Plan", "de": "Plan", "fr": "Plan"}, + "selectableByUser": {"en": "Selectable", "de": "Wählbar", "fr": "Sélectionnable"}, + "billingPeriod": {"en": "Billing Period", "de": "Abrechnungszeitraum", "fr": "Période de facturation"}, + "pricePerUserCHF": {"en": "Price per User (CHF)", "de": "Preis pro User (CHF)"}, + "pricePerFeatureInstanceCHF": {"en": "Price per Instance (CHF)", "de": "Preis pro Instanz (CHF)"}, + "maxUsers": {"en": "Max Users", "de": "Max. Benutzer", "fr": "Max. utilisateurs"}, + "maxFeatureInstances": {"en": "Max Instances", "de": "Max. Instanzen", "fr": "Max. instances"}, + }, +) + + +# ============================================================================ +# Stripe Price mapping (persisted in DB, auto-created at bootstrap) +# ============================================================================ + +class StripePlanPrice(BaseModel): + """Persisted mapping from planKey to Stripe Product/Price IDs. + Auto-created at startup — no manual configuration needed. + Uses separate Stripe Products for users and instances for clear invoice labels.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") + planKey: str = Field(..., description="Reference to SubscriptionPlan.planKey") + stripeProductId: str = Field("", description="Legacy single-product ID (unused)") + stripeProductIdUsers: Optional[str] = Field(None, description="Stripe Product ID for user licenses") + stripeProductIdInstances: Optional[str] = Field(None, description="Stripe Product ID for feature instances") + stripePriceIdUsers: Optional[str] = Field(None, description="Stripe Price ID for user-seat line item") + stripePriceIdInstances: Optional[str] = Field(None, description="Stripe Price ID for instance line item") + + +registerModelLabels( + "StripePlanPrice", + {"en": "Stripe Plan Prices", "de": "Stripe-Planpreise"}, + { + "planKey": {"en": "Plan", "de": "Plan"}, + "stripeProductIdUsers": {"en": "Product (Users)", "de": "Produkt (User)"}, + "stripeProductIdInstances": {"en": "Product (Instances)", "de": "Produkt (Instanzen)"}, + "stripePriceIdUsers": {"en": "Price ID (Users)", "de": "Preis-ID (User)"}, + "stripePriceIdInstances": {"en": "Price ID (Instances)", "de": "Preis-ID (Instanzen)"}, + }, +) + + +# ============================================================================ +# Instance: MandateSubscription +# ============================================================================ + +class MandateSubscription(BaseModel): + """A subscription instance bound to a specific mandate. + See wiki/concepts/Subscription-State-Machine.md for state transitions.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key") + mandateId: str = Field(..., description="Foreign key to Mandate") + planKey: str = Field(..., description="Reference to SubscriptionPlan.planKey") + + status: SubscriptionStatusEnum = Field(default=SubscriptionStatusEnum.PENDING, description="Current lifecycle status") + recurring: bool = Field(default=True, description="True: auto-renews at period end. False: expires at period end (gekuendigt).") + + startedAt: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="Record creation timestamp") + effectiveFrom: Optional[datetime] = Field(None, description="When this subscription becomes operative. None = immediate. Set for SCHEDULED subs.") + endedAt: Optional[datetime] = Field(None, description="When subscription ended (terminal)") + currentPeriodStart: Optional[datetime] = Field(None, description="Current billing period start (synced from Stripe)") + currentPeriodEnd: Optional[datetime] = Field(None, description="Current billing period end (synced from Stripe)") + trialEndsAt: Optional[datetime] = Field(None, description="Trial expiry timestamp") + + snapshotPricePerUserCHF: float = Field(default=0.0, description="Price snapshot at activation (for invoice history)") + snapshotPricePerInstanceCHF: float = Field(default=0.0, description="Price snapshot at activation") + + stripeSubscriptionId: Optional[str] = Field(None, description="Stripe Subscription ID (sub_xxx)") + stripeItemIdUsers: Optional[str] = Field(None, description="Stripe Subscription Item ID for user seats") + stripeItemIdInstances: Optional[str] = Field(None, description="Stripe Subscription Item ID for feature instances") + + +registerModelLabels( + "MandateSubscription", + {"en": "Mandate Subscription", "de": "Mandanten-Abonnement", "fr": "Abonnement du mandat"}, + { + "id": {"en": "ID", "de": "ID"}, + "mandateId": {"en": "Mandate ID", "de": "Mandanten-ID"}, + "planKey": {"en": "Plan", "de": "Plan"}, + "status": {"en": "Status", "de": "Status"}, + "recurring": {"en": "Recurring", "de": "Wiederkehrend"}, + "startedAt": {"en": "Started", "de": "Gestartet"}, + "effectiveFrom": {"en": "Effective From", "de": "Wirksam ab"}, + "endedAt": {"en": "Ended", "de": "Beendet"}, + "currentPeriodStart": {"en": "Period Start", "de": "Periodenbeginn"}, + "currentPeriodEnd": {"en": "Period End", "de": "Periodenende"}, + "trialEndsAt": {"en": "Trial Ends", "de": "Trial endet"}, + "snapshotPricePerUserCHF": {"en": "Price/User (CHF)", "de": "Preis/User (CHF)"}, + "snapshotPricePerInstanceCHF": {"en": "Price/Instance (CHF)", "de": "Preis/Instanz (CHF)"}, + }, +) + + +# ============================================================================ +# Built-in plan catalog (static, no env dependency) +# ============================================================================ + +BUILTIN_PLANS: Dict[str, SubscriptionPlan] = { + "ROOT": SubscriptionPlan( + planKey="ROOT", + selectableByUser=False, + title={"en": "Root (System)", "de": "Root (System)", "fr": "Root (Système)"}, + description={"en": "Internal system plan — no billing.", "de": "Interner Systemplan — keine Verrechnung."}, + billingPeriod=BillingPeriodEnum.NONE, + autoRenew=False, + maxUsers=None, + maxFeatureInstances=None, + ), + "TRIAL_7D": SubscriptionPlan( + planKey="TRIAL_7D", + selectableByUser=False, + title={"en": "Free Trial (7 days)", "de": "Gratis-Testphase (7 Tage)", "fr": "Essai gratuit (7 jours)"}, + description={ + "en": "Try the platform for 7 days — 1 user, up to 3 feature instances.", + "de": "Plattform 7 Tage testen — 1 User, bis zu 3 Feature-Instanzen.", + }, + billingPeriod=BillingPeriodEnum.NONE, + autoRenew=False, + maxUsers=1, + maxFeatureInstances=3, + trialDays=7, + successorPlanKey="STANDARD_MONTHLY", + ), + "STANDARD_MONTHLY": SubscriptionPlan( + planKey="STANDARD_MONTHLY", + selectableByUser=True, + title={"en": "Standard (Monthly)", "de": "Standard (Monatlich)", "fr": "Standard (Mensuel)"}, + description={ + "en": "Usage-based billing per active user and feature instance, billed monthly.", + "de": "Nutzungsbasierte Abrechnung pro aktivem User und Feature-Instanz, monatlich.", + }, + billingPeriod=BillingPeriodEnum.MONTHLY, + pricePerUserCHF=90.0, + pricePerFeatureInstanceCHF=150.0, + ), + "STANDARD_YEARLY": SubscriptionPlan( + planKey="STANDARD_YEARLY", + selectableByUser=True, + title={"en": "Standard (Yearly)", "de": "Standard (Jährlich)", "fr": "Standard (Annuel)"}, + description={ + "en": "Usage-based billing per active user and feature instance, billed yearly.", + "de": "Nutzungsbasierte Abrechnung pro aktivem User und Feature-Instanz, jährlich.", + }, + billingPeriod=BillingPeriodEnum.YEARLY, + pricePerUserCHF=1080.0, + pricePerFeatureInstanceCHF=1800.0, + ), +} + + +def _getPlan(planKey: str) -> Optional[SubscriptionPlan]: + """Resolve a plan by key from the built-in catalog.""" + return BUILTIN_PLANS.get(planKey) + + +def _getSelectablePlans() -> List[SubscriptionPlan]: + """Return plans that users can choose in the UI.""" + return [p for p in BUILTIN_PLANS.values() if p.selectableByUser] diff --git a/modules/datamodels/datamodelVoice.py b/modules/datamodels/datamodelVoice.py index 2223a3e6..565c7677 100644 --- a/modules/datamodels/datamodelVoice.py +++ b/modules/datamodels/datamodelVoice.py @@ -1,46 +1,7 @@ # Copyright (c) 2025 Patrick Motsch # All rights reserved. -"""Voice settings datamodel.""" - -from typing import Dict, Any, Optional -from pydantic import BaseModel, Field -from modules.shared.attributeUtils import registerModelLabels -from modules.shared.timeUtils import getUtcTimestamp -import uuid - - -class VoiceSettings(BaseModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) - userId: str = Field(description="ID of the user these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}) - mandateId: str = Field(description="ID of the mandate these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}) - featureInstanceId: str = Field(description="ID of the feature instance these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}) - sttLanguage: str = Field(default="de-DE", description="Speech-to-Text language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True}) - ttsLanguage: str = Field(default="de-DE", description="Text-to-Speech language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True}) - ttsVoice: str = Field(default="de-DE-KatjaNeural", description="Text-to-Speech voice", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True}) - ttsVoiceMap: Dict[str, Any] = Field(default_factory=dict, description="Per-language voice mapping, e.g. {'de-DE': {'voiceName': 'de-DE-Wavenet-A'}, 'en-US': {'voiceName': 'en-US-Wavenet-C'}}", json_schema_extra={"frontend_type": "json", "frontend_readonly": False, "frontend_required": False}) - translationEnabled: bool = Field(default=True, description="Whether translation is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}) - targetLanguage: str = Field(default="en-US", description="Target language for translation", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False}) - creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False}) - lastModified: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were last modified (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False}) - - -registerModelLabels( - "VoiceSettings", - {"en": "Voice Settings", "fr": "Paramètres vocaux"}, - { - "id": {"en": "ID", "fr": "ID"}, - "userId": {"en": "User ID", "fr": "ID utilisateur"}, - "mandateId": {"en": "Mandate ID", "fr": "ID du mandat"}, - "featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"}, - "sttLanguage": {"en": "STT Language", "fr": "Langue STT"}, - "ttsLanguage": {"en": "TTS Language", "fr": "Langue TTS"}, - "ttsVoice": {"en": "TTS Voice", "fr": "Voix TTS"}, - "ttsVoiceMap": {"en": "TTS Voice Map", "fr": "Carte des voix TTS"}, - "translationEnabled": {"en": "Translation Enabled", "fr": "Traduction activée"}, - "targetLanguage": {"en": "Target Language", "fr": "Langue cible"}, - "creationDate": {"en": "Creation Date", "fr": "Date de création"}, - "lastModified": {"en": "Last Modified", "fr": "Dernière modification"}, - }, -) +"""Voice settings datamodel — re-exported from workspace feature for backward compatibility.""" +from modules.features.workspace.datamodelFeatureWorkspace import VoiceSettings +__all__ = ["VoiceSettings"] diff --git a/modules/features/automation/interfaceFeatureAutomation.py b/modules/features/automation/interfaceFeatureAutomation.py index 3b20ca3d..4091bc28 100644 --- a/modules/features/automation/interfaceFeatureAutomation.py +++ b/modules/features/automation/interfaceFeatureAutomation.py @@ -270,7 +270,7 @@ class AutomationObjects: if value.lower() not in itemValue.lower(): match = False break - elif itemValue != value: + elif str(itemValue).lower() != str(value).lower(): match = False break if match: @@ -418,6 +418,8 @@ class AutomationObjects: if not self.checkRbacPermission(AutomationDefinition, "update", automationId): raise PermissionError(f"No permission to modify automation {automationId}") + automationData.pop("executionLogs", None) + # If deactivating: immediately remove scheduler job (don't rely on async callback) isBeingDeactivated = "active" in automationData and not automationData["active"] if isBeingDeactivated: diff --git a/modules/features/automation/mainAutomation.py b/modules/features/automation/mainAutomation.py index 35a61512..4bb30f7f 100644 --- a/modules/features/automation/mainAutomation.py +++ b/modules/features/automation/mainAutomation.py @@ -27,11 +27,6 @@ UI_OBJECTS = [ "label": {"en": "Templates", "de": "Vorlagen", "fr": "Modèles"}, "meta": {"area": "templates"} }, - { - "objectKey": "ui.feature.automation.logs", - "label": {"en": "Execution Logs", "de": "Ausführungsprotokolle", "fr": "Journaux d'exécution"}, - "meta": {"area": "logs"} - }, ] # Resource Objects for RBAC catalog diff --git a/modules/features/automation/routeFeatureAutomation.py b/modules/features/automation/routeFeatureAutomation.py index 6cdc4d44..48f53eea 100644 --- a/modules/features/automation/routeFeatureAutomation.py +++ b/modules/features/automation/routeFeatureAutomation.py @@ -109,6 +109,26 @@ def get_automations( detail=f"Error getting automations: {str(e)}" ) +@router.get("/filter-values") +@limiter.limit("60/minute") +def get_automation_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in automations.""" + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None) + result = chatInterface.getAllAutomationDefinitions(pagination=None) + items = result if isinstance(result, list) else [r if isinstance(r, dict) else r.model_dump() if hasattr(r, 'model_dump') else r for r in result] + return _handleFilterValuesRequest(items, column, pagination) + except Exception as e: + logger.error(f"Error getting filter values for automations: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("", response_model=AutomationDefinition) @limiter.limit("10/minute") def create_automation( @@ -765,6 +785,7 @@ def update_automation( try: chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None) automationData = automation.model_dump() + automationData.pop("executionLogs", None) updated = chatInterface.updateAutomationDefinition(automationId, automationData) return updated except HTTPException: @@ -1017,6 +1038,30 @@ def get_db_templates( ) +@templateRouter.get("/filter-values") +@limiter.limit("60/minute") +def get_template_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in automation templates.""" + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + chatInterface = getAutomationInterface( + context.user, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None + ) + result = chatInterface.getAllAutomationTemplates(pagination=None) + items = [r if isinstance(r, dict) else r.model_dump() if hasattr(r, 'model_dump') else r for r in result] + return _handleFilterValuesRequest(items, column, pagination) + except Exception as e: + logger.error(f"Error getting filter values for automation templates: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + @templateRouter.get("/attributes", response_model=Dict[str, Any]) def get_template_attributes( request: Request diff --git a/modules/features/automation2/__init__.py b/modules/features/automation2/__init__.py new file mode 100644 index 00000000..c86d7e61 --- /dev/null +++ b/modules/features/automation2/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 Patrick Motsch +# Automation2 feature - n8n-style flow automation (backup/parallel to legacy automation) diff --git a/modules/features/automation2/datamodelFeatureAutomation2.py b/modules/features/automation2/datamodelFeatureAutomation2.py new file mode 100644 index 00000000..f505c7d0 --- /dev/null +++ b/modules/features/automation2/datamodelFeatureAutomation2.py @@ -0,0 +1,159 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Automation2 models: Automation2Workflow, Automation2WorkflowRun, Automation2HumanTask.""" + +from typing import Dict, Any, List, Optional +from pydantic import BaseModel, Field +from modules.shared.attributeUtils import registerModelLabels +import uuid + + +class Automation2Workflow(BaseModel): + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Primary key", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}, + ) + mandateId: str = Field( + description="Mandate ID", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}, + ) + featureInstanceId: str = Field( + description="Feature instance ID", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}, + ) + label: str = Field( + description="User-friendly workflow name", + json_schema_extra={"frontend_type": "text", "frontend_required": True}, + ) + graph: Dict[str, Any] = Field( + default_factory=dict, + description="Graph with nodes and connections (incl. node parameters)", + json_schema_extra={"frontend_type": "textarea", "frontend_required": True}, + ) + active: bool = Field( + default=True, + description="Whether workflow is active", + json_schema_extra={"frontend_type": "checkbox", "frontend_required": False}, + ) + + +registerModelLabels( + "Automation2Workflow", + {"en": "Automation2 Workflow", "de": "Automation2 Workflow", "fr": "Workflow Automation2"}, + { + "id": {"en": "ID", "de": "ID", "fr": "ID"}, + "mandateId": {"en": "Mandate ID", "de": "Mandanten-ID", "fr": "ID du mandat"}, + "featureInstanceId": {"en": "Feature Instance ID", "de": "Feature-Instanz-ID", "fr": "ID instance"}, + "label": {"en": "Label", "de": "Bezeichnung", "fr": "Libellé"}, + "graph": {"en": "Graph", "de": "Graph", "fr": "Graphe"}, + "active": {"en": "Active", "de": "Aktiv", "fr": "Actif"}, + }, +) + + +class Automation2WorkflowRun(BaseModel): + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Primary key", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}, + ) + workflowId: str = Field( + description="Workflow ID", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}, + ) + status: str = Field( + default="running", + description="Status: running|paused|completed|failed", + json_schema_extra={"frontend_type": "text", "frontend_required": False}, + ) + nodeOutputs: Dict[str, Any] = Field( + default_factory=dict, + description="Outputs from executed nodes", + json_schema_extra={"frontend_type": "textarea", "frontend_required": False}, + ) + currentNodeId: Optional[str] = Field( + default=None, + description="Node ID when paused (human task)", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}, + ) + context: Dict[str, Any] = Field( + default_factory=dict, + description="Context for resume (connectionMap, inputSources, etc.)", + json_schema_extra={"frontend_type": "textarea", "frontend_required": False}, + ) + + +registerModelLabels( + "Automation2WorkflowRun", + {"en": "Automation2 Workflow Run", "de": "Automation2 Workflow-Ausführung", "fr": "Exécution workflow"}, + { + "id": {"en": "ID", "de": "ID", "fr": "ID"}, + "workflowId": {"en": "Workflow ID", "de": "Workflow-ID", "fr": "ID workflow"}, + "status": {"en": "Status", "de": "Status", "fr": "Statut"}, + "nodeOutputs": {"en": "Node Outputs", "de": "Node-Ausgaben", "fr": "Sorties nœuds"}, + "currentNodeId": {"en": "Current Node", "de": "Aktueller Knoten", "fr": "Nœud actuel"}, + "context": {"en": "Context", "de": "Kontext", "fr": "Contexte"}, + }, +) + + +class Automation2HumanTask(BaseModel): + id: str = Field( + default_factory=lambda: str(uuid.uuid4()), + description="Primary key", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}, + ) + runId: str = Field( + description="Workflow run ID", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}, + ) + workflowId: str = Field( + description="Workflow ID", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}, + ) + nodeId: str = Field( + description="Node ID in the graph", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}, + ) + nodeType: str = Field( + description="Node type: form|approval|upload|comment|review|selection|confirmation", + json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}, + ) + config: Dict[str, Any] = Field( + default_factory=dict, + description="Node config (form schema, approval text, etc.)", + json_schema_extra={"frontend_type": "textarea", "frontend_required": False}, + ) + assigneeId: Optional[str] = Field( + default=None, + description="User ID assigned to complete the task", + json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False}, + ) + status: str = Field( + default="pending", + description="Status: pending|completed|rejected", + json_schema_extra={"frontend_type": "text", "frontend_required": False}, + ) + result: Optional[Dict[str, Any]] = Field( + default=None, + description="Task result (form data, approval decision, etc.)", + json_schema_extra={"frontend_type": "textarea", "frontend_required": False}, + ) + + +registerModelLabels( + "Automation2HumanTask", + {"en": "Automation2 Human Task", "de": "Automation2 Benutzer-Aufgabe", "fr": "Tâche utilisateur"}, + { + "id": {"en": "ID", "de": "ID", "fr": "ID"}, + "runId": {"en": "Run ID", "de": "Lauf-ID", "fr": "ID exécution"}, + "workflowId": {"en": "Workflow ID", "de": "Workflow-ID", "fr": "ID workflow"}, + "nodeId": {"en": "Node ID", "de": "Knoten-ID", "fr": "ID nœud"}, + "nodeType": {"en": "Node Type", "de": "Knotentyp", "fr": "Type nœud"}, + "config": {"en": "Config", "de": "Konfiguration", "fr": "Configuration"}, + "assigneeId": {"en": "Assignee", "de": "Zugewiesen an", "fr": "Assigné à"}, + "status": {"en": "Status", "de": "Status", "fr": "Statut"}, + "result": {"en": "Result", "de": "Ergebnis", "fr": "Résultat"}, + }, +) diff --git a/modules/features/automation2/emailPoller.py b/modules/features/automation2/emailPoller.py new file mode 100644 index 00000000..ca440ca2 --- /dev/null +++ b/modules/features/automation2/emailPoller.py @@ -0,0 +1,268 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Background email poller for automation2. +Checks paused runs waiting for email (email.checkEmail node) and resumes when a new matching email arrives. +""" + +import asyncio +import json +import logging +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + +# Job ID for scheduler +POLLER_JOB_ID = "automation2_email_poller" +POLL_INTERVAL_MINUTES = 2 + + +async def _pollEmailWaits(eventUser) -> None: + """ + Poll for new emails for runs waiting on email.checkEmail. + Uses eventUser for DB access; loads owner user for each run to call readEmails. + Stops the poller when no runs are waiting. + """ + try: + from modules.features.automation2.interfaceFeatureAutomation2 import getAutomation2Interface + from modules.features.automation2.mainAutomation2 import getAutomation2Services + from modules.workflows.automation2.executionEngine import executeGraph + from modules.workflows.processing.shared.methodDiscovery import discoverMethods + from modules.interfaces.interfaceDbApp import getRootInterface + + root = getRootInterface() + if not root: + logger.warning("Email poller: root interface not available") + return + # Use eventUser - getRunsWaitingForEmail queries by status only + a2 = getAutomation2Interface(eventUser, mandateId="", featureInstanceId="") + runs = a2.getRunsWaitingForEmail() + if not runs: + # No workflows waiting for email - stop the poller + stop(eventUser) + return + logger.info("Automation2 email poller: checking %d run(s) waiting for email", len(runs)) + + for run in runs: + run_id = run.get("id") + workflow_id = run.get("workflowId") + context = run.get("context") or {} + wait_config = context.get("waitConfig") or {} + node_id = run.get("currentNodeId") or context.get("waitConfig", {}).get("_nodeId") + owner_id = context.get("ownerId") + mandate_id = context.get("mandateId") + instance_id = context.get("instanceId") + last_checked = context.get("lastCheckedAt") + + if not owner_id or not mandate_id or not instance_id or not workflow_id or not node_id: + logger.warning("Email wait run %s missing ownerId/mandateId/instanceId/workflowId/nodeId - skipping", run_id) + continue + + # First poll: use pausedAt (or now - 5 min) as baseline so we don't miss emails + # that arrived between pause and first poll + if last_checked is None: + paused_at = context.get("pausedAt") + if paused_at: + baseline = paused_at + else: + # Fallback: look back 5 minutes for runs created before pausedAt existed + baseline = (datetime.now(timezone.utc) - timedelta(minutes=5)).strftime("%Y-%m-%dT%H:%M:%SZ") + last_checked = baseline + + # Load owner user (root interface has broad access) + owner = root.getUser(owner_id) if hasattr(root, "getUser") else None + if not owner: + logger.warning("Email wait run %s: owner user %s not found", run_id, owner_id) + continue + + # Get workflow (need scoped interface for mandate/instance) + a2_scoped = getAutomation2Interface(eventUser, mandateId=mandate_id, featureInstanceId=instance_id) + wf = a2_scoped.getWorkflow(workflow_id) + if not wf or not wf.get("graph"): + logger.warning("Email wait run %s: workflow %s not found or has no graph", run_id, workflow_id) + continue + + # Only process runs paused at email.checkEmail – searchEmail never waits, it searches all immediately + nodes = (wf.get("graph") or {}).get("nodes") or [] + paused_node = next((n for n in nodes if n.get("id") == node_id), None) + if paused_node and paused_node.get("type") == "email.searchEmail": + logger.warning("Email wait run %s: paused at email.searchEmail (should not wait) – skipping", run_id) + continue + + services = getAutomation2Services(owner, mandateId=mandate_id, featureInstanceId=instance_id) + discoverMethods(services) + + # Build filter with receivedDateTime – only emails received at or after baseline (new emails) + base_filter = wait_config.get("filter") or "" + dt_filter = f"receivedDateTime ge {last_checked}" + combined_filter = f"({base_filter}) and {dt_filter}" if base_filter else dt_filter + logger.debug("Email wait run %s: fetch filter (new emails only) %s", run_id, combined_filter) + + from modules.workflows.processing.core.actionExecutor import ActionExecutor + executor = ActionExecutor(services) + params = { + "connectionReference": wait_config.get("connectionReference"), + "folder": wait_config.get("folder", "Inbox"), + "limit": min(int(wait_config.get("limit", 10)), 50), + "filter": combined_filter, + } + + try: + result = await executor.executeAction("outlook", "readEmails", params) + except Exception as e: + logger.warning("Email wait run %s: readEmails failed: %s", run_id, e) + continue + + # readEmails always returns 1 document (JSON wrapper); check actual email count + email_count = 0 + if result and result.documents: + doc = result.documents[0] + meta = getattr(doc, "validationMetadata", None) + if not meta and isinstance(doc, dict): + meta = doc.get("validationMetadata") + if meta and isinstance(meta, dict): + email_count = int(meta.get("emailCount", 0)) + else: + try: + data = json.loads(getattr(doc, "documentData", "") or "{}") + email_count = len(data.get("emails", {}).get("emails", [])) + except Exception: + pass + + if not result or not result.success or email_count == 0: + # No new emails - persist lastCheckedAt so next poll uses this as baseline + now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + ctx = dict(context) + ctx["lastCheckedAt"] = now_iso + a2_scoped.updateRun(run_id, context=ctx) + continue + + # Only pass NEW emails (receivedDateTime >= last_checked) – filter server-side as safeguard + doc = result.documents[0] + raw_data = json.loads(getattr(doc, "documentData", "") or "{}") + emails_data = raw_data.get("emails", {}) + all_emails = emails_data.get("emails", []) + new_emails = [ + e for e in all_emails + if (e.get("receivedDateTime") or "") >= last_checked + ] + if not new_emails: + now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + ctx = dict(context) + ctx["lastCheckedAt"] = now_iso + a2_scoped.updateRun(run_id, context=ctx) + continue + + # Rebuild document with only new emails for downstream nodes + result_data = dict(raw_data) + result_data["emails"] = dict(emails_data) + result_data["emails"]["emails"] = new_emails + result_data["emails"]["count"] = len(new_emails) + + from modules.datamodels.datamodelChat import ActionDocument + filtered_doc = ActionDocument( + documentName=getattr(doc, "documentName", "outlook_emails.json"), + documentData=json.dumps(result_data, indent=2), + mimeType=getattr(doc, "mimeType", "application/json"), + validationMetadata={**(getattr(doc, "validationMetadata") or {}), "emailCount": len(new_emails)}, + ) + + # Build node output in same format as ActionNodeExecutor for readEmails + node_output = { + "success": result.success, + "error": result.error, + "documents": [filtered_doc.model_dump() if hasattr(filtered_doc, "model_dump") else filtered_doc], + "data": result.model_dump() if hasattr(result, "model_dump") else {"success": result.success, "error": result.error}, + } + + # Update lastCheckedAt before resume + now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + ctx = dict(context) + ctx["lastCheckedAt"] = now_iso + a2_scoped.updateRun(run_id, status="running", context=ctx) + + node_outputs = dict(run.get("nodeOutputs") or {}) + node_outputs[node_id] = node_output + + logger.info("Email wait run %s: found new email, resuming from node %s", run_id, node_id) + + resume_result = await executeGraph( + graph=wf["graph"], + services=services, + workflowId=workflow_id, + instanceId=instance_id, + userId=owner_id, + mandateId=mandate_id, + automation2_interface=a2_scoped, + initialNodeOutputs=node_outputs, + startAfterNodeId=node_id, + runId=run_id, + ) + + if resume_result.get("success"): + logger.info("Email wait run %s: completed successfully", run_id) + elif resume_result.get("paused"): + logger.info("Email wait run %s: paused again (e.g. human task)", run_id) + else: + logger.warning("Email wait run %s: failed: %s", run_id, resume_result.get("error")) + + except Exception as e: + logger.exception("Email poller failed: %s", e) + + +def _runPollSync(ev_user): + """Sync job for scheduler - runs async poll. Thread-safe for both main loop and worker threads.""" + try: + try: + loop = asyncio.get_running_loop() + # Already in event loop - schedule, don't block + loop.create_task(_pollEmailWaits(ev_user)) + except RuntimeError: + # No running loop (worker thread) - run in new loop + asyncio.run(_pollEmailWaits(ev_user)) + except Exception as e: + logger.exception("Automation2 email poller job failed: %s", e) + + +def ensureRunning(eventUser) -> bool: + """Start the poller if not already running. Called when a run pauses for email.checkEmail.""" + return start(eventUser) + + +def start(eventUser) -> bool: + """Register the email poller interval job.""" + if not eventUser: + logger.warning("Automation2 email poller: no eventUser, not registering") + return False + try: + from modules.shared.eventManagement import eventManager + + # Use sync wrapper - APScheduler may run jobs in thread pool where async doesn't work + job_func = lambda: _runPollSync(eventUser) + eventManager.registerInterval( + POLLER_JOB_ID, + job_func, + seconds=0, + minutes=POLL_INTERVAL_MINUTES, + hours=0, + ) + logger.info("Automation2 email poller started (interval=%s min)", POLL_INTERVAL_MINUTES) + # Run once immediately so we don't wait 2 minutes for the first check + _runPollSync(eventUser) + return True + except Exception as e: + logger.error("Failed to register automation2 email poller: %s", e) + return False + + +def stop(eventUser) -> bool: + """Remove the email poller job.""" + try: + from modules.shared.eventManagement import eventManager + eventManager.remove(POLLER_JOB_ID) + logger.info("Automation2 email poller removed") + return True + except Exception as e: + logger.warning("Error removing automation2 email poller: %s", e) + return True diff --git a/modules/features/automation2/interfaceFeatureAutomation2.py b/modules/features/automation2/interfaceFeatureAutomation2.py new file mode 100644 index 00000000..cdc9bccf --- /dev/null +++ b/modules/features/automation2/interfaceFeatureAutomation2.py @@ -0,0 +1,311 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Interface for Automation2 feature - Workflows, Runs, Human Tasks. +Uses PostgreSQL poweron_automation2 database. +""" + +import base64 +import logging +import uuid +from typing import Dict, Any, List, Optional + + +def _make_json_serializable(obj: Any) -> Any: + """ + Recursively convert bytes to base64 strings so structures can be JSON-serialized + for storage in JSONB columns. + """ + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("ascii") + if isinstance(obj, dict): + return {k: _make_json_serializable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_make_json_serializable(v) for v in obj] + return obj + +from modules.datamodels.datamodelUam import User +from modules.features.automation2.datamodelFeatureAutomation2 import ( + Automation2Workflow, + Automation2WorkflowRun, + Automation2HumanTask, +) +from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.shared.configuration import APP_CONFIG + +logger = logging.getLogger(__name__) + + +def getAutomation2Interface( + currentUser: User, + mandateId: str, + featureInstanceId: str, +) -> "Automation2Objects": + """Factory for Automation2 interface with user context.""" + return Automation2Objects( + currentUser=currentUser, + mandateId=mandateId, + featureInstanceId=featureInstanceId, + ) + + +class Automation2Objects: + """Interface for Automation2 database operations.""" + + def __init__( + self, + currentUser: User, + mandateId: str, + featureInstanceId: str, + ): + self.currentUser = currentUser + self.mandateId = mandateId + self.featureInstanceId = featureInstanceId + self.userId = currentUser.id if currentUser else None + self._init_db() + if hasattr(self.db, "updateContext") and self.userId: + self.db.updateContext(self.userId) + + def _init_db(self): + """Initialize database connection to poweron_automation2.""" + dbHost = APP_CONFIG.get("DB_HOST", "localhost") + dbDatabase = "poweron_automation2" + dbUser = APP_CONFIG.get("DB_USER") + dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET") or APP_CONFIG.get("DB_PASSWORD") + dbPort = int(APP_CONFIG.get("DB_PORT", 5432)) + self.db = DatabaseConnector( + dbHost=dbHost, + dbDatabase=dbDatabase, + dbUser=dbUser, + dbPassword=dbPassword, + dbPort=dbPort, + userId=self.userId, + ) + logger.debug("Automation2 database initialized for user %s", self.userId) + + # ------------------------------------------------------------------------- + # Workflow CRUD + # ------------------------------------------------------------------------- + + def getWorkflows(self) -> List[Dict[str, Any]]: + """Get all workflows for this mandate and feature instance.""" + if not self.db._ensureTableExists(Automation2Workflow): + return [] + records = self.db.getRecordset( + Automation2Workflow, + recordFilter={ + "mandateId": self.mandateId, + "featureInstanceId": self.featureInstanceId, + }, + ) + return [dict(r) for r in records] if records else [] + + def getWorkflow(self, workflowId: str) -> Optional[Dict[str, Any]]: + """Get a single workflow by ID.""" + if not self.db._ensureTableExists(Automation2Workflow): + return None + records = self.db.getRecordset( + Automation2Workflow, + recordFilter={ + "id": workflowId, + "mandateId": self.mandateId, + "featureInstanceId": self.featureInstanceId, + }, + ) + if not records: + return None + return dict(records[0]) + + def createWorkflow(self, data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new workflow.""" + if "id" not in data or not data.get("id"): + data["id"] = str(uuid.uuid4()) + data["mandateId"] = self.mandateId + data["featureInstanceId"] = self.featureInstanceId + created = self.db.recordCreate(Automation2Workflow, data) + return dict(created) + + def updateWorkflow(self, workflowId: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Update an existing workflow.""" + existing = self.getWorkflow(workflowId) + if not existing: + return None + # Don't overwrite mandateId/featureInstanceId + data.pop("mandateId", None) + data.pop("featureInstanceId", None) + updated = self.db.recordModify(Automation2Workflow, workflowId, data) + return dict(updated) + + def deleteWorkflow(self, workflowId: str) -> bool: + """Delete a workflow.""" + existing = self.getWorkflow(workflowId) + if not existing: + return False + self.db.recordDelete(Automation2Workflow, workflowId) + return True + + # ------------------------------------------------------------------------- + # Workflow Runs + # ------------------------------------------------------------------------- + + def createRun(self, workflowId: str, nodeOutputs: Dict = None, context: Dict = None) -> Dict[str, Any]: + """Create a new workflow run.""" + data = { + "id": str(uuid.uuid4()), + "workflowId": workflowId, + "status": "running", + "nodeOutputs": _make_json_serializable(nodeOutputs or {}), + "currentNodeId": None, + "context": context or {}, + } + created = self.db.recordCreate(Automation2WorkflowRun, data) + return dict(created) + + def getRun(self, runId: str) -> Optional[Dict[str, Any]]: + """Get a run by ID.""" + if not self.db._ensureTableExists(Automation2WorkflowRun): + return None + records = self.db.getRecordset( + Automation2WorkflowRun, + recordFilter={"id": runId}, + ) + if not records: + return None + return dict(records[0]) + + def updateRun( + self, + runId: str, + status: str = None, + nodeOutputs: Dict = None, + currentNodeId: str = None, + context: Dict = None, + ) -> Optional[Dict[str, Any]]: + """Update a run.""" + run = self.getRun(runId) + if not run: + return None + updates = {} + if status is not None: + updates["status"] = status + if nodeOutputs is not None: + updates["nodeOutputs"] = _make_json_serializable(nodeOutputs) + if currentNodeId is not None: + updates["currentNodeId"] = currentNodeId + if context is not None: + updates["context"] = context + if not updates: + return run + updated = self.db.recordModify(Automation2WorkflowRun, runId, updates) + return dict(updated) + + def getRunsByWorkflow(self, workflowId: str) -> List[Dict[str, Any]]: + """Get all runs for a workflow.""" + if not self.db._ensureTableExists(Automation2WorkflowRun): + return [] + records = self.db.getRecordset( + Automation2WorkflowRun, + recordFilter={"workflowId": workflowId}, + ) + return [dict(r) for r in records] if records else [] + + def getRunsWaitingForEmail(self) -> List[Dict[str, Any]]: + """Get all paused runs waiting for a new email (for background poller).""" + if not self.db._ensureTableExists(Automation2WorkflowRun): + return [] + records = self.db.getRecordset( + Automation2WorkflowRun, + recordFilter={"status": "paused"}, + ) + if not records: + return [] + result = [] + for r in records: + rec = dict(r) + ctx = rec.get("context") or {} + if ctx.get("waitReason") == "email": + result.append(rec) + return result + + # ------------------------------------------------------------------------- + # Human Tasks + # ------------------------------------------------------------------------- + + def createTask( + self, + runId: str, + workflowId: str, + nodeId: str, + nodeType: str, + config: Dict[str, Any], + assigneeId: str = None, + ) -> Dict[str, Any]: + """Create a human task.""" + data = { + "id": str(uuid.uuid4()), + "runId": runId, + "workflowId": workflowId, + "nodeId": nodeId, + "nodeType": nodeType, + "config": config, + "assigneeId": assigneeId, + "status": "pending", + "result": None, + } + created = self.db.recordCreate(Automation2HumanTask, data) + return dict(created) + + def getTask(self, taskId: str) -> Optional[Dict[str, Any]]: + """Get a task by ID.""" + if not self.db._ensureTableExists(Automation2HumanTask): + return None + records = self.db.getRecordset( + Automation2HumanTask, + recordFilter={"id": taskId}, + ) + if not records: + return None + return dict(records[0]) + + def updateTask(self, taskId: str, status: str = None, result: Dict = None) -> Optional[Dict[str, Any]]: + """Update a task (e.g. complete with result).""" + task = self.getTask(taskId) + if not task: + return None + updates = {} + if status is not None: + updates["status"] = status + if result is not None: + updates["result"] = result + if not updates: + return task + updated = self.db.recordModify(Automation2HumanTask, taskId, updates) + return dict(updated) + + def getTasks( + self, + workflowId: str = None, + runId: str = None, + status: str = None, + assigneeId: str = None, + ) -> List[Dict[str, Any]]: + """Get tasks with optional filters. AssigneeId filters to that user; None returns all.""" + if not self.db._ensureTableExists(Automation2HumanTask): + return [] + rf = {} + if workflowId: + rf["workflowId"] = workflowId + if runId: + rf["runId"] = runId + if status: + rf["status"] = status + if assigneeId: + rf["assigneeId"] = assigneeId + records = self.db.getRecordset( + Automation2HumanTask, + recordFilter=rf if rf else None, + ) + items = [dict(r) for r in records] if records else [] + workflows = {w["id"]: w for w in self.getWorkflows()} + filtered = [t for t in items if t.get("workflowId") in workflows] + return filtered diff --git a/modules/features/automation2/mainAutomation2.py b/modules/features/automation2/mainAutomation2.py new file mode 100644 index 00000000..9ec97eca --- /dev/null +++ b/modules/features/automation2/mainAutomation2.py @@ -0,0 +1,303 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Automation2 Feature - n8n-style flow automation. +Minimal bootstrap for feature instance creation. Build from here. +""" + +import logging +from typing import Dict, List, Any, Optional + +logger = logging.getLogger(__name__) + +FEATURE_CODE = "automation2" + +# Services required for automation2 (methodDiscovery, ActionExecutor, etc.) +REQUIRED_SERVICES = [ + {"serviceKey": "chat", "meta": {"usage": "Interfaces, RBAC"}}, + {"serviceKey": "utils", "meta": {"usage": "Timestamps, utilities"}}, + {"serviceKey": "ai", "meta": {"usage": "AI nodes"}}, + {"serviceKey": "extraction", "meta": {"usage": "Workflow method actions"}}, + {"serviceKey": "sharepoint", "meta": {"usage": "SharePoint actions"}}, +] +FEATURE_LABEL = {"en": "Automation 2", "de": "Automatisierung 2", "fr": "Automatisation 2"} +FEATURE_ICON = "mdi-sitemap" + +UI_OBJECTS = [ + { + "objectKey": "ui.feature.automation2.editor", + "label": {"en": "Editor", "de": "Editor", "fr": "Éditeur"}, + "meta": {"area": "editor"} + }, + { + "objectKey": "ui.feature.automation2.workflows", + "label": {"en": "Workflows", "de": "Workflows", "fr": "Workflows"}, + "meta": {"area": "workflows"} + }, + { + "objectKey": "ui.feature.automation2.workflows-tasks", + "label": {"en": "Tasks", "de": "Tasks", "fr": "Tâches"}, + "meta": {"area": "tasks"} + }, +] + +RESOURCE_OBJECTS = [ + { + "objectKey": "resource.feature.automation2.dashboard", + "label": {"en": "Access Dashboard", "de": "Dashboard aufrufen", "fr": "Acceder au tableau de bord"}, + "meta": {"endpoint": "/api/automation2/{instanceId}/info", "method": "GET"} + }, + { + "objectKey": "resource.feature.automation2.node-types", + "label": {"en": "Get Node Types", "de": "Node-Typen abrufen", "fr": "Obtenir types de nœuds"}, + "meta": {"endpoint": "/api/automation2/{instanceId}/node-types", "method": "GET"} + }, + { + "objectKey": "resource.feature.automation2.execute", + "label": {"en": "Execute Workflow", "de": "Workflow ausführen", "fr": "Exécuter le workflow"}, + "meta": {"endpoint": "/api/automation2/{instanceId}/execute", "method": "POST"} + }, +] + +TEMPLATE_ROLES = [ + { + "roleLabel": "automation2-user", + "description": { + "en": "Automation2 User - Use automation2 flow builder", + "de": "Automation2 Benutzer - Flow-Builder nutzen", + "fr": "Utilisateur Automation2 - Utiliser le flow builder" + }, + "accessRules": [ + {"context": "UI", "item": "ui.feature.automation2.editor", "view": True}, + {"context": "UI", "item": "ui.feature.automation2.workflows", "view": True}, + {"context": "UI", "item": "ui.feature.automation2.workflows-tasks", "view": True}, + {"context": "RESOURCE", "item": "resource.feature.automation2.dashboard", "view": True}, + {"context": "RESOURCE", "item": "resource.feature.automation2.node-types", "view": True}, + {"context": "RESOURCE", "item": "resource.feature.automation2.execute", "view": True}, + {"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "m"}, + ] + }, +] + + +def getRequiredServiceKeys() -> List[str]: + """Return list of service keys this feature requires.""" + return [s["serviceKey"] for s in REQUIRED_SERVICES] + + +def getAutomation2Services( + user, + mandateId: Optional[str] = None, + featureInstanceId: Optional[str] = None, + workflow=None, +) -> "_Automation2ServiceHub": + """ + Get a service hub for automation2 using the service center. + Used for methodDiscovery (I/O nodes) and execution (ActionExecutor). + """ + from modules.serviceCenter import getService + from modules.serviceCenter.context import ServiceCenterContext + + _workflow = workflow + if _workflow is None: + _workflow = type( + "_Placeholder", + (), + {"featureCode": FEATURE_CODE, "id": None, "workflowMode": None, "messages": []}, + )() + + ctx = ServiceCenterContext( + user=user, + mandate_id=mandateId, + feature_instance_id=featureInstanceId, + workflow=_workflow, + ) + + hub = _Automation2ServiceHub() + hub.user = user + hub.mandateId = mandateId + hub.featureInstanceId = featureInstanceId + hub._service_context = ctx + hub.workflow = workflow + hub.featureCode = FEATURE_CODE + + for spec in REQUIRED_SERVICES: + key = spec["serviceKey"] + try: + svc = getService(key, ctx) + setattr(hub, key, svc) + except Exception as e: + logger.warning(f"Could not resolve service '{key}' for automation2: {e}") + setattr(hub, key, None) + + if hub.chat: + hub.interfaceDbApp = getattr(hub.chat, "interfaceDbApp", None) + hub.interfaceDbComponent = getattr(hub.chat, "interfaceDbComponent", None) + hub.interfaceDbChat = getattr(hub.chat, "interfaceDbChat", None) + hub.rbac = getattr(hub.interfaceDbApp, "rbac", None) if getattr(hub, "interfaceDbApp", None) else None + + return hub + + +class _Automation2ServiceHub: + """Lightweight hub for automation2 (methodDiscovery, execution).""" + + user = None + mandateId = None + featureInstanceId = None + _service_context = None + workflow = None + featureCode = FEATURE_CODE + interfaceDbApp = None + interfaceDbComponent = None + interfaceDbChat = None + rbac = None + chat = None + ai = None + utils = None + extraction = None + sharepoint = None + + +async def onStart(eventUser) -> None: + """Feature startup. Email poller is started on-demand when a run pauses for email.checkEmail.""" + + +async def onStop(eventUser) -> None: + """Feature shutdown - remove email poller if running.""" + from modules.features.automation2.emailPoller import stop as stopEmailPoller + stopEmailPoller(eventUser) + + +def getFeatureDefinition() -> Dict[str, Any]: + """Return the feature definition for registration.""" + return { + "code": FEATURE_CODE, + "label": FEATURE_LABEL, + "icon": FEATURE_ICON, + "autoCreateInstance": True, + } + + +def getUiObjects() -> List[Dict[str, Any]]: + """Return UI objects for RBAC catalog registration.""" + return UI_OBJECTS + + +def getResourceObjects() -> List[Dict[str, Any]]: + """Return resource objects for RBAC catalog registration.""" + return RESOURCE_OBJECTS + + +def getTemplateRoles() -> List[Dict[str, Any]]: + """Return template roles for this feature.""" + return TEMPLATE_ROLES + + +def registerFeature(catalogService) -> bool: + """Register this feature's RBAC objects in the catalog.""" + try: + for uiObj in UI_OBJECTS: + catalogService.registerUiObject( + featureCode=FEATURE_CODE, + objectKey=uiObj["objectKey"], + label=uiObj["label"], + meta=uiObj.get("meta") + ) + for resObj in RESOURCE_OBJECTS: + catalogService.registerResourceObject( + featureCode=FEATURE_CODE, + objectKey=resObj["objectKey"], + label=resObj["label"], + meta=resObj.get("meta") + ) + _syncTemplateRolesToDb() + logger.info(f"Feature '{FEATURE_CODE}' registered {len(UI_OBJECTS)} UI objects and {len(RESOURCE_OBJECTS)} resource objects") + return True + except Exception as e: + logger.error(f"Failed to register feature '{FEATURE_CODE}': {e}") + return False + + +def _syncTemplateRolesToDb() -> int: + """Sync template roles and their AccessRules to database. + Also syncs rules to mandate-specific roles (same roleLabel) so new UI objects + become visible after gateway restart without manual role update. + """ + try: + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.datamodels.datamodelRbac import Role + + rootInterface = getRootInterface() + existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE) + existingLabels = {r.roleLabel: str(r.id) for r in existingRoles if r.mandateId is None} + created = 0 + + for template in TEMPLATE_ROLES: + roleLabel = template["roleLabel"] + if roleLabel in existingLabels: + roleId = existingLabels[roleLabel] + else: + newRole = Role( + roleLabel=roleLabel, + description=template.get("description", {}), + featureCode=FEATURE_CODE, + mandateId=None, + featureInstanceId=None, + isSystemRole=False + ) + rec = rootInterface.db.recordCreate(Role, newRole.model_dump()) + roleId = rec.get("id") + created += 1 + logger.info(f"Created template role '{roleLabel}' for {FEATURE_CODE}") + + _ensureAccessRulesForRole(rootInterface, roleId, template.get("accessRules", [])) + + # Sync same rules to mandate-specific roles (so Workflows & Tasks etc. appear in sidebar) + for r in existingRoles: + if r.mandateId and r.roleLabel == roleLabel: + added = _ensureAccessRulesForRole( + rootInterface, str(r.id), template.get("accessRules", []) + ) + if added: + logger.debug(f"Added {added} access rules to mandate role {r.id}") + return created + except Exception as e: + logger.warning(f"Template role sync for {FEATURE_CODE}: {e}") + return 0 + + +def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int: + """Ensure AccessRules exist for a role based on templates.""" + from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext + + existingRules = rootInterface.getAccessRulesByRole(roleId) + existingSignatures = { + (r.context.value if r.context else None, r.item) + for r in existingRules + } + created = 0 + for t in ruleTemplates: + context = t.get("context", "UI") + item = t.get("item") + sig = (context, item) + if sig in existingSignatures: + continue + ctx_enum = ( + AccessRuleContext.UI if context == "UI" else + AccessRuleContext.DATA if context == "DATA" else + AccessRuleContext.RESOURCE if context == "RESOURCE" else context + ) + newRule = AccessRule( + roleId=roleId, + context=ctx_enum, + item=item, + view=t.get("view", False), + read=t.get("read"), + create=t.get("create"), + update=t.get("update"), + delete=t.get("delete"), + ) + rootInterface.db.recordCreate(AccessRule, newRule.model_dump()) + created += 1 + return created diff --git a/modules/features/automation2/nodeDefinitions/__init__.py b/modules/features/automation2/nodeDefinitions/__init__.py new file mode 100644 index 00000000..61eec51a --- /dev/null +++ b/modules/features/automation2/nodeDefinitions/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025 Patrick Motsch +# Node type definitions for automation2 flow builder. + +from .triggers import TRIGGER_NODES +from .flow import FLOW_NODES +from .data import DATA_NODES +from .input import INPUT_NODES +from .ai import AI_NODES +from .email import EMAIL_NODES +from .sharepoint import SHAREPOINT_NODES + +STATIC_NODE_TYPES = ( + TRIGGER_NODES + + FLOW_NODES + + DATA_NODES + + INPUT_NODES + + AI_NODES + + EMAIL_NODES + + SHAREPOINT_NODES +) diff --git a/modules/features/automation2/nodeDefinitions/ai.py b/modules/features/automation2/nodeDefinitions/ai.py new file mode 100644 index 00000000..4fdf0db9 --- /dev/null +++ b/modules/features/automation2/nodeDefinitions/ai.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025 Patrick Motsch +# AI node definitions - map to methodAi actions. + +AI_NODES = [ + { + "id": "ai.prompt", + "category": "ai", + "label": {"en": "Prompt", "de": "Prompt", "fr": "Invite"}, + "description": {"en": "Enter a prompt and AI does something", "de": "Prompt eingeben und KI führt aus", "fr": "Entrer une invite et l'IA exécute"}, + "parameters": [ + {"name": "prompt", "type": "string", "required": True, "description": {"en": "AI prompt", "de": "KI-Prompt", "fr": "Invite IA"}}, + {"name": "resultType", "type": "string", "required": False, "description": {"en": "Output format (txt, json, md, etc.)", "de": "Ausgabeformat", "fr": "Format de sortie"}, "default": "txt"}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-robot", "color": "#9C27B0"}, + "_method": "ai", + "_action": "process", + "_paramMap": {"prompt": "aiPrompt"}, + }, + { + "id": "ai.webResearch", + "category": "ai", + "label": {"en": "Web Research", "de": "Web-Recherche", "fr": "Recherche web"}, + "description": {"en": "Research on the web", "de": "Recherche im Web", "fr": "Recherche sur le web"}, + "parameters": [ + {"name": "query", "type": "string", "required": True, "description": {"en": "Research query", "de": "Recherche-Anfrage", "fr": "Requête de recherche"}}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-magnify", "color": "#9C27B0"}, + "_method": "ai", + "_action": "webResearch", + "_paramMap": {"query": "prompt"}, + }, + { + "id": "ai.summarizeDocument", + "category": "ai", + "label": {"en": "Summarize Document", "de": "Dokument zusammenfassen", "fr": "Résumer document"}, + "description": {"en": "Summarize document content", "de": "Dokumentinhalt zusammenfassen", "fr": "Résumer le contenu du document"}, + "parameters": [ + {"name": "summaryLength", "type": "string", "required": False, "description": {"en": "Short, medium, or long", "de": "Kurz, mittel oder lang", "fr": "Court, moyen ou long"}, "default": "medium"}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-file-document-outline", "color": "#9C27B0"}, + "_method": "ai", + "_action": "summarizeDocument", + "_paramMap": {}, + }, + { + "id": "ai.translateDocument", + "category": "ai", + "label": {"en": "Translate Document", "de": "Dokument übersetzen", "fr": "Traduire document"}, + "description": {"en": "Translate document to target language", "de": "Dokument in Zielsprache übersetzen", "fr": "Traduire le document"}, + "parameters": [ + {"name": "targetLanguage", "type": "string", "required": True, "description": {"en": "Target language (e.g. en, de, fr)", "de": "Zielsprache", "fr": "Langue cible"}}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-translate", "color": "#9C27B0"}, + "_method": "ai", + "_action": "translateDocument", + "_paramMap": {"targetLanguage": "targetLanguage"}, + }, + { + "id": "ai.convertDocument", + "category": "ai", + "label": {"en": "Convert Document", "de": "Dokument konvertieren", "fr": "Convertir document"}, + "description": {"en": "Convert document to another format", "de": "Dokument in anderes Format konvertieren", "fr": "Convertir le document"}, + "parameters": [ + {"name": "targetFormat", "type": "string", "required": True, "description": {"en": "Target format (pdf, docx, txt, etc.)", "de": "Zielformat", "fr": "Format cible"}}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-file-convert", "color": "#9C27B0"}, + "_method": "ai", + "_action": "convertDocument", + "_paramMap": {"targetFormat": "targetFormat"}, + }, + { + "id": "ai.generateDocument", + "category": "ai", + "label": {"en": "Generate Document", "de": "Dokument generieren", "fr": "Générer document"}, + "description": {"en": "Generate document from prompt", "de": "Dokument aus Prompt generieren", "fr": "Générer un document"}, + "parameters": [ + {"name": "prompt", "type": "string", "required": True, "description": {"en": "Generation prompt", "de": "Generierungs-Prompt", "fr": "Invite de génération"}}, + {"name": "format", "type": "string", "required": False, "description": {"en": "Output format", "de": "Ausgabeformat", "fr": "Format de sortie"}, "default": "docx"}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-file-plus", "color": "#9C27B0"}, + "_method": "ai", + "_action": "generateDocument", + "_paramMap": {"prompt": "prompt", "format": "format"}, + }, + { + "id": "ai.generateCode", + "category": "ai", + "label": {"en": "Generate Code", "de": "Code generieren", "fr": "Générer code"}, + "description": {"en": "Generate code from description", "de": "Code aus Beschreibung generieren", "fr": "Générer du code"}, + "parameters": [ + {"name": "prompt", "type": "string", "required": True, "description": {"en": "Code generation prompt", "de": "Code-Generierungs-Prompt", "fr": "Invite de génération de code"}}, + {"name": "language", "type": "string", "required": False, "description": {"en": "Programming language", "de": "Programmiersprache", "fr": "Langage de programmation"}, "default": "python"}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-code-tags", "color": "#9C27B0"}, + "_method": "ai", + "_action": "generateCode", + "_paramMap": {"prompt": "prompt", "language": "language"}, + }, +] diff --git a/modules/features/automation2/nodeDefinitions/data.py b/modules/features/automation2/nodeDefinitions/data.py new file mode 100644 index 00000000..b44618d1 --- /dev/null +++ b/modules/features/automation2/nodeDefinitions/data.py @@ -0,0 +1,58 @@ +# Copyright (c) 2025 Patrick Motsch +# Data transformation node definitions. + +DATA_NODES = [ + { + "id": "data.setFields", + "category": "data", + "label": {"en": "Set Fields", "de": "Felder setzen", "fr": "Définir champs"}, + "description": {"en": "Set or override fields on payload", "de": "Felder setzen oder überschreiben", "fr": "Définir ou écraser des champs"}, + "parameters": [ + {"name": "fields", "type": "object", "required": True, "description": {"en": "Key-value pairs", "de": "Schlüssel-Wert-Paare", "fr": "Paires clé-valeur"}}, + ], + "inputs": 1, + "outputs": 1, + "executor": "data", + "meta": {"icon": "mdi-pencil", "color": "#673AB7"}, + }, + { + "id": "data.filter", + "category": "data", + "label": {"en": "Filter", "de": "Filtern", "fr": "Filtrer"}, + "description": {"en": "Filter array by condition", "de": "Array nach Bedingung filtern", "fr": "Filtrer tableau par condition"}, + "parameters": [ + {"name": "condition", "type": "string", "required": True, "description": {"en": "Expression (e.g. item.active == true)", "de": "Bedingung", "fr": "Condition"}}, + {"name": "itemsPath", "type": "string", "required": False, "description": {"en": "Path to array", "de": "Pfad zum Array", "fr": "Chemin vers le tableau"}}, + ], + "inputs": 1, + "outputs": 1, + "executor": "data", + "meta": {"icon": "mdi-filter", "color": "#673AB7"}, + }, + { + "id": "data.parseJson", + "category": "data", + "label": {"en": "Parse JSON", "de": "JSON parsen", "fr": "Parser JSON"}, + "description": {"en": "Parse JSON string to object", "de": "JSON-String in Objekt parsen", "fr": "Parser chaîne JSON en objet"}, + "parameters": [ + {"name": "jsonPath", "type": "string", "required": False, "description": {"en": "Path to JSON string (default: input)", "de": "Pfad zum JSON", "fr": "Chemin vers JSON"}}, + ], + "inputs": 1, + "outputs": 1, + "executor": "data", + "meta": {"icon": "mdi-code-json", "color": "#673AB7"}, + }, + { + "id": "data.template", + "category": "data", + "label": {"en": "Template / Interpolation", "de": "Vorlage / Interpolation", "fr": "Modèle / Interpolation"}, + "description": {"en": "Text with {{placeholder}} substitution", "de": "Text mit {{platzhalter}}-Ersetzung", "fr": "Texte avec substitution {{placeholder}}"}, + "parameters": [ + {"name": "template", "type": "string", "required": True, "description": {"en": "Template (use {{path}} for values)", "de": "Vorlage", "fr": "Modèle"}}, + ], + "inputs": 1, + "outputs": 1, + "executor": "data", + "meta": {"icon": "mdi-format-text", "color": "#673AB7"}, + }, +] diff --git a/modules/features/automation2/nodeDefinitions/email.py b/modules/features/automation2/nodeDefinitions/email.py new file mode 100644 index 00000000..b96a5389 --- /dev/null +++ b/modules/features/automation2/nodeDefinitions/email.py @@ -0,0 +1,70 @@ +# Copyright (c) 2025 Patrick Motsch +# Email node definitions - map to methodOutlook actions. +# Use connectionId from user connections (like AI workspace sources). + +EMAIL_NODES = [ + { + "id": "email.checkEmail", + "category": "email", + "label": {"en": "Check Email", "de": "E-Mail prüfen", "fr": "Vérifier email"}, + "description": {"en": "Check for new emails (general or from specific account)", "de": "Neue E-Mails prüfen", "fr": "Vérifier les nouveaux emails"}, + "parameters": [ + {"name": "connectionId", "type": "string", "required": True, "description": {"en": "Email account connection", "de": "E-Mail-Konto Verbindung", "fr": "Connexion compte email"}}, + {"name": "folder", "type": "string", "required": False, "description": {"en": "Folder (e.g. Inbox)", "de": "Ordner (z.B. Posteingang)", "fr": "Dossier (ex. Boîte de réception)"}, "default": "Inbox"}, + {"name": "limit", "type": "number", "required": False, "description": {"en": "Max emails to fetch", "de": "Max E-Mails", "fr": "Max emails"}, "default": 100}, + {"name": "fromAddress", "type": "string", "required": False, "description": {"en": "Only emails from this address", "de": "Nur E-Mails von dieser Adresse", "fr": "Seulement les e-mails de cette adresse"}, "default": ""}, + {"name": "subjectContains", "type": "string", "required": False, "description": {"en": "Subject must contain this text", "de": "Betreff muss diesen Text enthalten", "fr": "Le sujet doit contenir ce texte"}, "default": ""}, + {"name": "hasAttachment", "type": "boolean", "required": False, "description": {"en": "Only emails with attachments", "de": "Nur E-Mails mit Anhängen", "fr": "Seulement les e-mails avec pièces jointes"}, "default": False}, + {"name": "filter", "type": "string", "required": False, "description": {"en": "Advanced: raw filter (overrides above if set)", "de": "Erweitert: Filter-Text (überschreibt obige)", "fr": "Avancé: filtre brut"}, "default": ""}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-email-check", "color": "#1976D2"}, + "_method": "outlook", + "_action": "readEmails", + "_paramMap": {"connectionId": "connectionReference", "folder": "folder", "limit": "limit", "filter": "filter"}, + }, + { + "id": "email.searchEmail", + "category": "email", + "label": {"en": "Search Email", "de": "E-Mail suchen", "fr": "Rechercher email"}, + "description": {"en": "Search or find emails", "de": "E-Mails suchen oder finden", "fr": "Rechercher des emails"}, + "parameters": [ + {"name": "connectionId", "type": "string", "required": True, "description": {"en": "Email account connection", "de": "E-Mail-Konto Verbindung", "fr": "Connexion compte email"}}, + {"name": "query", "type": "string", "required": False, "description": {"en": "General search term (searches subject, body, from)", "de": "Suchbegriff (durchsucht Betreff, Inhalt, Absender)", "fr": "Terme de recherche (sujet, corps, expéditeur)"}, "default": ""}, + {"name": "folder", "type": "string", "required": False, "description": {"en": "Folder to search", "de": "Ordner zum Suchen", "fr": "Dossier à rechercher"}, "default": "Inbox"}, + {"name": "limit", "type": "number", "required": False, "description": {"en": "Max emails to return", "de": "Max E-Mails", "fr": "Max emails"}, "default": 100}, + {"name": "fromAddress", "type": "string", "required": False, "description": {"en": "Only emails from this address", "de": "Nur E-Mails von dieser Adresse", "fr": "Seulement les e-mails de cette adresse"}, "default": ""}, + {"name": "toAddress", "type": "string", "required": False, "description": {"en": "Only emails to this recipient", "de": "Nur E-Mails an diesen Empfänger", "fr": "Seulement les e-mails à ce destinataire"}, "default": ""}, + {"name": "subjectContains", "type": "string", "required": False, "description": {"en": "Subject must contain this text", "de": "Betreff muss diesen Text enthalten", "fr": "Le sujet doit contenir ce texte"}, "default": ""}, + {"name": "bodyContains", "type": "string", "required": False, "description": {"en": "Body/content must contain this text", "de": "Inhalt muss diesen Text enthalten", "fr": "Le corps doit contenir ce texte"}, "default": ""}, + {"name": "hasAttachment", "type": "boolean", "required": False, "description": {"en": "Only emails with attachments", "de": "Nur E-Mails mit Anhängen", "fr": "Seulement les e-mails avec pièces jointes"}, "default": False}, + {"name": "filter", "type": "string", "required": False, "description": {"en": "Advanced: raw KQL (overrides above if set)", "de": "Erweitert: KQL-Filter (überschreibt obige)", "fr": "Avancé: filtre KQL brut"}, "default": ""}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-email-search", "color": "#1976D2"}, + "_method": "outlook", + "_action": "searchEmails", + "_paramMap": {"connectionId": "connectionReference", "query": "query", "folder": "folder", "limit": "limit", "filter": "filter"}, + }, + { + "id": "email.draftEmail", + "category": "email", + "label": {"en": "Draft Email", "de": "E-Mail entwerfen", "fr": "Brouillon email"}, + "description": {"en": "Create a draft email", "de": "E-Mail-Entwurf erstellen", "fr": "Créer un brouillon d'email"}, + "parameters": [ + {"name": "connectionId", "type": "string", "required": True, "description": {"en": "Email account connection", "de": "E-Mail-Konto Verbindung", "fr": "Connexion compte email"}}, + {"name": "subject", "type": "string", "required": True, "description": {"en": "Email subject", "de": "E-Mail-Betreff", "fr": "Sujet"}}, + {"name": "body", "type": "string", "required": True, "description": {"en": "Email body", "de": "E-Mail-Text", "fr": "Corps de l'email"}}, + {"name": "to", "type": "string", "required": False, "description": {"en": "Recipient(s)", "de": "Empfänger", "fr": "Destinataire(s)"}, "default": ""}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-email-edit", "color": "#1976D2"}, + "_method": "outlook", + "_action": "composeAndDraftEmailWithContext", + "_paramMap": {"connectionId": "connectionReference", "to": "to"}, + "_contextFrom": ["subject", "body"], + }, +] diff --git a/modules/features/automation2/nodeDefinitions/flow.py b/modules/features/automation2/nodeDefinitions/flow.py new file mode 100644 index 00000000..573a83ad --- /dev/null +++ b/modules/features/automation2/nodeDefinitions/flow.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025 Patrick Motsch +# Flow control node definitions. + +FLOW_NODES = [ + { + "id": "flow.ifElse", + "category": "flow", + "label": {"en": "If / Else", "de": "Wenn / Sonst", "fr": "Si / Sinon"}, + "description": {"en": "Branch based on condition", "de": "Verzweigung nach Bedingung", "fr": "Branche selon condition"}, + "parameters": [ + {"name": "condition", "type": "string", "required": True, "description": {"en": "Expression to evaluate (e.g. {{value}} > 0)", "de": "Bedingung", "fr": "Condition"}}, + ], + "inputs": 1, + "outputs": 2, + "executor": "flow", + "meta": {"icon": "mdi-source-branch", "color": "#FF9800"}, + }, + { + "id": "flow.switch", + "category": "flow", + "label": {"en": "Switch", "de": "Switch", "fr": "Switch"}, + "description": {"en": "Multiple branches based on value", "de": "Mehrere Zweige nach Wert", "fr": "Branches multiples selon valeur"}, + "parameters": [ + {"name": "value", "type": "string", "required": True, "description": {"en": "Value to match", "de": "Zu vergleichender Wert", "fr": "Valeur à comparer"}}, + {"name": "cases", "type": "array", "required": False, "description": {"en": "List of cases", "de": "Fälle", "fr": "Cas"}}, + ], + "inputs": 1, + "outputs": 1, + "executor": "flow", + "meta": {"icon": "mdi-swap-horizontal", "color": "#FF9800"}, + }, + { + "id": "flow.merge", + "category": "flow", + "label": {"en": "Merge", "de": "Zusammenführen", "fr": "Fusionner"}, + "description": {"en": "Merge multiple inputs", "de": "Mehrere Eingaben zusammenführen", "fr": "Fusionner plusieurs entrées"}, + "parameters": [ + {"name": "mode", "type": "string", "required": False, "description": {"en": "append | combine", "de": "Modus", "fr": "Mode"}}, + ], + "inputs": 2, + "outputs": 1, + "executor": "flow", + "meta": {"icon": "mdi-merge", "color": "#FF9800"}, + }, + { + "id": "flow.loop", + "category": "flow", + "label": {"en": "Loop / For Each", "de": "Schleife / Für Jedes", "fr": "Boucle / Pour Chaque"}, + "description": {"en": "Iterate over array items", "de": "Über Array-Elemente iterieren", "fr": "Itérer sur les éléments"}, + "parameters": [ + {"name": "items", "type": "string", "required": True, "description": {"en": "Path to array (e.g. {{input.items}})", "de": "Pfad zum Array", "fr": "Chemin vers le tableau"}}, + ], + "inputs": 1, + "outputs": 1, + "executor": "flow", + "meta": {"icon": "mdi-repeat", "color": "#FF9800"}, + }, + { + "id": "flow.wait", + "category": "flow", + "label": {"en": "Wait / Delay", "de": "Warten / Verzögerung", "fr": "Attendre / Délai"}, + "description": {"en": "Pause for duration", "de": "Pause für Dauer", "fr": "Pause pour durée"}, + "parameters": [ + {"name": "seconds", "type": "number", "required": True, "description": {"en": "Seconds to wait", "de": "Sekunden", "fr": "Secondes"}}, + ], + "inputs": 1, + "outputs": 1, + "executor": "flow", + "meta": {"icon": "mdi-timer", "color": "#FF9800"}, + }, + { + "id": "flow.stop", + "category": "flow", + "label": {"en": "Stop / Terminate", "de": "Stopp / Beenden", "fr": "Arrêter / Terminer"}, + "description": {"en": "Stop workflow execution", "de": "Workflow-Ausführung beenden", "fr": "Arrêter l'exécution"}, + "parameters": [], + "inputs": 1, + "outputs": 0, + "executor": "flow", + "meta": {"icon": "mdi-stop", "color": "#F44336"}, + }, +] diff --git a/modules/features/automation2/nodeDefinitions/input.py b/modules/features/automation2/nodeDefinitions/input.py new file mode 100644 index 00000000..8eb43e63 --- /dev/null +++ b/modules/features/automation2/nodeDefinitions/input.py @@ -0,0 +1,117 @@ +# Copyright (c) 2025 Patrick Motsch +# Input/Human node definitions - nodes that require user action. + +INPUT_NODES = [ + { + "id": "input.form", + "category": "input", + "label": {"en": "Form", "de": "Formular", "fr": "Formulaire"}, + "description": {"en": "User fills out a form", "de": "Benutzer füllt ein Formular aus", "fr": "L'utilisateur remplit un formulaire"}, + "parameters": [ + { + "name": "fields", + "type": "json", + "required": True, + "description": {"en": "Form fields: [{name, type, label, required, options?}]", "de": "Formularfelder", "fr": "Champs du formulaire"}, + "default": [], + }, + ], + "inputs": 1, + "outputs": 1, + "executor": "input", + "meta": {"icon": "mdi-form-textbox", "color": "#9C27B0"}, + }, + { + "id": "input.approval", + "category": "input", + "label": {"en": "Approval", "de": "Genehmigung", "fr": "Approbation"}, + "description": {"en": "User approves or rejects", "de": "Benutzer genehmigt oder lehnt ab", "fr": "L'utilisateur approuve ou rejette"}, + "parameters": [ + {"name": "title", "type": "string", "required": True, "description": {"en": "Approval title", "de": "Genehmigungstitel", "fr": "Titre"}}, + {"name": "description", "type": "string", "required": False, "description": {"en": "What to approve", "de": "Was genehmigt werden soll", "fr": "Ce qu'il faut approuver"}}, + {"name": "approvalType", "type": "string", "required": False, "description": {"en": "Type: document or generic", "de": "Typ: document oder generic", "fr": "Type: document ou generic"}, "default": "generic"}, + ], + "inputs": 1, + "outputs": 1, + "executor": "input", + "meta": {"icon": "mdi-check-decagram", "color": "#4CAF50"}, + }, + { + "id": "input.upload", + "category": "input", + "label": {"en": "Upload", "de": "Upload", "fr": "Téléversement"}, + "description": {"en": "User uploads file(s)", "de": "Benutzer lädt Datei(en) hoch", "fr": "L'utilisateur téléverse des fichiers"}, + "parameters": [ + {"name": "accept", "type": "string", "required": False, "description": {"en": "MIME types (e.g. .pdf,image/*)", "de": "MIME-Typen", "fr": "Types MIME"}, "default": ""}, + {"name": "maxSize", "type": "number", "required": False, "description": {"en": "Max file size in MB", "de": "Max. Dateigröße in MB", "fr": "Taille max en Mo"}, "default": 10}, + {"name": "multiple", "type": "boolean", "required": False, "description": {"en": "Allow multiple files", "de": "Mehrere Dateien erlauben", "fr": "Autoriser plusieurs fichiers"}, "default": False}, + ], + "inputs": 1, + "outputs": 1, + "executor": "input", + "meta": {"icon": "mdi-upload", "color": "#2196F3"}, + }, + { + "id": "input.comment", + "category": "input", + "label": {"en": "Comment", "de": "Kommentar", "fr": "Commentaire"}, + "description": {"en": "User adds a comment", "de": "Benutzer fügt einen Kommentar hinzu", "fr": "L'utilisateur ajoute un commentaire"}, + "parameters": [ + {"name": "placeholder", "type": "string", "required": False, "description": {"en": "Placeholder text", "de": "Platzhalter", "fr": "Texte indicatif"}, "default": ""}, + {"name": "required", "type": "boolean", "required": False, "description": {"en": "Comment required", "de": "Kommentar erforderlich", "fr": "Commentaire requis"}, "default": True}, + ], + "inputs": 1, + "outputs": 1, + "executor": "input", + "meta": {"icon": "mdi-comment-text", "color": "#FF9800"}, + }, + { + "id": "input.review", + "category": "input", + "label": {"en": "Review", "de": "Prüfung", "fr": "Revue"}, + "description": {"en": "User reviews content", "de": "Benutzer prüft Inhalt", "fr": "L'utilisateur révise le contenu"}, + "parameters": [ + {"name": "contentRef", "type": "string", "required": True, "description": {"en": "Reference to content (e.g. {{nodeId.field}})", "de": "Referenz auf Inhalt", "fr": "Référence au contenu"}}, + {"name": "reviewType", "type": "string", "required": False, "description": {"en": "Type of review", "de": "Art der Prüfung", "fr": "Type de revue"}, "default": "generic"}, + ], + "inputs": 1, + "outputs": 1, + "executor": "input", + "meta": {"icon": "mdi-magnify-scan", "color": "#673AB7"}, + }, + { + "id": "input.selection", + "category": "input", + "label": {"en": "Selection", "de": "Auswahl", "fr": "Sélection"}, + "description": {"en": "User selects from options", "de": "Benutzer wählt aus Optionen", "fr": "L'utilisateur choisit parmi les options"}, + "parameters": [ + { + "name": "options", + "type": "json", + "required": True, + "description": {"en": "Options: [{value, label}]", "de": "Optionen", "fr": "Options"}, + "default": [], + }, + {"name": "multiple", "type": "boolean", "required": False, "description": {"en": "Allow multiple selection", "de": "Mehrfachauswahl erlauben", "fr": "Sélection multiple"}, "default": False}, + ], + "inputs": 1, + "outputs": 1, + "executor": "input", + "meta": {"icon": "mdi-format-list-checks", "color": "#009688"}, + }, + { + "id": "input.confirmation", + "category": "input", + "label": {"en": "Confirmation", "de": "Bestätigung", "fr": "Confirmation"}, + "description": {"en": "User confirms yes/no", "de": "Benutzer bestätigt Ja/Nein", "fr": "L'utilisateur confirme oui/non"}, + "parameters": [ + {"name": "question", "type": "string", "required": True, "description": {"en": "Question to confirm", "de": "Zu bestätigende Frage", "fr": "Question à confirmer"}}, + {"name": "confirmLabel", "type": "string", "required": False, "description": {"en": "Label for confirm button", "de": "Label für Bestätigen-Button", "fr": "Libellé du bouton confirmer"}, "default": "Confirm"}, + {"name": "rejectLabel", "type": "string", "required": False, "description": {"en": "Label for reject button", "de": "Label für Ablehnen-Button", "fr": "Libellé du bouton refuser"}, "default": "Reject"}, + ], + "inputs": 1, + "outputs": 1, + "executor": "input", + "meta": {"icon": "mdi-checkbox-marked-circle", "color": "#8BC34A"}, + }, +] diff --git a/modules/features/automation2/nodeDefinitions/sharepoint.py b/modules/features/automation2/nodeDefinitions/sharepoint.py new file mode 100644 index 00000000..f0dd30cf --- /dev/null +++ b/modules/features/automation2/nodeDefinitions/sharepoint.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025 Patrick Motsch +# SharePoint node definitions - map to methodSharepoint actions. +# Use connectionId and path from connection selector (like workflow folder view). + +SHAREPOINT_NODES = [ + { + "id": "sharepoint.findFile", + "category": "sharepoint", + "label": {"en": "Find File", "de": "Datei finden", "fr": "Trouver fichier"}, + "description": {"en": "Find file by path or search", "de": "Datei nach Pfad oder Suche finden", "fr": "Trouver fichier par chemin ou recherche"}, + "parameters": [ + {"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}}, + {"name": "searchQuery", "type": "string", "required": True, "description": {"en": "Search query or path", "de": "Suchanfrage oder Pfad", "fr": "Requête ou chemin"}}, + {"name": "site", "type": "string", "required": False, "description": {"en": "Optional site hint", "de": "Optionaler Site-Hinweis", "fr": "Indication de site"}, "default": ""}, + {"name": "maxResults", "type": "number", "required": False, "description": {"en": "Max results", "de": "Max Ergebnisse", "fr": "Max résultats"}, "default": 1000}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-file-search", "color": "#0078D4"}, + "_method": "sharepoint", + "_action": "findDocumentPath", + "_paramMap": {"connectionId": "connectionReference", "searchQuery": "searchQuery", "site": "site", "maxResults": "maxResults"}, + }, + { + "id": "sharepoint.readFile", + "category": "sharepoint", + "label": {"en": "Read File", "de": "Datei lesen", "fr": "Lire fichier"}, + "description": {"en": "Extract content from file", "de": "Inhalt aus Datei extrahieren", "fr": "Extraire le contenu du fichier"}, + "parameters": [ + {"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}}, + {"name": "path", "type": "string", "required": True, "description": {"en": "File path or documentList from find file", "de": "Dateipfad oder documentList von Find", "fr": "Chemin ou documentList"}}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-file-document", "color": "#0078D4"}, + "_method": "sharepoint", + "_action": "readDocuments", + "_paramMap": {"connectionId": "connectionReference", "path": "pathQuery"}, + }, + { + "id": "sharepoint.uploadFile", + "category": "sharepoint", + "label": {"en": "Upload File", "de": "Datei hochladen", "fr": "Téléverser fichier"}, + "description": {"en": "Upload file to SharePoint", "de": "Datei zu SharePoint hochladen", "fr": "Téléverser fichier vers SharePoint"}, + "parameters": [ + {"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}}, + {"name": "path", "type": "string", "required": True, "description": {"en": "Target folder path (e.g. /sites/.../Folder)", "de": "Zielordner-Pfad", "fr": "Chemin du dossier cible"}}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-upload", "color": "#0078D4"}, + "_method": "sharepoint", + "_action": "uploadFile", + "_paramMap": {"connectionId": "connectionReference", "path": "pathQuery"}, + }, + { + "id": "sharepoint.listFiles", + "category": "sharepoint", + "label": {"en": "List Files", "de": "Dateien auflisten", "fr": "Lister fichiers"}, + "description": {"en": "List files in folder or SharePoint", "de": "Dateien in Ordner oder SharePoint auflisten", "fr": "Lister les fichiers dans un dossier"}, + "parameters": [ + {"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}}, + {"name": "path", "type": "string", "required": False, "description": {"en": "Folder path (e.g. /sites/SiteName/Shared Documents)", "de": "Ordnerpfad", "fr": "Chemin du dossier"}, "default": "/"}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-folder-open", "color": "#0078D4"}, + "_method": "sharepoint", + "_action": "listDocuments", + "_paramMap": {"connectionId": "connectionReference", "path": "pathQuery"}, + }, + { + "id": "sharepoint.downloadFile", + "category": "sharepoint", + "label": {"en": "Download File", "de": "Datei herunterladen", "fr": "Télécharger fichier"}, + "description": {"en": "Download file from path (e.g. /sites/SiteName/Shared Documents/file.pdf)", "de": "Datei vom Pfad herunterladen", "fr": "Télécharger le fichier"}, + "parameters": [ + {"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}}, + {"name": "path", "type": "string", "required": True, "description": {"en": "Full file path (e.g. /sites/SiteName/Shared Documents/file.pdf)", "de": "Vollständiger Dateipfad", "fr": "Chemin complet du fichier"}}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-download", "color": "#0078D4"}, + "_method": "sharepoint", + "_action": "downloadFileByPath", + "_paramMap": {"connectionId": "connectionReference", "path": "pathQuery", "siteId": "siteId", "filePath": "filePath"}, + }, + { + "id": "sharepoint.copyFile", + "category": "sharepoint", + "label": {"en": "Copy File", "de": "Datei kopieren", "fr": "Copier fichier"}, + "description": {"en": "Copy file to destination", "de": "Datei an Ziel kopieren", "fr": "Copier le fichier"}, + "parameters": [ + {"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}}, + {"name": "sourcePath", "type": "string", "required": True, "description": {"en": "Source file path (from browse)", "de": "Quelldatei-Pfad", "fr": "Chemin fichier source"}}, + {"name": "destPath", "type": "string", "required": True, "description": {"en": "Destination folder path (from browse)", "de": "Zielordner-Pfad", "fr": "Chemin dossier cible"}}, + ], + "inputs": 1, + "outputs": 1, + "meta": {"icon": "mdi-content-copy", "color": "#0078D4"}, + "_method": "sharepoint", + "_action": "copyFile", + "_paramMap": {"connectionId": "connectionReference", "sourcePath": "sourcePath", "destPath": "destPath"}, + }, +] diff --git a/modules/features/automation2/nodeDefinitions/triggers.py b/modules/features/automation2/nodeDefinitions/triggers.py new file mode 100644 index 00000000..0e206dc0 --- /dev/null +++ b/modules/features/automation2/nodeDefinitions/triggers.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 Patrick Motsch +# Trigger node definitions - workflow entry points. + +TRIGGER_NODES = [ + { + "id": "trigger.manual", + "category": "trigger", + "label": {"en": "Manual Trigger", "de": "Manueller Trigger", "fr": "Déclencheur manuel"}, + "description": {"en": "Start workflow on button press", "de": "Startet den Workflow bei Knopfdruck", "fr": "Démarre le workflow sur clic"}, + "parameters": [], + "inputs": 0, + "outputs": 1, + "executor": "trigger", + "meta": {"icon": "mdi-play", "color": "#4CAF50"}, + }, + { + "id": "trigger.schedule", + "category": "trigger", + "label": {"en": "Schedule", "de": "Zeitplan", "fr": "Planification"}, + "description": {"en": "Run on a cron schedule", "de": "Läuft nach Cron-Zeitplan", "fr": "S'exécute selon un cron"}, + "parameters": [ + {"name": "cron", "type": "string", "required": True, "description": {"en": "Cron expression (e.g. 0 9 * * * for daily at 9)", "de": "Cron-Ausdruck", "fr": "Expression cron"}}, + ], + "inputs": 0, + "outputs": 1, + "executor": "trigger", + "meta": {"icon": "mdi-clock", "color": "#2196F3"}, + }, + { + "id": "trigger.formSubmit", + "category": "trigger", + "label": {"en": "Form Submit", "de": "Formular-Absendung", "fr": "Soumission formulaire"}, + "description": {"en": "Start when form is submitted", "de": "Startet bei Formular-Absendung", "fr": "Démarre à la soumission du formulaire"}, + "parameters": [ + {"name": "formId", "type": "string", "required": True, "description": {"en": "Form identifier", "de": "Formular-ID", "fr": "Identifiant du formulaire"}}, + ], + "inputs": 0, + "outputs": 1, + "executor": "trigger", + "meta": {"icon": "mdi-form-select", "color": "#9C27B0"}, + }, +] diff --git a/modules/features/automation2/nodeRegistry.py b/modules/features/automation2/nodeRegistry.py new file mode 100644 index 00000000..39c3e2c9 --- /dev/null +++ b/modules/features/automation2/nodeRegistry.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Node Type Registry for automation2 - static node definitions (ai, email, sharepoint, trigger, flow, data, input). +Nodes are defined first; IO/method actions are used at execution time. +""" + +import logging +from typing import Dict, List, Any + +from modules.features.automation2.nodeDefinitions import STATIC_NODE_TYPES + +logger = logging.getLogger(__name__) + + +def getNodeTypes( + services: Any = None, + language: str = "en", +) -> List[Dict[str, Any]]: + """ + Return static node types. No dynamic I/O derivation from methodDiscovery. + services: Optional (kept for API compatibility, not used). + """ + return list(STATIC_NODE_TYPES) + + +def _localizeNode(node: Dict[str, Any], language: str) -> Dict[str, Any]: + """Apply language to label/description/parameters.""" + lang = language if language in ("en", "de", "fr") else "en" + out = dict(node) + # Strip internal keys for API response + for key in list(out.keys()): + if key.startswith("_"): + del out[key] + if isinstance(node.get("label"), dict): + out["label"] = node["label"].get(lang, node["label"].get("en", str(node["label"]))) + if isinstance(node.get("description"), dict): + out["description"] = node["description"].get(lang, node["description"].get("en", str(node["description"]))) + params = [] + for p in node.get("parameters", []): + pc = dict(p) + if isinstance(p.get("description"), dict): + pc["description"] = p["description"].get(lang, p["description"].get("en", str(p.get("description", "")))) + params.append(pc) + out["parameters"] = params + return out + + +def getNodeTypesForApi( + services: Any, + language: str = "en", +) -> Dict[str, Any]: + """ + API-ready response: nodeTypes with localized strings, plus categories list. + """ + nodes = getNodeTypes(services, language) + localized = [_localizeNode(n, language) for n in nodes] + categories = [ + {"id": "trigger", "label": {"en": "Trigger", "de": "Trigger", "fr": "Déclencheur"}}, + {"id": "input", "label": {"en": "Input/Human", "de": "Eingabe/Mensch", "fr": "Entrée/Humain"}}, + {"id": "flow", "label": {"en": "Flow", "de": "Ablauf", "fr": "Flux"}}, + {"id": "data", "label": {"en": "Data", "de": "Daten", "fr": "Données"}}, + {"id": "ai", "label": {"en": "AI", "de": "KI", "fr": "IA"}}, + {"id": "email", "label": {"en": "Email", "de": "E-Mail", "fr": "Email"}}, + {"id": "sharepoint", "label": {"en": "SharePoint", "de": "SharePoint", "fr": "SharePoint"}}, + ] + return {"nodeTypes": localized, "categories": categories} + + +def getNodeTypeToMethodAction() -> Dict[str, tuple]: + """ + Mapping from node type id to (method, action) for execution. + Used by ActionNodeExecutor. + """ + mapping = {} + for node in STATIC_NODE_TYPES: + method = node.get("_method") + action = node.get("_action") + if method and action: + mapping[node["id"]] = (method, action) + return mapping diff --git a/modules/features/automation2/routeFeatureAutomation2.py b/modules/features/automation2/routeFeatureAutomation2.py new file mode 100644 index 00000000..996c3cb6 --- /dev/null +++ b/modules/features/automation2/routeFeatureAutomation2.py @@ -0,0 +1,602 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Automation2 routes - node-types, execute, workflows, runs, tasks, connections, browse. +""" + +import logging +from fastapi import APIRouter, Depends, Path, Query, Body, Request, HTTPException +from fastapi.responses import JSONResponse +from modules.auth import limiter, getRequestContext, RequestContext + +from modules.features.automation2.mainAutomation2 import getAutomation2Services +from modules.features.automation2.nodeRegistry import getNodeTypesForApi +from modules.features.automation2.interfaceFeatureAutomation2 import getAutomation2Interface +from modules.workflows.automation2.executionEngine import executeGraph + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/api/automation2", + tags=["Automation2"], + responses={404: {"description": "Not found"}, 403: {"description": "Forbidden"}}, +) + + +def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str: + """Validate user has access to the automation2 feature instance. Returns mandateId.""" + from fastapi import HTTPException + from modules.interfaces.interfaceDbApp import getRootInterface + + rootInterface = getRootInterface() + instance = rootInterface.getFeatureInstance(instanceId) + if not instance: + raise HTTPException(status_code=404, detail=f"Feature instance {instanceId} not found") + featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId) + if not featureAccess or not featureAccess.enabled: + raise HTTPException(status_code=403, detail="Access denied to this feature instance") + return str(instance.mandateId) if instance.mandateId else "" + + +@router.get("/{instanceId}/info") +@limiter.limit("60/minute") +def get_automation2_info( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Minimal info endpoint - proves the feature works.""" + _validateInstanceAccess(instanceId, context) + return { + "featureCode": "automation2", + "instanceId": instanceId, + "status": "ok", + "message": "Automation2 feature ready. Build from here.", + } + + +@router.get("/{instanceId}/node-types") +@limiter.limit("60/minute") +def get_node_types( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + language: str = Query("en", description="Localization (en, de, fr)"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Return node types for the flow builder: static + I/O from methodDiscovery.""" + logger.info("automation2 node-types request: instanceId=%s language=%s", instanceId, language) + mandateId = _validateInstanceAccess(instanceId, context) + services = getAutomation2Services( + context.user, + mandateId=mandateId, + featureInstanceId=instanceId, + ) + result = getNodeTypesForApi(services, language=language) + logger.info( + "automation2 node-types response: %d nodeTypes %d categories", + len(result.get("nodeTypes", [])), + len(result.get("categories", [])), + ) + return result + + +@router.post("/{instanceId}/execute") +@limiter.limit("30/minute") +async def post_execute( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + body: dict = Body(..., description="{ workflowId?, graph: { nodes, connections } }"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Execute automation2 graph. Body: { workflowId?, graph: { nodes, connections } }.""" + userId = str(context.user.id) if context.user else None + logger.info( + "automation2 execute request: instanceId=%s userId=%s body_keys=%s", + instanceId, + userId, + list(body.keys()), + ) + mandateId = _validateInstanceAccess(instanceId, context) + services = getAutomation2Services( + context.user, + mandateId=mandateId, + featureInstanceId=instanceId, + ) + # Ensure workflow methods (outlook, ai, sharepoint, etc.) are discovered for ActionExecutor + from modules.workflows.processing.shared.methodDiscovery import discoverMethods + discoverMethods(services) + + graph = body.get("graph") or body + workflowId = body.get("workflowId") + req_nodes = graph.get("nodes") or [] + # When workflowId is set: prefer graph from request (current editor state) if it has nodes. + # Only fall back to stored workflow graph when request graph is empty (e.g. resume from email). + if workflowId and len(req_nodes) == 0: + a2 = getAutomation2Interface(context.user, mandateId, instanceId) + wf = a2.getWorkflow(workflowId) + if wf and wf.get("graph"): + graph = wf["graph"] + logger.info("automation2 execute: loaded graph from workflow %s", workflowId) + # Use transient workflowId when none provided (e.g. execute from editor without save) + # Required for email.checkEmail pause/resume - run must be created + if not workflowId: + import uuid + workflowId = f"transient-{uuid.uuid4().hex[:12]}" + logger.info("automation2 execute: using transient workflowId=%s", workflowId) + nodes_count = len(graph.get("nodes") or []) + connections_count = len(graph.get("connections") or []) + logger.info( + "automation2 execute: graph nodes=%d connections=%d workflowId=%s mandateId=%s", + nodes_count, + connections_count, + workflowId, + mandateId, + ) + a2_interface = getAutomation2Interface(context.user, mandateId, instanceId) + result = await executeGraph( + graph=graph, + services=services, + workflowId=workflowId, + instanceId=instanceId, + userId=userId, + mandateId=mandateId, + automation2_interface=a2_interface, + ) + logger.info( + "automation2 execute result: success=%s error=%s nodeOutputs_keys=%s failedNode=%s paused=%s", + result.get("success"), + result.get("error"), + list(result.get("nodeOutputs", {}).keys()) if result.get("nodeOutputs") else [], + result.get("failedNode"), + result.get("paused"), + ) + return result + + +# ------------------------------------------------------------------------- +# Connections and Browse (for Email/SharePoint node config - like workspace) +# ------------------------------------------------------------------------- + + +def _buildResolverDbInterface(chatService): + """Build a DB adapter that ConnectorResolver can use to load UserConnections.""" + class _ResolverDbAdapter: + def __init__(self, appInterface): + self._app = appInterface + + def getUserConnection(self, connectionId: str): + if hasattr(self._app, "getUserConnectionById"): + return self._app.getUserConnectionById(connectionId) + return None + + appIf = getattr(chatService, "interfaceDbApp", None) + if appIf: + return _ResolverDbAdapter(appIf) + return getattr(chatService, "interfaceDbComponent", None) + + +@router.get("/{instanceId}/connections") +@limiter.limit("300/minute") +def list_automation2_connections( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Return the user's active connections (UserConnections) for Email/SharePoint node config.""" + mandateId = _validateInstanceAccess(instanceId, context) + from modules.serviceCenter import getService + from modules.serviceCenter.context import ServiceCenterContext + ctx = ServiceCenterContext( + user=context.user, + mandate_id=str(context.mandateId) if context.mandateId else mandateId, + feature_instance_id=instanceId, + ) + chatService = getService("chat", ctx) + connections = chatService.getUserConnections() + items = [] + for c in connections or []: + conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {}) + authority = conn.get("authority") + if hasattr(authority, "value"): + authority = authority.value + status = conn.get("status") + if hasattr(status, "value"): + status = status.value + items.append({ + "id": conn.get("id"), + "authority": authority, + "externalUsername": conn.get("externalUsername"), + "externalEmail": conn.get("externalEmail"), + "status": status, + }) + return {"connections": items} + + +@router.get("/{instanceId}/connections/{connectionId}/services") +@limiter.limit("120/minute") +async def list_connection_services( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + connectionId: str = Path(..., description="Connection ID"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Return the available services for a specific UserConnection.""" + mandateId = _validateInstanceAccess(instanceId, context) + try: + from modules.connectors.connectorResolver import ConnectorResolver + from modules.serviceCenter import getService as getSvc + from modules.serviceCenter.context import ServiceCenterContext + ctx = ServiceCenterContext( + user=context.user, + mandate_id=str(context.mandateId) if context.mandateId else mandateId, + feature_instance_id=instanceId, + ) + chatService = getSvc("chat", ctx) + securityService = getSvc("security", ctx) + dbInterface = _buildResolverDbInterface(chatService) + resolver = ConnectorResolver(securityService, dbInterface) + provider = await resolver.resolve(connectionId) + services = provider.getAvailableServices() + _serviceLabels = { + "sharepoint": "SharePoint", + "outlook": "Outlook", + "teams": "Teams", + "onedrive": "OneDrive", + "drive": "Google Drive", + "gmail": "Gmail", + "files": "Files (FTP)", + } + _serviceIcons = { + "sharepoint": "sharepoint", + "outlook": "mail", + "teams": "chat", + "onedrive": "cloud", + "drive": "cloud", + "gmail": "mail", + "files": "folder", + } + items = [ + {"service": s, "label": _serviceLabels.get(s, s), "icon": _serviceIcons.get(s, "folder")} + for s in services + ] + return {"services": items} + except Exception as e: + logger.error(f"Error listing services for connection {connectionId}: {e}") + return JSONResponse({"services": [], "error": str(e)}, status_code=400) + + +@router.get("/{instanceId}/connections/{connectionId}/browse") +@limiter.limit("300/minute") +async def browse_connection_service( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + connectionId: str = Path(..., description="Connection ID"), + service: str = Query(..., description="Service name (e.g. sharepoint, onedrive, outlook)"), + path: str = Query("/", description="Path within the service to browse"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Browse folders/items within a connection's service at a given path.""" + mandateId = _validateInstanceAccess(instanceId, context) + try: + from modules.connectors.connectorResolver import ConnectorResolver + from modules.serviceCenter import getService as getSvc + from modules.serviceCenter.context import ServiceCenterContext + ctx = ServiceCenterContext( + user=context.user, + mandate_id=str(context.mandateId) if context.mandateId else mandateId, + feature_instance_id=instanceId, + ) + chatService = getSvc("chat", ctx) + securityService = getSvc("security", ctx) + dbInterface = _buildResolverDbInterface(chatService) + resolver = ConnectorResolver(securityService, dbInterface) + adapter = await resolver.resolveService(connectionId, service) + entries = await adapter.browse(path, filter=None) + items = [] + for entry in (entries or []): + items.append({ + "name": entry.name, + "path": entry.path, + "isFolder": entry.isFolder, + "size": entry.size, + "mimeType": entry.mimeType, + "metadata": entry.metadata if hasattr(entry, "metadata") else {}, + }) + return {"items": items, "path": path, "service": service} + except Exception as e: + logger.error(f"Error browsing {service} for connection {connectionId} at '{path}': {e}") + return JSONResponse({"items": [], "error": str(e)}, status_code=400) + + +# ------------------------------------------------------------------------- +# Workflow CRUD +# ------------------------------------------------------------------------- + + +def _get_node_label_from_graph(graph: dict, nodeId: str) -> str: + """Extract human-readable label for a node from graph.""" + if not graph or not nodeId: + return nodeId or "" + nodes = graph.get("nodes") or [] + for n in nodes: + if n.get("id") == nodeId: + params = n.get("parameters") or {} + config = params.get("config") or {} + if isinstance(config, dict): + label = config.get("title") or config.get("label") + else: + label = None + return ( + n.get("title") + or label + or params.get("title") + or params.get("label") + or n.get("type", "") + or nodeId + ) + return nodeId or "" + + +@router.get("/{instanceId}/workflows") +@limiter.limit("60/minute") +def get_workflows( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """List all workflows for this feature instance. + Enriches each workflow with runCount, isRunning, stuckAtNodeId, stuckAtNodeLabel, + createdAt, lastStartedAt. + """ + mandateId = _validateInstanceAccess(instanceId, context) + a2 = getAutomation2Interface(context.user, mandateId, instanceId) + items = a2.getWorkflows() + enriched = [] + for wf in items: + wf_id = wf.get("id") + runs = a2.getRunsByWorkflow(wf_id) if wf_id else [] + run_count = len(runs) + active_run = None + last_started_at = None + for r in runs: + ts = r.get("_createdAt") + if ts and (last_started_at is None or ts > last_started_at): + last_started_at = ts + if r.get("status") in ("running", "paused"): + active_run = r + stuck_at_node_id = active_run.get("currentNodeId") if active_run else None + stuck_at_node_label = "" + if stuck_at_node_id and wf.get("graph"): + stuck_at_node_label = _get_node_label_from_graph(wf["graph"], stuck_at_node_id) + enriched.append({ + **wf, + "runCount": run_count, + "isRunning": active_run is not None, + "runStatus": active_run.get("status") if active_run else None, + "stuckAtNodeId": stuck_at_node_id, + "stuckAtNodeLabel": stuck_at_node_label or stuck_at_node_id or "", + "createdAt": wf.get("_createdAt"), + "lastStartedAt": last_started_at, + }) + return {"workflows": enriched} + + +@router.get("/{instanceId}/workflows/{workflowId}") +@limiter.limit("60/minute") +def get_workflow( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + workflowId: str = Path(..., description="Workflow ID"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Get a single workflow by ID.""" + mandateId = _validateInstanceAccess(instanceId, context) + a2 = getAutomation2Interface(context.user, mandateId, instanceId) + wf = a2.getWorkflow(workflowId) + if not wf: + raise HTTPException(status_code=404, detail="Workflow not found") + return wf + + +@router.post("/{instanceId}/workflows") +@limiter.limit("30/minute") +def create_workflow( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + body: dict = Body(..., description="{ label, graph }"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Create a new workflow.""" + mandateId = _validateInstanceAccess(instanceId, context) + a2 = getAutomation2Interface(context.user, mandateId, instanceId) + created = a2.createWorkflow(body) + return created + + +@router.put("/{instanceId}/workflows/{workflowId}") +@limiter.limit("30/minute") +def update_workflow( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + workflowId: str = Path(..., description="Workflow ID"), + body: dict = Body(..., description="{ label?, graph? }"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Update a workflow.""" + mandateId = _validateInstanceAccess(instanceId, context) + a2 = getAutomation2Interface(context.user, mandateId, instanceId) + updated = a2.updateWorkflow(workflowId, body) + if not updated: + raise HTTPException(status_code=404, detail="Workflow not found") + return updated + + +@router.delete("/{instanceId}/workflows/{workflowId}") +@limiter.limit("30/minute") +def delete_workflow( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + workflowId: str = Path(..., description="Workflow ID"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Delete a workflow.""" + mandateId = _validateInstanceAccess(instanceId, context) + a2 = getAutomation2Interface(context.user, mandateId, instanceId) + if not a2.deleteWorkflow(workflowId): + raise HTTPException(status_code=404, detail="Workflow not found") + return {"success": True} + + +# ------------------------------------------------------------------------- +# Runs and Resume +# ------------------------------------------------------------------------- + + +@router.get("/{instanceId}/workflows/{workflowId}/runs") +@limiter.limit("60/minute") +def get_workflow_runs( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + workflowId: str = Path(..., description="Workflow ID"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Get runs for a workflow.""" + mandateId = _validateInstanceAccess(instanceId, context) + a2 = getAutomation2Interface(context.user, mandateId, instanceId) + if not a2.getWorkflow(workflowId): + raise HTTPException(status_code=404, detail="Workflow not found") + runs = a2.getRunsByWorkflow(workflowId) + return {"runs": runs} + + +@router.post("/{instanceId}/runs/{runId}/resume") +@limiter.limit("30/minute") +async def resume_run( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + runId: str = Path(..., description="Run ID"), + body: dict = Body(..., description="{ taskId, result }"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Resume a paused run after task completion.""" + mandateId = _validateInstanceAccess(instanceId, context) + a2 = getAutomation2Interface(context.user, mandateId, instanceId) + run = a2.getRun(runId) + if not run: + raise HTTPException(status_code=404, detail="Run not found") + taskId = body.get("taskId") + result = body.get("result") + if not taskId or result is None: + raise HTTPException(status_code=400, detail="taskId and result required") + task = a2.getTask(taskId) + if not task or task.get("runId") != runId: + raise HTTPException(status_code=404, detail="Task not found") + if task.get("status") != "pending": + raise HTTPException(status_code=400, detail="Task already completed") + a2.updateTask(taskId, status="completed", result=result) + nodeId = task.get("nodeId") + nodeOutputs = dict(run.get("nodeOutputs") or {}) + nodeOutputs[nodeId] = result + runContext = run.get("context") or {} + connectionMap = runContext.get("connectionMap", {}) + inputSources = runContext.get("inputSources", {}) + workflowId = run.get("workflowId") + wf = a2.getWorkflow(workflowId) if workflowId else None + if not wf or not wf.get("graph"): + raise HTTPException(status_code=400, detail="Workflow graph not found") + graph = wf["graph"] + services = getAutomation2Services(context.user, mandateId=mandateId, featureInstanceId=instanceId) + resume_result = await executeGraph( + graph=graph, + services=services, + workflowId=workflowId, + instanceId=instanceId, + userId=str(context.user.id) if context.user else None, + mandateId=mandateId, + automation2_interface=a2, + initialNodeOutputs=nodeOutputs, + startAfterNodeId=nodeId, + runId=runId, + ) + return resume_result + + +# ------------------------------------------------------------------------- +# Tasks +# ------------------------------------------------------------------------- + + +@router.get("/{instanceId}/tasks") +@limiter.limit("60/minute") +def get_tasks( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + workflowId: str = Query(None, description="Filter by workflow ID"), + status: str = Query(None, description="Filter: pending, completed, rejected"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Get tasks - by default those assigned to current user, or all if no assignee filter. + Enriches each task with workflowLabel and createdAt (_createdAt). + """ + mandateId = _validateInstanceAccess(instanceId, context) + a2 = getAutomation2Interface(context.user, mandateId, instanceId) + assigneeId = str(context.user.id) if context.user else None + items = a2.getTasks(workflowId=workflowId, status=status, assigneeId=assigneeId) + workflows = {w["id"]: w for w in a2.getWorkflows()} + enriched = [] + for t in items: + wf = workflows.get(t.get("workflowId") or "") + enriched.append({ + **t, + "workflowLabel": wf.get("label", t.get("workflowId", "")) if wf else t.get("workflowId", ""), + "createdAt": t.get("_createdAt"), + }) + return {"tasks": enriched} + + +@router.post("/{instanceId}/tasks/{taskId}/complete") +@limiter.limit("30/minute") +async def complete_task( + request: Request, + instanceId: str = Path(..., description="Feature instance ID"), + taskId: str = Path(..., description="Task ID"), + body: dict = Body(..., description="{ result }"), + context: RequestContext = Depends(getRequestContext), +) -> dict: + """Complete a task and resume the run.""" + mandateId = _validateInstanceAccess(instanceId, context) + a2 = getAutomation2Interface(context.user, mandateId, instanceId) + task = a2.getTask(taskId) + if not task: + raise HTTPException(status_code=404, detail="Task not found") + runId = task.get("runId") + result = body.get("result") + if result is None: + raise HTTPException(status_code=400, detail="result required") + run = a2.getRun(runId) + if not run: + raise HTTPException(status_code=404, detail="Run not found") + if task.get("status") != "pending": + raise HTTPException(status_code=400, detail="Task already completed") + a2.updateTask(taskId, status="completed", result=result) + nodeId = task.get("nodeId") + nodeOutputs = dict(run.get("nodeOutputs") or {}) + nodeOutputs[nodeId] = result + workflowId = run.get("workflowId") + wf = a2.getWorkflow(workflowId) if workflowId else None + if not wf or not wf.get("graph"): + raise HTTPException(status_code=400, detail="Workflow graph not found") + graph = wf["graph"] + services = getAutomation2Services(context.user, mandateId=mandateId, featureInstanceId=instanceId) + return await executeGraph( + graph=graph, + services=services, + workflowId=workflowId, + instanceId=instanceId, + userId=str(context.user.id) if context.user else None, + mandateId=mandateId, + automation2_interface=a2, + initialNodeOutputs=nodeOutputs, + startAfterNodeId=nodeId, + runId=runId, + ) diff --git a/modules/features/chatbot/interfaceFeatureChatbot.py b/modules/features/chatbot/interfaceFeatureChatbot.py index 559a9187..4a03bec9 100644 --- a/modules/features/chatbot/interfaceFeatureChatbot.py +++ b/modules/features/chatbot/interfaceFeatureChatbot.py @@ -514,7 +514,7 @@ class ChatObjects: # Handle simple value (equals operator) if not isinstance(filter_value, dict): - if record_value != filter_value: + if str(record_value).lower() != str(filter_value).lower(): matches = False break continue @@ -524,7 +524,7 @@ class ChatObjects: filter_val = filter_value.get("value") if operator in ["equals", "eq"]: - if record_value != filter_val: + if str(record_value).lower() != str(filter_val).lower(): matches = False break @@ -609,7 +609,7 @@ class ChatObjects: else: # Unknown operator - default to equals - if record_value != filter_val: + if str(record_value).lower() != str(filter_val).lower(): matches = False break diff --git a/modules/features/commcoach/routeFeatureCommcoach.py b/modules/features/commcoach/routeFeatureCommcoach.py index 81585a19..9074d2ba 100644 --- a/modules/features/commcoach/routeFeatureCommcoach.py +++ b/modules/features/commcoach/routeFeatureCommcoach.py @@ -1192,24 +1192,79 @@ async def deleteDocumentRoute( def _extractText(content: bytes, mimeType: str, fileName: str) -> Optional[str]: - """Extract text from uploaded file content.""" + """Extract text from uploaded file content (TXT, MD, HTML, PDF, DOCX, XLSX, PPTX).""" + import io + + lowerName = fileName.lower() try: - if mimeType == "text/plain" or fileName.endswith(".txt"): + if mimeType in ("text/plain",) or lowerName.endswith(".txt"): return content.decode("utf-8", errors="replace") - if mimeType == "text/markdown" or fileName.endswith(".md"): + + if mimeType in ("text/markdown",) or lowerName.endswith(".md"): return content.decode("utf-8", errors="replace") - if "pdf" in mimeType or fileName.endswith(".pdf"): + + if mimeType in ("text/html",) or lowerName.endswith((".html", ".htm")): + from html.parser import HTMLParser + class _Strip(HTMLParser): + def __init__(self): + super().__init__() + self._parts: list[str] = [] + def handle_data(self, d): + self._parts.append(d) + def result(self): + return " ".join(self._parts) + parser = _Strip() + parser.feed(content.decode("utf-8", errors="replace")) + return parser.result() + + if "pdf" in mimeType or lowerName.endswith(".pdf"): try: - import io from PyPDF2 import PdfReader reader = PdfReader(io.BytesIO(content)) - text = "" - for page in reader.pages: - text += page.extract_text() or "" - return text + return "".join(page.extract_text() or "" for page in reader.pages) except ImportError: logger.warning("PyPDF2 not installed, cannot extract PDF text") return None + + if "wordprocessingml" in mimeType or lowerName.endswith(".docx"): + try: + from docx import Document + doc = Document(io.BytesIO(content)) + return "\n".join(p.text for p in doc.paragraphs if p.text) + except ImportError: + logger.warning("python-docx not installed, cannot extract DOCX text") + return None + + if "spreadsheetml" in mimeType or lowerName.endswith(".xlsx"): + try: + from openpyxl import load_workbook + wb = load_workbook(io.BytesIO(content), read_only=True, data_only=True) + parts: list[str] = [] + for ws in wb.worksheets: + for row in ws.iter_rows(values_only=True): + cells = [str(c) for c in row if c is not None] + if cells: + parts.append("\t".join(cells)) + return "\n".join(parts) + except ImportError: + logger.warning("openpyxl not installed, cannot extract XLSX text") + return None + + if "presentationml" in mimeType or lowerName.endswith(".pptx"): + try: + from pptx import Presentation + prs = Presentation(io.BytesIO(content)) + parts = [] + for slide in prs.slides: + for shape in slide.shapes: + if shape.has_text_frame: + parts.append(shape.text_frame.text) + return "\n".join(parts) + except ImportError: + logger.warning("python-pptx not installed, cannot extract PPTX text") + return None + + logger.info(f"No text extractor for {fileName} (mime={mimeType})") except Exception as e: logger.warning(f"Text extraction failed for {fileName}: {e}") return None diff --git a/modules/features/commcoach/serviceCommcoach.py b/modules/features/commcoach/serviceCommcoach.py index be47a917..bf5ec281 100644 --- a/modules/features/commcoach/serviceCommcoach.py +++ b/modules/features/commcoach/serviceCommcoach.py @@ -331,7 +331,8 @@ def _getDocumentSummaries(contextId: str, userId: str, interface) -> Optional[Li elif doc.get("extractedText"): summaries.append(f"[{doc.get('fileName', 'Dokument')}] {doc['extractedText'][:200]}...") return summaries if summaries else None - except Exception: + except Exception as e: + logger.warning(f"Failed to load document summaries for context {contextId}: {e}") return None diff --git a/modules/features/neutralization/serviceNeutralization/mainServiceNeutralization.py b/modules/features/neutralization/serviceNeutralization/mainServiceNeutralization.py index cf8f0f53..c803b375 100644 --- a/modules/features/neutralization/serviceNeutralization/mainServiceNeutralization.py +++ b/modules/features/neutralization/serviceNeutralization/mainServiceNeutralization.py @@ -39,19 +39,21 @@ EXTRACTABLE_BINARY_MIME_TYPES = frozenset({ class NeutralizationService: """Service for handling data neutralization operations""" - def __init__(self, serviceCenter=None, NamesToParse: List[str] = None): + def __init__(self, serviceCenter=None, getServiceFn=None, NamesToParse: List[str] = None): """Initialize the service with user context and anonymization processors Args: - serviceCenter: Service center instance for accessing other services + serviceCenter: Service center context or legacy service center instance + getServiceFn: Service resolver function (injected by ServiceCenter resolver) NamesToParse: List of names to parse and replace (case-insensitive) """ self.services = serviceCenter - self.interfaceDbComponent = serviceCenter.interfaceDbComponent + self._getService = getServiceFn + self.interfaceDbComponent = getattr(serviceCenter, "interfaceDbComponent", None) # Create feature-specific interface for neutralizer DB operations self.interfaceNeutralizer: InterfaceFeatureNeutralizer = None - if serviceCenter and serviceCenter.interfaceDbApp: + if serviceCenter and getattr(serviceCenter, "interfaceDbApp", None): dbApp = serviceCenter.interfaceDbApp self.interfaceNeutralizer = getNeutralizerInterface( currentUser=dbApp.currentUser, @@ -59,10 +61,10 @@ class NeutralizationService: featureInstanceId=getattr(serviceCenter, 'featureInstanceId', None) or getattr(dbApp, 'featureInstanceId', None) ) - # Initialize anonymization processors - self.NamesToParse = NamesToParse or [] - self.textProcessor = TextProcessor(NamesToParse) - self.listProcessor = ListProcessor(NamesToParse) + namesList = NamesToParse if isinstance(NamesToParse, list) else [] + self.NamesToParse = namesList + self.textProcessor = TextProcessor(namesList) + self.listProcessor = ListProcessor(namesList) self.binaryProcessor = BinaryProcessor() self.commonUtils = CommonUtils() diff --git a/modules/features/realEstate/interfaceFeatureRealEstate.py b/modules/features/realEstate/interfaceFeatureRealEstate.py index 65601d9a..f7ed52b6 100644 --- a/modules/features/realEstate/interfaceFeatureRealEstate.py +++ b/modules/features/realEstate/interfaceFeatureRealEstate.py @@ -24,7 +24,8 @@ from modules.shared.configuration import APP_CONFIG from modules.security.rbac import RbacClass from modules.datamodels.datamodelRbac import AccessRuleContext from modules.datamodels.datamodelUam import AccessLevel -from modules.interfaces.interfaceRbac import getRecordsetWithRBAC +from modules.interfaces.interfaceRbac import getRecordsetWithRBAC, getRecordsetPaginatedWithRBAC, getDistinctColumnValuesWithRBAC +from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult logger = logging.getLogger(__name__) @@ -175,17 +176,21 @@ class RealEstateObjects: return Projekt(**records[0]) - def getProjekte(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Projekt]: - """Get all projects matching the filter.""" - records = getRecordsetWithRBAC( + def getProjekte(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Projekt], PaginatedResult]: + """Get all projects matching the filter with optional DB-level pagination.""" + result = getRecordsetPaginatedWithRBAC( self.db, Projekt, self.currentUser, + pagination=pagination, recordFilter=recordFilter or {}, featureCode=self.FEATURE_CODE ) - - return [Projekt(**r) for r in records] + + if isinstance(result, PaginatedResult): + result.items = [Projekt(**r) for r in result.items] + return result + return [Projekt(**r) for r in result] def updateProjekt(self, projektId_or_projekt: Union[str, Projekt], updateData: Optional[Dict[str, Any]] = None) -> Optional[Projekt]: """Update a project. @@ -265,20 +270,23 @@ class RealEstateObjects: return Parzelle(**records[0]) - def getParzellen(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Parzelle]: - """Get all plots matching the filter.""" - # Resolve location names to IDs if needed + def getParzellen(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Parzelle], PaginatedResult]: + """Get all plots matching the filter with optional DB-level pagination.""" if recordFilter: recordFilter = self._resolveLocationFilters(recordFilter) - - records = getRecordsetWithRBAC( + + result = getRecordsetPaginatedWithRBAC( self.db, Parzelle, self.currentUser, + pagination=pagination, recordFilter=recordFilter or {} ) - - return [Parzelle(**r) for r in records] + + if isinstance(result, PaginatedResult): + result.items = [Parzelle(**r) for r in result.items] + return result + return [Parzelle(**r) for r in result] def _resolveLocationFilters(self, recordFilter: Dict[str, Any]) -> Dict[str, Any]: """ @@ -477,18 +485,23 @@ class RealEstateObjects: return Dokument(**records[0]) - def getDokumente(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Dokument]: - """Get all documents matching the filter.""" - records = getRecordsetWithRBAC( + def getDokumente(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Dokument], PaginatedResult]: + """Get all documents matching the filter with optional DB-level pagination.""" + result = getRecordsetPaginatedWithRBAC( self.db, Dokument, self.currentUser, + pagination=pagination, recordFilter=recordFilter or {}, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - return [Dokument(**r) for r in records] + + if isinstance(result, PaginatedResult): + result.items = [Dokument(**r) for r in result.items] + return result + return [Dokument(**r) for r in result] def updateDokument(self, dokumentId: str, updateData: Dict[str, Any]) -> Optional[Dokument]: """Update a document.""" @@ -552,18 +565,23 @@ class RealEstateObjects: return Gemeinde(**records[0]) - def getGemeinden(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Gemeinde]: - """Get all municipalities matching the filter.""" - records = getRecordsetWithRBAC( + def getGemeinden(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Gemeinde], PaginatedResult]: + """Get all municipalities matching the filter with optional DB-level pagination.""" + result = getRecordsetPaginatedWithRBAC( self.db, Gemeinde, self.currentUser, + pagination=pagination, recordFilter=recordFilter or {}, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - return [Gemeinde(**r) for r in records] + + if isinstance(result, PaginatedResult): + result.items = [Gemeinde(**r) for r in result.items] + return result + return [Gemeinde(**r) for r in result] def updateGemeinde(self, gemeindeId: str, updateData: Dict[str, Any]) -> Optional[Gemeinde]: """Update a municipality.""" @@ -627,18 +645,23 @@ class RealEstateObjects: return Kanton(**records[0]) - def getKantone(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Kanton]: - """Get all cantons matching the filter.""" - records = getRecordsetWithRBAC( + def getKantone(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Kanton], PaginatedResult]: + """Get all cantons matching the filter with optional DB-level pagination.""" + result = getRecordsetPaginatedWithRBAC( self.db, Kanton, self.currentUser, + pagination=pagination, recordFilter=recordFilter or {}, mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - return [Kanton(**r) for r in records] + + if isinstance(result, PaginatedResult): + result.items = [Kanton(**r) for r in result.items] + return result + return [Kanton(**r) for r in result] def updateKanton(self, kantonId: str, updateData: Dict[str, Any]) -> Optional[Kanton]: """Update a canton.""" @@ -700,16 +723,21 @@ class RealEstateObjects: return Land(**records[0]) - def getLaender(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Land]: - """Get all countries matching the filter.""" - records = getRecordsetWithRBAC( + def getLaender(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Land], PaginatedResult]: + """Get all countries matching the filter with optional DB-level pagination.""" + result = getRecordsetPaginatedWithRBAC( self.db, Land, self.currentUser, + pagination=pagination, recordFilter=recordFilter or {}, featureCode=self.FEATURE_CODE ) - return [Land(**r) for r in records] + + if isinstance(result, PaginatedResult): + result.items = [Land(**r) for r in result.items] + return result + return [Land(**r) for r in result] def updateLand(self, landId: str, updateData: Dict[str, Any]) -> Optional[Land]: """Update a country.""" diff --git a/modules/features/realEstate/routeFeatureRealEstate.py b/modules/features/realEstate/routeFeatureRealEstate.py index b4df86dd..82fa55ba 100644 --- a/modules/features/realEstate/routeFeatureRealEstate.py +++ b/modules/features/realEstate/routeFeatureRealEstate.py @@ -252,6 +252,31 @@ def get_projects( return PaginatedResponse(items=items, pagination=None) +@router.get("/{instanceId}/projects/filter-values") +@limiter.limit("60/minute") +def get_project_filter_values( + request: Request, + instanceId: str = Path(..., description="Feature Instance ID"), + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in real estate projects.""" + mandateId = _validateInstanceAccess(instanceId, context) + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + interface = getRealEstateInterface( + context.user, mandateId=mandateId, featureInstanceId=instanceId + ) + recordFilter = {"featureInstanceId": instanceId} + items = interface.getProjekte(recordFilter=recordFilter) + itemDicts = [i.model_dump() if hasattr(i, 'model_dump') else i for i in items] + return _handleFilterValuesRequest(itemDicts, column, pagination) + except Exception as e: + logger.error(f"Error getting filter values for projects: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/{instanceId}/projects/{projectId}", response_model=Projekt) @limiter.limit("30/minute") def get_project_by_id( @@ -384,6 +409,31 @@ def get_parcels( return PaginatedResponse(items=items, pagination=None) +@router.get("/{instanceId}/parcels/filter-values") +@limiter.limit("60/minute") +def get_parcel_filter_values( + request: Request, + instanceId: str = Path(..., description="Feature Instance ID"), + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in real estate parcels.""" + mandateId = _validateInstanceAccess(instanceId, context) + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + interface = getRealEstateInterface( + context.user, mandateId=mandateId, featureInstanceId=instanceId + ) + recordFilter = {"featureInstanceId": instanceId} + items = interface.getParzellen(recordFilter=recordFilter) + itemDicts = [i.model_dump() if hasattr(i, 'model_dump') else i for i in items] + return _handleFilterValuesRequest(itemDicts, column, pagination) + except Exception as e: + logger.error(f"Error getting filter values for parcels: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/{instanceId}/parcels/{parcelId}", response_model=Parzelle) @limiter.limit("30/minute") def get_parcel_by_id( diff --git a/modules/features/trustee/interfaceFeatureTrustee.py b/modules/features/trustee/interfaceFeatureTrustee.py index a4b13c27..b9a95005 100644 --- a/modules/features/trustee/interfaceFeatureTrustee.py +++ b/modules/features/trustee/interfaceFeatureTrustee.py @@ -14,7 +14,7 @@ from pydantic import ValidationError from modules.connectors.connectorDbPostgre import DatabaseConnector from modules.shared.configuration import APP_CONFIG -from modules.interfaces.interfaceRbac import getRecordsetWithRBAC +from modules.interfaces.interfaceRbac import getRecordsetWithRBAC, getRecordsetPaginatedWithRBAC, getDistinctColumnValuesWithRBAC from modules.security.rbac import RbacClass from modules.datamodels.datamodelUam import User, AccessLevel from modules.datamodels.datamodelRbac import AccessRuleContext @@ -414,7 +414,7 @@ class TrusteeObjects: # Handle simple value (equals operator) if not isinstance(filterValue, dict): - if recordValue == filterValue: + if str(recordValue).lower() == str(filterValue).lower(): fieldFiltered.append(record) continue @@ -424,7 +424,7 @@ class TrusteeObjects: matches = False if operator in ["equals", "eq"]: - matches = recordValue == filterVal + matches = str(recordValue).lower() == str(filterVal).lower() elif operator == "contains": recordStr = str(recordValue).lower() if recordValue is not None else "" @@ -577,8 +577,8 @@ class TrusteeObjects: return None return TrusteeOrganisation(**{k: v for k, v in records[0].items() if not k.startswith("_")}) - def getAllOrganisations(self, params: Optional[PaginationParams] = None) -> PaginatedResult: - """Get all organisations with RBAC filtering. + def getAllOrganisations(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]: + """Get all organisations with RBAC filtering and optional DB-level pagination. Note: Organisations are managed at system level (by mandate). Feature-level filtering (trustee.access) is NOT applied here because: @@ -586,42 +586,18 @@ class TrusteeObjects: - trustee.access grants access to specific orgs for other users - New organisations wouldn't be visible without an access record """ - # Debug: Log user info and permissions logger.debug(f"getAllOrganisations called for user {self.userId}, mandateId: {self.mandateId}") - # System RBAC filtering (filters by mandate for GROUP access level) - records = getRecordsetWithRBAC( + return getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteeOrganisation, currentUser=self.currentUser, + pagination=params, recordFilter=None, - orderBy="id", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - logger.debug(f"getAllOrganisations: getRecordsetWithRBAC returned {len(records)} records") - - # Apply pagination - totalItems = len(records) - if params: - pageSize = params.pageSize or 20 - page = params.page or 1 - startIdx = (page - 1) * pageSize - endIdx = startIdx + pageSize - items = records[startIdx:endIdx] - totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1 - else: - items = records - totalPages = 1 - page = 1 - pageSize = totalItems - - return PaginatedResult( - items=items, - totalItems=totalItems, - totalPages=totalPages - ) def updateOrganisation(self, orgId: str, data: Dict[str, Any]) -> Optional[TrusteeOrganisation]: """Update an organisation.""" @@ -677,51 +653,40 @@ class TrusteeObjects: return None return TrusteeRole(**{k: v for k, v in records[0].items() if not k.startswith("_")}) - def getAllRoles(self, params: Optional[PaginationParams] = None) -> PaginatedResult: - """Get all roles with RBAC filtering. + def getAllRoles(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]: + """Get all roles with RBAC filtering and optional DB-level pagination. Note: Roles are available to all users with trustee access. They are not filtered by organisation since they define the role types. Users with ALL access level see all roles; others need trustee.access records. + + NOTE(post-filter): Feature-level access check runs after the paginated query. + When pagination is active the totals may overcount if the user lacks any + trustee.access record (in that case the entire result is emptied). """ - records = getRecordsetWithRBAC( + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteeRole, currentUser=self.currentUser, + pagination=params, recordFilter=None, - orderBy="id", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - # Users with ALL access level (from system RBAC) see all roles - # Others need at least one trustee.access record accessLevel = self.getRbacAccessLevel(TrusteeRole, "read") if accessLevel != AccessLevel.ALL: userAccess = self.getAllUserAccess(self.userId) if not userAccess: - records = [] # No trustee access at all + if isinstance(result, PaginatedResult): + result.items = [] + result.totalItems = 0 + result.totalPages = 0 + return result + return [] - totalItems = len(records) - if params: - pageSize = params.pageSize or 20 - page = params.page or 1 - startIdx = (page - 1) * pageSize - endIdx = startIdx + pageSize - items = records[startIdx:endIdx] - totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1 - else: - items = records - totalPages = 1 - page = 1 - pageSize = totalItems - - return PaginatedResult( - items=items, - totalItems=totalItems, - totalPages=totalPages - ) + return result def updateRole(self, roleId: str, data: Dict[str, Any]) -> Optional[TrusteeRole]: """Update a role (sysadmin only).""" @@ -788,116 +753,113 @@ class TrusteeObjects: return None return TrusteeAccess(**{k: v for k, v in records[0].items() if not k.startswith("_")}) - def getAllAccess(self, params: Optional[PaginationParams] = None) -> PaginatedResult: + def getAllAccess(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]: """Get all access records with RBAC filtering + feature-level filtering. Users with ALL access level see all access records. Others can only see access records for organisations they have admin access to. + + NOTE(post-filter): Feature-level admin-org filtering runs after the paginated + query, so totals may overcount when the user has restricted org access. """ - # Step 1: System RBAC filtering - records = getRecordsetWithRBAC( + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteeAccess, currentUser=self.currentUser, + pagination=params, recordFilter=None, - orderBy="id", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - # Users with ALL access level (from system RBAC) see all records accessLevel = self.getRbacAccessLevel(TrusteeAccess, "read") - + if accessLevel != AccessLevel.ALL: - # Step 2: Feature-level filtering - only see access for organisations where user is admin userAccess = self.getAllUserAccess(self.userId) - - # Get organisations where user has admin role adminOrgs = set() for access in userAccess: if access.get("roleId") == "admin": adminOrgs.add(access.get("organisationId")) - - # Filter records to only show those in admin organisations + + if isinstance(result, PaginatedResult): + if adminOrgs: + result.items = [r for r in result.items if r.get("organisationId") in adminOrgs] + else: + result.items = [] + return result + if adminOrgs: - records = [r for r in records if r.get("organisationId") in adminOrgs] + result = [r for r in result if r.get("organisationId") in adminOrgs] else: - records = [] + result = [] - totalItems = len(records) - if params: - pageSize = params.pageSize or 20 - page = params.page or 1 - startIdx = (page - 1) * pageSize - endIdx = startIdx + pageSize - items = records[startIdx:endIdx] - totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1 - else: - items = records - totalPages = 1 - page = 1 - pageSize = totalItems + return result - return PaginatedResult( - items=items, - totalItems=totalItems, - totalPages=totalPages - ) - - def getAccessByOrganisation(self, organisationId: str) -> List[TrusteeAccess]: + def getAccessByOrganisation(self, organisationId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeAccess], PaginatedResult]: """Get all access records for a specific organisation. Requires admin role for the organisation. """ - # Check if user has admin access for this organisation if not self.checkUserTrusteePermission(self.userId, organisationId, "admin"): logger.warning(f"User {self.userId} lacks admin role for organisation {organisationId}") return [] - - records = getRecordsetWithRBAC( + + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteeAccess, currentUser=self.currentUser, + pagination=pagination, recordFilter={"organisationId": organisationId}, - orderBy="id", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in records] - def getAccessByUser(self, userId: str) -> List[TrusteeAccess]: + if isinstance(result, PaginatedResult): + result.items = [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items] + return result + return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result] + + def getAccessByUser(self, userId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeAccess], PaginatedResult]: """Get all access records for a specific user. Users with ALL access level see all access records. Others can only see access records for organisations where they have admin role. + + NOTE(post-filter): Admin-org filtering runs after the paginated query, + so totals may overcount when the user has restricted org access. """ - records = getRecordsetWithRBAC( + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteeAccess, currentUser=self.currentUser, + pagination=pagination, recordFilter={"userId": userId}, - orderBy="id", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - - # Users with ALL access level (from system RBAC) see all records + accessLevel = self.getRbacAccessLevel(TrusteeAccess, "read") - if accessLevel == AccessLevel.ALL: - return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in records] - - # Filter to only organisations where current user has admin role - userAccess = self.getAllUserAccess(self.userId) - adminOrgs = set() - for access in userAccess: - if access.get("roleId") == "admin": - adminOrgs.add(access.get("organisationId")) - - filtered = [r for r in records if r.get("organisationId") in adminOrgs] - return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in filtered] + + if accessLevel != AccessLevel.ALL: + userAccess = self.getAllUserAccess(self.userId) + adminOrgs = set() + for access in userAccess: + if access.get("roleId") == "admin": + adminOrgs.add(access.get("organisationId")) + + if isinstance(result, PaginatedResult): + result.items = [r for r in result.items if r.get("organisationId") in adminOrgs] + result.items = [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items] + return result + result = [r for r in result if r.get("organisationId") in adminOrgs] + + if isinstance(result, PaginatedResult): + result.items = [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items] + return result + return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result] def updateAccess(self, accessId: str, data: Dict[str, Any]) -> Optional[TrusteeAccess]: """Update an access record. Requires admin role for the organisation or ALL access level.""" @@ -988,54 +950,36 @@ class TrusteeObjects: return None return TrusteeContract(**{k: v for k, v in records[0].items() if not k.startswith("_")}) - def getAllContracts(self, params: Optional[PaginationParams] = None) -> PaginatedResult: - """Get all contracts with RBAC filtering + feature-level access filtering.""" - # Step 1: System RBAC filtering - records = getRecordsetWithRBAC( + def getAllContracts(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]: + """Get all contracts with RBAC filtering and optional DB-level pagination.""" + return getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteeContract, currentUser=self.currentUser, + pagination=params, recordFilter=None, - orderBy="id", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - totalItems = len(records) - if params: - pageSize = params.pageSize or 20 - page = params.page or 1 - startIdx = (page - 1) * pageSize - endIdx = startIdx + pageSize - items = records[startIdx:endIdx] - totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1 - else: - items = records - totalPages = 1 - page = 1 - pageSize = totalItems - - return PaginatedResult( - items=items, - totalItems=totalItems, - totalPages=totalPages - ) - - def getContractsByOrganisation(self, organisationId: str) -> List[TrusteeContract]: + def getContractsByOrganisation(self, organisationId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeContract], PaginatedResult]: """Get all contracts for a specific organisation.""" - # Step 1: System RBAC filtering - records = getRecordsetWithRBAC( + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteeContract, currentUser=self.currentUser, + pagination=pagination, recordFilter={"organisationId": organisationId}, - orderBy="label", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - return [TrusteeContract(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in records] + + if isinstance(result, PaginatedResult): + result.items = [TrusteeContract(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items] + return result + return [TrusteeContract(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result] def updateContract(self, contractId: str, data: Dict[str, Any]) -> Optional[TrusteeContract]: """Update a contract (organisationId is immutable).""" @@ -1142,75 +1086,58 @@ class TrusteeObjects: # Legacy fallback: documentData was stored directly (for migration) return record.get("documentData") - def getAllDocuments(self, params: Optional[PaginationParams] = None) -> PaginatedResult: - """Get all documents with RBAC filtering + feature-level access filtering (metadata only).""" - # Step 1: System RBAC filtering - records = getRecordsetWithRBAC( + def getAllDocuments(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]: + """Get all documents with RBAC filtering and optional DB-level pagination (metadata only). + + Filtering, sorting, and pagination are handled at the SQL level by + getRecordsetPaginatedWithRBAC. Binary documentData is stripped from + the returned items. + """ + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteeDocument, currentUser=self.currentUser, + pagination=params, recordFilter=None, - orderBy="documentName", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - # Clean records (remove binary data and internal fields) - keep as dicts for filtering/sorting - cleanedRecords = [] - for record in records: - cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") and k != "documentData"} - cleanedRecords.append(cleanedRecord) + def _cleanDocumentRecords(records): + return [ + TrusteeDocument(**{k: v for k, v in r.items() if not k.startswith("_") and k != "documentData"}) + for r in records + ] - # Step 2: Apply filters (search and field filters) - filteredRecords = self._applyFilters(cleanedRecords, params) - - # Step 3: Apply sorting - sortedRecords = self._applySorting(filteredRecords, params) - - # Step 4: Convert to Pydantic objects - pydanticItems = [TrusteeDocument(**r) for r in sortedRecords] + if isinstance(result, PaginatedResult): + result.items = _cleanDocumentRecords(result.items) + return result + return _cleanDocumentRecords(result) - # Step 5: Apply pagination - totalItems = len(pydanticItems) - if params: - pageSize = params.pageSize or 20 - page = params.page or 1 - startIdx = (page - 1) * pageSize - endIdx = startIdx + pageSize - items = pydanticItems[startIdx:endIdx] - totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1 - else: - items = pydanticItems - totalPages = 1 - page = 1 - pageSize = totalItems - - return PaginatedResult( - items=items, - totalItems=totalItems, - totalPages=totalPages - ) - - def getDocumentsByContract(self, contractId: str) -> List[TrusteeDocument]: + def getDocumentsByContract(self, contractId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeDocument], PaginatedResult]: """Get all documents for a specific contract.""" - # Step 1: System RBAC filtering - records = getRecordsetWithRBAC( + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteeDocument, currentUser=self.currentUser, + pagination=pagination, recordFilter={"contractId": contractId}, - orderBy="documentName", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - - result = [] - for record in records: - cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") and k != "documentData"} - result.append(TrusteeDocument(**cleanedRecord)) - return result + + def _cleanDocumentRecords(records): + return [ + TrusteeDocument(**{k: v for k, v in r.items() if not k.startswith("_") and k != "documentData"}) + for r in records + ] + + if isinstance(result, PaginatedResult): + result.items = _cleanDocumentRecords(result.items) + return result + return _cleanDocumentRecords(result) def updateDocument(self, documentId: str, data: Dict[str, Any]) -> Optional[TrusteeDocument]: """Update a document. @@ -1340,101 +1267,94 @@ class TrusteeObjects: return None return self._toTrusteePositionOrDelete(records[0], deleteCorrupt=True) - def getAllPositions(self, params: Optional[PaginationParams] = None) -> PaginatedResult: - """Get all positions with RBAC filtering + feature-level access filtering.""" - # Step 1: System RBAC filtering - records = getRecordsetWithRBAC( + def getAllPositions(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]: + """Get all positions with RBAC filtering and optional DB-level pagination. + + Filtering, sorting, and pagination are handled at the SQL level. + Post-processing cleans internal fields (keeps _createdAt) and validates + each record via _toTrusteePositionOrDelete (corrupt rows are deleted). + + NOTE(post-process): totalItems may slightly overcount when corrupt legacy + records are removed from the current page. + """ + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteePosition, currentUser=self.currentUser, + pagination=params, recordFilter=None, - orderBy="valuta", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - # Clean records (remove internal fields except _createdAt) - keep as dicts for filtering/sorting - # Keep _createdAt for display in frontend keepFields = {'_createdAt'} - cleanedRecords = [] - for record in records: - cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") or k in keepFields} - cleanedRecords.append(cleanedRecord) - # Step 2: Apply filters (search and field filters) - filteredRecords = self._applyFilters(cleanedRecords, params) - - # Step 3: Apply sorting - sortedRecords = self._applySorting(filteredRecords, params) - - # Step 4: Convert to Pydantic objects and cleanup corrupt legacy records. - pydanticItems = [] - for record in sortedRecords: - position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True) - if position is not None: - pydanticItems.append(position) + def _cleanAndValidate(records): + items = [] + for record in records: + cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") or k in keepFields} + position = self._toTrusteePositionOrDelete(cleanedRecord, deleteCorrupt=True) + if position is not None: + items.append(position) + return items - # Step 5: Apply pagination - totalItems = len(pydanticItems) - if params: - pageSize = params.pageSize or 20 - page = params.page or 1 - startIdx = (page - 1) * pageSize - endIdx = startIdx + pageSize - items = pydanticItems[startIdx:endIdx] - totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1 - else: - items = pydanticItems - totalPages = 1 - page = 1 - pageSize = totalItems + if isinstance(result, PaginatedResult): + result.items = _cleanAndValidate(result.items) + return result + return _cleanAndValidate(result) - return PaginatedResult( - items=items, - totalItems=totalItems, - totalPages=totalPages - ) - - def getPositionsByContract(self, contractId: str) -> List[TrusteePosition]: + def getPositionsByContract(self, contractId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteePosition], PaginatedResult]: """Get all positions for a specific contract.""" - # Step 1: System RBAC filtering - records = getRecordsetWithRBAC( + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteePosition, currentUser=self.currentUser, + pagination=pagination, recordFilter={"contractId": contractId}, - orderBy="valuta", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - safeItems = [] - for record in records: - position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True) - if position is not None: - safeItems.append(position) - return safeItems - def getPositionsByOrganisation(self, organisationId: str) -> List[TrusteePosition]: + def _validatePositions(records): + items = [] + for record in records: + position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True) + if position is not None: + items.append(position) + return items + + if isinstance(result, PaginatedResult): + result.items = _validatePositions(result.items) + return result + return _validatePositions(result) + + def getPositionsByOrganisation(self, organisationId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteePosition], PaginatedResult]: """Get all positions for a specific organisation.""" - # Step 1: System RBAC filtering - records = getRecordsetWithRBAC( + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteePosition, currentUser=self.currentUser, + pagination=pagination, recordFilter={"organisationId": organisationId}, - orderBy="valuta", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - safeItems = [] - for record in records: - position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True) - if position is not None: - safeItems.append(position) - return safeItems + + def _validatePositions(records): + items = [] + for record in records: + position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True) + if position is not None: + items.append(position) + return items + + if isinstance(result, PaginatedResult): + result.items = _validatePositions(result.items) + return result + return _validatePositions(result) def updatePosition(self, positionId: str, data: Dict[str, Any]) -> Optional[TrusteePosition]: """Update a position. @@ -1481,24 +1401,31 @@ class TrusteeObjects: # ===== Position-Document Queries ===== - def getPositionsByDocument(self, documentId: str) -> List[TrusteePosition]: + def getPositionsByDocument(self, documentId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteePosition], PaginatedResult]: """Get all positions that reference a specific document (1:N via documentId FK).""" - records = getRecordsetWithRBAC( + result = getRecordsetPaginatedWithRBAC( connector=self.db, modelClass=TrusteePosition, currentUser=self.currentUser, + pagination=pagination, recordFilter={"documentId": documentId}, - orderBy="valuta", mandateId=self.mandateId, featureInstanceId=self.featureInstanceId, featureCode=self.FEATURE_CODE ) - safeItems = [] - for record in records: - position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True) - if position is not None: - safeItems.append(position) - return safeItems + + def _validatePositions(records): + items = [] + for record in records: + position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True) + if position is not None: + items.append(position) + return items + + if isinstance(result, PaginatedResult): + result.items = _validatePositions(result.items) + return result + return _validatePositions(result) # ===== Trustee-specific Access Check ===== diff --git a/modules/features/trustee/routeFeatureTrustee.py b/modules/features/trustee/routeFeatureTrustee.py index feb873ae..13b28b07 100644 --- a/modules/features/trustee/routeFeatureTrustee.py +++ b/modules/features/trustee/routeFeatureTrustee.py @@ -338,7 +338,7 @@ def get_organisations( interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.getAllOrganisations(paginationParams) - if paginationParams: + if paginationParams and hasattr(result, 'items'): return PaginatedResponse( items=result.items, pagination=PaginationMetadata( @@ -350,7 +350,7 @@ def get_organisations( filters=paginationParams.filters if paginationParams else None ) ) - return PaginatedResponse(items=result.items, pagination=None) + return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None) @router.get("/{instanceId}/organisations/{orgId}", response_model=TrusteeOrganisation) @@ -451,7 +451,7 @@ def get_roles( interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.getAllRoles(paginationParams) - if paginationParams: + if paginationParams and hasattr(result, 'items'): return PaginatedResponse( items=result.items, pagination=PaginationMetadata( @@ -463,7 +463,7 @@ def get_roles( filters=paginationParams.filters if paginationParams else None ) ) - return PaginatedResponse(items=result.items, pagination=None) + return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None) @router.get("/{instanceId}/roles/{roleId}", response_model=TrusteeRole) @@ -564,7 +564,7 @@ def get_all_access( interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.getAllAccess(paginationParams) - if paginationParams: + if paginationParams and hasattr(result, 'items'): return PaginatedResponse( items=result.items, pagination=PaginationMetadata( @@ -576,7 +576,7 @@ def get_all_access( filters=paginationParams.filters if paginationParams else None ) ) - return PaginatedResponse(items=result.items, pagination=None) + return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None) @router.get("/{instanceId}/access/{accessId}", response_model=TrusteeAccess) @@ -707,7 +707,7 @@ def get_contracts( interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.getAllContracts(paginationParams) - if paginationParams: + if paginationParams and hasattr(result, 'items'): return PaginatedResponse( items=result.items, pagination=PaginationMetadata( @@ -719,7 +719,7 @@ def get_contracts( filters=paginationParams.filters if paginationParams else None ) ) - return PaginatedResponse(items=result.items, pagination=None) + return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None) @router.get("/{instanceId}/contracts/{contractId}", response_model=TrusteeContract) @@ -835,7 +835,7 @@ def get_documents( interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.getAllDocuments(paginationParams) - if paginationParams: + if paginationParams and hasattr(result, 'items'): return PaginatedResponse( items=result.items, pagination=PaginationMetadata( @@ -847,7 +847,59 @@ def get_documents( filters=paginationParams.filters if paginationParams else None ) ) - return PaginatedResponse(items=result.items, pagination=None) + return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None) + + +@router.get("/{instanceId}/documents/filter-values") +@limiter.limit("60/minute") +def get_document_filter_values( + request: Request, + instanceId: str = Path(..., description="Feature Instance ID"), + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in trustee documents.""" + mandateId = _validateInstanceAccess(instanceId, context) + try: + from modules.interfaces.interfaceRbac import getDistinctColumnValuesWithRBAC + interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) + + crossFilterPagination = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterPagination = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError): + pass + + try: + values = getDistinctColumnValuesWithRBAC( + connector=interface.db, + modelClass=TrusteeDocument, + column=column, + currentUser=interface.currentUser, + pagination=crossFilterPagination, + recordFilter=None, + mandateId=interface.mandateId, + featureInstanceId=interface.featureInstanceId, + featureCode=interface.FEATURE_CODE + ) + return sorted(values, key=lambda v: str(v).lower()) + except Exception: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + result = interface.getAllDocuments(None) + items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in (result.items if hasattr(result, 'items') else result)] + return _handleFilterValuesRequest(items, column, pagination) + except Exception as e: + logger.error(f"Error getting filter values for trustee documents: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) @router.get("/{instanceId}/documents/{documentId}", response_model=TrusteeDocument) @@ -1039,7 +1091,7 @@ def get_positions( interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) result = interface.getAllPositions(paginationParams) - if paginationParams: + if paginationParams and hasattr(result, 'items'): return PaginatedResponse( items=result.items, pagination=PaginationMetadata( @@ -1051,7 +1103,59 @@ def get_positions( filters=paginationParams.filters if paginationParams else None ) ) - return PaginatedResponse(items=result.items, pagination=None) + return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None) + + +@router.get("/{instanceId}/positions/filter-values") +@limiter.limit("60/minute") +def get_position_filter_values( + request: Request, + instanceId: str = Path(..., description="Feature Instance ID"), + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in trustee positions.""" + mandateId = _validateInstanceAccess(instanceId, context) + try: + from modules.interfaces.interfaceRbac import getDistinctColumnValuesWithRBAC + interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId) + + crossFilterPagination = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterPagination = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError): + pass + + try: + values = getDistinctColumnValuesWithRBAC( + connector=interface.db, + modelClass=TrusteePosition, + column=column, + currentUser=interface.currentUser, + pagination=crossFilterPagination, + recordFilter=None, + mandateId=interface.mandateId, + featureInstanceId=interface.featureInstanceId, + featureCode=interface.FEATURE_CODE + ) + return sorted(values, key=lambda v: str(v).lower()) + except Exception: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + result = interface.getAllPositions(None) + items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in (result.items if hasattr(result, 'items') else result)] + return _handleFilterValuesRequest(items, column, pagination) + except Exception as e: + logger.error(f"Error getting filter values for trustee positions: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) @router.get("/{instanceId}/positions/{positionId}", response_model=TrusteePosition) diff --git a/modules/features/workspace/datamodelFeatureWorkspace.py b/modules/features/workspace/datamodelFeatureWorkspace.py new file mode 100644 index 00000000..80da5915 --- /dev/null +++ b/modules/features/workspace/datamodelFeatureWorkspace.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +"""Workspace feature data models — VoiceSettings and WorkspaceUserSettings.""" + +from typing import Dict, Any, Optional +from pydantic import BaseModel, Field +from modules.shared.attributeUtils import registerModelLabels +from modules.shared.timeUtils import getUtcTimestamp +import uuid + + +class VoiceSettings(BaseModel): + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) + userId: str = Field(description="ID of the user these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}) + mandateId: str = Field(description="ID of the mandate these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}) + featureInstanceId: str = Field(description="ID of the feature instance these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}) + sttLanguage: str = Field(default="de-DE", description="Speech-to-Text language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True}) + ttsLanguage: str = Field(default="de-DE", description="Text-to-Speech language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True}) + ttsVoice: str = Field(default="de-DE-KatjaNeural", description="Text-to-Speech voice", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True}) + ttsVoiceMap: Dict[str, Any] = Field(default_factory=dict, description="Per-language voice mapping, e.g. {'de-DE': {'voiceName': 'de-DE-Wavenet-A'}, 'en-US': {'voiceName': 'en-US-Wavenet-C'}}", json_schema_extra={"frontend_type": "json", "frontend_readonly": False, "frontend_required": False}) + translationEnabled: bool = Field(default=True, description="Whether translation is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False}) + targetLanguage: str = Field(default="en-US", description="Target language for translation", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False}) + creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False}) + lastModified: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were last modified (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False}) + + +class WorkspaceUserSettings(BaseModel): + """Per-user workspace settings. None values mean 'use instance default'.""" + id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False}) + userId: str = Field(description="User ID", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}) + mandateId: str = Field(description="Mandate ID", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}) + featureInstanceId: str = Field(description="Feature Instance ID", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True}) + maxAgentRounds: Optional[int] = Field(default=None, description="Max agent rounds override (None = instance default)", json_schema_extra={"frontend_type": "number", "frontend_readonly": False, "frontend_required": False}) + + +registerModelLabels( + "VoiceSettings", + {"en": "Voice Settings", "fr": "Paramètres vocaux"}, + { + "id": {"en": "ID", "fr": "ID"}, + "userId": {"en": "User ID", "fr": "ID utilisateur"}, + "mandateId": {"en": "Mandate ID", "fr": "ID du mandat"}, + "featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"}, + "sttLanguage": {"en": "STT Language", "fr": "Langue STT"}, + "ttsLanguage": {"en": "TTS Language", "fr": "Langue TTS"}, + "ttsVoice": {"en": "TTS Voice", "fr": "Voix TTS"}, + "ttsVoiceMap": {"en": "TTS Voice Map", "fr": "Carte des voix TTS"}, + "translationEnabled": {"en": "Translation Enabled", "fr": "Traduction activée"}, + "targetLanguage": {"en": "Target Language", "fr": "Langue cible"}, + "creationDate": {"en": "Creation Date", "fr": "Date de création"}, + "lastModified": {"en": "Last Modified", "fr": "Dernière modification"}, + }, +) + +registerModelLabels( + "WorkspaceUserSettings", + {"en": "Workspace User Settings", "de": "Workspace Benutzereinstellungen"}, + { + "id": {"en": "ID", "de": "ID"}, + "userId": {"en": "User ID", "de": "Benutzer-ID"}, + "mandateId": {"en": "Mandate ID", "de": "Mandanten-ID"}, + "featureInstanceId": {"en": "Feature Instance ID", "de": "Feature-Instanz-ID"}, + "maxAgentRounds": {"en": "Max Agent Rounds", "de": "Max. Agenten-Runden"}, + }, +) diff --git a/modules/features/workspace/interfaceFeatureWorkspace.py b/modules/features/workspace/interfaceFeatureWorkspace.py new file mode 100644 index 00000000..bd1a03c4 --- /dev/null +++ b/modules/features/workspace/interfaceFeatureWorkspace.py @@ -0,0 +1,248 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Interface for Workspace feature — manages VoiceSettings and WorkspaceUserSettings. +Uses a dedicated poweron_workspace database. +""" + +import logging +from typing import Dict, Any, Optional + +from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.datamodels.datamodelUam import User +from modules.features.workspace.datamodelFeatureWorkspace import VoiceSettings, WorkspaceUserSettings +from modules.interfaces.interfaceRbac import getRecordsetWithRBAC +from modules.security.rbac import RbacClass +from modules.shared.configuration import APP_CONFIG +from modules.shared.timeUtils import getUtcTimestamp + +logger = logging.getLogger(__name__) + +_workspaceInterfaces: Dict[str, "WorkspaceObjects"] = {} + + +class WorkspaceObjects: + """Interface for Workspace-specific database operations (voice + general settings).""" + + def __init__(self, currentUser: User, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None): + self.currentUser = currentUser + self.mandateId = mandateId + self.featureInstanceId = featureInstanceId + self.userId = currentUser.id if currentUser else None + + self._initializeDatabase() + + from modules.security.rootAccess import getRootDbAppConnector + dbApp = getRootDbAppConnector() + self.rbac = RbacClass(self.db, dbApp=dbApp) + + self.db.updateContext(self.userId) + + def _initializeDatabase(self): + dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data") + dbDatabase = "poweron_workspace" + dbUser = APP_CONFIG.get("DB_USER") + dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET") + dbPort = int(APP_CONFIG.get("DB_PORT", 5432)) + + self.db = DatabaseConnector( + dbHost=dbHost, + dbDatabase=dbDatabase, + dbUser=dbUser, + dbPassword=dbPassword, + dbPort=dbPort, + userId=self.userId, + ) + logger.debug(f"Workspace database initialized for user {self.userId}") + + def setUserContext(self, currentUser: User, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None): + self.currentUser = currentUser + self.userId = currentUser.id if currentUser else None + self.mandateId = mandateId + self.featureInstanceId = featureInstanceId + self.db.updateContext(self.userId) + + # ========================================================================= + # VoiceSettings CRUD + # ========================================================================= + + def getVoiceSettings(self, userId: Optional[str] = None) -> Optional[VoiceSettings]: + try: + targetUserId = userId or self.userId + if not targetUserId: + logger.error("No user ID provided for voice settings") + return None + + recordFilter: Dict[str, Any] = {"userId": targetUserId} + if self.featureInstanceId: + recordFilter["featureInstanceId"] = self.featureInstanceId + + filteredSettings = getRecordsetWithRBAC( + self.db, VoiceSettings, self.currentUser, + recordFilter=recordFilter, mandateId=self.mandateId, + ) + + if not filteredSettings: + return None + + settingsData = filteredSettings[0] + if not settingsData.get("creationDate"): + settingsData["creationDate"] = getUtcTimestamp() + if not settingsData.get("lastModified"): + settingsData["lastModified"] = getUtcTimestamp() + + return VoiceSettings(**settingsData) + + except Exception as e: + logger.error(f"Error getting voice settings: {e}") + return None + + def createVoiceSettings(self, settingsData: Dict[str, Any]) -> Dict[str, Any]: + try: + if "userId" not in settingsData: + settingsData["userId"] = self.userId + if "mandateId" not in settingsData: + settingsData["mandateId"] = self.mandateId + if "featureInstanceId" not in settingsData: + settingsData["featureInstanceId"] = self.featureInstanceId + + existing = self.getVoiceSettings(settingsData["userId"]) + if existing: + raise ValueError(f"Voice settings already exist for user {settingsData['userId']}") + + createdRecord = self.db.recordCreate(VoiceSettings, settingsData) + if not createdRecord or not createdRecord.get("id"): + raise ValueError("Failed to create voice settings record") + + logger.info(f"Created voice settings for user {settingsData['userId']}") + return createdRecord + + except Exception as e: + logger.error(f"Error creating voice settings: {e}") + raise + + def updateVoiceSettings(self, userId: str, updateData: Dict[str, Any]) -> Dict[str, Any]: + try: + existing = self.getVoiceSettings(userId) + if not existing: + raise ValueError(f"Voice settings not found for user {userId}") + + updateData["lastModified"] = getUtcTimestamp() + success = self.db.recordModify(VoiceSettings, existing.id, updateData) + if not success: + raise ValueError("Failed to update voice settings record") + + updated = self.getVoiceSettings(userId) + if not updated: + raise ValueError("Failed to retrieve updated voice settings") + + logger.info(f"Updated voice settings for user {userId}") + return updated.model_dump() + + except Exception as e: + logger.error(f"Error updating voice settings: {e}") + raise + + def deleteVoiceSettings(self, userId: str) -> bool: + try: + existing = self.getVoiceSettings(userId) + if not existing: + return False + success = self.db.recordDelete(VoiceSettings, existing.id) + if success: + logger.info(f"Deleted voice settings for user {userId}") + return success + except Exception as e: + logger.error(f"Error deleting voice settings: {e}") + return False + + def getOrCreateVoiceSettings(self, userId: Optional[str] = None) -> VoiceSettings: + targetUserId = userId or self.userId + if not targetUserId: + raise ValueError("No user ID provided for voice settings") + + existing = self.getVoiceSettings(targetUserId) + if existing: + return existing + + defaultSettings = { + "userId": targetUserId, + "mandateId": self.mandateId, + "featureInstanceId": self.featureInstanceId, + "sttLanguage": "de-DE", + "ttsLanguage": "de-DE", + "ttsVoice": "de-DE-KatjaNeural", + "translationEnabled": True, + "targetLanguage": "en-US", + } + createdRecord = self.createVoiceSettings(defaultSettings) + return VoiceSettings(**createdRecord) + + # ========================================================================= + # WorkspaceUserSettings CRUD + # ========================================================================= + + def getWorkspaceUserSettings(self, userId: Optional[str] = None) -> Optional[WorkspaceUserSettings]: + targetUserId = userId or self.userId + if not targetUserId: + return None + + try: + recordFilter: Dict[str, Any] = {"userId": targetUserId} + if self.featureInstanceId: + recordFilter["featureInstanceId"] = self.featureInstanceId + + records = getRecordsetWithRBAC( + self.db, WorkspaceUserSettings, self.currentUser, + recordFilter=recordFilter, mandateId=self.mandateId, + ) + if not records: + return None + return WorkspaceUserSettings(**records[0]) + + except Exception as e: + logger.error(f"Error getting workspace user settings: {e}") + return None + + def saveWorkspaceUserSettings(self, data: Dict[str, Any]) -> WorkspaceUserSettings: + """Upsert: create or update workspace user settings for the current user.""" + targetUserId = data.get("userId", self.userId) + if not targetUserId: + raise ValueError("No user ID provided") + + existing = self.getWorkspaceUserSettings(targetUserId) + if existing: + updateData = {k: v for k, v in data.items() if k not in ("id", "userId", "mandateId", "featureInstanceId")} + self.db.recordModify(WorkspaceUserSettings, existing.id, updateData) + updated = self.getWorkspaceUserSettings(targetUserId) + if not updated: + raise ValueError("Failed to retrieve updated workspace user settings") + return updated + + createData = { + "userId": targetUserId, + "mandateId": data.get("mandateId", self.mandateId), + "featureInstanceId": data.get("featureInstanceId", self.featureInstanceId), + } + createData.update({k: v for k, v in data.items() if k not in ("id", "userId", "mandateId", "featureInstanceId")}) + createdRecord = self.db.recordCreate(WorkspaceUserSettings, createData) + if not createdRecord or not createdRecord.get("id"): + raise ValueError("Failed to create workspace user settings") + return WorkspaceUserSettings(**createdRecord) + + +def getInterface(currentUser: Optional[User] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> WorkspaceObjects: + if not currentUser: + raise ValueError("Invalid user context: user is required") + + effectiveMandateId = str(mandateId) if mandateId else None + effectiveFeatureInstanceId = str(featureInstanceId) if featureInstanceId else None + + contextKey = f"workspace_{effectiveMandateId}_{effectiveFeatureInstanceId}_{currentUser.id}" + + if contextKey not in _workspaceInterfaces: + _workspaceInterfaces[contextKey] = WorkspaceObjects(currentUser, mandateId=effectiveMandateId, featureInstanceId=effectiveFeatureInstanceId) + else: + _workspaceInterfaces[contextKey].setUserContext(currentUser, mandateId=effectiveMandateId, featureInstanceId=effectiveFeatureInstanceId) + + return _workspaceInterfaces[contextKey] diff --git a/modules/features/workspace/mainWorkspace.py b/modules/features/workspace/mainWorkspace.py index 81526414..c502a82e 100644 --- a/modules/features/workspace/mainWorkspace.py +++ b/modules/features/workspace/mainWorkspace.py @@ -31,6 +31,15 @@ UI_OBJECTS = [ "label": {"en": "Settings", "de": "Einstellungen", "fr": "Parametres"}, "meta": {"area": "settings"} }, + { + "objectKey": "ui.feature.workspace.rag-insights", + "label": { + "en": "Knowledge insights", + "de": "Wissens-Insights", + "fr": "Aperçu des connaissances", + }, + "meta": {"area": "rag-insights"}, + }, ] RESOURCE_OBJECTS = [ @@ -83,6 +92,7 @@ TEMPLATE_ROLES = [ {"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True}, {"context": "UI", "item": "ui.feature.workspace.editor", "view": True}, {"context": "UI", "item": "ui.feature.workspace.settings", "view": True}, + {"context": "UI", "item": "ui.feature.workspace.rag-insights", "view": True}, {"context": "DATA", "item": None, "view": True, "read": "m", "create": "n", "update": "n", "delete": "n"}, ] }, @@ -97,6 +107,7 @@ TEMPLATE_ROLES = [ {"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True}, {"context": "UI", "item": "ui.feature.workspace.editor", "view": True}, {"context": "UI", "item": "ui.feature.workspace.settings", "view": True}, + {"context": "UI", "item": "ui.feature.workspace.rag-insights", "view": True}, {"context": "RESOURCE", "item": "resource.feature.workspace.start", "view": True}, {"context": "RESOURCE", "item": "resource.feature.workspace.stop", "view": True}, {"context": "RESOURCE", "item": "resource.feature.workspace.files", "view": True}, @@ -215,6 +226,8 @@ def _syncTemplateRolesToDb() -> int: if createdCount > 0: logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles") + _repairWorkspaceUserInstanceUiNav(rootInterface) + return createdCount except Exception as e: @@ -222,6 +235,57 @@ def _syncTemplateRolesToDb() -> int: return 0 +def _repairWorkspaceUserInstanceUiNav(rootInterface) -> int: + """ + Ensure every instance-scoped workspace-user role grants UI view on Editor and Settings. + Covers older instance roles copied before template updates (bootstrap / app startup). + """ + from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role + + workspaceNavObjectKeys = ( + "ui.feature.workspace.editor", + "ui.feature.workspace.settings", + ) + repairCount = 0 + try: + userRoleRecords = rootInterface.db.getRecordset( + Role, + recordFilter={"featureCode": FEATURE_CODE, "roleLabel": "workspace-user"}, + ) + for roleRec in userRoleRecords or []: + if not roleRec.get("featureInstanceId"): + continue + roleId = str(roleRec.get("id")) + accessRules = rootInterface.getAccessRulesByRole(roleId) + rulesByUiKey = {(r.context, r.item): r for r in accessRules} + for objectKey in workspaceNavObjectKeys: + uiKey = (AccessRuleContext.UI, objectKey) + existingRule = rulesByUiKey.get(uiKey) + if existingRule is None: + newRule = AccessRule( + roleId=roleId, + context=AccessRuleContext.UI, + item=objectKey, + view=True, + read=None, + create=None, + update=None, + delete=None, + ) + rootInterface.db.recordCreate(AccessRule, newRule.model_dump()) + repairCount += 1 + elif not existingRule.view: + rootInterface.db.recordModify(AccessRule, str(existingRule.id), {"view": True}) + repairCount += 1 + if repairCount: + logger.info( + f"Feature '{FEATURE_CODE}': Repaired {repairCount} UI AccessRules for instance workspace-user roles (Editor/Settings)" + ) + except Exception as e: + logger.warning(f"Feature '{FEATURE_CODE}': workspace-user UI nav repair skipped: {e}") + return repairCount + + def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int: """Ensure AccessRules exist for a role based on templates.""" from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext diff --git a/modules/features/workspace/routeFeatureWorkspace.py b/modules/features/workspace/routeFeatureWorkspace.py index d0dd22da..6b8c529b 100644 --- a/modules/features/workspace/routeFeatureWorkspace.py +++ b/modules/features/workspace/routeFeatureWorkspace.py @@ -19,7 +19,12 @@ from modules.auth import limiter, getRequestContext, RequestContext from modules.serviceCenter.services.serviceBilling.mainServiceBilling import ( InsufficientBalanceException, ) +from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + SubscriptionInactiveException, +) from modules.interfaces import interfaceDbChat, interfaceDbManagement +from modules.features.workspace import interfaceFeatureWorkspace +from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface from modules.interfaces.interfaceAiObjects import AiObjects from modules.serviceCenter.core.serviceStreaming import get_event_manager from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentEventTypeEnum, PendingFileEdit @@ -142,6 +147,14 @@ def _getDbManagement(context: RequestContext, featureInstanceId: str = None): ) +def _getWorkspaceInterface(context: RequestContext, featureInstanceId: str = None): + return interfaceFeatureWorkspace.getInterface( + context.user, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=featureInstanceId, + ) + + _SOURCE_TYPE_TO_SERVICE = { "sharepointFolder": "sharepoint", "onedriveFolder": "onedrive", @@ -698,7 +711,17 @@ async def _runWorkspaceAgent( _toolSet = _cfg.get("toolSet", "core") _agentCfg = _cfg.get("agentConfig") from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentConfig - agentConfig = AgentConfig(**_agentCfg) if isinstance(_agentCfg, dict) else None + + agentCfgDict = dict(_agentCfg) if isinstance(_agentCfg, dict) else {} + try: + wsIf = interfaceFeatureWorkspace.getInterface(user, mandateId=mandateId, featureInstanceId=instanceId) + userSettings = wsIf.getWorkspaceUserSettings(user.id if user else None) + if userSettings and userSettings.maxAgentRounds is not None: + agentCfgDict["maxRounds"] = userSettings.maxAgentRounds + except Exception as e: + logger.debug(f"Could not load workspace user settings for agent config: {e}") + + agentConfig = AgentConfig(**agentCfgDict) if agentCfgDict else None async for event in agentService.runAgent( prompt=enrichedPrompt, @@ -803,7 +826,15 @@ async def _runWorkspaceAgent( }) except Exception as e: - if isinstance(e, InsufficientBalanceException): + if isinstance(e, SubscriptionInactiveException): + logger.warning(f"Workspace blocked by subscription: {e.message}") + await eventManager.emit_event(queueId, "error", { + "type": "error", + "content": e.message, + "workflowId": workflowId, + "item": e.toClientDict(), + }) + elif isinstance(e, InsufficientBalanceException): logger.warning(f"Workspace blocked by billing: {e.message}") await eventManager.emit_event(queueId, "error", { "type": "error", @@ -1564,13 +1595,13 @@ async def getVoiceSettings( ): """Load voice settings for the current user and instance.""" _validateInstanceAccess(instanceId, context) - dbMgmt = _getDbManagement(context, instanceId) + wsInterface = _getWorkspaceInterface(context, instanceId) userId = str(context.user.id) try: - vs = dbMgmt.getVoiceSettings(userId) + vs = wsInterface.getVoiceSettings(userId) if not vs: logger.info(f"GET voice settings: not found for user={userId}, creating defaults") - vs = dbMgmt.getOrCreateVoiceSettings(userId) + vs = wsInterface.getOrCreateVoiceSettings(userId) result = vs.model_dump() if vs else {} mapKeys = list(result.get("ttsVoiceMap", {}).keys()) if result else [] logger.info(f"GET voice settings for user={userId}: ttsVoiceMap languages={mapKeys}") @@ -1590,12 +1621,12 @@ async def updateVoiceSettings( ): """Update voice settings for the current user and instance.""" _validateInstanceAccess(instanceId, context) - dbMgmt = _getDbManagement(context, instanceId) + wsInterface = _getWorkspaceInterface(context, instanceId) userId = str(context.user.id) try: logger.info(f"PUT voice settings for user={userId}, instance={instanceId}, body keys={list(body.keys())}") - vs = dbMgmt.getVoiceSettings(userId) + vs = wsInterface.getVoiceSettings(userId) if not vs: logger.info(f"No existing voice settings, creating new for user={userId}") createData = { @@ -1604,13 +1635,13 @@ async def updateVoiceSettings( "featureInstanceId": instanceId, } createData.update(body) - created = dbMgmt.createVoiceSettings(createData) + created = wsInterface.createVoiceSettings(createData) logger.info(f"Created voice settings for user={userId}, ttsVoiceMap keys={list((created or {}).get('ttsVoiceMap', {}).keys())}") return JSONResponse(created) updateData = {k: v for k, v in body.items() if k not in ("id", "userId", "mandateId", "featureInstanceId", "creationDate")} logger.info(f"Updating voice settings for user={userId}, update keys={list(updateData.keys())}") - updated = dbMgmt.updateVoiceSettings(userId, updateData) + updated = wsInterface.updateVoiceSettings(userId, updateData) logger.info(f"Updated voice settings for user={userId}, ttsVoiceMap keys={list((updated or {}).get('ttsVoiceMap', {}).keys())}") return JSONResponse(updated) except Exception as e: @@ -1816,3 +1847,113 @@ async def rejectAllEdits( logger.info(f"Rejected {len(rejected)} edits for instance {instanceId}") return JSONResponse({"rejected": rejected}) + + +# ========================================================================= +# General Settings Endpoints (per-user workspace settings) +# ========================================================================= + +@router.get("/{instanceId}/settings/general") +@limiter.limit("120/minute") +async def getGeneralSettings( + request: Request, + instanceId: str = Path(...), + context: RequestContext = Depends(getRequestContext), +): + """Load general workspace settings for the current user, with effective values.""" + _mandateId, instanceConfig = _validateInstanceAccess(instanceId, context) + wsInterface = _getWorkspaceInterface(context, instanceId) + userId = str(context.user.id) + + userSettings = wsInterface.getWorkspaceUserSettings(userId) + + agentCfg = (instanceConfig or {}).get("agentConfig", {}) + instanceDefault = agentCfg.get("maxRounds", 25) if isinstance(agentCfg, dict) else 25 + + userOverride = userSettings.maxAgentRounds if userSettings else None + effective = userOverride if userOverride is not None else instanceDefault + + return JSONResponse({ + "maxAgentRounds": { + "effective": effective, + "userOverride": userOverride, + "instanceDefault": instanceDefault, + }, + }) + + +@router.put("/{instanceId}/settings/general") +@limiter.limit("120/minute") +async def updateGeneralSettings( + request: Request, + instanceId: str = Path(...), + body: dict = Body(...), + context: RequestContext = Depends(getRequestContext), +): + """Update general workspace settings for the current user.""" + _validateInstanceAccess(instanceId, context) + wsInterface = _getWorkspaceInterface(context, instanceId) + userId = str(context.user.id) + + data = { + "userId": userId, + "mandateId": str(context.mandateId) if context.mandateId else "", + "featureInstanceId": instanceId, + } + if "maxAgentRounds" in body: + val = body["maxAgentRounds"] + data["maxAgentRounds"] = int(val) if val is not None else None + + wsInterface.saveWorkspaceUserSettings(data) + + return await getGeneralSettings(request, instanceId, context) + + +# ========================================================================= +# RAG / Knowledge — anonymised instance statistics (presentation / KPIs) +# ========================================================================= + +def _collectWorkspaceFileIdsForStats(instanceId: str, mandateId: Optional[str]) -> List[str]: + """All FileItem ids for this feature instance (any user). Knowledge rows are often stored + without featureInstanceId; we correlate by file id from the Management DB.""" + from modules.datamodels.datamodelFiles import FileItem + from modules.interfaces.interfaceDbManagement import ComponentObjects + + co = ComponentObjects() + rows = co.db.getRecordset(FileItem, recordFilter={"featureInstanceId": instanceId}) + out: List[str] = [] + m = str(mandateId) if mandateId else "" + for r in rows or []: + rid = r.get("id") if isinstance(r, dict) else getattr(r, "id", None) + if not rid: + continue + if m: + mid = r.get("mandateId") if isinstance(r, dict) else getattr(r, "mandateId", "") or "" + if mid and mid != m: + continue + out.append(str(rid)) + return out + + +@router.get("/{instanceId}/rag-statistics") +@limiter.limit("60/minute") +async def getRagStatistics( + request: Request, + instanceId: str = Path(...), + days: int = Query(90, ge=7, le=365, description="Timeline window in days"), + context: RequestContext = Depends(getRequestContext), +): + """Aggregated, non-identifying knowledge-store metrics for this workspace instance.""" + mandateId, _instanceConfig = _validateInstanceAccess(instanceId, context) + workspaceFileIds = _collectWorkspaceFileIdsForStats(instanceId, mandateId) + kdb = getKnowledgeInterface(context.user) + stats = kdb.getRagStatisticsForInstance( + featureInstanceId=instanceId, + mandateId=str(mandateId) if mandateId else "", + timelineDays=days, + workspaceFileIds=workspaceFileIds, + ) + if isinstance(stats, dict): + stats.setdefault("scope", {}) + stats["scope"]["workspaceFileIdsResolved"] = len(workspaceFileIds) + return JSONResponse(stats) diff --git a/modules/interfaces/interfaceBootstrap.py b/modules/interfaces/interfaceBootstrap.py index c1cac9ef..89cf4126 100644 --- a/modules/interfaces/interfaceBootstrap.py +++ b/modules/interfaces/interfaceBootstrap.py @@ -103,6 +103,13 @@ def initBootstrap(db: DatabaseConnector) -> None: if mandateId: initRootMandateBilling(mandateId) + # Initialize subscription for root mandate + if mandateId: + _initRootMandateSubscription(mandateId) + + # Auto-provision Stripe Products/Prices for paid plans (idempotent) + _bootstrapStripePrices() + def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str] = None) -> None: """ @@ -1864,7 +1871,7 @@ def _createAicoreProviderRules(db: DatabaseConnector) -> None: Provider access per role: - admin: all providers allowed - - user: all providers EXCEPT anthropic (view=False) + - user: all providers allowed - viewer: NO provider access (viewer has no RESOURCE permissions) NOTE: Provider list is dynamically discovered from AICore model registry. @@ -1909,7 +1916,7 @@ def _createAicoreProviderRules(db: DatabaseConnector) -> None: read=None, create=None, update=None, delete=None, )) - # User: access to all providers EXCEPT anthropic + # User: access to all providers (same provider keys as admin) userId = _getRoleId(db, "user") if userId: for provider in providers: @@ -1923,13 +1930,11 @@ def _createAicoreProviderRules(db: DatabaseConnector) -> None: } ) if not existingRules: - # Anthropic is not allowed for user role - isAllowed = provider != "anthropic" providerRules.append(AccessRule( roleId=userId, context=AccessRuleContext.RESOURCE, item=resourceKey, - view=isAllowed, + view=True, read=None, create=None, update=None, delete=None, )) @@ -2069,6 +2074,47 @@ def initRootMandateBilling(mandateId: str) -> None: logger.warning(f"Failed to initialize root mandate billing (non-critical): {e}") +def _initRootMandateSubscription(mandateId: str) -> None: + """ + Ensure the root mandate has an active ROOT subscription. + Called during bootstrap after billing init. + """ + try: + from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface + from modules.datamodels.datamodelSubscription import ( + MandateSubscription, + SubscriptionStatusEnum, + ) + + subInterface = getSubRootInterface() + existing = subInterface.getOperativeForMandate(mandateId) + if existing: + logger.info("Root mandate subscription already exists") + return + + sub = MandateSubscription( + mandateId=mandateId, + planKey="ROOT", + status=SubscriptionStatusEnum.ACTIVE, + recurring=False, + ) + subInterface.createSubscription(sub) + logger.info("Created ROOT subscription for root mandate") + + except Exception as e: + logger.warning(f"Failed to initialize root mandate subscription (non-critical): {e}") + + +def _bootstrapStripePrices() -> None: + """Auto-create Stripe Products and Prices for all paid plans. + Idempotent — safe on every startup. IDs are persisted in the StripePlanPrice table.""" + try: + from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import bootstrapStripePrices + bootstrapStripePrices() + except Exception as e: + logger.error(f"Stripe price bootstrap failed (subscriptions will not work for paid plans): {e}") + + def assignInitialUserMemberships( db: DatabaseConnector, mandateId: str, diff --git a/modules/interfaces/interfaceDbApp.py b/modules/interfaces/interfaceDbApp.py index 2fec872d..12eb935b 100644 --- a/modules/interfaces/interfaceDbApp.py +++ b/modules/interfaces/interfaceDbApp.py @@ -423,7 +423,7 @@ class AppObjects: else: # Unknown operator - default to equals - if record_value != filter_val: + if str(record_value).lower() != str(filter_val).lower(): matches = False break @@ -512,53 +512,34 @@ class AppObjects: If pagination is None: List[User] If pagination is provided: PaginatedResult with items and metadata """ - # Get user IDs via UserMandate junction table (UserInDB has no mandateId column) userMandates = self.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId}) userIds = [um.get("userId") for um in userMandates if um.get("userId")] - - # Fetch each user by ID - filteredUsers = [] - for userId in userIds: - userRecords = self.db.getRecordset(UserInDB, recordFilter={"id": userId}) - if userRecords: - cleanedUser = {k: v for k, v in userRecords[0].items() if not k.startswith("_")} - if cleanedUser.get("roleLabels") is None: - cleanedUser["roleLabels"] = [] - filteredUsers.append(cleanedUser) - # If no pagination requested, return all items + if not userIds: + if pagination is None: + return [] + return PaginatedResult(items=[], totalItems=0, totalPages=0) + + result = self.db.getRecordsetPaginated( + UserInDB, + pagination=pagination, + recordFilter={"id": userIds} + ) + + items = [] + for record in result["items"]: + cleanedUser = {k: v for k, v in record.items() if not k.startswith("_")} + if cleanedUser.get("roleLabels") is None: + cleanedUser["roleLabels"] = [] + items.append(User(**cleanedUser)) + if pagination is None: - return [User(**user) for user in filteredUsers] - - # Apply filtering (if filters provided) - if pagination.filters: - filteredUsers = self._applyFilters(filteredUsers, pagination.filters) - - # Apply sorting (in order of sortFields) - if pagination.sort: - filteredUsers = self._applySorting(filteredUsers, pagination.sort) - - # Count total items after filters - totalItems = len(filteredUsers) - totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0 - - # Apply pagination (skip/limit) - startIdx = (pagination.page - 1) * pagination.pageSize - endIdx = startIdx + pagination.pageSize - pagedUsers = filteredUsers[startIdx:endIdx] - - # Ensure roleLabels is always a list for paginated results too - for user in pagedUsers: - if user.get("roleLabels") is None: - user["roleLabels"] = [] - - # Convert to model objects - items = [User(**user) for user in pagedUsers] - + return items + return PaginatedResult( items=items, - totalItems=totalItems, - totalPages=totalPages + totalItems=result["totalItems"], + totalPages=result["totalPages"] ) def getUserByUsername(self, username: str) -> Optional[User]: @@ -1615,7 +1596,10 @@ class AppObjects: existing = self.getUserMandate(userId, mandateId) if existing: raise ValueError(f"User {userId} is already member of mandate {mandateId}") - + + # Subscription capacity check (before insert) + self._checkSubscriptionCapacity(mandateId, "users", delta=1) + # Create UserMandate userMandate = UserMandate( userId=userId, @@ -1636,7 +1620,10 @@ class AppObjects: # Create billing account for user if billing is configured self._ensureUserBillingAccount(userId, mandateId) - + + # Sync Stripe quantity after successful insert + self._syncSubscriptionQuantity(mandateId) + cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")} return UserMandate(**cleanedRecord) except Exception as e: @@ -1686,6 +1673,28 @@ class AppObjects: except Exception as e: logger.warning(f"Failed to create billing account for user {userId} (non-critical): {e}") + def _checkSubscriptionCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> None: + """Check subscription capacity before creating a resource. Raises on cap violation.""" + try: + from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface + from modules.security.rootAccess import getRootUser + subIf = getSubInterface(getRootUser(), mandateId) + subIf.assertCapacity(mandateId, resourceType, delta) + except Exception as e: + if "SubscriptionCapacityException" in type(e).__name__: + raise + logger.debug(f"Subscription capacity check skipped: {e}") + + def _syncSubscriptionQuantity(self, mandateId: str) -> None: + """Sync Stripe subscription quantities after a resource mutation.""" + try: + from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface + from modules.security.rootAccess import getRootUser + subIf = getSubInterface(getRootUser(), mandateId) + subIf.syncQuantityToStripe(mandateId) + except Exception as e: + logger.debug(f"Subscription quantity sync skipped: {e}") + def deleteUserMandate(self, userId: str, mandateId: str) -> bool: """ Delete a UserMandate record (remove user from mandate). @@ -2284,30 +2293,43 @@ class AppObjects: # Additional Helper Methods # ============================================ - def getAllUsers(self) -> List[User]: + def getAllUsers(self, pagination: Optional[PaginationParams] = None) -> Union[List[User], PaginatedResult]: """ Get all users (for SysAdmin only). + Args: + pagination: Optional pagination parameters. If None, returns all items. + Returns: - List of User objects (without sensitive fields) + If pagination is None: List[User] (without sensitive fields) + If pagination is provided: PaginatedResult with items and metadata """ try: - records = self.db.getRecordset(UserInDB) - result = [] - for record in records: - # Filter out sensitive and internal fields + result = self.db.getRecordsetPaginated(UserInDB, pagination=pagination) + + items = [] + for record in result["items"]: cleanedRecord = { - k: v for k, v in record.items() + k: v for k, v in record.items() if not k.startswith("_") and k not in ["hashedPassword", "resetToken", "resetTokenExpires"] } - # Ensure roleLabels is a list if cleanedRecord.get("roleLabels") is None: cleanedRecord["roleLabels"] = [] - result.append(User(**cleanedRecord)) - return result + items.append(User(**cleanedRecord)) + + if pagination is None: + return items + + return PaginatedResult( + items=items, + totalItems=result["totalItems"], + totalPages=result["totalPages"] + ) except Exception as e: logger.error(f"Error getting all users: {e}") - return [] + if pagination is None: + return [] + return PaginatedResult(items=[], totalItems=0, totalPages=0) def getUserMandateById(self, userMandateId: str) -> Optional[UserMandate]: """ @@ -3293,50 +3315,26 @@ class AppObjects: If pagination is provided: PaginatedResult with items and metadata """ try: - # Get all roles from database - roles = self.db.getRecordset(Role) - - # Filter out database-specific fields - filteredRoles = [] - for role in roles: - cleanedRole = {k: v for k, v in role.items() if not k.startswith("_")} - filteredRoles.append(cleanedRole) - - # If no pagination requested, return all items + result = self.db.getRecordsetPaginated(Role, pagination=pagination) + + items = [] + for record in result["items"]: + cleanedRole = {k: v for k, v in record.items() if not k.startswith("_")} + items.append(Role(**cleanedRole)) + if pagination is None: - return [Role(**role) for role in filteredRoles] - - # Apply filtering (if filters provided) - if pagination.filters: - filteredRoles = self._applyFilters(filteredRoles, pagination.filters) - - # Apply sorting (in order of sortFields) - if pagination.sort: - filteredRoles = self._applySorting(filteredRoles, pagination.sort) - - # Count total items after filters - totalItems = len(filteredRoles) - totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0 - - # Apply pagination (skip/limit) - startIdx = (pagination.page - 1) * pagination.pageSize - endIdx = startIdx + pagination.pageSize - pagedRoles = filteredRoles[startIdx:endIdx] - - # Convert to model objects - items = [Role(**role) for role in pagedRoles] - + return items + return PaginatedResult( items=items, - totalItems=totalItems, - totalPages=totalPages + totalItems=result["totalItems"], + totalPages=result["totalPages"] ) except Exception as e: logger.error(f"Error getting all roles: {str(e)}") if pagination is None: return [] - else: - return PaginatedResult(items=[], totalItems=0, totalPages=0) + return PaginatedResult(items=[], totalItems=0, totalPages=0) def countRoleAssignments(self) -> Dict[str, int]: """ diff --git a/modules/interfaces/interfaceDbBilling.py b/modules/interfaces/interfaceDbBilling.py index 58d56895..2db71bb4 100644 --- a/modules/interfaces/interfaceDbBilling.py +++ b/modules/interfaces/interfaceDbBilling.py @@ -8,7 +8,7 @@ All billing data is stored in the poweron_billing database. """ import logging -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Union from datetime import date, datetime, timedelta import uuid @@ -17,6 +17,7 @@ from modules.shared.configuration import APP_CONFIG from modules.shared.timeUtils import getUtcTimestamp from modules.datamodels.datamodelUam import User, Mandate from modules.datamodels.datamodelMembership import UserMandate +from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult from modules.datamodels.datamodelBilling import ( BillingAccount, BillingTransaction, @@ -633,26 +634,43 @@ class BillingObjects: limit: int = 100, offset: int = 0, startDate: date = None, - endDate: date = None - ) -> List[Dict[str, Any]]: + endDate: date = None, + pagination: PaginationParams = None + ) -> Union[List[Dict[str, Any]], PaginatedResult]: """ Get transactions for an account. + When pagination is provided, uses database-level pagination and returns + PaginatedResult. Otherwise falls back to in-memory filtering/sorting/slicing. + Args: accountId: Account ID - limit: Maximum number of results - offset: Offset for pagination - startDate: Filter by start date - endDate: Filter by end date + limit: Maximum number of results (legacy path) + offset: Offset for pagination (legacy path) + startDate: Filter by start date (legacy path) + endDate: Filter by end date (legacy path) + pagination: PaginationParams for DB-level pagination Returns: - List of transaction dicts + PaginatedResult when pagination is provided, List of dicts otherwise """ try: + if pagination: + recordFilter = {"accountId": accountId} + result = self.db.getRecordsetPaginated( + BillingTransaction, + pagination=pagination, + recordFilter=recordFilter + ) + return PaginatedResult( + items=result["items"], + totalItems=result["totalItems"], + totalPages=result["totalPages"] + ) + filterDict = {"accountId": accountId} results = self.db.getRecordset(BillingTransaction, recordFilter=filterDict) - # Apply date filters if provided if startDate or endDate: filtered = [] for t in results: @@ -666,35 +684,61 @@ class BillingObjects: filtered.append(t) results = filtered - # Sort by creation date descending results.sort(key=lambda x: x.get("_createdAt", ""), reverse=True) - # Apply pagination return results[offset:offset + limit] except Exception as e: logger.error(f"Error getting transactions: {e}") + if pagination: + return PaginatedResult(items=[], totalItems=0, totalPages=0) return [] - def getTransactionsByMandate(self, mandateId: str, limit: int = 100) -> List[Dict[str, Any]]: + def getTransactionsByMandate( + self, + mandateId: str, + limit: int = 100, + pagination: PaginationParams = None + ) -> Union[List[Dict[str, Any]], PaginatedResult]: """ Get all transactions for a mandate (across all accounts). + When pagination is provided, collects all accountIds for the mandate and + issues a single DB query with SQL-level filtering, sorting, and pagination. + Otherwise falls back to per-account querying and in-memory merging. + Args: mandateId: Mandate ID - limit: Maximum number of results + limit: Maximum number of results (legacy path) + pagination: PaginationParams for DB-level pagination Returns: - List of transaction dicts + PaginatedResult when pagination is provided, List of dicts otherwise """ - # Get all accounts for mandate accounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId}) - + accountIds = [acc["id"] for acc in accounts if acc.get("id")] + + if not accountIds: + if pagination: + return PaginatedResult(items=[], totalItems=0, totalPages=0) + return [] + + if pagination: + result = self.db.getRecordsetPaginated( + BillingTransaction, + pagination=pagination, + recordFilter={"accountId": accountIds} + ) + return PaginatedResult( + items=result["items"], + totalItems=result["totalItems"], + totalPages=result["totalPages"] + ) + allTransactions = [] for account in accounts: transactions = self.getTransactions(account["id"], limit=limit) allTransactions.extend(transactions) - # Sort by creation date descending and limit allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True) return allTransactions[:limit] diff --git a/modules/interfaces/interfaceDbChat.py b/modules/interfaces/interfaceDbChat.py index f2a3d4c6..b0d4aff3 100644 --- a/modules/interfaces/interfaceDbChat.py +++ b/modules/interfaces/interfaceDbChat.py @@ -450,7 +450,7 @@ class ChatObjects: # Handle simple value (equals operator) if not isinstance(filter_value, dict): - if record_value != filter_value: + if str(record_value).lower() != str(filter_value).lower(): matches = False break continue @@ -460,7 +460,7 @@ class ChatObjects: filter_val = filter_value.get("value") if operator in ["equals", "eq"]: - if record_value != filter_val: + if str(record_value).lower() != str(filter_val).lower(): matches = False break @@ -545,7 +545,7 @@ class ChatObjects: else: # Unknown operator - default to equals - if record_value != filter_val: + if str(record_value).lower() != str(filter_val).lower(): matches = False break diff --git a/modules/interfaces/interfaceDbKnowledge.py b/modules/interfaces/interfaceDbKnowledge.py index c8a597df..adf8ed0a 100644 --- a/modules/interfaces/interfaceDbKnowledge.py +++ b/modules/interfaces/interfaceDbKnowledge.py @@ -7,6 +7,8 @@ and semantic search via pgvector. """ import logging +from collections import defaultdict +from datetime import datetime, timezone, timedelta from typing import Dict, Any, List, Optional from modules.connectors.connectorDbPostgre import _get_cached_connector @@ -286,6 +288,183 @@ class KnowledgeObjects: minScore=minScore, ) + def getRagStatisticsForInstance( + self, + featureInstanceId: str, + mandateId: str, + timelineDays: int = 90, + workspaceFileIds: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """Aggregate anonymised RAG / knowledge-store metrics for one workspace instance. + + No file names, user identifiers, or chunk text are returned — only counts and + distributions suitable for dashboards and presentations. + + workspaceFileIds: optional list of FileItem ids for this feature instance (from Management DB). + Index pipelines often stored rows with empty featureInstanceId; linking by file id fixes stats. + """ + if not featureInstanceId: + return {"error": "featureInstanceId required"} + + ws_ids = [x for x in (workspaceFileIds or []) if x] + ws_id_set = set(ws_ids) + + files_inst = self.db.getRecordset( + FileContentIndex, + recordFilter={"featureInstanceId": featureInstanceId}, + ) + files_shared: List[Dict[str, Any]] = [] + if mandateId: + files_shared = self.db.getRecordset( + FileContentIndex, + recordFilter={"mandateId": mandateId, "isShared": True}, + ) + + by_id: Dict[str, Dict[str, Any]] = {} + for row in files_inst + files_shared: + rid = row.get("id") + if rid and rid not in by_id: + by_id[rid] = row + + for fid in ws_ids: + if fid in by_id: + continue + row = self.getFileContentIndex(fid) + if row: + by_id[fid] = row + + files = list(by_id.values()) + + chunks_by_id: Dict[str, Dict[str, Any]] = {} + inst_chunks = self.db.getRecordset( + ContentChunk, + recordFilter={"featureInstanceId": featureInstanceId}, + ) + for c in inst_chunks: + cid = c.get("id") + if cid: + chunks_by_id[cid] = c + + for fid in ws_id_set: + for c in self.getContentChunks(fid): + cid = c.get("id") + if cid and cid not in chunks_by_id: + chunks_by_id[cid] = c + + covered_file_ids = {c.get("fileId") for c in chunks_by_id.values() if c.get("fileId")} + for row in files: + fid = row.get("id") + if fid and fid not in covered_file_ids: + for c in self.getContentChunks(fid): + cid = c.get("id") + if cid and cid not in chunks_by_id: + chunks_by_id[cid] = c + + chunks = list(chunks_by_id.values()) + + def _mimeCategory(mime: str) -> str: + m = (mime or "").lower() + if "pdf" in m: + return "pdf" + if "wordprocessing" in m or "msword" in m or "officedocument.wordprocessing" in m: + return "office_doc" + if "spreadsheet" in m or "excel" in m or "officedocument.spreadsheet" in m: + return "office_sheet" + if "presentation" in m or "officedocument.presentation" in m: + return "office_slides" + if m.startswith("text/"): + return "text" + if m.startswith("image/"): + return "image" + if "html" in m: + return "html" + return "other" + + def _utcDay(ts: Any) -> str: + if ts is None: + return "" + try: + return datetime.fromtimestamp(float(ts), tz=timezone.utc).strftime("%Y-%m-%d") + except (TypeError, ValueError, OSError): + return "" + + status_counts: Dict[str, int] = defaultdict(int) + mime_counts: Dict[str, int] = defaultdict(int) + extracted_by_day: Dict[str, int] = defaultdict(int) + total_bytes = 0 + user_ids = set() + + for row in files: + st = row.get("status") or "unknown" + status_counts[st] += 1 + mime_counts[_mimeCategory(row.get("mimeType") or "")] += 1 + day = _utcDay(row.get("extractedAt")) + if day: + extracted_by_day[day] += 1 + try: + total_bytes += int(row.get("totalSize") or 0) + except (TypeError, ValueError): + pass + uid = row.get("userId") + if uid: + user_ids.add(str(uid)) + + content_type_counts: Dict[str, int] = defaultdict(int) + chunks_with_embedding = 0 + for c in chunks: + ct = c.get("contentType") or "other" + content_type_counts[ct] += 1 + emb = c.get("embedding") + if emb is not None and ( + (isinstance(emb, list) and len(emb) > 0) + or (isinstance(emb, str) and len(emb) > 10) + ): + chunks_with_embedding += 1 + + wf_mem = self.db.getRecordset( + WorkflowMemory, + recordFilter={"featureInstanceId": featureInstanceId}, + ) + + cutoff = datetime.now(timezone.utc) - timedelta(days=max(1, int(timelineDays))) + cutoff_ts = cutoff.timestamp() + + timeline: List[Dict[str, Any]] = [] + for day in sorted(extracted_by_day.keys()): + try: + d = datetime.strptime(day, "%Y-%m-%d").replace(tzinfo=timezone.utc) + except ValueError: + continue + if d.timestamp() >= cutoff_ts: + timeline.append({"date": day, "indexedDocuments": extracted_by_day[day]}) + + if len(timeline) > 120: + timeline = timeline[-120:] + + total_chunks = len(chunks) + embedding_pct = round(100.0 * chunks_with_embedding / total_chunks, 1) if total_chunks else 0.0 + + return { + "scope": { + "featureInstanceId": featureInstanceId, + "mandateScopedShared": bool(mandateId), + }, + "kpis": { + "indexedDocuments": len(files), + "indexedBytesTotal": total_bytes, + "contributorUsers": len(user_ids), + "contentChunks": total_chunks, + "chunksWithEmbedding": chunks_with_embedding, + "embeddingCoveragePercent": embedding_pct, + "workflowEntities": len(wf_mem), + }, + "indexedDocumentsByStatus": dict(sorted(status_counts.items())), + "documentsByMimeCategory": dict(sorted(mime_counts.items(), key=lambda x: -x[1])), + "chunksByContentType": dict(sorted(content_type_counts.items())), + "timelineIndexedDocuments": timeline, + "generatedAtUtc": datetime.now(timezone.utc).isoformat(), + } + def getInterface(currentUser: Optional[User] = None) -> KnowledgeObjects: """Get or create a KnowledgeObjects singleton.""" diff --git a/modules/interfaces/interfaceDbManagement.py b/modules/interfaces/interfaceDbManagement.py index 9c266aac..64883b95 100644 --- a/modules/interfaces/interfaceDbManagement.py +++ b/modules/interfaces/interfaceDbManagement.py @@ -21,7 +21,6 @@ from modules.datamodels.datamodelUam import AccessLevel from modules.datamodels.datamodelFiles import FilePreview, FileItem, FileData from modules.datamodels.datamodelFileFolder import FileFolder from modules.datamodels.datamodelUtils import Prompt -from modules.datamodels.datamodelVoice import VoiceSettings from modules.datamodels.datamodelMessaging import ( MessagingSubscription, MessagingSubscriptionRegistration, @@ -390,7 +389,7 @@ class ComponentObjects: # Handle simple value (equals operator) if not isinstance(filter_value, dict): - if record_value != filter_value: + if str(record_value).lower() != str(filter_value).lower(): matches = False break continue @@ -400,7 +399,7 @@ class ComponentObjects: filter_val = filter_value.get("value") if operator in ["equals", "eq"]: - if record_value != filter_val: + if str(record_value).lower() != str(filter_val).lower(): matches = False break @@ -650,6 +649,10 @@ class ComponentObjects: - Regular user: sees own prompts + system prompts (isSystem=True), can only CRUD own - Row-level _permissions control edit/delete buttons in the UI + NOTE: Cannot use db.getRecordsetPaginated() because visibility rules + (_getPromptsForUser: own + system for regular, all for SysAdmin) and + per-row _permissions enrichment require loading all records first. + Args: pagination: Optional pagination parameters. If None, returns all items. @@ -914,7 +917,7 @@ class ComponentObjects: """ Returns files owned by the current user (user-scoped, not RBAC-based). Every user (including SysAdmin) only sees their own files. - Supports optional pagination, sorting, and filtering. + Supports optional pagination, sorting, and filtering via database-level queries. Args: pagination: Optional pagination parameters. If None, returns all items. @@ -923,24 +926,21 @@ class ComponentObjects: If pagination is None: List[FileItem] If pagination is provided: PaginatedResult with items and metadata """ - # Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override) - filteredFiles = self._getFilesByCurrentUser() - - # Convert database records to FileItem instances (extra='allow' preserves system fields like _createdBy) - def convertFileItems(files): + # User-scoping filter: every user only sees their own files (bypasses RBAC SysAdmin override) + recordFilter = {"_createdBy": self.userId} + + def _convertFileItems(files): fileItems = [] for file in files: try: - # Ensure proper values, use defaults for invalid data creationDate = file.get("creationDate") if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0: file["creationDate"] = getUtcTimestamp() fileName = file.get("fileName") if not fileName or fileName == "None": - continue # Skip records with invalid fileName + continue - # Use **file to pass all fields including system fields (_createdBy, etc.) fileItem = FileItem(**file) fileItems.append(fileItem) except Exception as e: @@ -948,34 +948,23 @@ class ComponentObjects: continue return fileItems - # If no pagination requested, return all items if pagination is None: - return convertFileItems(filteredFiles) + allFiles = self._getFilesByCurrentUser() + return _convertFileItems(allFiles) - # Apply filtering (if filters provided) - if pagination.filters: - filteredFiles = self._applyFilters(filteredFiles, pagination.filters) + # Database-level pagination: filtering, sorting, and LIMIT/OFFSET happen in SQL + result = self.db.getRecordsetPaginated( + FileItem, + pagination=pagination, + recordFilter=recordFilter + ) - # Apply sorting (in order of sortFields) - if pagination.sort: - filteredFiles = self._applySorting(filteredFiles, pagination.sort) - - # Count total items after filters - totalItems = len(filteredFiles) - totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0 - - # Apply pagination (skip/limit) - startIdx = (pagination.page - 1) * pagination.pageSize - endIdx = startIdx + pagination.pageSize - pagedFiles = filteredFiles[startIdx:endIdx] - - # Convert to model objects (extra='allow' on FileItem preserves system fields) - items = convertFileItems(pagedFiles) + items = _convertFileItems(result["items"]) return PaginatedResult( items=items, - totalItems=totalItems, - totalPages=totalPages + totalItems=result["totalItems"], + totalPages=result["totalPages"] ) def getFile(self, fileId: str) -> Optional[FileItem]: @@ -1733,158 +1722,6 @@ class ComponentObjects: logger.error(f"Error in saveUploadedFile for {fileName}: {str(e)}", exc_info=True) raise FileStorageError(f"Error saving file: {str(e)}") - # VoiceSettings methods - - def getVoiceSettings(self, userId: Optional[str] = None) -> Optional[VoiceSettings]: - """Returns voice settings for a user if user has access.""" - try: - targetUserId = userId or self.userId - if not targetUserId: - logger.error("No user ID provided for voice settings") - return None - - recordFilter: Dict[str, Any] = {"userId": targetUserId} - if self.featureInstanceId: - recordFilter["featureInstanceId"] = self.featureInstanceId - - # Get voice settings for the user (scoped to current feature instance if available), filtered by RBAC - filteredSettings = getRecordsetWithRBAC(self.db, - VoiceSettings, - self.currentUser, - recordFilter=recordFilter, - mandateId=self.mandateId - ) - - if not filteredSettings: - logger.warning(f"No access to voice settings for user {targetUserId}") - return None - - # Ensure timestamps are set for validation - settings_data = filteredSettings[0] - if not settings_data.get("creationDate"): - settings_data["creationDate"] = getUtcTimestamp() - if not settings_data.get("lastModified"): - settings_data["lastModified"] = getUtcTimestamp() - - return VoiceSettings(**settings_data) - - except Exception as e: - logger.error(f"Error getting voice settings: {str(e)}") - return None - - def createVoiceSettings(self, settingsData: Dict[str, Any]) -> Dict[str, Any]: - """Creates voice settings for a user if user has permission.""" - try: - if not self.checkRbacPermission(VoiceSettings, "update"): - raise PermissionError("No permission to create voice settings") - - # Ensure userId is set - if "userId" not in settingsData: - settingsData["userId"] = self.userId - - # Ensure mandateId and featureInstanceId are set from context - if "mandateId" not in settingsData: - settingsData["mandateId"] = self.mandateId - if "featureInstanceId" not in settingsData: - settingsData["featureInstanceId"] = self.featureInstanceId - - # Check if settings already exist for this user - existingSettings = self.getVoiceSettings(settingsData["userId"]) - if existingSettings: - raise ValueError(f"Voice settings already exist for user {settingsData['userId']}") - - # Create voice settings record - createdRecord = self.db.recordCreate(VoiceSettings, settingsData) - if not createdRecord or not createdRecord.get("id"): - raise ValueError("Failed to create voice settings record") - - logger.info(f"Created voice settings for user {settingsData['userId']}") - return createdRecord - - except Exception as e: - logger.error(f"Error creating voice settings: {str(e)}") - raise - - def updateVoiceSettings(self, userId: str, updateData: Dict[str, Any]) -> Dict[str, Any]: - """Updates voice settings for a user if user has access.""" - try: - # Get existing settings - existingSettings = self.getVoiceSettings(userId) - if not existingSettings: - raise ValueError(f"Voice settings not found for user {userId}") - - # Update lastModified timestamp - updateData["lastModified"] = getUtcTimestamp() - - # Update voice settings record - success = self.db.recordModify(VoiceSettings, existingSettings.id, updateData) - if not success: - raise ValueError("Failed to update voice settings record") - - # Get updated settings - updatedSettings = self.getVoiceSettings(userId) - if not updatedSettings: - raise ValueError("Failed to retrieve updated voice settings") - - logger.info(f"Updated voice settings for user {userId}") - return updatedSettings.model_dump() - - except Exception as e: - logger.error(f"Error updating voice settings: {str(e)}") - raise - - def deleteVoiceSettings(self, userId: str) -> bool: - """Deletes voice settings for a user if user has access.""" - try: - # Get existing settings - existingSettings = self.getVoiceSettings(userId) - if not existingSettings: - logger.warning(f"Voice settings not found for user {userId}") - return False - - # Delete voice settings - success = self.db.recordDelete(VoiceSettings, existingSettings.id) - if success: - logger.info(f"Deleted voice settings for user {userId}") - else: - logger.error(f"Failed to delete voice settings for user {userId}") - - return success - - except Exception as e: - logger.error(f"Error deleting voice settings: {str(e)}") - return False - - def getOrCreateVoiceSettings(self, userId: Optional[str] = None) -> VoiceSettings: - """Gets existing voice settings or creates default ones for a user.""" - try: - targetUserId = userId or self.userId - if not targetUserId: - raise ValueError("No user ID provided for voice settings") - - # Try to get existing settings - existingSettings = self.getVoiceSettings(targetUserId) - if existingSettings: - return existingSettings - - # Create default settings - defaultSettings = { - "userId": targetUserId, - "mandateId": self.mandateId, - "sttLanguage": "de-DE", - "ttsLanguage": "de-DE", - "ttsVoice": "de-DE-KatjaNeural", - "translationEnabled": True, - "targetLanguage": "en-US" - } - - createdRecord = self.createVoiceSettings(defaultSettings) - return VoiceSettings(**createdRecord) - - except Exception as e: - logger.error(f"Error getting or creating voice settings: {str(e)}") - raise - # Messaging Subscription methods def getAllSubscriptions(self, pagination: Optional[PaginationParams] = None) -> Union[List[MessagingSubscription], PaginatedResult]: diff --git a/modules/interfaces/interfaceDbSubscription.py b/modules/interfaces/interfaceDbSubscription.py new file mode 100644 index 00000000..f08025ea --- /dev/null +++ b/modules/interfaces/interfaceDbSubscription.py @@ -0,0 +1,353 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Interface for Subscription operations — ID-based, deterministic. + +Every write operation takes an explicit subscriptionId. +No status-scan guessing. See wiki/concepts/Subscription-State-Machine.md. +""" + +import logging +from typing import Dict, Any, List, Optional +from datetime import datetime, timezone + +from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.shared.configuration import APP_CONFIG +from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelMembership import UserMandate +from modules.datamodels.datamodelSubscription import ( + SubscriptionPlan, + MandateSubscription, + SubscriptionStatusEnum, + BillingPeriodEnum, + ALLOWED_TRANSITIONS, + TERMINAL_STATUSES, + OPERATIVE_STATUSES, + BUILTIN_PLANS, + _getPlan, + _getSelectablePlans, +) + +logger = logging.getLogger(__name__) + +SUBSCRIPTION_DATABASE = "poweron_billing" + +_subscriptionInterfaces: Dict[str, "SubscriptionObjects"] = {} + + +class InvalidTransitionError(Exception): + """Raised when a state transition is not allowed by the state machine.""" + def __init__(self, subscriptionId: str, fromStatus: str, toStatus: str): + self.subscriptionId = subscriptionId + self.fromStatus = fromStatus + self.toStatus = toStatus + super().__init__(f"Invalid transition {fromStatus} -> {toStatus} for subscription {subscriptionId}") + + +def getInterface(currentUser: User, mandateId: str = None) -> "SubscriptionObjects": + cacheKey = f"{currentUser.id}_{mandateId}" + if cacheKey not in _subscriptionInterfaces: + _subscriptionInterfaces[cacheKey] = SubscriptionObjects(currentUser, mandateId) + else: + _subscriptionInterfaces[cacheKey].setUserContext(currentUser, mandateId) + return _subscriptionInterfaces[cacheKey] + + +def _getRootInterface() -> "SubscriptionObjects": + from modules.security.rootAccess import getRootUser + return SubscriptionObjects(getRootUser(), mandateId=None) + + +def _getAppDatabaseConnector() -> DatabaseConnector: + return DatabaseConnector( + dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"), + dbHost=APP_CONFIG.get("DB_HOST", "localhost"), + dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), + dbUser=APP_CONFIG.get("DB_USER"), + dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), + ) + + +class SubscriptionObjects: + """Interface for subscription operations: CRUD, gate checks, Stripe sync. + All writes are ID-based. All status changes go through transitionStatus().""" + + def __init__(self, currentUser: Optional[User] = None, mandateId: str = None): + self.currentUser = currentUser + self.userId = currentUser.id if currentUser else None + self.mandateId = mandateId + self.db = DatabaseConnector( + dbDatabase=SUBSCRIPTION_DATABASE, + dbHost=APP_CONFIG.get("DB_HOST", "localhost"), + dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), + dbUser=APP_CONFIG.get("DB_USER"), + dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), + ) + + def setUserContext(self, currentUser: User, mandateId: str = None): + self.currentUser = currentUser + self.userId = currentUser.id if currentUser else None + self.mandateId = mandateId + + # ========================================================================= + # Plan catalog (in-memory) + # ========================================================================= + + def getPlan(self, planKey: str) -> Optional[SubscriptionPlan]: + return _getPlan(planKey) + + def getSelectablePlans(self) -> List[SubscriptionPlan]: + return _getSelectablePlans() + + # ========================================================================= + # Read: by ID (primary access pattern) + # ========================================================================= + + def getById(self, subscriptionId: str) -> Optional[Dict[str, Any]]: + """Load a single subscription by its primary key.""" + try: + results = self.db.getRecordset(MandateSubscription, recordFilter={"id": subscriptionId}) + return dict(results[0]) if results else None + except Exception as e: + logger.error("getById(%s) failed: %s", subscriptionId, e) + return None + + def getByStripeSubscriptionId(self, stripeSubId: str) -> Optional[Dict[str, Any]]: + """Load subscription by Stripe subscription ID — the webhook resolution path.""" + try: + results = self.db.getRecordset(MandateSubscription, recordFilter={"stripeSubscriptionId": stripeSubId}) + return dict(results[0]) if results else None + except Exception as e: + logger.error("getByStripeSubscriptionId(%s) failed: %s", stripeSubId, e) + return None + + # ========================================================================= + # Read: by mandate (list queries) + # ========================================================================= + + def listForMandate(self, mandateId: str, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]: + """Return all subscriptions for a mandate, optionally filtered by status. + Sorted newest-first by startedAt.""" + try: + results = self.db.getRecordset(MandateSubscription, recordFilter={"mandateId": mandateId}) + rows = [dict(r) for r in results] + if statusFilter: + filterValues = {s.value for s in statusFilter} + rows = [r for r in rows if r.get("status") in filterValues] + rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True) + return rows + except Exception as e: + logger.error("listForMandate(%s) failed: %s", mandateId, e) + return [] + + def getOperativeForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]: + """Return the single operative subscription (ACTIVE, TRIALING, or PAST_DUE). + This is a read-only query for the billing gate. Returns None if no operative sub exists.""" + for row in self.listForMandate(mandateId): + if row.get("status") in {s.value for s in OPERATIVE_STATUSES}: + return row + return None + + def getScheduledForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]: + """Return a SCHEDULED subscription if one exists (next sub waiting to start).""" + for row in self.listForMandate(mandateId, [SubscriptionStatusEnum.SCHEDULED]): + return row + return None + + # ========================================================================= + # Read: global (SysAdmin) + # ========================================================================= + + def listAll(self, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]: + """Return ALL subscriptions across all mandates, newest-first. SysAdmin use only.""" + try: + results = self.db.getRecordset(MandateSubscription) + rows = [dict(r) for r in results] + if statusFilter: + filterValues = {s.value for s in statusFilter} + rows = [r for r in rows if r.get("status") in filterValues] + rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True) + return rows + except Exception as e: + logger.error("listAll() failed: %s", e) + return [] + + # ========================================================================= + # Write: create + # ========================================================================= + + def createSubscription(self, sub: MandateSubscription) -> Dict[str, Any]: + """Persist a new MandateSubscription record.""" + return self.db.recordCreate(MandateSubscription, sub.model_dump()) + + # ========================================================================= + # Write: update fields (no status change) + # ========================================================================= + + def updateFields(self, subscriptionId: str, data: Dict[str, Any]) -> Dict[str, Any]: + """Update non-status fields on a subscription (e.g. recurring, Stripe IDs, periods). + Must NOT be used for status changes — use transitionStatus() for that.""" + if "status" in data: + raise ValueError("updateFields must not change status — use transitionStatus()") + return self.db.recordModify(MandateSubscription, subscriptionId, data) + + # ========================================================================= + # Write: status transition (guarded) + # ========================================================================= + + def transitionStatus( + self, + subscriptionId: str, + expectedFromStatus: SubscriptionStatusEnum, + toStatus: SubscriptionStatusEnum, + additionalData: Dict[str, Any] = None, + ) -> Dict[str, Any]: + """Execute a guarded status transition. + + 1. Load the record by ID + 2. Verify current status matches expectedFromStatus + 3. Verify the transition is allowed by the state machine + 4. Apply the update + """ + sub = self.getById(subscriptionId) + if not sub: + raise ValueError(f"Subscription {subscriptionId} not found") + + currentStatus = sub.get("status", "") + if currentStatus != expectedFromStatus.value: + raise InvalidTransitionError(subscriptionId, currentStatus, toStatus.value) + + if (expectedFromStatus, toStatus) not in ALLOWED_TRANSITIONS: + raise InvalidTransitionError(subscriptionId, expectedFromStatus.value, toStatus.value) + + updateData = {"status": toStatus.value} + if toStatus in TERMINAL_STATUSES and not (additionalData or {}).get("endedAt"): + updateData["endedAt"] = datetime.now(timezone.utc).isoformat() + if additionalData: + updateData.update(additionalData) + + result = self.db.recordModify(MandateSubscription, subscriptionId, updateData) + logger.info("Transition %s -> %s for subscription %s", expectedFromStatus.value, toStatus.value, subscriptionId) + return result + + def forceExpire(self, subscriptionId: str) -> Dict[str, Any]: + """Sysadmin force-expire: ANY non-terminal -> EXPIRED. Bypasses normal transition guards.""" + sub = self.getById(subscriptionId) + if not sub: + raise ValueError(f"Subscription {subscriptionId} not found") + + currentStatus = sub.get("status", "") + if currentStatus == SubscriptionStatusEnum.EXPIRED.value: + raise ValueError(f"Subscription {subscriptionId} is already EXPIRED") + + result = self.db.recordModify(MandateSubscription, subscriptionId, { + "status": SubscriptionStatusEnum.EXPIRED.value, + "endedAt": datetime.now(timezone.utc).isoformat(), + }) + logger.info("Force-expired subscription %s (was %s)", subscriptionId, currentStatus) + return result + + # ========================================================================= + # Gate: assertActive (read-only, for billing gate) + # ========================================================================= + + def assertActive(self, mandateId: str) -> SubscriptionStatusEnum: + """Return effective status for billing decisions. + Returns the operative subscription's status, or EXPIRED if none exists. + This is the ONLY read-by-mandate operation used in the hot path.""" + sub = self.getOperativeForMandate(mandateId) + if sub: + return SubscriptionStatusEnum(sub["status"]) + return SubscriptionStatusEnum.EXPIRED + + # ========================================================================= + # Gate: assertCapacity + # ========================================================================= + + def assertCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> bool: + sub = self.getOperativeForMandate(mandateId) + if not sub: + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException + raise SubscriptionCapacityException( + resourceType=resourceType, currentCount=0, maxAllowed=0, + message="No active subscription for this mandate.", + ) + + plan = self.getPlan(sub.get("planKey", "")) + if not plan: + return True + + if resourceType == "users": + cap = plan.maxUsers + if cap is None: + return True + current = self.countActiveUsers(mandateId) + if current + delta > cap: + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException + raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap) + elif resourceType == "featureInstances": + cap = plan.maxFeatureInstances + if cap is None: + return True + current = self.countActiveFeatureInstances(mandateId) + if current + delta > cap: + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException + raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap) + + return True + + # ========================================================================= + # Counting (cross-DB queries against poweron_app) + # ========================================================================= + + def countActiveUsers(self, mandateId: str) -> int: + try: + appDb = _getAppDatabaseConnector() + return len(appDb.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})) + except Exception as e: + logger.error("countActiveUsers(%s) failed: %s", mandateId, e) + return 0 + + def countActiveFeatureInstances(self, mandateId: str) -> int: + try: + from modules.datamodels.datamodelFeatures import FeatureInstance + appDb = _getAppDatabaseConnector() + return len(appDb.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId, "enabled": True})) + except Exception as e: + logger.error("countActiveFeatureInstances(%s) failed: %s", mandateId, e) + return 0 + + # ========================================================================= + # Stripe quantity sync + # ========================================================================= + + def syncQuantityToStripe(self, subscriptionId: str) -> None: + """Update Stripe subscription item quantities to match actual active counts. + Takes subscriptionId, not mandateId.""" + sub = self.getById(subscriptionId) + if not sub or not sub.get("stripeSubscriptionId"): + return + + mandateId = sub["mandateId"] + itemIdUsers = sub.get("stripeItemIdUsers") + itemIdInstances = sub.get("stripeItemIdInstances") + + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + + activeUsers = self.countActiveUsers(mandateId) + activeInstances = self.countActiveFeatureInstances(mandateId) + + if itemIdUsers: + stripe.SubscriptionItem.modify( + itemIdUsers, quantity=max(activeUsers, 0), proration_behavior="create_prorations", + ) + if itemIdInstances: + stripe.SubscriptionItem.modify( + itemIdInstances, quantity=max(activeInstances, 0), proration_behavior="create_prorations", + ) + + logger.info("Stripe quantity synced for sub %s: users=%d, instances=%d", subscriptionId, activeUsers, activeInstances) + except Exception as e: + logger.error("syncQuantityToStripe(%s) failed: %s", subscriptionId, e) diff --git a/modules/interfaces/interfaceRbac.py b/modules/interfaces/interfaceRbac.py index b4c9a3b4..e65cd5ab 100644 --- a/modules/interfaces/interfaceRbac.py +++ b/modules/interfaces/interfaceRbac.py @@ -23,10 +23,12 @@ GROUP-Berechtigung: import logging import json -from typing import List, Dict, Any, Optional, Type +import math +from typing import List, Dict, Any, Optional, Type, Union from pydantic import BaseModel from modules.datamodels.datamodelRbac import AccessRuleContext from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel +from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult from modules.security.rbac import RbacClass from modules.security.rootAccess import getRootDbAppConnector @@ -72,6 +74,10 @@ TABLE_NAMESPACE = { # Automation - benutzer-eigen "AutomationDefinition": "automation", "AutomationTemplate": "automation", + # Automation2 - feature-scoped + "Automation2Workflow": "automation2", + "Automation2WorkflowRun": "automation2", + "Automation2HumanTask": "automation2", # Knowledge Store - benutzer-eigen "FileContentIndex": "knowledge", "ContentChunk": "knowledge", @@ -297,6 +303,309 @@ def getRecordsetWithRBAC( return [] +def getRecordsetPaginatedWithRBAC( + connector, + modelClass: Type[BaseModel], + currentUser: User, + pagination: Optional[PaginationParams] = None, + recordFilter: Dict[str, Any] = None, + mandateId: Optional[str] = None, + featureInstanceId: Optional[str] = None, + enrichPermissions: bool = False, + featureCode: Optional[str] = None, +) -> Union[List[Dict[str, Any]], PaginatedResult]: + """ + Get records with RBAC filtering and SQL-level pagination. + When pagination is None, returns a plain list (backward compatible). + When pagination is provided, returns PaginatedResult with COUNT + LIMIT/OFFSET at SQL level. + """ + table = modelClass.__name__ + objectKey = buildDataObjectKey(table, featureCode) + effectiveMandateId = mandateId + + try: + if not connector._ensureTableExists(modelClass): + return PaginatedResult(items=[], totalItems=0, totalPages=0) if pagination else [] + + dbApp = getRootDbAppConnector() + rbacInstance = RbacClass(connector, dbApp=dbApp) + permissions = rbacInstance.getUserPermissions( + currentUser, + AccessRuleContext.DATA, + objectKey, + mandateId=effectiveMandateId, + featureInstanceId=featureInstanceId + ) + + if not permissions.view: + return PaginatedResult(items=[], totalItems=0, totalPages=0) if pagination else [] + + whereConditions = [] + whereValues = [] + + featureInstanceIdForQuery = featureInstanceId + if featureInstanceId and hasattr(modelClass, 'model_fields') and "featureInstanceId" not in modelClass.model_fields: + featureInstanceIdForQuery = None + + rbacWhereClause = buildRbacWhereClause( + permissions, currentUser, table, connector, + mandateId=effectiveMandateId, + featureInstanceId=featureInstanceIdForQuery + ) + if rbacWhereClause: + whereConditions.append(rbacWhereClause["condition"]) + whereValues.extend(rbacWhereClause["values"]) + + if recordFilter: + for field, value in recordFilter.items(): + if isinstance(value, (list, tuple)): + if len(value) == 0: + whereConditions.append("1 = 0") + else: + whereConditions.append(f'"{field}" = ANY(%s)') + whereValues.append(list(value)) + elif value is None: + whereConditions.append(f'"{field}" IS NULL') + else: + whereConditions.append(f'"{field}" = %s') + whereValues.append(value) + + if pagination and pagination.filters: + from modules.connectors.connectorDbPostgre import _get_model_fields + fields = _get_model_fields(modelClass) + validColumns = set(fields.keys()) + for key, val in pagination.filters.items(): + if key == "search" and isinstance(val, str) and val.strip(): + term = f"%{val.strip()}%" + textCols = [c for c, t in fields.items() if t == "TEXT"] + if textCols: + orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols] + whereConditions.append(f"({' OR '.join(orParts)})") + whereValues.extend([term] * len(textCols)) + continue + if key not in validColumns: + continue + if isinstance(val, dict): + op = val.get("operator", "equals") + v = val.get("value", "") + if op in ("equals", "eq"): + whereConditions.append(f'"{key}"::TEXT = %s') + whereValues.append(str(v)) + elif op == "contains": + whereConditions.append(f'"{key}"::TEXT ILIKE %s') + whereValues.append(f"%{v}%") + elif op == "startsWith": + whereConditions.append(f'"{key}"::TEXT ILIKE %s') + whereValues.append(f"{v}%") + elif op == "endsWith": + whereConditions.append(f'"{key}"::TEXT ILIKE %s') + whereValues.append(f"%{v}") + elif op in ("gt", "gte", "lt", "lte"): + sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op] + whereConditions.append(f'"{key}"::TEXT {sqlOp} %s') + whereValues.append(str(v)) + elif op == "between": + fromVal = v.get("from", "") if isinstance(v, dict) else "" + toVal = v.get("to", "") if isinstance(v, dict) else "" + if fromVal and toVal: + whereConditions.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s') + whereValues.extend([str(fromVal), str(toVal)]) + elif fromVal: + whereConditions.append(f'"{key}"::TEXT >= %s') + whereValues.append(str(fromVal)) + elif toVal: + whereConditions.append(f'"{key}"::TEXT <= %s') + whereValues.append(str(toVal)) + else: + whereConditions.append(f'"{key}"::TEXT ILIKE %s') + whereValues.append(str(val)) + + whereClause = " WHERE " + " AND ".join(whereConditions) if whereConditions else "" + countValues = list(whereValues) + + orderParts: List[str] = [] + if pagination and pagination.sort: + from modules.connectors.connectorDbPostgre import _get_model_fields + validColumns = set(_get_model_fields(modelClass).keys()) + for sf in pagination.sort: + if sf.field in validColumns: + direction = "DESC" if sf.direction.lower() == "desc" else "ASC" + orderParts.append(f'"{sf.field}" {direction}') + if not orderParts: + orderParts.append('"id"') + orderByClause = " ORDER BY " + ", ".join(orderParts) + + limitClause = "" + if pagination: + offset = (pagination.page - 1) * pagination.pageSize + limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}" + + with connector.connection.cursor() as cursor: + countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}' + cursor.execute(countSql, countValues) + totalItems = cursor.fetchone()["count"] + + dataSql = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}' + cursor.execute(dataSql, whereValues) + records = [dict(row) for row in cursor.fetchall()] + + from modules.connectors.connectorDbPostgre import _get_model_fields, _parseRecordFields + fields = _get_model_fields(modelClass) + for record in records: + _parseRecordFields(record, fields, f"table {table}") + for fieldName, fieldType in fields.items(): + if fieldType == "JSONB" and fieldName in record and record[fieldName] is None: + modelFields = modelClass.model_fields + fieldInfo = modelFields.get(fieldName) + if fieldInfo: + fieldAnnotation = fieldInfo.annotation + if (fieldAnnotation == list or + (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is list)): + record[fieldName] = [] + elif (fieldAnnotation == dict or + (hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is dict)): + record[fieldName] = {} + + if enrichPermissions: + records = _enrichRecordsWithPermissions(records, permissions, currentUser) + + if pagination: + pageSize = pagination.pageSize + totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0 + return PaginatedResult(items=records, totalItems=totalItems, totalPages=totalPages) + + return records + except Exception as e: + logger.error(f"Error in getRecordsetPaginatedWithRBAC for table {table}: {e}") + return PaginatedResult(items=[], totalItems=0, totalPages=0) if pagination else [] + + +def getDistinctColumnValuesWithRBAC( + connector, + modelClass: Type[BaseModel], + currentUser: User, + column: str, + pagination: Optional[PaginationParams] = None, + recordFilter: Dict[str, Any] = None, + mandateId: Optional[str] = None, + featureInstanceId: Optional[str] = None, + featureCode: Optional[str] = None, +) -> List[str]: + """ + Get sorted distinct values for a column with RBAC filtering at SQL level. + Cross-filtering: removes the requested column from active filters. + """ + import copy + table = modelClass.__name__ + objectKey = buildDataObjectKey(table, featureCode) + + try: + if not connector._ensureTableExists(modelClass): + return [] + + from modules.connectors.connectorDbPostgre import _get_model_fields + fields = _get_model_fields(modelClass) + if column not in fields: + return [] + + dbApp = getRootDbAppConnector() + rbacInstance = RbacClass(connector, dbApp=dbApp) + permissions = rbacInstance.getUserPermissions( + currentUser, AccessRuleContext.DATA, objectKey, + mandateId=mandateId, featureInstanceId=featureInstanceId + ) + if not permissions.view: + return [] + + whereConditions = [] + whereValues = [] + + featureInstanceIdForQuery = featureInstanceId + if featureInstanceId and hasattr(modelClass, 'model_fields') and "featureInstanceId" not in modelClass.model_fields: + featureInstanceIdForQuery = None + + rbacWhereClause = buildRbacWhereClause( + permissions, currentUser, table, connector, + mandateId=mandateId, featureInstanceId=featureInstanceIdForQuery + ) + if rbacWhereClause: + whereConditions.append(rbacWhereClause["condition"]) + whereValues.extend(rbacWhereClause["values"]) + + if recordFilter: + for field, value in recordFilter.items(): + if isinstance(value, (list, tuple)): + if not value: + whereConditions.append("1 = 0") + else: + whereConditions.append(f'"{field}" = ANY(%s)') + whereValues.append(list(value)) + elif value is None: + whereConditions.append(f'"{field}" IS NULL') + else: + whereConditions.append(f'"{field}" = %s') + whereValues.append(value) + + crossPagination = copy.deepcopy(pagination) if pagination else None + if crossPagination and crossPagination.filters: + crossPagination.filters.pop(column, None) + validColumns = set(fields.keys()) + for key, val in crossPagination.filters.items(): + if key == "search" and isinstance(val, str) and val.strip(): + term = f"%{val.strip()}%" + textCols = [c for c, t in fields.items() if t == "TEXT"] + if textCols: + orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols] + whereConditions.append(f"({' OR '.join(orParts)})") + whereValues.extend([term] * len(textCols)) + continue + if key not in validColumns: + continue + if isinstance(val, dict): + op = val.get("operator", "equals") + v = val.get("value", "") + if op in ("equals", "eq"): + whereConditions.append(f'"{key}"::TEXT = %s') + whereValues.append(str(v)) + elif op == "contains": + whereConditions.append(f'"{key}"::TEXT ILIKE %s') + whereValues.append(f"%{v}%") + elif op == "between": + fromVal = v.get("from", "") if isinstance(v, dict) else "" + toVal = v.get("to", "") if isinstance(v, dict) else "" + if fromVal and toVal: + whereConditions.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s') + whereValues.extend([str(fromVal), str(toVal)]) + elif fromVal: + whereConditions.append(f'"{key}"::TEXT >= %s') + whereValues.append(str(fromVal)) + elif toVal: + whereConditions.append(f'"{key}"::TEXT <= %s') + whereValues.append(str(toVal)) + else: + whereConditions.append(f'"{key}"::TEXT ILIKE %s') + whereValues.append(str(v) if isinstance(v, str) else str(val)) + else: + whereConditions.append(f'"{key}"::TEXT ILIKE %s') + whereValues.append(str(val)) + + whereClause = " WHERE " + " AND ".join(whereConditions) if whereConditions else "" + notNullCond = f'"{column}" IS NOT NULL AND "{column}"::TEXT != \'\'' + if whereClause: + whereClause += f" AND {notNullCond}" + else: + whereClause = f" WHERE {notNullCond}" + + sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{whereClause} ORDER BY val' + + with connector.connection.cursor() as cursor: + cursor.execute(sql, whereValues) + return [row["val"] for row in cursor.fetchall()] + except Exception as e: + logger.error(f"Error in getDistinctColumnValuesWithRBAC for {table}.{column}: {e}") + return [] + + def buildRbacWhereClause( permissions: UserPermissions, currentUser: User, diff --git a/modules/routes/routeAdminAutomationEvents.py b/modules/routes/routeAdminAutomationEvents.py index d184ae76..47d3ac9c 100644 --- a/modules/routes/routeAdminAutomationEvents.py +++ b/modules/routes/routeAdminAutomationEvents.py @@ -5,15 +5,19 @@ Admin automation events routes for the backend API. Sysadmin-only endpoints for viewing and controlling scheduler events. """ -from fastapi import APIRouter, HTTPException, Depends, Path, Request, Response -from typing import List, Dict, Any +from fastapi import APIRouter, HTTPException, Depends, Path, Request, Response, Query +from typing import List, Dict, Any, Optional from fastapi import status import logging +import json +import math # Import interfaces and models from feature containers import modules.features.automation.interfaceFeatureAutomation as interfaceAutomation from modules.auth import limiter, getRequestContext, requireSysAdminRole, RequestContext from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelPagination import PaginationParams, PaginationMetadata, normalize_pagination_dict +from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues # Configure logger logger = logging.getLogger(__name__) @@ -31,122 +35,176 @@ router = APIRouter( } ) +def _buildEnrichedAutomationEvents(currentUser: User) -> List[Dict[str, Any]]: + """Build the full enriched automation events list.""" + from modules.shared.eventManagement import eventManager + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.features.automation.mainAutomation import getAutomationServices + + if not eventManager.scheduler: + return [] + + jobs = [] + for job in eventManager.scheduler.get_jobs(): + if job.id.startswith("automation."): + automationId = job.id.replace("automation.", "") + jobs.append({ + "eventId": job.id, + "id": job.id, + "automationId": automationId, + "nextRunTime": str(job.next_run_time) if job.next_run_time else None, + "trigger": str(job.trigger) if job.trigger else None, + "name": "", + "createdBy": "", + "mandate": "", + "featureInstance": "" + }) + + if jobs: + try: + rootInterface = getRootInterface() + eventUser = rootInterface.getUserByUsername("event") + if eventUser: + services = getAutomationServices(currentUser, mandateId=None, featureInstanceId=None) + allAutomations = services.interfaceDbAutomation.getAllAutomationDefinitionsWithRBAC(eventUser) + + automationLookup = {} + for a in allAutomations: + aId = a.get("id", "") if isinstance(a, dict) else getattr(a, "id", "") + automationLookup[aId] = a + + _userCache: Dict[str, str] = {} + _mandateCache: Dict[str, str] = {} + _featureCache: Dict[str, str] = {} + + def _resolveUsername(userId): + if not userId: return "" + if userId not in _userCache: + try: + user = rootInterface.getUser(userId) + _userCache[userId] = user.username if user else userId[:8] + except Exception: + _userCache[userId] = userId[:8] + return _userCache[userId] + + def _resolveMandateLabel(mandateId): + if not mandateId: return "" + if mandateId not in _mandateCache: + try: + mandate = rootInterface.getMandate(mandateId) + _mandateCache[mandateId] = getattr(mandate, "label", None) or mandateId[:8] + except Exception: + _mandateCache[mandateId] = mandateId[:8] + return _mandateCache[mandateId] + + def _resolveFeatureLabel(featureInstanceId): + if not featureInstanceId: return "" + if featureInstanceId not in _featureCache: + try: + instance = rootInterface.getFeatureInstance(featureInstanceId) + _featureCache[featureInstanceId] = getattr(instance, "label", None) or getattr(instance, "featureCode", None) or featureInstanceId[:8] + except Exception: + _featureCache[featureInstanceId] = featureInstanceId[:8] + return _featureCache[featureInstanceId] + + for job in jobs: + automation = automationLookup.get(job["automationId"]) + if automation: + if isinstance(automation, dict): + job["name"] = automation.get("label", "") + job["createdBy"] = _resolveUsername(automation.get("_createdBy", "")) + job["mandate"] = _resolveMandateLabel(automation.get("mandateId", "")) + job["featureInstance"] = _resolveFeatureLabel(automation.get("featureInstanceId", "")) + else: + job["name"] = getattr(automation, "label", "") + job["createdBy"] = _resolveUsername(getattr(automation, "_createdBy", "")) + job["mandate"] = _resolveMandateLabel(getattr(automation, "mandateId", "")) + job["featureInstance"] = _resolveFeatureLabel(getattr(automation, "featureInstanceId", "")) + else: + job["name"] = "(orphaned)" + except Exception as e: + logger.warning(f"Could not enrich automation events with context: {e}") + + return jobs + + @router.get("") @limiter.limit("30/minute") def get_all_automation_events( request: Request, - currentUser: User = Depends(requireSysAdminRole) -) -> List[Dict[str, Any]]: - """ - Get all active scheduler jobs (sysadmin only). - Each job is enriched with context from its automation definition - (name, mandate, feature instance, creator) for readability. - """ + pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"), + currentUser: User = Depends(requireSysAdminRole), +): + """Get all active scheduler jobs with pagination support (sysadmin only).""" try: - from modules.shared.eventManagement import eventManager - from modules.interfaces.interfaceDbApp import getRootInterface - from modules.features.automation.mainAutomation import getAutomationServices - - if not eventManager.scheduler: - return [] - - # 1. Collect all scheduler jobs - jobs = [] - automationIds = [] - for job in eventManager.scheduler.get_jobs(): - if job.id.startswith("automation."): - automationId = job.id.replace("automation.", "") - automationIds.append(automationId) - jobs.append({ - "eventId": job.id, - "automationId": automationId, - "nextRunTime": str(job.next_run_time) if job.next_run_time else None, - "trigger": str(job.trigger) if job.trigger else None, - "name": "", - "createdBy": "", - "mandate": "", - "featureInstance": "" - }) - - # 2. Enrich with context from automation definitions - if jobs: + paginationParams: Optional[PaginationParams] = None + if pagination: try: - rootInterface = getRootInterface() - eventUser = rootInterface.getUserByUsername("event") - if eventUser: - services = getAutomationServices(currentUser, mandateId=None, featureInstanceId=None) - allAutomations = services.interfaceDbAutomation.getAllAutomationDefinitionsWithRBAC(eventUser) - - # Build lookup by automation ID - automationLookup = {} - for a in allAutomations: - aId = a.get("id", "") if isinstance(a, dict) else getattr(a, "id", "") - automationLookup[aId] = a - - # Caches for resolving UUIDs to names - _userCache = {} - _mandateCache = {} - _featureCache = {} - - def _resolveUsername(userId): - if not userId: - return "" - if userId not in _userCache: - try: - user = rootInterface.getUser(userId) - _userCache[userId] = user.username if user else userId[:8] - except Exception: - _userCache[userId] = userId[:8] - return _userCache[userId] - - def _resolveMandateLabel(mandateId): - if not mandateId: - return "" - if mandateId not in _mandateCache: - try: - mandate = rootInterface.getMandate(mandateId) - _mandateCache[mandateId] = getattr(mandate, "label", None) or mandateId[:8] - except Exception: - _mandateCache[mandateId] = mandateId[:8] - return _mandateCache[mandateId] - - def _resolveFeatureLabel(featureInstanceId): - if not featureInstanceId: - return "" - if featureInstanceId not in _featureCache: - try: - instance = rootInterface.getFeatureInstance(featureInstanceId) - _featureCache[featureInstanceId] = getattr(instance, "label", None) or getattr(instance, "featureCode", None) or featureInstanceId[:8] - except Exception: - _featureCache[featureInstanceId] = featureInstanceId[:8] - return _featureCache[featureInstanceId] - - # Enrich each job - for job in jobs: - automation = automationLookup.get(job["automationId"]) - if automation: - if isinstance(automation, dict): - job["name"] = automation.get("label", "") - job["createdBy"] = _resolveUsername(automation.get("_createdBy", "")) - job["mandate"] = _resolveMandateLabel(automation.get("mandateId", "")) - job["featureInstance"] = _resolveFeatureLabel(automation.get("featureInstanceId", "")) - else: - job["name"] = getattr(automation, "label", "") - job["createdBy"] = _resolveUsername(getattr(automation, "_createdBy", "")) - job["mandate"] = _resolveMandateLabel(getattr(automation, "mandateId", "")) - job["featureInstance"] = _resolveFeatureLabel(getattr(automation, "featureInstanceId", "")) - else: - job["name"] = "(orphaned)" - except Exception as e: - logger.warning(f"Could not enrich automation events with context: {e}") - - return jobs + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + paginationParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError) as e: + raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}") + + enriched = _buildEnrichedAutomationEvents(currentUser) + filtered = _applyFiltersAndSort(enriched, paginationParams) + + if paginationParams: + totalItems = len(filtered) + totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0 + startIdx = (paginationParams.page - 1) * paginationParams.pageSize + endIdx = startIdx + paginationParams.pageSize + return { + "items": filtered[startIdx:endIdx], + "pagination": PaginationMetadata( + currentPage=paginationParams.page, + pageSize=paginationParams.pageSize, + totalItems=totalItems, + totalPages=totalPages, + sort=paginationParams.sort, + filters=paginationParams.filters, + ).model_dump(), + } + + return {"items": enriched, "pagination": None} + except HTTPException: + raise except Exception as e: logger.error(f"Error getting automation events: {str(e)}") - raise HTTPException( - status_code=500, - detail=f"Error getting automation events: {str(e)}" - ) + raise HTTPException(status_code=500, detail=f"Error getting automation events: {str(e)}") + + +@router.get("/filter-values") +@limiter.limit("60/minute") +def get_automation_event_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + currentUser: User = Depends(requireSysAdminRole), +): + """Return distinct filter values for a column in automation events.""" + try: + crossFilterParams: Optional[PaginationParams] = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError): + pass + + enriched = _buildEnrichedAutomationEvents(currentUser) + crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams) + return _extractDistinctValues(crossFiltered, column) + except Exception as e: + logger.error(f"Error getting filter values: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) @router.post("/sync") @limiter.limit("5/minute") diff --git a/modules/routes/routeAdminAutomationLogs.py b/modules/routes/routeAdminAutomationLogs.py new file mode 100644 index 00000000..8b4d897b --- /dev/null +++ b/modules/routes/routeAdminAutomationLogs.py @@ -0,0 +1,207 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Admin automation execution logs routes. +SysAdmin-only endpoints for viewing consolidated automation execution history +across all mandates and feature instances. +""" + +from fastapi import APIRouter, HTTPException, Depends, Request, Query +from typing import List, Dict, Any, Optional +import logging +import json +import math +import uuid + +from modules.auth import limiter, requireSysAdminRole +from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelPagination import PaginationParams, PaginationMetadata, normalize_pagination_dict +from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues + +logger = logging.getLogger(__name__) + +router = APIRouter( + prefix="/api/admin/automation-logs", + tags=["Admin Automation Logs"], + responses={ + 401: {"description": "Unauthorized"}, + 403: {"description": "Forbidden - Sysadmin only"}, + 500: {"description": "Internal server error"}, + }, +) + + +def _buildFlattenedExecutionLogs(currentUser: User) -> List[Dict[str, Any]]: + """Flatten executionLogs from all AutomationDefinitions across all mandates. + Called from a SysAdmin-only endpoint — bypasses RBAC, reads directly from DB.""" + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.features.automation.mainAutomation import getAutomationServices + from modules.features.automation.datamodelFeatureAutomation import AutomationDefinition + + rootInterface = getRootInterface() + services = getAutomationServices(currentUser, mandateId=None, featureInstanceId=None) + allAutomations = services.interfaceDbAutomation.db.getRecordset(AutomationDefinition) + + userCache: Dict[str, str] = {} + mandateCache: Dict[str, str] = {} + featureCache: Dict[str, str] = {} + + def _resolveUsername(userId: str) -> str: + if not userId: + return "" + if userId not in userCache: + try: + user = rootInterface.getUser(userId) + userCache[userId] = user.username if user else userId[:8] + except Exception: + userCache[userId] = userId[:8] + return userCache[userId] + + def _resolveMandateLabel(mandateId: str) -> str: + if not mandateId: + return "" + if mandateId not in mandateCache: + try: + mandate = rootInterface.getMandate(mandateId) + mandateCache[mandateId] = getattr(mandate, "label", None) or mandateId[:8] + except Exception: + mandateCache[mandateId] = mandateId[:8] + return mandateCache[mandateId] + + def _resolveFeatureLabel(featureInstanceId: str) -> str: + if not featureInstanceId: + return "" + if featureInstanceId not in featureCache: + try: + instance = rootInterface.getFeatureInstance(featureInstanceId) + featureCache[featureInstanceId] = ( + getattr(instance, "label", None) + or getattr(instance, "featureCode", None) + or featureInstanceId[:8] + ) + except Exception: + featureCache[featureInstanceId] = featureInstanceId[:8] + return featureCache[featureInstanceId] + + flatLogs: List[Dict[str, Any]] = [] + + for automation in allAutomations: + if isinstance(automation, dict): + automationId = automation.get("id", "") + automationLabel = automation.get("label", "") + mandateId = automation.get("mandateId", "") + featureInstanceId = automation.get("featureInstanceId", "") + createdBy = automation.get("_createdBy", "") + logs = automation.get("executionLogs") or [] + else: + automationId = getattr(automation, "id", "") + automationLabel = getattr(automation, "label", "") + mandateId = getattr(automation, "mandateId", "") + featureInstanceId = getattr(automation, "featureInstanceId", "") + createdBy = getattr(automation, "_createdBy", "") + logs = getattr(automation, "executionLogs", None) or [] + + mandateName = _resolveMandateLabel(mandateId) + featureInstanceName = _resolveFeatureLabel(featureInstanceId) + executedByName = _resolveUsername(createdBy) + + for log in logs: + timestamp = log.get("timestamp", 0) if isinstance(log, dict) else 0 + status = log.get("status", "") if isinstance(log, dict) else "" + workflowId = log.get("workflowId", "") if isinstance(log, dict) else "" + messages = log.get("messages", []) if isinstance(log, dict) else [] + + flatLogs.append({ + "id": str(uuid.uuid4()), + "timestamp": timestamp, + "automationId": automationId, + "automationLabel": automationLabel, + "mandateName": mandateName, + "featureInstanceName": featureInstanceName, + "executedBy": executedByName, + "status": status, + "workflowId": workflowId, + "messages": "; ".join(messages) if messages else "", + }) + + flatLogs.sort(key=lambda x: x.get("timestamp", 0), reverse=True) + return flatLogs + + +@router.get("") +@limiter.limit("30/minute") +def get_all_automation_logs( + request: Request, + pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"), + currentUser: User = Depends(requireSysAdminRole), +): + """Get consolidated execution logs from all automations (sysadmin only).""" + try: + paginationParams: Optional[PaginationParams] = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + paginationParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError) as e: + raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}") + + logs = _buildFlattenedExecutionLogs(currentUser) + filtered = _applyFiltersAndSort(logs, paginationParams) + + if paginationParams: + totalItems = len(filtered) + totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0 + startIdx = (paginationParams.page - 1) * paginationParams.pageSize + endIdx = startIdx + paginationParams.pageSize + return { + "items": filtered[startIdx:endIdx], + "pagination": PaginationMetadata( + currentPage=paginationParams.page, + pageSize=paginationParams.pageSize, + totalItems=totalItems, + totalPages=totalPages, + sort=paginationParams.sort, + filters=paginationParams.filters, + ).model_dump(), + } + + return {"items": logs, "pagination": None} + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting automation logs: {str(e)}") + raise HTTPException(status_code=500, detail=f"Error getting automation logs: {str(e)}") + + +@router.get("/filter-values") +@limiter.limit("60/minute") +def get_automation_log_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + currentUser: User = Depends(requireSysAdminRole), +): + """Return distinct filter values for a column in automation logs.""" + try: + crossFilterParams: Optional[PaginationParams] = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError): + pass + + logs = _buildFlattenedExecutionLogs(currentUser) + crossFiltered = _applyFiltersAndSort(logs, crossFilterParams) + return _extractDistinctValues(crossFiltered, column) + except Exception as e: + logger.error(f"Error getting filter values: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/modules/routes/routeAdminFeatures.py b/modules/routes/routeAdminFeatures.py index 1bada7fa..c95c0b1b 100644 --- a/modules/routes/routeAdminFeatures.py +++ b/modules/routes/routeAdminFeatures.py @@ -14,7 +14,11 @@ from fastapi import APIRouter, HTTPException, Depends, Request, Query from typing import List, Dict, Any, Optional from fastapi import status import logging +import json +import math from pydantic import BaseModel, Field +from modules.datamodels.datamodelPagination import PaginationParams, PaginationMetadata, normalize_pagination_dict +from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues from modules.auth import limiter, getRequestContext, RequestContext, requireSysAdminRole from modules.datamodels.datamodelUam import User, UserInDB @@ -433,6 +437,35 @@ def list_feature_instances( ) +@router.get("/instances/filter-values") +@limiter.limit("60/minute") +def get_feature_instance_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + featureCode: Optional[str] = Query(None, description="Filter by feature code"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in feature instances.""" + if not context.mandateId: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="X-Mandate-Id header is required") + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + rootInterface = getRootInterface() + featureInterface = getFeatureInterface(rootInterface.db) + instances = featureInterface.getFeatureInstancesForMandate( + mandateId=str(context.mandateId), + featureCode=featureCode + ) + items = [inst.model_dump() for inst in instances] + return _handleFilterValuesRequest(items, column, pagination) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting filter values for feature instances: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + + @router.get("/instances/{instanceId}", response_model=Dict[str, Any]) @limiter.limit("60/minute") def get_feature_instance( @@ -518,15 +551,40 @@ def create_feature_instance( detail=f"Feature '{data.featureCode}' not found" ) + # Subscription capacity check + mandateIdStr = str(context.mandateId) + try: + from modules.interfaces.interfaceDbSubscription import getInterface as _getSubIf + from modules.security.rootAccess import getRootUser + _subIf = _getSubIf(getRootUser(), mandateIdStr) + _subIf.assertCapacity(mandateIdStr, "featureInstances", delta=1) + except HTTPException: + raise + except Exception as capErr: + if "SubscriptionCapacityException" in type(capErr).__name__: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(capErr), + ) + instance = featureInterface.createFeatureInstance( featureCode=data.featureCode, - mandateId=str(context.mandateId), + mandateId=mandateIdStr, label=data.label, enabled=data.enabled, copyTemplateRoles=data.copyTemplateRoles, config=data.config ) - + + # Sync Stripe quantity after successful creation + try: + from modules.interfaces.interfaceDbSubscription import getInterface as _getSubIf2 + from modules.security.rootAccess import getRootUser as _getRU + _subIf2 = _getSubIf2(_getRU(), mandateIdStr) + _subIf2.syncQuantityToStripe(mandateIdStr) + except Exception: + pass + logger.info( f"User {context.user.id} created feature instance '{data.label}' " f"for feature '{data.featureCode}' in mandate {context.mandateId}" @@ -758,27 +816,55 @@ def sync_instance_roles( # Template Role Endpoints (SysAdmin only) # ============================================================================= -@router.get("/templates/roles", response_model=List[Dict[str, Any]]) +def _buildTemplateRolesList(featureCode: Optional[str] = None) -> List[Dict[str, Any]]: + """Build the full template roles list.""" + rootInterface = getRootInterface() + featureInterface = getFeatureInterface(rootInterface.db) + roles = featureInterface.getTemplateRoles(featureCode) + return [r.model_dump() for r in roles] + + +@router.get("/templates/roles") @limiter.limit("60/minute") def list_template_roles( request: Request, featureCode: Optional[str] = Query(None, description="Filter by feature code"), - sysAdmin: User = Depends(requireSysAdminRole) -) -> List[Dict[str, Any]]: - """ - List global template roles. - - SysAdmin only - returns template roles that are copied to new feature instances. - - Args: - featureCode: Optional filter by feature code - """ + pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"), + sysAdmin: User = Depends(requireSysAdminRole), +): + """List global template roles with pagination support.""" try: - rootInterface = getRootInterface() - featureInterface = getFeatureInterface(rootInterface.db) - - roles = featureInterface.getTemplateRoles(featureCode) - return [r.model_dump() for r in roles] + paginationParams: Optional[PaginationParams] = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + paginationParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError) as e: + raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}") + + enriched = _buildTemplateRolesList(featureCode) + filtered = _applyFiltersAndSort(enriched, paginationParams) + + if paginationParams: + totalItems = len(filtered) + totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0 + startIdx = (paginationParams.page - 1) * paginationParams.pageSize + endIdx = startIdx + paginationParams.pageSize + return { + "items": filtered[startIdx:endIdx], + "pagination": PaginationMetadata( + currentPage=paginationParams.page, + pageSize=paginationParams.pageSize, + totalItems=totalItems, + totalPages=totalPages, + sort=paginationParams.sort, + filters=paginationParams.filters, + ).model_dump(), + } + + return {"items": enriched, "pagination": None} except Exception as e: logger.error(f"Error listing template roles: {e}") @@ -788,6 +874,39 @@ def list_template_roles( ) +@router.get("/templates/roles/filter-values") +@limiter.limit("60/minute") +def get_template_role_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + featureCode: Optional[str] = Query(None, description="Filter by feature code"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + sysAdmin: User = Depends(requireSysAdminRole), +): + """Return distinct filter values for a column in template roles.""" + try: + crossFilterParams: Optional[PaginationParams] = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError): + pass + + enriched = _buildTemplateRolesList(featureCode) + crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams) + return _extractDistinctValues(crossFiltered, column) + except Exception as e: + logger.error(f"Error getting filter values: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/templates/roles", response_model=Dict[str, Any]) @limiter.limit("10/minute") def create_template_role( @@ -951,6 +1070,56 @@ def list_feature_instance_users( ) +@router.get("/instances/{instanceId}/users/filter-values") +@limiter.limit("60/minute") +def get_feature_instance_users_filter_values( + request: Request, + instanceId: str, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in feature instance users.""" + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + rootInterface = getRootInterface() + featureInterface = getFeatureInterface(rootInterface.db) + instance = featureInterface.getFeatureInstance(instanceId) + if not instance: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Feature instance '{instanceId}' not found") + if context.mandateId and str(instance.mandateId) != str(context.mandateId): + if not context.hasSysAdminRole: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied to this feature instance") + featureAccesses = rootInterface.getFeatureAccessesByInstance(instanceId) + result = [] + for fa in featureAccesses: + user = rootInterface.getUser(str(fa.userId)) + if not user: + continue + roleIds = rootInterface.getRoleIdsForFeatureAccess(str(fa.id)) + roleLabels = [] + for roleId in roleIds: + role = rootInterface.getRole(roleId) + if role: + roleLabels.append(role.roleLabel) + result.append({ + "id": str(fa.id), + "userId": str(fa.userId), + "username": user.username, + "email": user.email, + "fullName": user.fullName, + "roleIds": roleIds, + "roleLabels": roleLabels, + "enabled": fa.enabled + }) + return _handleFilterValuesRequest(result, column, pagination) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting filter values for feature instance users: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + + @router.post("/instances/{instanceId}/users", response_model=Dict[str, Any]) @limiter.limit("30/minute") def add_user_to_feature_instance( @@ -1239,7 +1408,7 @@ def update_feature_instance_user_roles( "userId": userId, "featureInstanceId": instanceId, "roleIds": data.roleIds, - "enabled": data.enabled if data.enabled is not None else existingAccess[0].get("enabled", True) + "enabled": data.enabled if data.enabled is not None else bool(existingAccess.enabled), } except HTTPException: diff --git a/modules/routes/routeAdminRbacRules.py b/modules/routes/routeAdminRbacRules.py index a7d4637e..3778d227 100644 --- a/modules/routes/routeAdminRbacRules.py +++ b/modules/routes/routeAdminRbacRules.py @@ -394,7 +394,10 @@ def get_access_rules( ) # Get rules with optional pagination - # MandateAdmin: fetch all then filter by admin's mandates + # MandateAdmin: fetch all then filter by admin's mandates. + # NOTE: Cannot use DB-level pagination for MandateAdmin because + # _isRoleInAdminMandates requires joining Role → mandateId which + # isn't expressible via getRecordsetPaginated's recordFilter. if not isSysAdmin: allRules = interface.getAccessRules( roleLabel=roleLabel, @@ -814,6 +817,11 @@ def list_roles( By default, only returns true global roles (mandateId=None, featureInstanceId=None, featureCode=None). Feature template roles are managed via /api/features/templates/roles. + NOTE: Base query (getAllRoles) already uses db.getRecordsetPaginated() internally. + However pagination=None is passed here because post-processing adds computed fields + (userCount, scopeType) and applies scope/mandate/template filtering that cannot run + at the DB level. In-memory pagination is applied after all transformations. + Args: pagination: Optional pagination parameters (includes search, filters, sort) includeTemplates: If True, also include feature template roles (featureCode != None) @@ -983,6 +991,77 @@ def list_roles( ) +@router.get("/roles/filter-values") +@limiter.limit("60/minute") +def get_roles_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + includeTemplates: bool = Query(False, description="Include feature template roles"), + mandateId: Optional[str] = Query(None, description="Include mandate-specific roles for this mandate"), + scopeFilter: Optional[str] = Query(None, description="Filter by scope: 'all', 'mandate', 'global', 'system'"), + reqContext: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in roles.""" + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + isSysAdmin = reqContext.hasSysAdminRole + adminMandateIds = [] if isSysAdmin else _getAdminMandateIds(reqContext) + if not isSysAdmin and not adminMandateIds: + raise HTTPException(status_code=403, detail="Admin role required") + + interface = getRootInterface() + dbRoles = interface.getAllRoles(pagination=None) + roleCounts = interface.countRoleAssignments() + + def _computeScopeType(role) -> str: + if role.mandateId: + return "mandate" + if role.isSystemRole: + return "system" + return "global" + + result = [] + for role in dbRoles: + if role.featureInstanceId is not None: + continue + if mandateId: + if role.mandateId != mandateId: + continue + else: + if role.mandateId is not None: + continue + if not includeTemplates and role.featureCode is not None: + continue + scopeType = _computeScopeType(role) + if scopeFilter and scopeFilter != 'all': + if scopeFilter == 'mandate' and scopeType != 'mandate': + continue + if scopeFilter == 'global' and scopeType not in ('global', 'system'): + continue + if scopeFilter == 'system' and scopeType != 'system': + continue + result.append({ + "id": role.id, + "roleLabel": role.roleLabel, + "description": role.description, + "mandateId": role.mandateId, + "featureInstanceId": role.featureInstanceId, + "featureCode": role.featureCode, + "userCount": roleCounts.get(str(role.id), 0), + "isSystemRole": role.isSystemRole, + "scopeType": scopeType + }) + if not isSysAdmin: + result = [r for r in result if r.get("mandateId") and str(r["mandateId"]) in adminMandateIds] + return _handleFilterValuesRequest(result, column, pagination) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting filter values for roles: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/roles", response_model=Dict[str, Any]) @limiter.limit("30/minute") def create_role( diff --git a/modules/routes/routeAdminUserAccessOverview.py b/modules/routes/routeAdminUserAccessOverview.py index 758eff65..9b19fc41 100644 --- a/modules/routes/routeAdminUserAccessOverview.py +++ b/modules/routes/routeAdminUserAccessOverview.py @@ -278,7 +278,8 @@ def getUserAccessOverview( # Get mandate name mandate = interface.getMandate(umMandateId) mandateName = mandate.name if mandate else umMandateId - + mandateLabel = (mandate.label or None) if mandate else None + # Get roles for this UserMandate using interface method umRoles = interface.getUserMandateRoles(umId) @@ -368,6 +369,7 @@ def getUserAccessOverview( mandatesInfo.append({ "id": umMandateId, "name": mandateName, + "label": mandateLabel, "roleIds": mandateRoleIds, "featureInstances": featureInstancesInfo, }) diff --git a/modules/routes/routeBilling.py b/modules/routes/routeBilling.py index eeee79a7..04412752 100644 --- a/modules/routes/routeBilling.py +++ b/modules/routes/routeBilling.py @@ -22,8 +22,10 @@ from modules.auth import limiter, requireSysAdminRole, getRequestContext, Reques # Import billing components from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface, _getRootInterface from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService +import json +import math from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict -from modules.routes.routeDataUsers import _applyFiltersAndSort +from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues, _handleFilterValuesRequest from modules.datamodels.datamodelBilling import ( BillingAccount, BillingTransaction, @@ -226,9 +228,9 @@ def _filterTransactionsByScope(transactions: list, scope: BillingDataScope) -> l # ============================================================================= class CreditAddRequest(BaseModel): - """Request model for adding credit to an account.""" + """Request model for adding or deducting credit from an account.""" userId: Optional[str] = Field(None, description="Target user ID (for PREPAY_USER model)") - amount: float = Field(..., gt=0, description="Amount to credit in CHF") + amount: float = Field(..., description="Amount in CHF. Positive = credit, negative = deduction. Must not be zero.") description: str = Field(default="Manual credit", description="Transaction description") @@ -358,19 +360,8 @@ class UserTransactionResponse(BaseModel): def _getStripeClient(): """Initialize and return configured Stripe SDK module.""" - import stripe - from modules.shared.configuration import APP_CONFIG - - api_version = APP_CONFIG.get("STRIPE_API_VERSION") - if api_version: - stripe.api_version = api_version - - secret_key = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY") - if not secret_key: - raise ValueError("STRIPE_SECRET_KEY_SECRET not configured") - - stripe.api_key = secret_key - return stripe + from modules.shared.stripeClient import getStripeClient + return getStripeClient() def _creditStripeSessionIfNeeded( @@ -835,20 +826,27 @@ def addCredit( else: raise HTTPException(status_code=400, detail=f"Cannot add credit to {billingModel.value} billing model") - # Create credit transaction + if creditRequest.amount == 0: + raise HTTPException(status_code=400, detail="Amount must not be zero") + from modules.datamodels.datamodelBilling import BillingTransaction - + + isDeduction = creditRequest.amount < 0 + txType = TransactionTypeEnum.DEBIT if isDeduction else TransactionTypeEnum.CREDIT + absAmount = abs(creditRequest.amount) + transaction = BillingTransaction( accountId=account["id"], - transactionType=TransactionTypeEnum.CREDIT, - amount=creditRequest.amount, + transactionType=txType, + amount=absAmount, description=creditRequest.description, referenceType=ReferenceTypeEnum.ADMIN ) result = billingInterface.createTransaction(transaction) - logger.info(f"Added {creditRequest.amount} CHF credit to account {account['id']} in mandate {targetMandateId}") + action = "Deducted" if isDeduction else "Added" + logger.info(f"{action} {absAmount} CHF to account {account['id']} in mandate {targetMandateId}") return result @@ -1006,13 +1004,33 @@ async def stripeWebhook( logger.info(f"Stripe webhook received: event={event.id}, type={event.type}") - accepted_event_types = {"checkout.session.completed", "checkout.session.async_payment_succeeded"} - if event.type not in accepted_event_types: + # Subscription-related events + subscriptionEventTypes = { + "customer.subscription.updated", + "customer.subscription.deleted", + "invoice.paid", + "invoice.payment_failed", + "customer.subscription.trial_will_end", + } + + # Checkout events (existing) + checkoutEventTypes = {"checkout.session.completed", "checkout.session.async_payment_succeeded"} + + if event.type in subscriptionEventTypes: + _handleSubscriptionWebhook(event) return {"received": True} - + + if event.type not in checkoutEventTypes: + return {"received": True} + session = event.data.object event_id = event.id + sessionMode = session.get("mode") if hasattr(session, "get") else getattr(session, "mode", None) + if sessionMode == "subscription": + _handleSubscriptionCheckoutCompleted(session, event_id) + return {"received": True} + billingInterface = _getRootInterface() if billingInterface.getStripeWebhookEventByEventId(event_id): logger.info(f"Stripe event {event_id} already processed, skipping") @@ -1027,6 +1045,257 @@ async def stripeWebhook( return {"received": True} +def _handleSubscriptionCheckoutCompleted(session, eventId: str) -> None: + """Handle checkout.session.completed for mode=subscription. + Resolves the local PENDING record by ID from webhook metadata and transitions it.""" + from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface + from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, _getPlan + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + getService as getSubscriptionService, + _notifySubscriptionChange, + ) + from modules.security.rootAccess import getRootUser + from datetime import datetime, timezone + + metadata = {} + if hasattr(session, "get"): + metadata = session.get("metadata") or {} + subscriptionRecordId = metadata.get("subscriptionRecordId") + mandateId = metadata.get("mandateId") + planKey = metadata.get("planKey", "") + + platformUrl = metadata.get("platformUrl", "") + + if not subscriptionRecordId: + stripeSub = session.get("subscription") + if stripeSub: + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + subObj = stripe.Subscription.retrieve(stripeSub) + metadata = subObj.get("metadata") or {} + subscriptionRecordId = metadata.get("subscriptionRecordId") + mandateId = metadata.get("mandateId") + planKey = metadata.get("planKey", "") + platformUrl = platformUrl or metadata.get("platformUrl", "") + except Exception: + pass + + stripeSubId = session.get("subscription") + + if not mandateId or not subscriptionRecordId: + logger.warning("Subscription checkout missing metadata: %s", metadata) + return + + subInterface = getSubRootInterface() + rootUser = getRootUser() + + sub = subInterface.getById(subscriptionRecordId) + if not sub: + logger.error("Subscription record %s not found for checkout webhook", subscriptionRecordId) + return + if sub.get("status") != SubscriptionStatusEnum.PENDING.value: + logger.warning("Subscription %s is %s, expected PENDING — skipping", subscriptionRecordId, sub.get("status")) + return + + stripeData: Dict[str, Any] = {} + if stripeSubId: + stripeData["stripeSubscriptionId"] = stripeSubId + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + stripeSub = stripe.Subscription.retrieve(stripeSubId, expand=["items"]) + + if stripeSub.get("current_period_start"): + stripeData["currentPeriodStart"] = datetime.fromtimestamp( + stripeSub["current_period_start"], tz=timezone.utc + ).isoformat() + if stripeSub.get("current_period_end"): + stripeData["currentPeriodEnd"] = datetime.fromtimestamp( + stripeSub["current_period_end"], tz=timezone.utc + ).isoformat() + + from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import getStripePricesForPlan + priceMapping = getStripePricesForPlan(planKey) + for item in stripeSub.get("items", {}).get("data", []): + priceId = item.get("price", {}).get("id", "") + if priceMapping and priceId == priceMapping.stripePriceIdUsers: + stripeData["stripeItemIdUsers"] = item["id"] + elif priceMapping and priceId == priceMapping.stripePriceIdInstances: + stripeData["stripeItemIdInstances"] = item["id"] + except Exception as e: + logger.error("Error retrieving Stripe subscription %s: %s", stripeSubId, e) + + if stripeData: + subInterface.updateFields(subscriptionRecordId, stripeData) + + operative = subInterface.getOperativeForMandate(mandateId) + hasActivePredecessor = operative is not None and operative["id"] != subscriptionRecordId + + if hasActivePredecessor: + toStatus = SubscriptionStatusEnum.SCHEDULED + if operative.get("recurring", True): + operativeStripeId = operative.get("stripeSubscriptionId") + if operativeStripeId: + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + stripe.Subscription.modify(operativeStripeId, cancel_at_period_end=True) + except Exception as e: + logger.error("Failed to set cancel_at_period_end on predecessor %s: %s", operativeStripeId, e) + subInterface.updateFields(operative["id"], {"recurring": False}) + effectiveFrom = operative.get("currentPeriodEnd") + if effectiveFrom: + subInterface.updateFields(subscriptionRecordId, {"effectiveFrom": effectiveFrom}) + else: + toStatus = SubscriptionStatusEnum.ACTIVE + + try: + subInterface.transitionStatus( + subscriptionRecordId, SubscriptionStatusEnum.PENDING, toStatus, + {"recurring": True}, + ) + except Exception as e: + logger.error("Failed to transition subscription %s: %s", subscriptionRecordId, e) + return + + subService = getSubscriptionService(rootUser, mandateId) + subService.invalidateCache(mandateId) + + if toStatus == SubscriptionStatusEnum.ACTIVE: + plan = _getPlan(planKey) + updatedSub = subInterface.getById(subscriptionRecordId) + _notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=updatedSub, platformUrl=platformUrl) + + logger.info( + "Checkout completed: sub=%s -> %s, mandate=%s, plan=%s", + subscriptionRecordId, toStatus.value, mandateId, planKey, + ) + + +def _handleSubscriptionWebhook(event) -> None: + """Process Stripe subscription webhook events. + All record resolution is by stripeSubscriptionId — no mandate-based guessing.""" + from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface + from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, _getPlan + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + getService as getSubscriptionService, + _notifySubscriptionChange, + ) + from modules.security.rootAccess import getRootUser + from datetime import datetime, timezone + + obj = event.data.object + stripeSubId = obj.get("id") if event.type.startswith("customer.subscription") else obj.get("subscription") + if not stripeSubId: + logger.warning("Subscription webhook %s has no subscription ID", event.type) + return + + subInterface = getSubRootInterface() + sub = subInterface.getByStripeSubscriptionId(stripeSubId) + if not sub: + logger.warning("No local record for Stripe subscription %s (event: %s)", stripeSubId, event.type) + return + + subId = sub["id"] + mandateId = sub["mandateId"] + currentStatus = SubscriptionStatusEnum(sub["status"]) + rootUser = getRootUser() + subService = getSubscriptionService(rootUser, mandateId) + + subMetadata = obj.get("metadata") or {} + webhookPlatformUrl = subMetadata.get("platformUrl", "") + + if event.type == "customer.subscription.updated": + stripeStatus = obj.get("status", "") + + periodData: Dict[str, Any] = {} + if obj.get("current_period_start"): + periodData["currentPeriodStart"] = datetime.fromtimestamp( + obj["current_period_start"], tz=timezone.utc + ).isoformat() + if obj.get("current_period_end"): + periodData["currentPeriodEnd"] = datetime.fromtimestamp( + obj["current_period_end"], tz=timezone.utc + ).isoformat() + if periodData: + subInterface.updateFields(subId, periodData) + + if stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.SCHEDULED: + subInterface.transitionStatus(subId, SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE) + subService.invalidateCache(mandateId) + plan = _getPlan(sub.get("planKey", "")) + refreshedSub = subInterface.getById(subId) + _notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedSub, platformUrl=webhookPlatformUrl) + logger.info("SCHEDULED -> ACTIVE for sub %s (mandate %s)", subId, mandateId) + + elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.PAST_DUE: + subInterface.transitionStatus(subId, SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.ACTIVE) + subService.invalidateCache(mandateId) + logger.info("PAST_DUE -> ACTIVE for sub %s (mandate %s)", subId, mandateId) + + elif stripeStatus == "past_due" and currentStatus == SubscriptionStatusEnum.ACTIVE: + subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE) + subService.invalidateCache(mandateId) + logger.info("ACTIVE -> PAST_DUE for sub %s (mandate %s)", subId, mandateId) + + elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.ACTIVE: + subService.invalidateCache(mandateId) + logger.info("Period renewed for sub %s (mandate %s)", subId, mandateId) + + elif event.type == "customer.subscription.deleted": + if currentStatus not in (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE, + SubscriptionStatusEnum.SCHEDULED): + logger.info("Ignoring deletion for sub %s in status %s", subId, currentStatus.value) + return + + subInterface.transitionStatus(subId, currentStatus, SubscriptionStatusEnum.EXPIRED) + subService.invalidateCache(mandateId) + logger.info("Sub %s -> EXPIRED (Stripe deleted, mandate %s)", subId, mandateId) + + scheduled = subInterface.getScheduledForMandate(mandateId) + if scheduled: + try: + subInterface.transitionStatus( + scheduled["id"], SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE, + ) + subService.invalidateCache(mandateId) + plan = _getPlan(scheduled.get("planKey", "")) + refreshedScheduled = subInterface.getById(scheduled["id"]) + _notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedScheduled, platformUrl=webhookPlatformUrl) + logger.info("Promoted SCHEDULED sub %s -> ACTIVE (mandate %s)", scheduled["id"], mandateId) + except Exception as e: + logger.error("Failed to promote SCHEDULED sub %s: %s", scheduled["id"], e) + + elif event.type == "invoice.payment_failed": + if currentStatus == SubscriptionStatusEnum.ACTIVE: + subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE) + subService.invalidateCache(mandateId) + plan = _getPlan(sub.get("planKey", "")) + _notifySubscriptionChange(mandateId, "payment_failed", plan, subscriptionRecord=sub, platformUrl=webhookPlatformUrl) + logger.info("Payment failed for sub %s (mandate %s)", subId, mandateId) + + elif event.type == "customer.subscription.trial_will_end": + logger.info("Trial ending soon for sub %s (mandate %s)", subId, mandateId) + try: + from modules.shared.notifyMandateAdmins import notifyMandateAdmins + notifyMandateAdmins( + mandateId, + "[PowerOn] Testphase endet bald", + "Testphase endet bald", + [ + "Die kostenlose Testphase für Ihren Mandanten endet in Kürze.", + "Bitte wählen Sie einen Plan unter Billing-Verwaltung › Abonnement.", + ], + ) + except Exception as e: + logger.error("Failed to notify about trial ending: %s", e) + + elif event.type == "invoice.paid": + logger.info("Invoice paid for sub %s (mandate %s)", subId, mandateId) + return None + + @router.get("/admin/accounts/{targetMandateId}", response_model=List[AccountSummary]) @limiter.limit("30/minute") def getAccounts( @@ -1135,49 +1404,187 @@ def getUsersForMandate( raise HTTPException(status_code=500, detail=str(e)) -@router.get("/admin/transactions/{targetMandateId}", response_model=List[TransactionResponse]) +def _enrichTransactionRows(transactions) -> List[Dict[str, Any]]: + """Convert raw transaction dicts to enriched TransactionResponse rows with resolved usernames.""" + result = [] + for t in transactions: + row = TransactionResponse( + id=t.get("id"), + accountId=t.get("accountId"), + transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), + amount=t.get("amount", 0.0), + description=t.get("description", ""), + referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, + workflowId=t.get("workflowId"), + featureCode=t.get("featureCode"), + featureInstanceId=t.get("featureInstanceId"), + aicoreProvider=t.get("aicoreProvider"), + aicoreModel=t.get("aicoreModel"), + createdByUserId=t.get("createdByUserId"), + createdAt=t.get("_createdAt") + ) + result.append(row.model_dump()) + + try: + from modules.interfaces.interfaceDbUam import _getRootInterface as getUamRoot + uamInterface = getUamRoot() + userNames: Dict[str, str] = {} + for row in result: + uid = row.get("createdByUserId") + if uid and uid not in userNames: + try: + user = uamInterface.getUser(uid) + userNames[uid] = user.get("username", uid[:8]) if user else uid[:8] + except Exception: + userNames[uid] = uid[:8] + row["userName"] = userNames.get(uid, "") if uid else "" + except Exception: + for row in result: + row["userName"] = row.get("createdByUserId", "")[:8] if row.get("createdByUserId") else "" + + return result + + +def _buildTransactionsList(ctx: RequestContext, targetMandateId: str) -> List[Dict[str, Any]]: + """Build the full enriched transactions list for a mandate.""" + billingInterface = getBillingInterface(ctx.user, targetMandateId) + transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=5000) + + result = [] + for t in transactions: + row = TransactionResponse( + id=t.get("id"), + accountId=t.get("accountId"), + transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), + amount=t.get("amount", 0.0), + description=t.get("description", ""), + referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, + workflowId=t.get("workflowId"), + featureCode=t.get("featureCode"), + featureInstanceId=t.get("featureInstanceId"), + aicoreProvider=t.get("aicoreProvider"), + aicoreModel=t.get("aicoreModel"), + createdByUserId=t.get("createdByUserId"), + createdAt=t.get("_createdAt") + ) + result.append(row.model_dump()) + + # Resolve user names + try: + from modules.interfaces.interfaceDbUam import _getRootInterface as getUamRoot + uamInterface = getUamRoot() + userNames: Dict[str, str] = {} + for row in result: + uid = row.get("createdByUserId") + if uid and uid not in userNames: + try: + user = uamInterface.getUser(uid) + userNames[uid] = user.get("username", uid[:8]) if user else uid[:8] + except Exception: + userNames[uid] = uid[:8] + row["userName"] = userNames.get(uid, "") if uid else "" + except Exception: + for row in result: + row["userName"] = row.get("createdByUserId", "")[:8] if row.get("createdByUserId") else "" + + return result + + +@router.get("/admin/transactions/{targetMandateId}") @limiter.limit("30/minute") def getTransactionsAdmin( request: Request, targetMandateId: str = Path(..., description="Mandate ID"), + pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"), limit: int = Query(default=100, ge=1, le=1000), ctx: RequestContext = Depends(getRequestContext), ): - """ - Get all transactions for a mandate. - Access: SysAdmin (any mandate) or MandateAdmin (own mandate). - """ + """Get all transactions for a mandate with pagination support.""" if not _isAdminOfMandate(ctx, targetMandateId): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate") try: - billingInterface = getBillingInterface(ctx.user, targetMandateId) - transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=limit) - - result = [] - for t in transactions: - result.append(TransactionResponse( - id=t.get("id"), - accountId=t.get("accountId"), - transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")), - amount=t.get("amount", 0.0), - description=t.get("description", ""), - referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None, - workflowId=t.get("workflowId"), - featureCode=t.get("featureCode"), - featureInstanceId=t.get("featureInstanceId"), - aicoreProvider=t.get("aicoreProvider"), - aicoreModel=t.get("aicoreModel"), - createdByUserId=t.get("createdByUserId"), - createdAt=t.get("_createdAt") - )) - - return result - + paginationParams: Optional[PaginationParams] = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + paginationParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError) as e: + raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}") + + if paginationParams: + # DB-level pagination — enrich only the returned page + billingInterface = getBillingInterface(ctx.user, targetMandateId) + result = billingInterface.getTransactionsByMandate(targetMandateId, pagination=paginationParams) + transactions = result.items if hasattr(result, 'items') else result + enrichedItems = _enrichTransactionRows(transactions) + return { + "items": enrichedItems, + "pagination": PaginationMetadata( + currentPage=paginationParams.page, + pageSize=paginationParams.pageSize, + totalItems=result.totalItems if hasattr(result, 'totalItems') else len(enrichedItems), + totalPages=result.totalPages if hasattr(result, 'totalPages') else 0, + sort=paginationParams.sort, + filters=paginationParams.filters, + ).model_dump(), + } + + enriched = _buildTransactionsList(ctx, targetMandateId) + return {"items": enriched, "pagination": None} + + except HTTPException: + raise except Exception as e: logger.error(f"Error getting billing transactions for mandate {targetMandateId}: {e}") raise HTTPException(status_code=500, detail=str(e)) +@router.get("/admin/transactions/{targetMandateId}/filter-values") +@limiter.limit("60/minute") +def getTransactionFilterValues( + request: Request, + targetMandateId: str = Path(..., description="Mandate ID"), + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + ctx: RequestContext = Depends(getRequestContext), +): + """Return distinct filter values for a column in mandate transactions.""" + if not _isAdminOfMandate(ctx, targetMandateId): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate") + try: + crossFilterParams: Optional[PaginationParams] = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError): + pass + + # Try SQL DISTINCT for native DB columns; fallback to in-memory for enriched columns (e.g. userName) + try: + rootBillingInterface = _getRootInterface() + recordFilter = {"mandateId": targetMandateId} + values = rootBillingInterface.db.getDistinctColumnValues( + BillingTransaction, column, crossFilterParams, recordFilter + ) + return sorted(values, key=lambda v: str(v).lower()) + except Exception: + enriched = _buildTransactionsList(ctx, targetMandateId) + crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams) + return _extractDistinctValues(crossFiltered, column) + except Exception as e: + logger.error(f"Error getting filter values for transactions: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + # ============================================================================= # Mandate View Endpoints (for Admins) # ============================================================================= @@ -1625,3 +2032,50 @@ def getUserViewTransactions( except Exception as e: logger.error(f"Error getting user view transactions: {e}") raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/view/users/transactions/filter-values") +@limiter.limit("60/minute") +def getUserViewTransactionsFilterValues( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + ctx: RequestContext = Depends(getRequestContext) +): + """Return distinct filter values for a column in user transactions.""" + try: + billingInterface = getBillingInterface(ctx.user, ctx.mandateId) + scope = _getBillingDataScope(ctx.user) + if scope.isGlobalAdmin: + mandateIds = None + else: + mandateIds = scope.adminMandateIds + scope.memberMandateIds + if not mandateIds: + return [] + allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000) + allTransactions = _filterTransactionsByScope(allTransactions, scope) + transactionDicts = [] + for t in allTransactions: + transactionDicts.append({ + "id": t.get("id"), + "accountId": t.get("accountId"), + "transactionType": t.get("transactionType", "DEBIT"), + "amount": t.get("amount", 0.0), + "description": t.get("description", ""), + "referenceType": t.get("referenceType"), + "workflowId": t.get("workflowId"), + "featureCode": t.get("featureCode"), + "featureInstanceId": t.get("featureInstanceId"), + "aicoreProvider": t.get("aicoreProvider"), + "aicoreModel": t.get("aicoreModel"), + "createdByUserId": t.get("createdByUserId"), + "createdAt": t.get("_createdAt"), + "mandateId": t.get("mandateId"), + "mandateName": t.get("mandateName"), + "userId": t.get("userId"), + "userName": t.get("userName"), + }) + return _handleFilterValuesRequest(transactionDicts, column, pagination) + except Exception as e: + logger.error(f"Error getting filter values for user transactions: {e}") + raise HTTPException(status_code=500, detail=str(e)) diff --git a/modules/routes/routeDataConnections.py b/modules/routes/routeDataConnections.py index 2a6be738..17ef0115 100644 --- a/modules/routes/routeDataConnections.py +++ b/modules/routes/routeDataConnections.py @@ -147,6 +147,11 @@ async def get_connections( try: interface = getInterface(currentUser) + # NOTE: Cannot use db.getRecordsetPaginated() here because each connection + # is enriched with computed tokenStatus/tokenExpiresAt (requires per-row DB lookup). + # Token refresh also may trigger re-fetch. Connections per user are typically < 10, + # so in-memory pagination is acceptable. + # Parse pagination parameter paginationParams = None if pagination: @@ -287,6 +292,42 @@ async def get_connections( detail=f"Failed to get connections: {str(e)}" ) +@router.get("/filter-values") +@limiter.limit("60/minute") +def get_connection_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + currentUser: User = Depends(getCurrentUser) +) -> List[str]: + """Return distinct filter values for a column in connections.""" + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + interface = getInterface(currentUser) + connections = interface.getUserConnections(currentUser.id) + items = [] + for connection in connections: + tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connection.id) + items.append({ + "id": connection.id, + "userId": connection.userId, + "authority": connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority), + "externalId": connection.externalId, + "externalUsername": connection.externalUsername or "", + "externalEmail": connection.externalEmail, + "status": connection.status.value if hasattr(connection.status, 'value') else str(connection.status), + "connectedAt": connection.connectedAt, + "lastChecked": connection.lastChecked, + "expiresAt": connection.expiresAt, + "tokenStatus": tokenStatus, + "tokenExpiresAt": tokenExpiresAt + }) + return _handleFilterValuesRequest(items, column, pagination) + except Exception as e: + logger.error(f"Error getting filter values for connections: {str(e)}") + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("/", response_model=UserConnection) @limiter.limit("10/minute") def create_connection( diff --git a/modules/routes/routeDataFiles.py b/modules/routes/routeDataFiles.py index 470793bb..999d07df 100644 --- a/modules/routes/routeDataFiles.py +++ b/modules/routes/routeDataFiles.py @@ -37,6 +37,17 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user): mgmtInterface.updateFile(fileId, {"status": "active"}) return + file_meta = mgmtInterface.getFile(fileId) + feature_instance_id = "" + mandate_id = "" + if file_meta: + if isinstance(file_meta, dict): + feature_instance_id = file_meta.get("featureInstanceId") or "" + mandate_id = file_meta.get("mandateId") or "" + else: + feature_instance_id = getattr(file_meta, "featureInstanceId", None) or "" + mandate_id = getattr(file_meta, "mandateId", None) or "" + logger.info(f"Auto-index starting for {fileName} ({len(rawBytes)} bytes, {mimeType})") # Step 1: Structure Pre-Scan (AI-free) @@ -47,6 +58,8 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user): fileId=fileId, fileName=fileName, userId=userId, + featureInstanceId=str(feature_instance_id) if feature_instance_id else "", + mandateId=str(mandate_id) if mandate_id else "", ) logger.info( f"Pre-scan complete for {fileName}: " @@ -105,7 +118,11 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user): from modules.serviceCenter import getService from modules.serviceCenter.context import ServiceCenterContext - ctx = ServiceCenterContext(user=user, mandate_id="", feature_instance_id="") + ctx = ServiceCenterContext( + user=user, + mandate_id=str(mandate_id) if mandate_id else "", + feature_instance_id=str(feature_instance_id) if feature_instance_id else "", + ) knowledgeService = getService("knowledge", ctx) await knowledgeService.indexFile( @@ -113,6 +130,8 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user): fileName=fileName, mimeType=mimeType, userId=userId, + featureInstanceId=str(feature_instance_id) if feature_instance_id else "", + mandateId=str(mandate_id) if mandate_id else "", contentObjects=contentObjects, structure=contentIndex.structure, ) @@ -214,6 +233,56 @@ def get_files( detail=f"Failed to get files: {str(e)}" ) +@router.get("/list/filter-values") +@limiter.limit("60/minute") +def get_file_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + currentUser: User = Depends(getCurrentUser), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in files.""" + try: + managementInterface = interfaceDbManagement.getInterface( + currentUser, + mandateId=str(context.mandateId) if context.mandateId else None, + featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None + ) + + crossFilterPagination = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterPagination = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError): + pass + + try: + recordFilter = {"_createdBy": managementInterface.userId} + values = managementInterface.db.getDistinctColumnValues( + FileItem, column, crossFilterPagination, recordFilter + ) + return sorted(values, key=lambda v: str(v).lower()) + except Exception: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + result = managementInterface.getAllFiles(pagination=None) + items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in result] + return _handleFilterValuesRequest(items, column, pagination) + except Exception as e: + logger.error(f"Error getting filter values for files: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(e) + ) + + @router.post("/upload", status_code=status.HTTP_201_CREATED) @limiter.limit("10/minute") async def upload_file( diff --git a/modules/routes/routeDataMandates.py b/modules/routes/routeDataMandates.py index 16b7166d..2c2bd31c 100644 --- a/modules/routes/routeDataMandates.py +++ b/modules/routes/routeDataMandates.py @@ -164,6 +164,66 @@ def get_mandates( detail=f"Failed to get mandates: {str(e)}" ) +@router.get("/filter-values") +@limiter.limit("60/minute") +def get_mandate_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in mandates.""" + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + isSysAdmin = context.hasSysAdminRole + if not isSysAdmin: + adminMandateIds = _getAdminMandateIds(context) + if not adminMandateIds: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required") + + appInterface = interfaceDbApp.getRootInterface() + + if isSysAdmin: + # SysAdmin: try SQL DISTINCT for DB columns + crossFilterPagination = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterPagination = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError): + pass + try: + values = appInterface.db.getDistinctColumnValues( + Mandate, column, crossFilterPagination + ) + return sorted(values, key=lambda v: str(v).lower()) + except Exception: + result = appInterface.getAllMandates(pagination=None) + items = result if isinstance(result, list) else (result.items if hasattr(result, 'items') else result) + items = [i.model_dump() if hasattr(i, 'model_dump') else i for i in items] + return _handleFilterValuesRequest(items, column, pagination) + else: + # MandateAdmin: in-memory (small set of individual mandate lookups) + result = [] + for mid in adminMandateIds: + mandate = appInterface.getMandate(mid) + if mandate: + result.append(mandate if isinstance(mandate, dict) else mandate.model_dump() if hasattr(mandate, 'model_dump') else vars(mandate)) + items = [i.model_dump() if hasattr(i, 'model_dump') else i for i in result] + return _handleFilterValuesRequest(items, column, pagination) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting filter values for mandates: {str(e)}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + + @router.get("/{targetMandateId}", response_model=Mandate) @limiter.limit("30/minute") def get_mandate( @@ -540,6 +600,63 @@ def list_mandate_users( ) +@router.get("/{targetMandateId}/users/filter-values") +@limiter.limit("60/minute") +def get_mandate_users_filter_values( + request: Request, + targetMandateId: str = Path(..., description="ID of the mandate"), + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in mandate users.""" + if not _hasMandateAdminRole(context, targetMandateId) and not context.hasSysAdminRole: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate-Admin role required") + + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + rootInterface = interfaceDbApp.getRootInterface() + mandate = rootInterface.getMandate(targetMandateId) + if not mandate: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Mandate {targetMandateId} not found") + + userMandates = rootInterface.getUserMandatesByMandate(targetMandateId) + result = [] + for um in userMandates: + user = rootInterface.getUser(str(um.userId)) + if not user: + continue + roleIds = rootInterface.getRoleIdsForUserMandate(str(um.id)) + roleLabels = [] + filteredRoleIds = [] + seenLabels = set() + for roleId in roleIds: + role = rootInterface.getRole(roleId) + if role: + if role.featureInstanceId: + continue + filteredRoleIds.append(roleId) + if role.roleLabel not in seenLabels: + roleLabels.append(role.roleLabel) + seenLabels.add(role.roleLabel) + result.append({ + "id": str(um.id), + "userId": str(user.id), + "username": user.username, + "email": user.email, + "fullName": user.fullName, + "roleIds": filteredRoleIds, + "roleLabels": roleLabels, + "enabled": um.enabled + }) + return _handleFilterValuesRequest(result, column, pagination) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting filter values for mandate users: {str(e)}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + + @router.post("/{targetMandateId}/users", response_model=UserMandateResponse) @limiter.limit("30/minute") def add_user_to_mandate( diff --git a/modules/routes/routeDataPrompts.py b/modules/routes/routeDataPrompts.py index faf692fe..f9246ab6 100644 --- a/modules/routes/routeDataPrompts.py +++ b/modules/routes/routeDataPrompts.py @@ -81,6 +81,29 @@ def get_prompts( pagination=None ) +@router.get("/filter-values") +@limiter.limit("60/minute") +def get_prompt_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + currentUser: User = Depends(getCurrentUser) +) -> list: + """Return distinct filter values for a column in prompts. + + NOTE: Cannot use db.getDistinctColumnValues() because visibility rules + (own + system for regular users) require pre-filtering the recordset. + """ + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + managementInterface = interfaceDbManagement.getInterface(currentUser) + result = managementInterface.getAllPrompts(pagination=None) + items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in result] + return _handleFilterValuesRequest(items, column, pagination) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @router.post("", response_model=Prompt) @limiter.limit("10/minute") def create_prompt( diff --git a/modules/routes/routeDataUsers.py b/modules/routes/routeDataUsers.py index 16742388..7e903466 100644 --- a/modules/routes/routeDataUsers.py +++ b/modules/routes/routeDataUsers.py @@ -69,6 +69,55 @@ def _isAdminForUser(context: RequestContext, targetUserId: str) -> bool: return False +def _extractDistinctValues(items: List[Dict[str, Any]], columnKey: str) -> List[str]: + """Extract sorted distinct display values for a column from enriched items.""" + values = set() + for item in items: + val = item.get(columnKey) + if val is None or val == "": + continue + if isinstance(val, bool): + values.add("true" if val else "false") + elif isinstance(val, (int, float)): + values.add(str(val)) + elif isinstance(val, dict): + text = val.get("en") or next((v for v in val.values() if isinstance(v, str) and v), None) + if text: + values.add(str(text)) + else: + values.add(str(val)) + return sorted(values, key=lambda v: v.lower()) + + +def _handleFilterValuesRequest( + items: List[Dict[str, Any]], + column: str, + paginationJson: Optional[str] = None, +) -> List[str]: + """ + Generic handler for /filter-values endpoints. + Applies all active filters EXCEPT the one for the requested column (cross-filtering), + then extracts distinct values for that column. + """ + crossFilterParams: Optional[PaginationParams] = None + if paginationJson: + try: + import json + paginationDict = json.loads(paginationJson) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError): + pass + + crossFiltered = _applyFiltersAndSort(items, crossFilterParams) + return _extractDistinctValues(crossFiltered, column) + + def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional[PaginationParams]) -> List[Dict[str, Any]]: """ Apply filters and sorting to a list of items. @@ -121,7 +170,6 @@ def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional if itemValue is None: return False - # Convert to string for comparison if needed itemStr = str(itemValue).lower() valueStr = str(v).lower() @@ -147,6 +195,42 @@ def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional return itemNum <= valueNum except (ValueError, TypeError): return False + elif op == 'between': + if isinstance(v, dict): + fromVal = v.get('from', '') + toVal = v.get('to', '') + if not fromVal and not toVal: + return True + # Date range: from/to are YYYY-MM-DD strings, itemValue may be Unix timestamp + try: + from datetime import datetime, timezone + fromTs = None + toTs = None + if fromVal: + fromTs = datetime.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=timezone.utc).timestamp() + if toVal: + toTs = datetime.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=timezone.utc).timestamp() + itemNum = float(itemValue) if not isinstance(itemValue, (int, float)) else itemValue + # Normalize: if item looks like a millisecond timestamp, convert to seconds + if itemNum > 10000000000: + itemNum = itemNum / 1000 + if fromTs is not None and toTs is not None: + return fromTs <= itemNum <= toTs + elif fromTs is not None: + return itemNum >= fromTs + elif toTs is not None: + return itemNum <= toTs + except (ValueError, TypeError): + # Fallback: string comparison (for non-numeric date fields) + fromStr = str(fromVal).lower() if fromVal else '' + toStr = str(toVal).lower() if toVal else '' + if fromStr and toStr: + return fromStr <= itemStr <= toStr + elif fromStr: + return itemStr >= fromStr + elif toStr: + return itemStr <= toStr + return True elif op == 'in': if isinstance(v, list): return itemStr in [str(x).lower() for x in v] @@ -159,23 +243,25 @@ def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional result = [item for item in result if matchesFilter(item, field, operator, value)] - # Apply sorting + # Apply sorting — None values always last if paginationParams.sort: for sortField in reversed(paginationParams.sort): fieldName = sortField.field ascending = sortField.direction == 'asc' - def getSortKey(item: Dict[str, Any]): - value = item.get(fieldName) - if value is None: - return (1, '') # Nulls last - if isinstance(value, bool): - return (0, not value if ascending else value) - if isinstance(value, (int, float)): - return (0, value) - return (0, str(value).lower()) + noneItems = [item for item in result if item.get(fieldName) is None] + nonNoneItems = [item for item in result if item.get(fieldName) is not None] - result = sorted(result, key=getSortKey, reverse=not ascending) + def getSortKey(item: Dict[str, Any], _fn=fieldName): + value = item.get(_fn) + if isinstance(value, bool): + return (0, int(value), '') + if isinstance(value, (int, float)): + return (0, value, '') + return (1, 0, str(value).lower()) + + nonNoneItems = sorted(nonNoneItems, key=getSortKey, reverse=not ascending) + result = nonNoneItems + noneItems return result @@ -291,38 +377,23 @@ def get_users( pagination=None ) elif context.hasSysAdminRole: - # SysAdmin without mandateId sees all users - # Get all users via interface method (returns Pydantic User models) - allUserModels = appInterface.getAllUsers() - # Convert to dictionaries for filtering/sorting - cleanedUsers = [u.model_dump() for u in allUserModels] + # SysAdmin without mandateId — DB-level pagination via interface + result = appInterface.getAllUsers(paginationParams) - # Apply server-side filtering and sorting - filteredUsers = _applyFiltersAndSort(cleanedUsers, paginationParams) - - # Convert to User objects - users = [User(**u) for u in filteredUsers] - - if paginationParams: - import math - totalItems = len(users) - totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0 - startIdx = (paginationParams.page - 1) * paginationParams.pageSize - endIdx = startIdx + paginationParams.pageSize - paginatedUsers = users[startIdx:endIdx] - + if paginationParams and hasattr(result, 'items'): return PaginatedResponse( - items=paginatedUsers, + items=result.items, pagination=PaginationMetadata( currentPage=paginationParams.page, pageSize=paginationParams.pageSize, - totalItems=totalItems, - totalPages=totalPages, + totalItems=result.totalItems, + totalPages=result.totalPages, sort=paginationParams.sort, filters=paginationParams.filters ) ) else: + users = result if isinstance(result, list) else (result.items if hasattr(result, 'items') else []) return PaginatedResponse( items=users, pagination=None @@ -407,6 +478,88 @@ def get_users( detail=f"Failed to get users: {str(e)}" ) +@router.get("/filter-values") +@limiter.limit("60/minute") +def get_user_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in users.""" + try: + appInterface = interfaceDbApp.getInterface(context.user, mandateId=context.mandateId) + + # Build cross-filter pagination (all filters except the requested column) + crossFilterPagination = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterPagination = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError): + pass + + if context.mandateId: + # Mandate-scoped: in-memory (users require UserMandate join) + result = appInterface.getUsersByMandate(str(context.mandateId), None) + users = result if isinstance(result, list) else (result.items if hasattr(result, 'items') else []) + items = [u.model_dump() if hasattr(u, 'model_dump') else u for u in users] + return _handleFilterValuesRequest(items, column, pagination) + elif context.hasSysAdminRole: + # SysAdmin: use SQL DISTINCT for DB columns + try: + rootInterface = getRootInterface() + values = rootInterface.db.getDistinctColumnValues( + UserInDB, column, crossFilterPagination + ) + return sorted(values, key=lambda v: v.lower()) + except Exception: + users = appInterface.getAllUsers() + items = [u.model_dump() if hasattr(u, 'model_dump') else u for u in users] + return _handleFilterValuesRequest(items, column, pagination) + else: + # Non-admin multi-mandate: aggregate across admin mandates (in-memory) + rootInterface = getRootInterface() + userMandates = rootInterface.getUserMandates(str(context.user.id)) + adminMandateIds = [] + for um in userMandates: + umId = getattr(um, 'id', None) + mandateId = getattr(um, 'mandateId', None) + if not umId or not mandateId: + continue + roleIds = rootInterface.getRoleIdsForUserMandate(str(umId)) + for roleId in roleIds: + role = rootInterface.getRole(roleId) + if role and role.roleLabel == "admin" and not role.featureInstanceId: + adminMandateIds.append(str(mandateId)) + break + if not adminMandateIds: + return [] + seenUserIds = set() + users = [] + for mid in adminMandateIds: + mandateUsers = rootInterface.getUsersByMandate(mid) + uList = mandateUsers if isinstance(mandateUsers, list) else (mandateUsers.items if hasattr(mandateUsers, 'items') else []) + for u in uList: + uid = u.get("id") if isinstance(u, dict) else getattr(u, "id", None) + if uid and uid not in seenUserIds: + seenUserIds.add(uid) + users.append(u) + items = [u.model_dump() if hasattr(u, 'model_dump') else u for u in users] + return _handleFilterValuesRequest(items, column, pagination) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting filter values for users: {str(e)}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + + @router.get("/{userId}", response_model=User) @limiter.limit("30/minute") def get_user( diff --git a/modules/routes/routeInvitations.py b/modules/routes/routeInvitations.py index 906b11e7..cb913137 100644 --- a/modules/routes/routeInvitations.py +++ b/modules/routes/routeInvitations.py @@ -420,6 +420,13 @@ def list_invitations( Requires Mandate-Admin role. Returns all invitations created for this mandate. + NOTE: Cannot use db.getRecordsetPaginated() because: + - Computed status fields (isExpired, isUsedUp) are derived in-memory + - Filtering by revoked/used/expired requires post-fetch logic + - Invitation volume per mandate is typically low (< 100) + When this endpoint needs FormGeneratorTable pagination, add PaginatedResponse + support with in-memory slicing (similar to routeDataConnections). + Args: includeUsed: Include invitations that have reached maxUses includeExpired: Include expired invitations @@ -485,6 +492,54 @@ def list_invitations( ) +@router.get("/filter-values") +@limiter.limit("60/minute") +def get_invitation_filter_values( + request: Request, + column: str = Query(..., description="Column key"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters"), + frontendUrl: str = Query("", description="Frontend URL for building invite links"), + includeUsed: bool = Query(False, description="Include already used invitations"), + includeExpired: bool = Query(False, description="Include expired invitations"), + context: RequestContext = Depends(getRequestContext) +) -> list: + """Return distinct filter values for a column in invitations.""" + if not context.mandateId: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="X-Mandate-Id header is required") + if not _hasMandateAdminRole(context): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate-Admin role required") + try: + from modules.routes.routeDataUsers import _handleFilterValuesRequest + rootInterface = getRootInterface() + allInvitations = rootInterface.getInvitationsByMandate(str(context.mandateId)) + currentTime = getUtcTimestamp() + result = [] + for inv in allInvitations: + if inv.revokedAt: + continue + currentUses = inv.currentUses or 0 + maxUses = inv.maxUses or 1 + if not includeUsed and currentUses >= maxUses: + continue + expiresAt = inv.expiresAt or 0 + if not includeExpired and expiresAt < currentTime: + continue + baseUrl = frontendUrl.rstrip("/") if frontendUrl else "" + inviteUrl = f"{baseUrl}/invite/{inv.token}" if baseUrl else "" + result.append({ + **inv.model_dump(), + "inviteUrl": inviteUrl, + "isExpired": expiresAt < currentTime, + "isUsedUp": currentUses >= maxUses + }) + return _handleFilterValuesRequest(result, column, pagination) + except HTTPException: + raise + except Exception as e: + logger.error(f"Error getting filter values for invitations: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + + @router.delete("/{invitationId}", response_model=Dict[str, str]) @limiter.limit("30/minute") def revoke_invitation( diff --git a/modules/routes/routeMessaging.py b/modules/routes/routeMessaging.py index 0b4784ff..42e15f0e 100644 --- a/modules/routes/routeMessaging.py +++ b/modules/routes/routeMessaging.py @@ -21,7 +21,7 @@ from modules.datamodels.datamodelMessaging import ( MessagingSubscriptionExecutionResult ) from modules.datamodels.datamodelUam import User -from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata +from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict # Configure logger logger = logging.getLogger(__name__) @@ -48,6 +48,8 @@ def get_subscriptions( if pagination: try: paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) paginationParams = PaginationParams(**paginationDict) if paginationDict else None except (json.JSONDecodeError, ValueError) as e: raise HTTPException( @@ -185,6 +187,8 @@ def get_subscription_registrations( if pagination: try: paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) paginationParams = PaginationParams(**paginationDict) if paginationDict else None except (json.JSONDecodeError, ValueError) as e: raise HTTPException( @@ -277,6 +281,8 @@ def get_my_registrations( if pagination: try: paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) paginationParams = PaginationParams(**paginationDict) if paginationDict else None except (json.JSONDecodeError, ValueError) as e: raise HTTPException( @@ -450,6 +456,8 @@ def get_deliveries( if pagination: try: paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) paginationParams = PaginationParams(**paginationDict) if paginationDict else None except (json.JSONDecodeError, ValueError) as e: raise HTTPException( diff --git a/modules/routes/routeRealEstate.py b/modules/routes/routeRealEstate.py index 881a77dc..a3466aca 100644 --- a/modules/routes/routeRealEstate.py +++ b/modules/routes/routeRealEstate.py @@ -227,33 +227,22 @@ async def get_projects( context.user, mandateId=mandateId, featureInstanceId=instanceId ) recordFilter = {"featureInstanceId": instanceId} - items = interface.getProjekte(recordFilter=recordFilter) paginationParams = _parsePagination(pagination) if paginationParams: - if paginationParams.sort: - for sort_field in reversed(paginationParams.sort): - field_name = sort_field.field - direction = sort_field.direction.lower() - items.sort( - key=lambda x: getattr(x, field_name, None), - reverse=(direction == "desc") + result = interface.getProjekte(pagination=paginationParams, recordFilter=recordFilter) + if hasattr(result, 'items'): + return PaginatedResponse( + items=result.items, + pagination=PaginationMetadata( + currentPage=paginationParams.page, + pageSize=paginationParams.pageSize, + totalItems=result.totalItems, + totalPages=result.totalPages, + sort=paginationParams.sort or [], + filters=paginationParams.filters ) - total_items = len(items) - total_pages = (total_items + paginationParams.pageSize - 1) // paginationParams.pageSize - start_idx = (paginationParams.page - 1) * paginationParams.pageSize - end_idx = start_idx + paginationParams.pageSize - paginated_items = items[start_idx:end_idx] - return PaginatedResponse( - items=paginated_items, - pagination=PaginationMetadata( - currentPage=paginationParams.page, - pageSize=paginationParams.pageSize, - totalItems=total_items, - totalPages=total_pages, - sort=paginationParams.sort or [], - filters=paginationParams.filters ) - ) + items = interface.getProjekte(recordFilter=recordFilter) return PaginatedResponse(items=items, pagination=None) @@ -359,33 +348,22 @@ async def get_parcels( context.user, mandateId=mandateId, featureInstanceId=instanceId ) recordFilter = {"featureInstanceId": instanceId} - items = interface.getParzellen(recordFilter=recordFilter) paginationParams = _parsePagination(pagination) if paginationParams: - if paginationParams.sort: - for sort_field in reversed(paginationParams.sort): - field_name = sort_field.field - direction = sort_field.direction.lower() - items.sort( - key=lambda x: getattr(x, field_name, None), - reverse=(direction == "desc") + result = interface.getParzellen(pagination=paginationParams, recordFilter=recordFilter) + if hasattr(result, 'items'): + return PaginatedResponse( + items=result.items, + pagination=PaginationMetadata( + currentPage=paginationParams.page, + pageSize=paginationParams.pageSize, + totalItems=result.totalItems, + totalPages=result.totalPages, + sort=paginationParams.sort or [], + filters=paginationParams.filters ) - total_items = len(items) - total_pages = (total_items + paginationParams.pageSize - 1) // paginationParams.pageSize - start_idx = (paginationParams.page - 1) * paginationParams.pageSize - end_idx = start_idx + paginationParams.pageSize - paginated_items = items[start_idx:end_idx] - return PaginatedResponse( - items=paginated_items, - pagination=PaginationMetadata( - currentPage=paginationParams.page, - pageSize=paginationParams.pageSize, - totalItems=total_items, - totalPages=total_pages, - sort=paginationParams.sort or [], - filters=paginationParams.filters ) - ) + items = interface.getParzellen(recordFilter=recordFilter) return PaginatedResponse(items=items, pagination=None) diff --git a/modules/routes/routeSubscription.py b/modules/routes/routeSubscription.py new file mode 100644 index 00000000..8334a8c0 --- /dev/null +++ b/modules/routes/routeSubscription.py @@ -0,0 +1,437 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Subscription routes — ID-based, state-machine-driven. + +Endpoints: +- GET /api/subscription/plans — list selectable plans +- GET /api/subscription/status — operative + scheduled subscription for current mandate +- POST /api/subscription/activate — start checkout for a plan +- POST /api/subscription/cancel — cancel a specific subscription (by ID) +- POST /api/subscription/reactivate — reactivate a cancelled subscription (by ID) +- POST /api/subscription/force-cancel — sysadmin immediate cancel (by ID) +""" + +from fastapi import APIRouter, HTTPException, Depends, Request, Query +from fastapi import status +from typing import Dict, Any, List, Optional +import logging +import json +import math +from pydantic import BaseModel, Field + +from modules.auth import limiter, getRequestContext, RequestContext +from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict +from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues + +logger = logging.getLogger(__name__) + + +def _resolveMandateId(context: RequestContext) -> str: + if context.mandateId: + return str(context.mandateId) + return "" + + +def _assertMandateAdmin(context: RequestContext, mandateId: str) -> None: + if context.hasSysAdminRole: + return + try: + from modules.interfaces.interfaceDbApp import getRootInterface + rootInterface = getRootInterface() + userMandates = rootInterface.getUserMandates(str(context.user.id)) + for um in userMandates: + if str(getattr(um, "mandateId", None)) != str(mandateId): + continue + if not getattr(um, "enabled", True): + continue + umId = str(getattr(um, "id", "")) + roleIds = rootInterface.getRoleIdsForUserMandate(umId) + for roleId in roleIds: + role = rootInterface.getRole(roleId) + if role and role.roleLabel == "admin" and not role.featureInstanceId: + return + except Exception: + pass + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate admin role required") + + +# ============================================================================= +# Request / Response models +# ============================================================================= + +class ActivatePlanRequest(BaseModel): + planKey: str = Field(..., description="Key of the plan to activate") + returnUrl: str = Field(..., description="Frontend URL to redirect back to after Stripe Checkout") + +class CancelRequest(BaseModel): + subscriptionId: str = Field(..., description="ID of the subscription to cancel") + +class ReactivateRequest(BaseModel): + subscriptionId: str = Field(..., description="ID of the subscription to reactivate") + +class ForceCancelRequest(BaseModel): + subscriptionId: str = Field(..., description="ID of the subscription to force-cancel") + +class VerifyCheckoutRequest(BaseModel): + sessionId: str = Field(..., description="Stripe Checkout Session ID to verify") + +class SubscriptionStatusResponse(BaseModel): + active: bool + subscription: Optional[Dict[str, Any]] = None + plan: Optional[Dict[str, Any]] = None + scheduled: Optional[Dict[str, Any]] = None + + +# ============================================================================= +# Router +# ============================================================================= + +router = APIRouter( + prefix="/api/subscription", + tags=["Subscription"], + responses={404: {"description": "Not found"}}, +) + + +# ============================================================================= +# Endpoints +# ============================================================================= + +@router.get("/plans", response_model=List[Dict[str, Any]]) +@limiter.limit("30/minute") +def getPlans(request: Request, context: RequestContext = Depends(getRequestContext)): + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + getService as getSubscriptionService, + ) + try: + mandateId = _resolveMandateId(context) + subService = getSubscriptionService(context.user, mandateId) + plans = subService.getSelectablePlans() + return [p.model_dump() for p in plans] + except Exception as e: + logger.error("Error fetching plans: %s", e) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/status", response_model=SubscriptionStatusResponse) +@limiter.limit("60/minute") +def getStatus(request: Request, context: RequestContext = Depends(getRequestContext)): + """Return the operative subscription and any scheduled successor for the current mandate.""" + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + getService as getSubscriptionService, + ) + mandateId = _resolveMandateId(context) + if not mandateId: + return SubscriptionStatusResponse(active=False) + _assertMandateAdmin(context, mandateId) + + try: + subService = getSubscriptionService(context.user, mandateId) + operative = subService.getOperativeSubscription(mandateId) + scheduled = subService.getScheduledSubscription(mandateId) + + if not operative: + from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum + pending = subService.listSubscriptions(mandateId, [SubscriptionStatusEnum.PENDING]) + if pending: + sub = pending[0] + plan = subService.getPlan(sub.get("planKey", "")) + return SubscriptionStatusResponse( + active=False, + subscription=sub, + plan=plan.model_dump() if plan else None, + scheduled=scheduled, + ) + return SubscriptionStatusResponse(active=False, scheduled=scheduled) + + plan = subService.getPlan(operative.get("planKey", "")) + return SubscriptionStatusResponse( + active=True, + subscription=operative, + plan=plan.model_dump() if plan else None, + scheduled=scheduled, + ) + except Exception as e: + logger.error("Error fetching status: %s", e) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/activate", response_model=Dict[str, Any]) +@limiter.limit("10/minute") +def activatePlan( + request: Request, + data: ActivatePlanRequest, + context: RequestContext = Depends(getRequestContext), +): + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + getService as getSubscriptionService, + ) + mandateId = _resolveMandateId(context) + if not mandateId: + raise HTTPException(status_code=400, detail="X-Mandate-Id header required") + _assertMandateAdmin(context, mandateId) + + try: + subService = getSubscriptionService(context.user, mandateId) + return subService.activatePlan(mandateId, data.planKey, returnUrl=data.returnUrl) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error("Error activating plan %s: %s", data.planKey, e) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/cancel", response_model=Dict[str, Any]) +@limiter.limit("5/minute") +def cancelSubscription( + request: Request, + data: CancelRequest, + context: RequestContext = Depends(getRequestContext), +): + """Cancel a specific subscription by its ID.""" + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + getService as getSubscriptionService, + ) + mandateId = _resolveMandateId(context) + if not mandateId: + raise HTTPException(status_code=400, detail="X-Mandate-Id header required") + _assertMandateAdmin(context, mandateId) + + try: + subService = getSubscriptionService(context.user, mandateId) + return subService.cancelSubscription(data.subscriptionId) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error("Error cancelling subscription %s: %s", data.subscriptionId, e) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/reactivate", response_model=Dict[str, Any]) +@limiter.limit("5/minute") +def reactivateSubscription( + request: Request, + data: ReactivateRequest, + context: RequestContext = Depends(getRequestContext), +): + """Reactivate a cancelled (non-recurring) subscription before its period ends.""" + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + getService as getSubscriptionService, + ) + mandateId = _resolveMandateId(context) + if not mandateId: + raise HTTPException(status_code=400, detail="X-Mandate-Id header required") + _assertMandateAdmin(context, mandateId) + + try: + subService = getSubscriptionService(context.user, mandateId) + return subService.reactivateSubscription(data.subscriptionId) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error("Error reactivating subscription %s: %s", data.subscriptionId, e) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/force-cancel", response_model=Dict[str, Any]) +@limiter.limit("5/minute") +def forceCancel( + request: Request, + data: ForceCancelRequest, + context: RequestContext = Depends(getRequestContext), +): + """Sysadmin: immediately expire any non-terminal subscription.""" + if not context.hasSysAdminRole: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required") + + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + getService as getSubscriptionService, + ) + from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface + sub = getSubRootInterface().getById(data.subscriptionId) + if not sub: + raise HTTPException(status_code=404, detail="Subscription not found") + mandateId = sub["mandateId"] + + try: + subService = getSubscriptionService(context.user, mandateId) + return subService.forceCancel(data.subscriptionId) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error("Error force-cancelling subscription %s: %s", data.subscriptionId, e) + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/checkout/verify", response_model=Dict[str, Any]) +@limiter.limit("20/minute") +def verifyCheckout( + request: Request, + data: VerifyCheckoutRequest, + context: RequestContext = Depends(getRequestContext), +): + """Verify a Stripe Checkout Session and activate the subscription if paid. + + This is the synchronous counterpart to the checkout.session.completed webhook. + It's called by the frontend immediately after returning from Stripe to handle + environments where webhooks may be delayed or unavailable (e.g. localhost dev). + The logic is idempotent — if the webhook already processed the session, this is a no-op. + """ + mandateId = _resolveMandateId(context) + if not mandateId: + raise HTTPException(status_code=400, detail="X-Mandate-Id header required") + _assertMandateAdmin(context, mandateId) + + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + session = stripe.checkout.Session.retrieve(data.sessionId) + except Exception as e: + logger.error("Failed to retrieve checkout session %s: %s", data.sessionId, e) + raise HTTPException(status_code=400, detail="Invalid session ID") + + if session.get("status") != "complete" or session.get("payment_status") != "paid": + return {"status": "pending", "message": "Checkout not yet completed"} + + if session.get("mode") != "subscription": + raise HTTPException(status_code=400, detail="Not a subscription checkout session") + + from modules.routes.routeBilling import _handleSubscriptionCheckoutCompleted + _handleSubscriptionCheckoutCompleted(session, f"verify-{data.sessionId}") + + return {"status": "activated", "message": "Subscription activated"} + + +# ============================================================================= +# SysAdmin: global subscription overview +# ============================================================================= + +def _buildEnrichedSubscriptions() -> List[Dict[str, Any]]: + """Build the full enriched subscription list (shared by list + filter-values endpoints).""" + from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface + from modules.datamodels.datamodelSubscription import BUILTIN_PLANS, OPERATIVE_STATUSES + + subInterface = getSubRootInterface() + allSubs = subInterface.listAll() + + mandateNames: Dict[str, str] = {} + try: + from modules.datamodels.datamodelUam import Mandate + from modules.security.rootAccess import getRootDbAppConnector + appDb = getRootDbAppConnector() + for row in appDb.getRecordset(Mandate): + r = dict(row) + mid = r.get("id", "") + mandateNames[mid] = r.get("label") or r.get("name") or mid[:8] + except Exception as e: + logger.warning("Could not bulk-resolve mandate names: %s", e) + + operativeValues = {s.value for s in OPERATIVE_STATUSES} + + enriched = [] + for sub in allSubs: + mid = sub.get("mandateId", "") + planKey = sub.get("planKey", "") + plan = BUILTIN_PLANS.get(planKey) + + sub["mandateName"] = mandateNames.get(mid, mid[:8]) + sub["planTitle"] = (plan.title.get("de") or plan.title.get("en") or planKey) if plan else planKey + + if sub.get("status") in operativeValues: + userPrice = sub.get("snapshotPricePerUserCHF", 0) or 0 + instPrice = sub.get("snapshotPricePerInstanceCHF", 0) or 0 + try: + userCount = subInterface.countActiveUsers(mid) + instanceCount = subInterface.countActiveFeatureInstances(mid) + except Exception: + userCount = 0 + instanceCount = 0 + sub["monthlyRevenueCHF"] = round(userPrice * userCount + instPrice * instanceCount, 2) + sub["activeUsers"] = userCount + sub["activeInstances"] = instanceCount + else: + sub["monthlyRevenueCHF"] = 0 + sub["activeUsers"] = 0 + sub["activeInstances"] = 0 + + enriched.append(sub) + + return enriched + + +@router.get("/admin/all") +@limiter.limit("30/minute") +def getAllSubscriptions( + request: Request, + pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"), + context: RequestContext = Depends(getRequestContext), +): + """SysAdmin: list ALL subscriptions across all mandates with enriched metadata.""" + if not context.hasSysAdminRole: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required") + + paginationParams: Optional[PaginationParams] = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + paginationParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError) as e: + raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}") + + enriched = _buildEnrichedSubscriptions() + filtered = _applyFiltersAndSort(enriched, paginationParams) + + if paginationParams: + totalItems = len(filtered) + totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0 + startIdx = (paginationParams.page - 1) * paginationParams.pageSize + endIdx = startIdx + paginationParams.pageSize + pageItems = filtered[startIdx:endIdx] + return { + "items": pageItems, + "pagination": PaginationMetadata( + currentPage=paginationParams.page, + pageSize=paginationParams.pageSize, + totalItems=totalItems, + totalPages=totalPages, + sort=paginationParams.sort, + filters=paginationParams.filters, + ).model_dump(), + } + + return {"items": enriched, "pagination": None} + + +@router.get("/admin/all/filter-values") +@limiter.limit("60/minute") +def getFilterValues( + request: Request, + column: str = Query(..., description="Column key to extract distinct values for"), + pagination: Optional[str] = Query(None, description="JSON-encoded current filters (applied except for the requested column)"), + context: RequestContext = Depends(getRequestContext), +): + """Return distinct values for a column, respecting all active filters except the requested one.""" + if not context.hasSysAdminRole: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required") + + crossFilterParams: Optional[PaginationParams] = None + if pagination: + try: + paginationDict = json.loads(pagination) + if paginationDict: + paginationDict = normalize_pagination_dict(paginationDict) + filters = paginationDict.get("filters", {}) + filters.pop(column, None) + paginationDict["filters"] = filters + paginationDict.pop("sort", None) + crossFilterParams = PaginationParams(**paginationDict) + except (json.JSONDecodeError, ValueError) as e: + raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}") + + enriched = _buildEnrichedSubscriptions() + crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams) + + return _extractDistinctValues(crossFiltered, column) diff --git a/modules/routes/routeSystem.py b/modules/routes/routeSystem.py index c2be23a6..60e498bd 100644 --- a/modules/routes/routeSystem.py +++ b/modules/routes/routeSystem.py @@ -105,6 +105,9 @@ def _getFeatureUiObjects(featureCode: str) -> List[Dict[str, Any]]: elif featureCode == "automation": from modules.features.automation.mainAutomation import UI_OBJECTS return UI_OBJECTS + elif featureCode == "automation2": + from modules.features.automation2.mainAutomation2 import UI_OBJECTS + return UI_OBJECTS elif featureCode == "teamsbot": from modules.features.teamsbot.mainTeamsbot import UI_OBJECTS return UI_OBJECTS diff --git a/modules/serviceCenter/registry.py b/modules/serviceCenter/registry.py index 900f9f0e..cdf57304 100644 --- a/modules/serviceCenter/registry.py +++ b/modules/serviceCenter/registry.py @@ -45,10 +45,17 @@ IMPORTABLE_SERVICES: Dict[str, Dict[str, Any]] = { "billing": { "module": "modules.serviceCenter.services.serviceBilling.mainServiceBilling", "class": "BillingService", - "dependencies": [], + "dependencies": ["subscription"], "objectKey": "service.billing", "label": {"en": "Billing", "de": "Abrechnung", "fr": "Facturation"}, }, + "subscription": { + "module": "modules.serviceCenter.services.serviceSubscription.mainServiceSubscription", + "class": "SubscriptionService", + "dependencies": [], + "objectKey": "service.subscription", + "label": {"en": "Subscription", "de": "Abonnement", "fr": "Abonnement"}, + }, "sharepoint": { "module": "modules.serviceCenter.services.serviceSharepoint.mainServiceSharepoint", "class": "SharepointService", @@ -92,7 +99,7 @@ IMPORTABLE_SERVICES: Dict[str, Dict[str, Any]] = { "label": {"en": "Web Research", "de": "Web-Recherche", "fr": "Recherche Web"}, }, "neutralization": { - "module": "modules.serviceCenter.services.serviceNeutralization.mainServiceNeutralization", + "module": "modules.features.neutralization.serviceNeutralization.mainServiceNeutralization", "class": "NeutralizationService", "dependencies": ["extraction", "generation"], "objectKey": "service.neutralization", diff --git a/modules/serviceCenter/services/serviceAgent/agentLoop.py b/modules/serviceCenter/services/serviceAgent/agentLoop.py index bee03424..c196d237 100644 --- a/modules/serviceCenter/services/serviceAgent/agentLoop.py +++ b/modules/serviceCenter/services/serviceAgent/agentLoop.py @@ -25,6 +25,9 @@ from modules.shared.jsonUtils import closeJsonStructures from modules.serviceCenter.services.serviceBilling.mainServiceBilling import ( InsufficientBalanceException, ) +from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + SubscriptionInactiveException, +) logger = logging.getLogger(__name__) @@ -191,6 +194,18 @@ async def runAgentLoop( else: aiResponse = await aiCallFn(aiRequest) + except SubscriptionInactiveException as e: + logger.warning( + f"Subscription inactive in round {state.currentRound} (mandate={mandateId}): {e.message}" + ) + state.status = AgentStatusEnum.ERROR + state.abortReason = e.message + yield AgentEvent( + type=AgentEventTypeEnum.ERROR, + content=e.message, + data=e.toClientDict(), + ) + break except InsufficientBalanceException as e: logger.warning( f"Insufficient balance in round {state.currentRound} (mandate={mandateId}): {e.message}" diff --git a/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py b/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py index 03b8598e..78c69ff3 100644 --- a/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py +++ b/modules/serviceCenter/services/serviceAgent/mainServiceAgent.py @@ -1432,21 +1432,55 @@ def _registerCoreTools(registry: ToolRegistry, services): return ToolResult(toolCallId="", toolName="uploadToExternal", success=False, error=str(e)) async def _sendMail(args: Dict[str, Any], context: Dict[str, Any]): + import base64 as _b64 + connectionId = args.get("connectionId", "") to = args.get("to", []) subject = args.get("subject", "") body = args.get("body", "") + bodyType = "HTML" if args.get("bodyType", "text").lower() == "html" else "Text" + draft = args.get("draft", False) + attachmentFileIds = args.get("attachmentFileIds") or [] + if not connectionId or not to or not subject: return ToolResult(toolCallId="", toolName="sendMail", success=False, error="connectionId, to, and subject are required") try: + graphAttachments: List[Dict[str, Any]] = [] + if attachmentFileIds: + chatService = services.chat + dbMgmt = chatService.interfaceDbComponent + for fid in attachmentFileIds: + fileRow = dbMgmt.getFile(fid) + if not fileRow: + return ToolResult(toolCallId="", toolName="sendMail", success=False, error=f"Attachment file not found: {fid}") + rawBytes = dbMgmt.getFileData(fid) + if not rawBytes: + return ToolResult(toolCallId="", toolName="sendMail", success=False, error=f"Attachment file has no data: {fid}") + graphAttachments.append({ + "name": fileRow.fileName, + "contentBytes": _b64.b64encode(rawBytes).decode("ascii"), + "contentType": getattr(fileRow, "mimeType", "application/octet-stream"), + }) + from modules.connectors.connectorResolver import ConnectorResolver resolver = ConnectorResolver( services.getService("security"), _buildResolverDb(), ) adapter = await resolver.resolveService(connectionId, "outlook") + + if draft and hasattr(adapter, "createDraft"): + result = await adapter.createDraft( + to=to, subject=subject, body=body, bodyType=bodyType, + cc=args.get("cc"), attachments=graphAttachments or None, + ) + return ToolResult(toolCallId="", toolName="sendMail", success=True, data=str(result)) + if hasattr(adapter, "sendMail"): - result = await adapter.sendMail(to=to, subject=subject, body=body, cc=args.get("cc")) + result = await adapter.sendMail( + to=to, subject=subject, body=body, bodyType=bodyType, + cc=args.get("cc"), attachments=graphAttachments or None, + ) return ToolResult(toolCallId="", toolName="sendMail", success=True, data=str(result)) return ToolResult(toolCallId="", toolName="sendMail", success=False, error="Mail not supported by this adapter") except Exception as e: @@ -1484,15 +1518,26 @@ def _registerCoreTools(registry: ToolRegistry, services): registry.register( "sendMail", _sendMail, - description="Send an email via a connected mail service (Outlook, Gmail). Use listConnections to find the connectionId.", + description=( + "Send or draft an email via a connected mail service (Outlook). " + "Supports HTML body and file attachments from the workspace. " + "Set draft=true to save as draft without sending. " + "Use listConnections to find the connectionId." + ), parameters={ "type": "object", "properties": { "connectionId": {"type": "string", "description": "UserConnection ID"}, "to": {"type": "array", "items": {"type": "string"}, "description": "Recipient email addresses"}, "subject": {"type": "string", "description": "Email subject"}, - "body": {"type": "string", "description": "Email body text"}, + "body": {"type": "string", "description": "Email body — plain text or HTML markup"}, + "bodyType": {"type": "string", "enum": ["text", "html"], "description": "Body format: 'text' (default) or 'html'"}, "cc": {"type": "array", "items": {"type": "string"}, "description": "CC addresses"}, + "attachmentFileIds": { + "type": "array", "items": {"type": "string"}, + "description": "File IDs from the workspace to attach (use listFiles to find IDs)", + }, + "draft": {"type": "boolean", "description": "If true, save as draft in Drafts folder instead of sending"}, }, "required": ["connectionId", "to", "subject", "body"], }, @@ -2471,16 +2516,16 @@ def _registerCoreTools(registry: ToolRegistry, services): if not voiceName: try: - from modules.interfaces import interfaceDbManagement + from modules.features.workspace import interfaceFeatureWorkspace featureInstanceId = context.get("featureInstanceId", "") userId = context.get("userId", "") if userId: - dbMgmt = interfaceDbManagement.getInterface( + wsIf = interfaceFeatureWorkspace.getInterface( services.user, mandateId=mandateId or None, featureInstanceId=featureInstanceId or None, ) - vs = dbMgmt.getVoiceSettings(userId) if dbMgmt and hasattr(dbMgmt, "getVoiceSettings") else None + vs = wsIf.getVoiceSettings(userId) if wsIf else None if vs: voiceMap = {} if hasattr(vs, "ttsVoiceMap") and vs.ttsVoiceMap: @@ -2914,6 +2959,8 @@ def _registerCoreTools(registry: ToolRegistry, services): neutralizationService = services.getService("neutralization") if not neutralizationService: return ToolResult(toolCallId="", toolName="neutralizeData", success=False, error="Neutralization service not available") + if not neutralizationService.interfaceDbComponent: + neutralizationService.interfaceDbComponent = services.chat.interfaceDbComponent if text: result = neutralizationService.processText(text) else: diff --git a/modules/serviceCenter/services/serviceAi/mainServiceAi.py b/modules/serviceCenter/services/serviceAi/mainServiceAi.py index 90fd4d9a..09e2d708 100644 --- a/modules/serviceCenter/services/serviceAi/mainServiceAi.py +++ b/modules/serviceCenter/services/serviceAi/mainServiceAi.py @@ -27,6 +27,10 @@ from modules.serviceCenter.services.serviceBilling.mainServiceBilling import ( ProviderNotAllowedException, BillingContextError ) +from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + SubscriptionInactiveException, + SUBSCRIPTION_REASONS, +) logger = logging.getLogger(__name__) @@ -590,11 +594,25 @@ detectedIntent-Werte: balanceCheck = billingService.checkBalance(estimatedCost) if not balanceCheck.allowed: + reason = balanceCheck.reason or "" + + if reason in SUBSCRIPTION_REASONS: + from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum + statusMap = { + "SUBSCRIPTION_PAYMENT_REQUIRED": SubscriptionStatusEnum.PAST_DUE, + "SUBSCRIPTION_EXPIRED": SubscriptionStatusEnum.EXPIRED, + "SUBSCRIPTION_INACTIVE": SubscriptionStatusEnum.EXPIRED, + } + raise SubscriptionInactiveException( + status=statusMap.get(reason, SubscriptionStatusEnum.EXPIRED), + mandateId=str(mandateId), + ) + balance_str = f"{(balanceCheck.currentBalance or 0):.2f}" logger.warning( f"Billing check failed for user {user.id}: " f"Balance {balance_str} CHF, " - f"Reason: {balanceCheck.reason}" + f"Reason: {reason}" ) if balanceCheck.billingModel == BillingModelEnum.PREPAY_MANDATE: ulabel = (getattr(user, "email", None) or getattr(user, "username", None) or str(user.id)) @@ -651,6 +669,8 @@ detectedIntent-Werte: logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed") + except SubscriptionInactiveException: + raise except InsufficientBalanceException: raise except ProviderNotAllowedException: @@ -658,7 +678,6 @@ detectedIntent-Werte: except BillingContextError: raise except Exception as e: - # FAIL-SAFE: Don't silently swallow errors - log at ERROR level logger.error(f"BILLING FAIL-SAFE: Billing check failed with unexpected error: {e}") raise BillingContextError(f"Billing check failed: {e}") diff --git a/modules/serviceCenter/services/serviceBilling/billingExhaustedNotify.py b/modules/serviceCenter/services/serviceBilling/billingExhaustedNotify.py index aba08b89..d8f2acc4 100644 --- a/modules/serviceCenter/services/serviceBilling/billingExhaustedNotify.py +++ b/modules/serviceCenter/services/serviceBilling/billingExhaustedNotify.py @@ -1,23 +1,17 @@ # Copyright (c) 2025 Patrick Motsch # All rights reserved. """ -When the shared mandate pool (PREPAY_MANDATE) is exhausted, notify billing contacts. +When the shared mandate pool (PREPAY_MANDATE) is exhausted, notify mandate admins. -Recipients: BillingSettings.notifyEmails for the mandate (configure as mandate owner / finance). +Uses the central notifyMandateAdmins() function for recipient resolution and delivery. Emails are throttled per mandate to avoid spam (one notification per cooldown window). """ from __future__ import annotations -import html import logging import time -from typing import Any, Dict, List, Optional - -from modules.datamodels.datamodelMessaging import MessagingChannel -from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface -from modules.interfaces.interfaceMessaging import getInterface as getMessagingInterface -from modules.security.rootAccess import getRootUser +from typing import Dict logger = logging.getLogger(__name__) @@ -26,29 +20,6 @@ _poolExhaustedEmailLastSent: Dict[str, float] = {} _DEFAULT_COOLDOWN_SEC = 3600 -def _normalizeNotifyEmails(raw: Any) -> List[str]: - if raw is None: - return [] - if isinstance(raw, list): - return [str(e).strip() for e in raw if str(e).strip()] - if isinstance(raw, str): - s = raw.strip() - if not s: - return [] - # JSON array string - if s.startswith("["): - try: - import json - - parsed = json.loads(s) - if isinstance(parsed, list): - return [str(e).strip() for e in parsed if str(e).strip()] - except Exception: - pass - return [s] - return [] - - def maybeEmailMandatePoolExhausted( mandateId: str, triggeringUserId: str, @@ -58,7 +29,7 @@ def maybeEmailMandatePoolExhausted( cooldownSec: float = _DEFAULT_COOLDOWN_SEC, ) -> None: """ - Send one email per mandate per cooldown to BillingSettings.notifyEmails. + Send one notification per mandate per cooldown window when the pool is exhausted. Args: mandateId: Mandate whose pool is empty. @@ -82,59 +53,22 @@ def maybeEmailMandatePoolExhausted( return try: - billing = getBillingInterface(getRootUser(), mandateId) - settings = billing.getSettings(mandateId) or {} - recipients = _normalizeNotifyEmails(settings.get("notifyEmails")) - if not recipients: - logger.warning( - "PREPAY_MANDATE pool exhausted for mandate %s but notifyEmails is empty — " - "configure BillingSettings.notifyEmails for owner alerts", - mandateId, - ) - return + from modules.shared.notifyMandateAdmins import notifyMandateAdmins - subject = f"[PowerOn] Mandanten-Budget aufgebraucht (Mandant {mandateId[:8]}…)" - body = ( - f"Das gemeinsame Guthaben (PREPAY_MANDATE) für diesen Mandanten ist nicht mehr ausreichend.\n\n" - f"Mandanten-ID: {mandateId}\n" - f"Aktuelles Guthaben (Pool): CHF {currentBalance:.2f}\n" - f"Benötigt (mind.): CHF {requiredAmount:.2f}\n\n" - f"Auslösende/r Benutzer/in: {triggeringUserLabel} (ID: {triggeringUserId})\n\n" - f"Bitte laden Sie das Mandats-Guthaben in der Billing-Verwaltung auf, " - f"damit Benutzer wieder AI-Funktionen nutzen können.\n" + sent = notifyMandateAdmins( + mandateId, + "[PowerOn] Mandanten-Budget aufgebraucht", + "Budget aufgebraucht", + [ + "Das gemeinsame Guthaben (Prepaid-Pool) für diesen Mandanten ist nicht mehr ausreichend.", + f"Aktuelles Guthaben: CHF {currentBalance:.2f}\n" + f"Benötigt (mindestens): CHF {requiredAmount:.2f}", + f"Ausgelöst durch: {triggeringUserLabel}", + "Bitte laden Sie das Mandats-Guthaben in der Billing-Verwaltung auf, " + "damit Benutzer wieder AI-Funktionen nutzen können.", + ], ) - escaped = html.escape(body) - # Cannot use '\\n' inside f-string {…} expression (SyntaxError); build replacement outside. - brWithNl = "
" + "\n" - htmlMessage = f""" - - -{escaped.replace(chr(10), brWithNl)} -""" - - messaging = getMessagingInterface() - any_ok = False - for to in recipients: - try: - ok = messaging.send( - channel=MessagingChannel.EMAIL, - recipient=to, - subject=subject, - message=htmlMessage, - ) - if ok: - any_ok = True - else: - logger.warning("Pool exhausted email failed for %s", to) - except Exception as send_err: - logger.error("Error sending pool exhausted email to %s: %s", to, send_err) - - if any_ok: + if sent > 0: _poolExhaustedEmailLastSent[mandateId] = now - logger.info( - "Sent mandate pool exhausted notification for mandate %s to %s recipient(s)", - mandateId, - len(recipients), - ) except Exception as e: logger.error("maybeEmailMandatePoolExhausted failed: %s", e, exc_info=True) diff --git a/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py b/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py index d0325a5e..790612ed 100644 --- a/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py +++ b/modules/serviceCenter/services/serviceBilling/mainServiceBilling.py @@ -29,7 +29,7 @@ from modules.interfaces.interfaceDbBilling import getInterface as getBillingInte logger = logging.getLogger(__name__) # Markup percentage for internal pricing (+50% für Infrastruktur und Platform Service + 50% für Währungsrisiko ==> Faktor 2.0) -BILLING_MARKUP_PERCENT = 100 +BILLING_MARKUP_PERCENT = 400 # Singleton cache _billingServices: Dict[str, "BillingService"] = {} @@ -160,6 +160,10 @@ class BillingService: def checkBalance(self, estimatedCost: float = 0.0) -> BillingCheckResult: """ Check if the current user/mandate has sufficient balance. + + Gate order: + 1. Subscription active? (fast, cached) — blocks AI if not + 2. Budget sufficient? (existing prepaid logic) Args: estimatedCost: Estimated cost of the operation (with markup applied) @@ -167,11 +171,42 @@ class BillingService: Returns: BillingCheckResult indicating if operation is allowed """ + subResult = self._checkSubscription() + if subResult is not None: + return subResult + return self._billingInterface.checkBalance( self.mandateId, self.currentUser.id, estimatedCost ) + + def _checkSubscription(self) -> Optional[BillingCheckResult]: + """Return a failing BillingCheckResult if subscription is not active, else None.""" + try: + from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum + from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import ( + getService as getSubscriptionService, + _subscriptionReasonForStatus, + _subscriptionUserActionForStatus, + ) + + subService = getSubscriptionService(self.currentUser, self.mandateId) + status = subService.assertActive(self.mandateId) + + if status in (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.PAST_DUE): + return None + + return BillingCheckResult( + allowed=False, + reason=_subscriptionReasonForStatus(status), + upgradeRequired=True, + subscriptionUiPath="/admin/billing?tab=subscription", + userAction=_subscriptionUserActionForStatus(status), + ) + except Exception as e: + logger.warning(f"Subscription check failed (allowing): {e}") + return None def hasBalance(self, estimatedCost: float = 0.0) -> bool: """ diff --git a/modules/serviceCenter/services/serviceBilling/stripeCheckout.py b/modules/serviceCenter/services/serviceBilling/stripeCheckout.py index 9f3f7e68..8d6b4a57 100644 --- a/modules/serviceCenter/services/serviceBilling/stripeCheckout.py +++ b/modules/serviceCenter/services/serviceBilling/stripeCheckout.py @@ -82,17 +82,8 @@ def create_checkout_session( f"Invalid amount {amount_chf} CHF. Allowed: {ALLOWED_AMOUNTS_CHF}" ) - # Pin API version from config (match Stripe Dashboard) - api_version = APP_CONFIG.get("STRIPE_API_VERSION") - if api_version: - stripe.api_version = api_version - - # Get secrets - secret_key = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY") - if not secret_key: - raise ValueError("STRIPE_SECRET_KEY_SECRET not configured") - - stripe.api_key = secret_key + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() base_return_url = _normalizeReturnUrl(return_url) query_separator = "&" if "?" in base_return_url else "?" diff --git a/modules/serviceCenter/services/serviceChat/mainServiceChat.py b/modules/serviceCenter/services/serviceChat/mainServiceChat.py index 5cc1eb66..b05b0c64 100644 --- a/modules/serviceCenter/services/serviceChat/mainServiceChat.py +++ b/modules/serviceCenter/services/serviceChat/mainServiceChat.py @@ -253,7 +253,8 @@ class ChatService: logger.error(f"No messages found with documentsLabel: {docRef}") raise ValueError(f"Document reference not found: {docRef}") - logger.debug(f"Resolved {len(allDocuments)} documents from document list: {documentList}") + ref_count = len(getattr(documentList, 'references', [])) if documentList else 0 + logger.debug(f"Resolved {len(allDocuments)} documents from document list ({ref_count} refs)") return allDocuments except Exception as e: logger.error(f"Error getting documents from document list: {str(e)}") diff --git a/modules/serviceCenter/services/serviceExtraction/mainServiceExtraction.py b/modules/serviceCenter/services/serviceExtraction/mainServiceExtraction.py index a8468636..a227e66f 100644 --- a/modules/serviceCenter/services/serviceExtraction/mainServiceExtraction.py +++ b/modules/serviceCenter/services/serviceExtraction/mainServiceExtraction.py @@ -50,6 +50,38 @@ class ExtractionService: if model is None or model.calculatepriceCHF is None: raise RuntimeError(f"FATAL: Required internal model '{modelDisplayName}' is not available. Check connector registration.") + def extractContentFromBytes( + self, + documentBytes: bytes, + fileName: str, + mimeType: str, + documentId: Optional[str] = None, + options: Optional[ExtractionOptions] = None, + ) -> ContentExtracted: + """Extract content from raw bytes (no persistence). + Used for inline ActionDocuments from SharePoint/email in automation2.""" + opts = options or ExtractionOptions(prompt="", mergeStrategy=MergeStrategy()) + doc_id = documentId or str(uuid.uuid4()) + ec = runExtraction( + self._extractorRegistry, + self._chunkerRegistry, + documentBytes, + fileName, + mimeType, + opts, + ) + for p in ec.parts: + if not p.metadata: + p.metadata = {} + p.metadata.setdefault("documentId", doc_id) + p.metadata.setdefault("documentMimeType", mimeType) + p.metadata.setdefault("originalFileName", fileName) + p.metadata.setdefault("contentFormat", "extracted") + p.metadata.setdefault("intent", "extract") + p.metadata.setdefault("usageHint", f"Use extracted content from {fileName}") + p.metadata.setdefault("sourceAction", "extraction.extractContentFromBytes") + return ec + def extractContent( self, documents: List[ChatDocument], diff --git a/modules/serviceCenter/services/serviceGeneration/paths/documentPath.py b/modules/serviceCenter/services/serviceGeneration/paths/documentPath.py index 72838918..4fc6c9d5 100644 --- a/modules/serviceCenter/services/serviceGeneration/paths/documentPath.py +++ b/modules/serviceCenter/services/serviceGeneration/paths/documentPath.py @@ -56,15 +56,15 @@ class DocumentGenerationPath: try: # Schritt 5A: Kläre Dokument-Intents - documents = [] + doc_list = [] if documentList: - documents = self.services.chat.getChatDocumentsFromDocumentList(documentList) + doc_list = self.services.chat.getChatDocumentsFromDocumentList(documentList) # Filter: Entferne Original-Dokumente, wenn bereits Pre-Extracted JSONs existieren # (um Duplikate zu vermeiden - Pre-Extracted JSONs enthalten bereits die ContentParts) # Schritt 1: Identifiziere alle Original-Dokument-IDs, die durch Pre-Extracted JSONs abgedeckt werden originalDocIdsCoveredByPreExtracted = set() - for doc in documents: + for doc in doc_list: preExtracted = self.services.ai.intentAnalyzer.resolvePreExtractedDocument(doc) if preExtracted: originalDocId = preExtracted["originalDocument"]["id"] @@ -73,7 +73,7 @@ class DocumentGenerationPath: # Schritt 2: Filtere Dokumente - entferne Original-Dokumente, die bereits durch Pre-Extracted JSONs abgedeckt werden filteredDocuments = [] - for doc in documents: + for doc in doc_list: preExtracted = self.services.ai.intentAnalyzer.resolvePreExtractedDocument(doc) if preExtracted: # Pre-Extracted JSON behalten @@ -85,13 +85,13 @@ class DocumentGenerationPath: # Normales Dokument ohne Pre-Extracted JSON - behalten filteredDocuments.append(doc) - documents = filteredDocuments + doc_list = filteredDocuments checkWorkflowStopped(self.services) - if not documentIntents and documents: + if not documentIntents and doc_list: documentIntents = await self.services.ai.clarifyDocumentIntents( - documents, + doc_list, userPrompt, {"outputFormat": outputFormat}, docOperationId @@ -100,9 +100,9 @@ class DocumentGenerationPath: checkWorkflowStopped(self.services) # Schritt 5B: Extrahiere und bereite Content vor - if documents: + if doc_list: preparedContentParts = await self.services.ai.extractAndPrepareContent( - documents, + doc_list, documentIntents or [], docOperationId ) diff --git a/modules/serviceCenter/services/serviceGeneration/renderers/registry.py b/modules/serviceCenter/services/serviceGeneration/renderers/registry.py index b0c96e80..adb83275 100644 --- a/modules/serviceCenter/services/serviceGeneration/renderers/registry.py +++ b/modules/serviceCenter/services/serviceGeneration/renderers/registry.py @@ -72,11 +72,18 @@ class RendererRegistry: """Register a renderer class keyed by (format, outputStyle).""" try: supportedFormats = rendererClass.getSupportedFormats() - outputStyle = rendererClass.getOutputStyle() if hasattr(rendererClass, 'getOutputStyle') else 'document' priority = rendererClass.getPriority() if hasattr(rendererClass, 'getPriority') else 0 for formatName in supportedFormats: formatKey = formatName.lower() + # Per-format output style when renderer supports it (e.g. RendererText: txt→document, js→code) + if hasattr(rendererClass, 'getOutputStyle'): + try: + outputStyle = rendererClass.getOutputStyle(formatKey) + except TypeError: + outputStyle = rendererClass.getOutputStyle() if callable(getattr(rendererClass, 'getOutputStyle')) else 'document' + else: + outputStyle = 'document' registryKey = (formatKey, outputStyle) if registryKey in self._renderers: diff --git a/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py b/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py index de8f176a..483d7fbe 100644 --- a/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py +++ b/modules/serviceCenter/services/serviceSharepoint/mainServiceSharepoint.py @@ -692,12 +692,35 @@ class SharepointService: logger.error(f"Error extracting site from standard path '{pathQuery}': {str(e)}") return None + def _isGraphSiteId(self, sitePath: str) -> bool: + """Check if sitePath is a Graph API site ID (hostname,siteId,webId format with 2 commas).""" + if not sitePath or sitePath.count(',') != 2: + return False + parts = sitePath.split(',') + return len(parts) == 3 and all(p.strip() for p in parts) + async def getSiteByStandardPath(self, sitePath: str, allSites: Optional[List[Dict[str, Any]]] = None) -> Optional[Dict[str, Any]]: - """Get SharePoint site directly by Microsoft-standard path (/sites/SiteName).""" + """Get SharePoint site directly by Microsoft-standard path (/sites/SiteName) or by site ID.""" try: from urllib.parse import urlparse - hostname = None + # When sitePath is a Graph API site ID (host,siteId,webId), use sites/{id} directly + if self._isGraphSiteId(sitePath): + endpoint = f"sites/{sitePath}" + result = await self._makeGraphApiCall(endpoint) + if result and "error" not in result: + return { + "id": result.get("id"), + "displayName": result.get("displayName"), + "name": result.get("name"), + "webUrl": result.get("webUrl"), + "description": result.get("description"), + "createdDateTime": result.get("createdDateTime"), + "lastModifiedDateTime": result.get("lastModifiedDateTime") + } + return None + + hostname = None if allSites and len(allSites) > 0: webUrl = allSites[0].get("webUrl", "") hostname = urlparse(webUrl).hostname if webUrl else None @@ -777,6 +800,14 @@ class SharepointService: parsedPath = self.extractSiteFromStandardPath(pathQuery) if parsedPath: siteName = parsedPath.get("siteName") + # When siteName is Graph API composite ID (host,siteId,webId), match by exact id + if siteName and ',' in siteName: + exact = [s for s in allSites if s.get("id") == siteName] + if exact: + logger.info(f"Resolved site by exact ID: {siteName}") + return exact + logger.warning(f"No site found with exact ID '{siteName}'") + return [] sites = self.filterSitesByHint(allSites, siteName) if not sites: logger.warning(f"No SharePoint site found matching '{siteName}'") diff --git a/modules/serviceCenter/services/serviceSubscription/__init__.py b/modules/serviceCenter/services/serviceSubscription/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py b/modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py new file mode 100644 index 00000000..944da4f7 --- /dev/null +++ b/modules/serviceCenter/services/serviceSubscription/mainServiceSubscription.py @@ -0,0 +1,758 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Subscription Service — state-machine-based lifecycle management. + +Every mutation takes an explicit subscriptionId. No status-scan guessing. +See wiki/concepts/Subscription-State-Machine.md for the full state machine. +""" + +import logging +import time +from typing import Dict, Any, List, Optional +from datetime import datetime, timezone, timedelta + +from modules.datamodels.datamodelUam import User +from modules.datamodels.datamodelSubscription import ( + SubscriptionPlan, + MandateSubscription, + SubscriptionStatusEnum, + BillingPeriodEnum, + OPERATIVE_STATUSES, + _getPlan, + _getSelectablePlans, +) +from modules.interfaces.interfaceDbSubscription import ( + getInterface as getSubscriptionInterface, + InvalidTransitionError, +) + +logger = logging.getLogger(__name__) + +SUBSCRIPTION_CACHE_TTL_SECONDS = 60 +_STALE_PENDING_SECONDS = 30 * 60 + +_subscriptionServices: Dict[str, "SubscriptionService"] = {} +_statusCache: Dict[str, tuple] = {} + + +def getService(currentUser: User, mandateId: str) -> "SubscriptionService": + cacheKey = f"{currentUser.id}_{mandateId}" + if cacheKey not in _subscriptionServices: + _subscriptionServices[cacheKey] = SubscriptionService(currentUser, mandateId) + else: + _subscriptionServices[cacheKey].setContext(currentUser, mandateId) + return _subscriptionServices[cacheKey] + + +class SubscriptionService: + """State-machine-based subscription service. + All mutations use explicit subscriptionId. No scan-based writes.""" + + def __init__(self, contextOrUser, mandateId=None, get_service=None): + if mandateId is not None and callable(mandateId): + ctx = contextOrUser + self.currentUser = ctx.user + self.mandateId = ctx.mandate_id or "" + elif get_service is not None and hasattr(contextOrUser, "user"): + ctx = contextOrUser + self.currentUser = ctx.user + self.mandateId = ctx.mandate_id or "" + else: + self.currentUser = contextOrUser + self.mandateId = mandateId or "" + self._interface = getSubscriptionInterface(self.currentUser, self.mandateId) + + def setContext(self, currentUser: User, mandateId: str): + self.currentUser = currentUser + self.mandateId = mandateId + self._interface = getSubscriptionInterface(currentUser, mandateId) + + # ========================================================================= + # Billing gate (cached, read-only) + # ========================================================================= + + def assertActive(self, mandateId: str = None) -> SubscriptionStatusEnum: + """Return subscription status for billing decisions. Uses TTL cache. + This is the ONLY method that works by mandateId (read-only).""" + mid = mandateId or self.mandateId + now = time.monotonic() + + cached = _statusCache.get(mid) + if cached and cached[1] > now: + return cached[0] + + status = self._interface.assertActive(mid) + _statusCache[mid] = (status, now + SUBSCRIPTION_CACHE_TTL_SECONDS) + return status + + def invalidateCache(self, mandateId: str = None): + mid = mandateId or self.mandateId + _statusCache.pop(mid, None) + + # ========================================================================= + # Capacity (delegation) + # ========================================================================= + + def assertCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> bool: + return self._interface.assertCapacity(mandateId or self.mandateId, resourceType, delta) + + # ========================================================================= + # Read operations + # ========================================================================= + + def getById(self, subscriptionId: str) -> Optional[Dict[str, Any]]: + return self._interface.getById(subscriptionId) + + def getOperativeSubscription(self, mandateId: str = None) -> Optional[Dict[str, Any]]: + return self._interface.getOperativeForMandate(mandateId or self.mandateId) + + def getScheduledSubscription(self, mandateId: str = None) -> Optional[Dict[str, Any]]: + return self._interface.getScheduledForMandate(mandateId or self.mandateId) + + def listSubscriptions(self, mandateId: str = None, statusFilter=None) -> List[Dict[str, Any]]: + return self._interface.listForMandate(mandateId or self.mandateId, statusFilter) + + def getSelectablePlans(self) -> List[SubscriptionPlan]: + return _getSelectablePlans() + + def getPlan(self, planKey: str) -> Optional[SubscriptionPlan]: + return _getPlan(planKey) + + # ========================================================================= + # T1/T2: Plan activation (creates PENDING, returns checkout URL) + # ========================================================================= + + def activatePlan(self, mandateId: str, planKey: str, returnUrl: str) -> Dict[str, Any]: + """Create a new subscription as PENDING and start the checkout flow. + + - Free/trial plans: immediately ACTIVE/TRIALING (no checkout). + - Paid plans with active predecessor: PENDING -> checkout -> SCHEDULED on confirmation. + - Paid plans without predecessor: PENDING -> checkout -> ACTIVE on confirmation. + + Cleans up any existing PENDING/SCHEDULED for this mandate first (by ID).""" + mid = mandateId or self.mandateId + plan = _getPlan(planKey) + if not plan: + raise ValueError(f"Unknown plan: {planKey}") + + isPaid = plan.billingPeriod != BillingPeriodEnum.NONE and not plan.trialDays + currentOperative = self._interface.getOperativeForMandate(mid) + + self._cleanupPreparatorySubscriptions(mid) + + now = datetime.now(timezone.utc) + if plan.trialDays: + initialStatus = SubscriptionStatusEnum.TRIALING + elif isPaid: + initialStatus = SubscriptionStatusEnum.PENDING + else: + initialStatus = SubscriptionStatusEnum.ACTIVE + + sub = MandateSubscription( + mandateId=mid, + planKey=planKey, + status=initialStatus, + recurring=plan.autoRenew and not plan.trialDays, + startedAt=now, + currentPeriodStart=now, + snapshotPricePerUserCHF=plan.pricePerUserCHF, + snapshotPricePerInstanceCHF=plan.pricePerFeatureInstanceCHF, + ) + + if plan.trialDays: + sub.trialEndsAt = now + timedelta(days=plan.trialDays) + + if plan.billingPeriod == BillingPeriodEnum.MONTHLY: + sub.currentPeriodEnd = now + timedelta(days=30) + elif plan.billingPeriod == BillingPeriodEnum.YEARLY: + sub.currentPeriodEnd = now + timedelta(days=365) + + created = self._interface.createSubscription(sub) + + from urllib.parse import urlparse + parsed = urlparse(returnUrl) if returnUrl else None + pUrl = f"{parsed.scheme}://{parsed.netloc}" if parsed and parsed.scheme else "" + + if isPaid: + try: + checkoutUrl = self._createCheckoutSession(mid, plan, created, currentOperative, returnUrl) + created["redirectUrl"] = checkoutUrl + except Exception as e: + self._interface.forceExpire(created["id"]) + self.invalidateCache(mid) + raise ValueError(f"Subscription konnte nicht erstellt werden: {e}") from e + else: + if currentOperative: + self._expireOperative(currentOperative["id"], mid) + _notifySubscriptionChange(mid, "activated", plan, subscriptionRecord=created, platformUrl=pUrl) + + self.invalidateCache(mid) + return created + + def _cleanupPreparatorySubscriptions(self, mandateId: str) -> None: + """Expire any existing PENDING or SCHEDULED subscriptions for this mandate (by ID).""" + preparatory = self._interface.listForMandate( + mandateId, [SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.SCHEDULED], + ) + for sub in preparatory: + subId = sub["id"] + currentStatus = SubscriptionStatusEnum(sub["status"]) + stripeSubId = sub.get("stripeSubscriptionId") + + if stripeSubId and currentStatus == SubscriptionStatusEnum.SCHEDULED: + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + stripe.Subscription.cancel(stripeSubId) + except Exception as e: + logger.error("Failed to cancel Stripe sub %s during cleanup: %s", stripeSubId, e) + + self._interface.transitionStatus(subId, currentStatus, SubscriptionStatusEnum.EXPIRED) + logger.info("Cleaned up %s subscription %s for mandate %s", currentStatus.value, subId, mandateId) + + def _expireOperative(self, subscriptionId: str, mandateId: str) -> None: + """Expire the current operative subscription (used when a free/trial plan replaces it).""" + sub = self._interface.getById(subscriptionId) + if not sub: + return + currentStatus = SubscriptionStatusEnum(sub["status"]) + if currentStatus in OPERATIVE_STATUSES: + stripeSubId = sub.get("stripeSubscriptionId") + if stripeSubId: + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + stripe.Subscription.cancel(stripeSubId) + except Exception as e: + logger.error("Failed to cancel Stripe sub %s: %s", stripeSubId, e) + self._interface.transitionStatus(subscriptionId, currentStatus, SubscriptionStatusEnum.EXPIRED) + + def _createCheckoutSession( + self, mandateId: str, plan: SubscriptionPlan, subRecord: Dict[str, Any], + currentOperative: Optional[Dict[str, Any]], returnUrl: str, + ) -> str: + """Create a Stripe Checkout Session. If a predecessor exists, delays billing + via trial_end to start after the predecessor's period end.""" + from modules.shared.stripeClient import getStripeClient + from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import getStripePricesForPlan + + stripe = getStripeClient() + priceMapping = getStripePricesForPlan(plan.planKey) + if not priceMapping or (not priceMapping.stripePriceIdUsers and not priceMapping.stripePriceIdInstances): + raise ValueError(f"Stripe Price IDs not provisioned for plan {plan.planKey}") + + stripeCustomerId = self._resolveStripeCustomer(mandateId) + if not stripeCustomerId: + raise ValueError(f"Could not resolve Stripe customer for mandate {mandateId}") + + activeUsers = self._interface.countActiveUsers(mandateId) + activeInstances = self._interface.countActiveFeatureInstances(mandateId) + + lineItems = [] + if priceMapping.stripePriceIdUsers: + lineItems.append({"price": priceMapping.stripePriceIdUsers, "quantity": max(activeUsers, 1)}) + if priceMapping.stripePriceIdInstances and activeInstances > 0: + lineItems.append({"price": priceMapping.stripePriceIdInstances, "quantity": activeInstances}) + + if not returnUrl: + raise ValueError("returnUrl is required for paid subscription checkout") + + from urllib.parse import urlparse + parsedReturn = urlparse(returnUrl) + platformUrl = f"{parsedReturn.scheme}://{parsedReturn.netloc}" if parsedReturn.scheme else "" + + separator = "&" if "?" in returnUrl else "?" + successUrl = f"{returnUrl}{separator}success=true&session_id={{CHECKOUT_SESSION_ID}}" + cancelUrl = f"{returnUrl}{separator}canceled=true" + + subscriptionData: Dict[str, Any] = { + "metadata": { + "mandateId": mandateId, + "subscriptionRecordId": subRecord["id"], + "planKey": plan.planKey, + "platformUrl": platformUrl, + }, + } + + if currentOperative and currentOperative.get("currentPeriodEnd"): + periodEnd = currentOperative["currentPeriodEnd"] + if isinstance(periodEnd, str): + periodEnd = datetime.fromisoformat(periodEnd) + trialEndTs = int(periodEnd.timestamp()) + subscriptionData["trial_end"] = trialEndTs + self._interface.updateFields(subRecord["id"], {"effectiveFrom": periodEnd.isoformat()}) + + session = stripe.checkout.Session.create( + mode="subscription", + customer=stripeCustomerId, + line_items=lineItems, + success_url=successUrl, + cancel_url=cancelUrl, + subscription_data=subscriptionData, + ) + + if not session or not session.url: + raise ValueError("Stripe Checkout Session creation failed") + + logger.info("Checkout session %s created for mandate %s, plan %s", session.id, mandateId, plan.planKey) + return session.url + + def _resolveStripeCustomer(self, mandateId: str) -> Optional[str]: + try: + from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface + billingIf = getBillingInterface(self.currentUser, mandateId) + settings = billingIf.getSettings(mandateId) + if not settings: + return None + customerId = settings.get("stripeCustomerId") + if customerId: + return customerId + + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + + mandateLabel = mandateId + try: + from modules.datamodels.datamodelUam import Mandate + from modules.security.rootAccess import getRootDbAppConnector + appDb = getRootDbAppConnector() + rows = appDb.getRecordset(Mandate, recordFilter={"id": mandateId}) + if rows: + mandateLabel = rows[0].get("label") or rows[0].get("name") or mandateId + except Exception: + pass + + customer = stripe.Customer.create(name=mandateLabel, metadata={"mandateId": mandateId}) + billingIf.updateSettings(settings["id"], {"stripeCustomerId": customer.id}) + logger.info("Stripe customer %s created for mandate %s", customer.id, mandateId) + return customer.id + except Exception as e: + logger.error("_resolveStripeCustomer(%s) failed: %s", mandateId, e) + return None + + # ========================================================================= + # T7: Cancel (set recurring=false) + # ========================================================================= + + def cancelSubscription(self, subscriptionId: str) -> Dict[str, Any]: + """Cancel a subscription (T7: set recurring=false, Stripe cancel_at_period_end). + The subscription stays ACTIVE until its period ends.""" + sub = self._interface.getById(subscriptionId) + if not sub: + raise ValueError(f"Subscription {subscriptionId} not found") + + status = sub.get("status", "") + mandateId = sub["mandateId"] + + if status == SubscriptionStatusEnum.PENDING.value: + result = self._interface.transitionStatus( + subscriptionId, SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.EXPIRED, + ) + self.invalidateCache(mandateId) + return result + + if status == SubscriptionStatusEnum.SCHEDULED.value: + stripeSubId = sub.get("stripeSubscriptionId") + if stripeSubId: + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + stripe.Subscription.cancel(stripeSubId) + except Exception as e: + logger.error("Failed to cancel Stripe sub %s: %s", stripeSubId, e) + result = self._interface.transitionStatus( + subscriptionId, SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.EXPIRED, + ) + self.invalidateCache(mandateId) + return result + + if status != SubscriptionStatusEnum.ACTIVE.value: + raise ValueError(f"Cannot cancel subscription in status {status}") + + if not sub.get("recurring", True): + raise ValueError("Subscription is already cancelled (non-recurring)") + + stripeSubId = sub.get("stripeSubscriptionId") + pUrl = "" + if stripeSubId: + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + stripeSub = stripe.Subscription.modify(stripeSubId, cancel_at_period_end=True) + pUrl = (stripeSub.get("metadata") or {}).get("platformUrl", "") + except Exception as e: + logger.error("Failed to set cancel_at_period_end for %s: %s", stripeSubId, e) + + result = self._interface.updateFields(subscriptionId, {"recurring": False}) + self.invalidateCache(mandateId) + + plan = _getPlan(sub.get("planKey", "")) + _notifySubscriptionChange(mandateId, "cancelled", plan, subscriptionRecord=sub, platformUrl=pUrl) + return result + + # ========================================================================= + # T8: Reactivate (set recurring=true) + # ========================================================================= + + def reactivateSubscription(self, subscriptionId: str) -> Dict[str, Any]: + """Reactivate a cancelled subscription before its period ends (T8: recurring=true).""" + sub = self._interface.getById(subscriptionId) + if not sub: + raise ValueError(f"Subscription {subscriptionId} not found") + + if sub.get("status") != SubscriptionStatusEnum.ACTIVE.value: + raise ValueError(f"Can only reactivate ACTIVE subscriptions, got {sub.get('status')}") + if sub.get("recurring", True): + raise ValueError("Subscription is already recurring") + + periodEnd = sub.get("currentPeriodEnd") + if periodEnd: + if isinstance(periodEnd, str): + periodEnd = datetime.fromisoformat(periodEnd) + if periodEnd <= datetime.now(timezone.utc): + raise ValueError("Cannot reactivate — period has already ended") + + stripeSubId = sub.get("stripeSubscriptionId") + if stripeSubId: + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + stripe.Subscription.modify(stripeSubId, cancel_at_period_end=False) + except Exception as e: + logger.error("Failed to reactivate Stripe sub %s: %s", stripeSubId, e) + + result = self._interface.updateFields(subscriptionId, {"recurring": True}) + self.invalidateCache(sub["mandateId"]) + return result + + # ========================================================================= + # T13: Sysadmin force-cancel + # ========================================================================= + + def forceCancel(self, subscriptionId: str) -> Dict[str, Any]: + """Sysadmin force-cancel: immediately expire any non-terminal subscription.""" + sub = self._interface.getById(subscriptionId) + if not sub: + raise ValueError(f"Subscription {subscriptionId} not found") + + stripeSubId = sub.get("stripeSubscriptionId") + pUrl = "" + if stripeSubId: + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + stripeSub = stripe.Subscription.retrieve(stripeSubId) + pUrl = (stripeSub.get("metadata") or {}).get("platformUrl", "") + stripe.Subscription.cancel(stripeSubId) + except Exception as e: + logger.error("Failed to cancel Stripe sub %s: %s", stripeSubId, e) + + result = self._interface.forceExpire(subscriptionId) + mandateId = sub["mandateId"] + self.invalidateCache(mandateId) + + plan = _getPlan(sub.get("planKey", "")) + _notifySubscriptionChange(mandateId, "force_cancelled", plan, subscriptionRecord=sub, platformUrl=pUrl) + return result + + # ========================================================================= + # T6: Trial expiry + # ========================================================================= + + def handleTrialExpiry(self, subscriptionId: str) -> None: + """Expire a trial subscription (T6: TRIALING -> EXPIRED).""" + sub = self._interface.getById(subscriptionId) + if not sub or sub.get("status") != SubscriptionStatusEnum.TRIALING.value: + return + + self._interface.transitionStatus( + subscriptionId, SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.EXPIRED, + ) + self.invalidateCache(sub["mandateId"]) + + plan = _getPlan(sub.get("planKey", "")) + successorPlan = _getPlan(plan.successorPlanKey) if plan and plan.successorPlanKey else None + _notifySubscriptionChange(sub["mandateId"], "trial_expired", successorPlan) + logger.info("Trial expired for subscription %s", subscriptionId) + + # ========================================================================= + # Stripe quantity sync + # ========================================================================= + + def syncStripeQuantity(self, subscriptionId: str): + self._interface.syncQuantityToStripe(subscriptionId) + + +# ============================================================================ +# Notifications +# ============================================================================ + +def _notifySubscriptionChange( + mandateId: str, + event: str, + plan: Optional[SubscriptionPlan] = None, + subscriptionRecord: Optional[Dict[str, Any]] = None, + platformUrl: str = "", +) -> None: + try: + from modules.shared.notifyMandateAdmins import notifyMandateAdmins + + planLabel = (plan.title.get("de") or plan.title.get("en") or plan.planKey) if plan else "—" + platformHint = f"Plattform: {platformUrl}" if platformUrl else "" + + rawHtmlBlock: Optional[str] = None + + if event == "activated" and plan and subscriptionRecord: + rawHtmlBlock = _buildInvoiceSummaryHtml(plan, subscriptionRecord, mandateId, platformUrl) + elif event in ("cancelled", "force_cancelled") and subscriptionRecord: + rawHtmlBlock = _buildCancelSummaryHtml(subscriptionRecord, platformUrl) + + templates: Dict[str, Dict[str, Any]] = { + "activated": { + "subject": f"[PowerOn] Abonnement aktiviert — {planLabel}", + "headline": "Abonnement aktiviert", + "paragraphs": [ + p for p in [ + f"Das Abonnement wurde auf den Plan «{planLabel}» aktiviert.", + platformHint, + "Sie können Ihr Abonnement jederzeit unter Billing-Verwaltung › Abonnement einsehen und verwalten.", + ] if p + ], + }, + "cancelled": { + "subject": f"[PowerOn] Abonnement gekündigt — {planLabel}", + "headline": "Abonnement gekündigt", + "paragraphs": [ + p for p in [ + f"Das Abonnement «{planLabel}» wurde gekündigt.", + platformHint, + "Die Kündigung wird zum Ende der aktuellen bezahlten Periode wirksam. Bis dahin bleibt der volle Zugang bestehen.", + ] if p + ], + }, + "force_cancelled": { + "subject": f"[PowerOn] Abonnement sofort beendet — {planLabel}", + "headline": "Abonnement sofort beendet", + "paragraphs": [ + p for p in [ + f"Das Abonnement «{planLabel}» wurde durch den Plattform-Administrator sofort beendet.", + platformHint, + "Der Zugang wurde per sofort deaktiviert. Bei Fragen wenden Sie sich an den Plattform-Support.", + ] if p + ], + }, + "trial_expired": { + "subject": "[PowerOn] Testphase abgelaufen", + "headline": "Testphase abgelaufen", + "paragraphs": [ + p for p in [ + "Die kostenlose Testphase ist abgelaufen.", + platformHint, + "Bitte wählen Sie einen Plan unter Billing-Verwaltung › Abonnement, damit der Zugang nicht unterbrochen wird.", + ] if p + ], + }, + "payment_failed": { + "subject": f"[PowerOn] Zahlung fehlgeschlagen — {planLabel}", + "headline": "Zahlung fehlgeschlagen", + "paragraphs": [ + p for p in [ + f"Die Zahlung für das Abonnement «{planLabel}» ist fehlgeschlagen.", + platformHint, + "Bitte aktualisieren Sie Ihr Zahlungsmittel unter Billing-Verwaltung.", + ] if p + ], + }, + } + + tpl = templates.get(event, { + "subject": f"[PowerOn] Abonnement-Änderung — {planLabel}", + "headline": "Abonnement-Änderung", + "paragraphs": [f"Änderung am Abonnement «{planLabel}»."], + }) + + notifyMandateAdmins( + mandateId, tpl["subject"], tpl["headline"], tpl["paragraphs"], + rawHtmlBlock=rawHtmlBlock, + ) + except Exception as e: + logger.error("_notifySubscriptionChange failed for mandate %s event %s: %s", mandateId, event, e) + + +def _buildInvoiceSummaryHtml( + plan: SubscriptionPlan, + subRecord: Dict[str, Any], + mandateId: str, + platformUrl: str = "", +) -> str: + """Build an HTML invoice summary block for inclusion in the activation email.""" + import html as htmlmod + from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface + + subInterface = getSubRootInterface() + userCount = subInterface.countActiveUsers(mandateId) + instanceCount = subInterface.countActiveFeatureInstances(mandateId) + + userPrice = plan.pricePerUserCHF + instancePrice = plan.pricePerFeatureInstanceCHF + userTotal = userCount * userPrice + instanceTotal = instanceCount * instancePrice + netTotal = userTotal + instanceTotal + + periodLabel = {"MONTHLY": "Monatlich", "YEARLY": "Jährlich"}.get(plan.billingPeriod, plan.billingPeriod) + + def _chf(amount: float) -> str: + return f"CHF {amount:,.2f}".replace(",", "'") + + rows = "" + if userPrice > 0: + rows += ( + f'Benutzer-Lizenzen' + f'{userCount} × {_chf(userPrice)}' + f'{_chf(userTotal)}\n' + ) + if instancePrice > 0: + rows += ( + f'Feature-Instanzen' + f'{instanceCount} × {_chf(instancePrice)}' + f'{_chf(instanceTotal)}\n' + ) + + invoiceLink = "" + stripeSubId = subRecord.get("stripeSubscriptionId") + if stripeSubId: + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + invoices = stripe.Invoice.list(subscription=stripeSubId, limit=1) + if invoices.data: + hostedUrl = invoices.data[0].get("hosted_invoice_url", "") + if hostedUrl: + invoiceLink = ( + f'

' + f'' + f'Vollständige Rechnung mit MwSt-Ausweis anzeigen

\n' + ) + except Exception as e: + logger.warning("Could not fetch Stripe invoice URL for sub %s: %s", stripeSubId, e) + + return ( + f'' + f'' + f'' + f'' + f'' + f'' + f'{rows}' + f'' + f'' + f'' + f'' + f'' + f'
PositionMenge × PreisTotal
Netto-Total ({periodLabel}){_chf(netTotal)}
' + f'{invoiceLink}' + ) + + +def _buildCancelSummaryHtml(subRecord: Dict[str, Any], platformUrl: str = "") -> str: + """Build an HTML block with billing link and Stripe invoice link for cancel emails.""" + import html as htmlmod + + parts: list[str] = [] + + stripeSubId = subRecord.get("stripeSubscriptionId") + if stripeSubId: + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + invoices = stripe.Invoice.list(subscription=stripeSubId, limit=1) + if invoices.data: + hostedUrl = invoices.data[0].get("hosted_invoice_url", "") + if hostedUrl: + parts.append( + f'

' + f'' + f'Letzte Stripe-Rechnung anzeigen

' + ) + except Exception as e: + logger.warning("Could not fetch Stripe invoice URL for sub %s: %s", stripeSubId, e) + + return "\n".join(parts) if parts else "" + + +# ============================================================================ +# Exception Classes +# ============================================================================ + +SUBSCRIPTION_USER_ACTION_UPGRADE = "UPGRADE_SUBSCRIPTION" +SUBSCRIPTION_USER_ACTION_REACTIVATE = "REACTIVATE_SUBSCRIPTION" +SUBSCRIPTION_USER_ACTION_ADD_PAYMENT = "ADD_PAYMENT_METHOD" + +SUBSCRIPTION_REASONS = { + "SUBSCRIPTION_INACTIVE", + "SUBSCRIPTION_PAYMENT_REQUIRED", + "SUBSCRIPTION_PAYMENT_PENDING", + "SUBSCRIPTION_EXPIRED", +} + + +def _subscriptionReasonForStatus(status: SubscriptionStatusEnum) -> str: + if status == SubscriptionStatusEnum.PENDING: + return "SUBSCRIPTION_PAYMENT_PENDING" + if status == SubscriptionStatusEnum.PAST_DUE: + return "SUBSCRIPTION_PAYMENT_REQUIRED" + if status == SubscriptionStatusEnum.EXPIRED: + return "SUBSCRIPTION_EXPIRED" + return "SUBSCRIPTION_INACTIVE" + + +def _subscriptionUserActionForStatus(status: SubscriptionStatusEnum) -> str: + if status in (SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.PENDING): + return SUBSCRIPTION_USER_ACTION_ADD_PAYMENT + return SUBSCRIPTION_USER_ACTION_UPGRADE + + +class SubscriptionInactiveException(Exception): + def __init__(self, status: SubscriptionStatusEnum, mandateId: str = "", message: Optional[str] = None): + self.status = status + self.mandateId = mandateId + self.reason = _subscriptionReasonForStatus(status) + self.userAction = _subscriptionUserActionForStatus(status) + self.message = message or ( + "Kein aktives Abonnement für diesen Mandanten. Bitte wählen Sie einen Plan unter Billing." + ) + super().__init__(self.message) + + def toClientDict(self) -> Dict[str, Any]: + out: Dict[str, Any] = { + "error": self.reason, "message": self.message, + "userAction": self.userAction, "subscriptionUiPath": "/admin/billing?tab=subscription", + } + if self.mandateId: + out["mandateId"] = self.mandateId + return out + + +class SubscriptionCapacityException(Exception): + def __init__(self, resourceType: str, currentCount: int, maxAllowed: int, message: Optional[str] = None): + self.resourceType = resourceType + self.currentCount = currentCount + self.maxAllowed = maxAllowed + self.message = message or ( + f"Ihr Plan erlaubt maximal {maxAllowed} {'Benutzer' if resourceType == 'users' else 'Feature-Instanzen'} " + f"(aktuell {currentCount}). Bitte wechseln Sie zu einem grösseren Plan." + ) + super().__init__(self.message) + + def toClientDict(self) -> Dict[str, Any]: + return { + "error": f"SUBSCRIPTION_{self.resourceType.upper()}_LIMIT", + "currentCount": self.currentCount, "maxAllowed": self.maxAllowed, + "message": self.message, "userAction": SUBSCRIPTION_USER_ACTION_UPGRADE, + "subscriptionUiPath": "/admin/billing?tab=subscription", + } + + +SubscriptionService.SubscriptionInactiveException = SubscriptionInactiveException +SubscriptionService.SubscriptionCapacityException = SubscriptionCapacityException diff --git a/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py b/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py new file mode 100644 index 00000000..38ac29e1 --- /dev/null +++ b/modules/serviceCenter/services/serviceSubscription/stripeBootstrap.py @@ -0,0 +1,273 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Auto-provision Stripe Products and Prices from the built-in plan catalog. + +Creates separate Stripe Products for user licenses and feature instances +so that invoice line items show clear, descriptive names: + - "Benutzer-Lizenzen" + - "Feature-Instanzen" + +Idempotent — safe to call on every startup. +""" + +import logging +from typing import Dict, Optional + +from modules.connectors.connectorDbPostgre import DatabaseConnector +from modules.shared.configuration import APP_CONFIG +from modules.datamodels.datamodelSubscription import ( + BUILTIN_PLANS, + SubscriptionPlan, + BillingPeriodEnum, + StripePlanPrice, +) + +logger = logging.getLogger(__name__) + +_BILLING_DATABASE = "poweron_billing" +_METADATA_KEY = "poweron_plan_key" +_METADATA_LINE_TYPE = "poweron_line_type" + +_PERIOD_TO_STRIPE = { + BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1}, + BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1}, +} + + +def _getBillingDb() -> DatabaseConnector: + return DatabaseConnector( + dbDatabase=_BILLING_DATABASE, + dbHost=APP_CONFIG.get("DB_HOST", "localhost"), + dbPort=int(APP_CONFIG.get("DB_PORT", "5432")), + dbUser=APP_CONFIG.get("DB_USER"), + dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"), + ) + + +def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]: + try: + rows = db.getRecordset(StripePlanPrice) + result = {} + for row in rows: + pk = row.get("planKey") + if pk: + result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")}) + return result + except Exception as e: + logger.warning("Could not load StripePlanPrice records: %s", e) + return {} + + +def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]: + """Search Stripe for a product tagged with plan key + line type.""" + try: + products = stripe.Product.search( + query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"', + limit=1, + ) + if products.data: + return products.data[0].id + except Exception: + try: + products = stripe.Product.search( + query=f'metadata["{_METADATA_KEY}"]:"{planKey}"', + limit=10, + ) + for p in products.data: + meta = p.get("metadata") or {} + if meta.get(_METADATA_LINE_TYPE) == lineType: + return p.id + except Exception: + pass + return None + + +def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str: + product = stripe.Product.create( + name=name, + description=description, + metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType}, + ) + logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType) + return product.id + + +def _findExistingStripePrice(stripe, productId: str, unitAmount: int, interval: str) -> Optional[str]: + try: + prices = stripe.Price.list(product=productId, active=True, limit=50) + for p in prices.data: + recurring = p.get("recurring") or {} + if p.get("unit_amount") == unitAmount and recurring.get("interval") == interval: + return p.id + except Exception: + pass + return None + + +def _getStripePriceAmount(stripe, priceId: str) -> Optional[int]: + """Retrieve the unit_amount (in Rappen) of an existing Stripe Price.""" + try: + price = stripe.Price.retrieve(priceId) + return price.get("unit_amount") if price else None + except Exception: + return None + + +def _reconcilePrice(stripe, productId: str, oldPriceId: str, expectedCHF: float, interval: str, nickname: str) -> str: + """If the stored Stripe Price has a different amount, create a new one and deactivate the old.""" + expectedCents = int(expectedCHF * 100) + actualCents = _getStripePriceAmount(stripe, oldPriceId) + + if actualCents == expectedCents: + return oldPriceId + + logger.warning( + "Price drift detected for %s: Stripe has %s Rappen, catalog expects %s Rappen. Rotating price.", + oldPriceId, actualCents, expectedCents, + ) + + existingMatch = _findExistingStripePrice(stripe, productId, expectedCents, interval) + if existingMatch: + newPriceId = existingMatch + else: + newPriceId = _createStripePrice(stripe, productId, expectedCHF, interval, nickname) + + try: + stripe.Price.modify(oldPriceId, active=False) + logger.info("Deactivated old Stripe Price %s", oldPriceId) + except Exception as e: + logger.warning("Could not deactivate old price %s: %s", oldPriceId, e) + + return newPriceId + + +def _createStripePrice(stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str) -> str: + price = stripe.Price.create( + product=productId, + unit_amount=int(unitAmountCHF * 100), + currency="chf", + recurring={"interval": interval}, + nickname=nickname, + ) + logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval) + return price.id + + +def bootstrapStripePrices() -> None: + """Ensure all paid plans have separate Stripe Products for users and instances.""" + try: + from modules.shared.stripeClient import getStripeClient + stripe = getStripeClient() + except ValueError as e: + logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e) + return + + db = _getBillingDb() + existing = _loadExistingMappings(db) + + for planKey, plan in BUILTIN_PLANS.items(): + if plan.billingPeriod == BillingPeriodEnum.NONE: + continue + if plan.pricePerUserCHF == 0 and plan.pricePerFeatureInstanceCHF == 0: + continue + + stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod) + if not stripePeriod: + continue + + interval = stripePeriod["interval"] + + if planKey in existing: + mapping = existing[planKey] + hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances + hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances + if hasAllPrices and hasAllProducts: + changed = False + reconciledUsers = _reconcilePrice( + stripe, mapping.stripeProductIdUsers, mapping.stripePriceIdUsers, + plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz", + ) + if reconciledUsers != mapping.stripePriceIdUsers: + changed = True + + reconciledInstances = _reconcilePrice( + stripe, mapping.stripeProductIdInstances, mapping.stripePriceIdInstances, + plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Feature-Instanz", + ) + if reconciledInstances != mapping.stripePriceIdInstances: + changed = True + + if changed: + db.recordModify(StripePlanPrice, mapping.id, { + "stripePriceIdUsers": reconciledUsers, + "stripePriceIdInstances": reconciledInstances, + }) + logger.info("Reconciled Stripe prices for plan %s: users=%s, instances=%s", planKey, reconciledUsers, reconciledInstances) + else: + logger.debug("Stripe prices up-to-date for plan %s", planKey) + continue + + productIdUsers = None + productIdInstances = None + priceIdUsers = None + priceIdInstances = None + + if plan.pricePerUserCHF > 0: + productIdUsers = _findStripeProduct(stripe, planKey, "users") + if not productIdUsers: + productIdUsers = _createStripeProduct( + stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title.get('de', planKey)}", + planKey, "users", + ) + priceIdUsers = _findExistingStripePrice(stripe, productIdUsers, int(plan.pricePerUserCHF * 100), interval) + if not priceIdUsers: + priceIdUsers = _createStripePrice( + stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz", + ) + + if plan.pricePerFeatureInstanceCHF > 0: + productIdInstances = _findStripeProduct(stripe, planKey, "instances") + if not productIdInstances: + productIdInstances = _createStripeProduct( + stripe, "Feature-Instanzen", f"Feature-Instanzen für {plan.title.get('de', planKey)}", + planKey, "instances", + ) + priceIdInstances = _findExistingStripePrice( + stripe, productIdInstances, int(plan.pricePerFeatureInstanceCHF * 100), interval, + ) + if not priceIdInstances: + priceIdInstances = _createStripePrice( + stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval, + f"{planKey} — Feature-Instanz", + ) + + persistData = { + "stripeProductId": "", + "stripeProductIdUsers": productIdUsers, + "stripeProductIdInstances": productIdInstances, + "stripePriceIdUsers": priceIdUsers, + "stripePriceIdInstances": priceIdInstances, + } + + if planKey in existing: + db.recordModify(StripePlanPrice, existing[planKey].id, persistData) + else: + db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump()) + + logger.info( + "Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s", + planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances, + ) + + +def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]: + """Load the persisted Stripe IDs for a plan.""" + try: + db = _getBillingDb() + rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey}) + if rows: + return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")}) + except Exception as e: + logger.error("Error loading Stripe prices for plan %s: %s", planKey, e) + return None diff --git a/modules/shared/attributeUtils.py b/modules/shared/attributeUtils.py index 863d7f36..6a857d85 100644 --- a/modules/shared/attributeUtils.py +++ b/modules/shared/attributeUtils.py @@ -258,6 +258,27 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag attributes.append(attr_def) + # Append system timestamp fields (set automatically by DatabaseConnector) + systemTimestampFields = [ + ("_createdAt", {"en": "Created at", "de": "Erstellt am", "fr": "Créé le"}), + ("_modifiedAt", {"en": "Modified at", "de": "Geändert am", "fr": "Modifié le"}), + ] + for sysName, sysLabels in systemTimestampFields: + attributes.append({ + "name": sysName, + "type": "timestamp", + "required": False, + "description": "", + "label": sysLabels.get(userLanguage, sysLabels["en"]), + "placeholder": "", + "editable": False, + "visible": True, + "order": len(attributes), + "readonly": True, + "options": None, + "default": None, + }) + return {"model": model_label, "attributes": attributes} diff --git a/modules/shared/configuration.py b/modules/shared/configuration.py index 73e2da90..721ce448 100644 --- a/modules/shared/configuration.py +++ b/modules/shared/configuration.py @@ -66,7 +66,7 @@ class Configuration: self._configMtime = currentMtime try: - with open(configPath, 'r') as f: + with open(configPath, 'r', encoding='utf-8') as f: lines = f.readlines() i = 0 diff --git a/modules/shared/notifyMandateAdmins.py b/modules/shared/notifyMandateAdmins.py new file mode 100644 index 00000000..27445afb --- /dev/null +++ b/modules/shared/notifyMandateAdmins.py @@ -0,0 +1,285 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Central notification utility for mandate administrators. + +All mandate-level notifications (subscription changes, billing warnings, etc.) +MUST go through notifyMandateAdmins() to ensure consistent recipient resolution +and delivery. + +Recipients are the union of: +1. BillingSettings.notifyEmails for the mandate (configured contact addresses) +2. All users with the mandate-level "admin" RBAC role +""" + +from __future__ import annotations + +import html +import json +import logging +from typing import Any, Dict, List, Optional, Set + +from modules.datamodels.datamodelMessaging import MessagingChannel +from modules.interfaces.interfaceMessaging import getInterface as getMessagingInterface + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Recipient resolution +# ============================================================================ + + +def _normalizeEmailList(raw: Any) -> List[str]: + """Parse notifyEmails which can be a list, JSON string, or single address.""" + if raw is None: + return [] + if isinstance(raw, list): + return [str(e).strip().lower() for e in raw if str(e).strip()] + if isinstance(raw, str): + s = raw.strip() + if not s: + return [] + if s.startswith("["): + try: + parsed = json.loads(s) + if isinstance(parsed, list): + return [str(e).strip().lower() for e in parsed if str(e).strip()] + except Exception: + pass + return [s.lower()] + return [] + + +def _resolveMandateContactEmails(mandateId: str) -> List[str]: + """Get the configured notifyEmails from BillingSettings.""" + try: + from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface + from modules.security.rootAccess import getRootUser + billingIf = getBillingInterface(getRootUser(), mandateId) + settings = billingIf.getSettings(mandateId) or {} + return _normalizeEmailList(settings.get("notifyEmails")) + except Exception as e: + logger.warning("Could not resolve BillingSettings.notifyEmails for mandate %s: %s", mandateId, e) + return [] + + +def _resolveMandateAdminEmails(mandateId: str) -> List[str]: + """Resolve all admin users of a mandate via RBAC and return their emails.""" + emails: List[str] = [] + try: + from modules.interfaces.interfaceDbApp import getRootInterface + rootIf = getRootInterface() + userMandates = rootIf.getUserMandatesByMandate(mandateId) + for um in userMandates: + if not getattr(um, "enabled", True): + continue + umId = str(getattr(um, "id", "")) + userId = getattr(um, "userId", None) + if not userId: + continue + roleIds = rootIf.getRoleIdsForUserMandate(umId) + isAdmin = False + for roleId in roleIds: + role = rootIf.getRole(roleId) + if role and role.roleLabel == "admin" and not role.featureInstanceId: + isAdmin = True + break + if not isAdmin: + continue + user = rootIf.getUser(str(userId)) + if user and user.email: + emails.append(user.email.strip().lower()) + except Exception as e: + logger.warning("Could not resolve admin emails for mandate %s: %s", mandateId, e) + return emails + + +def _resolveAllRecipients(mandateId: str) -> List[str]: + """Union of BillingSettings.notifyEmails + all mandate admin user emails, deduplicated.""" + seen: Set[str] = set() + result: List[str] = [] + for email in _resolveMandateContactEmails(mandateId) + _resolveMandateAdminEmails(mandateId): + if email and email not in seen: + seen.add(email) + result.append(email) + return result + + +# ============================================================================ +# Mandate name resolution +# ============================================================================ + + +def _resolveMandateName(mandateId: str) -> str: + """Return the human-readable mandate name (label or name), falling back to a short ID.""" + try: + from modules.datamodels.datamodelUam import Mandate + from modules.security.rootAccess import getRootDbAppConnector + appDb = getRootDbAppConnector() + rows = appDb.getRecordset(Mandate, recordFilter={"id": mandateId}) + if rows: + return rows[0].get("label") or rows[0].get("name") or mandateId[:8] + except Exception as e: + logger.warning("Could not resolve mandate name for %s: %s", mandateId, e) + return mandateId[:8] + + +# ============================================================================ +# HTML email rendering +# ============================================================================ + + +def _getOperatorInfo() -> Dict[str, str]: + """Load operator company data from config.ini.""" + try: + from modules.shared.configuration import APP_CONFIG + return { + "companyName": APP_CONFIG.get("Operator_CompanyName", ""), + "address": APP_CONFIG.get("Operator_Address", ""), + "vatNumber": APP_CONFIG.get("Operator_VatNumber", ""), + } + except Exception: + return {"companyName": "", "address": "", "vatNumber": ""} + + +def _renderHtmlEmail( + headline: str, + bodyParagraphs: List[str], + mandateName: str, + footerNote: Optional[str] = None, + rawHtmlBlock: Optional[str] = None, +) -> str: + """Render a clean, professional HTML notification email. + + Args: + rawHtmlBlock: Optional pre-formatted HTML inserted after bodyParagraphs (e.g. invoice table). + """ + hl = html.escape(headline) + mn = html.escape(mandateName) + + paragraphsHtml = "" + for p in bodyParagraphs: + escaped = html.escape(p).replace("\n", "
") + paragraphsHtml += f'

{escaped}

\n' + + rawBlock = "" + if rawHtmlBlock: + rawBlock = f'
{rawHtmlBlock}
\n' + + footer = "" + if footerNote: + footer = ( + f'

' + f'{html.escape(footerNote)}

\n' + ) + + operator = _getOperatorInfo() + operatorLine = "" + parts = [p for p in [operator["companyName"], operator["address"], operator["vatNumber"]] if p] + if parts: + operatorLine = ( + f'

' + f'{html.escape(" | ".join(parts))}

\n' + ) + + return f""" + + + + + +
+ + + + + + + +
+

PowerOn

+
+

{hl}

+

Mandant: {mn}

+
+ {paragraphsHtml} + {rawBlock} +
+ {footer} +
+

+ Diese E-Mail wurde automatisch von PowerOn versendet. +

+ {operatorLine} +
+
+ +""" + + +# ============================================================================ +# Public API +# ============================================================================ + + +def notifyMandateAdmins( + mandateId: str, + subject: str, + headline: str, + bodyParagraphs: List[str], + *, + footerNote: Optional[str] = None, + rawHtmlBlock: Optional[str] = None, +) -> int: + """ + Send a styled HTML notification to all mandate admins and configured contacts. + + Args: + mandateId: The mandate to notify admins for. + subject: Email subject line. + headline: Bold headline inside the email body. + bodyParagraphs: List of paragraph strings (plain text, auto-escaped). + footerNote: Optional small-print note below the main content. + rawHtmlBlock: Optional pre-formatted HTML block (e.g. invoice summary table). + + Returns: + Number of recipients that were successfully notified. + """ + if not mandateId: + return 0 + + recipients = _resolveAllRecipients(mandateId) + if not recipients: + logger.warning( + "notifyMandateAdmins: no recipients found for mandate %s " + "(no notifyEmails configured and no admin users with email)", + mandateId, + ) + return 0 + + mandateName = _resolveMandateName(mandateId) + htmlMessage = _renderHtmlEmail(headline, bodyParagraphs, mandateName, footerNote, rawHtmlBlock) + messaging = getMessagingInterface() + successCount = 0 + + for to in recipients: + try: + ok = messaging.send( + channel=MessagingChannel.EMAIL, + recipient=to, + subject=subject, + message=htmlMessage, + ) + if ok: + successCount += 1 + else: + logger.warning("notifyMandateAdmins: send failed for %s", to) + except Exception as e: + logger.error("notifyMandateAdmins: error sending to %s: %s", to, e) + + logger.info( + "notifyMandateAdmins: sent '%s' to %d/%d recipients for mandate %s (%s)", + subject, successCount, len(recipients), mandateId, mandateName, + ) + return successCount diff --git a/modules/shared/progressLogger.py b/modules/shared/progressLogger.py index d12a1562..1b67f73e 100644 --- a/modules/shared/progressLogger.py +++ b/modules/shared/progressLogger.py @@ -135,8 +135,8 @@ class ProgressLogger: message = f"{op['service']}" workflow = self.services.workflow - if not workflow: - logger.warning(f"Cannot log progress: no workflow available") + if not workflow or not getattr(workflow, "id", None): + # No workflow or no workflow.id (e.g. automation2 placeholder) - skip progress logging return None # Validate parentOperationId exists in activeOperations or finishedOperations diff --git a/modules/shared/stripeClient.py b/modules/shared/stripeClient.py new file mode 100644 index 00000000..9c7b4c67 --- /dev/null +++ b/modules/shared/stripeClient.py @@ -0,0 +1,38 @@ +# Copyright (c) 2025 Patrick Motsch +# All rights reserved. +""" +Central Stripe SDK initialization. + +All Stripe interactions MUST use getStripeClient() to ensure consistent +API key, API version, and fallback handling across billing and subscription flows. +""" + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + +_stripeInitialized = False + + +def getStripeClient(): + """ + Initialize and return the configured Stripe SDK module. + + Raises ValueError if no Stripe secret key is configured. + """ + import stripe + from modules.shared.configuration import APP_CONFIG + + apiVersion = APP_CONFIG.get("STRIPE_API_VERSION") + if apiVersion: + stripe.api_version = apiVersion + + secretKey = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY") + if not secretKey: + raise ValueError("STRIPE_SECRET_KEY_SECRET not configured") + + stripe.api_key = secretKey + return stripe + + diff --git a/modules/system/mainSystem.py b/modules/system/mainSystem.py index a4b92ca4..a424c973 100644 --- a/modules/system/mainSystem.py +++ b/modules/system/mainSystem.py @@ -190,7 +190,15 @@ NAVIGATION_SECTIONS = [ "path": "/admin/billing", "order": 40, "adminOnly": True, - "sysAdminOnly": True, + }, + { + "id": "admin-subscriptions", + "objectKey": "ui.admin.subscriptions", + "label": {"en": "Subscriptions", "de": "Abonnements", "fr": "Abonnements"}, + "icon": "FaFileContract", + "path": "/admin/subscriptions", + "order": 50, + "adminOnly": True, }, ], }, @@ -274,6 +282,16 @@ NAVIGATION_SECTIONS = [ "adminOnly": True, "sysAdminOnly": True, }, + { + "id": "admin-automation-logs", + "objectKey": "ui.admin.automationLogs", + "label": {"en": "Execution Logs", "de": "Ausführungsprotokolle", "fr": "Journaux d'exécution"}, + "icon": "FaClipboardList", + "path": "/admin/automation-logs", + "order": 85, + "adminOnly": True, + "sysAdminOnly": True, + }, { "id": "admin-logs", "objectKey": "ui.admin.logs", @@ -437,6 +455,11 @@ RESOURCE_OBJECTS = [ "label": {"en": "Store: Automation", "de": "Store: Automation", "fr": "Store: Automatisation"}, "meta": {"category": "store", "featureCode": "automation"} }, + { + "objectKey": "resource.store.automation2", + "label": {"en": "Store: Automation 2", "de": "Store: Automation 2", "fr": "Store: Automatisation 2"}, + "meta": {"category": "store", "featureCode": "automation2"} + }, { "objectKey": "resource.store.teamsbot", "label": {"en": "Store: Teams Bot", "de": "Store: Teams Bot", "fr": "Store: Teams Bot"}, diff --git a/modules/workflows/automation2/__init__.py b/modules/workflows/automation2/__init__.py new file mode 100644 index 00000000..0656ab39 --- /dev/null +++ b/modules/workflows/automation2/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025 Patrick Motsch +# automation2 - n8n-style graph execution engine. diff --git a/modules/workflows/automation2/executionEngine.py b/modules/workflows/automation2/executionEngine.py new file mode 100644 index 00000000..2e799707 --- /dev/null +++ b/modules/workflows/automation2/executionEngine.py @@ -0,0 +1,244 @@ +# Copyright (c) 2025 Patrick Motsch +# Main execution engine for automation2 graphs. + +import logging +from datetime import datetime, timezone +from typing import Dict, Any, List, Set, Optional + +from modules.workflows.automation2.graphUtils import ( + parseGraph, + buildConnectionMap, + validateGraph, + topoSort, + getInputSources, +) + +from modules.workflows.automation2.executors import ( + TriggerExecutor, + FlowExecutor, + DataExecutor, + ActionNodeExecutor, + InputExecutor, + PauseForHumanTaskError, + PauseForEmailWaitError, +) +from modules.features.automation2.nodeDefinitions import STATIC_NODE_TYPES + +logger = logging.getLogger(__name__) + + +def _getNodeTypeIds(services: Any = None) -> Set[str]: + """Collect all known node type IDs from static definitions.""" + return {n["id"] for n in STATIC_NODE_TYPES} + + +def _getExecutor( + nodeType: str, + services: Any, + automation2_interface: Optional[Any] = None, +) -> Any: + """Dispatch to correct executor based on node type.""" + if nodeType.startswith("trigger."): + return TriggerExecutor() + if nodeType.startswith("flow."): + return FlowExecutor() + if nodeType.startswith("data."): + return DataExecutor() + if nodeType.startswith("ai.") or nodeType.startswith("email.") or nodeType.startswith("sharepoint."): + return ActionNodeExecutor(services) + if nodeType.startswith("input.") and automation2_interface: + return InputExecutor(automation2_interface) + return None + + +async def executeGraph( + graph: Dict[str, Any], + services: Any, + workflowId: str = None, + instanceId: str = None, + userId: str = None, + mandateId: str = None, + automation2_interface: Optional[Any] = None, + initialNodeOutputs: Optional[Dict[str, Any]] = None, + startAfterNodeId: Optional[str] = None, + runId: Optional[str] = None, +) -> Dict[str, Any]: + """ + Execute automation2 graph. Returns { success, nodeOutputs, error?, stopped? }. + When an input node is reached and automation2_interface is provided, creates a task, + pauses the run, and returns { success: False, paused: True, taskId, runId }. + For resume: pass initialNodeOutputs (with result for the human node) and startAfterNodeId. + """ + logger.info( + "executeGraph start: instanceId=%s workflowId=%s userId=%s mandateId=%s resume=%s", + instanceId, + workflowId, + userId, + mandateId, + startAfterNodeId is not None, + ) + from modules.workflows.processing.shared.methodDiscovery import discoverMethods + discoverMethods(services) + nodeTypeIds = _getNodeTypeIds(services) + logger.debug("executeGraph nodeTypeIds (%d): %s", len(nodeTypeIds), sorted(nodeTypeIds)) + errors = validateGraph(graph, nodeTypeIds) + if errors: + logger.warning("executeGraph validation failed: %s", errors) + return {"success": False, "error": "; ".join(errors), "nodeOutputs": {}} + + nodes, connections = parseGraph(graph)[:2] + connectionMap = buildConnectionMap(connections) + inputSources = {n["id"]: getInputSources(n["id"], connectionMap) for n in nodes if n.get("id")} + logger.info( + "executeGraph parsed: nodes=%d connections=%d connectionMap_targets=%s", + len(nodes), + len(connections), + list(connectionMap.keys()), + ) + + ordered = topoSort(nodes, connectionMap) + ordered_ids = [n.get("id") for n in ordered if n.get("id")] + logger.info("executeGraph topoSort order: %s", ordered_ids) + + nodeOutputs: Dict[str, Any] = dict(initialNodeOutputs or {}) + is_resume = startAfterNodeId is not None + if not runId and automation2_interface and workflowId and not is_resume: + run_context = { + "connectionMap": connectionMap, + "inputSources": inputSources, + "orderedNodeIds": ordered_ids, + } + if userId: + run_context["ownerId"] = userId + if mandateId: + run_context["mandateId"] = mandateId + if instanceId: + run_context["instanceId"] = instanceId + run = automation2_interface.createRun( + workflowId=workflowId, + nodeOutputs=nodeOutputs, + context=run_context, + ) + runId = run.get("id") if run else None + logger.info("executeGraph created run %s", runId) + + context = { + "workflowId": workflowId, + "instanceId": instanceId, + "userId": userId, + "mandateId": mandateId, + "nodeOutputs": nodeOutputs, + "connectionMap": connectionMap, + "inputSources": inputSources, + "services": services, + "_runId": runId, + "_orderedNodes": ordered, + } + + skip_until_passed = bool(startAfterNodeId) + for i, node in enumerate(ordered): + if skip_until_passed: + if node.get("id") == startAfterNodeId: + skip_until_passed = False + continue + if context.get("_stopped"): + logger.info("executeGraph stopped early (flow.stop) at step %d", i) + break + nodeId = node.get("id") + nodeType = node.get("type", "") + executor = _getExecutor(nodeType, services, automation2_interface) + logger.info( + "executeGraph step %d/%d: nodeId=%s nodeType=%s executor=%s", + i + 1, + len(ordered), + nodeId, + nodeType, + type(executor).__name__ if executor else "None", + ) + if not executor: + nodeOutputs[nodeId] = None + logger.debug("executeGraph node %s: no executor, output=None", nodeId) + continue + try: + result = await executor.execute(node, context) + nodeOutputs[nodeId] = result + logger.info( + "executeGraph node %s done: result_type=%s result_keys=%s", + nodeId, + type(result).__name__, + list(result.keys()) if isinstance(result, dict) else "n/a", + ) + except PauseForHumanTaskError as e: + logger.info("executeGraph paused for human task %s", e.taskId) + return { + "success": False, + "paused": True, + "taskId": e.taskId, + "runId": e.runId, + "nodeId": e.nodeId, + "nodeOutputs": dict(nodeOutputs), + } + except PauseForEmailWaitError as e: + logger.info("executeGraph paused for email wait (run %s, node %s)", e.runId, e.nodeId) + # Start email poller on-demand (only runs while workflows wait for email) + try: + from modules.interfaces.interfaceDbApp import getRootInterface + from modules.features.automation2.emailPoller import ensureRunning + root = getRootInterface() + event_user = root.getUserByUsername("event") if root else None + if event_user: + ensureRunning(event_user) + except Exception as poll_err: + logger.warning("Could not start email poller: %s", poll_err) + paused_at = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") + run_ctx = { + "connectionMap": context.get("connectionMap"), + "inputSources": context.get("inputSources"), + "orderedNodeIds": [n.get("id") for n in context.get("_orderedNodes", []) if n.get("id")], + "waitReason": "email", + "waitConfig": e.waitConfig, + "pausedAt": paused_at, + "lastCheckedAt": None, + "ownerId": context.get("userId"), + "mandateId": context.get("mandateId"), + "instanceId": context.get("instanceId"), + } + automation2_interface.updateRun( + e.runId, + status="paused", + nodeOutputs=dict(nodeOutputs), + currentNodeId=e.nodeId, + context=run_ctx, + ) + return { + "success": False, + "paused": True, + "waitReason": "email", + "runId": e.runId, + "nodeId": e.nodeId, + "nodeOutputs": dict(nodeOutputs), + } + except Exception as e: + logger.exception("executeGraph node %s (%s) FAILED: %s", nodeId, nodeType, e) + nodeOutputs[nodeId] = {"error": str(e), "success": False} + if runId and automation2_interface: + automation2_interface.updateRun(runId, status="failed", nodeOutputs=nodeOutputs) + return { + "success": False, + "error": str(e), + "nodeOutputs": nodeOutputs, + "failedNode": nodeId, + } + + if runId and automation2_interface: + automation2_interface.updateRun(runId, status="completed", nodeOutputs=nodeOutputs) + logger.info( + "executeGraph complete: success=True nodeOutputs_keys=%s stopped=%s", + list(nodeOutputs.keys()), + context.get("_stopped", False), + ) + return { + "success": True, + "nodeOutputs": nodeOutputs, + "stopped": context.get("_stopped", False), + } diff --git a/modules/workflows/automation2/executors/__init__.py b/modules/workflows/automation2/executors/__init__.py new file mode 100644 index 00000000..c147a0d0 --- /dev/null +++ b/modules/workflows/automation2/executors/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025 Patrick Motsch +# Executors for automation2 node types. + +from .triggerExecutor import TriggerExecutor +from .flowExecutor import FlowExecutor +from .dataExecutor import DataExecutor +from .actionNodeExecutor import ActionNodeExecutor +from .inputExecutor import InputExecutor, PauseForHumanTaskError, PauseForEmailWaitError + +__all__ = [ + "TriggerExecutor", + "FlowExecutor", + "DataExecutor", + "ActionNodeExecutor", + "InputExecutor", + "PauseForHumanTaskError", + "PauseForEmailWaitError", +] diff --git a/modules/workflows/automation2/executors/actionNodeExecutor.py b/modules/workflows/automation2/executors/actionNodeExecutor.py new file mode 100644 index 00000000..504fb34e --- /dev/null +++ b/modules/workflows/automation2/executors/actionNodeExecutor.py @@ -0,0 +1,599 @@ +# Copyright (c) 2025 Patrick Motsch +# Action node executor - maps ai.*, email.*, sharepoint.* to method actions via ActionExecutor. + +import logging +from typing import Dict, Any, List, Optional + +logger = logging.getLogger(__name__) + + +def _getNodeDefinition(nodeType: str) -> Optional[Dict[str, Any]]: + """Get node definition by type id for _method, _action, _paramMap.""" + from modules.features.automation2.nodeDefinitions import STATIC_NODE_TYPES + for node in STATIC_NODE_TYPES: + if node.get("id") == nodeType: + return node + return None + + +def _resolveConnectionIdToReference(chatService, connectionId: str, services=None) -> Optional[str]: + """ + Resolve connectionId (UserConnection.id) to connectionReference format. + connectionReference format: connection:{authority}:{externalUsername} + Falls back to interfaceDbApp.getUserConnectionById when chatService resolution fails. + """ + if not connectionId: + return None + # Already in reference format + if isinstance(connectionId, str) and connectionId.startswith("connection:"): + return connectionId + # Try chatService first + if chatService: + try: + connections = chatService.getUserConnections() + for c in connections or []: + conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {}) + if str(conn.get("id")) == str(connectionId): + authority = conn.get("authority") + if hasattr(authority, "value"): + authority = authority.value + username = conn.get("externalUsername", "") + return f"connection:{authority}:{username}" + except Exception as e: + logger.debug("_resolveConnectionIdToReference chatService: %s", e) + # Fallback: interfaceDbApp.getUserConnectionById (automation2 may not have chat.getUserConnections) + app = getattr(services, "interfaceDbApp", None) if services else None + if app and hasattr(app, "getUserConnectionById"): + try: + conn = app.getUserConnectionById(str(connectionId)) + if conn: + authority = getattr(conn, "authority", None) + if hasattr(authority, "value"): + authority = authority.value + else: + authority = str(authority) if authority else "outlook" + username = getattr(conn, "externalUsername", "") or "" + return f"connection:{authority}:{username}" + except Exception as e: + logger.debug("_resolveConnectionIdToReference getUserConnectionById: %s", e) + return None + + +def _extractEmailContentFromUpstream(inp: Any) -> Optional[Dict[str, Any]]: + """ + Extract {subject, body, to} from upstream node output (e.g. AI node returning JSON). + Expects JSON like {"subject": "...", "body": "...", "to": "..."} in documentData. + """ + if not inp: + return None + import json + docs = inp.get("documents", inp.get("documentList", [])) if isinstance(inp, dict) else [] + if not docs: + return None + doc = docs[0] if isinstance(docs, list) else docs + raw = getattr(doc, "documentData", None) if hasattr(doc, "documentData") else (doc.get("documentData") if isinstance(doc, dict) else None) + if not raw: + return None + try: + data = json.loads(raw) if isinstance(raw, str) else raw + if isinstance(data, dict) and data.get("subject") and data.get("body"): + return { + "subject": str(data.get("subject", "")), + "body": str(data.get("body", "")), + "to": data.get("to"), + } + except (json.JSONDecodeError, TypeError): + pass + return None + + +def _extractContextFromUpstream(inp: Any) -> Optional[str]: + """ + Extract plain text context from upstream node output (e.g. AI node returning txt). + Use when _extractEmailContentFromUpstream returns None – the generated document content + (email body, summary, etc.) should be passed as context to email.draftEmail. + """ + if not inp: + return None + docs = None + if isinstance(inp, dict): + docs = inp.get("documents") or inp.get("documentList") + if not docs and isinstance(inp.get("data"), dict): + docs = inp.get("data", {}).get("documents") + if not docs or not isinstance(docs, (list, tuple)): + return None + doc = docs[0] if docs else None + if not doc: + return None + raw = getattr(doc, "documentData", None) if hasattr(doc, "documentData") else (doc.get("documentData") or doc.get("content") if isinstance(doc, dict) else None) + if not raw: + return None + if isinstance(raw, bytes): + return raw.decode("utf-8", errors="replace").strip() + s = str(raw).strip() + return s if s else None + + +def _gatherAttachmentDocumentsFromUpstream( + nodeId: str, + inputSources: Dict[str, Dict[int, tuple]], + nodeOutputs: Dict[str, Any], + orderedNodes: List[Dict], + visited: Optional[set] = None, +) -> List[Any]: + """ + Walk upstream from nodeId through AI nodes to collect file documents (e.g. from sharepoint.downloadFile). + Used when email.draftEmail has AI upstream – attachments come from file nodes, not AI output. + """ + visited = visited or set() + if nodeId in visited: + return [] + visited.add(nodeId) + docs = [] + src = inputSources.get(nodeId, {}).get(0) + if not src: + return [] + srcId, _ = src + srcNode = next((n for n in (orderedNodes or []) if n.get("id") == srcId), None) + srcType = (srcNode or {}).get("type", "") + out = nodeOutputs.get(srcId) + + if srcType in ("sharepoint.downloadFile", "sharepoint.readFile"): + if isinstance(out, dict): + for d in out.get("documents") or out.get("documentList") or []: + if isinstance(d, dict) and (d.get("documentData") or (d.get("validationMetadata") or {}).get("fileId")): + docs.append(d) + elif hasattr(d, "documentData") or (getattr(d, "validationMetadata", None) or {}).get("fileId"): + docs.append(d.model_dump() if hasattr(d, "model_dump") else d) + elif srcType.startswith("ai."): + docs.extend( + _gatherAttachmentDocumentsFromUpstream(srcId, inputSources, nodeOutputs, orderedNodes, visited) + ) + return docs + + +def _getIncomingEmailFromUpstream( + nodeId: str, + inputSources: Dict[str, Dict[int, tuple]], + nodeOutputs: Dict[str, Any], + orderedNodes: List[Dict], +) -> Optional[tuple]: + """ + Walk upstream from draftEmail to find email.checkEmail/searchEmail and return (context, documentList). + context = formatted incoming email(s) for composeAndDraftEmail. + documentList = documents from the email node for attachment/context. + """ + src = inputSources.get(nodeId, {}).get(0) + if not src: + return None + srcId, _ = src + srcNode = next((n for n in (orderedNodes or []) if n.get("id") == srcId), None) + srcType = (srcNode or {}).get("type", "") + + # Direct connection to email node + if srcType in ("email.checkEmail", "email.searchEmail"): + out = nodeOutputs.get(srcId) + return _formatEmailOutputAsContext(out) + + # Connected via AI node: walk one more step to email source + if srcType.startswith("ai."): + src2 = inputSources.get(srcId, {}).get(0) + if not src2: + return None + emailNodeId, _ = src2 + emailNode = next((n for n in (orderedNodes or []) if n.get("id") == emailNodeId), None) + if (emailNode or {}).get("type") in ("email.checkEmail", "email.searchEmail"): + out = nodeOutputs.get(emailNodeId) + return _formatEmailOutputAsContext(out) + return None + + +def _formatEmailOutputAsContext(out: Any) -> Optional[tuple]: + """Format email node output as (context, documentList, reply_to) for composeAndDraftEmail. + reply_to = sender address of first email (recipient for the reply). + """ + if not out: + return None + docs = out.get("documents", out.get("documentList", [])) if isinstance(out, dict) else [] + if not docs: + return None + doc = docs[0] if isinstance(docs, list) else docs + raw = getattr(doc, "documentData", None) if hasattr(doc, "documentData") else (doc.get("documentData") if isinstance(doc, dict) else None) + if not raw: + return None + import json + try: + data = json.loads(raw) if isinstance(raw, str) else raw + except (json.JSONDecodeError, TypeError): + return None + if not isinstance(data, dict): + return None + # readEmails: data.emails.emails | searchEmails: data.searchResults.results + emails_data = data.get("emails") or {} + emails_list = emails_data.get("emails", []) if isinstance(emails_data, dict) else [] + if not emails_list: + search_results = data.get("searchResults") or {} + emails_list = search_results.get("results", []) if isinstance(search_results, dict) else [] + if not emails_list: + return None + reply_to = None + parts = ["Reply to the following email(s):", ""] + for i, em in enumerate(emails_list[:5]): # max 5 + if not isinstance(em, dict): + continue + fr = em.get("from", em.get("sender", {})) + addr = fr.get("emailAddress", {}) if isinstance(fr, dict) else {} + from_str = addr.get("address", "") or addr.get("name", "") + if from_str and not reply_to: + reply_to = addr.get("address", "") or from_str + subj = em.get("subject", "") + body = em.get("bodyPreview", "") or (em.get("body") or {}).get("content", "") if isinstance(em.get("body"), dict) else "" + if body and len(str(body)) > 1500: + body = str(body)[:1500] + "..." + parts.append(f"From: {from_str}") + parts.append(f"Subject: {subj}") + parts.append(f"Content:\n{body}") + parts.append("") + if reply_to: + parts.insert(2, f"Recipient (reply to this address): {reply_to}") + parts.insert(3, "") + context = "\n".join(parts).strip() + return (context, docs, reply_to) + + +def _buildSearchQuery( + query: str = None, + fromAddress: str = None, + toAddress: str = None, + subjectContains: str = None, + bodyContains: str = None, + hasAttachment: bool = None, + filter: str = None, +) -> str: + """ + Build Microsoft Graph $search query from discrete params. + Uses KQL: from:, to:, subject:, body:, hasattachments: (supported by Graph API). + """ + if filter and str(filter).strip(): + return str(filter).strip() + parts = [] + if query and str(query).strip(): + parts.append(str(query).strip()) + if fromAddress and str(fromAddress).strip(): + safe = str(fromAddress).strip().replace('"', '') + parts.append(f'from:{safe}') + if toAddress and str(toAddress).strip(): + safe = str(toAddress).strip().replace('"', '') + parts.append(f'to:{safe}') + if subjectContains and str(subjectContains).strip(): + safe = str(subjectContains).strip().replace('"', '') + parts.append(f'subject:{safe}') + if bodyContains and str(bodyContains).strip(): + safe = str(bodyContains).strip().replace('"', '') + parts.append(f'body:{safe}') + if hasAttachment is True: + parts.append("hasattachments:true") + return " ".join(parts) if parts else "*" + + +def _buildEmailFilter(fromAddress: str = None, subjectContains: str = None, hasAttachment: bool = None) -> str: + """ + Build Microsoft Graph API $filter string from discrete email filter params. + Used for email.checkEmail (and trigger.newEmail). + """ + parts = [] + if fromAddress and str(fromAddress).strip(): + safe = str(fromAddress).strip().replace("'", "''") + parts.append(f"from/emailAddress/address eq '{safe}'") + if subjectContains and str(subjectContains).strip(): + safe = str(subjectContains).strip().replace("'", "''") + parts.append(f"contains(subject,'{safe}')") + if hasAttachment is True: + parts.append("hasAttachments eq true") + return " and ".join(parts) if parts else "" + + +def _buildActionParams( + node: Dict[str, Any], + nodeDef: Dict[str, Any], + resolvedParams: Dict[str, Any], + chatService, + services=None, +) -> Dict[str, Any]: + """ + Build params for ActionExecutor from node parameters using _paramMap. + Resolves connectionId -> connectionReference. + Handles _contextFrom for composite params (e.g. email.draftEmail subject+body -> context). + """ + params = dict(resolvedParams) + paramMap = nodeDef.get("_paramMap") or {} + contextFrom = nodeDef.get("_contextFrom") or [] + + # email.checkEmail: build filter from discrete params (fromAddress, subjectContains, hasAttachment) + nodeType = node.get("type", "") + if nodeType == "email.checkEmail": + built = _buildEmailFilter( + fromAddress=params.get("fromAddress"), + subjectContains=params.get("subjectContains"), + hasAttachment=params.get("hasAttachment"), + ) + raw_filter = (params.get("filter") or "").strip() + params["filter"] = built if built else (raw_filter if raw_filter else None) + params.pop("fromAddress", None) + params.pop("subjectContains", None) + params.pop("hasAttachment", None) + + # email.searchEmail: build query from discrete params (fromAddress, toAddress, subjectContains, bodyContains, hasAttachment) + if nodeType == "email.searchEmail": + built = _buildSearchQuery( + query=params.get("query"), + fromAddress=params.get("fromAddress"), + toAddress=params.get("toAddress"), + subjectContains=params.get("subjectContains"), + bodyContains=params.get("bodyContains"), + hasAttachment=params.get("hasAttachment"), + filter=params.get("filter"), + ) + params["query"] = built + params.pop("fromAddress", None) + params.pop("toAddress", None) + params.pop("subjectContains", None) + params.pop("bodyContains", None) + params.pop("hasAttachment", None) + params.pop("filter", None) + + # Resolve connectionId to connectionReference + if "connectionId" in params: + connId = params.get("connectionId") + if connId: + ref = _resolveConnectionIdToReference(chatService, connId, services) + if ref: + params["connectionReference"] = ref + else: + logger.warning(f"Could not resolve connectionId {connId} to connectionReference") + params.pop("connectionId", None) + + # Build context from multiple params (e.g. subject + body for draft email) + if contextFrom: + parts = [] + for key in contextFrom: + val = params.get(key) + if val: + if key == "subject": + parts.append(f"Subject: {val}") + elif key == "body": + parts.append(f"Body:\n{val}") + else: + parts.append(str(val)) + if parts: + params["context"] = "\n\n".join(parts) + for k in contextFrom: + params.pop(k, None) + + # Apply paramMap: node param name -> action param name + result = {} + mappedNodeKeys = {nodeKey for nodeKey, actionKey in paramMap.items() if actionKey and nodeKey in params} + for nodeKey, actionKey in paramMap.items(): + if nodeKey in params and actionKey: + result[actionKey] = params[nodeKey] + # Pass through params not used as source for mapping + for k, v in params.items(): + if k not in mappedNodeKeys and k not in result: + result[k] = v + return result + + +class ActionNodeExecutor: + """Execute ai.*, email.*, sharepoint.* nodes by mapping to method actions.""" + + def __init__(self, services: Any): + self.services = services + + async def execute( + self, + node: Dict[str, Any], + context: Dict[str, Any], + ) -> Any: + from modules.features.automation2.nodeRegistry import getNodeTypeToMethodAction + from modules.workflows.automation2.graphUtils import resolveParameterReferences + from modules.workflows.processing.core.actionExecutor import ActionExecutor + + nodeType = node.get("type", "") + nodeId = node.get("id", "") + logger.info("ActionNodeExecutor node %s type=%s", nodeId, nodeType) + + mapping = getNodeTypeToMethodAction() + methodAction = mapping.get(nodeType) + if not methodAction: + logger.debug("ActionNodeExecutor node %s not in mapping -> None", nodeId) + return None + + methodName, actionName = methodAction + logger.info("ActionNodeExecutor node %s method=%s action=%s", nodeId, methodName, actionName) + + nodeDef = _getNodeDefinition(nodeType) + params = dict(node.get("parameters") or {}) + resolvedParams = resolveParameterReferences(params, context.get("nodeOutputs", {})) + + # Merge input from connected nodes (documentList, etc.) + inputSources = context.get("inputSources", {}).get(nodeId, {}) + if 0 in inputSources: + srcId, _ = inputSources[0] + inp = context.get("nodeOutputs", {}).get(srcId) + if isinstance(inp, dict): + resolvedParams.setdefault("documentList", inp.get("documents", inp.get("documentList", []))) + elif inp is not None: + resolvedParams.setdefault("input", inp) + + # ai.prompt with email upstream: inject actual email content into prompt so AI has context + # (getChatDocumentsFromDocumentList fails in automation2 – workflow has no messages) + if nodeType.startswith("ai."): + orderedNodes = context.get("_orderedNodes") or [] + if 0 in inputSources: + srcId, _ = inputSources[0] + srcNode = next((n for n in orderedNodes if n.get("id") == srcId), None) + srcType = (srcNode or {}).get("type", "") + if srcType in ("email.checkEmail", "email.searchEmail"): + incoming = _getIncomingEmailFromUpstream( + nodeId, + context.get("inputSources", {}), + context.get("nodeOutputs", {}), + orderedNodes, + ) + if incoming: + ctx, _doc_list, _reply_to = incoming + if ctx and ctx.strip(): + base_prompt = (resolvedParams.get("aiPrompt") or "").strip() + resolvedParams["aiPrompt"] = ( + f"Eingehende E-Mail:\n{ctx}\n\nAufgabe: {base_prompt}" + if base_prompt + else f"Eingehende E-Mail:\n{ctx}" + ) + logger.debug("ai.prompt: injected email context from upstream %s", srcType) + + chatService = getattr(self.services, "chat", None) + actionParams = _buildActionParams(node, nodeDef or {}, resolvedParams, chatService, self.services) + + # email.checkEmail: pause and wait for new email (background poller will resume) + if nodeType == "email.checkEmail": + runId = context.get("_runId") + workflowId = context.get("workflowId") + connRef = actionParams.get("connectionReference") + if runId and workflowId and connRef: + from modules.workflows.automation2.executors import PauseForEmailWaitError + waitConfig = { + "connectionReference": connRef, + "folder": actionParams.get("folder", "Inbox"), + "limit": min(int(actionParams.get("limit") or 10), 50), + "filter": actionParams.get("filter"), + } + raise PauseForEmailWaitError(runId=runId, nodeId=nodeId, waitConfig=waitConfig) + # Fallback: no pause (calls readEmails directly) – needs runId, workflowId, connectionReference + if not runId or not workflowId: + logger.warning( + "email.checkEmail not pausing (runId=%s workflowId=%s) – run must be saved/executed as workflow", + runId, + workflowId, + ) + elif not connRef: + logger.warning( + "email.checkEmail not pausing – connectionReference missing (check connectionId/config)", + ) + + # email.draftEmail: use AI output as emailContent if available; else pass incoming email as context + if nodeType == "email.draftEmail": + inputSources = context.get("inputSources", {}) + nodeOutputs = context.get("nodeOutputs", {}) + orderedNodes = context.get("_orderedNodes") or [] + if 0 in inputSources.get(nodeId, {}): + srcId, _ = inputSources[nodeId][0] + srcNode = next((n for n in orderedNodes if n.get("id") == srcId), None) + srcType = (srcNode or {}).get("type", "") + if srcType.startswith("ai."): + inp = nodeOutputs.get(srcId) + email_content = _extractEmailContentFromUpstream(inp) + if email_content: + actionParams["emailContent"] = email_content + actionParams["context"] = email_content.get("body", "") or "(from connected AI node)" + # Attachments: gather from file nodes upstream of AI (e.g. downloadFile -> AI -> email) + attachment_docs = _gatherAttachmentDocumentsFromUpstream( + nodeId, inputSources, nodeOutputs, orderedNodes + ) + if attachment_docs: + existing = actionParams.get("documentList") or [] + # Prefer file docs from upstream; append any existing that look like binary attachments + def _is_binary_attachment(d): + if isinstance(d, dict) and d.get("documentData"): + try: + import json + json.loads(d["documentData"]) + return False # JSON = email content, not attachment + except (TypeError, ValueError): + return True + return bool(isinstance(d, dict) and (d.get("validationMetadata") or {}).get("fileId")) + extra = [x for x in (existing if isinstance(existing, list) else []) if _is_binary_attachment(x)] + actionParams["documentList"] = attachment_docs + extra + if not email_content: + # AI returns plain text (e.g. email.txt): use as email body directly (no extra AI call) + ctx = _extractContextFromUpstream(inp) + if ctx: + actionParams["emailContent"] = { + "subject": actionParams.get("subject", "Draft"), + "body": ctx, + "to": actionParams.get("to"), + } + actionParams["context"] = ctx + else: + # Fallback: incoming email from upstream (if flow is email->AI->draft) + incoming = _getIncomingEmailFromUpstream(nodeId, inputSources, nodeOutputs, orderedNodes) + if incoming: + ctx, doc_list, reply_to = incoming + actionParams["context"] = ctx + if doc_list and not actionParams.get("documentList"): + actionParams["documentList"] = doc_list + if reply_to and not actionParams.get("to"): + actionParams["to"] = [reply_to] + else: + doc_count = len(inp.get("documents", [])) if isinstance(inp, dict) else 0 + logger.warning( + "email.draftEmail: AI upstream returned %d doc(s) but context extraction failed (no subject/body, no plain text). " + "Ensure AI node outputs document with documentData.", + doc_count, + ) + actionParams["context"] = "(no context extracted from upstream – check AI node output)" + elif srcType in ("sharepoint.downloadFile", "sharepoint.readFile"): + # File itself is the context: pass as attachment, use filename as minimal context (no content extraction) + if not actionParams.get("context"): + inp = nodeOutputs.get(srcId) + docs = (inp.get("documents") or inp.get("documentList", [])) if isinstance(inp, dict) else [] + doc = docs[0] if docs else None + name = None + if isinstance(doc, dict): + name = doc.get("documentName") or doc.get("fileName") + elif doc and hasattr(doc, "documentName"): + name = getattr(doc, "documentName", None) or getattr(doc, "fileName", None) + ctx = name if name else "Attachment" + actionParams["context"] = ctx + actionParams["emailContent"] = { + "subject": actionParams.get("subject", "Draft"), + "body": ctx, + "to": actionParams.get("to"), + } + # documentList already merged from upstream (file as attachment) + else: + # Direct connection to email.checkEmail/searchEmail: use incoming email as context + if not actionParams.get("context"): + incoming = _getIncomingEmailFromUpstream(nodeId, inputSources, nodeOutputs, orderedNodes) + if incoming: + ctx, doc_list, reply_to = incoming + actionParams["context"] = ctx + if doc_list and not actionParams.get("documentList"): + actionParams["documentList"] = doc_list + if reply_to and not actionParams.get("to"): + actionParams["to"] = [reply_to] + + # Generic context handover: when upstream provides documents, pass first doc as content for actions that expect it + docList = actionParams.get("documentList") or resolvedParams.get("documentList") + if docList and "content" not in actionParams: + first = docList[0] if isinstance(docList, list) and docList else docList + # Actions like sharepoint.uploadFile consume content from context + actionParams["content"] = first + + executor = ActionExecutor(self.services) + logger.info("ActionNodeExecutor node %s calling executeAction(%s, %s)", nodeId, methodName, actionName) + result = await executor.executeAction(methodName, actionName, actionParams) + + out = { + "success": result.success, + "error": result.error, + "documents": [d.model_dump() if hasattr(d, "model_dump") else d for d in (result.documents or [])], + "data": result.model_dump() if hasattr(result, "model_dump") else {"success": result.success, "error": result.error}, + } + logger.info( + "ActionNodeExecutor node %s result: success=%s error=%s doc_count=%d", + nodeId, + result.success, + result.error, + len(out.get("documents", [])), + ) + return out diff --git a/modules/workflows/automation2/executors/dataExecutor.py b/modules/workflows/automation2/executors/dataExecutor.py new file mode 100644 index 00000000..386c8abd --- /dev/null +++ b/modules/workflows/automation2/executors/dataExecutor.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025 Patrick Motsch +# Data transformation node executor (setFields, filter, parseJson, template). + +import json +import logging +import re +from typing import Dict, Any, List + +logger = logging.getLogger(__name__) + + +def _get_nested(obj: Any, path: str) -> Any: + """Get nested key from obj, e.g. 'data.items'.""" + for k in path.split("."): + if not k: + continue + if isinstance(obj, dict) and k in obj: + obj = obj[k] + elif isinstance(obj, (list, tuple)) and k.isdigit(): + obj = obj[int(k)] + else: + return None + return obj + + +class DataExecutor: + """Execute data transformation nodes.""" + + async def execute( + self, + node: Dict[str, Any], + context: Dict[str, Any], + ) -> Any: + nodeType = node.get("type", "") + nodeOutputs = context.get("nodeOutputs", {}) + nodeId = node.get("id", "") + inputSources = context.get("inputSources", {}).get(nodeId, {}) + params = node.get("parameters") or {} + logger.info( + "DataExecutor node %s type=%s inputSources=%s params=%s", + nodeId, + nodeType, + inputSources, + params, + ) + + inp = None + if 0 in inputSources: + srcId, _ = inputSources[0] + inp = nodeOutputs.get(srcId) + + from modules.workflows.automation2.graphUtils import resolveParameterReferences + resolvedParams = {k: resolveParameterReferences(v, nodeOutputs) for k, v in params.items()} + + if nodeType == "data.setFields": + out = self._setFields(inp, resolvedParams) + logger.info("DataExecutor node %s setFields inp=%s -> %s", nodeId, type(inp).__name__, out) + return out + if nodeType == "data.filter": + out = self._filter(inp, resolvedParams) + logger.info("DataExecutor node %s filter inp=%s -> len=%d", nodeId, type(inp).__name__, len(out) if isinstance(out, list) else -1) + return out + if nodeType == "data.parseJson": + out = self._parseJson(inp, resolvedParams) + logger.info("DataExecutor node %s parseJson -> %s", nodeId, type(out).__name__) + return out + if nodeType == "data.template": + out = self._template(inp, resolvedParams, nodeOutputs) + logger.info("DataExecutor node %s template -> %s", nodeId, out) + return out + + logger.debug("DataExecutor node %s unhandled type %s -> passThrough", nodeId, nodeType) + return inp + + def _setFields(self, inp: Any, params: Dict) -> Any: + fields = params.get("fields", {}) + if not isinstance(fields, dict): + return inp + base = dict(inp) if isinstance(inp, dict) else {} + base.update(fields) + return base + + def _filter(self, inp: Any, params: Dict) -> Any: + itemsPath = (params.get("itemsPath") or "").strip() + condition = params.get("condition", "True") + items = inp + if itemsPath: + items = _get_nested(inp, itemsPath) + if not isinstance(items, list): + items = [inp] if inp is not None else [] + out = [] + for i, item in enumerate(items): + try: + local = {"item": item, "index": i, "input": inp} + ok = bool(eval(condition, {"__builtins__": {}}, local)) + if ok: + out.append(item) + except Exception: + pass + return out + + def _parseJson(self, inp: Any, params: Dict) -> Any: + jsonPath = (params.get("jsonPath") or "").strip() + raw = inp + if jsonPath: + raw = _get_nested(inp, jsonPath) if isinstance(inp, dict) else inp + if isinstance(raw, dict): + return raw + if isinstance(raw, str): + try: + return json.loads(raw) + except json.JSONDecodeError: + return {"error": "Invalid JSON", "raw": raw[:200]} + return inp + + def _template(self, inp: Any, params: Dict, nodeOutputs: Dict) -> Any: + tpl = params.get("template", "") + from modules.workflows.automation2.graphUtils import resolveParameterReferences + result = resolveParameterReferences(tpl, nodeOutputs) + return {"text": result, "template": tpl} diff --git a/modules/workflows/automation2/executors/flowExecutor.py b/modules/workflows/automation2/executors/flowExecutor.py new file mode 100644 index 00000000..de5789e5 --- /dev/null +++ b/modules/workflows/automation2/executors/flowExecutor.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025 Patrick Motsch +# Flow control node executor (ifElse, merge, wait, stop). + +import asyncio +import logging +from typing import Dict, Any + +logger = logging.getLogger(__name__) + + +class FlowExecutor: + """Execute flow control nodes.""" + + async def execute( + self, + node: Dict[str, Any], + context: Dict[str, Any], + ) -> Any: + nodeType = node.get("type", "") + nodeOutputs = context.get("nodeOutputs", {}) + connectionMap = context.get("connectionMap", {}) + nodeId = node.get("id", "") + inputSources = context.get("inputSources", {}).get(nodeId, {}) + logger.info( + "FlowExecutor node %s type=%s inputSources=%s params=%s", + nodeId, + nodeType, + inputSources, + node.get("parameters"), + ) + + if nodeType == "flow.ifElse": + out = await self._ifElse(node, nodeOutputs, nodeId, inputSources) + logger.info("FlowExecutor node %s ifElse -> %s", nodeId, out) + return out + if nodeType == "flow.merge": + out = await self._merge(node, nodeOutputs, nodeId, inputSources) + logger.info("FlowExecutor node %s merge -> %s", nodeId, out) + return out + if nodeType == "flow.wait": + out = await self._wait(node, nodeOutputs, nodeId, inputSources) + logger.info("FlowExecutor node %s wait -> %s", nodeId, out) + return out + if nodeType == "flow.stop": + context["_stopped"] = True + logger.info("FlowExecutor node %s -> STOP", nodeId) + return {"stopped": True} + if nodeType == "flow.switch": + out = await self._switch(node, nodeOutputs, nodeId, inputSources) + logger.info("FlowExecutor node %s switch -> %s", nodeId, out) + return out + if nodeType == "flow.loop": + out = await self._loop(node, nodeOutputs, nodeId, inputSources) + logger.info("FlowExecutor node %s loop -> %s", nodeId, out) + return out + + logger.debug("FlowExecutor node %s unhandled type %s -> None", nodeId, nodeType) + return None + + def _getInputData(self, nodeId: str, inputSources: Dict, nodeOutputs: Dict, outputIndex: int = 0) -> Any: + """Get data from the connected source node.""" + sources = inputSources.get(nodeId, {}) + if 0 not in sources: + return None + srcId, srcOut = sources[0] + return nodeOutputs.get(srcId) + + async def _ifElse( + self, + node: Dict, + nodeOutputs: Dict, + nodeId: str, + inputSources: Dict, + ) -> Any: + condExpr = (node.get("parameters") or {}).get("condition", "") + inp = self._getInputData(nodeId, {nodeId: inputSources}, nodeOutputs) + # Simple eval - in production use safe evaluation + try: + # Replace {{nodeId}} refs with actual values + from modules.workflows.automation2.graphUtils import resolveParameterReferences + resolved = resolveParameterReferences(condExpr, nodeOutputs) + # Minimal eval for simple comparisons (e.g. "True", "1 > 0") + ok = bool(eval(resolved)) if resolved else False + except Exception: + ok = False + return {"branch": 0 if ok else 1, "conditionResult": ok, "input": inp} + + async def _merge(self, node: Dict, nodeOutputs: Dict, nodeId: str, inputSources: Dict) -> Any: + mode = (node.get("parameters") or {}).get("mode", "append") + sources = inputSources + items = [] + for inpIdx in sorted(sources.keys()): + srcId, _ = sources[inpIdx] + data = nodeOutputs.get(srcId) + if data is not None: + if isinstance(data, list): + items.extend(data) + else: + items.append(data) + if mode == "combine" and len(items) == 2: + if isinstance(items[0], dict) and isinstance(items[1], dict): + return {**items[0], **items[1]} + return items + + async def _wait(self, node: Dict, nodeOutputs: Dict) -> Any: + secs = (node.get("parameters") or {}).get("seconds", 0) + if secs > 0: + await asyncio.sleep(min(float(secs), 300)) + nodeId = node.get("id") + from modules.workflows.automation2.graphUtils import getInputSources + # Input comes from context + inp = context.get("_inputData") if "context" in dir() else None + return nodeOutputs.get(nodeId, {}) + + async def _wait( + self, + node: Dict, + nodeOutputs: Dict, + nodeId: str, + inputSources: Dict, + ) -> Any: + secs = (node.get("parameters") or {}).get("seconds", 0) + if secs > 0: + await asyncio.sleep(min(float(secs), 300)) + if 0 in inputSources: + srcId, _ = inputSources[0] + return nodeOutputs.get(srcId) + return None + + async def _switch(self, node: Dict, nodeOutputs: Dict, nodeId: str, inputSources: Dict) -> Any: + valueExpr = (node.get("parameters") or {}).get("value", "") + from modules.workflows.automation2.graphUtils import resolveParameterReferences + value = resolveParameterReferences(valueExpr, nodeOutputs) + cases = (node.get("parameters") or {}).get("cases", []) + for i, c in enumerate(cases): + if c == value: + return {"match": i, "value": value} + return {"match": -1, "value": value} + + async def _loop(self, node: Dict, nodeOutputs: Dict, nodeId: str, inputSources: Dict) -> Any: + itemsPath = (node.get("parameters") or {}).get("items", "[]") + from modules.workflows.automation2.graphUtils import resolveParameterReferences + items = resolveParameterReferences(itemsPath, nodeOutputs) + if not isinstance(items, list): + items = [items] if items is not None else [] + return {"items": items, "count": len(items)} diff --git a/modules/workflows/automation2/executors/inputExecutor.py b/modules/workflows/automation2/executors/inputExecutor.py new file mode 100644 index 00000000..22fa2eba --- /dev/null +++ b/modules/workflows/automation2/executors/inputExecutor.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025 Patrick Motsch +# Input/Human node executor - creates tasks and pauses execution. + +import logging +from typing import Dict, Any + +logger = logging.getLogger(__name__) + + +class PauseForHumanTaskError(Exception): + """Raised when execution must pause for a human task. Contains runId, taskId.""" + + def __init__(self, runId: str, taskId: str, nodeId: str): + self.runId = runId + self.taskId = taskId + self.nodeId = nodeId + super().__init__(f"Pause for human task {taskId} (run {runId}, node {nodeId})") + + +class PauseForEmailWaitError(Exception): + """Raised when execution must pause waiting for a new email. Background poller will resume.""" + + def __init__(self, runId: str, nodeId: str, waitConfig: Dict[str, Any]): + self.runId = runId + self.nodeId = nodeId + self.waitConfig = waitConfig + super().__init__(f"Pause for email wait (run {runId}, node {nodeId})") + + +class InputExecutor: + """ + Execute input/human nodes. Creates a HumanTask, pauses the run, and raises + PauseForHumanTaskError so the engine returns { paused: true, taskId, runId }. + """ + + def __init__(self, automation2_interface: Any): + self.automation2 = automation2_interface + + async def execute( + self, + node: Dict[str, Any], + context: Dict[str, Any], + ) -> Any: + nodeType = node.get("type", "") + nodeId = node.get("id", "") + runId = context.get("_runId") + workflowId = context.get("workflowId") + instanceId = context.get("instanceId") + userId = context.get("userId") + + if not runId or not workflowId: + logger.error("InputExecutor: runId/workflowId missing in context - cannot create task") + return {"error": "Missing run context", "success": False} + + config = dict(node.get("parameters") or {}) + logger.info("InputExecutor node %s type=%s creating task", nodeId, nodeType) + + task = self.automation2.createTask( + runId=runId, + workflowId=workflowId, + nodeId=nodeId, + nodeType=nodeType, + config=config, + assigneeId=userId, + ) + taskId = task.get("id") + + self.automation2.updateRun( + runId, + status="paused", + nodeOutputs=context.get("nodeOutputs"), + currentNodeId=nodeId, + context={ + "connectionMap": context.get("connectionMap"), + "inputSources": context.get("inputSources"), + "orderedNodeIds": [n.get("id") for n in context.get("_orderedNodes", []) if n.get("id")], + }, + ) + logger.info("InputExecutor node %s: created task %s, run %s paused", nodeId, taskId, runId) + raise PauseForHumanTaskError(runId=runId, taskId=taskId, nodeId=nodeId) diff --git a/modules/workflows/automation2/executors/ioExecutor.py b/modules/workflows/automation2/executors/ioExecutor.py new file mode 100644 index 00000000..eb006c7e --- /dev/null +++ b/modules/workflows/automation2/executors/ioExecutor.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025 Patrick Motsch +# I/O node executor - delegates to ActionExecutor. + +import logging +from typing import Dict, Any + +logger = logging.getLogger(__name__) + + +class IOExecutor: + """Execute I/O nodes by calling ActionExecutor.executeAction(method, action, params).""" + + def __init__(self, services: Any): + self.services = services + + async def execute( + self, + node: Dict[str, Any], + context: Dict[str, Any], + ) -> Any: + from modules.workflows.processing.core.actionExecutor import ActionExecutor + + nodeType = node.get("type", "") + nodeId = node.get("id", "") + logger.info("IOExecutor node %s type=%s", nodeId, nodeType) + if not nodeType.startswith("io."): + logger.debug("IOExecutor node %s not io.* -> None", nodeId) + return None + + parts = nodeType.split(".", 2) + if len(parts) < 3: + logger.debug("IOExecutor node %s invalid type parts -> None", nodeId) + return None + _, methodName, actionName = parts + logger.info("IOExecutor node %s method=%s action=%s", nodeId, methodName, actionName) + + nodeOutputs = context.get("nodeOutputs", {}) + params = dict(node.get("parameters") or {}) + + from modules.workflows.automation2.graphUtils import resolveParameterReferences + resolvedParams = resolveParameterReferences(params, nodeOutputs) + logger.info("IOExecutor node %s resolvedParams keys=%s", nodeId, list(resolvedParams.keys())) + + inputSources = context.get("inputSources", {}).get(nodeId, {}) + if 0 in inputSources: + srcId, _ = inputSources[0] + inp = nodeOutputs.get(srcId) + if isinstance(inp, dict): + resolvedParams.setdefault("documentList", inp.get("documents", inp.get("documentList", []))) + elif inp is not None: + resolvedParams.setdefault("input", inp) + + executor = ActionExecutor(self.services) + logger.info("IOExecutor node %s calling executeAction(%s, %s)", nodeId, methodName, actionName) + result = await executor.executeAction(methodName, actionName, resolvedParams) + out = { + "success": result.success, + "error": result.error, + "documents": [d.model_dump() if hasattr(d, "model_dump") else d for d in (result.documents or [])], + "data": result.model_dump() if hasattr(result, "model_dump") else {"success": result.success, "error": result.error}, + } + logger.info( + "IOExecutor node %s result: success=%s error=%s doc_count=%d", + nodeId, + result.success, + result.error, + len(out.get("documents", [])), + ) + return out diff --git a/modules/workflows/automation2/executors/triggerExecutor.py b/modules/workflows/automation2/executors/triggerExecutor.py new file mode 100644 index 00000000..87ac359e --- /dev/null +++ b/modules/workflows/automation2/executors/triggerExecutor.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025 Patrick Motsch +# Trigger node executor. + +import logging +from typing import Dict, Any + +logger = logging.getLogger(__name__) + + +class TriggerExecutor: + """Execute trigger nodes (manual, schedule, formSubmit).""" + + async def execute( + self, + node: Dict[str, Any], + context: Dict[str, Any], + ) -> Any: + nodeType = node.get("type", "") + nodeId = node.get("id", "") + logger.info("TriggerExecutor node %s type=%s parameters=%s", nodeId, nodeType, node.get("parameters")) + if nodeType == "trigger.manual": + out = {"triggered": True, "source": "manual"} + logger.info("TriggerExecutor node %s -> manual trigger: %s", nodeId, out) + return out + if nodeType == "trigger.schedule": + out = {"triggered": True, "source": "schedule"} + logger.info("TriggerExecutor node %s -> schedule trigger: %s", nodeId, out) + return out + if nodeType == "trigger.formSubmit": + params = node.get("parameters") or {} + formId = params.get("formId", "") + out = {"triggered": True, "source": "formSubmit", "formId": formId} + logger.info("TriggerExecutor node %s -> formSubmit: %s", nodeId, out) + return out + out = {"triggered": True, "source": "unknown"} + logger.info("TriggerExecutor node %s -> unknown: %s", nodeId, out) + return out diff --git a/modules/workflows/automation2/graphUtils.py b/modules/workflows/automation2/graphUtils.py new file mode 100644 index 00000000..ad58c69c --- /dev/null +++ b/modules/workflows/automation2/graphUtils.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025 Patrick Motsch +# Graph parsing, validation, and topological sort for automation2. + +import logging +from typing import Dict, List, Any, Tuple, Set + +logger = logging.getLogger(__name__) + + +def parseGraph(graph: Dict[str, Any]) -> Tuple[List[Dict], List[Dict], Set[str]]: + """ + Parse graph into nodes, connections, and node IDs. + graph: { nodes: [...], connections: [...] } + Returns (nodes, connections, node_ids). + """ + nodes = graph.get("nodes") or [] + connections = graph.get("connections") or [] + nodeIds = {n.get("id") for n in nodes if n.get("id")} + logger.debug( + "parseGraph: nodes=%d connections=%d nodeIds=%s", + len(nodes), + len(connections), + sorted(nodeIds), + ) + return nodes, connections, nodeIds + + +def buildConnectionMap(connections: List[Dict]) -> Dict[str, List[Tuple[str, int, int]]]: + """ + Build map: targetNodeId -> [(sourceNodeId, sourceOutput, targetInput), ...] + connection: { source, sourceOutput?, target, targetInput? } + """ + out: Dict[str, List[Tuple[str, int, int]]] = {} + for i, c in enumerate(connections): + src = c.get("source") or c.get("sourceNode") + tgt = c.get("target") or c.get("targetNode") + if not src or not tgt: + logger.debug("buildConnectionMap skip conn[%d]: missing source/target %r", i, c) + continue + so = c.get("sourceOutput", 0) + ti = c.get("targetInput", 0) + if tgt not in out: + out[tgt] = [] + out[tgt].append((src, so, ti)) + logger.debug("buildConnectionMap conn[%d]: %s -> %s (so=%d ti=%d)", i, src, tgt, so, ti) + logger.debug("buildConnectionMap result: %s", {k: v for k, v in out.items()}) + return out + + +def getInputSources(nodeId: str, connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> Dict[int, Tuple[str, int]]: + """ + For a node, return targetInput -> (sourceNodeId, sourceOutput). + """ + result: Dict[int, Tuple[str, int]] = {} + for src, so, ti in connectionMap.get(nodeId, []): + result[ti] = (src, so) + return result + + +def getTriggerNodes(nodes: List[Dict]) -> List[Dict]: + """Return nodes with category=trigger or type starting with trigger.""" + return [n for n in nodes if (n.get("type", "").startswith("trigger.") or n.get("category") == "trigger")] + + +def validateGraph(graph: Dict[str, Any], nodeTypeIds: Set[str]) -> List[str]: + """ + Validate graph: all node IDs referenced in connections exist, all node types in registry. + Returns list of error messages (empty if valid). + """ + errors = [] + nodes, connections, nodeIds = parseGraph(graph) + + for n in nodes: + nid = n.get("id") + ntype = n.get("type") + if not nid: + errors.append("Node missing id") + continue + if not ntype: + errors.append(f"Node {nid} missing type") + continue + if ntype not in nodeTypeIds: + errors.append(f"Unknown node type '{ntype}' for node {nid}") + + connMap = buildConnectionMap(connections) + allReferred = set() + for tgt, pairs in connMap.items(): + allReferred.add(tgt) + for src, _, _ in pairs: + allReferred.add(src) + for nid in allReferred: + if nid not in nodeIds: + errors.append(f"Connection references non-existent node {nid}") + + if errors: + logger.debug("validateGraph errors: %s", errors) + else: + logger.debug("validateGraph: OK") + return errors + + +def topoSort(nodes: List[Dict], connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> List[Dict]: + """ + Topological sort: start from trigger nodes, then BFS by connections. + Returns ordered list of nodes (trigger first, then downstream). + """ + nodeById = {n["id"]: n for n in nodes if n.get("id")} + triggers = getTriggerNodes(nodes) + if not triggers: + return list(nodes) + + visited: Set[str] = set() + order: List[Dict] = [] + + def bfs(startIds: List[str]) -> None: + from collections import deque + q = deque(startIds) + for nid in startIds: + visited.add(nid) + if nid in nodeById: + order.append(nodeById[nid]) + while q: + nid = q.popleft() + # Find all nodes that receive from nid + for tgt, pairs in connectionMap.items(): + for src, _, _ in pairs: + if src == nid and tgt not in visited: + visited.add(tgt) + q.append(tgt) + if tgt in nodeById: + order.append(nodeById[tgt]) + + triggerIds = [t["id"] for t in triggers] + logger.debug("topoSort triggers: %s", triggerIds) + bfs(triggerIds) + + # Append any orphan nodes (e.g. disconnected) + for n in nodes: + if n.get("id") and n["id"] not in visited: + order.append(n) + logger.debug("topoSort order (%d nodes): %s", len(order), [n.get("id") for n in order]) + return order + + +def resolveParameterReferences(value: Any, nodeOutputs: Dict[str, Any]) -> Any: + """ + Resolve {{nodeId.output}} or {{nodeId.output.path}} in strings/structures. + """ + import json + import re + if isinstance(value, str): + def repl(m): + ref = m.group(1).strip() + parts = ref.split(".") + nodeId = parts[0] + data = nodeOutputs.get(nodeId) + if data is None: + return m.group(0) + if len(parts) < 2: + return json.dumps(data) if isinstance(data, (dict, list)) else str(data) + rest = ".".join(parts[1:]) + if data is None: + return m.group(0) + for k in rest.split("."): + if isinstance(data, dict) and k in data: + data = data[k] + elif isinstance(data, (list, tuple)) and k.isdigit(): + data = data[int(k)] + else: + return m.group(0) + return str(data) if data is not None else m.group(0) + return re.sub(r"\{\{\s*([^}]+)\s*\}\}", repl, value) + if isinstance(value, dict): + return {k: resolveParameterReferences(v, nodeOutputs) for k, v in value.items()} + if isinstance(value, list): + return [resolveParameterReferences(v, nodeOutputs) for v in value] + return value diff --git a/modules/workflows/methods/methodAi/actions/process.py b/modules/workflows/methods/methodAi/actions/process.py index b4157f13..c3cca62b 100644 --- a/modules/workflows/methods/methodAi/actions/process.py +++ b/modules/workflows/methods/methodAi/actions/process.py @@ -1,9 +1,11 @@ # Copyright (c) 2025 Patrick Motsch # All rights reserved. +import base64 import logging import time import json +import uuid from typing import Dict, Any, List, Optional from modules.datamodels.datamodelChat import ActionResult, ActionDocument from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, ProcessingModeEnum @@ -11,6 +13,64 @@ from modules.datamodels.datamodelExtraction import ContentPart logger = logging.getLogger(__name__) + +def _is_action_document_like(obj: Any) -> bool: + """Check if object is ActionDocument-like (has documentData for inline workflow documents).""" + if obj is None: + return False + data = None + if isinstance(obj, dict): + data = obj.get("documentData") or obj.get("document_data") + else: + data = getattr(obj, "documentData", None) or getattr(obj, "document_data", None) + if data is None: + return False + if isinstance(data, bytes): + return len(data) > 0 + if isinstance(data, str): + return len(data.strip()) > 0 + return True + + +def _action_docs_to_content_parts(services, docs: List[Any]) -> List[ContentPart]: + """Extract content from ActionDocument-like objects in memory (no persistence). + Decodes base64, runs extraction pipeline, returns ContentParts for AI. + """ + from modules.datamodels.datamodelExtraction import ExtractionOptions, MergeStrategy + + all_parts = [] + extraction = getattr(services, "extraction", None) + if not extraction: + logger.warning("ai.process: No extraction service - cannot extract from inline documents") + return [] + opts = ExtractionOptions(prompt="", mergeStrategy=MergeStrategy()) + for i, doc in enumerate(docs): + raw = (doc.get("documentData") or doc.get("document_data")) if isinstance(doc, dict) else (getattr(doc, "documentData", None) or getattr(doc, "document_data", None)) + if not raw: + continue + name = doc.get("documentName", doc.get("fileName", f"document_{i}")) + mime = doc.get("mimeType", "application/octet-stream") + if isinstance(raw, str): + try: + content = base64.b64decode(raw, validate=True) + except Exception: + content = raw.encode("utf-8") + else: + content = raw if isinstance(raw, bytes) else bytes(raw) + ec = extraction.extractContentFromBytes( + documentBytes=content, + fileName=name, + mimeType=mime, + documentId=str(uuid.uuid4()), + options=opts, + ) + for p in ec.parts: + if p.data or getattr(p, "typeGroup", "") == "image": + p.metadata.setdefault("originalFileName", name) + all_parts.append(p) + logger.info(f"ai.process: Extracted {len(ec.parts)} parts from {name} (no persistence)") + return all_parts + async def process(self, parameters: Dict[str, Any]) -> ActionResult: operationId = None try: @@ -41,8 +101,21 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult: from modules.datamodels.datamodelDocref import DocumentReferenceList documentListParam = parameters.get("documentList") - # Convert to DocumentReferenceList if needed - if documentListParam is None: + inline_content_parts: Optional[List[ContentPart]] = None + + # Handle inline ActionDocuments (e.g. from SharePoint/email in automation2 – no persistence) + is_inline = ( + isinstance(documentListParam, list) + and len(documentListParam) > 0 + and _is_action_document_like(documentListParam[0]) + ) + if is_inline: + inline_content_parts = _action_docs_to_content_parts(self.services, documentListParam) + documentList = DocumentReferenceList(references=[]) + logger.info( + f"ai.process: Extracted {len(inline_content_parts)} ContentParts from {len(documentListParam)} inline ActionDocuments (no persistence)" + ) + elif documentListParam is None: documentList = DocumentReferenceList(references=[]) logger.debug(f"ai.process: documentList is None, using empty DocumentReferenceList") elif isinstance(documentListParam, DocumentReferenceList): @@ -54,6 +127,11 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult: documentList = DocumentReferenceList.from_string_list([documentListParam]) logger.info(f"ai.process: Converted string to DocumentReferenceList with {len(documentList.references)} references") elif isinstance(documentListParam, list): + first = documentListParam[0] if documentListParam else None + logger.info( + f"ai.process: documentList is list of {len(documentListParam)} items, " + f"first type={type(first).__name__}, has_documentData={_is_action_document_like(first) if first else False}" + ) documentList = DocumentReferenceList.from_string_list(documentListParam) logger.info(f"ai.process: Converted list to DocumentReferenceList with {len(documentList.references)} references") else: @@ -65,7 +143,9 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult: simpleMode = parameters.get("simpleMode", False) if not aiPrompt: - logger.error(f"aiPrompt is missing or empty. Parameters: {parameters}") + param_keys = list(parameters.keys()) if isinstance(parameters, dict) else [] + doc_count = len(parameters.get("documentList") or []) if isinstance(parameters.get("documentList"), (list, tuple)) else 0 + logger.error(f"aiPrompt is missing or empty. Parameter keys: {param_keys}, documentList: {doc_count} item(s)") return ActionResult.isFailure( error="AI prompt is required" ) @@ -86,10 +166,9 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult: mimeMap = {"txt": "text/plain", "json": "application/json", "html": "text/html", "md": "text/markdown", "csv": "text/csv", "xml": "application/xml"} output_mime_type = mimeMap.get(normalized_result_type, "text/plain") if normalized_result_type else "text/plain" - # Phase 7.3: Pass both documentList and contentParts to AI service - # (Extraction logic removed - handled by AI service) - contentParts: Optional[List[ContentPart]] = None - if "contentParts" in parameters: + # Phase 7.3: Pass documentList and/or contentParts to AI service + contentParts: Optional[List[ContentPart]] = inline_content_parts + if "contentParts" in parameters and not inline_content_parts: contentPartsParam = parameters.get("contentParts") if contentPartsParam: if isinstance(contentPartsParam, list): @@ -203,8 +282,8 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult: aiResponse = await self.services.ai.callAiContent( prompt=aiPrompt, options=options, - documentList=documentList, # Pass documentList - AI service handles extraction - contentParts=contentParts, # Pass contentParts if provided (or None) + documentList=documentList if not inline_content_parts else None, # Skip if using inline extracted parts + contentParts=contentParts, # From inline ActionDocuments (extracted in memory) or parameters outputFormat=output_format, # Can be None - AI determines from prompt parentOperationId=operationId, generationIntent=generationIntent # REQUIRED for DATA_GENERATE diff --git a/modules/workflows/methods/methodOutlook/actions/composeAndDraftEmailWithContext.py b/modules/workflows/methods/methodOutlook/actions/composeAndDraftEmailWithContext.py index e8bc94b3..09cdd1dd 100644 --- a/modules/workflows/methods/methodOutlook/actions/composeAndDraftEmailWithContext.py +++ b/modules/workflows/methods/methodOutlook/actions/composeAndDraftEmailWithContext.py @@ -20,11 +20,44 @@ async def composeAndDraftEmailWithContext(self, parameters: Dict[str, Any]) -> A bcc = parameters.get("bcc") or [] emailStyle = parameters.get("emailStyle") or "business" maxLength = parameters.get("maxLength") or 1000 - - # Only connectionReference and context are required - to is optional for drafts - if not connectionReference or not context: - return ActionResult.isFailure(error="connectionReference and context are required") - + + # Direct content from upstream (e.g. AI node): skip internal AI, use subject/body/to directly + email_content = parameters.get("emailContent") + if isinstance(email_content, dict): + direct_subject = email_content.get("subject") + direct_body = email_content.get("body") + direct_to = email_content.get("to") + if direct_subject and direct_body: + subject = str(direct_subject).strip() + body = str(direct_body).strip() + to = [direct_to] if isinstance(direct_to, str) else (direct_to or []) + if isinstance(to, str): + to = [to] + ai_attachments = [] + # Jump to create-email section (see below) + else: + direct_subject = parameters.get("subject") + direct_body = parameters.get("body") + if direct_subject and direct_body: + subject = str(direct_subject).strip() + body = str(direct_body).strip() + if isinstance(to, str): + to = [to] + ai_attachments = [] + else: + subject = None + body = None + ai_attachments = None + + use_direct_content = bool(subject and body) + + if not use_direct_content: + # Original path: require connectionReference and context + if not connectionReference or not context: + return ActionResult.isFailure(error="connectionReference and context are required") + elif not connectionReference: + return ActionResult.isFailure(error="connectionReference is required") + # Convert single values to lists for all recipient parameters if isinstance(to, str): to = [to] @@ -45,10 +78,10 @@ async def composeAndDraftEmailWithContext(self, parameters: Dict[str, Any]) -> A if not permissions_ok: return ActionResult.isFailure(error="Connection lacks necessary permissions for Outlook operations") - # Prepare documents for AI processing + # Prepare documents for AI processing (only when using AI path) from modules.datamodels.datamodelDocref import DocumentReferenceList chatDocuments = [] - if documentList: + if not use_direct_content and documentList: # Convert to DocumentReferenceList if needed if isinstance(documentList, DocumentReferenceList): docRefList = documentList @@ -60,33 +93,34 @@ async def composeAndDraftEmailWithContext(self, parameters: Dict[str, Any]) -> A docRefList = DocumentReferenceList(references=[]) chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docRefList) - # Create AI prompt for email composition - # Build document reference list for AI with expanded list contents when possible - doc_references = documentList - doc_list_text = "" - if doc_references: - lines = ["Available_Document_References:"] - for ref in doc_references: - # Each item is a label: resolve to its document list and render contained items - from modules.datamodels.datamodelDocref import DocumentReferenceList - list_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([ref])) or [] - if list_docs: - for d in list_docs: - doc_ref_label = self.services.chat.getDocumentReferenceFromChatDocument(d) - lines.append(f"- {doc_ref_label}") - else: - lines.append(" - (no documents)") - doc_list_text = "\n" + "\n".join(lines) - else: - doc_list_text = "Available_Document_References: (No documents available for attachment)" - - # Escape only the user-controlled context to prevent prompt injection - escaped_context = context.replace('"', '\\"').replace('\n', '\\n').replace('\r', '\\r') - - # Build recipients text for prompt - recipients_text = f"Recipients: {to}" if to else "Recipients: (not specified - this is a draft)" - - ai_prompt = f"""Compose an email based on this context: + if not use_direct_content: + # Create AI prompt for email composition + # Build document reference list for AI with expanded list contents when possible + doc_references = documentList + doc_list_text = "" + if doc_references: + lines = ["Available_Document_References:"] + for ref in doc_references: + # Each item is a label: resolve to its document list and render contained items + from modules.datamodels.datamodelDocref import DocumentReferenceList + list_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([ref])) or [] + if list_docs: + for d in list_docs: + doc_ref_label = self.services.chat.getDocumentReferenceFromChatDocument(d) + lines.append(f"- {doc_ref_label}") + else: + lines.append(" - (no documents)") + doc_list_text = "\n" + "\n".join(lines) + else: + doc_list_text = "Available_Document_References: (No documents available for attachment)" + + # Escape only the user-controlled context to prevent prompt injection + escaped_context = context.replace('"', '\\"').replace('\n', '\\n').replace('\r', '\\r') + + # Build recipients text for prompt + recipients_text = f"Recipients: {to}" if to else "Recipients: (not specified - this is a draft)" + + ai_prompt = f"""Compose an email based on this context: ------- {escaped_context} ------- @@ -107,93 +141,93 @@ Return JSON: "attachments": ["docItem::"] }} """ - - # Call AI service to generate email content - try: - ai_response = await self.services.ai.callAiPlanning( - prompt=ai_prompt, - placeholders=None, - debugType="email_composition" - ) - - # Parse AI response + + # Call AI service to generate email content try: - ai_content = ai_response - # Extract JSON from AI response - if "```json" in ai_content: - json_start = ai_content.find("```json") + 7 - json_end = ai_content.find("```", json_start) - json_content = ai_content[json_start:json_end].strip() - elif "{" in ai_content and "}" in ai_content: - json_start = ai_content.find("{") - json_end = ai_content.rfind("}") + 1 - json_content = ai_content[json_start:json_end] - else: - json_content = ai_content - - email_data = json.loads(json_content) - subject = email_data.get("subject", "") - body = email_data.get("body", "") - ai_attachments = email_data.get("attachments", []) - - if not subject or not body: - return ActionResult.isFailure(error="AI did not generate valid subject and body") - - # Use AI-selected attachments if provided, otherwise use all documents - normalized_ai_attachments = [] - if documentList: - try: - available_refs = [documentList] if isinstance(documentList, str) else documentList - from modules.datamodels.datamodelDocref import DocumentReferenceList - available_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(available_refs)) or [] - except Exception: - available_docs = [] + ai_response = await self.services.ai.callAiPlanning( + prompt=ai_prompt, + placeholders=None, + debugType="email_composition" + ) - # Normalize AI attachments to a list of strings - if isinstance(ai_attachments, str): - ai_attachments = [ai_attachments] - elif isinstance(ai_attachments, list): - ai_attachments = [a for a in ai_attachments if isinstance(a, str)] + # Parse AI response + try: + ai_content = ai_response + # Extract JSON from AI response + if "```json" in ai_content: + json_start = ai_content.find("```json") + 7 + json_end = ai_content.find("```", json_start) + json_content = ai_content[json_start:json_end].strip() + elif "{" in ai_content and "}" in ai_content: + json_start = ai_content.find("{") + json_end = ai_content.rfind("}") + 1 + json_content = ai_content[json_start:json_end] + else: + json_content = ai_content - if ai_attachments: + email_data = json.loads(json_content) + subject = email_data.get("subject", "") + body = email_data.get("body", "") + ai_attachments = email_data.get("attachments", []) + + if not subject or not body: + return ActionResult.isFailure(error="AI did not generate valid subject and body") + + # Use AI-selected attachments if provided, otherwise use all documents + normalized_ai_attachments = [] + if documentList: try: - ai_refs = [ai_attachments] if isinstance(ai_attachments, str) else ai_attachments + available_refs = [documentList] if isinstance(documentList, str) else documentList from modules.datamodels.datamodelDocref import DocumentReferenceList - ai_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(ai_refs)) or [] + available_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(available_refs)) or [] except Exception: - ai_docs = [] + available_docs = [] - # Intersect by document id - available_ids = {getattr(d, 'id', None) for d in available_docs} - selected_docs = [d for d in ai_docs if getattr(d, 'id', None) in available_ids] + # Normalize AI attachments to a list of strings + if isinstance(ai_attachments, str): + ai_attachments = [ai_attachments] + elif isinstance(ai_attachments, list): + ai_attachments = [a for a in ai_attachments if isinstance(a, str)] - if selected_docs: - # Map selected ChatDocuments back to docItem references (with full filename) - documentList = [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in selected_docs] - # Normalize ai_attachments to full format for storage - normalized_ai_attachments = documentList.copy() - logger.info(f"AI selected {len(documentList)} documents for attachment (resolved via ChatDocuments)") + if ai_attachments: + try: + ai_refs = [ai_attachments] if isinstance(ai_attachments, str) else ai_attachments + from modules.datamodels.datamodelDocref import DocumentReferenceList + ai_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(ai_refs)) or [] + except Exception: + ai_docs = [] + + # Intersect by document id + available_ids = {getattr(d, 'id', None) for d in available_docs} + selected_docs = [d for d in ai_docs if getattr(d, 'id', None) in available_ids] + + if selected_docs: + # Map selected ChatDocuments back to docItem references (with full filename) + documentList = [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in selected_docs] + # Normalize ai_attachments to full format for storage + normalized_ai_attachments = documentList.copy() + logger.info(f"AI selected {len(documentList)} documents for attachment (resolved via ChatDocuments)") + else: + # No intersection; use all available documents + documentList = [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in available_docs] + normalized_ai_attachments = documentList.copy() + logger.warning("AI selected attachments not found in available documents, using all documents") else: - # No intersection; use all available documents + # No AI selection; use all available documents documentList = [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in available_docs] normalized_ai_attachments = documentList.copy() - logger.warning("AI selected attachments not found in available documents, using all documents") + logger.warning("AI did not specify attachments, using all available documents") else: - # No AI selection; use all available documents - documentList = [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in available_docs] - normalized_ai_attachments = documentList.copy() - logger.warning("AI did not specify attachments, using all available documents") - else: - logger.info("No documents provided in documentList; skipping attachment processing") - - except json.JSONDecodeError as e: - logger.error(f"Failed to parse AI response as JSON: {str(e)}") - logger.error(f"AI response content: {ai_response}") - return ActionResult.isFailure(error="AI response was not valid JSON format") - - except Exception as e: - logger.error(f"Error calling AI service: {str(e)}") - return ActionResult.isFailure(error=f"Failed to generate email content: {str(e)}") + logger.info("No documents provided in documentList; skipping attachment processing") + + except json.JSONDecodeError as e: + logger.error(f"Failed to parse AI response as JSON: {str(e)}") + logger.error(f"AI response content: {ai_response}") + return ActionResult.isFailure(error="AI response was not valid JSON format") + + except Exception as e: + logger.error(f"Error calling AI service: {str(e)}") + return ActionResult.isFailure(error=f"Failed to generate email content: {str(e)}") # Now create the email with AI-generated content try: @@ -227,35 +261,78 @@ Return JSON: } # Add documents as attachments if provided + # Supports: 1) inline ActionDocuments (dict with documentData from e.g. sharepoint.downloadFile) + # 2) docItem:... references (chat workflow documents) if documentList: message["attachments"] = [] for attachment_ref in documentList: - # Get attachment document from service center - from modules.datamodels.datamodelDocref import DocumentReferenceList - attachment_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([attachment_ref])) + base64_content = None + attach_name = "attachment" + attach_mime = "application/octet-stream" + + # Inline document: dict/object with documentData (from automation2 upstream, e.g. sharepoint.downloadFile) + is_inline = isinstance(attachment_ref, dict) and attachment_ref.get("documentData") + if not is_inline and hasattr(attachment_ref, "documentData"): + is_inline = bool(getattr(attachment_ref, "documentData", None)) + if is_inline: + doc = attachment_ref + base64_content = doc.get("documentData") if isinstance(doc, dict) else getattr(doc, "documentData", None) + attach_name = (doc.get("documentName") or doc.get("fileName")) if isinstance(doc, dict) else (getattr(doc, "documentName", None) or getattr(doc, "fileName", "attachment")) + attach_mime = (doc.get("mimeType") or attach_mime) if isinstance(doc, dict) else (getattr(doc, "mimeType", None) or attach_mime) + if base64_content and attach_name: + message["attachments"].append({ + "@odata.type": "#microsoft.graph.fileAttachment", + "name": attach_name, + "contentType": attach_mime, + "contentBytes": base64_content + }) + continue + # fileId in validationMetadata: resolve via getFileData (avoids large base64 in pipeline) + file_id = None + if isinstance(attachment_ref, dict): + vm = attachment_ref.get("validationMetadata") or {} + file_id = vm.get("fileId") + elif hasattr(attachment_ref, "validationMetadata"): + vm = getattr(attachment_ref, "validationMetadata") or {} + file_id = vm.get("fileId") if isinstance(vm, dict) else None + if file_id: + try: + file_content = self.services.chat.getFileData(file_id) + if file_content: + base64_content = base64.b64encode(file_content if isinstance(file_content, bytes) else str(file_content).encode("utf-8")).decode("utf-8") + name = (attachment_ref.get("documentName") or attachment_ref.get("fileName", "attachment")) if isinstance(attachment_ref, dict) else (getattr(attachment_ref, "documentName", None) or getattr(attachment_ref, "fileName", "attachment")) + mime = (attachment_ref.get("mimeType") or attach_mime) if isinstance(attachment_ref, dict) else (getattr(attachment_ref, "mimeType", None) or attach_mime) + message["attachments"].append({ + "@odata.type": "#microsoft.graph.fileAttachment", + "name": name, + "contentType": mime, + "contentBytes": base64_content + }) + continue + except Exception as e: + logger.warning("Could not load file %s for attachment: %s", file_id, e) + + # docItem:... reference (chat workflow) – only when it's a string ref + attachment_docs = [] + if isinstance(attachment_ref, str) and attachment_ref.strip(): + from modules.datamodels.datamodelDocref import DocumentReferenceList + attachment_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([attachment_ref])) if attachment_docs: for doc in attachment_docs: - file_id = getattr(doc, 'fileId', None) - if file_id: + fid = getattr(doc, 'fileId', None) + if fid: try: - file_content = self.services.chat.getFileData(file_id) + file_content = self.services.chat.getFileData(fid) if file_content: - if isinstance(file_content, bytes): - content_bytes = file_content - else: - content_bytes = str(file_content).encode('utf-8') - - base64_content = base64.b64encode(content_bytes).decode('utf-8') - - attachment = { + cb = file_content if isinstance(file_content, bytes) else str(file_content).encode('utf-8') + message["attachments"].append({ "@odata.type": "#microsoft.graph.fileAttachment", "name": doc.fileName, "contentType": doc.mimeType or "application/octet-stream", - "contentBytes": base64_content - } - message["attachments"].append(attachment) + "contentBytes": base64.b64encode(cb).decode('utf-8') + }) except Exception as e: - logger.error(f"Error reading attachment file {doc.fileName}: {str(e)}") + logger.error("Error reading attachment file %s: %s", doc.fileName, e) # Create the draft message drafts_folder_id = self.folderManagement.getFolderId("Drafts", connection) @@ -297,14 +374,21 @@ Return JSON: attachmentFilenames = [] attachmentReferences = [] if documentList: - try: - from modules.datamodels.datamodelDocref import DocumentReferenceList - attached_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(documentList)) or [] - attachmentFilenames = [getattr(doc, 'fileName', '') for doc in attached_docs if getattr(doc, 'fileName', None)] - # Store normalized document references (with filenames) - use normalized_ai_attachments if available - attachmentReferences = normalized_ai_attachments if normalized_ai_attachments else [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in attached_docs] - except Exception: - pass + # Inline docs (dict with documentName): use directly + string_refs = [r for r in documentList if isinstance(r, str)] + inline_docs = [r for r in documentList if isinstance(r, dict)] + for d in inline_docs: + name = d.get("documentName") or d.get("fileName") + if name: + attachmentFilenames.append(name) + if string_refs: + try: + from modules.datamodels.datamodelDocref import DocumentReferenceList + attached_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(string_refs)) or [] + attachmentFilenames.extend(getattr(doc, 'fileName', '') for doc in attached_docs if getattr(doc, 'fileName', None)) + attachmentReferences = normalized_ai_attachments if normalized_ai_attachments else [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in attached_docs] + except Exception: + pass # Create validation metadata for content validator validationMetadata = { diff --git a/modules/workflows/methods/methodOutlook/actions/readEmails.py b/modules/workflows/methods/methodOutlook/actions/readEmails.py index f388f818..5620a62d 100644 --- a/modules/workflows/methods/methodOutlook/actions/readEmails.py +++ b/modules/workflows/methods/methodOutlook/actions/readEmails.py @@ -49,9 +49,9 @@ async def readEmails(self, parameters: Dict[str, Any]) -> ActionResult: if filter: # Remove any potentially dangerous characters that could break the filter filter = filter.strip() - if len(filter) > 100: - logger.warning(f"Filter too long ({len(filter)} chars), truncating to 100 characters") - filter = filter[:100] + if len(filter) > 500: + logger.warning(f"Filter too long ({len(filter)} chars), truncating to 500 characters") + filter = filter[:500] # Get Microsoft connection diff --git a/modules/workflows/methods/methodOutlook/actions/searchEmails.py b/modules/workflows/methods/methodOutlook/actions/searchEmails.py index c7f839b6..f12c6d71 100644 --- a/modules/workflows/methods/methodOutlook/actions/searchEmails.py +++ b/modules/workflows/methods/methodOutlook/actions/searchEmails.py @@ -73,8 +73,12 @@ async def searchEmails(self, parameters: Dict[str, Any]) -> ActionResult: logger.warning(f"Could not find folder ID for '{folder}', using folder name directly") # Build the search API request - api_url = f"{graph_url}/me/messages" params = self.emailProcessing.buildSearchParameters(query, folder_id or folder, limit) + # Use folder-specific URL when we have folder_id and $search - avoids InefficientFilter + if folder_id and params.get("$search"): + api_url = f"{graph_url}/me/mailFolders/{folder_id}/messages" + else: + api_url = f"{graph_url}/me/messages" # Log search parameters for debugging logger.debug(f"Search query: '{query}'") diff --git a/modules/workflows/methods/methodOutlook/helpers/emailProcessing.py b/modules/workflows/methods/methodOutlook/helpers/emailProcessing.py index f1736221..d34bb778 100644 --- a/modules/workflows/methods/methodOutlook/helpers/emailProcessing.py +++ b/modules/workflows/methods/methodOutlook/helpers/emailProcessing.py @@ -53,7 +53,7 @@ class EmailProcessingHelper: # Handle common search operators # Recognize Graph operators including both singular and plural forms for hasAttachments lowered = clean_query.lower() - if any(op in lowered for op in ['from:', 'to:', 'subject:', 'received:', 'hasattachment:', 'hasattachments:']): + if any(op in lowered for op in ['from:', 'to:', 'subject:', 'body:', 'received:', 'hasattachment:', 'hasattachments:']): # This is an advanced search query, return as-is return clean_query @@ -104,7 +104,7 @@ class EmailProcessingHelper: # Check if this is a complex search query with multiple operators # Recognize Graph operators including both singular and plural forms for hasAttachments lowered = clean_query.lower() - if any(op in lowered for op in ['from:', 'to:', 'subject:', 'received:', 'hasattachment:', 'hasattachments:']): + if any(op in lowered for op in ['from:', 'to:', 'subject:', 'body:', 'received:', 'hasattachment:', 'hasattachments:']): # This is an advanced search query, use $search # Microsoft Graph API supports complex search syntax params["$search"] = f'"{clean_query}"' @@ -113,34 +113,20 @@ class EmailProcessingHelper: # We'll need to filter results after the API call # Folder filtering will be done after the API call else: - # Use $filter for basic text search, but keep it simple to avoid "InefficientFilter" error - # Microsoft Graph API has limitations on complex filters + # Use $search (KQL) instead of $filter to avoid "InefficientFilter" - Graph rejects + # contains(subject,x) + parentFolderId + orderby. $search handles subject:query. if len(clean_query) > 50: - # If query is too long, truncate it to avoid complex filter issues clean_query = clean_query[:50] - - - # Use only subject search to keep filter simple - # Handle wildcard queries specially if clean_query == "*" or clean_query == "": - # For wildcard or empty query, don't use contains filter - # Just use folder filter if specified if folder and folder.lower() != "all": params["$filter"] = f"parentFolderId eq '{folder}'" - else: - # No filter needed for wildcard search across all folders - pass + params["$orderby"] = "receivedDateTime desc" else: - params["$filter"] = f"contains(subject,'{clean_query}')" - - # Add folder filter if specified - if folder and folder.lower() != "all": - params["$filter"] = f"{params['$filter']} and parentFolderId eq '{folder}'" - - # Add orderby for basic queries - params["$orderby"] = "receivedDateTime desc" + # Use $search with subject: to avoid InefficientFilter + safe = clean_query.replace('"', '') + params["$search"] = f'"subject:{safe}"' + # Folder filtering done post-API in searchEmails when $search is used - return params def buildGraphFilter(self, filter_text: str) -> Dict[str, str]: @@ -168,7 +154,7 @@ class EmailProcessingHelper: # Handle search queries (from:, to:, subject:, etc.) - check this FIRST # Support both singular and plural forms for hasAttachments lt = filter_text.lower() - if any(lt.startswith(prefix) for prefix in ['from:', 'to:', 'subject:', 'received:', 'hasattachment:', 'hasattachments:']): + if any(lt.startswith(prefix) for prefix in ['from:', 'to:', 'subject:', 'body:', 'received:', 'hasattachment:', 'hasattachments:']): return {"$search": f'"{filter_text}"'} # Handle email address filters (only if it's NOT a search query) diff --git a/modules/workflows/methods/methodOutlook/helpers/folderManagement.py b/modules/workflows/methods/methodOutlook/helpers/folderManagement.py index 1ca7be87..47309a8b 100644 --- a/modules/workflows/methods/methodOutlook/helpers/folderManagement.py +++ b/modules/workflows/methods/methodOutlook/helpers/folderManagement.py @@ -27,10 +27,15 @@ class FolderManagementHelper: def getFolderId(self, folder_name: str, connection: Dict[str, Any]) -> Optional[str]: """ - Get the folder ID for a given folder name - - This is needed for proper filtering when using advanced search queries + Get the folder ID for a given folder name or ID. + Returns the input as-is if it already looks like a Microsoft Graph folder ID. """ + if not folder_name or not str(folder_name).strip(): + return None + # Graph folder IDs are base64-like strings (e.g. AQMk...); return as-is + s = str(folder_name).strip() + if s.startswith("AQMk") and len(s) > 20 and " " not in s: + return s try: graph_url = "https://graph.microsoft.com/v1.0" headers = { diff --git a/modules/workflows/methods/methodOutlook/methodOutlook.py b/modules/workflows/methods/methodOutlook/methodOutlook.py index 8d80cef5..633f396d 100644 --- a/modules/workflows/methods/methodOutlook/methodOutlook.py +++ b/modules/workflows/methods/methodOutlook/methodOutlook.py @@ -157,15 +157,22 @@ class MethodOutlook(MethodBase): name="context", type="str", frontendType=FrontendType.TEXTAREA, - required=True, - description="Detailed context for composing the email" + required=False, + description="Detailed context for AI composition (omit when emailContent provided)" + ), + "emailContent": WorkflowActionParameter( + name="emailContent", + type="dict", + frontendType=FrontendType.HIDDEN, + required=False, + description="Direct subject/body/to from upstream (skips AI composition)" ), "documentList": WorkflowActionParameter( name="documentList", - type="List[str]", + type="List[Any]", frontendType=FrontendType.DOCUMENT_REFERENCE, required=False, - description="Document references for context/attachments" + description="Document references or inline ActionDocuments for attachments" ), "cc": WorkflowActionParameter( name="cc", diff --git a/modules/workflows/methods/methodSharepoint/actions/copyFile.py b/modules/workflows/methods/methodSharepoint/actions/copyFile.py index f149e482..92ce88a2 100644 --- a/modules/workflows/methods/methodSharepoint/actions/copyFile.py +++ b/modules/workflows/methods/methodSharepoint/actions/copyFile.py @@ -14,52 +14,69 @@ async def copyFile(self, parameters: Dict[str, Any]) -> ActionResult: if not connectionReference: return ActionResult.isFailure(error="connectionReference parameter is required") - siteIdParam = parameters.get("siteId") - if not siteIdParam: - return ActionResult.isFailure(error="siteId parameter is required") - - sourceFolder = parameters.get("sourceFolder") - if not sourceFolder: - return ActionResult.isFailure(error="sourceFolder parameter is required") - - sourceFile = parameters.get("sourceFile") - if not sourceFile: - return ActionResult.isFailure(error="sourceFile parameter is required") - - destFolder = parameters.get("destFolder") - if not destFolder: - return ActionResult.isFailure(error="destFolder parameter is required") - - destFile = parameters.get("destFile") - if not destFile: - return ActionResult.isFailure(error="destFile parameter is required") - - # Extract siteId from document if it's a reference - siteId = None - if isinstance(siteIdParam, str): - from modules.datamodels.datamodelDocref import DocumentReferenceList - try: - docList = DocumentReferenceList.from_string_list([siteIdParam]) - chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docList) - if chatDocuments and len(chatDocuments) > 0: - siteInfoJson = json.loads(chatDocuments[0].documentData) - siteId = siteInfoJson.get("id") - except: - pass - - if not siteId: - siteId = siteIdParam - else: - siteId = siteIdParam - - if not siteId: - return ActionResult.isFailure(error="Could not extract siteId from parameter") - - # Get Microsoft connection + # Set SharePoint access token first – required before siteDiscovery/sharepoint calls connection = self.connection.getMicrosoftConnection(connectionReference) if not connection: return ActionResult.isFailure(error="No valid Microsoft connection found for the provided connection reference") + sourcePath = (parameters.get("sourcePath") or parameters.get("sourcePathQuery") or "").strip() + destPath = (parameters.get("destPath") or parameters.get("destPathQuery") or "").strip() + + siteId = None + sourceFolder = None + sourceFile = None + destFolder = None + destFile = None + + if sourcePath and destPath and sourcePath.startswith("/sites/") and destPath.startswith("/sites/"): + parsedSource = self.services.sharepoint.extractSiteFromStandardPath(sourcePath) + parsedDest = self.services.sharepoint.extractSiteFromStandardPath(destPath) + if parsedSource and parsedDest: + innerSrc = (parsedSource.get("innerPath") or "").strip().rstrip("/") + innerDest = (parsedDest.get("innerPath") or "").strip().rstrip("/") + if innerSrc: + if "/" in innerSrc: + sourceFolder = innerSrc.rsplit("/", 1)[0] + sourceFile = innerSrc.rsplit("/", 1)[-1] + else: + sourceFolder = "" + sourceFile = innerSrc + destFolder = innerDest + destFile = sourceFile + sites, _ = await self.siteDiscovery.resolveSitesFromPathQuery(sourcePath) + if sites: + siteId = sites[0].get("id") + + if not siteId or not sourceFolder or not sourceFile or not destFolder: + siteIdParam = parameters.get("siteId") + sourceFolder = parameters.get("sourceFolder") + sourceFile = parameters.get("sourceFile") + destFolder = parameters.get("destFolder") + destFile = parameters.get("destFile") + if not siteIdParam: + return ActionResult.isFailure(error="Either sourcePath+destPath or siteId, sourceFolder, sourceFile, destFolder, destFile are required") + if not sourceFolder or not sourceFile or not destFolder or not destFile: + return ActionResult.isFailure(error="sourceFolder, sourceFile, destFolder, and destFile are required") + if not destFile: + destFile = sourceFile + if isinstance(siteIdParam, str): + from modules.datamodels.datamodelDocref import DocumentReferenceList + try: + docList = DocumentReferenceList.from_string_list([siteIdParam]) + chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docList) + if chatDocuments and len(chatDocuments) > 0: + siteInfoJson = json.loads(chatDocuments[0].documentData) + siteId = siteInfoJson.get("id") + except Exception: + pass + if not siteId: + siteId = siteIdParam + else: + siteId = siteIdParam + + if not siteId: + return ActionResult.isFailure(error="Could not resolve siteId") + # Copy file await self.services.sharepoint.copyFileAsync( siteId=siteId, diff --git a/modules/workflows/methods/methodSharepoint/actions/downloadFileByPath.py b/modules/workflows/methods/methodSharepoint/actions/downloadFileByPath.py index c64a6637..447d8c08 100644 --- a/modules/workflows/methods/methodSharepoint/actions/downloadFileByPath.py +++ b/modules/workflows/methods/methodSharepoint/actions/downloadFileByPath.py @@ -16,43 +16,59 @@ async def downloadFileByPath(self, parameters: Dict[str, Any]) -> ActionResult: if not connectionReference: return ActionResult.isFailure(error="connectionReference parameter is required") - siteIdParam = parameters.get("siteId") - if not siteIdParam: - return ActionResult.isFailure(error="siteId parameter is required") - - filePath = parameters.get("filePath") - if not filePath: - return ActionResult.isFailure(error="filePath parameter is required") - - # Extract siteId from document if it's a reference - siteId = None - if isinstance(siteIdParam, str): - # Try to parse from document reference - from modules.datamodels.datamodelDocref import DocumentReferenceList - try: - docList = DocumentReferenceList.from_string_list([siteIdParam]) - chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docList) - if chatDocuments and len(chatDocuments) > 0: - siteInfoJson = json.loads(chatDocuments[0].documentData) - siteId = siteInfoJson.get("id") - except: - pass - - if not siteId: - # Assume it's the site ID directly - siteId = siteIdParam - else: - siteId = siteIdParam - - if not siteId: - return ActionResult.isFailure(error="Could not extract siteId from parameter") - - # Get Microsoft connection + # Set SharePoint access token first – required before siteDiscovery/sharepoint calls connection = self.connection.getMicrosoftConnection(connectionReference) if not connection: return ActionResult.isFailure(error="No valid Microsoft connection found for the provided connection reference") - # Download file + pathQuery = (parameters.get("pathQuery") or parameters.get("path") or "").strip() + siteIdParam = parameters.get("siteId") + filePath = parameters.get("filePath") + # If filePath looks like full SharePoint path, use as pathQuery fallback + if not pathQuery and filePath and isinstance(filePath, str) and filePath.strip().startswith("/sites/"): + pathQuery = filePath.strip() + + siteId = None + innerPath = None + + # Option 1: pathQuery provided (e.g. /sites/SiteName/Shared Documents/file.pdf) – resolve site and inner path + if pathQuery and pathQuery != "*": + sites, errorMsg = await self.siteDiscovery.resolveSitesFromPathQuery(pathQuery) + if errorMsg: + return ActionResult.isFailure(error=errorMsg) + if not sites: + return ActionResult.isFailure(error="Could not resolve site from pathQuery") + parsedPath = self.services.sharepoint.extractSiteFromStandardPath(pathQuery) + if not parsedPath: + return ActionResult.isFailure(error="pathQuery must be a standard SharePoint path (e.g. /sites/SiteName/Shared Documents/file.pdf)") + innerPath = (parsedPath.get("innerPath") or "").strip() + if not innerPath: + return ActionResult.isFailure(error="pathQuery must include a file path (e.g. /sites/SiteName/Shared Documents/file.pdf)") + siteId = sites[0].get("id") + filePath = innerPath + elif siteIdParam and filePath: + # Option 2: siteId + filePath provided directly + if isinstance(siteIdParam, str): + from modules.datamodels.datamodelDocref import DocumentReferenceList + try: + docList = DocumentReferenceList.from_string_list([siteIdParam]) + chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docList) + if chatDocuments and len(chatDocuments) > 0: + siteInfoJson = json.loads(chatDocuments[0].documentData) + siteId = siteInfoJson.get("id") + except Exception: + pass + if not siteId: + siteId = siteIdParam + else: + siteId = siteIdParam + else: + return ActionResult.isFailure(error="Either pathQuery (e.g. /sites/SiteName/Shared Documents/file.pdf) or both siteId and filePath are required") + + if not siteId or not filePath: + return ActionResult.isFailure(error="Could not resolve siteId and file path from parameters") + + # Download file (connection/token already set above) fileContent = await self.services.sharepoint.downloadFileByPath( siteId=siteId, filePath=filePath @@ -73,7 +89,19 @@ async def downloadFileByPath(self, parameters: Dict[str, Any]) -> ActionResult: "downloadFileByPath" ) - # Encode as base64 + # Save to user's Files (FileItem + FileData) via interfaceDbComponent – appears in Files UI + fileItem = None + db = getattr(self.services, "interfaceDbComponent", None) + if db: + try: + mimeType = db.getMimeType(filename) if hasattr(db, "getMimeType") else "application/octet-stream" + fileItem = db.createFile(name=filename, mimeType=mimeType, content=fileContent) + db.createFileData(fileItem.id, fileContent) + logger.info(f"Saved SharePoint file to user Files: {filename} (id={fileItem.id})") + except Exception as e: + logger.warning(f"Could not save to user Files: {e}") + + # Encode as base64 for workflow context (AI, data nodes) fileBase64 = base64.b64encode(fileContent).decode('utf-8') validationMetadata = self._createValidationMetadata( @@ -82,6 +110,8 @@ async def downloadFileByPath(self, parameters: Dict[str, Any]) -> ActionResult: filePath=filePath, fileSize=len(fileContent) ) + if fileItem: + validationMetadata["fileId"] = fileItem.id document = ActionDocument( documentName=filename, diff --git a/modules/workflows/methods/methodSharepoint/actions/readDocuments.py b/modules/workflows/methods/methodSharepoint/actions/readDocuments.py index 73cdb730..542ab2e8 100644 --- a/modules/workflows/methods/methodSharepoint/actions/readDocuments.py +++ b/modules/workflows/methods/methodSharepoint/actions/readDocuments.py @@ -244,7 +244,56 @@ async def readDocuments(self, parameters: Dict[str, Any]) -> ActionResult: self.services.chat.progressLogFinish(operationId, False) return ActionResult.isFailure(error="Either documentList must contain findDocumentPath result with file information, or pathQuery must be provided. Use findDocumentPath first to get file paths, or provide pathQuery directly.") - # This should never be reached if logic above is correct + # When we have pathQuery + sites but no sharePointFileIds (e.g. user selected from browse tree), + # download the file by path + if pathQuery and pathQuery.strip() != "" and pathQuery.strip() != "*" and sites and not sharePointFileIds: + parsedPath = self.services.sharepoint.extractSiteFromStandardPath(pathQuery) + if parsedPath: + innerPath = (parsedPath.get("innerPath") or "").strip() + if not innerPath: + if operationId: + self.services.chat.progressLogFinish(operationId, False) + return ActionResult.isFailure(error="pathQuery must include a file path (e.g. /sites/SiteName/Shared Documents/file.pdf)") + siteId = sites[0].get("id") + if not siteId: + if operationId: + self.services.chat.progressLogFinish(operationId, False) + return ActionResult.isFailure(error="Could not resolve site ID from pathQuery") + self.services.chat.progressLogUpdate(operationId, 0.5, f"Reading file from path: {innerPath}") + fileContent = await self.services.sharepoint.downloadFileByPath(siteId=siteId, filePath=innerPath) + if fileContent is None: + if operationId: + self.services.chat.progressLogFinish(operationId, False) + return ActionResult.isFailure(error=f"File not found or could not be downloaded: {innerPath}") + fileName = innerPath.split("/")[-1] if "/" in innerPath else innerPath + mimeType = "application/octet-stream" + if fileName.endswith(".pdf"): + mimeType = "application/pdf" + elif fileName.endswith(".txt"): + mimeType = "text/plain" + elif fileName.endswith(".json"): + mimeType = "application/json" + base64Content = base64.b64encode(fileContent).decode("utf-8") + validationMetadata = { + "actionType": "sharepoint.readDocuments", + "fileName": fileName, + "sharepointFileId": None, + "siteName": sites[0].get("displayName"), + "mimeType": mimeType, + "contentType": "binary", + "size": len(fileContent), + "includeMetadata": includeMetadata + } + actionDoc = ActionDocument( + documentName=fileName, + documentData=base64Content, + mimeType=mimeType, + validationMetadata=validationMetadata + ) + self.services.chat.progressLogUpdate(operationId, 0.9, f"Read 1 document(s)") + self.services.chat.progressLogFinish(operationId, True) + return ActionResult.isSuccess(documents=[actionDoc]) + if operationId: self.services.chat.progressLogFinish(operationId, False) return ActionResult.isFailure(error="Unexpected error: could not process documentList or pathQuery") diff --git a/modules/workflows/methods/methodSharepoint/actions/uploadFile.py b/modules/workflows/methods/methodSharepoint/actions/uploadFile.py index 1f469b80..86d5787d 100644 --- a/modules/workflows/methods/methodSharepoint/actions/uploadFile.py +++ b/modules/workflows/methods/methodSharepoint/actions/uploadFile.py @@ -15,51 +15,84 @@ async def uploadFile(self, parameters: Dict[str, Any]) -> ActionResult: if not connectionReference: return ActionResult.isFailure(error="connectionReference parameter is required") - siteIdParam = parameters.get("siteId") - if not siteIdParam: - return ActionResult.isFailure(error="siteId parameter is required") - - folderPath = parameters.get("folderPath") - if not folderPath: - return ActionResult.isFailure(error="folderPath parameter is required") - - fileName = parameters.get("fileName") - if not fileName: - return ActionResult.isFailure(error="fileName parameter is required") - contentParam = parameters.get("content") if not contentParam: return ActionResult.isFailure(error="content parameter is required") - # Extract siteId from document if it's a reference + # Resolve siteId and folderPath: pathQuery (path) or explicit siteId+folderPath + pathQuery = (parameters.get("pathQuery") or parameters.get("path") or "").strip() + siteIdParam = parameters.get("siteId") + folderPath = parameters.get("folderPath") siteId = None - if isinstance(siteIdParam, str): - from modules.datamodels.datamodelDocref import DocumentReferenceList - try: - docList = DocumentReferenceList.from_string_list([siteIdParam]) - chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docList) - if chatDocuments and len(chatDocuments) > 0: - siteInfoJson = json.loads(chatDocuments[0].documentData) - siteId = siteInfoJson.get("id") - except: - pass - - if not siteId: + + if pathQuery and pathQuery != "*": + # Option 1: pathQuery (e.g. /sites/host,siteId,webId/15. Persoenliche Ordner/Ida Dittrich/) + sites, errorMsg = await self.siteDiscovery.resolveSitesFromPathQuery(pathQuery) + if errorMsg: + return ActionResult.isFailure(error=errorMsg) + if not sites: + return ActionResult.isFailure(error="Could not resolve site from path") + parsedPath = self.services.sharepoint.extractSiteFromStandardPath(pathQuery) + if not parsedPath: + return ActionResult.isFailure(error="path must be a standard SharePoint path (e.g. /sites/SiteName/Shared Documents/Folder)") + innerPath = (parsedPath.get("innerPath") or "").strip().rstrip("/") + siteId = sites[0].get("id") + folderPath = innerPath + elif siteIdParam and folderPath: + # Option 2: explicit siteId + folderPath + if isinstance(siteIdParam, str): + from modules.datamodels.datamodelDocref import DocumentReferenceList + try: + docList = DocumentReferenceList.from_string_list([siteIdParam]) + chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docList) + if chatDocuments and len(chatDocuments) > 0: + siteInfoJson = json.loads(chatDocuments[0].documentData) + siteId = siteInfoJson.get("id") + except Exception: + pass + if not siteId: + siteId = siteIdParam + else: siteId = siteIdParam else: - siteId = siteIdParam + return ActionResult.isFailure(error="Either path (e.g. /sites/.../Folder) or both siteId and folderPath are required") if not siteId: - return ActionResult.isFailure(error="Could not extract siteId from parameter") + return ActionResult.isFailure(error="Could not resolve siteId") - # Get file content from document - from modules.datamodels.datamodelDocref import DocumentReferenceList - docList = DocumentReferenceList.from_string_list([contentParam] if isinstance(contentParam, str) else contentParam) - chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docList) - if not chatDocuments or len(chatDocuments) == 0: - return ActionResult.isFailure(error="Could not get file content from document reference") + # fileName: from param or from content document + fileName = parameters.get("fileName") + if not fileName and contentParam: + content = contentParam[0] if isinstance(contentParam, (list, tuple)) and contentParam else contentParam + if isinstance(content, dict): + fileName = content.get("documentName") or content.get("fileName") + elif hasattr(content, "documentName"): + fileName = getattr(content, "documentName", None) or getattr(content, "fileName", None) + if not fileName: + fileName = "file" - fileContentBase64 = chatDocuments[0].documentData + # Get file content: support inline ActionDocument (from automation2 e.g. sharepoint.downloadFile) + # or docItem references (chat workflow) + content = contentParam[0] if isinstance(contentParam, (list, tuple)) and contentParam else contentParam + fileContentBase64 = None + if isinstance(content, dict) and content.get("documentData"): + fileContentBase64 = content.get("documentData") + elif hasattr(content, "documentData") and content.documentData: + fileContentBase64 = content.documentData + elif isinstance(content, dict) and (content.get("validationMetadata") or {}).get("fileId"): + file_id = content["validationMetadata"]["fileId"] + try: + raw = self.services.chat.getFileData(file_id) + fileContentBase64 = base64.b64encode(raw if isinstance(raw, bytes) else str(raw).encode("utf-8")).decode("utf-8") + except Exception as e: + return ActionResult.isFailure(error=f"Could not load file content from fileId {file_id}: {e}") + if not fileContentBase64: + from modules.datamodels.datamodelDocref import DocumentReferenceList + docList = DocumentReferenceList.from_string_list([content] if isinstance(content, str) else content) + chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docList) + if not chatDocuments or len(chatDocuments) == 0: + return ActionResult.isFailure(error="Could not get file content from document reference") + fileContentBase64 = chatDocuments[0].documentData # Decode base64 try: diff --git a/modules/workflows/methods/methodSharepoint/methodSharepoint.py b/modules/workflows/methods/methodSharepoint/methodSharepoint.py index 6ee07736..0fa0aca8 100644 --- a/modules/workflows/methods/methodSharepoint/methodSharepoint.py +++ b/modules/workflows/methods/methodSharepoint/methodSharepoint.py @@ -274,19 +274,26 @@ class MethodSharepoint(MethodBase): required=True, description="Microsoft connection label" ), + "pathQuery": WorkflowActionParameter( + name="pathQuery", + type="str", + frontendType=FrontendType.TEXT, + required=False, + description="Full SharePoint path (e.g. /sites/SiteName/Shared Documents/file.pdf). When provided, siteId and filePath are derived automatically." + ), "siteId": WorkflowActionParameter( name="siteId", type="str", frontendType=FrontendType.TEXT, - required=True, - description="SharePoint site ID (from findSiteByUrl result) or document reference containing site info" + required=False, + description="SharePoint site ID (optional when pathQuery is provided)" ), "filePath": WorkflowActionParameter( name="filePath", type="str", frontendType=FrontendType.TEXT, - required=True, - description="Full file path relative to site root (e.g., /General/50 Docs hosted by SELISE/file.xlsx)" + required=False, + description="File path relative to site root (optional when pathQuery is provided)" ) }, execute=downloadFileByPath.__get__(self, self.__class__) @@ -303,40 +310,19 @@ class MethodSharepoint(MethodBase): required=True, description="Microsoft connection label" ), - "siteId": WorkflowActionParameter( - name="siteId", + "sourcePath": WorkflowActionParameter( + name="sourcePath", type="str", frontendType=FrontendType.TEXT, required=True, - description="SharePoint site ID (from findSiteByUrl result) or document reference containing site info" + description="Full path to source file (e.g. /sites/.../folder/file.pdf)" ), - "sourceFolder": WorkflowActionParameter( - name="sourceFolder", + "destPath": WorkflowActionParameter( + name="destPath", type="str", frontendType=FrontendType.TEXT, required=True, - description="Source folder path relative to site root" - ), - "sourceFile": WorkflowActionParameter( - name="sourceFile", - type="str", - frontendType=FrontendType.TEXT, - required=True, - description="Source file name" - ), - "destFolder": WorkflowActionParameter( - name="destFolder", - type="str", - frontendType=FrontendType.TEXT, - required=True, - description="Destination folder path relative to site root" - ), - "destFile": WorkflowActionParameter( - name="destFile", - type="str", - frontendType=FrontendType.TEXT, - required=True, - description="Destination file name" + description="Full path to destination folder (e.g. /sites/.../folder)" ) }, execute=copyFile.__get__(self, self.__class__) @@ -353,33 +339,40 @@ class MethodSharepoint(MethodBase): required=True, description="Microsoft connection label" ), + "pathQuery": WorkflowActionParameter( + name="pathQuery", + type="str", + frontendType=FrontendType.TEXT, + required=False, + description="Target folder path (e.g. /sites/.../Folder). When provided, siteId and folderPath are derived. Alternative to explicit siteId+folderPath." + ), "siteId": WorkflowActionParameter( name="siteId", type="str", frontendType=FrontendType.TEXT, - required=True, - description="SharePoint site ID (from findSiteByUrl result) or document reference containing site info" + required=False, + description="SharePoint site ID (when not using pathQuery)" ), "folderPath": WorkflowActionParameter( name="folderPath", type="str", frontendType=FrontendType.TEXT, - required=True, - description="Folder path relative to site root" + required=False, + description="Folder path relative to site root (when not using pathQuery)" ), "fileName": WorkflowActionParameter( name="fileName", type="str", frontendType=FrontendType.TEXT, - required=True, - description="File name" + required=False, + description="File name (defaults to document name when content from context)" ), "content": WorkflowActionParameter( name="content", - type="str", + type="Any", frontendType=FrontendType.DOCUMENT_REFERENCE, required=True, - description="Document reference containing file content as base64-encoded bytes" + description="File content from context (upstream document) or document reference" ) }, execute=uploadFile.__get__(self, self.__class__) diff --git a/modules/workflows/processing/core/actionExecutor.py b/modules/workflows/processing/core/actionExecutor.py index 6b1e3544..92813213 100644 --- a/modules/workflows/processing/core/actionExecutor.py +++ b/modules/workflows/processing/core/actionExecutor.py @@ -69,15 +69,14 @@ class ActionExecutor: logger.info(f"=== TASK {taskNum} ACTION {actionNum}: {action.execMethod}.{action.execAction} ===") - # Log input parameters + # Log input parameters (redact documentData to avoid dumping base64 in logs) inputDocs = action.execParameters.get('documentList', []) inputConnections = action.execParameters.get('connections', []) - logger.info(f"Input documents: {inputDocs} (type: {type(inputDocs)})") + doc_preview = f"{len(inputDocs)} item(s)" if isinstance(inputDocs, (list, tuple)) else str(type(inputDocs)) + logger.info(f"Input documents: {doc_preview} (type: {type(inputDocs).__name__})") if inputConnections: logger.info(f"Input connections: {inputConnections}") - - # Log all action parameters for debugging - logger.info(f"All action parameters: {action.execParameters}") + logger.debug(f"Action parameters keys: {list(action.execParameters.keys())}") enhancedParameters = action.execParameters.copy() if action.expectedDocumentFormats: diff --git a/requirements.lock b/requirements.lock index 303535a4..d55fc5f7 100644 --- a/requirements.lock +++ b/requirements.lock @@ -614,3 +614,6 @@ zipp==3.23.0 # via importlib-metadata zstandard==0.25.0 # via langsmith + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements.txt b/requirements.txt index cf6fe53a..cb1dd467 100644 --- a/requirements.txt +++ b/requirements.txt @@ -56,6 +56,7 @@ Pillow>=10.0.0 # Für Bildverarbeitung (als PIL importiert) # Audio format conversion handled by pure Python implementation ## Utilities & Timezone Support +setuptools>=65.0.0,<82.0.0 # Provides pkg_resources (removed in 82+); required by google-cloud-translate python-dateutil==2.8.2 python-dotenv==1.0.0 pytz>=2023.3 # For timezone handling and UTC operations diff --git a/scripts/.$import_diagram.drawio.bkp b/scripts/.$import_diagram.drawio.bkp deleted file mode 100644 index 588285b0..00000000 --- a/scripts/.$import_diagram.drawio.bkp +++ /dev/null @@ -1,2723 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/scripts/.$import_diagram_containers.drawio.bkp b/scripts/.$import_diagram_containers.drawio.bkp deleted file mode 100644 index cad95eef..00000000 --- a/scripts/.$import_diagram_containers.drawio.bkp +++ /dev/null @@ -1,317 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/scripts/script_db_init_automation2.py b/scripts/script_db_init_automation2.py new file mode 100644 index 00000000..56d0daaf --- /dev/null +++ b/scripts/script_db_init_automation2.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +""" +Initialize poweron_automation2 database for the Automation2 feature. + +Creates the poweron_automation2 database if it does not exist. +Uses DB_* config. Tables (Automation2Workflow, Automation2WorkflowRun, +Automation2HumanTask) are auto-created by the connector on first use. + +Usage: + python scripts/script_db_init_automation2.py [--dry-run] +""" + +import os +import sys +import argparse +import logging +from pathlib import Path + +scriptPath = Path(__file__).resolve() +gatewayPath = scriptPath.parent.parent +sys.path.insert(0, str(gatewayPath)) +os.chdir(str(gatewayPath)) + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +import psycopg2 +from modules.shared.configuration import APP_CONFIG + +DB_NAME = "poweron_automation2" + + +def _get_config(): + """Get DB config from APP_CONFIG.""" + host = APP_CONFIG.get("DB_HOST", "localhost") + port = int(APP_CONFIG.get("DB_PORT", "5432")) + user = APP_CONFIG.get("DB_USER") + password = ( + APP_CONFIG.get("DB_PASSWORD_SECRET") or APP_CONFIG.get("DB_PASSWORD") + ) + return {"host": host, "port": port, "user": user, "password": password} + + +def init_automation2_db(dry_run: bool = False) -> bool: + """Create poweron_automation2 database if it does not exist.""" + config = _get_config() + if not config["user"] or not config["password"]: + logger.error("DB_USER and DB_PASSWORD required") + return False + + try: + conn = psycopg2.connect( + host=config["host"], + port=config["port"], + database="postgres", + user=config["user"], + password=config["password"], + ) + conn.autocommit = True + + with conn.cursor() as cur: + cur.execute( + "SELECT 1 FROM pg_database WHERE datname = %s", + (DB_NAME,), + ) + exists = cur.fetchone() is not None + + if exists: + logger.info("Database %s already exists", DB_NAME) + else: + if dry_run: + logger.info("[DRY-RUN] Would create database %s", DB_NAME) + else: + cur.execute(f'CREATE DATABASE "{DB_NAME}"') + logger.info("Created database %s", DB_NAME) + + conn.close() + return True + except Exception as e: + logger.error("Failed to init %s: %s", DB_NAME, e) + return False + + +def main(): + parser = argparse.ArgumentParser(description="Initialize poweron_automation2 database") + parser.add_argument("--dry-run", action="store_true", help="Do not create, only report") + args = parser.parse_args() + ok = init_automation2_db(dry_run=args.dry_run) + sys.exit(0 if ok else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/rbac/test_rbac_bootstrap.py b/tests/unit/rbac/test_rbac_bootstrap.py index e8b04f07..0e69b802 100644 --- a/tests/unit/rbac/test_rbac_bootstrap.py +++ b/tests/unit/rbac/test_rbac_bootstrap.py @@ -5,165 +5,155 @@ Unit tests for RBAC bootstrap initialization. Tests that bootstrap creates correct rules and initial data. """ -import pytest -from unittest.mock import Mock, MagicMock, patch +from unittest.mock import Mock, patch + from modules.interfaces.interfaceBootstrap import ( - initBootstrap, initRootMandate, initAdminUser, initEventUser, initRbacRules, - createDefaultRoleRules, - createTableSpecificRules + _createDefaultRoleRules, + _createTableSpecificRules, ) -from modules.datamodels.datamodelUam import UserInDB, Mandate, AuthAuthority +from modules.datamodels.datamodelUam import UserInDB, Mandate from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext from modules.datamodels.datamodelUam import AccessLevel class TestRbacBootstrap: """Test RBAC bootstrap initialization.""" - + def testInitRootMandateCreatesIfNotExists(self): """Test that initRootMandate creates mandate if it doesn't exist.""" db = Mock() - db.getRecordset = Mock(return_value=[]) # No existing mandates + db.getRecordset = Mock(return_value=[]) db.recordCreate = Mock(return_value={"id": "mandate1", "name": "Root"}) - + mandateId = initRootMandate(db) - + assert mandateId == "mandate1" db.recordCreate.assert_called_once() callArgs = db.recordCreate.call_args assert isinstance(callArgs[0][1], Mandate) - assert callArgs[0][1].name == "Root" - + assert callArgs[0][1].name == "root" + assert callArgs[0][1].label == "Root" + def testInitRootMandateReturnsExisting(self): """Test that initRootMandate returns existing mandate ID.""" db = Mock() db.getRecordset = Mock(return_value=[{"id": "existing_mandate"}]) - + mandateId = initRootMandate(db) - + assert mandateId == "existing_mandate" db.recordCreate.assert_not_called() - + def testInitAdminUserCreatesWithSysadminRole(self): - """Test that initAdminUser creates user with sysadmin role.""" + """Test that initAdminUser creates user with isSysAdmin=True.""" db = Mock() - db.getRecordset = Mock(return_value=[]) # No existing users + db.getRecordset = Mock(return_value=[]) db.recordCreate = Mock(return_value={"id": "admin1", "username": "admin"}) - - with patch('modules.interfaces.interfaceBootstrap._getPasswordHash', return_value="hashed"): + + with patch("modules.interfaces.interfaceBootstrap._getPasswordHash", return_value="hashed"): userId = initAdminUser(db, "mandate1") - + assert userId == "admin1" db.recordCreate.assert_called_once() callArgs = db.recordCreate.call_args user = callArgs[0][1] assert isinstance(user, UserInDB) assert user.username == "admin" - assert "sysadmin" in user.roleLabels - + assert user.isSysAdmin is True + def testInitEventUserCreatesWithSysadminRole(self): - """Test that initEventUser creates user with sysadmin role.""" + """Test that initEventUser creates user with isSysAdmin=True.""" db = Mock() - db.getRecordset = Mock(return_value=[]) # No existing users + db.getRecordset = Mock(return_value=[]) db.recordCreate = Mock(return_value={"id": "event1", "username": "event"}) - - with patch('modules.interfaces.interfaceBootstrap._getPasswordHash', return_value="hashed"): + + with patch("modules.interfaces.interfaceBootstrap._getPasswordHash", return_value="hashed"): userId = initEventUser(db, "mandate1") - + assert userId == "event1" db.recordCreate.assert_called_once() callArgs = db.recordCreate.call_args user = callArgs[0][1] assert isinstance(user, UserInDB) assert user.username == "event" - assert "sysadmin" in user.roleLabels - + assert user.isSysAdmin is True + def testCreateDefaultRoleRules(self): - """Test that createDefaultRoleRules creates correct default rules.""" + """Test that _createDefaultRoleRules creates admin + viewer generic DATA rules.""" db = Mock() db.recordCreate = Mock() - - createDefaultRoleRules(db) - - # Should create 4 default rules (sysadmin, admin, user, viewer) - assert db.recordCreate.call_count == 4 - - # Check sysadmin rule - sysadminCall = [call for call in db.recordCreate.call_args_list - if call[0][1].roleLabel == "sysadmin"][0] - sysadminRule = sysadminCall[0][1] - assert sysadminRule.context == AccessRuleContext.DATA - assert sysadminRule.item is None - assert sysadminRule.view == True - assert sysadminRule.read == AccessLevel.ALL - assert sysadminRule.create == AccessLevel.ALL - - # Check user rule - userCall = [call for call in db.recordCreate.call_args_list - if call[0][1].roleLabel == "user"][0] - userRule = userCall[0][1] - assert userRule.read == AccessLevel.MY - assert userRule.create == AccessLevel.MY - + + with patch( + "modules.interfaces.interfaceBootstrap._getRoleId", + side_effect=lambda d, label: f"rid-{label}", + ): + _createDefaultRoleRules(db) + + assert db.recordCreate.call_count == 2 + created = [c[0][1] for c in db.recordCreate.call_args_list] + byRoleId = {r.roleId: r for r in created} + + adminRule = byRoleId["rid-admin"] + assert adminRule.context == AccessRuleContext.DATA + assert adminRule.item is None + assert adminRule.view is True + assert adminRule.read == AccessLevel.GROUP + assert adminRule.create == AccessLevel.GROUP + + viewerRule = byRoleId["rid-viewer"] + assert viewerRule.read == AccessLevel.GROUP + assert viewerRule.create == AccessLevel.NONE + def testCreateTableSpecificRules(self): - """Test that createTableSpecificRules creates table-specific rules.""" + """Test that _createTableSpecificRules creates table-specific rules.""" db = Mock() db.recordCreate = Mock() - - createTableSpecificRules(db) - - # Should create multiple rules for different tables + + with patch( + "modules.interfaces.interfaceBootstrap._getRoleId", + side_effect=lambda d, label: f"rid-{label}", + ): + _createTableSpecificRules(db) + assert db.recordCreate.call_count > 0 - - # Check that Mandate table rules are created with full objectKey (UAM namespace) - mandateCalls = [call for call in db.recordCreate.call_args_list - if call[0][1].item == "data.uam.Mandate"] + + mandateCalls = [ + call for call in db.recordCreate.call_args_list if call[0][1].item == "data.uam.Mandate" + ] assert len(mandateCalls) > 0 - - # Check that all roles have view=False and no access for Mandate - # (SysAdmin bypasses RBAC via isSysAdmin flag, not via roles) + for call in mandateCalls: rule = call[0][1] - assert rule.view == False + assert rule.view is False assert rule.read == AccessLevel.NONE - + def testInitRbacRulesSkipsIfExists(self): - """Test that initRbacRules skips default rule creation if rules already exist, but adds missing table-specific rules.""" + """When AccessRules already exist, init skips full init; ensure-* hooks are not exercised here.""" db = Mock() - # Mock existing rules - include rules for ChatWorkflow and AutomationDefinition to prevent adding missing rules - # Need rules for all required roles to fully prevent creation - # Using semantic namespace format: data.chat.{TableName}, data.automation.{TableName} - existingRules = [] - for table in ["data.chat.ChatWorkflow", "data.automation.AutomationDefinition"]: - for role in ["admin", "user", "viewer"]: - existingRules.append({ - "id": f"rule_{table}_{role}", - "item": table, - "context": AccessRuleContext.DATA.value, - "roleLabel": role - }) - db.getRecordset = Mock(return_value=existingRules) + db.getRecordset = Mock(return_value=[{"id": "existing_rule"}]) db.recordCreate = Mock() - - initRbacRules(db) - - # Should not create new rules since all required tables already have rules for all roles + + with patch("modules.interfaces.interfaceBootstrap._ensureUiContextRules"): + with patch("modules.interfaces.interfaceBootstrap._ensureDataContextRules"): + initRbacRules(db) + db.recordCreate.assert_not_called() - + def testInitRbacRulesCreatesIfNotExists(self): - """Test that initRbacRules creates rules if they don't exist.""" + """Test that initRbacRules creates rules when the AccessRule table is empty.""" db = Mock() - db.getRecordset = Mock(side_effect=[ - [], # No existing rules - [] # After creating default rules - ]) db.recordCreate = Mock() - - initRbacRules(db) - - # Should create rules + db.recordModify = Mock() + db.getRecordset = Mock(return_value=[]) + + with patch( + "modules.interfaces.interfaceBootstrap._getRoleId", + side_effect=lambda d, label: f"rid-{label}", + ): + initRbacRules(db) + assert db.recordCreate.call_count > 0 diff --git a/tests/unit/rbac/test_rbac_permissions.py b/tests/unit/rbac/test_rbac_permissions.py index b40bebe3..49458367 100644 --- a/tests/unit/rbac/test_rbac_permissions.py +++ b/tests/unit/rbac/test_rbac_permissions.py @@ -5,411 +5,397 @@ Unit tests for RBAC permission resolution. Tests rule specificity, multiple roles, and permission combination logic. """ -import pytest -from modules.datamodels.datamodelUam import User, AccessLevel, UserPermissions +from unittest.mock import Mock + +from modules.datamodels.datamodelUam import User, AccessLevel from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext from modules.security.rbac import RbacClass from modules.connectors.connectorDbPostgre import DatabaseConnector -from unittest.mock import Mock, MagicMock + +_TEST_ROLE_USER = "test-rid-user" +_TEST_ROLE_VIEWER = "test-rid-viewer" + + +def _patchRbacResolution(rbac, roleIds, rulesWithPriority): + """Stub multi-tenant role loading and rule fetch so tests exercise merge logic only.""" + rbac._getRoleIdsForUser = Mock(return_value=roleIds) + rbac._getRulesForRoleIds = Mock(return_value=rulesWithPriority) class TestRbacPermissionResolution: """Test RBAC permission resolution logic.""" - + def testSingleRoleGenericRule(self): """Test permission resolution with a single role and generic rule.""" - # Mock database connector db = Mock(spec=DatabaseConnector) dbApp = Mock(spec=DatabaseConnector) - - # Create RBAC interface rbac = RbacClass(db, dbApp=dbApp) - - # Create user with single role + user = User( id="user1", username="testuser", roleLabels=["user"], - mandateId="mandate1" + mandateId="mandate1", ) - - # Mock rules for "user" role - def mockGetRulesForRole(roleLabel, context): - if roleLabel == "user" and context == AccessRuleContext.DATA: - return [ - AccessRule( - roleLabel="user", - context=AccessRuleContext.DATA, - item=None, # Generic rule - view=True, - read=AccessLevel.MY, - create=AccessLevel.MY, - update=AccessLevel.MY, - delete=AccessLevel.MY - ) - ] - return [] - - rbac._getRulesForRole = mockGetRulesForRole - - # Get permissions for generic table + + rules = [ + ( + 1, + AccessRule( + roleId=_TEST_ROLE_USER, + context=AccessRuleContext.DATA, + item=None, + view=True, + read=AccessLevel.MY, + create=AccessLevel.MY, + update=AccessLevel.MY, + delete=AccessLevel.MY, + ), + ) + ] + _patchRbacResolution(rbac, [_TEST_ROLE_USER], rules) + permissions = rbac.getUserPermissions( user, AccessRuleContext.DATA, - "SomeTable" + "SomeTable", ) - - assert permissions.view == True + + assert permissions.view is True assert permissions.read == AccessLevel.MY assert permissions.create == AccessLevel.MY assert permissions.update == AccessLevel.MY assert permissions.delete == AccessLevel.MY - + def testRuleSpecificityMostSpecificWins(self): """Test that most specific rule wins within a single role.""" db = Mock(spec=DatabaseConnector) dbApp = Mock(spec=DatabaseConnector) rbac = RbacClass(db, dbApp=dbApp) - + user = User( id="user1", username="testuser", roleLabels=["user"], - mandateId="mandate1" + mandateId="mandate1", ) - - def mockGetRulesForRole(roleLabel, context): - if roleLabel == "user" and context == AccessRuleContext.DATA: - return [ - AccessRule( - roleLabel="user", - context=AccessRuleContext.DATA, - item=None, # Generic rule - view=True, - read=AccessLevel.GROUP, - create=AccessLevel.GROUP, - update=AccessLevel.GROUP, - delete=AccessLevel.GROUP - ), - AccessRule( - roleLabel="user", - context=AccessRuleContext.DATA, - item="data.uam.UserInDB", # Specific rule with UAM namespace - view=True, - read=AccessLevel.MY, - create=AccessLevel.NONE, - update=AccessLevel.MY, - delete=AccessLevel.NONE - ) - ] - return [] - - rbac._getRulesForRole = mockGetRulesForRole - - # Get permissions for UserInDB table - should use specific rule - # Using UAM namespace: data.uam.UserInDB + + rules = [ + ( + 1, + AccessRule( + roleId=_TEST_ROLE_USER, + context=AccessRuleContext.DATA, + item=None, + view=True, + read=AccessLevel.GROUP, + create=AccessLevel.GROUP, + update=AccessLevel.GROUP, + delete=AccessLevel.GROUP, + ), + ), + ( + 1, + AccessRule( + roleId=_TEST_ROLE_USER, + context=AccessRuleContext.DATA, + item="data.uam.UserInDB", + view=True, + read=AccessLevel.MY, + create=AccessLevel.NONE, + update=AccessLevel.MY, + delete=AccessLevel.NONE, + ), + ), + ] + _patchRbacResolution(rbac, [_TEST_ROLE_USER], rules) + permissions = rbac.getUserPermissions( user, AccessRuleContext.DATA, - "data.uam.UserInDB" + "data.uam.UserInDB", ) - - # Most specific rule should win + assert permissions.read == AccessLevel.MY assert permissions.create == AccessLevel.NONE assert permissions.update == AccessLevel.MY assert permissions.delete == AccessLevel.NONE - + def testMultipleRolesUnionLogic(self): - """Test that multiple roles use union (opening) logic.""" + """Test that multiple roles use union (opening) logic for view.""" db = Mock(spec=DatabaseConnector) dbApp = Mock(spec=DatabaseConnector) rbac = RbacClass(db, dbApp=dbApp) - - # User with multiple roles + user = User( id="user1", username="testuser", roleLabels=["user", "viewer"], - mandateId="mandate1" + mandateId="mandate1", ) - - def mockGetRulesForRole(roleLabel, context): - if context == AccessRuleContext.UI: - if roleLabel == "user": - return [ - AccessRule( - roleLabel="user", - context=AccessRuleContext.UI, - item="playground", - view=False # User role hides playground - ) - ] - elif roleLabel == "viewer": - return [ - AccessRule( - roleLabel="viewer", - context=AccessRuleContext.UI, - item="playground", - view=True # Viewer role shows playground - ) - ] - return [] - - rbac._getRulesForRole = mockGetRulesForRole - - # Get permissions - union logic should make playground visible + + rules = [ + ( + 1, + AccessRule( + roleId=_TEST_ROLE_USER, + context=AccessRuleContext.UI, + item="playground", + view=False, + ), + ), + ( + 1, + AccessRule( + roleId=_TEST_ROLE_VIEWER, + context=AccessRuleContext.UI, + item="playground", + view=True, + ), + ), + ] + _patchRbacResolution(rbac, [_TEST_ROLE_USER, _TEST_ROLE_VIEWER], rules) + permissions = rbac.getUserPermissions( user, AccessRuleContext.UI, - "playground" + "playground", ) - - # Union logic: if ANY role has view=true, then view=true - assert permissions.view == True - + + assert permissions.view is True + def testViewFalseOverridesGeneric(self): """Test that specific view=false overrides generic view=true.""" db = Mock(spec=DatabaseConnector) dbApp = Mock(spec=DatabaseConnector) rbac = RbacClass(db, dbApp=dbApp) - + user = User( id="user1", username="testuser", roleLabels=["user"], - mandateId="mandate1" + mandateId="mandate1", ) - - def mockGetRulesForRole(roleLabel, context): - if roleLabel == "user" and context == AccessRuleContext.UI: - return [ - AccessRule( - roleLabel="user", - context=AccessRuleContext.UI, - item=None, # Generic: view all UI - view=True - ), - AccessRule( - roleLabel="user", - context=AccessRuleContext.UI, - item="playground.voice.settings", # Specific: hide this - view=False - ) - ] - return [] - - rbac._getRulesForRole = mockGetRulesForRole - - # Get permissions for specific UI element + + rules = [ + ( + 1, + AccessRule( + roleId=_TEST_ROLE_USER, + context=AccessRuleContext.UI, + item=None, + view=True, + ), + ), + ( + 1, + AccessRule( + roleId=_TEST_ROLE_USER, + context=AccessRuleContext.UI, + item="playground.voice.settings", + view=False, + ), + ), + ] + _patchRbacResolution(rbac, [_TEST_ROLE_USER], rules) + permissions = rbac.getUserPermissions( user, AccessRuleContext.UI, - "playground.voice.settings" + "playground.voice.settings", ) - - # Specific rule should override generic - assert permissions.view == False - + + assert permissions.view is False + def testNoRolesReturnsNoAccess(self): - """Test that user with no roles gets no access.""" + """Test that user with no resolved role IDs gets no access.""" db = Mock(spec=DatabaseConnector) dbApp = Mock(spec=DatabaseConnector) rbac = RbacClass(db, dbApp=dbApp) - + dbApp.getRecordset = Mock(return_value=[]) + user = User( id="user1", username="testuser", - roleLabels=[], # No roles - mandateId="mandate1" + roleLabels=[], + mandateId="mandate1", ) - + permissions = rbac.getUserPermissions( user, AccessRuleContext.DATA, - "SomeTable" + "SomeTable", ) - - assert permissions.view == False + + assert permissions.view is False assert permissions.read == AccessLevel.NONE assert permissions.create == AccessLevel.NONE assert permissions.update == AccessLevel.NONE assert permissions.delete == AccessLevel.NONE - + def testFindMostSpecificRule(self): """Test findMostSpecificRule method.""" db = Mock(spec=DatabaseConnector) dbApp = Mock(spec=DatabaseConnector) rbac = RbacClass(db, dbApp=dbApp) - + rules = [ AccessRule( - roleLabel="user", + roleId=_TEST_ROLE_USER, context=AccessRuleContext.DATA, - item=None, # Generic + item=None, view=True, - read=AccessLevel.GROUP + read=AccessLevel.GROUP, ), AccessRule( - roleLabel="user", + roleId=_TEST_ROLE_USER, context=AccessRuleContext.DATA, - item="data.uam.UserInDB", # Table-level with UAM namespace + item="data.uam.UserInDB", view=True, - read=AccessLevel.MY + read=AccessLevel.MY, ), AccessRule( - roleLabel="user", + roleId=_TEST_ROLE_USER, context=AccessRuleContext.DATA, - item="data.uam.UserInDB.email", # Field-level - most specific + item="data.uam.UserInDB.email", view=True, - read=AccessLevel.NONE - ) + read=AccessLevel.NONE, + ), ] - - # Test exact match + rule = rbac.findMostSpecificRule(rules, "data.uam.UserInDB.email") assert rule is not None assert rule.item == "data.uam.UserInDB.email" assert rule.read == AccessLevel.NONE - - # Test table-level match + rule = rbac.findMostSpecificRule(rules, "data.uam.UserInDB") assert rule is not None assert rule.item == "data.uam.UserInDB" assert rule.read == AccessLevel.MY - - # Test generic fallback + rule = rbac.findMostSpecificRule(rules, "OtherTable") assert rule is not None assert rule.item is None assert rule.read == AccessLevel.GROUP - + def testValidateAccessRuleOpeningRights(self): """Test that CUD permissions respect read permission level.""" db = Mock(spec=DatabaseConnector) dbApp = Mock(spec=DatabaseConnector) rbac = RbacClass(db, dbApp=dbApp) - - # Valid: Read=MY, Create=MY (allowed) + rule1 = AccessRule( - roleLabel="user", + roleId=_TEST_ROLE_USER, context=AccessRuleContext.DATA, item="data.uam.UserInDB", view=True, read=AccessLevel.MY, create=AccessLevel.MY, update=AccessLevel.MY, - delete=AccessLevel.MY + delete=AccessLevel.MY, ) - assert rbac.validateAccessRule(rule1) == True - - # Invalid: Read=MY, Create=GROUP (not allowed - GROUP > MY) + assert rbac.validateAccessRule(rule1) is True + rule2 = AccessRule( - roleLabel="user", + roleId=_TEST_ROLE_USER, context=AccessRuleContext.DATA, item="data.uam.UserInDB", view=True, read=AccessLevel.MY, - create=AccessLevel.GROUP, # Not allowed + create=AccessLevel.GROUP, update=AccessLevel.MY, - delete=AccessLevel.MY + delete=AccessLevel.MY, ) - assert rbac.validateAccessRule(rule2) == False - - # Valid: Read=GROUP, Create=GROUP (allowed) + assert rbac.validateAccessRule(rule2) is False + rule3 = AccessRule( - roleLabel="admin", + roleId="test-rid-admin", context=AccessRuleContext.DATA, item="data.uam.UserInDB", view=True, read=AccessLevel.GROUP, create=AccessLevel.GROUP, update=AccessLevel.GROUP, - delete=AccessLevel.GROUP + delete=AccessLevel.GROUP, ) - assert rbac.validateAccessRule(rule3) == True - - # Invalid: Read=NONE, Create=MY (not allowed - no read access) + assert rbac.validateAccessRule(rule3) is True + rule4 = AccessRule( - roleLabel="user", + roleId=_TEST_ROLE_USER, context=AccessRuleContext.DATA, item="data.uam.UserInDB", view=True, read=AccessLevel.NONE, - create=AccessLevel.MY, # Not allowed without read + create=AccessLevel.MY, update=AccessLevel.MY, - delete=AccessLevel.MY + delete=AccessLevel.MY, ) - assert rbac.validateAccessRule(rule4) == False - + assert rbac.validateAccessRule(rule4) is False + def testUiContextOnlyViewMatters(self): """Test that UI context only checks view permission.""" db = Mock(spec=DatabaseConnector) dbApp = Mock(spec=DatabaseConnector) rbac = RbacClass(db, dbApp=dbApp) - + user = User( id="user1", username="testuser", roleLabels=["user"], - mandateId="mandate1" + mandateId="mandate1", ) - - def mockGetRulesForRole(roleLabel, context): - if roleLabel == "user" and context == AccessRuleContext.UI: - return [ - AccessRule( - roleLabel="user", - context=AccessRuleContext.UI, - item="playground", - view=True - # No read/create/update/delete for UI context - ) - ] - return [] - - rbac._getRulesForRole = mockGetRulesForRole - + + rules = [ + ( + 1, + AccessRule( + roleId=_TEST_ROLE_USER, + context=AccessRuleContext.UI, + item="playground", + view=True, + ), + ) + ] + _patchRbacResolution(rbac, [_TEST_ROLE_USER], rules) + permissions = rbac.getUserPermissions( user, AccessRuleContext.UI, - "playground" + "playground", ) - - assert permissions.view == True - # Other permissions don't matter for UI context - + + assert permissions.view is True + def testResourceContextOnlyViewMatters(self): """Test that RESOURCE context only checks view permission.""" db = Mock(spec=DatabaseConnector) dbApp = Mock(spec=DatabaseConnector) rbac = RbacClass(db, dbApp=dbApp) - + user = User( id="user1", username="testuser", roleLabels=["user"], - mandateId="mandate1" + mandateId="mandate1", ) - - def mockGetRulesForRole(roleLabel, context): - if roleLabel == "user" and context == AccessRuleContext.RESOURCE: - return [ - AccessRule( - roleLabel="user", - context=AccessRuleContext.RESOURCE, - item="ai.model.anthropic", - view=True - ) - ] - return [] - - rbac._getRulesForRole = mockGetRulesForRole - + + rules = [ + ( + 1, + AccessRule( + roleId=_TEST_ROLE_USER, + context=AccessRuleContext.RESOURCE, + item="ai.model.anthropic", + view=True, + ), + ) + ] + _patchRbacResolution(rbac, [_TEST_ROLE_USER], rules) + permissions = rbac.getUserPermissions( user, AccessRuleContext.RESOURCE, - "ai.model.anthropic" + "ai.model.anthropic", ) - - assert permissions.view == True + + assert permissions.view is True