commit
f89e4ab0c4
113 changed files with 11318 additions and 4787 deletions
10
app.py
10
app.py
|
|
@ -21,6 +21,8 @@ from datetime import datetime
|
|||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.eventManagement import eventManager
|
||||
from modules.workflows.automation import subAutomationSchedule
|
||||
from modules.features.automation2.emailPoller import start as startAutomation2EmailPoller
|
||||
from modules.features.automation2.emailPoller import stop as stopAutomation2EmailPoller
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.system.registry import loadFeatureMainModules
|
||||
|
||||
|
|
@ -354,6 +356,7 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
# --- Init Managers ---
|
||||
subAutomationSchedule.start(eventUser) # Automation scheduler
|
||||
# Automation2 email poller: started on-demand when a run pauses for email.checkEmail
|
||||
eventManager.start()
|
||||
|
||||
# Register audit log cleanup scheduler
|
||||
|
|
@ -382,6 +385,7 @@ async def lifespan(app: FastAPI):
|
|||
yield
|
||||
|
||||
# --- Stop Managers ---
|
||||
stopAutomation2EmailPoller(eventUser) # Automation2 email poller (no-op if not running)
|
||||
eventManager.stop()
|
||||
subAutomationSchedule.stop(eventUser) # Automation scheduler
|
||||
|
||||
|
|
@ -568,6 +572,9 @@ app.include_router(sharepointRouter)
|
|||
from modules.routes.routeAdminAutomationEvents import router as adminAutomationEventsRouter
|
||||
app.include_router(adminAutomationEventsRouter)
|
||||
|
||||
from modules.routes.routeAdminAutomationLogs import router as adminAutomationLogsRouter
|
||||
app.include_router(adminAutomationLogsRouter)
|
||||
|
||||
from modules.routes.routeAdminLogs import router as adminLogsRouter
|
||||
app.include_router(adminLogsRouter)
|
||||
|
||||
|
|
@ -601,6 +608,9 @@ app.include_router(gdprRouter)
|
|||
from modules.routes.routeBilling import router as billingRouter
|
||||
app.include_router(billingRouter)
|
||||
|
||||
from modules.routes.routeSubscription import router as subscriptionRouter
|
||||
app.include_router(subscriptionRouter)
|
||||
|
||||
# ============================================================================
|
||||
# SYSTEM ROUTES (Navigation, etc.)
|
||||
# ============================================================================
|
||||
|
|
|
|||
|
|
@ -44,3 +44,8 @@ Connector_StacSwisstopo_TIMEOUT = 30
|
|||
Connector_StacSwisstopo_MAX_RETRIES = 3
|
||||
Connector_StacSwisstopo_RETRY_DELAY = 1.0
|
||||
Connector_StacSwisstopo_ENABLE_CACHE = True
|
||||
|
||||
# Operator company information (shown on invoice emails)
|
||||
Operator_CompanyName = PowerOn AG
|
||||
Operator_Address = Birmensdorferstrasse 94, 8003 Zürich
|
||||
Operator_VatNumber = CHE491.960.195
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
import contextvars
|
||||
import re
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import logging
|
||||
|
|
@ -92,19 +93,27 @@ def _get_model_fields(model_class) -> Dict[str, str]:
|
|||
fields[field_name] = extra["db_type"]
|
||||
continue
|
||||
|
||||
# Check for JSONB fields (Dict, List, or complex types)
|
||||
# Unwrap Optional[X] → X (handles both typing.Union and types.UnionType)
|
||||
origin = get_origin(field_type)
|
||||
if origin is Union:
|
||||
args = [a for a in get_args(field_type) if a is not type(None)]
|
||||
if len(args) == 1:
|
||||
field_type = args[0]
|
||||
elif hasattr(field_type, '__args__') and type(None) in getattr(field_type, '__args__', ()):
|
||||
args = [a for a in field_type.__args__ if a is not type(None)]
|
||||
if len(args) == 1:
|
||||
field_type = args[0]
|
||||
|
||||
if _isJsonbType(field_type):
|
||||
fields[field_name] = "JSONB"
|
||||
elif field_type in (str, type(None)) or (
|
||||
get_origin(field_type) is Union and type(None) in get_args(field_type)
|
||||
):
|
||||
fields[field_name] = "TEXT"
|
||||
elif field_type == int:
|
||||
fields[field_name] = "INTEGER"
|
||||
elif field_type == float:
|
||||
fields[field_name] = "DOUBLE PRECISION"
|
||||
elif field_type == bool:
|
||||
elif field_type is bool:
|
||||
fields[field_name] = "BOOLEAN"
|
||||
elif field_type is int:
|
||||
fields[field_name] = "INTEGER"
|
||||
elif field_type is float:
|
||||
fields[field_name] = "DOUBLE PRECISION"
|
||||
elif field_type in (str, type(None)):
|
||||
fields[field_name] = "TEXT"
|
||||
else:
|
||||
fields[field_name] = "TEXT"
|
||||
|
||||
|
|
@ -135,6 +144,9 @@ def _parseRecordFields(record: Dict[str, Any], fields: Dict[str, str], context:
|
|||
elif isinstance(value, list):
|
||||
pass # already a list
|
||||
|
||||
elif fieldType == "BOOLEAN":
|
||||
record[fieldName] = bool(value) if value is not None else False
|
||||
|
||||
elif fieldType == "JSONB" and value is not None:
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
|
|
@ -969,6 +981,252 @@ class DatabaseConnector:
|
|||
logger.error(f"Error loading records from table {table}: {e}")
|
||||
return []
|
||||
|
||||
def _buildPaginationClauses(
|
||||
self,
|
||||
model_class: type,
|
||||
pagination,
|
||||
recordFilter: Dict[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
Translate PaginationParams + recordFilter into SQL clauses.
|
||||
Returns (where_clause, order_clause, limit_clause, values, count_values).
|
||||
"""
|
||||
fields = _get_model_fields(model_class)
|
||||
fields["_createdAt"] = "DOUBLE PRECISION"
|
||||
fields["_modifiedAt"] = "DOUBLE PRECISION"
|
||||
fields["_createdBy"] = "TEXT"
|
||||
fields["_modifiedBy"] = "TEXT"
|
||||
validColumns = set(fields.keys())
|
||||
where_parts: List[str] = []
|
||||
values: List[Any] = []
|
||||
|
||||
if recordFilter:
|
||||
for field, value in recordFilter.items():
|
||||
if value is None:
|
||||
where_parts.append(f'"{field}" IS NULL')
|
||||
elif isinstance(value, list):
|
||||
where_parts.append(f'"{field}" = ANY(%s)')
|
||||
values.append(value)
|
||||
else:
|
||||
where_parts.append(f'"{field}" = %s')
|
||||
values.append(value)
|
||||
|
||||
if pagination and pagination.filters:
|
||||
for key, val in pagination.filters.items():
|
||||
if key == "search" and isinstance(val, str) and val.strip():
|
||||
term = f"%{val.strip()}%"
|
||||
textCols = [c for c, t in fields.items() if t == "TEXT"]
|
||||
if textCols:
|
||||
orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols]
|
||||
where_parts.append(f"({' OR '.join(orParts)})")
|
||||
values.extend([term] * len(textCols))
|
||||
continue
|
||||
if key not in validColumns:
|
||||
logger.debug(f"_buildPaginationClauses: key '{key}' NOT in validColumns {list(validColumns)[:10]}")
|
||||
continue
|
||||
colType = fields.get(key, "TEXT")
|
||||
logger.debug(f"_buildPaginationClauses: filter key='{key}' val={val!r} type(val)={type(val).__name__} colType={colType}")
|
||||
if isinstance(val, dict):
|
||||
op = val.get("operator", "equals")
|
||||
v = val.get("value", "")
|
||||
if op in ("equals", "eq"):
|
||||
if colType == "BOOLEAN":
|
||||
where_parts.append(f'COALESCE("{key}", FALSE) = %s')
|
||||
values.append(str(v).lower() == "true")
|
||||
else:
|
||||
where_parts.append(f'"{key}"::TEXT = %s')
|
||||
values.append(str(v))
|
||||
elif op == "contains":
|
||||
where_parts.append(f'"{key}"::TEXT ILIKE %s')
|
||||
values.append(f"%{v}%")
|
||||
elif op == "startsWith":
|
||||
where_parts.append(f'"{key}"::TEXT ILIKE %s')
|
||||
values.append(f"{v}%")
|
||||
elif op == "endsWith":
|
||||
where_parts.append(f'"{key}"::TEXT ILIKE %s')
|
||||
values.append(f"%{v}")
|
||||
elif op in ("gt", "gte", "lt", "lte"):
|
||||
sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op]
|
||||
where_parts.append(f'"{key}"::TEXT {sqlOp} %s')
|
||||
values.append(str(v))
|
||||
elif op == "between":
|
||||
fromVal = v.get("from", "") if isinstance(v, dict) else ""
|
||||
toVal = v.get("to", "") if isinstance(v, dict) else ""
|
||||
if not fromVal and not toVal:
|
||||
continue
|
||||
colType = fields.get(key, "TEXT")
|
||||
isNumericCol = colType in ("INTEGER", "DOUBLE PRECISION")
|
||||
isDateVal = bool(fromVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(fromVal))) or \
|
||||
bool(toVal and re.match(r'^\d{4}-\d{2}-\d{2}$', str(toVal)))
|
||||
if isNumericCol and isDateVal:
|
||||
from datetime import datetime as _dt, timezone as _tz
|
||||
if fromVal and toVal:
|
||||
fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp()
|
||||
toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp()
|
||||
where_parts.append(f'"{key}" >= %s AND "{key}" <= %s')
|
||||
values.extend([fromTs, toTs])
|
||||
elif fromVal:
|
||||
fromTs = _dt.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=_tz.utc).timestamp()
|
||||
where_parts.append(f'"{key}" >= %s')
|
||||
values.append(fromTs)
|
||||
else:
|
||||
toTs = _dt.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=_tz.utc).timestamp()
|
||||
where_parts.append(f'"{key}" <= %s')
|
||||
values.append(toTs)
|
||||
else:
|
||||
if fromVal and toVal:
|
||||
where_parts.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s')
|
||||
values.extend([str(fromVal), str(toVal)])
|
||||
elif fromVal:
|
||||
where_parts.append(f'"{key}"::TEXT >= %s')
|
||||
values.append(str(fromVal))
|
||||
elif toVal:
|
||||
where_parts.append(f'"{key}"::TEXT <= %s')
|
||||
values.append(str(toVal))
|
||||
else:
|
||||
if colType == "BOOLEAN":
|
||||
where_parts.append(f'COALESCE("{key}", FALSE) = %s')
|
||||
values.append(str(val).lower() == "true")
|
||||
else:
|
||||
where_parts.append(f'"{key}"::TEXT ILIKE %s')
|
||||
values.append(str(val))
|
||||
|
||||
where_clause = " WHERE " + " AND ".join(where_parts) if where_parts else ""
|
||||
count_values = list(values)
|
||||
|
||||
orderParts: List[str] = []
|
||||
if pagination and pagination.sort:
|
||||
for sf in pagination.sort:
|
||||
if sf.field in validColumns:
|
||||
direction = "DESC" if sf.direction.lower() == "desc" else "ASC"
|
||||
colType = fields.get(sf.field, "TEXT")
|
||||
if colType == "BOOLEAN":
|
||||
orderParts.append(f'COALESCE("{sf.field}", FALSE) {direction}')
|
||||
else:
|
||||
orderParts.append(f'"{sf.field}" {direction} NULLS LAST')
|
||||
if not orderParts:
|
||||
orderParts.append('"id"')
|
||||
order_clause = " ORDER BY " + ", ".join(orderParts)
|
||||
|
||||
limit_clause = ""
|
||||
if pagination:
|
||||
offset = (pagination.page - 1) * pagination.pageSize
|
||||
limit_clause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
|
||||
|
||||
return where_clause, order_clause, limit_clause, values, count_values
|
||||
|
||||
def getRecordsetPaginated(
|
||||
self,
|
||||
model_class: type,
|
||||
pagination=None,
|
||||
recordFilter: Dict[str, Any] = None,
|
||||
fieldFilter: List[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns paginated records with filtering + sorting at the SQL level.
|
||||
Returns { "items": [...], "totalItems": int, "totalPages": int }.
|
||||
If pagination is None, returns all records (no LIMIT/OFFSET).
|
||||
"""
|
||||
from modules.datamodels.datamodelPagination import PaginationParams
|
||||
import math
|
||||
|
||||
table = model_class.__name__
|
||||
|
||||
try:
|
||||
if not self._ensureTableExists(model_class):
|
||||
return {"items": [], "totalItems": 0, "totalPages": 0}
|
||||
|
||||
where_clause, order_clause, limit_clause, values, count_values = \
|
||||
self._buildPaginationClauses(model_class, pagination, recordFilter)
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
countSql = f'SELECT COUNT(*) FROM "{table}"{where_clause}'
|
||||
cursor.execute(countSql, count_values)
|
||||
totalItems = cursor.fetchone()["count"]
|
||||
|
||||
dataSql = f'SELECT * FROM "{table}"{where_clause}{order_clause}{limit_clause}'
|
||||
cursor.execute(dataSql, values)
|
||||
records = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
fields = _get_model_fields(model_class)
|
||||
modelFields = model_class.model_fields
|
||||
for record in records:
|
||||
_parseRecordFields(record, fields, f"table {table}")
|
||||
for fieldName, fieldType in fields.items():
|
||||
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
||||
fieldInfo = modelFields.get(fieldName)
|
||||
if fieldInfo:
|
||||
fieldAnnotation = fieldInfo.annotation
|
||||
if (fieldAnnotation == list or
|
||||
(hasattr(fieldAnnotation, "__origin__") and
|
||||
fieldAnnotation.__origin__ is list)):
|
||||
record[fieldName] = []
|
||||
elif (fieldAnnotation == dict or
|
||||
(hasattr(fieldAnnotation, "__origin__") and
|
||||
fieldAnnotation.__origin__ is dict)):
|
||||
record[fieldName] = {}
|
||||
|
||||
if fieldFilter and isinstance(fieldFilter, list):
|
||||
records = [{f: r[f] for f in fieldFilter if f in r} for r in records]
|
||||
|
||||
pageSize = pagination.pageSize if pagination else max(totalItems, 1)
|
||||
totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0
|
||||
|
||||
return {"items": records, "totalItems": totalItems, "totalPages": totalPages}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in getRecordsetPaginated for table {table}: {e}")
|
||||
return {"items": [], "totalItems": 0, "totalPages": 0}
|
||||
|
||||
def getDistinctColumnValues(
|
||||
self,
|
||||
model_class: type,
|
||||
column: str,
|
||||
pagination=None,
|
||||
recordFilter: Dict[str, Any] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns sorted distinct non-null values for a column using SQL DISTINCT.
|
||||
Applies cross-filtering (all filters except the requested column).
|
||||
"""
|
||||
table = model_class.__name__
|
||||
fields = _get_model_fields(model_class)
|
||||
fields["_createdAt"] = "DOUBLE PRECISION"
|
||||
fields["_modifiedAt"] = "DOUBLE PRECISION"
|
||||
fields["_createdBy"] = "TEXT"
|
||||
fields["_modifiedBy"] = "TEXT"
|
||||
|
||||
if column not in fields:
|
||||
return []
|
||||
|
||||
try:
|
||||
if not self._ensureTableExists(model_class):
|
||||
return []
|
||||
|
||||
if pagination:
|
||||
if pagination.filters and column in pagination.filters:
|
||||
import copy
|
||||
pagination = copy.deepcopy(pagination)
|
||||
pagination.filters.pop(column, None)
|
||||
|
||||
where_clause, _, _, values, _ = \
|
||||
self._buildPaginationClauses(model_class, pagination, recordFilter)
|
||||
|
||||
sql = (
|
||||
f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} '
|
||||
f'WHERE "{column}" IS NOT NULL AND "{column}"::TEXT != \'\' '
|
||||
if not where_clause else
|
||||
f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{where_clause} '
|
||||
f'AND "{column}" IS NOT NULL AND "{column}"::TEXT != \'\' '
|
||||
)
|
||||
sql += 'ORDER BY val'
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(sql, values)
|
||||
return [row["val"] for row in cursor.fetchall()]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in getDistinctColumnValues for {table}.{column}: {e}")
|
||||
return []
|
||||
|
||||
def recordCreate(
|
||||
self, model_class: type, record: Union[Dict[str, Any], BaseModel]
|
||||
) -> Dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -299,26 +299,65 @@ class OutlookAdapter(_GraphApiMixin, ServiceAdapter):
|
|||
for m in result.get("value", [])
|
||||
]
|
||||
|
||||
async def sendMail(
|
||||
def _buildMessage(
|
||||
self, to: List[str], subject: str, body: str,
|
||||
cc: Optional[List[str]] = None, attachments: Optional[List[Dict]] = None
|
||||
bodyType: str = "Text",
|
||||
cc: Optional[List[str]] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send an email via Microsoft Graph."""
|
||||
import json
|
||||
"""Build a Graph API message object.
|
||||
|
||||
attachments: list of {"name": str, "contentBytes": str (base64), "contentType": str}
|
||||
"""
|
||||
message: Dict[str, Any] = {
|
||||
"subject": subject,
|
||||
"body": {"contentType": "Text", "content": body},
|
||||
"body": {"contentType": bodyType, "content": body},
|
||||
"toRecipients": [{"emailAddress": {"address": addr}} for addr in to],
|
||||
}
|
||||
if cc:
|
||||
message["ccRecipients"] = [{"emailAddress": {"address": addr}} for addr in cc]
|
||||
if attachments:
|
||||
message["attachments"] = [
|
||||
{
|
||||
"@odata.type": "#microsoft.graph.fileAttachment",
|
||||
"name": att["name"],
|
||||
"contentBytes": att["contentBytes"],
|
||||
"contentType": att.get("contentType", "application/octet-stream"),
|
||||
}
|
||||
for att in attachments
|
||||
]
|
||||
return message
|
||||
|
||||
async def sendMail(
|
||||
self, to: List[str], subject: str, body: str,
|
||||
bodyType: str = "Text",
|
||||
cc: Optional[List[str]] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send an email via Microsoft Graph. bodyType: 'Text' or 'HTML'."""
|
||||
import json
|
||||
message = self._buildMessage(to, subject, body, bodyType, cc, attachments)
|
||||
payload = json.dumps({"message": message, "saveToSentItems": True}).encode("utf-8")
|
||||
result = await self._graphPost("me/sendMail", payload)
|
||||
if "error" in result:
|
||||
return result
|
||||
return {"success": True}
|
||||
|
||||
async def createDraft(
|
||||
self, to: List[str], subject: str, body: str,
|
||||
bodyType: str = "Text",
|
||||
cc: Optional[List[str]] = None,
|
||||
attachments: Optional[List[Dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a draft email in the user's Drafts folder via Microsoft Graph."""
|
||||
import json
|
||||
message = self._buildMessage(to, subject, body, bodyType, cc, attachments)
|
||||
payload = json.dumps(message).encode("utf-8")
|
||||
result = await self._graphPost("me/messages", payload)
|
||||
if "error" in result:
|
||||
return result
|
||||
return {"success": True, "draft": True, "messageId": result.get("id", "")}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Teams Adapter (Stub)
|
||||
|
|
|
|||
|
|
@ -143,6 +143,9 @@ class BillingSettings(BaseModel):
|
|||
)
|
||||
warningThresholdPercent: float = Field(default=10.0, description="Warning threshold as percentage")
|
||||
|
||||
# Stripe
|
||||
stripeCustomerId: Optional[str] = Field(None, description="Stripe Customer ID (cus_xxx) — one per mandate")
|
||||
|
||||
# Notifications (e.g. mandate owner / finance — also used when PREPAY_MANDATE pool is exhausted)
|
||||
notifyEmails: List[str] = Field(
|
||||
default_factory=list,
|
||||
|
|
@ -163,6 +166,7 @@ registerModelLabels(
|
|||
"de": "Startguthaben nur Root-Mandant (CHF)",
|
||||
},
|
||||
"warningThresholdPercent": {"en": "Warning Threshold (%)", "de": "Warnschwelle (%)"},
|
||||
"stripeCustomerId": {"en": "Stripe Customer ID", "de": "Stripe-Kunden-ID"},
|
||||
"notifyEmails": {
|
||||
"en": "Billing notification emails (owner / admin)",
|
||||
"de": "E-Mails für Billing-Alerts (Inhaber/Admin)",
|
||||
|
|
@ -260,12 +264,15 @@ class BillingStatisticsResponse(BaseModel):
|
|||
|
||||
|
||||
class BillingCheckResult(BaseModel):
|
||||
"""Result of a billing balance check."""
|
||||
"""Result of a billing balance check (budget + subscription gate)."""
|
||||
allowed: bool
|
||||
reason: Optional[str] = None
|
||||
currentBalance: Optional[float] = None
|
||||
requiredAmount: Optional[float] = None
|
||||
billingModel: Optional[BillingModelEnum] = None
|
||||
upgradeRequired: Optional[bool] = None
|
||||
subscriptionUiPath: Optional[str] = None
|
||||
userAction: Optional[str] = None
|
||||
|
||||
|
||||
def parseBillingModelFromStoredValue(raw: Optional[str]) -> BillingModelEnum:
|
||||
|
|
|
|||
|
|
@ -98,6 +98,12 @@ def normalize_pagination_dict(pagination_dict: Dict[str, Any]) -> Dict[str, Any]
|
|||
# Create a copy to avoid modifying the original
|
||||
normalized = dict(pagination_dict)
|
||||
|
||||
# Ensure required fields have sensible defaults
|
||||
if "page" not in normalized:
|
||||
normalized["page"] = 1
|
||||
if "pageSize" not in normalized:
|
||||
normalized["pageSize"] = 25
|
||||
|
||||
# Move top-level "search" into filters if present
|
||||
if "search" in normalized:
|
||||
if "filters" not in normalized or normalized["filters"] is None:
|
||||
|
|
|
|||
235
modules/datamodels/datamodelSubscription.py
Normal file
235
modules/datamodels/datamodelSubscription.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Subscription models: SubscriptionPlan (catalog), MandateSubscription (instance per mandate),
|
||||
StripePlanPrice (persisted Stripe IDs per plan).
|
||||
|
||||
State Machine: see wiki/concepts/Subscription-State-Machine.md
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
import uuid
|
||||
|
||||
|
||||
class SubscriptionStatusEnum(str, Enum):
|
||||
"""Lifecycle status of a mandate subscription.
|
||||
See wiki/concepts/Subscription-State-Machine.md for transition rules."""
|
||||
PENDING = "PENDING"
|
||||
SCHEDULED = "SCHEDULED"
|
||||
TRIALING = "TRIALING"
|
||||
ACTIVE = "ACTIVE"
|
||||
PAST_DUE = "PAST_DUE"
|
||||
EXPIRED = "EXPIRED"
|
||||
|
||||
|
||||
TERMINAL_STATUSES = {SubscriptionStatusEnum.EXPIRED}
|
||||
OPERATIVE_STATUSES = {SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.PAST_DUE}
|
||||
|
||||
ALLOWED_TRANSITIONS = {
|
||||
(SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.ACTIVE),
|
||||
(SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.SCHEDULED),
|
||||
(SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.EXPIRED),
|
||||
(SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE),
|
||||
(SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.EXPIRED),
|
||||
(SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.EXPIRED),
|
||||
(SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE),
|
||||
(SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.EXPIRED),
|
||||
(SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.ACTIVE),
|
||||
(SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.EXPIRED),
|
||||
}
|
||||
|
||||
|
||||
class BillingPeriodEnum(str, Enum):
|
||||
"""Recurring billing interval."""
|
||||
MONTHLY = "MONTHLY"
|
||||
YEARLY = "YEARLY"
|
||||
NONE = "NONE"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Catalog: SubscriptionPlan (static, in-memory)
|
||||
# ============================================================================
|
||||
|
||||
class SubscriptionPlan(BaseModel):
|
||||
"""Plan definition (catalog entry). Not stored per mandate — static."""
|
||||
planKey: str = Field(..., description="Unique plan identifier")
|
||||
selectableByUser: bool = Field(default=True, description="Whether users can choose this plan in the UI")
|
||||
|
||||
title: Dict[str, str] = Field(default_factory=dict, description="Multilingual title (en/de/fr)")
|
||||
description: Dict[str, str] = Field(default_factory=dict, description="Multilingual description")
|
||||
|
||||
currency: str = Field(default="CHF", description="Billing currency")
|
||||
billingPeriod: BillingPeriodEnum = Field(default=BillingPeriodEnum.MONTHLY, description="Recurring interval")
|
||||
pricePerUserCHF: float = Field(default=0.0, description="Price per active user per period")
|
||||
pricePerFeatureInstanceCHF: float = Field(default=0.0, description="Price per active feature instance per period")
|
||||
autoRenew: bool = Field(default=True, description="Stripe renews automatically at period end")
|
||||
|
||||
maxUsers: Optional[int] = Field(None, description="Hard cap on active users (None = unlimited)")
|
||||
maxFeatureInstances: Optional[int] = Field(None, description="Hard cap on active feature instances (None = unlimited)")
|
||||
trialDays: Optional[int] = Field(None, description="Trial duration in days (only for trial plans)")
|
||||
successorPlanKey: Optional[str] = Field(None, description="Plan to transition to when trial ends")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"SubscriptionPlan",
|
||||
{"en": "Subscription Plan", "de": "Abonnement-Plan", "fr": "Plan d'abonnement"},
|
||||
{
|
||||
"planKey": {"en": "Plan", "de": "Plan", "fr": "Plan"},
|
||||
"selectableByUser": {"en": "Selectable", "de": "Wählbar", "fr": "Sélectionnable"},
|
||||
"billingPeriod": {"en": "Billing Period", "de": "Abrechnungszeitraum", "fr": "Période de facturation"},
|
||||
"pricePerUserCHF": {"en": "Price per User (CHF)", "de": "Preis pro User (CHF)"},
|
||||
"pricePerFeatureInstanceCHF": {"en": "Price per Instance (CHF)", "de": "Preis pro Instanz (CHF)"},
|
||||
"maxUsers": {"en": "Max Users", "de": "Max. Benutzer", "fr": "Max. utilisateurs"},
|
||||
"maxFeatureInstances": {"en": "Max Instances", "de": "Max. Instanzen", "fr": "Max. instances"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Stripe Price mapping (persisted in DB, auto-created at bootstrap)
|
||||
# ============================================================================
|
||||
|
||||
class StripePlanPrice(BaseModel):
|
||||
"""Persisted mapping from planKey to Stripe Product/Price IDs.
|
||||
Auto-created at startup — no manual configuration needed.
|
||||
Uses separate Stripe Products for users and instances for clear invoice labels."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
planKey: str = Field(..., description="Reference to SubscriptionPlan.planKey")
|
||||
stripeProductId: str = Field("", description="Legacy single-product ID (unused)")
|
||||
stripeProductIdUsers: Optional[str] = Field(None, description="Stripe Product ID for user licenses")
|
||||
stripeProductIdInstances: Optional[str] = Field(None, description="Stripe Product ID for feature instances")
|
||||
stripePriceIdUsers: Optional[str] = Field(None, description="Stripe Price ID for user-seat line item")
|
||||
stripePriceIdInstances: Optional[str] = Field(None, description="Stripe Price ID for instance line item")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"StripePlanPrice",
|
||||
{"en": "Stripe Plan Prices", "de": "Stripe-Planpreise"},
|
||||
{
|
||||
"planKey": {"en": "Plan", "de": "Plan"},
|
||||
"stripeProductIdUsers": {"en": "Product (Users)", "de": "Produkt (User)"},
|
||||
"stripeProductIdInstances": {"en": "Product (Instances)", "de": "Produkt (Instanzen)"},
|
||||
"stripePriceIdUsers": {"en": "Price ID (Users)", "de": "Preis-ID (User)"},
|
||||
"stripePriceIdInstances": {"en": "Price ID (Instances)", "de": "Preis-ID (Instanzen)"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Instance: MandateSubscription
|
||||
# ============================================================================
|
||||
|
||||
class MandateSubscription(BaseModel):
|
||||
"""A subscription instance bound to a specific mandate.
|
||||
See wiki/concepts/Subscription-State-Machine.md for state transitions."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key")
|
||||
mandateId: str = Field(..., description="Foreign key to Mandate")
|
||||
planKey: str = Field(..., description="Reference to SubscriptionPlan.planKey")
|
||||
|
||||
status: SubscriptionStatusEnum = Field(default=SubscriptionStatusEnum.PENDING, description="Current lifecycle status")
|
||||
recurring: bool = Field(default=True, description="True: auto-renews at period end. False: expires at period end (gekuendigt).")
|
||||
|
||||
startedAt: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="Record creation timestamp")
|
||||
effectiveFrom: Optional[datetime] = Field(None, description="When this subscription becomes operative. None = immediate. Set for SCHEDULED subs.")
|
||||
endedAt: Optional[datetime] = Field(None, description="When subscription ended (terminal)")
|
||||
currentPeriodStart: Optional[datetime] = Field(None, description="Current billing period start (synced from Stripe)")
|
||||
currentPeriodEnd: Optional[datetime] = Field(None, description="Current billing period end (synced from Stripe)")
|
||||
trialEndsAt: Optional[datetime] = Field(None, description="Trial expiry timestamp")
|
||||
|
||||
snapshotPricePerUserCHF: float = Field(default=0.0, description="Price snapshot at activation (for invoice history)")
|
||||
snapshotPricePerInstanceCHF: float = Field(default=0.0, description="Price snapshot at activation")
|
||||
|
||||
stripeSubscriptionId: Optional[str] = Field(None, description="Stripe Subscription ID (sub_xxx)")
|
||||
stripeItemIdUsers: Optional[str] = Field(None, description="Stripe Subscription Item ID for user seats")
|
||||
stripeItemIdInstances: Optional[str] = Field(None, description="Stripe Subscription Item ID for feature instances")
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"MandateSubscription",
|
||||
{"en": "Mandate Subscription", "de": "Mandanten-Abonnement", "fr": "Abonnement du mandat"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "de": "Mandanten-ID"},
|
||||
"planKey": {"en": "Plan", "de": "Plan"},
|
||||
"status": {"en": "Status", "de": "Status"},
|
||||
"recurring": {"en": "Recurring", "de": "Wiederkehrend"},
|
||||
"startedAt": {"en": "Started", "de": "Gestartet"},
|
||||
"effectiveFrom": {"en": "Effective From", "de": "Wirksam ab"},
|
||||
"endedAt": {"en": "Ended", "de": "Beendet"},
|
||||
"currentPeriodStart": {"en": "Period Start", "de": "Periodenbeginn"},
|
||||
"currentPeriodEnd": {"en": "Period End", "de": "Periodenende"},
|
||||
"trialEndsAt": {"en": "Trial Ends", "de": "Trial endet"},
|
||||
"snapshotPricePerUserCHF": {"en": "Price/User (CHF)", "de": "Preis/User (CHF)"},
|
||||
"snapshotPricePerInstanceCHF": {"en": "Price/Instance (CHF)", "de": "Preis/Instanz (CHF)"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Built-in plan catalog (static, no env dependency)
|
||||
# ============================================================================
|
||||
|
||||
BUILTIN_PLANS: Dict[str, SubscriptionPlan] = {
|
||||
"ROOT": SubscriptionPlan(
|
||||
planKey="ROOT",
|
||||
selectableByUser=False,
|
||||
title={"en": "Root (System)", "de": "Root (System)", "fr": "Root (Système)"},
|
||||
description={"en": "Internal system plan — no billing.", "de": "Interner Systemplan — keine Verrechnung."},
|
||||
billingPeriod=BillingPeriodEnum.NONE,
|
||||
autoRenew=False,
|
||||
maxUsers=None,
|
||||
maxFeatureInstances=None,
|
||||
),
|
||||
"TRIAL_7D": SubscriptionPlan(
|
||||
planKey="TRIAL_7D",
|
||||
selectableByUser=False,
|
||||
title={"en": "Free Trial (7 days)", "de": "Gratis-Testphase (7 Tage)", "fr": "Essai gratuit (7 jours)"},
|
||||
description={
|
||||
"en": "Try the platform for 7 days — 1 user, up to 3 feature instances.",
|
||||
"de": "Plattform 7 Tage testen — 1 User, bis zu 3 Feature-Instanzen.",
|
||||
},
|
||||
billingPeriod=BillingPeriodEnum.NONE,
|
||||
autoRenew=False,
|
||||
maxUsers=1,
|
||||
maxFeatureInstances=3,
|
||||
trialDays=7,
|
||||
successorPlanKey="STANDARD_MONTHLY",
|
||||
),
|
||||
"STANDARD_MONTHLY": SubscriptionPlan(
|
||||
planKey="STANDARD_MONTHLY",
|
||||
selectableByUser=True,
|
||||
title={"en": "Standard (Monthly)", "de": "Standard (Monatlich)", "fr": "Standard (Mensuel)"},
|
||||
description={
|
||||
"en": "Usage-based billing per active user and feature instance, billed monthly.",
|
||||
"de": "Nutzungsbasierte Abrechnung pro aktivem User und Feature-Instanz, monatlich.",
|
||||
},
|
||||
billingPeriod=BillingPeriodEnum.MONTHLY,
|
||||
pricePerUserCHF=90.0,
|
||||
pricePerFeatureInstanceCHF=150.0,
|
||||
),
|
||||
"STANDARD_YEARLY": SubscriptionPlan(
|
||||
planKey="STANDARD_YEARLY",
|
||||
selectableByUser=True,
|
||||
title={"en": "Standard (Yearly)", "de": "Standard (Jährlich)", "fr": "Standard (Annuel)"},
|
||||
description={
|
||||
"en": "Usage-based billing per active user and feature instance, billed yearly.",
|
||||
"de": "Nutzungsbasierte Abrechnung pro aktivem User und Feature-Instanz, jährlich.",
|
||||
},
|
||||
billingPeriod=BillingPeriodEnum.YEARLY,
|
||||
pricePerUserCHF=1080.0,
|
||||
pricePerFeatureInstanceCHF=1800.0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _getPlan(planKey: str) -> Optional[SubscriptionPlan]:
|
||||
"""Resolve a plan by key from the built-in catalog."""
|
||||
return BUILTIN_PLANS.get(planKey)
|
||||
|
||||
|
||||
def _getSelectablePlans() -> List[SubscriptionPlan]:
|
||||
"""Return plans that users can choose in the UI."""
|
||||
return [p for p in BUILTIN_PLANS.values() if p.selectableByUser]
|
||||
|
|
@ -1,46 +1,7 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Voice settings datamodel."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class VoiceSettings(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
userId: str = Field(description="ID of the user these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
|
||||
mandateId: str = Field(description="ID of the mandate these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
|
||||
featureInstanceId: str = Field(description="ID of the feature instance these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
|
||||
sttLanguage: str = Field(default="de-DE", description="Speech-to-Text language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
|
||||
ttsLanguage: str = Field(default="de-DE", description="Text-to-Speech language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
|
||||
ttsVoice: str = Field(default="de-DE-KatjaNeural", description="Text-to-Speech voice", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
|
||||
ttsVoiceMap: Dict[str, Any] = Field(default_factory=dict, description="Per-language voice mapping, e.g. {'de-DE': {'voiceName': 'de-DE-Wavenet-A'}, 'en-US': {'voiceName': 'en-US-Wavenet-C'}}", json_schema_extra={"frontend_type": "json", "frontend_readonly": False, "frontend_required": False})
|
||||
translationEnabled: bool = Field(default=True, description="Whether translation is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False})
|
||||
targetLanguage: str = Field(default="en-US", description="Target language for translation", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False})
|
||||
creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||
lastModified: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were last modified (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"VoiceSettings",
|
||||
{"en": "Voice Settings", "fr": "Paramètres vocaux"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"},
|
||||
"sttLanguage": {"en": "STT Language", "fr": "Langue STT"},
|
||||
"ttsLanguage": {"en": "TTS Language", "fr": "Langue TTS"},
|
||||
"ttsVoice": {"en": "TTS Voice", "fr": "Voix TTS"},
|
||||
"ttsVoiceMap": {"en": "TTS Voice Map", "fr": "Carte des voix TTS"},
|
||||
"translationEnabled": {"en": "Translation Enabled", "fr": "Traduction activée"},
|
||||
"targetLanguage": {"en": "Target Language", "fr": "Langue cible"},
|
||||
"creationDate": {"en": "Creation Date", "fr": "Date de création"},
|
||||
"lastModified": {"en": "Last Modified", "fr": "Dernière modification"},
|
||||
},
|
||||
)
|
||||
"""Voice settings datamodel — re-exported from workspace feature for backward compatibility."""
|
||||
|
||||
from modules.features.workspace.datamodelFeatureWorkspace import VoiceSettings
|
||||
|
||||
__all__ = ["VoiceSettings"]
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ class AutomationObjects:
|
|||
if value.lower() not in itemValue.lower():
|
||||
match = False
|
||||
break
|
||||
elif itemValue != value:
|
||||
elif str(itemValue).lower() != str(value).lower():
|
||||
match = False
|
||||
break
|
||||
if match:
|
||||
|
|
@ -418,6 +418,8 @@ class AutomationObjects:
|
|||
if not self.checkRbacPermission(AutomationDefinition, "update", automationId):
|
||||
raise PermissionError(f"No permission to modify automation {automationId}")
|
||||
|
||||
automationData.pop("executionLogs", None)
|
||||
|
||||
# If deactivating: immediately remove scheduler job (don't rely on async callback)
|
||||
isBeingDeactivated = "active" in automationData and not automationData["active"]
|
||||
if isBeingDeactivated:
|
||||
|
|
|
|||
|
|
@ -27,11 +27,6 @@ UI_OBJECTS = [
|
|||
"label": {"en": "Templates", "de": "Vorlagen", "fr": "Modèles"},
|
||||
"meta": {"area": "templates"}
|
||||
},
|
||||
{
|
||||
"objectKey": "ui.feature.automation.logs",
|
||||
"label": {"en": "Execution Logs", "de": "Ausführungsprotokolle", "fr": "Journaux d'exécution"},
|
||||
"meta": {"area": "logs"}
|
||||
},
|
||||
]
|
||||
|
||||
# Resource Objects for RBAC catalog
|
||||
|
|
|
|||
|
|
@ -109,6 +109,26 @@ def get_automations(
|
|||
detail=f"Error getting automations: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_automation_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in automations."""
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
result = chatInterface.getAllAutomationDefinitions(pagination=None)
|
||||
items = result if isinstance(result, list) else [r if isinstance(r, dict) else r.model_dump() if hasattr(r, 'model_dump') else r for r in result]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for automations: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("", response_model=AutomationDefinition)
|
||||
@limiter.limit("10/minute")
|
||||
def create_automation(
|
||||
|
|
@ -765,6 +785,7 @@ def update_automation(
|
|||
try:
|
||||
chatInterface = getAutomationInterface(context.user, mandateId=str(context.mandateId) if context.mandateId else None)
|
||||
automationData = automation.model_dump()
|
||||
automationData.pop("executionLogs", None)
|
||||
updated = chatInterface.updateAutomationDefinition(automationId, automationData)
|
||||
return updated
|
||||
except HTTPException:
|
||||
|
|
@ -1017,6 +1038,30 @@ def get_db_templates(
|
|||
)
|
||||
|
||||
|
||||
@templateRouter.get("/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_template_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in automation templates."""
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
chatInterface = getAutomationInterface(
|
||||
context.user,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None
|
||||
)
|
||||
result = chatInterface.getAllAutomationTemplates(pagination=None)
|
||||
items = [r if isinstance(r, dict) else r.model_dump() if hasattr(r, 'model_dump') else r for r in result]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for automation templates: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@templateRouter.get("/attributes", response_model=Dict[str, Any])
|
||||
def get_template_attributes(
|
||||
request: Request
|
||||
|
|
|
|||
2
modules/features/automation2/__init__.py
Normal file
2
modules/features/automation2/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Automation2 feature - n8n-style flow automation (backup/parallel to legacy automation)
|
||||
159
modules/features/automation2/datamodelFeatureAutomation2.py
Normal file
159
modules/features/automation2/datamodelFeatureAutomation2.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Automation2 models: Automation2Workflow, Automation2WorkflowRun, Automation2HumanTask."""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
import uuid
|
||||
|
||||
|
||||
class Automation2Workflow(BaseModel):
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Primary key",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False},
|
||||
)
|
||||
mandateId: str = Field(
|
||||
description="Mandate ID",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False},
|
||||
)
|
||||
featureInstanceId: str = Field(
|
||||
description="Feature instance ID",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False},
|
||||
)
|
||||
label: str = Field(
|
||||
description="User-friendly workflow name",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_required": True},
|
||||
)
|
||||
graph: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Graph with nodes and connections (incl. node parameters)",
|
||||
json_schema_extra={"frontend_type": "textarea", "frontend_required": True},
|
||||
)
|
||||
active: bool = Field(
|
||||
default=True,
|
||||
description="Whether workflow is active",
|
||||
json_schema_extra={"frontend_type": "checkbox", "frontend_required": False},
|
||||
)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"Automation2Workflow",
|
||||
{"en": "Automation2 Workflow", "de": "Automation2 Workflow", "fr": "Workflow Automation2"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"mandateId": {"en": "Mandate ID", "de": "Mandanten-ID", "fr": "ID du mandat"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "de": "Feature-Instanz-ID", "fr": "ID instance"},
|
||||
"label": {"en": "Label", "de": "Bezeichnung", "fr": "Libellé"},
|
||||
"graph": {"en": "Graph", "de": "Graph", "fr": "Graphe"},
|
||||
"active": {"en": "Active", "de": "Aktiv", "fr": "Actif"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Automation2WorkflowRun(BaseModel):
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Primary key",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False},
|
||||
)
|
||||
workflowId: str = Field(
|
||||
description="Workflow ID",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True},
|
||||
)
|
||||
status: str = Field(
|
||||
default="running",
|
||||
description="Status: running|paused|completed|failed",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_required": False},
|
||||
)
|
||||
nodeOutputs: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Outputs from executed nodes",
|
||||
json_schema_extra={"frontend_type": "textarea", "frontend_required": False},
|
||||
)
|
||||
currentNodeId: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Node ID when paused (human task)",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False},
|
||||
)
|
||||
context: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Context for resume (connectionMap, inputSources, etc.)",
|
||||
json_schema_extra={"frontend_type": "textarea", "frontend_required": False},
|
||||
)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"Automation2WorkflowRun",
|
||||
{"en": "Automation2 Workflow Run", "de": "Automation2 Workflow-Ausführung", "fr": "Exécution workflow"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"workflowId": {"en": "Workflow ID", "de": "Workflow-ID", "fr": "ID workflow"},
|
||||
"status": {"en": "Status", "de": "Status", "fr": "Statut"},
|
||||
"nodeOutputs": {"en": "Node Outputs", "de": "Node-Ausgaben", "fr": "Sorties nœuds"},
|
||||
"currentNodeId": {"en": "Current Node", "de": "Aktueller Knoten", "fr": "Nœud actuel"},
|
||||
"context": {"en": "Context", "de": "Kontext", "fr": "Contexte"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Automation2HumanTask(BaseModel):
|
||||
id: str = Field(
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
description="Primary key",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False},
|
||||
)
|
||||
runId: str = Field(
|
||||
description="Workflow run ID",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True},
|
||||
)
|
||||
workflowId: str = Field(
|
||||
description="Workflow ID",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True},
|
||||
)
|
||||
nodeId: str = Field(
|
||||
description="Node ID in the graph",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True},
|
||||
)
|
||||
nodeType: str = Field(
|
||||
description="Node type: form|approval|upload|comment|review|selection|confirmation",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True},
|
||||
)
|
||||
config: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Node config (form schema, approval text, etc.)",
|
||||
json_schema_extra={"frontend_type": "textarea", "frontend_required": False},
|
||||
)
|
||||
assigneeId: Optional[str] = Field(
|
||||
default=None,
|
||||
description="User ID assigned to complete the task",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_readonly": False, "frontend_required": False},
|
||||
)
|
||||
status: str = Field(
|
||||
default="pending",
|
||||
description="Status: pending|completed|rejected",
|
||||
json_schema_extra={"frontend_type": "text", "frontend_required": False},
|
||||
)
|
||||
result: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Task result (form data, approval decision, etc.)",
|
||||
json_schema_extra={"frontend_type": "textarea", "frontend_required": False},
|
||||
)
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"Automation2HumanTask",
|
||||
{"en": "Automation2 Human Task", "de": "Automation2 Benutzer-Aufgabe", "fr": "Tâche utilisateur"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID", "fr": "ID"},
|
||||
"runId": {"en": "Run ID", "de": "Lauf-ID", "fr": "ID exécution"},
|
||||
"workflowId": {"en": "Workflow ID", "de": "Workflow-ID", "fr": "ID workflow"},
|
||||
"nodeId": {"en": "Node ID", "de": "Knoten-ID", "fr": "ID nœud"},
|
||||
"nodeType": {"en": "Node Type", "de": "Knotentyp", "fr": "Type nœud"},
|
||||
"config": {"en": "Config", "de": "Konfiguration", "fr": "Configuration"},
|
||||
"assigneeId": {"en": "Assignee", "de": "Zugewiesen an", "fr": "Assigné à"},
|
||||
"status": {"en": "Status", "de": "Status", "fr": "Statut"},
|
||||
"result": {"en": "Result", "de": "Ergebnis", "fr": "Résultat"},
|
||||
},
|
||||
)
|
||||
268
modules/features/automation2/emailPoller.py
Normal file
268
modules/features/automation2/emailPoller.py
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Background email poller for automation2.
|
||||
Checks paused runs waiting for email (email.checkEmail node) and resumes when a new matching email arrives.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Job ID for scheduler
|
||||
POLLER_JOB_ID = "automation2_email_poller"
|
||||
POLL_INTERVAL_MINUTES = 2
|
||||
|
||||
|
||||
async def _pollEmailWaits(eventUser) -> None:
|
||||
"""
|
||||
Poll for new emails for runs waiting on email.checkEmail.
|
||||
Uses eventUser for DB access; loads owner user for each run to call readEmails.
|
||||
Stops the poller when no runs are waiting.
|
||||
"""
|
||||
try:
|
||||
from modules.features.automation2.interfaceFeatureAutomation2 import getAutomation2Interface
|
||||
from modules.features.automation2.mainAutomation2 import getAutomation2Services
|
||||
from modules.workflows.automation2.executionEngine import executeGraph
|
||||
from modules.workflows.processing.shared.methodDiscovery import discoverMethods
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
|
||||
root = getRootInterface()
|
||||
if not root:
|
||||
logger.warning("Email poller: root interface not available")
|
||||
return
|
||||
# Use eventUser - getRunsWaitingForEmail queries by status only
|
||||
a2 = getAutomation2Interface(eventUser, mandateId="", featureInstanceId="")
|
||||
runs = a2.getRunsWaitingForEmail()
|
||||
if not runs:
|
||||
# No workflows waiting for email - stop the poller
|
||||
stop(eventUser)
|
||||
return
|
||||
logger.info("Automation2 email poller: checking %d run(s) waiting for email", len(runs))
|
||||
|
||||
for run in runs:
|
||||
run_id = run.get("id")
|
||||
workflow_id = run.get("workflowId")
|
||||
context = run.get("context") or {}
|
||||
wait_config = context.get("waitConfig") or {}
|
||||
node_id = run.get("currentNodeId") or context.get("waitConfig", {}).get("_nodeId")
|
||||
owner_id = context.get("ownerId")
|
||||
mandate_id = context.get("mandateId")
|
||||
instance_id = context.get("instanceId")
|
||||
last_checked = context.get("lastCheckedAt")
|
||||
|
||||
if not owner_id or not mandate_id or not instance_id or not workflow_id or not node_id:
|
||||
logger.warning("Email wait run %s missing ownerId/mandateId/instanceId/workflowId/nodeId - skipping", run_id)
|
||||
continue
|
||||
|
||||
# First poll: use pausedAt (or now - 5 min) as baseline so we don't miss emails
|
||||
# that arrived between pause and first poll
|
||||
if last_checked is None:
|
||||
paused_at = context.get("pausedAt")
|
||||
if paused_at:
|
||||
baseline = paused_at
|
||||
else:
|
||||
# Fallback: look back 5 minutes for runs created before pausedAt existed
|
||||
baseline = (datetime.now(timezone.utc) - timedelta(minutes=5)).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
last_checked = baseline
|
||||
|
||||
# Load owner user (root interface has broad access)
|
||||
owner = root.getUser(owner_id) if hasattr(root, "getUser") else None
|
||||
if not owner:
|
||||
logger.warning("Email wait run %s: owner user %s not found", run_id, owner_id)
|
||||
continue
|
||||
|
||||
# Get workflow (need scoped interface for mandate/instance)
|
||||
a2_scoped = getAutomation2Interface(eventUser, mandateId=mandate_id, featureInstanceId=instance_id)
|
||||
wf = a2_scoped.getWorkflow(workflow_id)
|
||||
if not wf or not wf.get("graph"):
|
||||
logger.warning("Email wait run %s: workflow %s not found or has no graph", run_id, workflow_id)
|
||||
continue
|
||||
|
||||
# Only process runs paused at email.checkEmail – searchEmail never waits, it searches all immediately
|
||||
nodes = (wf.get("graph") or {}).get("nodes") or []
|
||||
paused_node = next((n for n in nodes if n.get("id") == node_id), None)
|
||||
if paused_node and paused_node.get("type") == "email.searchEmail":
|
||||
logger.warning("Email wait run %s: paused at email.searchEmail (should not wait) – skipping", run_id)
|
||||
continue
|
||||
|
||||
services = getAutomation2Services(owner, mandateId=mandate_id, featureInstanceId=instance_id)
|
||||
discoverMethods(services)
|
||||
|
||||
# Build filter with receivedDateTime – only emails received at or after baseline (new emails)
|
||||
base_filter = wait_config.get("filter") or ""
|
||||
dt_filter = f"receivedDateTime ge {last_checked}"
|
||||
combined_filter = f"({base_filter}) and {dt_filter}" if base_filter else dt_filter
|
||||
logger.debug("Email wait run %s: fetch filter (new emails only) %s", run_id, combined_filter)
|
||||
|
||||
from modules.workflows.processing.core.actionExecutor import ActionExecutor
|
||||
executor = ActionExecutor(services)
|
||||
params = {
|
||||
"connectionReference": wait_config.get("connectionReference"),
|
||||
"folder": wait_config.get("folder", "Inbox"),
|
||||
"limit": min(int(wait_config.get("limit", 10)), 50),
|
||||
"filter": combined_filter,
|
||||
}
|
||||
|
||||
try:
|
||||
result = await executor.executeAction("outlook", "readEmails", params)
|
||||
except Exception as e:
|
||||
logger.warning("Email wait run %s: readEmails failed: %s", run_id, e)
|
||||
continue
|
||||
|
||||
# readEmails always returns 1 document (JSON wrapper); check actual email count
|
||||
email_count = 0
|
||||
if result and result.documents:
|
||||
doc = result.documents[0]
|
||||
meta = getattr(doc, "validationMetadata", None)
|
||||
if not meta and isinstance(doc, dict):
|
||||
meta = doc.get("validationMetadata")
|
||||
if meta and isinstance(meta, dict):
|
||||
email_count = int(meta.get("emailCount", 0))
|
||||
else:
|
||||
try:
|
||||
data = json.loads(getattr(doc, "documentData", "") or "{}")
|
||||
email_count = len(data.get("emails", {}).get("emails", []))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not result or not result.success or email_count == 0:
|
||||
# No new emails - persist lastCheckedAt so next poll uses this as baseline
|
||||
now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
ctx = dict(context)
|
||||
ctx["lastCheckedAt"] = now_iso
|
||||
a2_scoped.updateRun(run_id, context=ctx)
|
||||
continue
|
||||
|
||||
# Only pass NEW emails (receivedDateTime >= last_checked) – filter server-side as safeguard
|
||||
doc = result.documents[0]
|
||||
raw_data = json.loads(getattr(doc, "documentData", "") or "{}")
|
||||
emails_data = raw_data.get("emails", {})
|
||||
all_emails = emails_data.get("emails", [])
|
||||
new_emails = [
|
||||
e for e in all_emails
|
||||
if (e.get("receivedDateTime") or "") >= last_checked
|
||||
]
|
||||
if not new_emails:
|
||||
now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
ctx = dict(context)
|
||||
ctx["lastCheckedAt"] = now_iso
|
||||
a2_scoped.updateRun(run_id, context=ctx)
|
||||
continue
|
||||
|
||||
# Rebuild document with only new emails for downstream nodes
|
||||
result_data = dict(raw_data)
|
||||
result_data["emails"] = dict(emails_data)
|
||||
result_data["emails"]["emails"] = new_emails
|
||||
result_data["emails"]["count"] = len(new_emails)
|
||||
|
||||
from modules.datamodels.datamodelChat import ActionDocument
|
||||
filtered_doc = ActionDocument(
|
||||
documentName=getattr(doc, "documentName", "outlook_emails.json"),
|
||||
documentData=json.dumps(result_data, indent=2),
|
||||
mimeType=getattr(doc, "mimeType", "application/json"),
|
||||
validationMetadata={**(getattr(doc, "validationMetadata") or {}), "emailCount": len(new_emails)},
|
||||
)
|
||||
|
||||
# Build node output in same format as ActionNodeExecutor for readEmails
|
||||
node_output = {
|
||||
"success": result.success,
|
||||
"error": result.error,
|
||||
"documents": [filtered_doc.model_dump() if hasattr(filtered_doc, "model_dump") else filtered_doc],
|
||||
"data": result.model_dump() if hasattr(result, "model_dump") else {"success": result.success, "error": result.error},
|
||||
}
|
||||
|
||||
# Update lastCheckedAt before resume
|
||||
now_iso = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
ctx = dict(context)
|
||||
ctx["lastCheckedAt"] = now_iso
|
||||
a2_scoped.updateRun(run_id, status="running", context=ctx)
|
||||
|
||||
node_outputs = dict(run.get("nodeOutputs") or {})
|
||||
node_outputs[node_id] = node_output
|
||||
|
||||
logger.info("Email wait run %s: found new email, resuming from node %s", run_id, node_id)
|
||||
|
||||
resume_result = await executeGraph(
|
||||
graph=wf["graph"],
|
||||
services=services,
|
||||
workflowId=workflow_id,
|
||||
instanceId=instance_id,
|
||||
userId=owner_id,
|
||||
mandateId=mandate_id,
|
||||
automation2_interface=a2_scoped,
|
||||
initialNodeOutputs=node_outputs,
|
||||
startAfterNodeId=node_id,
|
||||
runId=run_id,
|
||||
)
|
||||
|
||||
if resume_result.get("success"):
|
||||
logger.info("Email wait run %s: completed successfully", run_id)
|
||||
elif resume_result.get("paused"):
|
||||
logger.info("Email wait run %s: paused again (e.g. human task)", run_id)
|
||||
else:
|
||||
logger.warning("Email wait run %s: failed: %s", run_id, resume_result.get("error"))
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Email poller failed: %s", e)
|
||||
|
||||
|
||||
def _runPollSync(ev_user):
|
||||
"""Sync job for scheduler - runs async poll. Thread-safe for both main loop and worker threads."""
|
||||
try:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# Already in event loop - schedule, don't block
|
||||
loop.create_task(_pollEmailWaits(ev_user))
|
||||
except RuntimeError:
|
||||
# No running loop (worker thread) - run in new loop
|
||||
asyncio.run(_pollEmailWaits(ev_user))
|
||||
except Exception as e:
|
||||
logger.exception("Automation2 email poller job failed: %s", e)
|
||||
|
||||
|
||||
def ensureRunning(eventUser) -> bool:
|
||||
"""Start the poller if not already running. Called when a run pauses for email.checkEmail."""
|
||||
return start(eventUser)
|
||||
|
||||
|
||||
def start(eventUser) -> bool:
|
||||
"""Register the email poller interval job."""
|
||||
if not eventUser:
|
||||
logger.warning("Automation2 email poller: no eventUser, not registering")
|
||||
return False
|
||||
try:
|
||||
from modules.shared.eventManagement import eventManager
|
||||
|
||||
# Use sync wrapper - APScheduler may run jobs in thread pool where async doesn't work
|
||||
job_func = lambda: _runPollSync(eventUser)
|
||||
eventManager.registerInterval(
|
||||
POLLER_JOB_ID,
|
||||
job_func,
|
||||
seconds=0,
|
||||
minutes=POLL_INTERVAL_MINUTES,
|
||||
hours=0,
|
||||
)
|
||||
logger.info("Automation2 email poller started (interval=%s min)", POLL_INTERVAL_MINUTES)
|
||||
# Run once immediately so we don't wait 2 minutes for the first check
|
||||
_runPollSync(eventUser)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to register automation2 email poller: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def stop(eventUser) -> bool:
|
||||
"""Remove the email poller job."""
|
||||
try:
|
||||
from modules.shared.eventManagement import eventManager
|
||||
eventManager.remove(POLLER_JOB_ID)
|
||||
logger.info("Automation2 email poller removed")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Error removing automation2 email poller: %s", e)
|
||||
return True
|
||||
311
modules/features/automation2/interfaceFeatureAutomation2.py
Normal file
311
modules/features/automation2/interfaceFeatureAutomation2.py
Normal file
|
|
@ -0,0 +1,311 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Interface for Automation2 feature - Workflows, Runs, Human Tasks.
|
||||
Uses PostgreSQL poweron_automation2 database.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
|
||||
def _make_json_serializable(obj: Any) -> Any:
|
||||
"""
|
||||
Recursively convert bytes to base64 strings so structures can be JSON-serialized
|
||||
for storage in JSONB columns.
|
||||
"""
|
||||
if isinstance(obj, bytes):
|
||||
return base64.b64encode(obj).decode("ascii")
|
||||
if isinstance(obj, dict):
|
||||
return {k: _make_json_serializable(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_make_json_serializable(v) for v in obj]
|
||||
return obj
|
||||
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.features.automation2.datamodelFeatureAutomation2 import (
|
||||
Automation2Workflow,
|
||||
Automation2WorkflowRun,
|
||||
Automation2HumanTask,
|
||||
)
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def getAutomation2Interface(
|
||||
currentUser: User,
|
||||
mandateId: str,
|
||||
featureInstanceId: str,
|
||||
) -> "Automation2Objects":
|
||||
"""Factory for Automation2 interface with user context."""
|
||||
return Automation2Objects(
|
||||
currentUser=currentUser,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=featureInstanceId,
|
||||
)
|
||||
|
||||
|
||||
class Automation2Objects:
|
||||
"""Interface for Automation2 database operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
currentUser: User,
|
||||
mandateId: str,
|
||||
featureInstanceId: str,
|
||||
):
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = mandateId
|
||||
self.featureInstanceId = featureInstanceId
|
||||
self.userId = currentUser.id if currentUser else None
|
||||
self._init_db()
|
||||
if hasattr(self.db, "updateContext") and self.userId:
|
||||
self.db.updateContext(self.userId)
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize database connection to poweron_automation2."""
|
||||
dbHost = APP_CONFIG.get("DB_HOST", "localhost")
|
||||
dbDatabase = "poweron_automation2"
|
||||
dbUser = APP_CONFIG.get("DB_USER")
|
||||
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET") or APP_CONFIG.get("DB_PASSWORD")
|
||||
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
||||
self.db = DatabaseConnector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
dbPort=dbPort,
|
||||
userId=self.userId,
|
||||
)
|
||||
logger.debug("Automation2 database initialized for user %s", self.userId)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Workflow CRUD
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def getWorkflows(self) -> List[Dict[str, Any]]:
|
||||
"""Get all workflows for this mandate and feature instance."""
|
||||
if not self.db._ensureTableExists(Automation2Workflow):
|
||||
return []
|
||||
records = self.db.getRecordset(
|
||||
Automation2Workflow,
|
||||
recordFilter={
|
||||
"mandateId": self.mandateId,
|
||||
"featureInstanceId": self.featureInstanceId,
|
||||
},
|
||||
)
|
||||
return [dict(r) for r in records] if records else []
|
||||
|
||||
def getWorkflow(self, workflowId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single workflow by ID."""
|
||||
if not self.db._ensureTableExists(Automation2Workflow):
|
||||
return None
|
||||
records = self.db.getRecordset(
|
||||
Automation2Workflow,
|
||||
recordFilter={
|
||||
"id": workflowId,
|
||||
"mandateId": self.mandateId,
|
||||
"featureInstanceId": self.featureInstanceId,
|
||||
},
|
||||
)
|
||||
if not records:
|
||||
return None
|
||||
return dict(records[0])
|
||||
|
||||
def createWorkflow(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Create a new workflow."""
|
||||
if "id" not in data or not data.get("id"):
|
||||
data["id"] = str(uuid.uuid4())
|
||||
data["mandateId"] = self.mandateId
|
||||
data["featureInstanceId"] = self.featureInstanceId
|
||||
created = self.db.recordCreate(Automation2Workflow, data)
|
||||
return dict(created)
|
||||
|
||||
def updateWorkflow(self, workflowId: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Update an existing workflow."""
|
||||
existing = self.getWorkflow(workflowId)
|
||||
if not existing:
|
||||
return None
|
||||
# Don't overwrite mandateId/featureInstanceId
|
||||
data.pop("mandateId", None)
|
||||
data.pop("featureInstanceId", None)
|
||||
updated = self.db.recordModify(Automation2Workflow, workflowId, data)
|
||||
return dict(updated)
|
||||
|
||||
def deleteWorkflow(self, workflowId: str) -> bool:
|
||||
"""Delete a workflow."""
|
||||
existing = self.getWorkflow(workflowId)
|
||||
if not existing:
|
||||
return False
|
||||
self.db.recordDelete(Automation2Workflow, workflowId)
|
||||
return True
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Workflow Runs
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def createRun(self, workflowId: str, nodeOutputs: Dict = None, context: Dict = None) -> Dict[str, Any]:
|
||||
"""Create a new workflow run."""
|
||||
data = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"workflowId": workflowId,
|
||||
"status": "running",
|
||||
"nodeOutputs": _make_json_serializable(nodeOutputs or {}),
|
||||
"currentNodeId": None,
|
||||
"context": context or {},
|
||||
}
|
||||
created = self.db.recordCreate(Automation2WorkflowRun, data)
|
||||
return dict(created)
|
||||
|
||||
def getRun(self, runId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a run by ID."""
|
||||
if not self.db._ensureTableExists(Automation2WorkflowRun):
|
||||
return None
|
||||
records = self.db.getRecordset(
|
||||
Automation2WorkflowRun,
|
||||
recordFilter={"id": runId},
|
||||
)
|
||||
if not records:
|
||||
return None
|
||||
return dict(records[0])
|
||||
|
||||
def updateRun(
|
||||
self,
|
||||
runId: str,
|
||||
status: str = None,
|
||||
nodeOutputs: Dict = None,
|
||||
currentNodeId: str = None,
|
||||
context: Dict = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Update a run."""
|
||||
run = self.getRun(runId)
|
||||
if not run:
|
||||
return None
|
||||
updates = {}
|
||||
if status is not None:
|
||||
updates["status"] = status
|
||||
if nodeOutputs is not None:
|
||||
updates["nodeOutputs"] = _make_json_serializable(nodeOutputs)
|
||||
if currentNodeId is not None:
|
||||
updates["currentNodeId"] = currentNodeId
|
||||
if context is not None:
|
||||
updates["context"] = context
|
||||
if not updates:
|
||||
return run
|
||||
updated = self.db.recordModify(Automation2WorkflowRun, runId, updates)
|
||||
return dict(updated)
|
||||
|
||||
def getRunsByWorkflow(self, workflowId: str) -> List[Dict[str, Any]]:
|
||||
"""Get all runs for a workflow."""
|
||||
if not self.db._ensureTableExists(Automation2WorkflowRun):
|
||||
return []
|
||||
records = self.db.getRecordset(
|
||||
Automation2WorkflowRun,
|
||||
recordFilter={"workflowId": workflowId},
|
||||
)
|
||||
return [dict(r) for r in records] if records else []
|
||||
|
||||
def getRunsWaitingForEmail(self) -> List[Dict[str, Any]]:
|
||||
"""Get all paused runs waiting for a new email (for background poller)."""
|
||||
if not self.db._ensureTableExists(Automation2WorkflowRun):
|
||||
return []
|
||||
records = self.db.getRecordset(
|
||||
Automation2WorkflowRun,
|
||||
recordFilter={"status": "paused"},
|
||||
)
|
||||
if not records:
|
||||
return []
|
||||
result = []
|
||||
for r in records:
|
||||
rec = dict(r)
|
||||
ctx = rec.get("context") or {}
|
||||
if ctx.get("waitReason") == "email":
|
||||
result.append(rec)
|
||||
return result
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Human Tasks
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def createTask(
|
||||
self,
|
||||
runId: str,
|
||||
workflowId: str,
|
||||
nodeId: str,
|
||||
nodeType: str,
|
||||
config: Dict[str, Any],
|
||||
assigneeId: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a human task."""
|
||||
data = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"runId": runId,
|
||||
"workflowId": workflowId,
|
||||
"nodeId": nodeId,
|
||||
"nodeType": nodeType,
|
||||
"config": config,
|
||||
"assigneeId": assigneeId,
|
||||
"status": "pending",
|
||||
"result": None,
|
||||
}
|
||||
created = self.db.recordCreate(Automation2HumanTask, data)
|
||||
return dict(created)
|
||||
|
||||
def getTask(self, taskId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a task by ID."""
|
||||
if not self.db._ensureTableExists(Automation2HumanTask):
|
||||
return None
|
||||
records = self.db.getRecordset(
|
||||
Automation2HumanTask,
|
||||
recordFilter={"id": taskId},
|
||||
)
|
||||
if not records:
|
||||
return None
|
||||
return dict(records[0])
|
||||
|
||||
def updateTask(self, taskId: str, status: str = None, result: Dict = None) -> Optional[Dict[str, Any]]:
|
||||
"""Update a task (e.g. complete with result)."""
|
||||
task = self.getTask(taskId)
|
||||
if not task:
|
||||
return None
|
||||
updates = {}
|
||||
if status is not None:
|
||||
updates["status"] = status
|
||||
if result is not None:
|
||||
updates["result"] = result
|
||||
if not updates:
|
||||
return task
|
||||
updated = self.db.recordModify(Automation2HumanTask, taskId, updates)
|
||||
return dict(updated)
|
||||
|
||||
def getTasks(
|
||||
self,
|
||||
workflowId: str = None,
|
||||
runId: str = None,
|
||||
status: str = None,
|
||||
assigneeId: str = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get tasks with optional filters. AssigneeId filters to that user; None returns all."""
|
||||
if not self.db._ensureTableExists(Automation2HumanTask):
|
||||
return []
|
||||
rf = {}
|
||||
if workflowId:
|
||||
rf["workflowId"] = workflowId
|
||||
if runId:
|
||||
rf["runId"] = runId
|
||||
if status:
|
||||
rf["status"] = status
|
||||
if assigneeId:
|
||||
rf["assigneeId"] = assigneeId
|
||||
records = self.db.getRecordset(
|
||||
Automation2HumanTask,
|
||||
recordFilter=rf if rf else None,
|
||||
)
|
||||
items = [dict(r) for r in records] if records else []
|
||||
workflows = {w["id"]: w for w in self.getWorkflows()}
|
||||
filtered = [t for t in items if t.get("workflowId") in workflows]
|
||||
return filtered
|
||||
303
modules/features/automation2/mainAutomation2.py
Normal file
303
modules/features/automation2/mainAutomation2.py
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Automation2 Feature - n8n-style flow automation.
|
||||
Minimal bootstrap for feature instance creation. Build from here.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FEATURE_CODE = "automation2"
|
||||
|
||||
# Services required for automation2 (methodDiscovery, ActionExecutor, etc.)
|
||||
REQUIRED_SERVICES = [
|
||||
{"serviceKey": "chat", "meta": {"usage": "Interfaces, RBAC"}},
|
||||
{"serviceKey": "utils", "meta": {"usage": "Timestamps, utilities"}},
|
||||
{"serviceKey": "ai", "meta": {"usage": "AI nodes"}},
|
||||
{"serviceKey": "extraction", "meta": {"usage": "Workflow method actions"}},
|
||||
{"serviceKey": "sharepoint", "meta": {"usage": "SharePoint actions"}},
|
||||
]
|
||||
FEATURE_LABEL = {"en": "Automation 2", "de": "Automatisierung 2", "fr": "Automatisation 2"}
|
||||
FEATURE_ICON = "mdi-sitemap"
|
||||
|
||||
UI_OBJECTS = [
|
||||
{
|
||||
"objectKey": "ui.feature.automation2.editor",
|
||||
"label": {"en": "Editor", "de": "Editor", "fr": "Éditeur"},
|
||||
"meta": {"area": "editor"}
|
||||
},
|
||||
{
|
||||
"objectKey": "ui.feature.automation2.workflows",
|
||||
"label": {"en": "Workflows", "de": "Workflows", "fr": "Workflows"},
|
||||
"meta": {"area": "workflows"}
|
||||
},
|
||||
{
|
||||
"objectKey": "ui.feature.automation2.workflows-tasks",
|
||||
"label": {"en": "Tasks", "de": "Tasks", "fr": "Tâches"},
|
||||
"meta": {"area": "tasks"}
|
||||
},
|
||||
]
|
||||
|
||||
RESOURCE_OBJECTS = [
|
||||
{
|
||||
"objectKey": "resource.feature.automation2.dashboard",
|
||||
"label": {"en": "Access Dashboard", "de": "Dashboard aufrufen", "fr": "Acceder au tableau de bord"},
|
||||
"meta": {"endpoint": "/api/automation2/{instanceId}/info", "method": "GET"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.automation2.node-types",
|
||||
"label": {"en": "Get Node Types", "de": "Node-Typen abrufen", "fr": "Obtenir types de nœuds"},
|
||||
"meta": {"endpoint": "/api/automation2/{instanceId}/node-types", "method": "GET"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.feature.automation2.execute",
|
||||
"label": {"en": "Execute Workflow", "de": "Workflow ausführen", "fr": "Exécuter le workflow"},
|
||||
"meta": {"endpoint": "/api/automation2/{instanceId}/execute", "method": "POST"}
|
||||
},
|
||||
]
|
||||
|
||||
TEMPLATE_ROLES = [
|
||||
{
|
||||
"roleLabel": "automation2-user",
|
||||
"description": {
|
||||
"en": "Automation2 User - Use automation2 flow builder",
|
||||
"de": "Automation2 Benutzer - Flow-Builder nutzen",
|
||||
"fr": "Utilisateur Automation2 - Utiliser le flow builder"
|
||||
},
|
||||
"accessRules": [
|
||||
{"context": "UI", "item": "ui.feature.automation2.editor", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.automation2.workflows", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.automation2.workflows-tasks", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.automation2.dashboard", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.automation2.node-types", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.automation2.execute", "view": True},
|
||||
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "m", "update": "m", "delete": "m"},
|
||||
]
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def getRequiredServiceKeys() -> List[str]:
|
||||
"""Return list of service keys this feature requires."""
|
||||
return [s["serviceKey"] for s in REQUIRED_SERVICES]
|
||||
|
||||
|
||||
def getAutomation2Services(
|
||||
user,
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None,
|
||||
workflow=None,
|
||||
) -> "_Automation2ServiceHub":
|
||||
"""
|
||||
Get a service hub for automation2 using the service center.
|
||||
Used for methodDiscovery (I/O nodes) and execution (ActionExecutor).
|
||||
"""
|
||||
from modules.serviceCenter import getService
|
||||
from modules.serviceCenter.context import ServiceCenterContext
|
||||
|
||||
_workflow = workflow
|
||||
if _workflow is None:
|
||||
_workflow = type(
|
||||
"_Placeholder",
|
||||
(),
|
||||
{"featureCode": FEATURE_CODE, "id": None, "workflowMode": None, "messages": []},
|
||||
)()
|
||||
|
||||
ctx = ServiceCenterContext(
|
||||
user=user,
|
||||
mandate_id=mandateId,
|
||||
feature_instance_id=featureInstanceId,
|
||||
workflow=_workflow,
|
||||
)
|
||||
|
||||
hub = _Automation2ServiceHub()
|
||||
hub.user = user
|
||||
hub.mandateId = mandateId
|
||||
hub.featureInstanceId = featureInstanceId
|
||||
hub._service_context = ctx
|
||||
hub.workflow = workflow
|
||||
hub.featureCode = FEATURE_CODE
|
||||
|
||||
for spec in REQUIRED_SERVICES:
|
||||
key = spec["serviceKey"]
|
||||
try:
|
||||
svc = getService(key, ctx)
|
||||
setattr(hub, key, svc)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not resolve service '{key}' for automation2: {e}")
|
||||
setattr(hub, key, None)
|
||||
|
||||
if hub.chat:
|
||||
hub.interfaceDbApp = getattr(hub.chat, "interfaceDbApp", None)
|
||||
hub.interfaceDbComponent = getattr(hub.chat, "interfaceDbComponent", None)
|
||||
hub.interfaceDbChat = getattr(hub.chat, "interfaceDbChat", None)
|
||||
hub.rbac = getattr(hub.interfaceDbApp, "rbac", None) if getattr(hub, "interfaceDbApp", None) else None
|
||||
|
||||
return hub
|
||||
|
||||
|
||||
class _Automation2ServiceHub:
|
||||
"""Lightweight hub for automation2 (methodDiscovery, execution)."""
|
||||
|
||||
user = None
|
||||
mandateId = None
|
||||
featureInstanceId = None
|
||||
_service_context = None
|
||||
workflow = None
|
||||
featureCode = FEATURE_CODE
|
||||
interfaceDbApp = None
|
||||
interfaceDbComponent = None
|
||||
interfaceDbChat = None
|
||||
rbac = None
|
||||
chat = None
|
||||
ai = None
|
||||
utils = None
|
||||
extraction = None
|
||||
sharepoint = None
|
||||
|
||||
|
||||
async def onStart(eventUser) -> None:
|
||||
"""Feature startup. Email poller is started on-demand when a run pauses for email.checkEmail."""
|
||||
|
||||
|
||||
async def onStop(eventUser) -> None:
|
||||
"""Feature shutdown - remove email poller if running."""
|
||||
from modules.features.automation2.emailPoller import stop as stopEmailPoller
|
||||
stopEmailPoller(eventUser)
|
||||
|
||||
|
||||
def getFeatureDefinition() -> Dict[str, Any]:
|
||||
"""Return the feature definition for registration."""
|
||||
return {
|
||||
"code": FEATURE_CODE,
|
||||
"label": FEATURE_LABEL,
|
||||
"icon": FEATURE_ICON,
|
||||
"autoCreateInstance": True,
|
||||
}
|
||||
|
||||
|
||||
def getUiObjects() -> List[Dict[str, Any]]:
|
||||
"""Return UI objects for RBAC catalog registration."""
|
||||
return UI_OBJECTS
|
||||
|
||||
|
||||
def getResourceObjects() -> List[Dict[str, Any]]:
|
||||
"""Return resource objects for RBAC catalog registration."""
|
||||
return RESOURCE_OBJECTS
|
||||
|
||||
|
||||
def getTemplateRoles() -> List[Dict[str, Any]]:
|
||||
"""Return template roles for this feature."""
|
||||
return TEMPLATE_ROLES
|
||||
|
||||
|
||||
def registerFeature(catalogService) -> bool:
|
||||
"""Register this feature's RBAC objects in the catalog."""
|
||||
try:
|
||||
for uiObj in UI_OBJECTS:
|
||||
catalogService.registerUiObject(
|
||||
featureCode=FEATURE_CODE,
|
||||
objectKey=uiObj["objectKey"],
|
||||
label=uiObj["label"],
|
||||
meta=uiObj.get("meta")
|
||||
)
|
||||
for resObj in RESOURCE_OBJECTS:
|
||||
catalogService.registerResourceObject(
|
||||
featureCode=FEATURE_CODE,
|
||||
objectKey=resObj["objectKey"],
|
||||
label=resObj["label"],
|
||||
meta=resObj.get("meta")
|
||||
)
|
||||
_syncTemplateRolesToDb()
|
||||
logger.info(f"Feature '{FEATURE_CODE}' registered {len(UI_OBJECTS)} UI objects and {len(RESOURCE_OBJECTS)} resource objects")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register feature '{FEATURE_CODE}': {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _syncTemplateRolesToDb() -> int:
|
||||
"""Sync template roles and their AccessRules to database.
|
||||
Also syncs rules to mandate-specific roles (same roleLabel) so new UI objects
|
||||
become visible after gateway restart without manual role update.
|
||||
"""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.datamodels.datamodelRbac import Role
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
existingRoles = rootInterface.getRolesByFeatureCode(FEATURE_CODE)
|
||||
existingLabels = {r.roleLabel: str(r.id) for r in existingRoles if r.mandateId is None}
|
||||
created = 0
|
||||
|
||||
for template in TEMPLATE_ROLES:
|
||||
roleLabel = template["roleLabel"]
|
||||
if roleLabel in existingLabels:
|
||||
roleId = existingLabels[roleLabel]
|
||||
else:
|
||||
newRole = Role(
|
||||
roleLabel=roleLabel,
|
||||
description=template.get("description", {}),
|
||||
featureCode=FEATURE_CODE,
|
||||
mandateId=None,
|
||||
featureInstanceId=None,
|
||||
isSystemRole=False
|
||||
)
|
||||
rec = rootInterface.db.recordCreate(Role, newRole.model_dump())
|
||||
roleId = rec.get("id")
|
||||
created += 1
|
||||
logger.info(f"Created template role '{roleLabel}' for {FEATURE_CODE}")
|
||||
|
||||
_ensureAccessRulesForRole(rootInterface, roleId, template.get("accessRules", []))
|
||||
|
||||
# Sync same rules to mandate-specific roles (so Workflows & Tasks etc. appear in sidebar)
|
||||
for r in existingRoles:
|
||||
if r.mandateId and r.roleLabel == roleLabel:
|
||||
added = _ensureAccessRulesForRole(
|
||||
rootInterface, str(r.id), template.get("accessRules", [])
|
||||
)
|
||||
if added:
|
||||
logger.debug(f"Added {added} access rules to mandate role {r.id}")
|
||||
return created
|
||||
except Exception as e:
|
||||
logger.warning(f"Template role sync for {FEATURE_CODE}: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int:
|
||||
"""Ensure AccessRules exist for a role based on templates."""
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
|
||||
|
||||
existingRules = rootInterface.getAccessRulesByRole(roleId)
|
||||
existingSignatures = {
|
||||
(r.context.value if r.context else None, r.item)
|
||||
for r in existingRules
|
||||
}
|
||||
created = 0
|
||||
for t in ruleTemplates:
|
||||
context = t.get("context", "UI")
|
||||
item = t.get("item")
|
||||
sig = (context, item)
|
||||
if sig in existingSignatures:
|
||||
continue
|
||||
ctx_enum = (
|
||||
AccessRuleContext.UI if context == "UI" else
|
||||
AccessRuleContext.DATA if context == "DATA" else
|
||||
AccessRuleContext.RESOURCE if context == "RESOURCE" else context
|
||||
)
|
||||
newRule = AccessRule(
|
||||
roleId=roleId,
|
||||
context=ctx_enum,
|
||||
item=item,
|
||||
view=t.get("view", False),
|
||||
read=t.get("read"),
|
||||
create=t.get("create"),
|
||||
update=t.get("update"),
|
||||
delete=t.get("delete"),
|
||||
)
|
||||
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
|
||||
created += 1
|
||||
return created
|
||||
20
modules/features/automation2/nodeDefinitions/__init__.py
Normal file
20
modules/features/automation2/nodeDefinitions/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Node type definitions for automation2 flow builder.
|
||||
|
||||
from .triggers import TRIGGER_NODES
|
||||
from .flow import FLOW_NODES
|
||||
from .data import DATA_NODES
|
||||
from .input import INPUT_NODES
|
||||
from .ai import AI_NODES
|
||||
from .email import EMAIL_NODES
|
||||
from .sharepoint import SHAREPOINT_NODES
|
||||
|
||||
STATIC_NODE_TYPES = (
|
||||
TRIGGER_NODES
|
||||
+ FLOW_NODES
|
||||
+ DATA_NODES
|
||||
+ INPUT_NODES
|
||||
+ AI_NODES
|
||||
+ EMAIL_NODES
|
||||
+ SHAREPOINT_NODES
|
||||
)
|
||||
113
modules/features/automation2/nodeDefinitions/ai.py
Normal file
113
modules/features/automation2/nodeDefinitions/ai.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# AI node definitions - map to methodAi actions.
|
||||
|
||||
AI_NODES = [
|
||||
{
|
||||
"id": "ai.prompt",
|
||||
"category": "ai",
|
||||
"label": {"en": "Prompt", "de": "Prompt", "fr": "Invite"},
|
||||
"description": {"en": "Enter a prompt and AI does something", "de": "Prompt eingeben und KI führt aus", "fr": "Entrer une invite et l'IA exécute"},
|
||||
"parameters": [
|
||||
{"name": "prompt", "type": "string", "required": True, "description": {"en": "AI prompt", "de": "KI-Prompt", "fr": "Invite IA"}},
|
||||
{"name": "resultType", "type": "string", "required": False, "description": {"en": "Output format (txt, json, md, etc.)", "de": "Ausgabeformat", "fr": "Format de sortie"}, "default": "txt"},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-robot", "color": "#9C27B0"},
|
||||
"_method": "ai",
|
||||
"_action": "process",
|
||||
"_paramMap": {"prompt": "aiPrompt"},
|
||||
},
|
||||
{
|
||||
"id": "ai.webResearch",
|
||||
"category": "ai",
|
||||
"label": {"en": "Web Research", "de": "Web-Recherche", "fr": "Recherche web"},
|
||||
"description": {"en": "Research on the web", "de": "Recherche im Web", "fr": "Recherche sur le web"},
|
||||
"parameters": [
|
||||
{"name": "query", "type": "string", "required": True, "description": {"en": "Research query", "de": "Recherche-Anfrage", "fr": "Requête de recherche"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-magnify", "color": "#9C27B0"},
|
||||
"_method": "ai",
|
||||
"_action": "webResearch",
|
||||
"_paramMap": {"query": "prompt"},
|
||||
},
|
||||
{
|
||||
"id": "ai.summarizeDocument",
|
||||
"category": "ai",
|
||||
"label": {"en": "Summarize Document", "de": "Dokument zusammenfassen", "fr": "Résumer document"},
|
||||
"description": {"en": "Summarize document content", "de": "Dokumentinhalt zusammenfassen", "fr": "Résumer le contenu du document"},
|
||||
"parameters": [
|
||||
{"name": "summaryLength", "type": "string", "required": False, "description": {"en": "Short, medium, or long", "de": "Kurz, mittel oder lang", "fr": "Court, moyen ou long"}, "default": "medium"},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-file-document-outline", "color": "#9C27B0"},
|
||||
"_method": "ai",
|
||||
"_action": "summarizeDocument",
|
||||
"_paramMap": {},
|
||||
},
|
||||
{
|
||||
"id": "ai.translateDocument",
|
||||
"category": "ai",
|
||||
"label": {"en": "Translate Document", "de": "Dokument übersetzen", "fr": "Traduire document"},
|
||||
"description": {"en": "Translate document to target language", "de": "Dokument in Zielsprache übersetzen", "fr": "Traduire le document"},
|
||||
"parameters": [
|
||||
{"name": "targetLanguage", "type": "string", "required": True, "description": {"en": "Target language (e.g. en, de, fr)", "de": "Zielsprache", "fr": "Langue cible"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-translate", "color": "#9C27B0"},
|
||||
"_method": "ai",
|
||||
"_action": "translateDocument",
|
||||
"_paramMap": {"targetLanguage": "targetLanguage"},
|
||||
},
|
||||
{
|
||||
"id": "ai.convertDocument",
|
||||
"category": "ai",
|
||||
"label": {"en": "Convert Document", "de": "Dokument konvertieren", "fr": "Convertir document"},
|
||||
"description": {"en": "Convert document to another format", "de": "Dokument in anderes Format konvertieren", "fr": "Convertir le document"},
|
||||
"parameters": [
|
||||
{"name": "targetFormat", "type": "string", "required": True, "description": {"en": "Target format (pdf, docx, txt, etc.)", "de": "Zielformat", "fr": "Format cible"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-file-convert", "color": "#9C27B0"},
|
||||
"_method": "ai",
|
||||
"_action": "convertDocument",
|
||||
"_paramMap": {"targetFormat": "targetFormat"},
|
||||
},
|
||||
{
|
||||
"id": "ai.generateDocument",
|
||||
"category": "ai",
|
||||
"label": {"en": "Generate Document", "de": "Dokument generieren", "fr": "Générer document"},
|
||||
"description": {"en": "Generate document from prompt", "de": "Dokument aus Prompt generieren", "fr": "Générer un document"},
|
||||
"parameters": [
|
||||
{"name": "prompt", "type": "string", "required": True, "description": {"en": "Generation prompt", "de": "Generierungs-Prompt", "fr": "Invite de génération"}},
|
||||
{"name": "format", "type": "string", "required": False, "description": {"en": "Output format", "de": "Ausgabeformat", "fr": "Format de sortie"}, "default": "docx"},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-file-plus", "color": "#9C27B0"},
|
||||
"_method": "ai",
|
||||
"_action": "generateDocument",
|
||||
"_paramMap": {"prompt": "prompt", "format": "format"},
|
||||
},
|
||||
{
|
||||
"id": "ai.generateCode",
|
||||
"category": "ai",
|
||||
"label": {"en": "Generate Code", "de": "Code generieren", "fr": "Générer code"},
|
||||
"description": {"en": "Generate code from description", "de": "Code aus Beschreibung generieren", "fr": "Générer du code"},
|
||||
"parameters": [
|
||||
{"name": "prompt", "type": "string", "required": True, "description": {"en": "Code generation prompt", "de": "Code-Generierungs-Prompt", "fr": "Invite de génération de code"}},
|
||||
{"name": "language", "type": "string", "required": False, "description": {"en": "Programming language", "de": "Programmiersprache", "fr": "Langage de programmation"}, "default": "python"},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-code-tags", "color": "#9C27B0"},
|
||||
"_method": "ai",
|
||||
"_action": "generateCode",
|
||||
"_paramMap": {"prompt": "prompt", "language": "language"},
|
||||
},
|
||||
]
|
||||
58
modules/features/automation2/nodeDefinitions/data.py
Normal file
58
modules/features/automation2/nodeDefinitions/data.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Data transformation node definitions.
|
||||
|
||||
DATA_NODES = [
|
||||
{
|
||||
"id": "data.setFields",
|
||||
"category": "data",
|
||||
"label": {"en": "Set Fields", "de": "Felder setzen", "fr": "Définir champs"},
|
||||
"description": {"en": "Set or override fields on payload", "de": "Felder setzen oder überschreiben", "fr": "Définir ou écraser des champs"},
|
||||
"parameters": [
|
||||
{"name": "fields", "type": "object", "required": True, "description": {"en": "Key-value pairs", "de": "Schlüssel-Wert-Paare", "fr": "Paires clé-valeur"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "data",
|
||||
"meta": {"icon": "mdi-pencil", "color": "#673AB7"},
|
||||
},
|
||||
{
|
||||
"id": "data.filter",
|
||||
"category": "data",
|
||||
"label": {"en": "Filter", "de": "Filtern", "fr": "Filtrer"},
|
||||
"description": {"en": "Filter array by condition", "de": "Array nach Bedingung filtern", "fr": "Filtrer tableau par condition"},
|
||||
"parameters": [
|
||||
{"name": "condition", "type": "string", "required": True, "description": {"en": "Expression (e.g. item.active == true)", "de": "Bedingung", "fr": "Condition"}},
|
||||
{"name": "itemsPath", "type": "string", "required": False, "description": {"en": "Path to array", "de": "Pfad zum Array", "fr": "Chemin vers le tableau"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "data",
|
||||
"meta": {"icon": "mdi-filter", "color": "#673AB7"},
|
||||
},
|
||||
{
|
||||
"id": "data.parseJson",
|
||||
"category": "data",
|
||||
"label": {"en": "Parse JSON", "de": "JSON parsen", "fr": "Parser JSON"},
|
||||
"description": {"en": "Parse JSON string to object", "de": "JSON-String in Objekt parsen", "fr": "Parser chaîne JSON en objet"},
|
||||
"parameters": [
|
||||
{"name": "jsonPath", "type": "string", "required": False, "description": {"en": "Path to JSON string (default: input)", "de": "Pfad zum JSON", "fr": "Chemin vers JSON"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "data",
|
||||
"meta": {"icon": "mdi-code-json", "color": "#673AB7"},
|
||||
},
|
||||
{
|
||||
"id": "data.template",
|
||||
"category": "data",
|
||||
"label": {"en": "Template / Interpolation", "de": "Vorlage / Interpolation", "fr": "Modèle / Interpolation"},
|
||||
"description": {"en": "Text with {{placeholder}} substitution", "de": "Text mit {{platzhalter}}-Ersetzung", "fr": "Texte avec substitution {{placeholder}}"},
|
||||
"parameters": [
|
||||
{"name": "template", "type": "string", "required": True, "description": {"en": "Template (use {{path}} for values)", "de": "Vorlage", "fr": "Modèle"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "data",
|
||||
"meta": {"icon": "mdi-format-text", "color": "#673AB7"},
|
||||
},
|
||||
]
|
||||
70
modules/features/automation2/nodeDefinitions/email.py
Normal file
70
modules/features/automation2/nodeDefinitions/email.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Email node definitions - map to methodOutlook actions.
|
||||
# Use connectionId from user connections (like AI workspace sources).
|
||||
|
||||
EMAIL_NODES = [
|
||||
{
|
||||
"id": "email.checkEmail",
|
||||
"category": "email",
|
||||
"label": {"en": "Check Email", "de": "E-Mail prüfen", "fr": "Vérifier email"},
|
||||
"description": {"en": "Check for new emails (general or from specific account)", "de": "Neue E-Mails prüfen", "fr": "Vérifier les nouveaux emails"},
|
||||
"parameters": [
|
||||
{"name": "connectionId", "type": "string", "required": True, "description": {"en": "Email account connection", "de": "E-Mail-Konto Verbindung", "fr": "Connexion compte email"}},
|
||||
{"name": "folder", "type": "string", "required": False, "description": {"en": "Folder (e.g. Inbox)", "de": "Ordner (z.B. Posteingang)", "fr": "Dossier (ex. Boîte de réception)"}, "default": "Inbox"},
|
||||
{"name": "limit", "type": "number", "required": False, "description": {"en": "Max emails to fetch", "de": "Max E-Mails", "fr": "Max emails"}, "default": 100},
|
||||
{"name": "fromAddress", "type": "string", "required": False, "description": {"en": "Only emails from this address", "de": "Nur E-Mails von dieser Adresse", "fr": "Seulement les e-mails de cette adresse"}, "default": ""},
|
||||
{"name": "subjectContains", "type": "string", "required": False, "description": {"en": "Subject must contain this text", "de": "Betreff muss diesen Text enthalten", "fr": "Le sujet doit contenir ce texte"}, "default": ""},
|
||||
{"name": "hasAttachment", "type": "boolean", "required": False, "description": {"en": "Only emails with attachments", "de": "Nur E-Mails mit Anhängen", "fr": "Seulement les e-mails avec pièces jointes"}, "default": False},
|
||||
{"name": "filter", "type": "string", "required": False, "description": {"en": "Advanced: raw filter (overrides above if set)", "de": "Erweitert: Filter-Text (überschreibt obige)", "fr": "Avancé: filtre brut"}, "default": ""},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-email-check", "color": "#1976D2"},
|
||||
"_method": "outlook",
|
||||
"_action": "readEmails",
|
||||
"_paramMap": {"connectionId": "connectionReference", "folder": "folder", "limit": "limit", "filter": "filter"},
|
||||
},
|
||||
{
|
||||
"id": "email.searchEmail",
|
||||
"category": "email",
|
||||
"label": {"en": "Search Email", "de": "E-Mail suchen", "fr": "Rechercher email"},
|
||||
"description": {"en": "Search or find emails", "de": "E-Mails suchen oder finden", "fr": "Rechercher des emails"},
|
||||
"parameters": [
|
||||
{"name": "connectionId", "type": "string", "required": True, "description": {"en": "Email account connection", "de": "E-Mail-Konto Verbindung", "fr": "Connexion compte email"}},
|
||||
{"name": "query", "type": "string", "required": False, "description": {"en": "General search term (searches subject, body, from)", "de": "Suchbegriff (durchsucht Betreff, Inhalt, Absender)", "fr": "Terme de recherche (sujet, corps, expéditeur)"}, "default": ""},
|
||||
{"name": "folder", "type": "string", "required": False, "description": {"en": "Folder to search", "de": "Ordner zum Suchen", "fr": "Dossier à rechercher"}, "default": "Inbox"},
|
||||
{"name": "limit", "type": "number", "required": False, "description": {"en": "Max emails to return", "de": "Max E-Mails", "fr": "Max emails"}, "default": 100},
|
||||
{"name": "fromAddress", "type": "string", "required": False, "description": {"en": "Only emails from this address", "de": "Nur E-Mails von dieser Adresse", "fr": "Seulement les e-mails de cette adresse"}, "default": ""},
|
||||
{"name": "toAddress", "type": "string", "required": False, "description": {"en": "Only emails to this recipient", "de": "Nur E-Mails an diesen Empfänger", "fr": "Seulement les e-mails à ce destinataire"}, "default": ""},
|
||||
{"name": "subjectContains", "type": "string", "required": False, "description": {"en": "Subject must contain this text", "de": "Betreff muss diesen Text enthalten", "fr": "Le sujet doit contenir ce texte"}, "default": ""},
|
||||
{"name": "bodyContains", "type": "string", "required": False, "description": {"en": "Body/content must contain this text", "de": "Inhalt muss diesen Text enthalten", "fr": "Le corps doit contenir ce texte"}, "default": ""},
|
||||
{"name": "hasAttachment", "type": "boolean", "required": False, "description": {"en": "Only emails with attachments", "de": "Nur E-Mails mit Anhängen", "fr": "Seulement les e-mails avec pièces jointes"}, "default": False},
|
||||
{"name": "filter", "type": "string", "required": False, "description": {"en": "Advanced: raw KQL (overrides above if set)", "de": "Erweitert: KQL-Filter (überschreibt obige)", "fr": "Avancé: filtre KQL brut"}, "default": ""},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-email-search", "color": "#1976D2"},
|
||||
"_method": "outlook",
|
||||
"_action": "searchEmails",
|
||||
"_paramMap": {"connectionId": "connectionReference", "query": "query", "folder": "folder", "limit": "limit", "filter": "filter"},
|
||||
},
|
||||
{
|
||||
"id": "email.draftEmail",
|
||||
"category": "email",
|
||||
"label": {"en": "Draft Email", "de": "E-Mail entwerfen", "fr": "Brouillon email"},
|
||||
"description": {"en": "Create a draft email", "de": "E-Mail-Entwurf erstellen", "fr": "Créer un brouillon d'email"},
|
||||
"parameters": [
|
||||
{"name": "connectionId", "type": "string", "required": True, "description": {"en": "Email account connection", "de": "E-Mail-Konto Verbindung", "fr": "Connexion compte email"}},
|
||||
{"name": "subject", "type": "string", "required": True, "description": {"en": "Email subject", "de": "E-Mail-Betreff", "fr": "Sujet"}},
|
||||
{"name": "body", "type": "string", "required": True, "description": {"en": "Email body", "de": "E-Mail-Text", "fr": "Corps de l'email"}},
|
||||
{"name": "to", "type": "string", "required": False, "description": {"en": "Recipient(s)", "de": "Empfänger", "fr": "Destinataire(s)"}, "default": ""},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-email-edit", "color": "#1976D2"},
|
||||
"_method": "outlook",
|
||||
"_action": "composeAndDraftEmailWithContext",
|
||||
"_paramMap": {"connectionId": "connectionReference", "to": "to"},
|
||||
"_contextFrom": ["subject", "body"],
|
||||
},
|
||||
]
|
||||
82
modules/features/automation2/nodeDefinitions/flow.py
Normal file
82
modules/features/automation2/nodeDefinitions/flow.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Flow control node definitions.
|
||||
|
||||
FLOW_NODES = [
|
||||
{
|
||||
"id": "flow.ifElse",
|
||||
"category": "flow",
|
||||
"label": {"en": "If / Else", "de": "Wenn / Sonst", "fr": "Si / Sinon"},
|
||||
"description": {"en": "Branch based on condition", "de": "Verzweigung nach Bedingung", "fr": "Branche selon condition"},
|
||||
"parameters": [
|
||||
{"name": "condition", "type": "string", "required": True, "description": {"en": "Expression to evaluate (e.g. {{value}} > 0)", "de": "Bedingung", "fr": "Condition"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 2,
|
||||
"executor": "flow",
|
||||
"meta": {"icon": "mdi-source-branch", "color": "#FF9800"},
|
||||
},
|
||||
{
|
||||
"id": "flow.switch",
|
||||
"category": "flow",
|
||||
"label": {"en": "Switch", "de": "Switch", "fr": "Switch"},
|
||||
"description": {"en": "Multiple branches based on value", "de": "Mehrere Zweige nach Wert", "fr": "Branches multiples selon valeur"},
|
||||
"parameters": [
|
||||
{"name": "value", "type": "string", "required": True, "description": {"en": "Value to match", "de": "Zu vergleichender Wert", "fr": "Valeur à comparer"}},
|
||||
{"name": "cases", "type": "array", "required": False, "description": {"en": "List of cases", "de": "Fälle", "fr": "Cas"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "flow",
|
||||
"meta": {"icon": "mdi-swap-horizontal", "color": "#FF9800"},
|
||||
},
|
||||
{
|
||||
"id": "flow.merge",
|
||||
"category": "flow",
|
||||
"label": {"en": "Merge", "de": "Zusammenführen", "fr": "Fusionner"},
|
||||
"description": {"en": "Merge multiple inputs", "de": "Mehrere Eingaben zusammenführen", "fr": "Fusionner plusieurs entrées"},
|
||||
"parameters": [
|
||||
{"name": "mode", "type": "string", "required": False, "description": {"en": "append | combine", "de": "Modus", "fr": "Mode"}},
|
||||
],
|
||||
"inputs": 2,
|
||||
"outputs": 1,
|
||||
"executor": "flow",
|
||||
"meta": {"icon": "mdi-merge", "color": "#FF9800"},
|
||||
},
|
||||
{
|
||||
"id": "flow.loop",
|
||||
"category": "flow",
|
||||
"label": {"en": "Loop / For Each", "de": "Schleife / Für Jedes", "fr": "Boucle / Pour Chaque"},
|
||||
"description": {"en": "Iterate over array items", "de": "Über Array-Elemente iterieren", "fr": "Itérer sur les éléments"},
|
||||
"parameters": [
|
||||
{"name": "items", "type": "string", "required": True, "description": {"en": "Path to array (e.g. {{input.items}})", "de": "Pfad zum Array", "fr": "Chemin vers le tableau"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "flow",
|
||||
"meta": {"icon": "mdi-repeat", "color": "#FF9800"},
|
||||
},
|
||||
{
|
||||
"id": "flow.wait",
|
||||
"category": "flow",
|
||||
"label": {"en": "Wait / Delay", "de": "Warten / Verzögerung", "fr": "Attendre / Délai"},
|
||||
"description": {"en": "Pause for duration", "de": "Pause für Dauer", "fr": "Pause pour durée"},
|
||||
"parameters": [
|
||||
{"name": "seconds", "type": "number", "required": True, "description": {"en": "Seconds to wait", "de": "Sekunden", "fr": "Secondes"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "flow",
|
||||
"meta": {"icon": "mdi-timer", "color": "#FF9800"},
|
||||
},
|
||||
{
|
||||
"id": "flow.stop",
|
||||
"category": "flow",
|
||||
"label": {"en": "Stop / Terminate", "de": "Stopp / Beenden", "fr": "Arrêter / Terminer"},
|
||||
"description": {"en": "Stop workflow execution", "de": "Workflow-Ausführung beenden", "fr": "Arrêter l'exécution"},
|
||||
"parameters": [],
|
||||
"inputs": 1,
|
||||
"outputs": 0,
|
||||
"executor": "flow",
|
||||
"meta": {"icon": "mdi-stop", "color": "#F44336"},
|
||||
},
|
||||
]
|
||||
117
modules/features/automation2/nodeDefinitions/input.py
Normal file
117
modules/features/automation2/nodeDefinitions/input.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Input/Human node definitions - nodes that require user action.
|
||||
|
||||
INPUT_NODES = [
|
||||
{
|
||||
"id": "input.form",
|
||||
"category": "input",
|
||||
"label": {"en": "Form", "de": "Formular", "fr": "Formulaire"},
|
||||
"description": {"en": "User fills out a form", "de": "Benutzer füllt ein Formular aus", "fr": "L'utilisateur remplit un formulaire"},
|
||||
"parameters": [
|
||||
{
|
||||
"name": "fields",
|
||||
"type": "json",
|
||||
"required": True,
|
||||
"description": {"en": "Form fields: [{name, type, label, required, options?}]", "de": "Formularfelder", "fr": "Champs du formulaire"},
|
||||
"default": [],
|
||||
},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "input",
|
||||
"meta": {"icon": "mdi-form-textbox", "color": "#9C27B0"},
|
||||
},
|
||||
{
|
||||
"id": "input.approval",
|
||||
"category": "input",
|
||||
"label": {"en": "Approval", "de": "Genehmigung", "fr": "Approbation"},
|
||||
"description": {"en": "User approves or rejects", "de": "Benutzer genehmigt oder lehnt ab", "fr": "L'utilisateur approuve ou rejette"},
|
||||
"parameters": [
|
||||
{"name": "title", "type": "string", "required": True, "description": {"en": "Approval title", "de": "Genehmigungstitel", "fr": "Titre"}},
|
||||
{"name": "description", "type": "string", "required": False, "description": {"en": "What to approve", "de": "Was genehmigt werden soll", "fr": "Ce qu'il faut approuver"}},
|
||||
{"name": "approvalType", "type": "string", "required": False, "description": {"en": "Type: document or generic", "de": "Typ: document oder generic", "fr": "Type: document ou generic"}, "default": "generic"},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "input",
|
||||
"meta": {"icon": "mdi-check-decagram", "color": "#4CAF50"},
|
||||
},
|
||||
{
|
||||
"id": "input.upload",
|
||||
"category": "input",
|
||||
"label": {"en": "Upload", "de": "Upload", "fr": "Téléversement"},
|
||||
"description": {"en": "User uploads file(s)", "de": "Benutzer lädt Datei(en) hoch", "fr": "L'utilisateur téléverse des fichiers"},
|
||||
"parameters": [
|
||||
{"name": "accept", "type": "string", "required": False, "description": {"en": "MIME types (e.g. .pdf,image/*)", "de": "MIME-Typen", "fr": "Types MIME"}, "default": ""},
|
||||
{"name": "maxSize", "type": "number", "required": False, "description": {"en": "Max file size in MB", "de": "Max. Dateigröße in MB", "fr": "Taille max en Mo"}, "default": 10},
|
||||
{"name": "multiple", "type": "boolean", "required": False, "description": {"en": "Allow multiple files", "de": "Mehrere Dateien erlauben", "fr": "Autoriser plusieurs fichiers"}, "default": False},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "input",
|
||||
"meta": {"icon": "mdi-upload", "color": "#2196F3"},
|
||||
},
|
||||
{
|
||||
"id": "input.comment",
|
||||
"category": "input",
|
||||
"label": {"en": "Comment", "de": "Kommentar", "fr": "Commentaire"},
|
||||
"description": {"en": "User adds a comment", "de": "Benutzer fügt einen Kommentar hinzu", "fr": "L'utilisateur ajoute un commentaire"},
|
||||
"parameters": [
|
||||
{"name": "placeholder", "type": "string", "required": False, "description": {"en": "Placeholder text", "de": "Platzhalter", "fr": "Texte indicatif"}, "default": ""},
|
||||
{"name": "required", "type": "boolean", "required": False, "description": {"en": "Comment required", "de": "Kommentar erforderlich", "fr": "Commentaire requis"}, "default": True},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "input",
|
||||
"meta": {"icon": "mdi-comment-text", "color": "#FF9800"},
|
||||
},
|
||||
{
|
||||
"id": "input.review",
|
||||
"category": "input",
|
||||
"label": {"en": "Review", "de": "Prüfung", "fr": "Revue"},
|
||||
"description": {"en": "User reviews content", "de": "Benutzer prüft Inhalt", "fr": "L'utilisateur révise le contenu"},
|
||||
"parameters": [
|
||||
{"name": "contentRef", "type": "string", "required": True, "description": {"en": "Reference to content (e.g. {{nodeId.field}})", "de": "Referenz auf Inhalt", "fr": "Référence au contenu"}},
|
||||
{"name": "reviewType", "type": "string", "required": False, "description": {"en": "Type of review", "de": "Art der Prüfung", "fr": "Type de revue"}, "default": "generic"},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "input",
|
||||
"meta": {"icon": "mdi-magnify-scan", "color": "#673AB7"},
|
||||
},
|
||||
{
|
||||
"id": "input.selection",
|
||||
"category": "input",
|
||||
"label": {"en": "Selection", "de": "Auswahl", "fr": "Sélection"},
|
||||
"description": {"en": "User selects from options", "de": "Benutzer wählt aus Optionen", "fr": "L'utilisateur choisit parmi les options"},
|
||||
"parameters": [
|
||||
{
|
||||
"name": "options",
|
||||
"type": "json",
|
||||
"required": True,
|
||||
"description": {"en": "Options: [{value, label}]", "de": "Optionen", "fr": "Options"},
|
||||
"default": [],
|
||||
},
|
||||
{"name": "multiple", "type": "boolean", "required": False, "description": {"en": "Allow multiple selection", "de": "Mehrfachauswahl erlauben", "fr": "Sélection multiple"}, "default": False},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "input",
|
||||
"meta": {"icon": "mdi-format-list-checks", "color": "#009688"},
|
||||
},
|
||||
{
|
||||
"id": "input.confirmation",
|
||||
"category": "input",
|
||||
"label": {"en": "Confirmation", "de": "Bestätigung", "fr": "Confirmation"},
|
||||
"description": {"en": "User confirms yes/no", "de": "Benutzer bestätigt Ja/Nein", "fr": "L'utilisateur confirme oui/non"},
|
||||
"parameters": [
|
||||
{"name": "question", "type": "string", "required": True, "description": {"en": "Question to confirm", "de": "Zu bestätigende Frage", "fr": "Question à confirmer"}},
|
||||
{"name": "confirmLabel", "type": "string", "required": False, "description": {"en": "Label for confirm button", "de": "Label für Bestätigen-Button", "fr": "Libellé du bouton confirmer"}, "default": "Confirm"},
|
||||
{"name": "rejectLabel", "type": "string", "required": False, "description": {"en": "Label for reject button", "de": "Label für Ablehnen-Button", "fr": "Libellé du bouton refuser"}, "default": "Reject"},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"executor": "input",
|
||||
"meta": {"icon": "mdi-checkbox-marked-circle", "color": "#8BC34A"},
|
||||
},
|
||||
]
|
||||
105
modules/features/automation2/nodeDefinitions/sharepoint.py
Normal file
105
modules/features/automation2/nodeDefinitions/sharepoint.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# SharePoint node definitions - map to methodSharepoint actions.
|
||||
# Use connectionId and path from connection selector (like workflow folder view).
|
||||
|
||||
SHAREPOINT_NODES = [
|
||||
{
|
||||
"id": "sharepoint.findFile",
|
||||
"category": "sharepoint",
|
||||
"label": {"en": "Find File", "de": "Datei finden", "fr": "Trouver fichier"},
|
||||
"description": {"en": "Find file by path or search", "de": "Datei nach Pfad oder Suche finden", "fr": "Trouver fichier par chemin ou recherche"},
|
||||
"parameters": [
|
||||
{"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}},
|
||||
{"name": "searchQuery", "type": "string", "required": True, "description": {"en": "Search query or path", "de": "Suchanfrage oder Pfad", "fr": "Requête ou chemin"}},
|
||||
{"name": "site", "type": "string", "required": False, "description": {"en": "Optional site hint", "de": "Optionaler Site-Hinweis", "fr": "Indication de site"}, "default": ""},
|
||||
{"name": "maxResults", "type": "number", "required": False, "description": {"en": "Max results", "de": "Max Ergebnisse", "fr": "Max résultats"}, "default": 1000},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-file-search", "color": "#0078D4"},
|
||||
"_method": "sharepoint",
|
||||
"_action": "findDocumentPath",
|
||||
"_paramMap": {"connectionId": "connectionReference", "searchQuery": "searchQuery", "site": "site", "maxResults": "maxResults"},
|
||||
},
|
||||
{
|
||||
"id": "sharepoint.readFile",
|
||||
"category": "sharepoint",
|
||||
"label": {"en": "Read File", "de": "Datei lesen", "fr": "Lire fichier"},
|
||||
"description": {"en": "Extract content from file", "de": "Inhalt aus Datei extrahieren", "fr": "Extraire le contenu du fichier"},
|
||||
"parameters": [
|
||||
{"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}},
|
||||
{"name": "path", "type": "string", "required": True, "description": {"en": "File path or documentList from find file", "de": "Dateipfad oder documentList von Find", "fr": "Chemin ou documentList"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-file-document", "color": "#0078D4"},
|
||||
"_method": "sharepoint",
|
||||
"_action": "readDocuments",
|
||||
"_paramMap": {"connectionId": "connectionReference", "path": "pathQuery"},
|
||||
},
|
||||
{
|
||||
"id": "sharepoint.uploadFile",
|
||||
"category": "sharepoint",
|
||||
"label": {"en": "Upload File", "de": "Datei hochladen", "fr": "Téléverser fichier"},
|
||||
"description": {"en": "Upload file to SharePoint", "de": "Datei zu SharePoint hochladen", "fr": "Téléverser fichier vers SharePoint"},
|
||||
"parameters": [
|
||||
{"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}},
|
||||
{"name": "path", "type": "string", "required": True, "description": {"en": "Target folder path (e.g. /sites/.../Folder)", "de": "Zielordner-Pfad", "fr": "Chemin du dossier cible"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-upload", "color": "#0078D4"},
|
||||
"_method": "sharepoint",
|
||||
"_action": "uploadFile",
|
||||
"_paramMap": {"connectionId": "connectionReference", "path": "pathQuery"},
|
||||
},
|
||||
{
|
||||
"id": "sharepoint.listFiles",
|
||||
"category": "sharepoint",
|
||||
"label": {"en": "List Files", "de": "Dateien auflisten", "fr": "Lister fichiers"},
|
||||
"description": {"en": "List files in folder or SharePoint", "de": "Dateien in Ordner oder SharePoint auflisten", "fr": "Lister les fichiers dans un dossier"},
|
||||
"parameters": [
|
||||
{"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}},
|
||||
{"name": "path", "type": "string", "required": False, "description": {"en": "Folder path (e.g. /sites/SiteName/Shared Documents)", "de": "Ordnerpfad", "fr": "Chemin du dossier"}, "default": "/"},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-folder-open", "color": "#0078D4"},
|
||||
"_method": "sharepoint",
|
||||
"_action": "listDocuments",
|
||||
"_paramMap": {"connectionId": "connectionReference", "path": "pathQuery"},
|
||||
},
|
||||
{
|
||||
"id": "sharepoint.downloadFile",
|
||||
"category": "sharepoint",
|
||||
"label": {"en": "Download File", "de": "Datei herunterladen", "fr": "Télécharger fichier"},
|
||||
"description": {"en": "Download file from path (e.g. /sites/SiteName/Shared Documents/file.pdf)", "de": "Datei vom Pfad herunterladen", "fr": "Télécharger le fichier"},
|
||||
"parameters": [
|
||||
{"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}},
|
||||
{"name": "path", "type": "string", "required": True, "description": {"en": "Full file path (e.g. /sites/SiteName/Shared Documents/file.pdf)", "de": "Vollständiger Dateipfad", "fr": "Chemin complet du fichier"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-download", "color": "#0078D4"},
|
||||
"_method": "sharepoint",
|
||||
"_action": "downloadFileByPath",
|
||||
"_paramMap": {"connectionId": "connectionReference", "path": "pathQuery", "siteId": "siteId", "filePath": "filePath"},
|
||||
},
|
||||
{
|
||||
"id": "sharepoint.copyFile",
|
||||
"category": "sharepoint",
|
||||
"label": {"en": "Copy File", "de": "Datei kopieren", "fr": "Copier fichier"},
|
||||
"description": {"en": "Copy file to destination", "de": "Datei an Ziel kopieren", "fr": "Copier le fichier"},
|
||||
"parameters": [
|
||||
{"name": "connectionId", "type": "string", "required": True, "description": {"en": "SharePoint connection", "de": "SharePoint-Verbindung", "fr": "Connexion SharePoint"}},
|
||||
{"name": "sourcePath", "type": "string", "required": True, "description": {"en": "Source file path (from browse)", "de": "Quelldatei-Pfad", "fr": "Chemin fichier source"}},
|
||||
{"name": "destPath", "type": "string", "required": True, "description": {"en": "Destination folder path (from browse)", "de": "Zielordner-Pfad", "fr": "Chemin dossier cible"}},
|
||||
],
|
||||
"inputs": 1,
|
||||
"outputs": 1,
|
||||
"meta": {"icon": "mdi-content-copy", "color": "#0078D4"},
|
||||
"_method": "sharepoint",
|
||||
"_action": "copyFile",
|
||||
"_paramMap": {"connectionId": "connectionReference", "sourcePath": "sourcePath", "destPath": "destPath"},
|
||||
},
|
||||
]
|
||||
42
modules/features/automation2/nodeDefinitions/triggers.py
Normal file
42
modules/features/automation2/nodeDefinitions/triggers.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Trigger node definitions - workflow entry points.
|
||||
|
||||
TRIGGER_NODES = [
|
||||
{
|
||||
"id": "trigger.manual",
|
||||
"category": "trigger",
|
||||
"label": {"en": "Manual Trigger", "de": "Manueller Trigger", "fr": "Déclencheur manuel"},
|
||||
"description": {"en": "Start workflow on button press", "de": "Startet den Workflow bei Knopfdruck", "fr": "Démarre le workflow sur clic"},
|
||||
"parameters": [],
|
||||
"inputs": 0,
|
||||
"outputs": 1,
|
||||
"executor": "trigger",
|
||||
"meta": {"icon": "mdi-play", "color": "#4CAF50"},
|
||||
},
|
||||
{
|
||||
"id": "trigger.schedule",
|
||||
"category": "trigger",
|
||||
"label": {"en": "Schedule", "de": "Zeitplan", "fr": "Planification"},
|
||||
"description": {"en": "Run on a cron schedule", "de": "Läuft nach Cron-Zeitplan", "fr": "S'exécute selon un cron"},
|
||||
"parameters": [
|
||||
{"name": "cron", "type": "string", "required": True, "description": {"en": "Cron expression (e.g. 0 9 * * * for daily at 9)", "de": "Cron-Ausdruck", "fr": "Expression cron"}},
|
||||
],
|
||||
"inputs": 0,
|
||||
"outputs": 1,
|
||||
"executor": "trigger",
|
||||
"meta": {"icon": "mdi-clock", "color": "#2196F3"},
|
||||
},
|
||||
{
|
||||
"id": "trigger.formSubmit",
|
||||
"category": "trigger",
|
||||
"label": {"en": "Form Submit", "de": "Formular-Absendung", "fr": "Soumission formulaire"},
|
||||
"description": {"en": "Start when form is submitted", "de": "Startet bei Formular-Absendung", "fr": "Démarre à la soumission du formulaire"},
|
||||
"parameters": [
|
||||
{"name": "formId", "type": "string", "required": True, "description": {"en": "Form identifier", "de": "Formular-ID", "fr": "Identifiant du formulaire"}},
|
||||
],
|
||||
"inputs": 0,
|
||||
"outputs": 1,
|
||||
"executor": "trigger",
|
||||
"meta": {"icon": "mdi-form-select", "color": "#9C27B0"},
|
||||
},
|
||||
]
|
||||
81
modules/features/automation2/nodeRegistry.py
Normal file
81
modules/features/automation2/nodeRegistry.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Node Type Registry for automation2 - static node definitions (ai, email, sharepoint, trigger, flow, data, input).
|
||||
Nodes are defined first; IO/method actions are used at execution time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from modules.features.automation2.nodeDefinitions import STATIC_NODE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def getNodeTypes(
|
||||
services: Any = None,
|
||||
language: str = "en",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Return static node types. No dynamic I/O derivation from methodDiscovery.
|
||||
services: Optional (kept for API compatibility, not used).
|
||||
"""
|
||||
return list(STATIC_NODE_TYPES)
|
||||
|
||||
|
||||
def _localizeNode(node: Dict[str, Any], language: str) -> Dict[str, Any]:
|
||||
"""Apply language to label/description/parameters."""
|
||||
lang = language if language in ("en", "de", "fr") else "en"
|
||||
out = dict(node)
|
||||
# Strip internal keys for API response
|
||||
for key in list(out.keys()):
|
||||
if key.startswith("_"):
|
||||
del out[key]
|
||||
if isinstance(node.get("label"), dict):
|
||||
out["label"] = node["label"].get(lang, node["label"].get("en", str(node["label"])))
|
||||
if isinstance(node.get("description"), dict):
|
||||
out["description"] = node["description"].get(lang, node["description"].get("en", str(node["description"])))
|
||||
params = []
|
||||
for p in node.get("parameters", []):
|
||||
pc = dict(p)
|
||||
if isinstance(p.get("description"), dict):
|
||||
pc["description"] = p["description"].get(lang, p["description"].get("en", str(p.get("description", ""))))
|
||||
params.append(pc)
|
||||
out["parameters"] = params
|
||||
return out
|
||||
|
||||
|
||||
def getNodeTypesForApi(
|
||||
services: Any,
|
||||
language: str = "en",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
API-ready response: nodeTypes with localized strings, plus categories list.
|
||||
"""
|
||||
nodes = getNodeTypes(services, language)
|
||||
localized = [_localizeNode(n, language) for n in nodes]
|
||||
categories = [
|
||||
{"id": "trigger", "label": {"en": "Trigger", "de": "Trigger", "fr": "Déclencheur"}},
|
||||
{"id": "input", "label": {"en": "Input/Human", "de": "Eingabe/Mensch", "fr": "Entrée/Humain"}},
|
||||
{"id": "flow", "label": {"en": "Flow", "de": "Ablauf", "fr": "Flux"}},
|
||||
{"id": "data", "label": {"en": "Data", "de": "Daten", "fr": "Données"}},
|
||||
{"id": "ai", "label": {"en": "AI", "de": "KI", "fr": "IA"}},
|
||||
{"id": "email", "label": {"en": "Email", "de": "E-Mail", "fr": "Email"}},
|
||||
{"id": "sharepoint", "label": {"en": "SharePoint", "de": "SharePoint", "fr": "SharePoint"}},
|
||||
]
|
||||
return {"nodeTypes": localized, "categories": categories}
|
||||
|
||||
|
||||
def getNodeTypeToMethodAction() -> Dict[str, tuple]:
|
||||
"""
|
||||
Mapping from node type id to (method, action) for execution.
|
||||
Used by ActionNodeExecutor.
|
||||
"""
|
||||
mapping = {}
|
||||
for node in STATIC_NODE_TYPES:
|
||||
method = node.get("_method")
|
||||
action = node.get("_action")
|
||||
if method and action:
|
||||
mapping[node["id"]] = (method, action)
|
||||
return mapping
|
||||
602
modules/features/automation2/routeFeatureAutomation2.py
Normal file
602
modules/features/automation2/routeFeatureAutomation2.py
Normal file
|
|
@ -0,0 +1,602 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Automation2 routes - node-types, execute, workflows, runs, tasks, connections, browse.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, Path, Query, Body, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from modules.auth import limiter, getRequestContext, RequestContext
|
||||
|
||||
from modules.features.automation2.mainAutomation2 import getAutomation2Services
|
||||
from modules.features.automation2.nodeRegistry import getNodeTypesForApi
|
||||
from modules.features.automation2.interfaceFeatureAutomation2 import getAutomation2Interface
|
||||
from modules.workflows.automation2.executionEngine import executeGraph
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/automation2",
|
||||
tags=["Automation2"],
|
||||
responses={404: {"description": "Not found"}, 403: {"description": "Forbidden"}},
|
||||
)
|
||||
|
||||
|
||||
def _validateInstanceAccess(instanceId: str, context: RequestContext) -> str:
|
||||
"""Validate user has access to the automation2 feature instance. Returns mandateId."""
|
||||
from fastapi import HTTPException
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
instance = rootInterface.getFeatureInstance(instanceId)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail=f"Feature instance {instanceId} not found")
|
||||
featureAccess = rootInterface.getFeatureAccess(str(context.user.id), instanceId)
|
||||
if not featureAccess or not featureAccess.enabled:
|
||||
raise HTTPException(status_code=403, detail="Access denied to this feature instance")
|
||||
return str(instance.mandateId) if instance.mandateId else ""
|
||||
|
||||
|
||||
@router.get("/{instanceId}/info")
|
||||
@limiter.limit("60/minute")
|
||||
def get_automation2_info(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Minimal info endpoint - proves the feature works."""
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
return {
|
||||
"featureCode": "automation2",
|
||||
"instanceId": instanceId,
|
||||
"status": "ok",
|
||||
"message": "Automation2 feature ready. Build from here.",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{instanceId}/node-types")
|
||||
@limiter.limit("60/minute")
|
||||
def get_node_types(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
language: str = Query("en", description="Localization (en, de, fr)"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Return node types for the flow builder: static + I/O from methodDiscovery."""
|
||||
logger.info("automation2 node-types request: instanceId=%s language=%s", instanceId, language)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
services = getAutomation2Services(
|
||||
context.user,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
)
|
||||
result = getNodeTypesForApi(services, language=language)
|
||||
logger.info(
|
||||
"automation2 node-types response: %d nodeTypes %d categories",
|
||||
len(result.get("nodeTypes", [])),
|
||||
len(result.get("categories", [])),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/{instanceId}/execute")
|
||||
@limiter.limit("30/minute")
|
||||
async def post_execute(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
body: dict = Body(..., description="{ workflowId?, graph: { nodes, connections } }"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Execute automation2 graph. Body: { workflowId?, graph: { nodes, connections } }."""
|
||||
userId = str(context.user.id) if context.user else None
|
||||
logger.info(
|
||||
"automation2 execute request: instanceId=%s userId=%s body_keys=%s",
|
||||
instanceId,
|
||||
userId,
|
||||
list(body.keys()),
|
||||
)
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
services = getAutomation2Services(
|
||||
context.user,
|
||||
mandateId=mandateId,
|
||||
featureInstanceId=instanceId,
|
||||
)
|
||||
# Ensure workflow methods (outlook, ai, sharepoint, etc.) are discovered for ActionExecutor
|
||||
from modules.workflows.processing.shared.methodDiscovery import discoverMethods
|
||||
discoverMethods(services)
|
||||
|
||||
graph = body.get("graph") or body
|
||||
workflowId = body.get("workflowId")
|
||||
req_nodes = graph.get("nodes") or []
|
||||
# When workflowId is set: prefer graph from request (current editor state) if it has nodes.
|
||||
# Only fall back to stored workflow graph when request graph is empty (e.g. resume from email).
|
||||
if workflowId and len(req_nodes) == 0:
|
||||
a2 = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
wf = a2.getWorkflow(workflowId)
|
||||
if wf and wf.get("graph"):
|
||||
graph = wf["graph"]
|
||||
logger.info("automation2 execute: loaded graph from workflow %s", workflowId)
|
||||
# Use transient workflowId when none provided (e.g. execute from editor without save)
|
||||
# Required for email.checkEmail pause/resume - run must be created
|
||||
if not workflowId:
|
||||
import uuid
|
||||
workflowId = f"transient-{uuid.uuid4().hex[:12]}"
|
||||
logger.info("automation2 execute: using transient workflowId=%s", workflowId)
|
||||
nodes_count = len(graph.get("nodes") or [])
|
||||
connections_count = len(graph.get("connections") or [])
|
||||
logger.info(
|
||||
"automation2 execute: graph nodes=%d connections=%d workflowId=%s mandateId=%s",
|
||||
nodes_count,
|
||||
connections_count,
|
||||
workflowId,
|
||||
mandateId,
|
||||
)
|
||||
a2_interface = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
result = await executeGraph(
|
||||
graph=graph,
|
||||
services=services,
|
||||
workflowId=workflowId,
|
||||
instanceId=instanceId,
|
||||
userId=userId,
|
||||
mandateId=mandateId,
|
||||
automation2_interface=a2_interface,
|
||||
)
|
||||
logger.info(
|
||||
"automation2 execute result: success=%s error=%s nodeOutputs_keys=%s failedNode=%s paused=%s",
|
||||
result.get("success"),
|
||||
result.get("error"),
|
||||
list(result.get("nodeOutputs", {}).keys()) if result.get("nodeOutputs") else [],
|
||||
result.get("failedNode"),
|
||||
result.get("paused"),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Connections and Browse (for Email/SharePoint node config - like workspace)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _buildResolverDbInterface(chatService):
|
||||
"""Build a DB adapter that ConnectorResolver can use to load UserConnections."""
|
||||
class _ResolverDbAdapter:
|
||||
def __init__(self, appInterface):
|
||||
self._app = appInterface
|
||||
|
||||
def getUserConnection(self, connectionId: str):
|
||||
if hasattr(self._app, "getUserConnectionById"):
|
||||
return self._app.getUserConnectionById(connectionId)
|
||||
return None
|
||||
|
||||
appIf = getattr(chatService, "interfaceDbApp", None)
|
||||
if appIf:
|
||||
return _ResolverDbAdapter(appIf)
|
||||
return getattr(chatService, "interfaceDbComponent", None)
|
||||
|
||||
|
||||
@router.get("/{instanceId}/connections")
|
||||
@limiter.limit("300/minute")
|
||||
def list_automation2_connections(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Return the user's active connections (UserConnections) for Email/SharePoint node config."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
from modules.serviceCenter import getService
|
||||
from modules.serviceCenter.context import ServiceCenterContext
|
||||
ctx = ServiceCenterContext(
|
||||
user=context.user,
|
||||
mandate_id=str(context.mandateId) if context.mandateId else mandateId,
|
||||
feature_instance_id=instanceId,
|
||||
)
|
||||
chatService = getService("chat", ctx)
|
||||
connections = chatService.getUserConnections()
|
||||
items = []
|
||||
for c in connections or []:
|
||||
conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {})
|
||||
authority = conn.get("authority")
|
||||
if hasattr(authority, "value"):
|
||||
authority = authority.value
|
||||
status = conn.get("status")
|
||||
if hasattr(status, "value"):
|
||||
status = status.value
|
||||
items.append({
|
||||
"id": conn.get("id"),
|
||||
"authority": authority,
|
||||
"externalUsername": conn.get("externalUsername"),
|
||||
"externalEmail": conn.get("externalEmail"),
|
||||
"status": status,
|
||||
})
|
||||
return {"connections": items}
|
||||
|
||||
|
||||
@router.get("/{instanceId}/connections/{connectionId}/services")
|
||||
@limiter.limit("120/minute")
|
||||
async def list_connection_services(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
connectionId: str = Path(..., description="Connection ID"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Return the available services for a specific UserConnection."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
try:
|
||||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
from modules.serviceCenter import getService as getSvc
|
||||
from modules.serviceCenter.context import ServiceCenterContext
|
||||
ctx = ServiceCenterContext(
|
||||
user=context.user,
|
||||
mandate_id=str(context.mandateId) if context.mandateId else mandateId,
|
||||
feature_instance_id=instanceId,
|
||||
)
|
||||
chatService = getSvc("chat", ctx)
|
||||
securityService = getSvc("security", ctx)
|
||||
dbInterface = _buildResolverDbInterface(chatService)
|
||||
resolver = ConnectorResolver(securityService, dbInterface)
|
||||
provider = await resolver.resolve(connectionId)
|
||||
services = provider.getAvailableServices()
|
||||
_serviceLabels = {
|
||||
"sharepoint": "SharePoint",
|
||||
"outlook": "Outlook",
|
||||
"teams": "Teams",
|
||||
"onedrive": "OneDrive",
|
||||
"drive": "Google Drive",
|
||||
"gmail": "Gmail",
|
||||
"files": "Files (FTP)",
|
||||
}
|
||||
_serviceIcons = {
|
||||
"sharepoint": "sharepoint",
|
||||
"outlook": "mail",
|
||||
"teams": "chat",
|
||||
"onedrive": "cloud",
|
||||
"drive": "cloud",
|
||||
"gmail": "mail",
|
||||
"files": "folder",
|
||||
}
|
||||
items = [
|
||||
{"service": s, "label": _serviceLabels.get(s, s), "icon": _serviceIcons.get(s, "folder")}
|
||||
for s in services
|
||||
]
|
||||
return {"services": items}
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing services for connection {connectionId}: {e}")
|
||||
return JSONResponse({"services": [], "error": str(e)}, status_code=400)
|
||||
|
||||
|
||||
@router.get("/{instanceId}/connections/{connectionId}/browse")
|
||||
@limiter.limit("300/minute")
|
||||
async def browse_connection_service(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
connectionId: str = Path(..., description="Connection ID"),
|
||||
service: str = Query(..., description="Service name (e.g. sharepoint, onedrive, outlook)"),
|
||||
path: str = Query("/", description="Path within the service to browse"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Browse folders/items within a connection's service at a given path."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
try:
|
||||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
from modules.serviceCenter import getService as getSvc
|
||||
from modules.serviceCenter.context import ServiceCenterContext
|
||||
ctx = ServiceCenterContext(
|
||||
user=context.user,
|
||||
mandate_id=str(context.mandateId) if context.mandateId else mandateId,
|
||||
feature_instance_id=instanceId,
|
||||
)
|
||||
chatService = getSvc("chat", ctx)
|
||||
securityService = getSvc("security", ctx)
|
||||
dbInterface = _buildResolverDbInterface(chatService)
|
||||
resolver = ConnectorResolver(securityService, dbInterface)
|
||||
adapter = await resolver.resolveService(connectionId, service)
|
||||
entries = await adapter.browse(path, filter=None)
|
||||
items = []
|
||||
for entry in (entries or []):
|
||||
items.append({
|
||||
"name": entry.name,
|
||||
"path": entry.path,
|
||||
"isFolder": entry.isFolder,
|
||||
"size": entry.size,
|
||||
"mimeType": entry.mimeType,
|
||||
"metadata": entry.metadata if hasattr(entry, "metadata") else {},
|
||||
})
|
||||
return {"items": items, "path": path, "service": service}
|
||||
except Exception as e:
|
||||
logger.error(f"Error browsing {service} for connection {connectionId} at '{path}': {e}")
|
||||
return JSONResponse({"items": [], "error": str(e)}, status_code=400)
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Workflow CRUD
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_node_label_from_graph(graph: dict, nodeId: str) -> str:
|
||||
"""Extract human-readable label for a node from graph."""
|
||||
if not graph or not nodeId:
|
||||
return nodeId or ""
|
||||
nodes = graph.get("nodes") or []
|
||||
for n in nodes:
|
||||
if n.get("id") == nodeId:
|
||||
params = n.get("parameters") or {}
|
||||
config = params.get("config") or {}
|
||||
if isinstance(config, dict):
|
||||
label = config.get("title") or config.get("label")
|
||||
else:
|
||||
label = None
|
||||
return (
|
||||
n.get("title")
|
||||
or label
|
||||
or params.get("title")
|
||||
or params.get("label")
|
||||
or n.get("type", "")
|
||||
or nodeId
|
||||
)
|
||||
return nodeId or ""
|
||||
|
||||
|
||||
@router.get("/{instanceId}/workflows")
|
||||
@limiter.limit("60/minute")
|
||||
def get_workflows(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""List all workflows for this feature instance.
|
||||
Enriches each workflow with runCount, isRunning, stuckAtNodeId, stuckAtNodeLabel,
|
||||
createdAt, lastStartedAt.
|
||||
"""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
a2 = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
items = a2.getWorkflows()
|
||||
enriched = []
|
||||
for wf in items:
|
||||
wf_id = wf.get("id")
|
||||
runs = a2.getRunsByWorkflow(wf_id) if wf_id else []
|
||||
run_count = len(runs)
|
||||
active_run = None
|
||||
last_started_at = None
|
||||
for r in runs:
|
||||
ts = r.get("_createdAt")
|
||||
if ts and (last_started_at is None or ts > last_started_at):
|
||||
last_started_at = ts
|
||||
if r.get("status") in ("running", "paused"):
|
||||
active_run = r
|
||||
stuck_at_node_id = active_run.get("currentNodeId") if active_run else None
|
||||
stuck_at_node_label = ""
|
||||
if stuck_at_node_id and wf.get("graph"):
|
||||
stuck_at_node_label = _get_node_label_from_graph(wf["graph"], stuck_at_node_id)
|
||||
enriched.append({
|
||||
**wf,
|
||||
"runCount": run_count,
|
||||
"isRunning": active_run is not None,
|
||||
"runStatus": active_run.get("status") if active_run else None,
|
||||
"stuckAtNodeId": stuck_at_node_id,
|
||||
"stuckAtNodeLabel": stuck_at_node_label or stuck_at_node_id or "",
|
||||
"createdAt": wf.get("_createdAt"),
|
||||
"lastStartedAt": last_started_at,
|
||||
})
|
||||
return {"workflows": enriched}
|
||||
|
||||
|
||||
@router.get("/{instanceId}/workflows/{workflowId}")
|
||||
@limiter.limit("60/minute")
|
||||
def get_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="Workflow ID"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Get a single workflow by ID."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
a2 = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
wf = a2.getWorkflow(workflowId)
|
||||
if not wf:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
return wf
|
||||
|
||||
|
||||
@router.post("/{instanceId}/workflows")
|
||||
@limiter.limit("30/minute")
|
||||
def create_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
body: dict = Body(..., description="{ label, graph }"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Create a new workflow."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
a2 = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
created = a2.createWorkflow(body)
|
||||
return created
|
||||
|
||||
|
||||
@router.put("/{instanceId}/workflows/{workflowId}")
|
||||
@limiter.limit("30/minute")
|
||||
def update_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="Workflow ID"),
|
||||
body: dict = Body(..., description="{ label?, graph? }"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Update a workflow."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
a2 = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
updated = a2.updateWorkflow(workflowId, body)
|
||||
if not updated:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
return updated
|
||||
|
||||
|
||||
@router.delete("/{instanceId}/workflows/{workflowId}")
|
||||
@limiter.limit("30/minute")
|
||||
def delete_workflow(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="Workflow ID"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Delete a workflow."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
a2 = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
if not a2.deleteWorkflow(workflowId):
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
return {"success": True}
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Runs and Resume
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{instanceId}/workflows/{workflowId}/runs")
|
||||
@limiter.limit("60/minute")
|
||||
def get_workflow_runs(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Path(..., description="Workflow ID"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Get runs for a workflow."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
a2 = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
if not a2.getWorkflow(workflowId):
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
runs = a2.getRunsByWorkflow(workflowId)
|
||||
return {"runs": runs}
|
||||
|
||||
|
||||
@router.post("/{instanceId}/runs/{runId}/resume")
|
||||
@limiter.limit("30/minute")
|
||||
async def resume_run(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
runId: str = Path(..., description="Run ID"),
|
||||
body: dict = Body(..., description="{ taskId, result }"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Resume a paused run after task completion."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
a2 = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
run = a2.getRun(runId)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
taskId = body.get("taskId")
|
||||
result = body.get("result")
|
||||
if not taskId or result is None:
|
||||
raise HTTPException(status_code=400, detail="taskId and result required")
|
||||
task = a2.getTask(taskId)
|
||||
if not task or task.get("runId") != runId:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
if task.get("status") != "pending":
|
||||
raise HTTPException(status_code=400, detail="Task already completed")
|
||||
a2.updateTask(taskId, status="completed", result=result)
|
||||
nodeId = task.get("nodeId")
|
||||
nodeOutputs = dict(run.get("nodeOutputs") or {})
|
||||
nodeOutputs[nodeId] = result
|
||||
runContext = run.get("context") or {}
|
||||
connectionMap = runContext.get("connectionMap", {})
|
||||
inputSources = runContext.get("inputSources", {})
|
||||
workflowId = run.get("workflowId")
|
||||
wf = a2.getWorkflow(workflowId) if workflowId else None
|
||||
if not wf or not wf.get("graph"):
|
||||
raise HTTPException(status_code=400, detail="Workflow graph not found")
|
||||
graph = wf["graph"]
|
||||
services = getAutomation2Services(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
resume_result = await executeGraph(
|
||||
graph=graph,
|
||||
services=services,
|
||||
workflowId=workflowId,
|
||||
instanceId=instanceId,
|
||||
userId=str(context.user.id) if context.user else None,
|
||||
mandateId=mandateId,
|
||||
automation2_interface=a2,
|
||||
initialNodeOutputs=nodeOutputs,
|
||||
startAfterNodeId=nodeId,
|
||||
runId=runId,
|
||||
)
|
||||
return resume_result
|
||||
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Tasks
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{instanceId}/tasks")
|
||||
@limiter.limit("60/minute")
|
||||
def get_tasks(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
workflowId: str = Query(None, description="Filter by workflow ID"),
|
||||
status: str = Query(None, description="Filter: pending, completed, rejected"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Get tasks - by default those assigned to current user, or all if no assignee filter.
|
||||
Enriches each task with workflowLabel and createdAt (_createdAt).
|
||||
"""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
a2 = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
assigneeId = str(context.user.id) if context.user else None
|
||||
items = a2.getTasks(workflowId=workflowId, status=status, assigneeId=assigneeId)
|
||||
workflows = {w["id"]: w for w in a2.getWorkflows()}
|
||||
enriched = []
|
||||
for t in items:
|
||||
wf = workflows.get(t.get("workflowId") or "")
|
||||
enriched.append({
|
||||
**t,
|
||||
"workflowLabel": wf.get("label", t.get("workflowId", "")) if wf else t.get("workflowId", ""),
|
||||
"createdAt": t.get("_createdAt"),
|
||||
})
|
||||
return {"tasks": enriched}
|
||||
|
||||
|
||||
@router.post("/{instanceId}/tasks/{taskId}/complete")
|
||||
@limiter.limit("30/minute")
|
||||
async def complete_task(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature instance ID"),
|
||||
taskId: str = Path(..., description="Task ID"),
|
||||
body: dict = Body(..., description="{ result }"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
) -> dict:
|
||||
"""Complete a task and resume the run."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
a2 = getAutomation2Interface(context.user, mandateId, instanceId)
|
||||
task = a2.getTask(taskId)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
runId = task.get("runId")
|
||||
result = body.get("result")
|
||||
if result is None:
|
||||
raise HTTPException(status_code=400, detail="result required")
|
||||
run = a2.getRun(runId)
|
||||
if not run:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
if task.get("status") != "pending":
|
||||
raise HTTPException(status_code=400, detail="Task already completed")
|
||||
a2.updateTask(taskId, status="completed", result=result)
|
||||
nodeId = task.get("nodeId")
|
||||
nodeOutputs = dict(run.get("nodeOutputs") or {})
|
||||
nodeOutputs[nodeId] = result
|
||||
workflowId = run.get("workflowId")
|
||||
wf = a2.getWorkflow(workflowId) if workflowId else None
|
||||
if not wf or not wf.get("graph"):
|
||||
raise HTTPException(status_code=400, detail="Workflow graph not found")
|
||||
graph = wf["graph"]
|
||||
services = getAutomation2Services(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
return await executeGraph(
|
||||
graph=graph,
|
||||
services=services,
|
||||
workflowId=workflowId,
|
||||
instanceId=instanceId,
|
||||
userId=str(context.user.id) if context.user else None,
|
||||
mandateId=mandateId,
|
||||
automation2_interface=a2,
|
||||
initialNodeOutputs=nodeOutputs,
|
||||
startAfterNodeId=nodeId,
|
||||
runId=runId,
|
||||
)
|
||||
|
|
@ -514,7 +514,7 @@ class ChatObjects:
|
|||
|
||||
# Handle simple value (equals operator)
|
||||
if not isinstance(filter_value, dict):
|
||||
if record_value != filter_value:
|
||||
if str(record_value).lower() != str(filter_value).lower():
|
||||
matches = False
|
||||
break
|
||||
continue
|
||||
|
|
@ -524,7 +524,7 @@ class ChatObjects:
|
|||
filter_val = filter_value.get("value")
|
||||
|
||||
if operator in ["equals", "eq"]:
|
||||
if record_value != filter_val:
|
||||
if str(record_value).lower() != str(filter_val).lower():
|
||||
matches = False
|
||||
break
|
||||
|
||||
|
|
@ -609,7 +609,7 @@ class ChatObjects:
|
|||
|
||||
else:
|
||||
# Unknown operator - default to equals
|
||||
if record_value != filter_val:
|
||||
if str(record_value).lower() != str(filter_val).lower():
|
||||
matches = False
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -1192,24 +1192,79 @@ async def deleteDocumentRoute(
|
|||
|
||||
|
||||
def _extractText(content: bytes, mimeType: str, fileName: str) -> Optional[str]:
|
||||
"""Extract text from uploaded file content."""
|
||||
"""Extract text from uploaded file content (TXT, MD, HTML, PDF, DOCX, XLSX, PPTX)."""
|
||||
import io
|
||||
|
||||
lowerName = fileName.lower()
|
||||
try:
|
||||
if mimeType == "text/plain" or fileName.endswith(".txt"):
|
||||
if mimeType in ("text/plain",) or lowerName.endswith(".txt"):
|
||||
return content.decode("utf-8", errors="replace")
|
||||
if mimeType == "text/markdown" or fileName.endswith(".md"):
|
||||
|
||||
if mimeType in ("text/markdown",) or lowerName.endswith(".md"):
|
||||
return content.decode("utf-8", errors="replace")
|
||||
if "pdf" in mimeType or fileName.endswith(".pdf"):
|
||||
|
||||
if mimeType in ("text/html",) or lowerName.endswith((".html", ".htm")):
|
||||
from html.parser import HTMLParser
|
||||
class _Strip(HTMLParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._parts: list[str] = []
|
||||
def handle_data(self, d):
|
||||
self._parts.append(d)
|
||||
def result(self):
|
||||
return " ".join(self._parts)
|
||||
parser = _Strip()
|
||||
parser.feed(content.decode("utf-8", errors="replace"))
|
||||
return parser.result()
|
||||
|
||||
if "pdf" in mimeType or lowerName.endswith(".pdf"):
|
||||
try:
|
||||
import io
|
||||
from PyPDF2 import PdfReader
|
||||
reader = PdfReader(io.BytesIO(content))
|
||||
text = ""
|
||||
for page in reader.pages:
|
||||
text += page.extract_text() or ""
|
||||
return text
|
||||
return "".join(page.extract_text() or "" for page in reader.pages)
|
||||
except ImportError:
|
||||
logger.warning("PyPDF2 not installed, cannot extract PDF text")
|
||||
return None
|
||||
|
||||
if "wordprocessingml" in mimeType or lowerName.endswith(".docx"):
|
||||
try:
|
||||
from docx import Document
|
||||
doc = Document(io.BytesIO(content))
|
||||
return "\n".join(p.text for p in doc.paragraphs if p.text)
|
||||
except ImportError:
|
||||
logger.warning("python-docx not installed, cannot extract DOCX text")
|
||||
return None
|
||||
|
||||
if "spreadsheetml" in mimeType or lowerName.endswith(".xlsx"):
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
wb = load_workbook(io.BytesIO(content), read_only=True, data_only=True)
|
||||
parts: list[str] = []
|
||||
for ws in wb.worksheets:
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
cells = [str(c) for c in row if c is not None]
|
||||
if cells:
|
||||
parts.append("\t".join(cells))
|
||||
return "\n".join(parts)
|
||||
except ImportError:
|
||||
logger.warning("openpyxl not installed, cannot extract XLSX text")
|
||||
return None
|
||||
|
||||
if "presentationml" in mimeType or lowerName.endswith(".pptx"):
|
||||
try:
|
||||
from pptx import Presentation
|
||||
prs = Presentation(io.BytesIO(content))
|
||||
parts = []
|
||||
for slide in prs.slides:
|
||||
for shape in slide.shapes:
|
||||
if shape.has_text_frame:
|
||||
parts.append(shape.text_frame.text)
|
||||
return "\n".join(parts)
|
||||
except ImportError:
|
||||
logger.warning("python-pptx not installed, cannot extract PPTX text")
|
||||
return None
|
||||
|
||||
logger.info(f"No text extractor for {fileName} (mime={mimeType})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Text extraction failed for {fileName}: {e}")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -331,7 +331,8 @@ def _getDocumentSummaries(contextId: str, userId: str, interface) -> Optional[Li
|
|||
elif doc.get("extractedText"):
|
||||
summaries.append(f"[{doc.get('fileName', 'Dokument')}] {doc['extractedText'][:200]}...")
|
||||
return summaries if summaries else None
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load document summaries for context {contextId}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,19 +39,21 @@ EXTRACTABLE_BINARY_MIME_TYPES = frozenset({
|
|||
class NeutralizationService:
|
||||
"""Service for handling data neutralization operations"""
|
||||
|
||||
def __init__(self, serviceCenter=None, NamesToParse: List[str] = None):
|
||||
def __init__(self, serviceCenter=None, getServiceFn=None, NamesToParse: List[str] = None):
|
||||
"""Initialize the service with user context and anonymization processors
|
||||
|
||||
Args:
|
||||
serviceCenter: Service center instance for accessing other services
|
||||
serviceCenter: Service center context or legacy service center instance
|
||||
getServiceFn: Service resolver function (injected by ServiceCenter resolver)
|
||||
NamesToParse: List of names to parse and replace (case-insensitive)
|
||||
"""
|
||||
self.services = serviceCenter
|
||||
self.interfaceDbComponent = serviceCenter.interfaceDbComponent
|
||||
self._getService = getServiceFn
|
||||
self.interfaceDbComponent = getattr(serviceCenter, "interfaceDbComponent", None)
|
||||
|
||||
# Create feature-specific interface for neutralizer DB operations
|
||||
self.interfaceNeutralizer: InterfaceFeatureNeutralizer = None
|
||||
if serviceCenter and serviceCenter.interfaceDbApp:
|
||||
if serviceCenter and getattr(serviceCenter, "interfaceDbApp", None):
|
||||
dbApp = serviceCenter.interfaceDbApp
|
||||
self.interfaceNeutralizer = getNeutralizerInterface(
|
||||
currentUser=dbApp.currentUser,
|
||||
|
|
@ -59,10 +61,10 @@ class NeutralizationService:
|
|||
featureInstanceId=getattr(serviceCenter, 'featureInstanceId', None) or getattr(dbApp, 'featureInstanceId', None)
|
||||
)
|
||||
|
||||
# Initialize anonymization processors
|
||||
self.NamesToParse = NamesToParse or []
|
||||
self.textProcessor = TextProcessor(NamesToParse)
|
||||
self.listProcessor = ListProcessor(NamesToParse)
|
||||
namesList = NamesToParse if isinstance(NamesToParse, list) else []
|
||||
self.NamesToParse = namesList
|
||||
self.textProcessor = TextProcessor(namesList)
|
||||
self.listProcessor = ListProcessor(namesList)
|
||||
self.binaryProcessor = BinaryProcessor()
|
||||
self.commonUtils = CommonUtils()
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ from modules.shared.configuration import APP_CONFIG
|
|||
from modules.security.rbac import RbacClass
|
||||
from modules.datamodels.datamodelRbac import AccessRuleContext
|
||||
from modules.datamodels.datamodelUam import AccessLevel
|
||||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
||||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC, getRecordsetPaginatedWithRBAC, getDistinctColumnValuesWithRBAC
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -175,17 +176,21 @@ class RealEstateObjects:
|
|||
|
||||
return Projekt(**records[0])
|
||||
|
||||
def getProjekte(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Projekt]:
|
||||
"""Get all projects matching the filter."""
|
||||
records = getRecordsetWithRBAC(
|
||||
def getProjekte(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Projekt], PaginatedResult]:
|
||||
"""Get all projects matching the filter with optional DB-level pagination."""
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
self.db,
|
||||
Projekt,
|
||||
self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter=recordFilter or {},
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
|
||||
return [Projekt(**r) for r in records]
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = [Projekt(**r) for r in result.items]
|
||||
return result
|
||||
return [Projekt(**r) for r in result]
|
||||
|
||||
def updateProjekt(self, projektId_or_projekt: Union[str, Projekt], updateData: Optional[Dict[str, Any]] = None) -> Optional[Projekt]:
|
||||
"""Update a project.
|
||||
|
|
@ -265,20 +270,23 @@ class RealEstateObjects:
|
|||
|
||||
return Parzelle(**records[0])
|
||||
|
||||
def getParzellen(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Parzelle]:
|
||||
"""Get all plots matching the filter."""
|
||||
# Resolve location names to IDs if needed
|
||||
def getParzellen(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Parzelle], PaginatedResult]:
|
||||
"""Get all plots matching the filter with optional DB-level pagination."""
|
||||
if recordFilter:
|
||||
recordFilter = self._resolveLocationFilters(recordFilter)
|
||||
|
||||
records = getRecordsetWithRBAC(
|
||||
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
self.db,
|
||||
Parzelle,
|
||||
self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter=recordFilter or {}
|
||||
)
|
||||
|
||||
return [Parzelle(**r) for r in records]
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = [Parzelle(**r) for r in result.items]
|
||||
return result
|
||||
return [Parzelle(**r) for r in result]
|
||||
|
||||
def _resolveLocationFilters(self, recordFilter: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -477,18 +485,23 @@ class RealEstateObjects:
|
|||
|
||||
return Dokument(**records[0])
|
||||
|
||||
def getDokumente(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Dokument]:
|
||||
"""Get all documents matching the filter."""
|
||||
records = getRecordsetWithRBAC(
|
||||
def getDokumente(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Dokument], PaginatedResult]:
|
||||
"""Get all documents matching the filter with optional DB-level pagination."""
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
self.db,
|
||||
Dokument,
|
||||
self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter=recordFilter or {},
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
return [Dokument(**r) for r in records]
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = [Dokument(**r) for r in result.items]
|
||||
return result
|
||||
return [Dokument(**r) for r in result]
|
||||
|
||||
def updateDokument(self, dokumentId: str, updateData: Dict[str, Any]) -> Optional[Dokument]:
|
||||
"""Update a document."""
|
||||
|
|
@ -552,18 +565,23 @@ class RealEstateObjects:
|
|||
|
||||
return Gemeinde(**records[0])
|
||||
|
||||
def getGemeinden(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Gemeinde]:
|
||||
"""Get all municipalities matching the filter."""
|
||||
records = getRecordsetWithRBAC(
|
||||
def getGemeinden(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Gemeinde], PaginatedResult]:
|
||||
"""Get all municipalities matching the filter with optional DB-level pagination."""
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
self.db,
|
||||
Gemeinde,
|
||||
self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter=recordFilter or {},
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
return [Gemeinde(**r) for r in records]
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = [Gemeinde(**r) for r in result.items]
|
||||
return result
|
||||
return [Gemeinde(**r) for r in result]
|
||||
|
||||
def updateGemeinde(self, gemeindeId: str, updateData: Dict[str, Any]) -> Optional[Gemeinde]:
|
||||
"""Update a municipality."""
|
||||
|
|
@ -627,18 +645,23 @@ class RealEstateObjects:
|
|||
|
||||
return Kanton(**records[0])
|
||||
|
||||
def getKantone(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Kanton]:
|
||||
"""Get all cantons matching the filter."""
|
||||
records = getRecordsetWithRBAC(
|
||||
def getKantone(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Kanton], PaginatedResult]:
|
||||
"""Get all cantons matching the filter with optional DB-level pagination."""
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
self.db,
|
||||
Kanton,
|
||||
self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter=recordFilter or {},
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
return [Kanton(**r) for r in records]
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = [Kanton(**r) for r in result.items]
|
||||
return result
|
||||
return [Kanton(**r) for r in result]
|
||||
|
||||
def updateKanton(self, kantonId: str, updateData: Dict[str, Any]) -> Optional[Kanton]:
|
||||
"""Update a canton."""
|
||||
|
|
@ -700,16 +723,21 @@ class RealEstateObjects:
|
|||
|
||||
return Land(**records[0])
|
||||
|
||||
def getLaender(self, recordFilter: Optional[Dict[str, Any]] = None) -> List[Land]:
|
||||
"""Get all countries matching the filter."""
|
||||
records = getRecordsetWithRBAC(
|
||||
def getLaender(self, recordFilter: Optional[Dict[str, Any]] = None, pagination: Optional[PaginationParams] = None) -> Union[List[Land], PaginatedResult]:
|
||||
"""Get all countries matching the filter with optional DB-level pagination."""
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
self.db,
|
||||
Land,
|
||||
self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter=recordFilter or {},
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
return [Land(**r) for r in records]
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = [Land(**r) for r in result.items]
|
||||
return result
|
||||
return [Land(**r) for r in result]
|
||||
|
||||
def updateLand(self, landId: str, updateData: Dict[str, Any]) -> Optional[Land]:
|
||||
"""Update a country."""
|
||||
|
|
|
|||
|
|
@ -252,6 +252,31 @@ def get_projects(
|
|||
return PaginatedResponse(items=items, pagination=None)
|
||||
|
||||
|
||||
@router.get("/{instanceId}/projects/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_project_filter_values(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in real estate projects."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
recordFilter = {"featureInstanceId": instanceId}
|
||||
items = interface.getProjekte(recordFilter=recordFilter)
|
||||
itemDicts = [i.model_dump() if hasattr(i, 'model_dump') else i for i in items]
|
||||
return _handleFilterValuesRequest(itemDicts, column, pagination)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for projects: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{instanceId}/projects/{projectId}", response_model=Projekt)
|
||||
@limiter.limit("30/minute")
|
||||
def get_project_by_id(
|
||||
|
|
@ -384,6 +409,31 @@ def get_parcels(
|
|||
return PaginatedResponse(items=items, pagination=None)
|
||||
|
||||
|
||||
@router.get("/{instanceId}/parcels/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_parcel_filter_values(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in real estate parcels."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
interface = getRealEstateInterface(
|
||||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
recordFilter = {"featureInstanceId": instanceId}
|
||||
items = interface.getParzellen(recordFilter=recordFilter)
|
||||
itemDicts = [i.model_dump() if hasattr(i, 'model_dump') else i for i in items]
|
||||
return _handleFilterValuesRequest(itemDicts, column, pagination)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for parcels: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{instanceId}/parcels/{parcelId}", response_model=Parzelle)
|
||||
@limiter.limit("30/minute")
|
||||
def get_parcel_by_id(
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from pydantic import ValidationError
|
|||
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
||||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC, getRecordsetPaginatedWithRBAC, getDistinctColumnValuesWithRBAC
|
||||
from modules.security.rbac import RbacClass
|
||||
from modules.datamodels.datamodelUam import User, AccessLevel
|
||||
from modules.datamodels.datamodelRbac import AccessRuleContext
|
||||
|
|
@ -414,7 +414,7 @@ class TrusteeObjects:
|
|||
|
||||
# Handle simple value (equals operator)
|
||||
if not isinstance(filterValue, dict):
|
||||
if recordValue == filterValue:
|
||||
if str(recordValue).lower() == str(filterValue).lower():
|
||||
fieldFiltered.append(record)
|
||||
continue
|
||||
|
||||
|
|
@ -424,7 +424,7 @@ class TrusteeObjects:
|
|||
|
||||
matches = False
|
||||
if operator in ["equals", "eq"]:
|
||||
matches = recordValue == filterVal
|
||||
matches = str(recordValue).lower() == str(filterVal).lower()
|
||||
|
||||
elif operator == "contains":
|
||||
recordStr = str(recordValue).lower() if recordValue is not None else ""
|
||||
|
|
@ -577,8 +577,8 @@ class TrusteeObjects:
|
|||
return None
|
||||
return TrusteeOrganisation(**{k: v for k, v in records[0].items() if not k.startswith("_")})
|
||||
|
||||
def getAllOrganisations(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
|
||||
"""Get all organisations with RBAC filtering.
|
||||
def getAllOrganisations(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
|
||||
"""Get all organisations with RBAC filtering and optional DB-level pagination.
|
||||
|
||||
Note: Organisations are managed at system level (by mandate).
|
||||
Feature-level filtering (trustee.access) is NOT applied here because:
|
||||
|
|
@ -586,42 +586,18 @@ class TrusteeObjects:
|
|||
- trustee.access grants access to specific orgs for other users
|
||||
- New organisations wouldn't be visible without an access record
|
||||
"""
|
||||
# Debug: Log user info and permissions
|
||||
logger.debug(f"getAllOrganisations called for user {self.userId}, mandateId: {self.mandateId}")
|
||||
|
||||
# System RBAC filtering (filters by mandate for GROUP access level)
|
||||
records = getRecordsetWithRBAC(
|
||||
return getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteeOrganisation,
|
||||
currentUser=self.currentUser,
|
||||
pagination=params,
|
||||
recordFilter=None,
|
||||
orderBy="id",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
logger.debug(f"getAllOrganisations: getRecordsetWithRBAC returned {len(records)} records")
|
||||
|
||||
# Apply pagination
|
||||
totalItems = len(records)
|
||||
if params:
|
||||
pageSize = params.pageSize or 20
|
||||
page = params.page or 1
|
||||
startIdx = (page - 1) * pageSize
|
||||
endIdx = startIdx + pageSize
|
||||
items = records[startIdx:endIdx]
|
||||
totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
|
||||
else:
|
||||
items = records
|
||||
totalPages = 1
|
||||
page = 1
|
||||
pageSize = totalItems
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
)
|
||||
|
||||
def updateOrganisation(self, orgId: str, data: Dict[str, Any]) -> Optional[TrusteeOrganisation]:
|
||||
"""Update an organisation."""
|
||||
|
|
@ -677,51 +653,40 @@ class TrusteeObjects:
|
|||
return None
|
||||
return TrusteeRole(**{k: v for k, v in records[0].items() if not k.startswith("_")})
|
||||
|
||||
def getAllRoles(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
|
||||
"""Get all roles with RBAC filtering.
|
||||
def getAllRoles(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
|
||||
"""Get all roles with RBAC filtering and optional DB-level pagination.
|
||||
|
||||
Note: Roles are available to all users with trustee access.
|
||||
They are not filtered by organisation since they define the role types.
|
||||
Users with ALL access level see all roles; others need trustee.access records.
|
||||
|
||||
NOTE(post-filter): Feature-level access check runs after the paginated query.
|
||||
When pagination is active the totals may overcount if the user lacks any
|
||||
trustee.access record (in that case the entire result is emptied).
|
||||
"""
|
||||
records = getRecordsetWithRBAC(
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteeRole,
|
||||
currentUser=self.currentUser,
|
||||
pagination=params,
|
||||
recordFilter=None,
|
||||
orderBy="id",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
|
||||
# Users with ALL access level (from system RBAC) see all roles
|
||||
# Others need at least one trustee.access record
|
||||
accessLevel = self.getRbacAccessLevel(TrusteeRole, "read")
|
||||
if accessLevel != AccessLevel.ALL:
|
||||
userAccess = self.getAllUserAccess(self.userId)
|
||||
if not userAccess:
|
||||
records = [] # No trustee access at all
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = []
|
||||
result.totalItems = 0
|
||||
result.totalPages = 0
|
||||
return result
|
||||
return []
|
||||
|
||||
totalItems = len(records)
|
||||
if params:
|
||||
pageSize = params.pageSize or 20
|
||||
page = params.page or 1
|
||||
startIdx = (page - 1) * pageSize
|
||||
endIdx = startIdx + pageSize
|
||||
items = records[startIdx:endIdx]
|
||||
totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
|
||||
else:
|
||||
items = records
|
||||
totalPages = 1
|
||||
page = 1
|
||||
pageSize = totalItems
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
)
|
||||
return result
|
||||
|
||||
def updateRole(self, roleId: str, data: Dict[str, Any]) -> Optional[TrusteeRole]:
|
||||
"""Update a role (sysadmin only)."""
|
||||
|
|
@ -788,116 +753,113 @@ class TrusteeObjects:
|
|||
return None
|
||||
return TrusteeAccess(**{k: v for k, v in records[0].items() if not k.startswith("_")})
|
||||
|
||||
def getAllAccess(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
|
||||
def getAllAccess(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
|
||||
"""Get all access records with RBAC filtering + feature-level filtering.
|
||||
|
||||
Users with ALL access level see all access records.
|
||||
Others can only see access records for organisations they have admin access to.
|
||||
|
||||
NOTE(post-filter): Feature-level admin-org filtering runs after the paginated
|
||||
query, so totals may overcount when the user has restricted org access.
|
||||
"""
|
||||
# Step 1: System RBAC filtering
|
||||
records = getRecordsetWithRBAC(
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteeAccess,
|
||||
currentUser=self.currentUser,
|
||||
pagination=params,
|
||||
recordFilter=None,
|
||||
orderBy="id",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
|
||||
# Users with ALL access level (from system RBAC) see all records
|
||||
accessLevel = self.getRbacAccessLevel(TrusteeAccess, "read")
|
||||
|
||||
|
||||
if accessLevel != AccessLevel.ALL:
|
||||
# Step 2: Feature-level filtering - only see access for organisations where user is admin
|
||||
userAccess = self.getAllUserAccess(self.userId)
|
||||
|
||||
# Get organisations where user has admin role
|
||||
adminOrgs = set()
|
||||
for access in userAccess:
|
||||
if access.get("roleId") == "admin":
|
||||
adminOrgs.add(access.get("organisationId"))
|
||||
|
||||
# Filter records to only show those in admin organisations
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
if adminOrgs:
|
||||
result.items = [r for r in result.items if r.get("organisationId") in adminOrgs]
|
||||
else:
|
||||
result.items = []
|
||||
return result
|
||||
|
||||
if adminOrgs:
|
||||
records = [r for r in records if r.get("organisationId") in adminOrgs]
|
||||
result = [r for r in result if r.get("organisationId") in adminOrgs]
|
||||
else:
|
||||
records = []
|
||||
result = []
|
||||
|
||||
totalItems = len(records)
|
||||
if params:
|
||||
pageSize = params.pageSize or 20
|
||||
page = params.page or 1
|
||||
startIdx = (page - 1) * pageSize
|
||||
endIdx = startIdx + pageSize
|
||||
items = records[startIdx:endIdx]
|
||||
totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
|
||||
else:
|
||||
items = records
|
||||
totalPages = 1
|
||||
page = 1
|
||||
pageSize = totalItems
|
||||
return result
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
)
|
||||
|
||||
def getAccessByOrganisation(self, organisationId: str) -> List[TrusteeAccess]:
|
||||
def getAccessByOrganisation(self, organisationId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeAccess], PaginatedResult]:
|
||||
"""Get all access records for a specific organisation.
|
||||
|
||||
Requires admin role for the organisation.
|
||||
"""
|
||||
# Check if user has admin access for this organisation
|
||||
if not self.checkUserTrusteePermission(self.userId, organisationId, "admin"):
|
||||
logger.warning(f"User {self.userId} lacks admin role for organisation {organisationId}")
|
||||
return []
|
||||
|
||||
records = getRecordsetWithRBAC(
|
||||
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteeAccess,
|
||||
currentUser=self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter={"organisationId": organisationId},
|
||||
orderBy="id",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in records]
|
||||
|
||||
def getAccessByUser(self, userId: str) -> List[TrusteeAccess]:
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items]
|
||||
return result
|
||||
return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result]
|
||||
|
||||
def getAccessByUser(self, userId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeAccess], PaginatedResult]:
|
||||
"""Get all access records for a specific user.
|
||||
|
||||
Users with ALL access level see all access records.
|
||||
Others can only see access records for organisations where they have admin role.
|
||||
|
||||
NOTE(post-filter): Admin-org filtering runs after the paginated query,
|
||||
so totals may overcount when the user has restricted org access.
|
||||
"""
|
||||
records = getRecordsetWithRBAC(
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteeAccess,
|
||||
currentUser=self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter={"userId": userId},
|
||||
orderBy="id",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
|
||||
# Users with ALL access level (from system RBAC) see all records
|
||||
|
||||
accessLevel = self.getRbacAccessLevel(TrusteeAccess, "read")
|
||||
if accessLevel == AccessLevel.ALL:
|
||||
return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in records]
|
||||
|
||||
# Filter to only organisations where current user has admin role
|
||||
userAccess = self.getAllUserAccess(self.userId)
|
||||
adminOrgs = set()
|
||||
for access in userAccess:
|
||||
if access.get("roleId") == "admin":
|
||||
adminOrgs.add(access.get("organisationId"))
|
||||
|
||||
filtered = [r for r in records if r.get("organisationId") in adminOrgs]
|
||||
return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in filtered]
|
||||
|
||||
if accessLevel != AccessLevel.ALL:
|
||||
userAccess = self.getAllUserAccess(self.userId)
|
||||
adminOrgs = set()
|
||||
for access in userAccess:
|
||||
if access.get("roleId") == "admin":
|
||||
adminOrgs.add(access.get("organisationId"))
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = [r for r in result.items if r.get("organisationId") in adminOrgs]
|
||||
result.items = [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items]
|
||||
return result
|
||||
result = [r for r in result if r.get("organisationId") in adminOrgs]
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items]
|
||||
return result
|
||||
return [TrusteeAccess(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result]
|
||||
|
||||
def updateAccess(self, accessId: str, data: Dict[str, Any]) -> Optional[TrusteeAccess]:
|
||||
"""Update an access record. Requires admin role for the organisation or ALL access level."""
|
||||
|
|
@ -988,54 +950,36 @@ class TrusteeObjects:
|
|||
return None
|
||||
return TrusteeContract(**{k: v for k, v in records[0].items() if not k.startswith("_")})
|
||||
|
||||
def getAllContracts(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
|
||||
"""Get all contracts with RBAC filtering + feature-level access filtering."""
|
||||
# Step 1: System RBAC filtering
|
||||
records = getRecordsetWithRBAC(
|
||||
def getAllContracts(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
|
||||
"""Get all contracts with RBAC filtering and optional DB-level pagination."""
|
||||
return getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteeContract,
|
||||
currentUser=self.currentUser,
|
||||
pagination=params,
|
||||
recordFilter=None,
|
||||
orderBy="id",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
|
||||
totalItems = len(records)
|
||||
if params:
|
||||
pageSize = params.pageSize or 20
|
||||
page = params.page or 1
|
||||
startIdx = (page - 1) * pageSize
|
||||
endIdx = startIdx + pageSize
|
||||
items = records[startIdx:endIdx]
|
||||
totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
|
||||
else:
|
||||
items = records
|
||||
totalPages = 1
|
||||
page = 1
|
||||
pageSize = totalItems
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
)
|
||||
|
||||
def getContractsByOrganisation(self, organisationId: str) -> List[TrusteeContract]:
|
||||
def getContractsByOrganisation(self, organisationId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeContract], PaginatedResult]:
|
||||
"""Get all contracts for a specific organisation."""
|
||||
# Step 1: System RBAC filtering
|
||||
records = getRecordsetWithRBAC(
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteeContract,
|
||||
currentUser=self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter={"organisationId": organisationId},
|
||||
orderBy="label",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
return [TrusteeContract(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in records]
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = [TrusteeContract(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result.items]
|
||||
return result
|
||||
return [TrusteeContract(**{k: v for k, v in r.items() if not k.startswith("_")}) for r in result]
|
||||
|
||||
def updateContract(self, contractId: str, data: Dict[str, Any]) -> Optional[TrusteeContract]:
|
||||
"""Update a contract (organisationId is immutable)."""
|
||||
|
|
@ -1142,75 +1086,58 @@ class TrusteeObjects:
|
|||
# Legacy fallback: documentData was stored directly (for migration)
|
||||
return record.get("documentData")
|
||||
|
||||
def getAllDocuments(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
|
||||
"""Get all documents with RBAC filtering + feature-level access filtering (metadata only)."""
|
||||
# Step 1: System RBAC filtering
|
||||
records = getRecordsetWithRBAC(
|
||||
def getAllDocuments(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
|
||||
"""Get all documents with RBAC filtering and optional DB-level pagination (metadata only).
|
||||
|
||||
Filtering, sorting, and pagination are handled at the SQL level by
|
||||
getRecordsetPaginatedWithRBAC. Binary documentData is stripped from
|
||||
the returned items.
|
||||
"""
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteeDocument,
|
||||
currentUser=self.currentUser,
|
||||
pagination=params,
|
||||
recordFilter=None,
|
||||
orderBy="documentName",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
|
||||
# Clean records (remove binary data and internal fields) - keep as dicts for filtering/sorting
|
||||
cleanedRecords = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") and k != "documentData"}
|
||||
cleanedRecords.append(cleanedRecord)
|
||||
def _cleanDocumentRecords(records):
|
||||
return [
|
||||
TrusteeDocument(**{k: v for k, v in r.items() if not k.startswith("_") and k != "documentData"})
|
||||
for r in records
|
||||
]
|
||||
|
||||
# Step 2: Apply filters (search and field filters)
|
||||
filteredRecords = self._applyFilters(cleanedRecords, params)
|
||||
|
||||
# Step 3: Apply sorting
|
||||
sortedRecords = self._applySorting(filteredRecords, params)
|
||||
|
||||
# Step 4: Convert to Pydantic objects
|
||||
pydanticItems = [TrusteeDocument(**r) for r in sortedRecords]
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = _cleanDocumentRecords(result.items)
|
||||
return result
|
||||
return _cleanDocumentRecords(result)
|
||||
|
||||
# Step 5: Apply pagination
|
||||
totalItems = len(pydanticItems)
|
||||
if params:
|
||||
pageSize = params.pageSize or 20
|
||||
page = params.page or 1
|
||||
startIdx = (page - 1) * pageSize
|
||||
endIdx = startIdx + pageSize
|
||||
items = pydanticItems[startIdx:endIdx]
|
||||
totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
|
||||
else:
|
||||
items = pydanticItems
|
||||
totalPages = 1
|
||||
page = 1
|
||||
pageSize = totalItems
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
)
|
||||
|
||||
def getDocumentsByContract(self, contractId: str) -> List[TrusteeDocument]:
|
||||
def getDocumentsByContract(self, contractId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteeDocument], PaginatedResult]:
|
||||
"""Get all documents for a specific contract."""
|
||||
# Step 1: System RBAC filtering
|
||||
records = getRecordsetWithRBAC(
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteeDocument,
|
||||
currentUser=self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter={"contractId": contractId},
|
||||
orderBy="documentName",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
|
||||
result = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") and k != "documentData"}
|
||||
result.append(TrusteeDocument(**cleanedRecord))
|
||||
return result
|
||||
|
||||
def _cleanDocumentRecords(records):
|
||||
return [
|
||||
TrusteeDocument(**{k: v for k, v in r.items() if not k.startswith("_") and k != "documentData"})
|
||||
for r in records
|
||||
]
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = _cleanDocumentRecords(result.items)
|
||||
return result
|
||||
return _cleanDocumentRecords(result)
|
||||
|
||||
def updateDocument(self, documentId: str, data: Dict[str, Any]) -> Optional[TrusteeDocument]:
|
||||
"""Update a document.
|
||||
|
|
@ -1340,101 +1267,94 @@ class TrusteeObjects:
|
|||
return None
|
||||
return self._toTrusteePositionOrDelete(records[0], deleteCorrupt=True)
|
||||
|
||||
def getAllPositions(self, params: Optional[PaginationParams] = None) -> PaginatedResult:
|
||||
"""Get all positions with RBAC filtering + feature-level access filtering."""
|
||||
# Step 1: System RBAC filtering
|
||||
records = getRecordsetWithRBAC(
|
||||
def getAllPositions(self, params: Optional[PaginationParams] = None) -> Union[List[Dict], PaginatedResult]:
|
||||
"""Get all positions with RBAC filtering and optional DB-level pagination.
|
||||
|
||||
Filtering, sorting, and pagination are handled at the SQL level.
|
||||
Post-processing cleans internal fields (keeps _createdAt) and validates
|
||||
each record via _toTrusteePositionOrDelete (corrupt rows are deleted).
|
||||
|
||||
NOTE(post-process): totalItems may slightly overcount when corrupt legacy
|
||||
records are removed from the current page.
|
||||
"""
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteePosition,
|
||||
currentUser=self.currentUser,
|
||||
pagination=params,
|
||||
recordFilter=None,
|
||||
orderBy="valuta",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
|
||||
# Clean records (remove internal fields except _createdAt) - keep as dicts for filtering/sorting
|
||||
# Keep _createdAt for display in frontend
|
||||
keepFields = {'_createdAt'}
|
||||
cleanedRecords = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") or k in keepFields}
|
||||
cleanedRecords.append(cleanedRecord)
|
||||
|
||||
# Step 2: Apply filters (search and field filters)
|
||||
filteredRecords = self._applyFilters(cleanedRecords, params)
|
||||
|
||||
# Step 3: Apply sorting
|
||||
sortedRecords = self._applySorting(filteredRecords, params)
|
||||
|
||||
# Step 4: Convert to Pydantic objects and cleanup corrupt legacy records.
|
||||
pydanticItems = []
|
||||
for record in sortedRecords:
|
||||
position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
|
||||
if position is not None:
|
||||
pydanticItems.append(position)
|
||||
def _cleanAndValidate(records):
|
||||
items = []
|
||||
for record in records:
|
||||
cleanedRecord = {k: v for k, v in record.items() if not k.startswith("_") or k in keepFields}
|
||||
position = self._toTrusteePositionOrDelete(cleanedRecord, deleteCorrupt=True)
|
||||
if position is not None:
|
||||
items.append(position)
|
||||
return items
|
||||
|
||||
# Step 5: Apply pagination
|
||||
totalItems = len(pydanticItems)
|
||||
if params:
|
||||
pageSize = params.pageSize or 20
|
||||
page = params.page or 1
|
||||
startIdx = (page - 1) * pageSize
|
||||
endIdx = startIdx + pageSize
|
||||
items = pydanticItems[startIdx:endIdx]
|
||||
totalPages = math.ceil(totalItems / pageSize) if pageSize > 0 else 1
|
||||
else:
|
||||
items = pydanticItems
|
||||
totalPages = 1
|
||||
page = 1
|
||||
pageSize = totalItems
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = _cleanAndValidate(result.items)
|
||||
return result
|
||||
return _cleanAndValidate(result)
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
)
|
||||
|
||||
def getPositionsByContract(self, contractId: str) -> List[TrusteePosition]:
|
||||
def getPositionsByContract(self, contractId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteePosition], PaginatedResult]:
|
||||
"""Get all positions for a specific contract."""
|
||||
# Step 1: System RBAC filtering
|
||||
records = getRecordsetWithRBAC(
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteePosition,
|
||||
currentUser=self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter={"contractId": contractId},
|
||||
orderBy="valuta",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
safeItems = []
|
||||
for record in records:
|
||||
position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
|
||||
if position is not None:
|
||||
safeItems.append(position)
|
||||
return safeItems
|
||||
|
||||
def getPositionsByOrganisation(self, organisationId: str) -> List[TrusteePosition]:
|
||||
def _validatePositions(records):
|
||||
items = []
|
||||
for record in records:
|
||||
position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
|
||||
if position is not None:
|
||||
items.append(position)
|
||||
return items
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = _validatePositions(result.items)
|
||||
return result
|
||||
return _validatePositions(result)
|
||||
|
||||
def getPositionsByOrganisation(self, organisationId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteePosition], PaginatedResult]:
|
||||
"""Get all positions for a specific organisation."""
|
||||
# Step 1: System RBAC filtering
|
||||
records = getRecordsetWithRBAC(
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteePosition,
|
||||
currentUser=self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter={"organisationId": organisationId},
|
||||
orderBy="valuta",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
safeItems = []
|
||||
for record in records:
|
||||
position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
|
||||
if position is not None:
|
||||
safeItems.append(position)
|
||||
return safeItems
|
||||
|
||||
def _validatePositions(records):
|
||||
items = []
|
||||
for record in records:
|
||||
position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
|
||||
if position is not None:
|
||||
items.append(position)
|
||||
return items
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = _validatePositions(result.items)
|
||||
return result
|
||||
return _validatePositions(result)
|
||||
|
||||
def updatePosition(self, positionId: str, data: Dict[str, Any]) -> Optional[TrusteePosition]:
|
||||
"""Update a position.
|
||||
|
|
@ -1481,24 +1401,31 @@ class TrusteeObjects:
|
|||
|
||||
# ===== Position-Document Queries =====
|
||||
|
||||
def getPositionsByDocument(self, documentId: str) -> List[TrusteePosition]:
|
||||
def getPositionsByDocument(self, documentId: str, pagination: Optional[PaginationParams] = None) -> Union[List[TrusteePosition], PaginatedResult]:
|
||||
"""Get all positions that reference a specific document (1:N via documentId FK)."""
|
||||
records = getRecordsetWithRBAC(
|
||||
result = getRecordsetPaginatedWithRBAC(
|
||||
connector=self.db,
|
||||
modelClass=TrusteePosition,
|
||||
currentUser=self.currentUser,
|
||||
pagination=pagination,
|
||||
recordFilter={"documentId": documentId},
|
||||
orderBy="valuta",
|
||||
mandateId=self.mandateId,
|
||||
featureInstanceId=self.featureInstanceId,
|
||||
featureCode=self.FEATURE_CODE
|
||||
)
|
||||
safeItems = []
|
||||
for record in records:
|
||||
position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
|
||||
if position is not None:
|
||||
safeItems.append(position)
|
||||
return safeItems
|
||||
|
||||
def _validatePositions(records):
|
||||
items = []
|
||||
for record in records:
|
||||
position = self._toTrusteePositionOrDelete(record, deleteCorrupt=True)
|
||||
if position is not None:
|
||||
items.append(position)
|
||||
return items
|
||||
|
||||
if isinstance(result, PaginatedResult):
|
||||
result.items = _validatePositions(result.items)
|
||||
return result
|
||||
return _validatePositions(result)
|
||||
|
||||
# ===== Trustee-specific Access Check =====
|
||||
|
||||
|
|
|
|||
|
|
@ -338,7 +338,7 @@ def get_organisations(
|
|||
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
result = interface.getAllOrganisations(paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
if paginationParams and hasattr(result, 'items'):
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
|
|
@ -350,7 +350,7 @@ def get_organisations(
|
|||
filters=paginationParams.filters if paginationParams else None
|
||||
)
|
||||
)
|
||||
return PaginatedResponse(items=result.items, pagination=None)
|
||||
return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
|
||||
|
||||
|
||||
@router.get("/{instanceId}/organisations/{orgId}", response_model=TrusteeOrganisation)
|
||||
|
|
@ -451,7 +451,7 @@ def get_roles(
|
|||
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
result = interface.getAllRoles(paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
if paginationParams and hasattr(result, 'items'):
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
|
|
@ -463,7 +463,7 @@ def get_roles(
|
|||
filters=paginationParams.filters if paginationParams else None
|
||||
)
|
||||
)
|
||||
return PaginatedResponse(items=result.items, pagination=None)
|
||||
return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
|
||||
|
||||
|
||||
@router.get("/{instanceId}/roles/{roleId}", response_model=TrusteeRole)
|
||||
|
|
@ -564,7 +564,7 @@ def get_all_access(
|
|||
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
result = interface.getAllAccess(paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
if paginationParams and hasattr(result, 'items'):
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
|
|
@ -576,7 +576,7 @@ def get_all_access(
|
|||
filters=paginationParams.filters if paginationParams else None
|
||||
)
|
||||
)
|
||||
return PaginatedResponse(items=result.items, pagination=None)
|
||||
return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
|
||||
|
||||
|
||||
@router.get("/{instanceId}/access/{accessId}", response_model=TrusteeAccess)
|
||||
|
|
@ -707,7 +707,7 @@ def get_contracts(
|
|||
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
result = interface.getAllContracts(paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
if paginationParams and hasattr(result, 'items'):
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
|
|
@ -719,7 +719,7 @@ def get_contracts(
|
|||
filters=paginationParams.filters if paginationParams else None
|
||||
)
|
||||
)
|
||||
return PaginatedResponse(items=result.items, pagination=None)
|
||||
return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
|
||||
|
||||
|
||||
@router.get("/{instanceId}/contracts/{contractId}", response_model=TrusteeContract)
|
||||
|
|
@ -835,7 +835,7 @@ def get_documents(
|
|||
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
result = interface.getAllDocuments(paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
if paginationParams and hasattr(result, 'items'):
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
|
|
@ -847,7 +847,59 @@ def get_documents(
|
|||
filters=paginationParams.filters if paginationParams else None
|
||||
)
|
||||
)
|
||||
return PaginatedResponse(items=result.items, pagination=None)
|
||||
return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
|
||||
|
||||
|
||||
@router.get("/{instanceId}/documents/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_document_filter_values(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in trustee documents."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
try:
|
||||
from modules.interfaces.interfaceRbac import getDistinctColumnValuesWithRBAC
|
||||
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
|
||||
crossFilterPagination = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterPagination = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
try:
|
||||
values = getDistinctColumnValuesWithRBAC(
|
||||
connector=interface.db,
|
||||
modelClass=TrusteeDocument,
|
||||
column=column,
|
||||
currentUser=interface.currentUser,
|
||||
pagination=crossFilterPagination,
|
||||
recordFilter=None,
|
||||
mandateId=interface.mandateId,
|
||||
featureInstanceId=interface.featureInstanceId,
|
||||
featureCode=interface.FEATURE_CODE
|
||||
)
|
||||
return sorted(values, key=lambda v: str(v).lower())
|
||||
except Exception:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
result = interface.getAllDocuments(None)
|
||||
items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in (result.items if hasattr(result, 'items') else result)]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for trustee documents: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{instanceId}/documents/{documentId}", response_model=TrusteeDocument)
|
||||
|
|
@ -1039,7 +1091,7 @@ def get_positions(
|
|||
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
result = interface.getAllPositions(paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
if paginationParams and hasattr(result, 'items'):
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
|
|
@ -1051,7 +1103,59 @@ def get_positions(
|
|||
filters=paginationParams.filters if paginationParams else None
|
||||
)
|
||||
)
|
||||
return PaginatedResponse(items=result.items, pagination=None)
|
||||
return PaginatedResponse(items=result if isinstance(result, list) else result.items, pagination=None)
|
||||
|
||||
|
||||
@router.get("/{instanceId}/positions/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_position_filter_values(
|
||||
request: Request,
|
||||
instanceId: str = Path(..., description="Feature Instance ID"),
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in trustee positions."""
|
||||
mandateId = _validateInstanceAccess(instanceId, context)
|
||||
try:
|
||||
from modules.interfaces.interfaceRbac import getDistinctColumnValuesWithRBAC
|
||||
interface = getInterface(context.user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
|
||||
crossFilterPagination = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterPagination = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
try:
|
||||
values = getDistinctColumnValuesWithRBAC(
|
||||
connector=interface.db,
|
||||
modelClass=TrusteePosition,
|
||||
column=column,
|
||||
currentUser=interface.currentUser,
|
||||
pagination=crossFilterPagination,
|
||||
recordFilter=None,
|
||||
mandateId=interface.mandateId,
|
||||
featureInstanceId=interface.featureInstanceId,
|
||||
featureCode=interface.FEATURE_CODE
|
||||
)
|
||||
return sorted(values, key=lambda v: str(v).lower())
|
||||
except Exception:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
result = interface.getAllPositions(None)
|
||||
items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in (result.items if hasattr(result, 'items') else result)]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for trustee positions: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{instanceId}/positions/{positionId}", response_model=TrusteePosition)
|
||||
|
|
|
|||
65
modules/features/workspace/datamodelFeatureWorkspace.py
Normal file
65
modules/features/workspace/datamodelFeatureWorkspace.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""Workspace feature data models — VoiceSettings and WorkspaceUserSettings."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.shared.attributeUtils import registerModelLabels
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
import uuid
|
||||
|
||||
|
||||
class VoiceSettings(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
userId: str = Field(description="ID of the user these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
|
||||
mandateId: str = Field(description="ID of the mandate these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
|
||||
featureInstanceId: str = Field(description="ID of the feature instance these settings belong to", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
|
||||
sttLanguage: str = Field(default="de-DE", description="Speech-to-Text language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
|
||||
ttsLanguage: str = Field(default="de-DE", description="Text-to-Speech language", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
|
||||
ttsVoice: str = Field(default="de-DE-KatjaNeural", description="Text-to-Speech voice", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": True})
|
||||
ttsVoiceMap: Dict[str, Any] = Field(default_factory=dict, description="Per-language voice mapping, e.g. {'de-DE': {'voiceName': 'de-DE-Wavenet-A'}, 'en-US': {'voiceName': 'en-US-Wavenet-C'}}", json_schema_extra={"frontend_type": "json", "frontend_readonly": False, "frontend_required": False})
|
||||
translationEnabled: bool = Field(default=True, description="Whether translation is enabled", json_schema_extra={"frontend_type": "checkbox", "frontend_readonly": False, "frontend_required": False})
|
||||
targetLanguage: str = Field(default="en-US", description="Target language for translation", json_schema_extra={"frontend_type": "select", "frontend_readonly": False, "frontend_required": False})
|
||||
creationDate: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were created (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||
lastModified: float = Field(default_factory=getUtcTimestamp, description="Date when the settings were last modified (UTC timestamp in seconds)", json_schema_extra={"frontend_type": "timestamp", "frontend_readonly": True, "frontend_required": False})
|
||||
|
||||
|
||||
class WorkspaceUserSettings(BaseModel):
|
||||
"""Per-user workspace settings. None values mean 'use instance default'."""
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="Primary key", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": False})
|
||||
userId: str = Field(description="User ID", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
|
||||
mandateId: str = Field(description="Mandate ID", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
|
||||
featureInstanceId: str = Field(description="Feature Instance ID", json_schema_extra={"frontend_type": "text", "frontend_readonly": True, "frontend_required": True})
|
||||
maxAgentRounds: Optional[int] = Field(default=None, description="Max agent rounds override (None = instance default)", json_schema_extra={"frontend_type": "number", "frontend_readonly": False, "frontend_required": False})
|
||||
|
||||
|
||||
registerModelLabels(
|
||||
"VoiceSettings",
|
||||
{"en": "Voice Settings", "fr": "Paramètres vocaux"},
|
||||
{
|
||||
"id": {"en": "ID", "fr": "ID"},
|
||||
"userId": {"en": "User ID", "fr": "ID utilisateur"},
|
||||
"mandateId": {"en": "Mandate ID", "fr": "ID du mandat"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "fr": "ID de l'instance de fonctionnalité"},
|
||||
"sttLanguage": {"en": "STT Language", "fr": "Langue STT"},
|
||||
"ttsLanguage": {"en": "TTS Language", "fr": "Langue TTS"},
|
||||
"ttsVoice": {"en": "TTS Voice", "fr": "Voix TTS"},
|
||||
"ttsVoiceMap": {"en": "TTS Voice Map", "fr": "Carte des voix TTS"},
|
||||
"translationEnabled": {"en": "Translation Enabled", "fr": "Traduction activée"},
|
||||
"targetLanguage": {"en": "Target Language", "fr": "Langue cible"},
|
||||
"creationDate": {"en": "Creation Date", "fr": "Date de création"},
|
||||
"lastModified": {"en": "Last Modified", "fr": "Dernière modification"},
|
||||
},
|
||||
)
|
||||
|
||||
registerModelLabels(
|
||||
"WorkspaceUserSettings",
|
||||
{"en": "Workspace User Settings", "de": "Workspace Benutzereinstellungen"},
|
||||
{
|
||||
"id": {"en": "ID", "de": "ID"},
|
||||
"userId": {"en": "User ID", "de": "Benutzer-ID"},
|
||||
"mandateId": {"en": "Mandate ID", "de": "Mandanten-ID"},
|
||||
"featureInstanceId": {"en": "Feature Instance ID", "de": "Feature-Instanz-ID"},
|
||||
"maxAgentRounds": {"en": "Max Agent Rounds", "de": "Max. Agenten-Runden"},
|
||||
},
|
||||
)
|
||||
248
modules/features/workspace/interfaceFeatureWorkspace.py
Normal file
248
modules/features/workspace/interfaceFeatureWorkspace.py
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Interface for Workspace feature — manages VoiceSettings and WorkspaceUserSettings.
|
||||
Uses a dedicated poweron_workspace database.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.features.workspace.datamodelFeatureWorkspace import VoiceSettings, WorkspaceUserSettings
|
||||
from modules.interfaces.interfaceRbac import getRecordsetWithRBAC
|
||||
from modules.security.rbac import RbacClass
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_workspaceInterfaces: Dict[str, "WorkspaceObjects"] = {}
|
||||
|
||||
|
||||
class WorkspaceObjects:
|
||||
"""Interface for Workspace-specific database operations (voice + general settings)."""
|
||||
|
||||
def __init__(self, currentUser: User, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = mandateId
|
||||
self.featureInstanceId = featureInstanceId
|
||||
self.userId = currentUser.id if currentUser else None
|
||||
|
||||
self._initializeDatabase()
|
||||
|
||||
from modules.security.rootAccess import getRootDbAppConnector
|
||||
dbApp = getRootDbAppConnector()
|
||||
self.rbac = RbacClass(self.db, dbApp=dbApp)
|
||||
|
||||
self.db.updateContext(self.userId)
|
||||
|
||||
def _initializeDatabase(self):
|
||||
dbHost = APP_CONFIG.get("DB_HOST", "_no_config_default_data")
|
||||
dbDatabase = "poweron_workspace"
|
||||
dbUser = APP_CONFIG.get("DB_USER")
|
||||
dbPassword = APP_CONFIG.get("DB_PASSWORD_SECRET")
|
||||
dbPort = int(APP_CONFIG.get("DB_PORT", 5432))
|
||||
|
||||
self.db = DatabaseConnector(
|
||||
dbHost=dbHost,
|
||||
dbDatabase=dbDatabase,
|
||||
dbUser=dbUser,
|
||||
dbPassword=dbPassword,
|
||||
dbPort=dbPort,
|
||||
userId=self.userId,
|
||||
)
|
||||
logger.debug(f"Workspace database initialized for user {self.userId}")
|
||||
|
||||
def setUserContext(self, currentUser: User, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None):
|
||||
self.currentUser = currentUser
|
||||
self.userId = currentUser.id if currentUser else None
|
||||
self.mandateId = mandateId
|
||||
self.featureInstanceId = featureInstanceId
|
||||
self.db.updateContext(self.userId)
|
||||
|
||||
# =========================================================================
|
||||
# VoiceSettings CRUD
|
||||
# =========================================================================
|
||||
|
||||
def getVoiceSettings(self, userId: Optional[str] = None) -> Optional[VoiceSettings]:
|
||||
try:
|
||||
targetUserId = userId or self.userId
|
||||
if not targetUserId:
|
||||
logger.error("No user ID provided for voice settings")
|
||||
return None
|
||||
|
||||
recordFilter: Dict[str, Any] = {"userId": targetUserId}
|
||||
if self.featureInstanceId:
|
||||
recordFilter["featureInstanceId"] = self.featureInstanceId
|
||||
|
||||
filteredSettings = getRecordsetWithRBAC(
|
||||
self.db, VoiceSettings, self.currentUser,
|
||||
recordFilter=recordFilter, mandateId=self.mandateId,
|
||||
)
|
||||
|
||||
if not filteredSettings:
|
||||
return None
|
||||
|
||||
settingsData = filteredSettings[0]
|
||||
if not settingsData.get("creationDate"):
|
||||
settingsData["creationDate"] = getUtcTimestamp()
|
||||
if not settingsData.get("lastModified"):
|
||||
settingsData["lastModified"] = getUtcTimestamp()
|
||||
|
||||
return VoiceSettings(**settingsData)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting voice settings: {e}")
|
||||
return None
|
||||
|
||||
def createVoiceSettings(self, settingsData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
if "userId" not in settingsData:
|
||||
settingsData["userId"] = self.userId
|
||||
if "mandateId" not in settingsData:
|
||||
settingsData["mandateId"] = self.mandateId
|
||||
if "featureInstanceId" not in settingsData:
|
||||
settingsData["featureInstanceId"] = self.featureInstanceId
|
||||
|
||||
existing = self.getVoiceSettings(settingsData["userId"])
|
||||
if existing:
|
||||
raise ValueError(f"Voice settings already exist for user {settingsData['userId']}")
|
||||
|
||||
createdRecord = self.db.recordCreate(VoiceSettings, settingsData)
|
||||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create voice settings record")
|
||||
|
||||
logger.info(f"Created voice settings for user {settingsData['userId']}")
|
||||
return createdRecord
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating voice settings: {e}")
|
||||
raise
|
||||
|
||||
def updateVoiceSettings(self, userId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
try:
|
||||
existing = self.getVoiceSettings(userId)
|
||||
if not existing:
|
||||
raise ValueError(f"Voice settings not found for user {userId}")
|
||||
|
||||
updateData["lastModified"] = getUtcTimestamp()
|
||||
success = self.db.recordModify(VoiceSettings, existing.id, updateData)
|
||||
if not success:
|
||||
raise ValueError("Failed to update voice settings record")
|
||||
|
||||
updated = self.getVoiceSettings(userId)
|
||||
if not updated:
|
||||
raise ValueError("Failed to retrieve updated voice settings")
|
||||
|
||||
logger.info(f"Updated voice settings for user {userId}")
|
||||
return updated.model_dump()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating voice settings: {e}")
|
||||
raise
|
||||
|
||||
def deleteVoiceSettings(self, userId: str) -> bool:
|
||||
try:
|
||||
existing = self.getVoiceSettings(userId)
|
||||
if not existing:
|
||||
return False
|
||||
success = self.db.recordDelete(VoiceSettings, existing.id)
|
||||
if success:
|
||||
logger.info(f"Deleted voice settings for user {userId}")
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting voice settings: {e}")
|
||||
return False
|
||||
|
||||
def getOrCreateVoiceSettings(self, userId: Optional[str] = None) -> VoiceSettings:
|
||||
targetUserId = userId or self.userId
|
||||
if not targetUserId:
|
||||
raise ValueError("No user ID provided for voice settings")
|
||||
|
||||
existing = self.getVoiceSettings(targetUserId)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
defaultSettings = {
|
||||
"userId": targetUserId,
|
||||
"mandateId": self.mandateId,
|
||||
"featureInstanceId": self.featureInstanceId,
|
||||
"sttLanguage": "de-DE",
|
||||
"ttsLanguage": "de-DE",
|
||||
"ttsVoice": "de-DE-KatjaNeural",
|
||||
"translationEnabled": True,
|
||||
"targetLanguage": "en-US",
|
||||
}
|
||||
createdRecord = self.createVoiceSettings(defaultSettings)
|
||||
return VoiceSettings(**createdRecord)
|
||||
|
||||
# =========================================================================
|
||||
# WorkspaceUserSettings CRUD
|
||||
# =========================================================================
|
||||
|
||||
def getWorkspaceUserSettings(self, userId: Optional[str] = None) -> Optional[WorkspaceUserSettings]:
|
||||
targetUserId = userId or self.userId
|
||||
if not targetUserId:
|
||||
return None
|
||||
|
||||
try:
|
||||
recordFilter: Dict[str, Any] = {"userId": targetUserId}
|
||||
if self.featureInstanceId:
|
||||
recordFilter["featureInstanceId"] = self.featureInstanceId
|
||||
|
||||
records = getRecordsetWithRBAC(
|
||||
self.db, WorkspaceUserSettings, self.currentUser,
|
||||
recordFilter=recordFilter, mandateId=self.mandateId,
|
||||
)
|
||||
if not records:
|
||||
return None
|
||||
return WorkspaceUserSettings(**records[0])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting workspace user settings: {e}")
|
||||
return None
|
||||
|
||||
def saveWorkspaceUserSettings(self, data: Dict[str, Any]) -> WorkspaceUserSettings:
|
||||
"""Upsert: create or update workspace user settings for the current user."""
|
||||
targetUserId = data.get("userId", self.userId)
|
||||
if not targetUserId:
|
||||
raise ValueError("No user ID provided")
|
||||
|
||||
existing = self.getWorkspaceUserSettings(targetUserId)
|
||||
if existing:
|
||||
updateData = {k: v for k, v in data.items() if k not in ("id", "userId", "mandateId", "featureInstanceId")}
|
||||
self.db.recordModify(WorkspaceUserSettings, existing.id, updateData)
|
||||
updated = self.getWorkspaceUserSettings(targetUserId)
|
||||
if not updated:
|
||||
raise ValueError("Failed to retrieve updated workspace user settings")
|
||||
return updated
|
||||
|
||||
createData = {
|
||||
"userId": targetUserId,
|
||||
"mandateId": data.get("mandateId", self.mandateId),
|
||||
"featureInstanceId": data.get("featureInstanceId", self.featureInstanceId),
|
||||
}
|
||||
createData.update({k: v for k, v in data.items() if k not in ("id", "userId", "mandateId", "featureInstanceId")})
|
||||
createdRecord = self.db.recordCreate(WorkspaceUserSettings, createData)
|
||||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create workspace user settings")
|
||||
return WorkspaceUserSettings(**createdRecord)
|
||||
|
||||
|
||||
def getInterface(currentUser: Optional[User] = None, mandateId: Optional[str] = None, featureInstanceId: Optional[str] = None) -> WorkspaceObjects:
|
||||
if not currentUser:
|
||||
raise ValueError("Invalid user context: user is required")
|
||||
|
||||
effectiveMandateId = str(mandateId) if mandateId else None
|
||||
effectiveFeatureInstanceId = str(featureInstanceId) if featureInstanceId else None
|
||||
|
||||
contextKey = f"workspace_{effectiveMandateId}_{effectiveFeatureInstanceId}_{currentUser.id}"
|
||||
|
||||
if contextKey not in _workspaceInterfaces:
|
||||
_workspaceInterfaces[contextKey] = WorkspaceObjects(currentUser, mandateId=effectiveMandateId, featureInstanceId=effectiveFeatureInstanceId)
|
||||
else:
|
||||
_workspaceInterfaces[contextKey].setUserContext(currentUser, mandateId=effectiveMandateId, featureInstanceId=effectiveFeatureInstanceId)
|
||||
|
||||
return _workspaceInterfaces[contextKey]
|
||||
|
|
@ -31,6 +31,15 @@ UI_OBJECTS = [
|
|||
"label": {"en": "Settings", "de": "Einstellungen", "fr": "Parametres"},
|
||||
"meta": {"area": "settings"}
|
||||
},
|
||||
{
|
||||
"objectKey": "ui.feature.workspace.rag-insights",
|
||||
"label": {
|
||||
"en": "Knowledge insights",
|
||||
"de": "Wissens-Insights",
|
||||
"fr": "Aperçu des connaissances",
|
||||
},
|
||||
"meta": {"area": "rag-insights"},
|
||||
},
|
||||
]
|
||||
|
||||
RESOURCE_OBJECTS = [
|
||||
|
|
@ -83,6 +92,7 @@ TEMPLATE_ROLES = [
|
|||
{"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.editor", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.settings", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.rag-insights", "view": True},
|
||||
{"context": "DATA", "item": None, "view": True, "read": "m", "create": "n", "update": "n", "delete": "n"},
|
||||
]
|
||||
},
|
||||
|
|
@ -97,6 +107,7 @@ TEMPLATE_ROLES = [
|
|||
{"context": "UI", "item": "ui.feature.workspace.dashboard", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.editor", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.settings", "view": True},
|
||||
{"context": "UI", "item": "ui.feature.workspace.rag-insights", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.workspace.start", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.workspace.stop", "view": True},
|
||||
{"context": "RESOURCE", "item": "resource.feature.workspace.files", "view": True},
|
||||
|
|
@ -215,6 +226,8 @@ def _syncTemplateRolesToDb() -> int:
|
|||
if createdCount > 0:
|
||||
logger.info(f"Feature '{FEATURE_CODE}': Created {createdCount} template roles")
|
||||
|
||||
_repairWorkspaceUserInstanceUiNav(rootInterface)
|
||||
|
||||
return createdCount
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -222,6 +235,57 @@ def _syncTemplateRolesToDb() -> int:
|
|||
return 0
|
||||
|
||||
|
||||
def _repairWorkspaceUserInstanceUiNav(rootInterface) -> int:
|
||||
"""
|
||||
Ensure every instance-scoped workspace-user role grants UI view on Editor and Settings.
|
||||
Covers older instance roles copied before template updates (bootstrap / app startup).
|
||||
"""
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext, Role
|
||||
|
||||
workspaceNavObjectKeys = (
|
||||
"ui.feature.workspace.editor",
|
||||
"ui.feature.workspace.settings",
|
||||
)
|
||||
repairCount = 0
|
||||
try:
|
||||
userRoleRecords = rootInterface.db.getRecordset(
|
||||
Role,
|
||||
recordFilter={"featureCode": FEATURE_CODE, "roleLabel": "workspace-user"},
|
||||
)
|
||||
for roleRec in userRoleRecords or []:
|
||||
if not roleRec.get("featureInstanceId"):
|
||||
continue
|
||||
roleId = str(roleRec.get("id"))
|
||||
accessRules = rootInterface.getAccessRulesByRole(roleId)
|
||||
rulesByUiKey = {(r.context, r.item): r for r in accessRules}
|
||||
for objectKey in workspaceNavObjectKeys:
|
||||
uiKey = (AccessRuleContext.UI, objectKey)
|
||||
existingRule = rulesByUiKey.get(uiKey)
|
||||
if existingRule is None:
|
||||
newRule = AccessRule(
|
||||
roleId=roleId,
|
||||
context=AccessRuleContext.UI,
|
||||
item=objectKey,
|
||||
view=True,
|
||||
read=None,
|
||||
create=None,
|
||||
update=None,
|
||||
delete=None,
|
||||
)
|
||||
rootInterface.db.recordCreate(AccessRule, newRule.model_dump())
|
||||
repairCount += 1
|
||||
elif not existingRule.view:
|
||||
rootInterface.db.recordModify(AccessRule, str(existingRule.id), {"view": True})
|
||||
repairCount += 1
|
||||
if repairCount:
|
||||
logger.info(
|
||||
f"Feature '{FEATURE_CODE}': Repaired {repairCount} UI AccessRules for instance workspace-user roles (Editor/Settings)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Feature '{FEATURE_CODE}': workspace-user UI nav repair skipped: {e}")
|
||||
return repairCount
|
||||
|
||||
|
||||
def _ensureAccessRulesForRole(rootInterface, roleId: str, ruleTemplates: List[Dict[str, Any]]) -> int:
|
||||
"""Ensure AccessRules exist for a role based on templates."""
|
||||
from modules.datamodels.datamodelRbac import AccessRule, AccessRuleContext
|
||||
|
|
|
|||
|
|
@ -19,7 +19,12 @@ from modules.auth import limiter, getRequestContext, RequestContext
|
|||
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import (
|
||||
InsufficientBalanceException,
|
||||
)
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
SubscriptionInactiveException,
|
||||
)
|
||||
from modules.interfaces import interfaceDbChat, interfaceDbManagement
|
||||
from modules.features.workspace import interfaceFeatureWorkspace
|
||||
from modules.interfaces.interfaceDbKnowledge import getInterface as getKnowledgeInterface
|
||||
from modules.interfaces.interfaceAiObjects import AiObjects
|
||||
from modules.serviceCenter.core.serviceStreaming import get_event_manager
|
||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentEventTypeEnum, PendingFileEdit
|
||||
|
|
@ -142,6 +147,14 @@ def _getDbManagement(context: RequestContext, featureInstanceId: str = None):
|
|||
)
|
||||
|
||||
|
||||
def _getWorkspaceInterface(context: RequestContext, featureInstanceId: str = None):
|
||||
return interfaceFeatureWorkspace.getInterface(
|
||||
context.user,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=featureInstanceId,
|
||||
)
|
||||
|
||||
|
||||
_SOURCE_TYPE_TO_SERVICE = {
|
||||
"sharepointFolder": "sharepoint",
|
||||
"onedriveFolder": "onedrive",
|
||||
|
|
@ -698,7 +711,17 @@ async def _runWorkspaceAgent(
|
|||
_toolSet = _cfg.get("toolSet", "core")
|
||||
_agentCfg = _cfg.get("agentConfig")
|
||||
from modules.serviceCenter.services.serviceAgent.datamodelAgent import AgentConfig
|
||||
agentConfig = AgentConfig(**_agentCfg) if isinstance(_agentCfg, dict) else None
|
||||
|
||||
agentCfgDict = dict(_agentCfg) if isinstance(_agentCfg, dict) else {}
|
||||
try:
|
||||
wsIf = interfaceFeatureWorkspace.getInterface(user, mandateId=mandateId, featureInstanceId=instanceId)
|
||||
userSettings = wsIf.getWorkspaceUserSettings(user.id if user else None)
|
||||
if userSettings and userSettings.maxAgentRounds is not None:
|
||||
agentCfgDict["maxRounds"] = userSettings.maxAgentRounds
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not load workspace user settings for agent config: {e}")
|
||||
|
||||
agentConfig = AgentConfig(**agentCfgDict) if agentCfgDict else None
|
||||
|
||||
async for event in agentService.runAgent(
|
||||
prompt=enrichedPrompt,
|
||||
|
|
@ -803,7 +826,15 @@ async def _runWorkspaceAgent(
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, InsufficientBalanceException):
|
||||
if isinstance(e, SubscriptionInactiveException):
|
||||
logger.warning(f"Workspace blocked by subscription: {e.message}")
|
||||
await eventManager.emit_event(queueId, "error", {
|
||||
"type": "error",
|
||||
"content": e.message,
|
||||
"workflowId": workflowId,
|
||||
"item": e.toClientDict(),
|
||||
})
|
||||
elif isinstance(e, InsufficientBalanceException):
|
||||
logger.warning(f"Workspace blocked by billing: {e.message}")
|
||||
await eventManager.emit_event(queueId, "error", {
|
||||
"type": "error",
|
||||
|
|
@ -1564,13 +1595,13 @@ async def getVoiceSettings(
|
|||
):
|
||||
"""Load voice settings for the current user and instance."""
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
dbMgmt = _getDbManagement(context, instanceId)
|
||||
wsInterface = _getWorkspaceInterface(context, instanceId)
|
||||
userId = str(context.user.id)
|
||||
try:
|
||||
vs = dbMgmt.getVoiceSettings(userId)
|
||||
vs = wsInterface.getVoiceSettings(userId)
|
||||
if not vs:
|
||||
logger.info(f"GET voice settings: not found for user={userId}, creating defaults")
|
||||
vs = dbMgmt.getOrCreateVoiceSettings(userId)
|
||||
vs = wsInterface.getOrCreateVoiceSettings(userId)
|
||||
result = vs.model_dump() if vs else {}
|
||||
mapKeys = list(result.get("ttsVoiceMap", {}).keys()) if result else []
|
||||
logger.info(f"GET voice settings for user={userId}: ttsVoiceMap languages={mapKeys}")
|
||||
|
|
@ -1590,12 +1621,12 @@ async def updateVoiceSettings(
|
|||
):
|
||||
"""Update voice settings for the current user and instance."""
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
dbMgmt = _getDbManagement(context, instanceId)
|
||||
wsInterface = _getWorkspaceInterface(context, instanceId)
|
||||
userId = str(context.user.id)
|
||||
|
||||
try:
|
||||
logger.info(f"PUT voice settings for user={userId}, instance={instanceId}, body keys={list(body.keys())}")
|
||||
vs = dbMgmt.getVoiceSettings(userId)
|
||||
vs = wsInterface.getVoiceSettings(userId)
|
||||
if not vs:
|
||||
logger.info(f"No existing voice settings, creating new for user={userId}")
|
||||
createData = {
|
||||
|
|
@ -1604,13 +1635,13 @@ async def updateVoiceSettings(
|
|||
"featureInstanceId": instanceId,
|
||||
}
|
||||
createData.update(body)
|
||||
created = dbMgmt.createVoiceSettings(createData)
|
||||
created = wsInterface.createVoiceSettings(createData)
|
||||
logger.info(f"Created voice settings for user={userId}, ttsVoiceMap keys={list((created or {}).get('ttsVoiceMap', {}).keys())}")
|
||||
return JSONResponse(created)
|
||||
|
||||
updateData = {k: v for k, v in body.items() if k not in ("id", "userId", "mandateId", "featureInstanceId", "creationDate")}
|
||||
logger.info(f"Updating voice settings for user={userId}, update keys={list(updateData.keys())}")
|
||||
updated = dbMgmt.updateVoiceSettings(userId, updateData)
|
||||
updated = wsInterface.updateVoiceSettings(userId, updateData)
|
||||
logger.info(f"Updated voice settings for user={userId}, ttsVoiceMap keys={list((updated or {}).get('ttsVoiceMap', {}).keys())}")
|
||||
return JSONResponse(updated)
|
||||
except Exception as e:
|
||||
|
|
@ -1816,3 +1847,113 @@ async def rejectAllEdits(
|
|||
|
||||
logger.info(f"Rejected {len(rejected)} edits for instance {instanceId}")
|
||||
return JSONResponse({"rejected": rejected})
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# General Settings Endpoints (per-user workspace settings)
|
||||
# =========================================================================
|
||||
|
||||
@router.get("/{instanceId}/settings/general")
|
||||
@limiter.limit("120/minute")
|
||||
async def getGeneralSettings(
|
||||
request: Request,
|
||||
instanceId: str = Path(...),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Load general workspace settings for the current user, with effective values."""
|
||||
_mandateId, instanceConfig = _validateInstanceAccess(instanceId, context)
|
||||
wsInterface = _getWorkspaceInterface(context, instanceId)
|
||||
userId = str(context.user.id)
|
||||
|
||||
userSettings = wsInterface.getWorkspaceUserSettings(userId)
|
||||
|
||||
agentCfg = (instanceConfig or {}).get("agentConfig", {})
|
||||
instanceDefault = agentCfg.get("maxRounds", 25) if isinstance(agentCfg, dict) else 25
|
||||
|
||||
userOverride = userSettings.maxAgentRounds if userSettings else None
|
||||
effective = userOverride if userOverride is not None else instanceDefault
|
||||
|
||||
return JSONResponse({
|
||||
"maxAgentRounds": {
|
||||
"effective": effective,
|
||||
"userOverride": userOverride,
|
||||
"instanceDefault": instanceDefault,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@router.put("/{instanceId}/settings/general")
|
||||
@limiter.limit("120/minute")
|
||||
async def updateGeneralSettings(
|
||||
request: Request,
|
||||
instanceId: str = Path(...),
|
||||
body: dict = Body(...),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Update general workspace settings for the current user."""
|
||||
_validateInstanceAccess(instanceId, context)
|
||||
wsInterface = _getWorkspaceInterface(context, instanceId)
|
||||
userId = str(context.user.id)
|
||||
|
||||
data = {
|
||||
"userId": userId,
|
||||
"mandateId": str(context.mandateId) if context.mandateId else "",
|
||||
"featureInstanceId": instanceId,
|
||||
}
|
||||
if "maxAgentRounds" in body:
|
||||
val = body["maxAgentRounds"]
|
||||
data["maxAgentRounds"] = int(val) if val is not None else None
|
||||
|
||||
wsInterface.saveWorkspaceUserSettings(data)
|
||||
|
||||
return await getGeneralSettings(request, instanceId, context)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# RAG / Knowledge — anonymised instance statistics (presentation / KPIs)
|
||||
# =========================================================================
|
||||
|
||||
def _collectWorkspaceFileIdsForStats(instanceId: str, mandateId: Optional[str]) -> List[str]:
|
||||
"""All FileItem ids for this feature instance (any user). Knowledge rows are often stored
|
||||
without featureInstanceId; we correlate by file id from the Management DB."""
|
||||
from modules.datamodels.datamodelFiles import FileItem
|
||||
from modules.interfaces.interfaceDbManagement import ComponentObjects
|
||||
|
||||
co = ComponentObjects()
|
||||
rows = co.db.getRecordset(FileItem, recordFilter={"featureInstanceId": instanceId})
|
||||
out: List[str] = []
|
||||
m = str(mandateId) if mandateId else ""
|
||||
for r in rows or []:
|
||||
rid = r.get("id") if isinstance(r, dict) else getattr(r, "id", None)
|
||||
if not rid:
|
||||
continue
|
||||
if m:
|
||||
mid = r.get("mandateId") if isinstance(r, dict) else getattr(r, "mandateId", "") or ""
|
||||
if mid and mid != m:
|
||||
continue
|
||||
out.append(str(rid))
|
||||
return out
|
||||
|
||||
|
||||
@router.get("/{instanceId}/rag-statistics")
|
||||
@limiter.limit("60/minute")
|
||||
async def getRagStatistics(
|
||||
request: Request,
|
||||
instanceId: str = Path(...),
|
||||
days: int = Query(90, ge=7, le=365, description="Timeline window in days"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Aggregated, non-identifying knowledge-store metrics for this workspace instance."""
|
||||
mandateId, _instanceConfig = _validateInstanceAccess(instanceId, context)
|
||||
workspaceFileIds = _collectWorkspaceFileIdsForStats(instanceId, mandateId)
|
||||
kdb = getKnowledgeInterface(context.user)
|
||||
stats = kdb.getRagStatisticsForInstance(
|
||||
featureInstanceId=instanceId,
|
||||
mandateId=str(mandateId) if mandateId else "",
|
||||
timelineDays=days,
|
||||
workspaceFileIds=workspaceFileIds,
|
||||
)
|
||||
if isinstance(stats, dict):
|
||||
stats.setdefault("scope", {})
|
||||
stats["scope"]["workspaceFileIdsResolved"] = len(workspaceFileIds)
|
||||
return JSONResponse(stats)
|
||||
|
|
|
|||
|
|
@ -103,6 +103,13 @@ def initBootstrap(db: DatabaseConnector) -> None:
|
|||
if mandateId:
|
||||
initRootMandateBilling(mandateId)
|
||||
|
||||
# Initialize subscription for root mandate
|
||||
if mandateId:
|
||||
_initRootMandateSubscription(mandateId)
|
||||
|
||||
# Auto-provision Stripe Products/Prices for paid plans (idempotent)
|
||||
_bootstrapStripePrices()
|
||||
|
||||
|
||||
def initAutomationTemplates(dbApp: DatabaseConnector, adminUserId: Optional[str] = None) -> None:
|
||||
"""
|
||||
|
|
@ -1864,7 +1871,7 @@ def _createAicoreProviderRules(db: DatabaseConnector) -> None:
|
|||
|
||||
Provider access per role:
|
||||
- admin: all providers allowed
|
||||
- user: all providers EXCEPT anthropic (view=False)
|
||||
- user: all providers allowed
|
||||
- viewer: NO provider access (viewer has no RESOURCE permissions)
|
||||
|
||||
NOTE: Provider list is dynamically discovered from AICore model registry.
|
||||
|
|
@ -1909,7 +1916,7 @@ def _createAicoreProviderRules(db: DatabaseConnector) -> None:
|
|||
read=None, create=None, update=None, delete=None,
|
||||
))
|
||||
|
||||
# User: access to all providers EXCEPT anthropic
|
||||
# User: access to all providers (same provider keys as admin)
|
||||
userId = _getRoleId(db, "user")
|
||||
if userId:
|
||||
for provider in providers:
|
||||
|
|
@ -1923,13 +1930,11 @@ def _createAicoreProviderRules(db: DatabaseConnector) -> None:
|
|||
}
|
||||
)
|
||||
if not existingRules:
|
||||
# Anthropic is not allowed for user role
|
||||
isAllowed = provider != "anthropic"
|
||||
providerRules.append(AccessRule(
|
||||
roleId=userId,
|
||||
context=AccessRuleContext.RESOURCE,
|
||||
item=resourceKey,
|
||||
view=isAllowed,
|
||||
view=True,
|
||||
read=None, create=None, update=None, delete=None,
|
||||
))
|
||||
|
||||
|
|
@ -2069,6 +2074,47 @@ def initRootMandateBilling(mandateId: str) -> None:
|
|||
logger.warning(f"Failed to initialize root mandate billing (non-critical): {e}")
|
||||
|
||||
|
||||
def _initRootMandateSubscription(mandateId: str) -> None:
|
||||
"""
|
||||
Ensure the root mandate has an active ROOT subscription.
|
||||
Called during bootstrap after billing init.
|
||||
"""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
|
||||
from modules.datamodels.datamodelSubscription import (
|
||||
MandateSubscription,
|
||||
SubscriptionStatusEnum,
|
||||
)
|
||||
|
||||
subInterface = getSubRootInterface()
|
||||
existing = subInterface.getOperativeForMandate(mandateId)
|
||||
if existing:
|
||||
logger.info("Root mandate subscription already exists")
|
||||
return
|
||||
|
||||
sub = MandateSubscription(
|
||||
mandateId=mandateId,
|
||||
planKey="ROOT",
|
||||
status=SubscriptionStatusEnum.ACTIVE,
|
||||
recurring=False,
|
||||
)
|
||||
subInterface.createSubscription(sub)
|
||||
logger.info("Created ROOT subscription for root mandate")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize root mandate subscription (non-critical): {e}")
|
||||
|
||||
|
||||
def _bootstrapStripePrices() -> None:
|
||||
"""Auto-create Stripe Products and Prices for all paid plans.
|
||||
Idempotent — safe on every startup. IDs are persisted in the StripePlanPrice table."""
|
||||
try:
|
||||
from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import bootstrapStripePrices
|
||||
bootstrapStripePrices()
|
||||
except Exception as e:
|
||||
logger.error(f"Stripe price bootstrap failed (subscriptions will not work for paid plans): {e}")
|
||||
|
||||
|
||||
def assignInitialUserMemberships(
|
||||
db: DatabaseConnector,
|
||||
mandateId: str,
|
||||
|
|
|
|||
|
|
@ -423,7 +423,7 @@ class AppObjects:
|
|||
|
||||
else:
|
||||
# Unknown operator - default to equals
|
||||
if record_value != filter_val:
|
||||
if str(record_value).lower() != str(filter_val).lower():
|
||||
matches = False
|
||||
break
|
||||
|
||||
|
|
@ -512,53 +512,34 @@ class AppObjects:
|
|||
If pagination is None: List[User]
|
||||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
# Get user IDs via UserMandate junction table (UserInDB has no mandateId column)
|
||||
userMandates = self.db.getRecordset(UserMandate, recordFilter={"mandateId": mandateId})
|
||||
userIds = [um.get("userId") for um in userMandates if um.get("userId")]
|
||||
|
||||
# Fetch each user by ID
|
||||
filteredUsers = []
|
||||
for userId in userIds:
|
||||
userRecords = self.db.getRecordset(UserInDB, recordFilter={"id": userId})
|
||||
if userRecords:
|
||||
cleanedUser = {k: v for k, v in userRecords[0].items() if not k.startswith("_")}
|
||||
if cleanedUser.get("roleLabels") is None:
|
||||
cleanedUser["roleLabels"] = []
|
||||
filteredUsers.append(cleanedUser)
|
||||
|
||||
# If no pagination requested, return all items
|
||||
if not userIds:
|
||||
if pagination is None:
|
||||
return []
|
||||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
|
||||
result = self.db.getRecordsetPaginated(
|
||||
UserInDB,
|
||||
pagination=pagination,
|
||||
recordFilter={"id": userIds}
|
||||
)
|
||||
|
||||
items = []
|
||||
for record in result["items"]:
|
||||
cleanedUser = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
if cleanedUser.get("roleLabels") is None:
|
||||
cleanedUser["roleLabels"] = []
|
||||
items.append(User(**cleanedUser))
|
||||
|
||||
if pagination is None:
|
||||
return [User(**user) for user in filteredUsers]
|
||||
|
||||
# Apply filtering (if filters provided)
|
||||
if pagination.filters:
|
||||
filteredUsers = self._applyFilters(filteredUsers, pagination.filters)
|
||||
|
||||
# Apply sorting (in order of sortFields)
|
||||
if pagination.sort:
|
||||
filteredUsers = self._applySorting(filteredUsers, pagination.sort)
|
||||
|
||||
# Count total items after filters
|
||||
totalItems = len(filteredUsers)
|
||||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||||
|
||||
# Apply pagination (skip/limit)
|
||||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||||
endIdx = startIdx + pagination.pageSize
|
||||
pagedUsers = filteredUsers[startIdx:endIdx]
|
||||
|
||||
# Ensure roleLabels is always a list for paginated results too
|
||||
for user in pagedUsers:
|
||||
if user.get("roleLabels") is None:
|
||||
user["roleLabels"] = []
|
||||
|
||||
# Convert to model objects
|
||||
items = [User(**user) for user in pagedUsers]
|
||||
|
||||
return items
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
totalItems=result["totalItems"],
|
||||
totalPages=result["totalPages"]
|
||||
)
|
||||
|
||||
def getUserByUsername(self, username: str) -> Optional[User]:
|
||||
|
|
@ -1615,7 +1596,10 @@ class AppObjects:
|
|||
existing = self.getUserMandate(userId, mandateId)
|
||||
if existing:
|
||||
raise ValueError(f"User {userId} is already member of mandate {mandateId}")
|
||||
|
||||
|
||||
# Subscription capacity check (before insert)
|
||||
self._checkSubscriptionCapacity(mandateId, "users", delta=1)
|
||||
|
||||
# Create UserMandate
|
||||
userMandate = UserMandate(
|
||||
userId=userId,
|
||||
|
|
@ -1636,7 +1620,10 @@ class AppObjects:
|
|||
|
||||
# Create billing account for user if billing is configured
|
||||
self._ensureUserBillingAccount(userId, mandateId)
|
||||
|
||||
|
||||
# Sync Stripe quantity after successful insert
|
||||
self._syncSubscriptionQuantity(mandateId)
|
||||
|
||||
cleanedRecord = {k: v for k, v in createdRecord.items() if not k.startswith("_")}
|
||||
return UserMandate(**cleanedRecord)
|
||||
except Exception as e:
|
||||
|
|
@ -1686,6 +1673,28 @@ class AppObjects:
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to create billing account for user {userId} (non-critical): {e}")
|
||||
|
||||
def _checkSubscriptionCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> None:
|
||||
"""Check subscription capacity before creating a resource. Raises on cap violation."""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface
|
||||
from modules.security.rootAccess import getRootUser
|
||||
subIf = getSubInterface(getRootUser(), mandateId)
|
||||
subIf.assertCapacity(mandateId, resourceType, delta)
|
||||
except Exception as e:
|
||||
if "SubscriptionCapacityException" in type(e).__name__:
|
||||
raise
|
||||
logger.debug(f"Subscription capacity check skipped: {e}")
|
||||
|
||||
def _syncSubscriptionQuantity(self, mandateId: str) -> None:
|
||||
"""Sync Stripe subscription quantities after a resource mutation."""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbSubscription import getInterface as getSubInterface
|
||||
from modules.security.rootAccess import getRootUser
|
||||
subIf = getSubInterface(getRootUser(), mandateId)
|
||||
subIf.syncQuantityToStripe(mandateId)
|
||||
except Exception as e:
|
||||
logger.debug(f"Subscription quantity sync skipped: {e}")
|
||||
|
||||
def deleteUserMandate(self, userId: str, mandateId: str) -> bool:
|
||||
"""
|
||||
Delete a UserMandate record (remove user from mandate).
|
||||
|
|
@ -2284,30 +2293,43 @@ class AppObjects:
|
|||
# Additional Helper Methods
|
||||
# ============================================
|
||||
|
||||
def getAllUsers(self) -> List[User]:
|
||||
def getAllUsers(self, pagination: Optional[PaginationParams] = None) -> Union[List[User], PaginatedResult]:
|
||||
"""
|
||||
Get all users (for SysAdmin only).
|
||||
|
||||
Args:
|
||||
pagination: Optional pagination parameters. If None, returns all items.
|
||||
|
||||
Returns:
|
||||
List of User objects (without sensitive fields)
|
||||
If pagination is None: List[User] (without sensitive fields)
|
||||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
try:
|
||||
records = self.db.getRecordset(UserInDB)
|
||||
result = []
|
||||
for record in records:
|
||||
# Filter out sensitive and internal fields
|
||||
result = self.db.getRecordsetPaginated(UserInDB, pagination=pagination)
|
||||
|
||||
items = []
|
||||
for record in result["items"]:
|
||||
cleanedRecord = {
|
||||
k: v for k, v in record.items()
|
||||
k: v for k, v in record.items()
|
||||
if not k.startswith("_") and k not in ["hashedPassword", "resetToken", "resetTokenExpires"]
|
||||
}
|
||||
# Ensure roleLabels is a list
|
||||
if cleanedRecord.get("roleLabels") is None:
|
||||
cleanedRecord["roleLabels"] = []
|
||||
result.append(User(**cleanedRecord))
|
||||
return result
|
||||
items.append(User(**cleanedRecord))
|
||||
|
||||
if pagination is None:
|
||||
return items
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=result["totalItems"],
|
||||
totalPages=result["totalPages"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all users: {e}")
|
||||
return []
|
||||
if pagination is None:
|
||||
return []
|
||||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
|
||||
def getUserMandateById(self, userMandateId: str) -> Optional[UserMandate]:
|
||||
"""
|
||||
|
|
@ -3293,50 +3315,26 @@ class AppObjects:
|
|||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
try:
|
||||
# Get all roles from database
|
||||
roles = self.db.getRecordset(Role)
|
||||
|
||||
# Filter out database-specific fields
|
||||
filteredRoles = []
|
||||
for role in roles:
|
||||
cleanedRole = {k: v for k, v in role.items() if not k.startswith("_")}
|
||||
filteredRoles.append(cleanedRole)
|
||||
|
||||
# If no pagination requested, return all items
|
||||
result = self.db.getRecordsetPaginated(Role, pagination=pagination)
|
||||
|
||||
items = []
|
||||
for record in result["items"]:
|
||||
cleanedRole = {k: v for k, v in record.items() if not k.startswith("_")}
|
||||
items.append(Role(**cleanedRole))
|
||||
|
||||
if pagination is None:
|
||||
return [Role(**role) for role in filteredRoles]
|
||||
|
||||
# Apply filtering (if filters provided)
|
||||
if pagination.filters:
|
||||
filteredRoles = self._applyFilters(filteredRoles, pagination.filters)
|
||||
|
||||
# Apply sorting (in order of sortFields)
|
||||
if pagination.sort:
|
||||
filteredRoles = self._applySorting(filteredRoles, pagination.sort)
|
||||
|
||||
# Count total items after filters
|
||||
totalItems = len(filteredRoles)
|
||||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||||
|
||||
# Apply pagination (skip/limit)
|
||||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||||
endIdx = startIdx + pagination.pageSize
|
||||
pagedRoles = filteredRoles[startIdx:endIdx]
|
||||
|
||||
# Convert to model objects
|
||||
items = [Role(**role) for role in pagedRoles]
|
||||
|
||||
return items
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
totalItems=result["totalItems"],
|
||||
totalPages=result["totalPages"]
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all roles: {str(e)}")
|
||||
if pagination is None:
|
||||
return []
|
||||
else:
|
||||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
|
||||
def countRoleAssignments(self) -> Dict[str, int]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ All billing data is stored in the poweron_billing database.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from datetime import date, datetime, timedelta
|
||||
import uuid
|
||||
|
||||
|
|
@ -17,6 +17,7 @@ from modules.shared.configuration import APP_CONFIG
|
|||
from modules.shared.timeUtils import getUtcTimestamp
|
||||
from modules.datamodels.datamodelUam import User, Mandate
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
|
||||
from modules.datamodels.datamodelBilling import (
|
||||
BillingAccount,
|
||||
BillingTransaction,
|
||||
|
|
@ -633,26 +634,43 @@ class BillingObjects:
|
|||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
startDate: date = None,
|
||||
endDate: date = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
endDate: date = None,
|
||||
pagination: PaginationParams = None
|
||||
) -> Union[List[Dict[str, Any]], PaginatedResult]:
|
||||
"""
|
||||
Get transactions for an account.
|
||||
|
||||
When pagination is provided, uses database-level pagination and returns
|
||||
PaginatedResult. Otherwise falls back to in-memory filtering/sorting/slicing.
|
||||
|
||||
Args:
|
||||
accountId: Account ID
|
||||
limit: Maximum number of results
|
||||
offset: Offset for pagination
|
||||
startDate: Filter by start date
|
||||
endDate: Filter by end date
|
||||
limit: Maximum number of results (legacy path)
|
||||
offset: Offset for pagination (legacy path)
|
||||
startDate: Filter by start date (legacy path)
|
||||
endDate: Filter by end date (legacy path)
|
||||
pagination: PaginationParams for DB-level pagination
|
||||
|
||||
Returns:
|
||||
List of transaction dicts
|
||||
PaginatedResult when pagination is provided, List of dicts otherwise
|
||||
"""
|
||||
try:
|
||||
if pagination:
|
||||
recordFilter = {"accountId": accountId}
|
||||
result = self.db.getRecordsetPaginated(
|
||||
BillingTransaction,
|
||||
pagination=pagination,
|
||||
recordFilter=recordFilter
|
||||
)
|
||||
return PaginatedResult(
|
||||
items=result["items"],
|
||||
totalItems=result["totalItems"],
|
||||
totalPages=result["totalPages"]
|
||||
)
|
||||
|
||||
filterDict = {"accountId": accountId}
|
||||
results = self.db.getRecordset(BillingTransaction, recordFilter=filterDict)
|
||||
|
||||
# Apply date filters if provided
|
||||
if startDate or endDate:
|
||||
filtered = []
|
||||
for t in results:
|
||||
|
|
@ -666,35 +684,61 @@ class BillingObjects:
|
|||
filtered.append(t)
|
||||
results = filtered
|
||||
|
||||
# Sort by creation date descending
|
||||
results.sort(key=lambda x: x.get("_createdAt", ""), reverse=True)
|
||||
|
||||
# Apply pagination
|
||||
return results[offset:offset + limit]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting transactions: {e}")
|
||||
if pagination:
|
||||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
return []
|
||||
|
||||
def getTransactionsByMandate(self, mandateId: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
def getTransactionsByMandate(
|
||||
self,
|
||||
mandateId: str,
|
||||
limit: int = 100,
|
||||
pagination: PaginationParams = None
|
||||
) -> Union[List[Dict[str, Any]], PaginatedResult]:
|
||||
"""
|
||||
Get all transactions for a mandate (across all accounts).
|
||||
|
||||
When pagination is provided, collects all accountIds for the mandate and
|
||||
issues a single DB query with SQL-level filtering, sorting, and pagination.
|
||||
Otherwise falls back to per-account querying and in-memory merging.
|
||||
|
||||
Args:
|
||||
mandateId: Mandate ID
|
||||
limit: Maximum number of results
|
||||
limit: Maximum number of results (legacy path)
|
||||
pagination: PaginationParams for DB-level pagination
|
||||
|
||||
Returns:
|
||||
List of transaction dicts
|
||||
PaginatedResult when pagination is provided, List of dicts otherwise
|
||||
"""
|
||||
# Get all accounts for mandate
|
||||
accounts = self.db.getRecordset(BillingAccount, recordFilter={"mandateId": mandateId})
|
||||
|
||||
accountIds = [acc["id"] for acc in accounts if acc.get("id")]
|
||||
|
||||
if not accountIds:
|
||||
if pagination:
|
||||
return PaginatedResult(items=[], totalItems=0, totalPages=0)
|
||||
return []
|
||||
|
||||
if pagination:
|
||||
result = self.db.getRecordsetPaginated(
|
||||
BillingTransaction,
|
||||
pagination=pagination,
|
||||
recordFilter={"accountId": accountIds}
|
||||
)
|
||||
return PaginatedResult(
|
||||
items=result["items"],
|
||||
totalItems=result["totalItems"],
|
||||
totalPages=result["totalPages"]
|
||||
)
|
||||
|
||||
allTransactions = []
|
||||
for account in accounts:
|
||||
transactions = self.getTransactions(account["id"], limit=limit)
|
||||
allTransactions.extend(transactions)
|
||||
|
||||
# Sort by creation date descending and limit
|
||||
allTransactions.sort(key=lambda x: x.get("_createdAt", ""), reverse=True)
|
||||
return allTransactions[:limit]
|
||||
|
||||
|
|
|
|||
|
|
@ -450,7 +450,7 @@ class ChatObjects:
|
|||
|
||||
# Handle simple value (equals operator)
|
||||
if not isinstance(filter_value, dict):
|
||||
if record_value != filter_value:
|
||||
if str(record_value).lower() != str(filter_value).lower():
|
||||
matches = False
|
||||
break
|
||||
continue
|
||||
|
|
@ -460,7 +460,7 @@ class ChatObjects:
|
|||
filter_val = filter_value.get("value")
|
||||
|
||||
if operator in ["equals", "eq"]:
|
||||
if record_value != filter_val:
|
||||
if str(record_value).lower() != str(filter_val).lower():
|
||||
matches = False
|
||||
break
|
||||
|
||||
|
|
@ -545,7 +545,7 @@ class ChatObjects:
|
|||
|
||||
else:
|
||||
# Unknown operator - default to equals
|
||||
if record_value != filter_val:
|
||||
if str(record_value).lower() != str(filter_val).lower():
|
||||
matches = False
|
||||
break
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ and semantic search via pgvector.
|
|||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from modules.connectors.connectorDbPostgre import _get_cached_connector
|
||||
|
|
@ -286,6 +288,183 @@ class KnowledgeObjects:
|
|||
minScore=minScore,
|
||||
)
|
||||
|
||||
def getRagStatisticsForInstance(
|
||||
self,
|
||||
featureInstanceId: str,
|
||||
mandateId: str,
|
||||
timelineDays: int = 90,
|
||||
workspaceFileIds: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Aggregate anonymised RAG / knowledge-store metrics for one workspace instance.
|
||||
|
||||
No file names, user identifiers, or chunk text are returned — only counts and
|
||||
distributions suitable for dashboards and presentations.
|
||||
|
||||
workspaceFileIds: optional list of FileItem ids for this feature instance (from Management DB).
|
||||
Index pipelines often stored rows with empty featureInstanceId; linking by file id fixes stats.
|
||||
"""
|
||||
if not featureInstanceId:
|
||||
return {"error": "featureInstanceId required"}
|
||||
|
||||
ws_ids = [x for x in (workspaceFileIds or []) if x]
|
||||
ws_id_set = set(ws_ids)
|
||||
|
||||
files_inst = self.db.getRecordset(
|
||||
FileContentIndex,
|
||||
recordFilter={"featureInstanceId": featureInstanceId},
|
||||
)
|
||||
files_shared: List[Dict[str, Any]] = []
|
||||
if mandateId:
|
||||
files_shared = self.db.getRecordset(
|
||||
FileContentIndex,
|
||||
recordFilter={"mandateId": mandateId, "isShared": True},
|
||||
)
|
||||
|
||||
by_id: Dict[str, Dict[str, Any]] = {}
|
||||
for row in files_inst + files_shared:
|
||||
rid = row.get("id")
|
||||
if rid and rid not in by_id:
|
||||
by_id[rid] = row
|
||||
|
||||
for fid in ws_ids:
|
||||
if fid in by_id:
|
||||
continue
|
||||
row = self.getFileContentIndex(fid)
|
||||
if row:
|
||||
by_id[fid] = row
|
||||
|
||||
files = list(by_id.values())
|
||||
|
||||
chunks_by_id: Dict[str, Dict[str, Any]] = {}
|
||||
inst_chunks = self.db.getRecordset(
|
||||
ContentChunk,
|
||||
recordFilter={"featureInstanceId": featureInstanceId},
|
||||
)
|
||||
for c in inst_chunks:
|
||||
cid = c.get("id")
|
||||
if cid:
|
||||
chunks_by_id[cid] = c
|
||||
|
||||
for fid in ws_id_set:
|
||||
for c in self.getContentChunks(fid):
|
||||
cid = c.get("id")
|
||||
if cid and cid not in chunks_by_id:
|
||||
chunks_by_id[cid] = c
|
||||
|
||||
covered_file_ids = {c.get("fileId") for c in chunks_by_id.values() if c.get("fileId")}
|
||||
for row in files:
|
||||
fid = row.get("id")
|
||||
if fid and fid not in covered_file_ids:
|
||||
for c in self.getContentChunks(fid):
|
||||
cid = c.get("id")
|
||||
if cid and cid not in chunks_by_id:
|
||||
chunks_by_id[cid] = c
|
||||
|
||||
chunks = list(chunks_by_id.values())
|
||||
|
||||
def _mimeCategory(mime: str) -> str:
|
||||
m = (mime or "").lower()
|
||||
if "pdf" in m:
|
||||
return "pdf"
|
||||
if "wordprocessing" in m or "msword" in m or "officedocument.wordprocessing" in m:
|
||||
return "office_doc"
|
||||
if "spreadsheet" in m or "excel" in m or "officedocument.spreadsheet" in m:
|
||||
return "office_sheet"
|
||||
if "presentation" in m or "officedocument.presentation" in m:
|
||||
return "office_slides"
|
||||
if m.startswith("text/"):
|
||||
return "text"
|
||||
if m.startswith("image/"):
|
||||
return "image"
|
||||
if "html" in m:
|
||||
return "html"
|
||||
return "other"
|
||||
|
||||
def _utcDay(ts: Any) -> str:
|
||||
if ts is None:
|
||||
return ""
|
||||
try:
|
||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc).strftime("%Y-%m-%d")
|
||||
except (TypeError, ValueError, OSError):
|
||||
return ""
|
||||
|
||||
status_counts: Dict[str, int] = defaultdict(int)
|
||||
mime_counts: Dict[str, int] = defaultdict(int)
|
||||
extracted_by_day: Dict[str, int] = defaultdict(int)
|
||||
total_bytes = 0
|
||||
user_ids = set()
|
||||
|
||||
for row in files:
|
||||
st = row.get("status") or "unknown"
|
||||
status_counts[st] += 1
|
||||
mime_counts[_mimeCategory(row.get("mimeType") or "")] += 1
|
||||
day = _utcDay(row.get("extractedAt"))
|
||||
if day:
|
||||
extracted_by_day[day] += 1
|
||||
try:
|
||||
total_bytes += int(row.get("totalSize") or 0)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
uid = row.get("userId")
|
||||
if uid:
|
||||
user_ids.add(str(uid))
|
||||
|
||||
content_type_counts: Dict[str, int] = defaultdict(int)
|
||||
chunks_with_embedding = 0
|
||||
for c in chunks:
|
||||
ct = c.get("contentType") or "other"
|
||||
content_type_counts[ct] += 1
|
||||
emb = c.get("embedding")
|
||||
if emb is not None and (
|
||||
(isinstance(emb, list) and len(emb) > 0)
|
||||
or (isinstance(emb, str) and len(emb) > 10)
|
||||
):
|
||||
chunks_with_embedding += 1
|
||||
|
||||
wf_mem = self.db.getRecordset(
|
||||
WorkflowMemory,
|
||||
recordFilter={"featureInstanceId": featureInstanceId},
|
||||
)
|
||||
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=max(1, int(timelineDays)))
|
||||
cutoff_ts = cutoff.timestamp()
|
||||
|
||||
timeline: List[Dict[str, Any]] = []
|
||||
for day in sorted(extracted_by_day.keys()):
|
||||
try:
|
||||
d = datetime.strptime(day, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
continue
|
||||
if d.timestamp() >= cutoff_ts:
|
||||
timeline.append({"date": day, "indexedDocuments": extracted_by_day[day]})
|
||||
|
||||
if len(timeline) > 120:
|
||||
timeline = timeline[-120:]
|
||||
|
||||
total_chunks = len(chunks)
|
||||
embedding_pct = round(100.0 * chunks_with_embedding / total_chunks, 1) if total_chunks else 0.0
|
||||
|
||||
return {
|
||||
"scope": {
|
||||
"featureInstanceId": featureInstanceId,
|
||||
"mandateScopedShared": bool(mandateId),
|
||||
},
|
||||
"kpis": {
|
||||
"indexedDocuments": len(files),
|
||||
"indexedBytesTotal": total_bytes,
|
||||
"contributorUsers": len(user_ids),
|
||||
"contentChunks": total_chunks,
|
||||
"chunksWithEmbedding": chunks_with_embedding,
|
||||
"embeddingCoveragePercent": embedding_pct,
|
||||
"workflowEntities": len(wf_mem),
|
||||
},
|
||||
"indexedDocumentsByStatus": dict(sorted(status_counts.items())),
|
||||
"documentsByMimeCategory": dict(sorted(mime_counts.items(), key=lambda x: -x[1])),
|
||||
"chunksByContentType": dict(sorted(content_type_counts.items())),
|
||||
"timelineIndexedDocuments": timeline,
|
||||
"generatedAtUtc": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
||||
|
||||
def getInterface(currentUser: Optional[User] = None) -> KnowledgeObjects:
|
||||
"""Get or create a KnowledgeObjects singleton."""
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ from modules.datamodels.datamodelUam import AccessLevel
|
|||
from modules.datamodels.datamodelFiles import FilePreview, FileItem, FileData
|
||||
from modules.datamodels.datamodelFileFolder import FileFolder
|
||||
from modules.datamodels.datamodelUtils import Prompt
|
||||
from modules.datamodels.datamodelVoice import VoiceSettings
|
||||
from modules.datamodels.datamodelMessaging import (
|
||||
MessagingSubscription,
|
||||
MessagingSubscriptionRegistration,
|
||||
|
|
@ -390,7 +389,7 @@ class ComponentObjects:
|
|||
|
||||
# Handle simple value (equals operator)
|
||||
if not isinstance(filter_value, dict):
|
||||
if record_value != filter_value:
|
||||
if str(record_value).lower() != str(filter_value).lower():
|
||||
matches = False
|
||||
break
|
||||
continue
|
||||
|
|
@ -400,7 +399,7 @@ class ComponentObjects:
|
|||
filter_val = filter_value.get("value")
|
||||
|
||||
if operator in ["equals", "eq"]:
|
||||
if record_value != filter_val:
|
||||
if str(record_value).lower() != str(filter_val).lower():
|
||||
matches = False
|
||||
break
|
||||
|
||||
|
|
@ -650,6 +649,10 @@ class ComponentObjects:
|
|||
- Regular user: sees own prompts + system prompts (isSystem=True), can only CRUD own
|
||||
- Row-level _permissions control edit/delete buttons in the UI
|
||||
|
||||
NOTE: Cannot use db.getRecordsetPaginated() because visibility rules
|
||||
(_getPromptsForUser: own + system for regular, all for SysAdmin) and
|
||||
per-row _permissions enrichment require loading all records first.
|
||||
|
||||
Args:
|
||||
pagination: Optional pagination parameters. If None, returns all items.
|
||||
|
||||
|
|
@ -914,7 +917,7 @@ class ComponentObjects:
|
|||
"""
|
||||
Returns files owned by the current user (user-scoped, not RBAC-based).
|
||||
Every user (including SysAdmin) only sees their own files.
|
||||
Supports optional pagination, sorting, and filtering.
|
||||
Supports optional pagination, sorting, and filtering via database-level queries.
|
||||
|
||||
Args:
|
||||
pagination: Optional pagination parameters. If None, returns all items.
|
||||
|
|
@ -923,24 +926,21 @@ class ComponentObjects:
|
|||
If pagination is None: List[FileItem]
|
||||
If pagination is provided: PaginatedResult with items and metadata
|
||||
"""
|
||||
# Files are always user-scoped: filter by _createdBy (bypasses RBAC SysAdmin override)
|
||||
filteredFiles = self._getFilesByCurrentUser()
|
||||
|
||||
# Convert database records to FileItem instances (extra='allow' preserves system fields like _createdBy)
|
||||
def convertFileItems(files):
|
||||
# User-scoping filter: every user only sees their own files (bypasses RBAC SysAdmin override)
|
||||
recordFilter = {"_createdBy": self.userId}
|
||||
|
||||
def _convertFileItems(files):
|
||||
fileItems = []
|
||||
for file in files:
|
||||
try:
|
||||
# Ensure proper values, use defaults for invalid data
|
||||
creationDate = file.get("creationDate")
|
||||
if creationDate is None or not isinstance(creationDate, (int, float)) or creationDate <= 0:
|
||||
file["creationDate"] = getUtcTimestamp()
|
||||
|
||||
fileName = file.get("fileName")
|
||||
if not fileName or fileName == "None":
|
||||
continue # Skip records with invalid fileName
|
||||
continue
|
||||
|
||||
# Use **file to pass all fields including system fields (_createdBy, etc.)
|
||||
fileItem = FileItem(**file)
|
||||
fileItems.append(fileItem)
|
||||
except Exception as e:
|
||||
|
|
@ -948,34 +948,23 @@ class ComponentObjects:
|
|||
continue
|
||||
return fileItems
|
||||
|
||||
# If no pagination requested, return all items
|
||||
if pagination is None:
|
||||
return convertFileItems(filteredFiles)
|
||||
allFiles = self._getFilesByCurrentUser()
|
||||
return _convertFileItems(allFiles)
|
||||
|
||||
# Apply filtering (if filters provided)
|
||||
if pagination.filters:
|
||||
filteredFiles = self._applyFilters(filteredFiles, pagination.filters)
|
||||
# Database-level pagination: filtering, sorting, and LIMIT/OFFSET happen in SQL
|
||||
result = self.db.getRecordsetPaginated(
|
||||
FileItem,
|
||||
pagination=pagination,
|
||||
recordFilter=recordFilter
|
||||
)
|
||||
|
||||
# Apply sorting (in order of sortFields)
|
||||
if pagination.sort:
|
||||
filteredFiles = self._applySorting(filteredFiles, pagination.sort)
|
||||
|
||||
# Count total items after filters
|
||||
totalItems = len(filteredFiles)
|
||||
totalPages = math.ceil(totalItems / pagination.pageSize) if totalItems > 0 else 0
|
||||
|
||||
# Apply pagination (skip/limit)
|
||||
startIdx = (pagination.page - 1) * pagination.pageSize
|
||||
endIdx = startIdx + pagination.pageSize
|
||||
pagedFiles = filteredFiles[startIdx:endIdx]
|
||||
|
||||
# Convert to model objects (extra='allow' on FileItem preserves system fields)
|
||||
items = convertFileItems(pagedFiles)
|
||||
items = _convertFileItems(result["items"])
|
||||
|
||||
return PaginatedResult(
|
||||
items=items,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages
|
||||
totalItems=result["totalItems"],
|
||||
totalPages=result["totalPages"]
|
||||
)
|
||||
|
||||
def getFile(self, fileId: str) -> Optional[FileItem]:
|
||||
|
|
@ -1733,158 +1722,6 @@ class ComponentObjects:
|
|||
logger.error(f"Error in saveUploadedFile for {fileName}: {str(e)}", exc_info=True)
|
||||
raise FileStorageError(f"Error saving file: {str(e)}")
|
||||
|
||||
# VoiceSettings methods
|
||||
|
||||
def getVoiceSettings(self, userId: Optional[str] = None) -> Optional[VoiceSettings]:
|
||||
"""Returns voice settings for a user if user has access."""
|
||||
try:
|
||||
targetUserId = userId or self.userId
|
||||
if not targetUserId:
|
||||
logger.error("No user ID provided for voice settings")
|
||||
return None
|
||||
|
||||
recordFilter: Dict[str, Any] = {"userId": targetUserId}
|
||||
if self.featureInstanceId:
|
||||
recordFilter["featureInstanceId"] = self.featureInstanceId
|
||||
|
||||
# Get voice settings for the user (scoped to current feature instance if available), filtered by RBAC
|
||||
filteredSettings = getRecordsetWithRBAC(self.db,
|
||||
VoiceSettings,
|
||||
self.currentUser,
|
||||
recordFilter=recordFilter,
|
||||
mandateId=self.mandateId
|
||||
)
|
||||
|
||||
if not filteredSettings:
|
||||
logger.warning(f"No access to voice settings for user {targetUserId}")
|
||||
return None
|
||||
|
||||
# Ensure timestamps are set for validation
|
||||
settings_data = filteredSettings[0]
|
||||
if not settings_data.get("creationDate"):
|
||||
settings_data["creationDate"] = getUtcTimestamp()
|
||||
if not settings_data.get("lastModified"):
|
||||
settings_data["lastModified"] = getUtcTimestamp()
|
||||
|
||||
return VoiceSettings(**settings_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting voice settings: {str(e)}")
|
||||
return None
|
||||
|
||||
def createVoiceSettings(self, settingsData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Creates voice settings for a user if user has permission."""
|
||||
try:
|
||||
if not self.checkRbacPermission(VoiceSettings, "update"):
|
||||
raise PermissionError("No permission to create voice settings")
|
||||
|
||||
# Ensure userId is set
|
||||
if "userId" not in settingsData:
|
||||
settingsData["userId"] = self.userId
|
||||
|
||||
# Ensure mandateId and featureInstanceId are set from context
|
||||
if "mandateId" not in settingsData:
|
||||
settingsData["mandateId"] = self.mandateId
|
||||
if "featureInstanceId" not in settingsData:
|
||||
settingsData["featureInstanceId"] = self.featureInstanceId
|
||||
|
||||
# Check if settings already exist for this user
|
||||
existingSettings = self.getVoiceSettings(settingsData["userId"])
|
||||
if existingSettings:
|
||||
raise ValueError(f"Voice settings already exist for user {settingsData['userId']}")
|
||||
|
||||
# Create voice settings record
|
||||
createdRecord = self.db.recordCreate(VoiceSettings, settingsData)
|
||||
if not createdRecord or not createdRecord.get("id"):
|
||||
raise ValueError("Failed to create voice settings record")
|
||||
|
||||
logger.info(f"Created voice settings for user {settingsData['userId']}")
|
||||
return createdRecord
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating voice settings: {str(e)}")
|
||||
raise
|
||||
|
||||
def updateVoiceSettings(self, userId: str, updateData: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Updates voice settings for a user if user has access."""
|
||||
try:
|
||||
# Get existing settings
|
||||
existingSettings = self.getVoiceSettings(userId)
|
||||
if not existingSettings:
|
||||
raise ValueError(f"Voice settings not found for user {userId}")
|
||||
|
||||
# Update lastModified timestamp
|
||||
updateData["lastModified"] = getUtcTimestamp()
|
||||
|
||||
# Update voice settings record
|
||||
success = self.db.recordModify(VoiceSettings, existingSettings.id, updateData)
|
||||
if not success:
|
||||
raise ValueError("Failed to update voice settings record")
|
||||
|
||||
# Get updated settings
|
||||
updatedSettings = self.getVoiceSettings(userId)
|
||||
if not updatedSettings:
|
||||
raise ValueError("Failed to retrieve updated voice settings")
|
||||
|
||||
logger.info(f"Updated voice settings for user {userId}")
|
||||
return updatedSettings.model_dump()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating voice settings: {str(e)}")
|
||||
raise
|
||||
|
||||
def deleteVoiceSettings(self, userId: str) -> bool:
|
||||
"""Deletes voice settings for a user if user has access."""
|
||||
try:
|
||||
# Get existing settings
|
||||
existingSettings = self.getVoiceSettings(userId)
|
||||
if not existingSettings:
|
||||
logger.warning(f"Voice settings not found for user {userId}")
|
||||
return False
|
||||
|
||||
# Delete voice settings
|
||||
success = self.db.recordDelete(VoiceSettings, existingSettings.id)
|
||||
if success:
|
||||
logger.info(f"Deleted voice settings for user {userId}")
|
||||
else:
|
||||
logger.error(f"Failed to delete voice settings for user {userId}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting voice settings: {str(e)}")
|
||||
return False
|
||||
|
||||
def getOrCreateVoiceSettings(self, userId: Optional[str] = None) -> VoiceSettings:
|
||||
"""Gets existing voice settings or creates default ones for a user."""
|
||||
try:
|
||||
targetUserId = userId or self.userId
|
||||
if not targetUserId:
|
||||
raise ValueError("No user ID provided for voice settings")
|
||||
|
||||
# Try to get existing settings
|
||||
existingSettings = self.getVoiceSettings(targetUserId)
|
||||
if existingSettings:
|
||||
return existingSettings
|
||||
|
||||
# Create default settings
|
||||
defaultSettings = {
|
||||
"userId": targetUserId,
|
||||
"mandateId": self.mandateId,
|
||||
"sttLanguage": "de-DE",
|
||||
"ttsLanguage": "de-DE",
|
||||
"ttsVoice": "de-DE-KatjaNeural",
|
||||
"translationEnabled": True,
|
||||
"targetLanguage": "en-US"
|
||||
}
|
||||
|
||||
createdRecord = self.createVoiceSettings(defaultSettings)
|
||||
return VoiceSettings(**createdRecord)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting or creating voice settings: {str(e)}")
|
||||
raise
|
||||
|
||||
# Messaging Subscription methods
|
||||
|
||||
def getAllSubscriptions(self, pagination: Optional[PaginationParams] = None) -> Union[List[MessagingSubscription], PaginatedResult]:
|
||||
|
|
|
|||
353
modules/interfaces/interfaceDbSubscription.py
Normal file
353
modules/interfaces/interfaceDbSubscription.py
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Interface for Subscription operations — ID-based, deterministic.
|
||||
|
||||
Every write operation takes an explicit subscriptionId.
|
||||
No status-scan guessing. See wiki/concepts/Subscription-State-Machine.md.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelMembership import UserMandate
|
||||
from modules.datamodels.datamodelSubscription import (
|
||||
SubscriptionPlan,
|
||||
MandateSubscription,
|
||||
SubscriptionStatusEnum,
|
||||
BillingPeriodEnum,
|
||||
ALLOWED_TRANSITIONS,
|
||||
TERMINAL_STATUSES,
|
||||
OPERATIVE_STATUSES,
|
||||
BUILTIN_PLANS,
|
||||
_getPlan,
|
||||
_getSelectablePlans,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUBSCRIPTION_DATABASE = "poweron_billing"
|
||||
|
||||
_subscriptionInterfaces: Dict[str, "SubscriptionObjects"] = {}
|
||||
|
||||
|
||||
class InvalidTransitionError(Exception):
|
||||
"""Raised when a state transition is not allowed by the state machine."""
|
||||
def __init__(self, subscriptionId: str, fromStatus: str, toStatus: str):
|
||||
self.subscriptionId = subscriptionId
|
||||
self.fromStatus = fromStatus
|
||||
self.toStatus = toStatus
|
||||
super().__init__(f"Invalid transition {fromStatus} -> {toStatus} for subscription {subscriptionId}")
|
||||
|
||||
|
||||
def getInterface(currentUser: User, mandateId: str = None) -> "SubscriptionObjects":
|
||||
cacheKey = f"{currentUser.id}_{mandateId}"
|
||||
if cacheKey not in _subscriptionInterfaces:
|
||||
_subscriptionInterfaces[cacheKey] = SubscriptionObjects(currentUser, mandateId)
|
||||
else:
|
||||
_subscriptionInterfaces[cacheKey].setUserContext(currentUser, mandateId)
|
||||
return _subscriptionInterfaces[cacheKey]
|
||||
|
||||
|
||||
def _getRootInterface() -> "SubscriptionObjects":
|
||||
from modules.security.rootAccess import getRootUser
|
||||
return SubscriptionObjects(getRootUser(), mandateId=None)
|
||||
|
||||
|
||||
def _getAppDatabaseConnector() -> DatabaseConnector:
|
||||
return DatabaseConnector(
|
||||
dbDatabase=APP_CONFIG.get("DB_DATABASE", "poweron_app"),
|
||||
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
||||
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
|
||||
dbUser=APP_CONFIG.get("DB_USER"),
|
||||
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
|
||||
)
|
||||
|
||||
|
||||
class SubscriptionObjects:
|
||||
"""Interface for subscription operations: CRUD, gate checks, Stripe sync.
|
||||
All writes are ID-based. All status changes go through transitionStatus()."""
|
||||
|
||||
def __init__(self, currentUser: Optional[User] = None, mandateId: str = None):
|
||||
self.currentUser = currentUser
|
||||
self.userId = currentUser.id if currentUser else None
|
||||
self.mandateId = mandateId
|
||||
self.db = DatabaseConnector(
|
||||
dbDatabase=SUBSCRIPTION_DATABASE,
|
||||
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
||||
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
|
||||
dbUser=APP_CONFIG.get("DB_USER"),
|
||||
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
|
||||
)
|
||||
|
||||
def setUserContext(self, currentUser: User, mandateId: str = None):
|
||||
self.currentUser = currentUser
|
||||
self.userId = currentUser.id if currentUser else None
|
||||
self.mandateId = mandateId
|
||||
|
||||
# =========================================================================
|
||||
# Plan catalog (in-memory)
|
||||
# =========================================================================
|
||||
|
||||
def getPlan(self, planKey: str) -> Optional[SubscriptionPlan]:
|
||||
return _getPlan(planKey)
|
||||
|
||||
def getSelectablePlans(self) -> List[SubscriptionPlan]:
|
||||
return _getSelectablePlans()
|
||||
|
||||
# =========================================================================
|
||||
# Read: by ID (primary access pattern)
|
||||
# =========================================================================
|
||||
|
||||
def getById(self, subscriptionId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load a single subscription by its primary key."""
|
||||
try:
|
||||
results = self.db.getRecordset(MandateSubscription, recordFilter={"id": subscriptionId})
|
||||
return dict(results[0]) if results else None
|
||||
except Exception as e:
|
||||
logger.error("getById(%s) failed: %s", subscriptionId, e)
|
||||
return None
|
||||
|
||||
def getByStripeSubscriptionId(self, stripeSubId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load subscription by Stripe subscription ID — the webhook resolution path."""
|
||||
try:
|
||||
results = self.db.getRecordset(MandateSubscription, recordFilter={"stripeSubscriptionId": stripeSubId})
|
||||
return dict(results[0]) if results else None
|
||||
except Exception as e:
|
||||
logger.error("getByStripeSubscriptionId(%s) failed: %s", stripeSubId, e)
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Read: by mandate (list queries)
|
||||
# =========================================================================
|
||||
|
||||
def listForMandate(self, mandateId: str, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]:
|
||||
"""Return all subscriptions for a mandate, optionally filtered by status.
|
||||
Sorted newest-first by startedAt."""
|
||||
try:
|
||||
results = self.db.getRecordset(MandateSubscription, recordFilter={"mandateId": mandateId})
|
||||
rows = [dict(r) for r in results]
|
||||
if statusFilter:
|
||||
filterValues = {s.value for s in statusFilter}
|
||||
rows = [r for r in rows if r.get("status") in filterValues]
|
||||
rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True)
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error("listForMandate(%s) failed: %s", mandateId, e)
|
||||
return []
|
||||
|
||||
def getOperativeForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return the single operative subscription (ACTIVE, TRIALING, or PAST_DUE).
|
||||
This is a read-only query for the billing gate. Returns None if no operative sub exists."""
|
||||
for row in self.listForMandate(mandateId):
|
||||
if row.get("status") in {s.value for s in OPERATIVE_STATUSES}:
|
||||
return row
|
||||
return None
|
||||
|
||||
def getScheduledForMandate(self, mandateId: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return a SCHEDULED subscription if one exists (next sub waiting to start)."""
|
||||
for row in self.listForMandate(mandateId, [SubscriptionStatusEnum.SCHEDULED]):
|
||||
return row
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Read: global (SysAdmin)
|
||||
# =========================================================================
|
||||
|
||||
def listAll(self, statusFilter: List[SubscriptionStatusEnum] = None) -> List[Dict[str, Any]]:
|
||||
"""Return ALL subscriptions across all mandates, newest-first. SysAdmin use only."""
|
||||
try:
|
||||
results = self.db.getRecordset(MandateSubscription)
|
||||
rows = [dict(r) for r in results]
|
||||
if statusFilter:
|
||||
filterValues = {s.value for s in statusFilter}
|
||||
rows = [r for r in rows if r.get("status") in filterValues]
|
||||
rows.sort(key=lambda r: r.get("startedAt", ""), reverse=True)
|
||||
return rows
|
||||
except Exception as e:
|
||||
logger.error("listAll() failed: %s", e)
|
||||
return []
|
||||
|
||||
# =========================================================================
|
||||
# Write: create
|
||||
# =========================================================================
|
||||
|
||||
def createSubscription(self, sub: MandateSubscription) -> Dict[str, Any]:
|
||||
"""Persist a new MandateSubscription record."""
|
||||
return self.db.recordCreate(MandateSubscription, sub.model_dump())
|
||||
|
||||
# =========================================================================
|
||||
# Write: update fields (no status change)
|
||||
# =========================================================================
|
||||
|
||||
def updateFields(self, subscriptionId: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Update non-status fields on a subscription (e.g. recurring, Stripe IDs, periods).
|
||||
Must NOT be used for status changes — use transitionStatus() for that."""
|
||||
if "status" in data:
|
||||
raise ValueError("updateFields must not change status — use transitionStatus()")
|
||||
return self.db.recordModify(MandateSubscription, subscriptionId, data)
|
||||
|
||||
# =========================================================================
|
||||
# Write: status transition (guarded)
|
||||
# =========================================================================
|
||||
|
||||
def transitionStatus(
|
||||
self,
|
||||
subscriptionId: str,
|
||||
expectedFromStatus: SubscriptionStatusEnum,
|
||||
toStatus: SubscriptionStatusEnum,
|
||||
additionalData: Dict[str, Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute a guarded status transition.
|
||||
|
||||
1. Load the record by ID
|
||||
2. Verify current status matches expectedFromStatus
|
||||
3. Verify the transition is allowed by the state machine
|
||||
4. Apply the update
|
||||
"""
|
||||
sub = self.getById(subscriptionId)
|
||||
if not sub:
|
||||
raise ValueError(f"Subscription {subscriptionId} not found")
|
||||
|
||||
currentStatus = sub.get("status", "")
|
||||
if currentStatus != expectedFromStatus.value:
|
||||
raise InvalidTransitionError(subscriptionId, currentStatus, toStatus.value)
|
||||
|
||||
if (expectedFromStatus, toStatus) not in ALLOWED_TRANSITIONS:
|
||||
raise InvalidTransitionError(subscriptionId, expectedFromStatus.value, toStatus.value)
|
||||
|
||||
updateData = {"status": toStatus.value}
|
||||
if toStatus in TERMINAL_STATUSES and not (additionalData or {}).get("endedAt"):
|
||||
updateData["endedAt"] = datetime.now(timezone.utc).isoformat()
|
||||
if additionalData:
|
||||
updateData.update(additionalData)
|
||||
|
||||
result = self.db.recordModify(MandateSubscription, subscriptionId, updateData)
|
||||
logger.info("Transition %s -> %s for subscription %s", expectedFromStatus.value, toStatus.value, subscriptionId)
|
||||
return result
|
||||
|
||||
def forceExpire(self, subscriptionId: str) -> Dict[str, Any]:
|
||||
"""Sysadmin force-expire: ANY non-terminal -> EXPIRED. Bypasses normal transition guards."""
|
||||
sub = self.getById(subscriptionId)
|
||||
if not sub:
|
||||
raise ValueError(f"Subscription {subscriptionId} not found")
|
||||
|
||||
currentStatus = sub.get("status", "")
|
||||
if currentStatus == SubscriptionStatusEnum.EXPIRED.value:
|
||||
raise ValueError(f"Subscription {subscriptionId} is already EXPIRED")
|
||||
|
||||
result = self.db.recordModify(MandateSubscription, subscriptionId, {
|
||||
"status": SubscriptionStatusEnum.EXPIRED.value,
|
||||
"endedAt": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
logger.info("Force-expired subscription %s (was %s)", subscriptionId, currentStatus)
|
||||
return result
|
||||
|
||||
# =========================================================================
|
||||
# Gate: assertActive (read-only, for billing gate)
|
||||
# =========================================================================
|
||||
|
||||
def assertActive(self, mandateId: str) -> SubscriptionStatusEnum:
|
||||
"""Return effective status for billing decisions.
|
||||
Returns the operative subscription's status, or EXPIRED if none exists.
|
||||
This is the ONLY read-by-mandate operation used in the hot path."""
|
||||
sub = self.getOperativeForMandate(mandateId)
|
||||
if sub:
|
||||
return SubscriptionStatusEnum(sub["status"])
|
||||
return SubscriptionStatusEnum.EXPIRED
|
||||
|
||||
# =========================================================================
|
||||
# Gate: assertCapacity
|
||||
# =========================================================================
|
||||
|
||||
def assertCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> bool:
|
||||
sub = self.getOperativeForMandate(mandateId)
|
||||
if not sub:
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
|
||||
raise SubscriptionCapacityException(
|
||||
resourceType=resourceType, currentCount=0, maxAllowed=0,
|
||||
message="No active subscription for this mandate.",
|
||||
)
|
||||
|
||||
plan = self.getPlan(sub.get("planKey", ""))
|
||||
if not plan:
|
||||
return True
|
||||
|
||||
if resourceType == "users":
|
||||
cap = plan.maxUsers
|
||||
if cap is None:
|
||||
return True
|
||||
current = self.countActiveUsers(mandateId)
|
||||
if current + delta > cap:
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
|
||||
raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap)
|
||||
elif resourceType == "featureInstances":
|
||||
cap = plan.maxFeatureInstances
|
||||
if cap is None:
|
||||
return True
|
||||
current = self.countActiveFeatureInstances(mandateId)
|
||||
if current + delta > cap:
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import SubscriptionCapacityException
|
||||
raise SubscriptionCapacityException(resourceType=resourceType, currentCount=current, maxAllowed=cap)
|
||||
|
||||
return True
|
||||
|
||||
# =========================================================================
|
||||
# Counting (cross-DB queries against poweron_app)
|
||||
# =========================================================================
|
||||
|
||||
def countActiveUsers(self, mandateId: str) -> int:
|
||||
try:
|
||||
appDb = _getAppDatabaseConnector()
|
||||
return len(appDb.getRecordset(UserMandate, recordFilter={"mandateId": mandateId}))
|
||||
except Exception as e:
|
||||
logger.error("countActiveUsers(%s) failed: %s", mandateId, e)
|
||||
return 0
|
||||
|
||||
def countActiveFeatureInstances(self, mandateId: str) -> int:
|
||||
try:
|
||||
from modules.datamodels.datamodelFeatures import FeatureInstance
|
||||
appDb = _getAppDatabaseConnector()
|
||||
return len(appDb.getRecordset(FeatureInstance, recordFilter={"mandateId": mandateId, "enabled": True}))
|
||||
except Exception as e:
|
||||
logger.error("countActiveFeatureInstances(%s) failed: %s", mandateId, e)
|
||||
return 0
|
||||
|
||||
# =========================================================================
|
||||
# Stripe quantity sync
|
||||
# =========================================================================
|
||||
|
||||
def syncQuantityToStripe(self, subscriptionId: str) -> None:
|
||||
"""Update Stripe subscription item quantities to match actual active counts.
|
||||
Takes subscriptionId, not mandateId."""
|
||||
sub = self.getById(subscriptionId)
|
||||
if not sub or not sub.get("stripeSubscriptionId"):
|
||||
return
|
||||
|
||||
mandateId = sub["mandateId"]
|
||||
itemIdUsers = sub.get("stripeItemIdUsers")
|
||||
itemIdInstances = sub.get("stripeItemIdInstances")
|
||||
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
|
||||
activeUsers = self.countActiveUsers(mandateId)
|
||||
activeInstances = self.countActiveFeatureInstances(mandateId)
|
||||
|
||||
if itemIdUsers:
|
||||
stripe.SubscriptionItem.modify(
|
||||
itemIdUsers, quantity=max(activeUsers, 0), proration_behavior="create_prorations",
|
||||
)
|
||||
if itemIdInstances:
|
||||
stripe.SubscriptionItem.modify(
|
||||
itemIdInstances, quantity=max(activeInstances, 0), proration_behavior="create_prorations",
|
||||
)
|
||||
|
||||
logger.info("Stripe quantity synced for sub %s: users=%d, instances=%d", subscriptionId, activeUsers, activeInstances)
|
||||
except Exception as e:
|
||||
logger.error("syncQuantityToStripe(%s) failed: %s", subscriptionId, e)
|
||||
|
|
@ -23,10 +23,12 @@ GROUP-Berechtigung:
|
|||
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional, Type
|
||||
import math
|
||||
from typing import List, Dict, Any, Optional, Type, Union
|
||||
from pydantic import BaseModel
|
||||
from modules.datamodels.datamodelRbac import AccessRuleContext
|
||||
from modules.datamodels.datamodelUam import User, UserPermissions, AccessLevel
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResult
|
||||
from modules.security.rbac import RbacClass
|
||||
from modules.security.rootAccess import getRootDbAppConnector
|
||||
|
||||
|
|
@ -72,6 +74,10 @@ TABLE_NAMESPACE = {
|
|||
# Automation - benutzer-eigen
|
||||
"AutomationDefinition": "automation",
|
||||
"AutomationTemplate": "automation",
|
||||
# Automation2 - feature-scoped
|
||||
"Automation2Workflow": "automation2",
|
||||
"Automation2WorkflowRun": "automation2",
|
||||
"Automation2HumanTask": "automation2",
|
||||
# Knowledge Store - benutzer-eigen
|
||||
"FileContentIndex": "knowledge",
|
||||
"ContentChunk": "knowledge",
|
||||
|
|
@ -297,6 +303,309 @@ def getRecordsetWithRBAC(
|
|||
return []
|
||||
|
||||
|
||||
def getRecordsetPaginatedWithRBAC(
|
||||
connector,
|
||||
modelClass: Type[BaseModel],
|
||||
currentUser: User,
|
||||
pagination: Optional[PaginationParams] = None,
|
||||
recordFilter: Dict[str, Any] = None,
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None,
|
||||
enrichPermissions: bool = False,
|
||||
featureCode: Optional[str] = None,
|
||||
) -> Union[List[Dict[str, Any]], PaginatedResult]:
|
||||
"""
|
||||
Get records with RBAC filtering and SQL-level pagination.
|
||||
When pagination is None, returns a plain list (backward compatible).
|
||||
When pagination is provided, returns PaginatedResult with COUNT + LIMIT/OFFSET at SQL level.
|
||||
"""
|
||||
table = modelClass.__name__
|
||||
objectKey = buildDataObjectKey(table, featureCode)
|
||||
effectiveMandateId = mandateId
|
||||
|
||||
try:
|
||||
if not connector._ensureTableExists(modelClass):
|
||||
return PaginatedResult(items=[], totalItems=0, totalPages=0) if pagination else []
|
||||
|
||||
dbApp = getRootDbAppConnector()
|
||||
rbacInstance = RbacClass(connector, dbApp=dbApp)
|
||||
permissions = rbacInstance.getUserPermissions(
|
||||
currentUser,
|
||||
AccessRuleContext.DATA,
|
||||
objectKey,
|
||||
mandateId=effectiveMandateId,
|
||||
featureInstanceId=featureInstanceId
|
||||
)
|
||||
|
||||
if not permissions.view:
|
||||
return PaginatedResult(items=[], totalItems=0, totalPages=0) if pagination else []
|
||||
|
||||
whereConditions = []
|
||||
whereValues = []
|
||||
|
||||
featureInstanceIdForQuery = featureInstanceId
|
||||
if featureInstanceId and hasattr(modelClass, 'model_fields') and "featureInstanceId" not in modelClass.model_fields:
|
||||
featureInstanceIdForQuery = None
|
||||
|
||||
rbacWhereClause = buildRbacWhereClause(
|
||||
permissions, currentUser, table, connector,
|
||||
mandateId=effectiveMandateId,
|
||||
featureInstanceId=featureInstanceIdForQuery
|
||||
)
|
||||
if rbacWhereClause:
|
||||
whereConditions.append(rbacWhereClause["condition"])
|
||||
whereValues.extend(rbacWhereClause["values"])
|
||||
|
||||
if recordFilter:
|
||||
for field, value in recordFilter.items():
|
||||
if isinstance(value, (list, tuple)):
|
||||
if len(value) == 0:
|
||||
whereConditions.append("1 = 0")
|
||||
else:
|
||||
whereConditions.append(f'"{field}" = ANY(%s)')
|
||||
whereValues.append(list(value))
|
||||
elif value is None:
|
||||
whereConditions.append(f'"{field}" IS NULL')
|
||||
else:
|
||||
whereConditions.append(f'"{field}" = %s')
|
||||
whereValues.append(value)
|
||||
|
||||
if pagination and pagination.filters:
|
||||
from modules.connectors.connectorDbPostgre import _get_model_fields
|
||||
fields = _get_model_fields(modelClass)
|
||||
validColumns = set(fields.keys())
|
||||
for key, val in pagination.filters.items():
|
||||
if key == "search" and isinstance(val, str) and val.strip():
|
||||
term = f"%{val.strip()}%"
|
||||
textCols = [c for c, t in fields.items() if t == "TEXT"]
|
||||
if textCols:
|
||||
orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols]
|
||||
whereConditions.append(f"({' OR '.join(orParts)})")
|
||||
whereValues.extend([term] * len(textCols))
|
||||
continue
|
||||
if key not in validColumns:
|
||||
continue
|
||||
if isinstance(val, dict):
|
||||
op = val.get("operator", "equals")
|
||||
v = val.get("value", "")
|
||||
if op in ("equals", "eq"):
|
||||
whereConditions.append(f'"{key}"::TEXT = %s')
|
||||
whereValues.append(str(v))
|
||||
elif op == "contains":
|
||||
whereConditions.append(f'"{key}"::TEXT ILIKE %s')
|
||||
whereValues.append(f"%{v}%")
|
||||
elif op == "startsWith":
|
||||
whereConditions.append(f'"{key}"::TEXT ILIKE %s')
|
||||
whereValues.append(f"{v}%")
|
||||
elif op == "endsWith":
|
||||
whereConditions.append(f'"{key}"::TEXT ILIKE %s')
|
||||
whereValues.append(f"%{v}")
|
||||
elif op in ("gt", "gte", "lt", "lte"):
|
||||
sqlOp = {"gt": ">", "gte": ">=", "lt": "<", "lte": "<="}[op]
|
||||
whereConditions.append(f'"{key}"::TEXT {sqlOp} %s')
|
||||
whereValues.append(str(v))
|
||||
elif op == "between":
|
||||
fromVal = v.get("from", "") if isinstance(v, dict) else ""
|
||||
toVal = v.get("to", "") if isinstance(v, dict) else ""
|
||||
if fromVal and toVal:
|
||||
whereConditions.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s')
|
||||
whereValues.extend([str(fromVal), str(toVal)])
|
||||
elif fromVal:
|
||||
whereConditions.append(f'"{key}"::TEXT >= %s')
|
||||
whereValues.append(str(fromVal))
|
||||
elif toVal:
|
||||
whereConditions.append(f'"{key}"::TEXT <= %s')
|
||||
whereValues.append(str(toVal))
|
||||
else:
|
||||
whereConditions.append(f'"{key}"::TEXT ILIKE %s')
|
||||
whereValues.append(str(val))
|
||||
|
||||
whereClause = " WHERE " + " AND ".join(whereConditions) if whereConditions else ""
|
||||
countValues = list(whereValues)
|
||||
|
||||
orderParts: List[str] = []
|
||||
if pagination and pagination.sort:
|
||||
from modules.connectors.connectorDbPostgre import _get_model_fields
|
||||
validColumns = set(_get_model_fields(modelClass).keys())
|
||||
for sf in pagination.sort:
|
||||
if sf.field in validColumns:
|
||||
direction = "DESC" if sf.direction.lower() == "desc" else "ASC"
|
||||
orderParts.append(f'"{sf.field}" {direction}')
|
||||
if not orderParts:
|
||||
orderParts.append('"id"')
|
||||
orderByClause = " ORDER BY " + ", ".join(orderParts)
|
||||
|
||||
limitClause = ""
|
||||
if pagination:
|
||||
offset = (pagination.page - 1) * pagination.pageSize
|
||||
limitClause = f" LIMIT {pagination.pageSize} OFFSET {offset}"
|
||||
|
||||
with connector.connection.cursor() as cursor:
|
||||
countSql = f'SELECT COUNT(*) FROM "{table}"{whereClause}'
|
||||
cursor.execute(countSql, countValues)
|
||||
totalItems = cursor.fetchone()["count"]
|
||||
|
||||
dataSql = f'SELECT * FROM "{table}"{whereClause}{orderByClause}{limitClause}'
|
||||
cursor.execute(dataSql, whereValues)
|
||||
records = [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
from modules.connectors.connectorDbPostgre import _get_model_fields, _parseRecordFields
|
||||
fields = _get_model_fields(modelClass)
|
||||
for record in records:
|
||||
_parseRecordFields(record, fields, f"table {table}")
|
||||
for fieldName, fieldType in fields.items():
|
||||
if fieldType == "JSONB" and fieldName in record and record[fieldName] is None:
|
||||
modelFields = modelClass.model_fields
|
||||
fieldInfo = modelFields.get(fieldName)
|
||||
if fieldInfo:
|
||||
fieldAnnotation = fieldInfo.annotation
|
||||
if (fieldAnnotation == list or
|
||||
(hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is list)):
|
||||
record[fieldName] = []
|
||||
elif (fieldAnnotation == dict or
|
||||
(hasattr(fieldAnnotation, "__origin__") and fieldAnnotation.__origin__ is dict)):
|
||||
record[fieldName] = {}
|
||||
|
||||
if enrichPermissions:
|
||||
records = _enrichRecordsWithPermissions(records, permissions, currentUser)
|
||||
|
||||
if pagination:
|
||||
pageSize = pagination.pageSize
|
||||
totalPages = math.ceil(totalItems / pageSize) if totalItems > 0 else 0
|
||||
return PaginatedResult(items=records, totalItems=totalItems, totalPages=totalPages)
|
||||
|
||||
return records
|
||||
except Exception as e:
|
||||
logger.error(f"Error in getRecordsetPaginatedWithRBAC for table {table}: {e}")
|
||||
return PaginatedResult(items=[], totalItems=0, totalPages=0) if pagination else []
|
||||
|
||||
|
||||
def getDistinctColumnValuesWithRBAC(
|
||||
connector,
|
||||
modelClass: Type[BaseModel],
|
||||
currentUser: User,
|
||||
column: str,
|
||||
pagination: Optional[PaginationParams] = None,
|
||||
recordFilter: Dict[str, Any] = None,
|
||||
mandateId: Optional[str] = None,
|
||||
featureInstanceId: Optional[str] = None,
|
||||
featureCode: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get sorted distinct values for a column with RBAC filtering at SQL level.
|
||||
Cross-filtering: removes the requested column from active filters.
|
||||
"""
|
||||
import copy
|
||||
table = modelClass.__name__
|
||||
objectKey = buildDataObjectKey(table, featureCode)
|
||||
|
||||
try:
|
||||
if not connector._ensureTableExists(modelClass):
|
||||
return []
|
||||
|
||||
from modules.connectors.connectorDbPostgre import _get_model_fields
|
||||
fields = _get_model_fields(modelClass)
|
||||
if column not in fields:
|
||||
return []
|
||||
|
||||
dbApp = getRootDbAppConnector()
|
||||
rbacInstance = RbacClass(connector, dbApp=dbApp)
|
||||
permissions = rbacInstance.getUserPermissions(
|
||||
currentUser, AccessRuleContext.DATA, objectKey,
|
||||
mandateId=mandateId, featureInstanceId=featureInstanceId
|
||||
)
|
||||
if not permissions.view:
|
||||
return []
|
||||
|
||||
whereConditions = []
|
||||
whereValues = []
|
||||
|
||||
featureInstanceIdForQuery = featureInstanceId
|
||||
if featureInstanceId and hasattr(modelClass, 'model_fields') and "featureInstanceId" not in modelClass.model_fields:
|
||||
featureInstanceIdForQuery = None
|
||||
|
||||
rbacWhereClause = buildRbacWhereClause(
|
||||
permissions, currentUser, table, connector,
|
||||
mandateId=mandateId, featureInstanceId=featureInstanceIdForQuery
|
||||
)
|
||||
if rbacWhereClause:
|
||||
whereConditions.append(rbacWhereClause["condition"])
|
||||
whereValues.extend(rbacWhereClause["values"])
|
||||
|
||||
if recordFilter:
|
||||
for field, value in recordFilter.items():
|
||||
if isinstance(value, (list, tuple)):
|
||||
if not value:
|
||||
whereConditions.append("1 = 0")
|
||||
else:
|
||||
whereConditions.append(f'"{field}" = ANY(%s)')
|
||||
whereValues.append(list(value))
|
||||
elif value is None:
|
||||
whereConditions.append(f'"{field}" IS NULL')
|
||||
else:
|
||||
whereConditions.append(f'"{field}" = %s')
|
||||
whereValues.append(value)
|
||||
|
||||
crossPagination = copy.deepcopy(pagination) if pagination else None
|
||||
if crossPagination and crossPagination.filters:
|
||||
crossPagination.filters.pop(column, None)
|
||||
validColumns = set(fields.keys())
|
||||
for key, val in crossPagination.filters.items():
|
||||
if key == "search" and isinstance(val, str) and val.strip():
|
||||
term = f"%{val.strip()}%"
|
||||
textCols = [c for c, t in fields.items() if t == "TEXT"]
|
||||
if textCols:
|
||||
orParts = [f'COALESCE("{c}"::TEXT, \'\') ILIKE %s' for c in textCols]
|
||||
whereConditions.append(f"({' OR '.join(orParts)})")
|
||||
whereValues.extend([term] * len(textCols))
|
||||
continue
|
||||
if key not in validColumns:
|
||||
continue
|
||||
if isinstance(val, dict):
|
||||
op = val.get("operator", "equals")
|
||||
v = val.get("value", "")
|
||||
if op in ("equals", "eq"):
|
||||
whereConditions.append(f'"{key}"::TEXT = %s')
|
||||
whereValues.append(str(v))
|
||||
elif op == "contains":
|
||||
whereConditions.append(f'"{key}"::TEXT ILIKE %s')
|
||||
whereValues.append(f"%{v}%")
|
||||
elif op == "between":
|
||||
fromVal = v.get("from", "") if isinstance(v, dict) else ""
|
||||
toVal = v.get("to", "") if isinstance(v, dict) else ""
|
||||
if fromVal and toVal:
|
||||
whereConditions.append(f'"{key}"::TEXT >= %s AND "{key}"::TEXT <= %s')
|
||||
whereValues.extend([str(fromVal), str(toVal)])
|
||||
elif fromVal:
|
||||
whereConditions.append(f'"{key}"::TEXT >= %s')
|
||||
whereValues.append(str(fromVal))
|
||||
elif toVal:
|
||||
whereConditions.append(f'"{key}"::TEXT <= %s')
|
||||
whereValues.append(str(toVal))
|
||||
else:
|
||||
whereConditions.append(f'"{key}"::TEXT ILIKE %s')
|
||||
whereValues.append(str(v) if isinstance(v, str) else str(val))
|
||||
else:
|
||||
whereConditions.append(f'"{key}"::TEXT ILIKE %s')
|
||||
whereValues.append(str(val))
|
||||
|
||||
whereClause = " WHERE " + " AND ".join(whereConditions) if whereConditions else ""
|
||||
notNullCond = f'"{column}" IS NOT NULL AND "{column}"::TEXT != \'\''
|
||||
if whereClause:
|
||||
whereClause += f" AND {notNullCond}"
|
||||
else:
|
||||
whereClause = f" WHERE {notNullCond}"
|
||||
|
||||
sql = f'SELECT DISTINCT "{column}"::TEXT AS val FROM "{table}"{whereClause} ORDER BY val'
|
||||
|
||||
with connector.connection.cursor() as cursor:
|
||||
cursor.execute(sql, whereValues)
|
||||
return [row["val"] for row in cursor.fetchall()]
|
||||
except Exception as e:
|
||||
logger.error(f"Error in getDistinctColumnValuesWithRBAC for {table}.{column}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def buildRbacWhereClause(
|
||||
permissions: UserPermissions,
|
||||
currentUser: User,
|
||||
|
|
|
|||
|
|
@ -5,15 +5,19 @@ Admin automation events routes for the backend API.
|
|||
Sysadmin-only endpoints for viewing and controlling scheduler events.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Path, Request, Response
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Depends, Path, Request, Response, Query
|
||||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import status
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
|
||||
# Import interfaces and models from feature containers
|
||||
import modules.features.automation.interfaceFeatureAutomation as interfaceAutomation
|
||||
from modules.auth import limiter, getRequestContext, requireSysAdminRole, RequestContext
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginationMetadata, normalize_pagination_dict
|
||||
from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -31,122 +35,176 @@ router = APIRouter(
|
|||
}
|
||||
)
|
||||
|
||||
def _buildEnrichedAutomationEvents(currentUser: User) -> List[Dict[str, Any]]:
|
||||
"""Build the full enriched automation events list."""
|
||||
from modules.shared.eventManagement import eventManager
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.features.automation.mainAutomation import getAutomationServices
|
||||
|
||||
if not eventManager.scheduler:
|
||||
return []
|
||||
|
||||
jobs = []
|
||||
for job in eventManager.scheduler.get_jobs():
|
||||
if job.id.startswith("automation."):
|
||||
automationId = job.id.replace("automation.", "")
|
||||
jobs.append({
|
||||
"eventId": job.id,
|
||||
"id": job.id,
|
||||
"automationId": automationId,
|
||||
"nextRunTime": str(job.next_run_time) if job.next_run_time else None,
|
||||
"trigger": str(job.trigger) if job.trigger else None,
|
||||
"name": "",
|
||||
"createdBy": "",
|
||||
"mandate": "",
|
||||
"featureInstance": ""
|
||||
})
|
||||
|
||||
if jobs:
|
||||
try:
|
||||
rootInterface = getRootInterface()
|
||||
eventUser = rootInterface.getUserByUsername("event")
|
||||
if eventUser:
|
||||
services = getAutomationServices(currentUser, mandateId=None, featureInstanceId=None)
|
||||
allAutomations = services.interfaceDbAutomation.getAllAutomationDefinitionsWithRBAC(eventUser)
|
||||
|
||||
automationLookup = {}
|
||||
for a in allAutomations:
|
||||
aId = a.get("id", "") if isinstance(a, dict) else getattr(a, "id", "")
|
||||
automationLookup[aId] = a
|
||||
|
||||
_userCache: Dict[str, str] = {}
|
||||
_mandateCache: Dict[str, str] = {}
|
||||
_featureCache: Dict[str, str] = {}
|
||||
|
||||
def _resolveUsername(userId):
|
||||
if not userId: return ""
|
||||
if userId not in _userCache:
|
||||
try:
|
||||
user = rootInterface.getUser(userId)
|
||||
_userCache[userId] = user.username if user else userId[:8]
|
||||
except Exception:
|
||||
_userCache[userId] = userId[:8]
|
||||
return _userCache[userId]
|
||||
|
||||
def _resolveMandateLabel(mandateId):
|
||||
if not mandateId: return ""
|
||||
if mandateId not in _mandateCache:
|
||||
try:
|
||||
mandate = rootInterface.getMandate(mandateId)
|
||||
_mandateCache[mandateId] = getattr(mandate, "label", None) or mandateId[:8]
|
||||
except Exception:
|
||||
_mandateCache[mandateId] = mandateId[:8]
|
||||
return _mandateCache[mandateId]
|
||||
|
||||
def _resolveFeatureLabel(featureInstanceId):
|
||||
if not featureInstanceId: return ""
|
||||
if featureInstanceId not in _featureCache:
|
||||
try:
|
||||
instance = rootInterface.getFeatureInstance(featureInstanceId)
|
||||
_featureCache[featureInstanceId] = getattr(instance, "label", None) or getattr(instance, "featureCode", None) or featureInstanceId[:8]
|
||||
except Exception:
|
||||
_featureCache[featureInstanceId] = featureInstanceId[:8]
|
||||
return _featureCache[featureInstanceId]
|
||||
|
||||
for job in jobs:
|
||||
automation = automationLookup.get(job["automationId"])
|
||||
if automation:
|
||||
if isinstance(automation, dict):
|
||||
job["name"] = automation.get("label", "")
|
||||
job["createdBy"] = _resolveUsername(automation.get("_createdBy", ""))
|
||||
job["mandate"] = _resolveMandateLabel(automation.get("mandateId", ""))
|
||||
job["featureInstance"] = _resolveFeatureLabel(automation.get("featureInstanceId", ""))
|
||||
else:
|
||||
job["name"] = getattr(automation, "label", "")
|
||||
job["createdBy"] = _resolveUsername(getattr(automation, "_createdBy", ""))
|
||||
job["mandate"] = _resolveMandateLabel(getattr(automation, "mandateId", ""))
|
||||
job["featureInstance"] = _resolveFeatureLabel(getattr(automation, "featureInstanceId", ""))
|
||||
else:
|
||||
job["name"] = "(orphaned)"
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not enrich automation events with context: {e}")
|
||||
|
||||
return jobs
|
||||
|
||||
|
||||
@router.get("")
|
||||
@limiter.limit("30/minute")
|
||||
def get_all_automation_events(
|
||||
request: Request,
|
||||
currentUser: User = Depends(requireSysAdminRole)
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all active scheduler jobs (sysadmin only).
|
||||
Each job is enriched with context from its automation definition
|
||||
(name, mandate, feature instance, creator) for readability.
|
||||
"""
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
|
||||
currentUser: User = Depends(requireSysAdminRole),
|
||||
):
|
||||
"""Get all active scheduler jobs with pagination support (sysadmin only)."""
|
||||
try:
|
||||
from modules.shared.eventManagement import eventManager
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.features.automation.mainAutomation import getAutomationServices
|
||||
|
||||
if not eventManager.scheduler:
|
||||
return []
|
||||
|
||||
# 1. Collect all scheduler jobs
|
||||
jobs = []
|
||||
automationIds = []
|
||||
for job in eventManager.scheduler.get_jobs():
|
||||
if job.id.startswith("automation."):
|
||||
automationId = job.id.replace("automation.", "")
|
||||
automationIds.append(automationId)
|
||||
jobs.append({
|
||||
"eventId": job.id,
|
||||
"automationId": automationId,
|
||||
"nextRunTime": str(job.next_run_time) if job.next_run_time else None,
|
||||
"trigger": str(job.trigger) if job.trigger else None,
|
||||
"name": "",
|
||||
"createdBy": "",
|
||||
"mandate": "",
|
||||
"featureInstance": ""
|
||||
})
|
||||
|
||||
# 2. Enrich with context from automation definitions
|
||||
if jobs:
|
||||
paginationParams: Optional[PaginationParams] = None
|
||||
if pagination:
|
||||
try:
|
||||
rootInterface = getRootInterface()
|
||||
eventUser = rootInterface.getUserByUsername("event")
|
||||
if eventUser:
|
||||
services = getAutomationServices(currentUser, mandateId=None, featureInstanceId=None)
|
||||
allAutomations = services.interfaceDbAutomation.getAllAutomationDefinitionsWithRBAC(eventUser)
|
||||
|
||||
# Build lookup by automation ID
|
||||
automationLookup = {}
|
||||
for a in allAutomations:
|
||||
aId = a.get("id", "") if isinstance(a, dict) else getattr(a, "id", "")
|
||||
automationLookup[aId] = a
|
||||
|
||||
# Caches for resolving UUIDs to names
|
||||
_userCache = {}
|
||||
_mandateCache = {}
|
||||
_featureCache = {}
|
||||
|
||||
def _resolveUsername(userId):
|
||||
if not userId:
|
||||
return ""
|
||||
if userId not in _userCache:
|
||||
try:
|
||||
user = rootInterface.getUser(userId)
|
||||
_userCache[userId] = user.username if user else userId[:8]
|
||||
except Exception:
|
||||
_userCache[userId] = userId[:8]
|
||||
return _userCache[userId]
|
||||
|
||||
def _resolveMandateLabel(mandateId):
|
||||
if not mandateId:
|
||||
return ""
|
||||
if mandateId not in _mandateCache:
|
||||
try:
|
||||
mandate = rootInterface.getMandate(mandateId)
|
||||
_mandateCache[mandateId] = getattr(mandate, "label", None) or mandateId[:8]
|
||||
except Exception:
|
||||
_mandateCache[mandateId] = mandateId[:8]
|
||||
return _mandateCache[mandateId]
|
||||
|
||||
def _resolveFeatureLabel(featureInstanceId):
|
||||
if not featureInstanceId:
|
||||
return ""
|
||||
if featureInstanceId not in _featureCache:
|
||||
try:
|
||||
instance = rootInterface.getFeatureInstance(featureInstanceId)
|
||||
_featureCache[featureInstanceId] = getattr(instance, "label", None) or getattr(instance, "featureCode", None) or featureInstanceId[:8]
|
||||
except Exception:
|
||||
_featureCache[featureInstanceId] = featureInstanceId[:8]
|
||||
return _featureCache[featureInstanceId]
|
||||
|
||||
# Enrich each job
|
||||
for job in jobs:
|
||||
automation = automationLookup.get(job["automationId"])
|
||||
if automation:
|
||||
if isinstance(automation, dict):
|
||||
job["name"] = automation.get("label", "")
|
||||
job["createdBy"] = _resolveUsername(automation.get("_createdBy", ""))
|
||||
job["mandate"] = _resolveMandateLabel(automation.get("mandateId", ""))
|
||||
job["featureInstance"] = _resolveFeatureLabel(automation.get("featureInstanceId", ""))
|
||||
else:
|
||||
job["name"] = getattr(automation, "label", "")
|
||||
job["createdBy"] = _resolveUsername(getattr(automation, "_createdBy", ""))
|
||||
job["mandate"] = _resolveMandateLabel(getattr(automation, "mandateId", ""))
|
||||
job["featureInstance"] = _resolveFeatureLabel(getattr(automation, "featureInstanceId", ""))
|
||||
else:
|
||||
job["name"] = "(orphaned)"
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not enrich automation events with context: {e}")
|
||||
|
||||
return jobs
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
||||
|
||||
enriched = _buildEnrichedAutomationEvents(currentUser)
|
||||
filtered = _applyFiltersAndSort(enriched, paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
totalItems = len(filtered)
|
||||
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
||||
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||
endIdx = startIdx + paginationParams.pageSize
|
||||
return {
|
||||
"items": filtered[startIdx:endIdx],
|
||||
"pagination": PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters,
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
return {"items": enriched, "pagination": None}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting automation events: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error getting automation events: {str(e)}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=f"Error getting automation events: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_automation_event_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
currentUser: User = Depends(requireSysAdminRole),
|
||||
):
|
||||
"""Return distinct filter values for a column in automation events."""
|
||||
try:
|
||||
crossFilterParams: Optional[PaginationParams] = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
enriched = _buildEnrichedAutomationEvents(currentUser)
|
||||
crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
|
||||
return _extractDistinctValues(crossFiltered, column)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/sync")
|
||||
@limiter.limit("5/minute")
|
||||
|
|
|
|||
207
modules/routes/routeAdminAutomationLogs.py
Normal file
207
modules/routes/routeAdminAutomationLogs.py
Normal file
|
|
@ -0,0 +1,207 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Admin automation execution logs routes.
|
||||
SysAdmin-only endpoints for viewing consolidated automation execution history
|
||||
across all mandates and feature instances.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request, Query
|
||||
from typing import List, Dict, Any, Optional
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
import uuid
|
||||
|
||||
from modules.auth import limiter, requireSysAdminRole
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginationMetadata, normalize_pagination_dict
|
||||
from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/admin/automation-logs",
|
||||
tags=["Admin Automation Logs"],
|
||||
responses={
|
||||
401: {"description": "Unauthorized"},
|
||||
403: {"description": "Forbidden - Sysadmin only"},
|
||||
500: {"description": "Internal server error"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _buildFlattenedExecutionLogs(currentUser: User) -> List[Dict[str, Any]]:
|
||||
"""Flatten executionLogs from all AutomationDefinitions across all mandates.
|
||||
Called from a SysAdmin-only endpoint — bypasses RBAC, reads directly from DB."""
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.features.automation.mainAutomation import getAutomationServices
|
||||
from modules.features.automation.datamodelFeatureAutomation import AutomationDefinition
|
||||
|
||||
rootInterface = getRootInterface()
|
||||
services = getAutomationServices(currentUser, mandateId=None, featureInstanceId=None)
|
||||
allAutomations = services.interfaceDbAutomation.db.getRecordset(AutomationDefinition)
|
||||
|
||||
userCache: Dict[str, str] = {}
|
||||
mandateCache: Dict[str, str] = {}
|
||||
featureCache: Dict[str, str] = {}
|
||||
|
||||
def _resolveUsername(userId: str) -> str:
|
||||
if not userId:
|
||||
return ""
|
||||
if userId not in userCache:
|
||||
try:
|
||||
user = rootInterface.getUser(userId)
|
||||
userCache[userId] = user.username if user else userId[:8]
|
||||
except Exception:
|
||||
userCache[userId] = userId[:8]
|
||||
return userCache[userId]
|
||||
|
||||
def _resolveMandateLabel(mandateId: str) -> str:
|
||||
if not mandateId:
|
||||
return ""
|
||||
if mandateId not in mandateCache:
|
||||
try:
|
||||
mandate = rootInterface.getMandate(mandateId)
|
||||
mandateCache[mandateId] = getattr(mandate, "label", None) or mandateId[:8]
|
||||
except Exception:
|
||||
mandateCache[mandateId] = mandateId[:8]
|
||||
return mandateCache[mandateId]
|
||||
|
||||
def _resolveFeatureLabel(featureInstanceId: str) -> str:
|
||||
if not featureInstanceId:
|
||||
return ""
|
||||
if featureInstanceId not in featureCache:
|
||||
try:
|
||||
instance = rootInterface.getFeatureInstance(featureInstanceId)
|
||||
featureCache[featureInstanceId] = (
|
||||
getattr(instance, "label", None)
|
||||
or getattr(instance, "featureCode", None)
|
||||
or featureInstanceId[:8]
|
||||
)
|
||||
except Exception:
|
||||
featureCache[featureInstanceId] = featureInstanceId[:8]
|
||||
return featureCache[featureInstanceId]
|
||||
|
||||
flatLogs: List[Dict[str, Any]] = []
|
||||
|
||||
for automation in allAutomations:
|
||||
if isinstance(automation, dict):
|
||||
automationId = automation.get("id", "")
|
||||
automationLabel = automation.get("label", "")
|
||||
mandateId = automation.get("mandateId", "")
|
||||
featureInstanceId = automation.get("featureInstanceId", "")
|
||||
createdBy = automation.get("_createdBy", "")
|
||||
logs = automation.get("executionLogs") or []
|
||||
else:
|
||||
automationId = getattr(automation, "id", "")
|
||||
automationLabel = getattr(automation, "label", "")
|
||||
mandateId = getattr(automation, "mandateId", "")
|
||||
featureInstanceId = getattr(automation, "featureInstanceId", "")
|
||||
createdBy = getattr(automation, "_createdBy", "")
|
||||
logs = getattr(automation, "executionLogs", None) or []
|
||||
|
||||
mandateName = _resolveMandateLabel(mandateId)
|
||||
featureInstanceName = _resolveFeatureLabel(featureInstanceId)
|
||||
executedByName = _resolveUsername(createdBy)
|
||||
|
||||
for log in logs:
|
||||
timestamp = log.get("timestamp", 0) if isinstance(log, dict) else 0
|
||||
status = log.get("status", "") if isinstance(log, dict) else ""
|
||||
workflowId = log.get("workflowId", "") if isinstance(log, dict) else ""
|
||||
messages = log.get("messages", []) if isinstance(log, dict) else []
|
||||
|
||||
flatLogs.append({
|
||||
"id": str(uuid.uuid4()),
|
||||
"timestamp": timestamp,
|
||||
"automationId": automationId,
|
||||
"automationLabel": automationLabel,
|
||||
"mandateName": mandateName,
|
||||
"featureInstanceName": featureInstanceName,
|
||||
"executedBy": executedByName,
|
||||
"status": status,
|
||||
"workflowId": workflowId,
|
||||
"messages": "; ".join(messages) if messages else "",
|
||||
})
|
||||
|
||||
flatLogs.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
|
||||
return flatLogs
|
||||
|
||||
|
||||
@router.get("")
|
||||
@limiter.limit("30/minute")
|
||||
def get_all_automation_logs(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
|
||||
currentUser: User = Depends(requireSysAdminRole),
|
||||
):
|
||||
"""Get consolidated execution logs from all automations (sysadmin only)."""
|
||||
try:
|
||||
paginationParams: Optional[PaginationParams] = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
||||
|
||||
logs = _buildFlattenedExecutionLogs(currentUser)
|
||||
filtered = _applyFiltersAndSort(logs, paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
totalItems = len(filtered)
|
||||
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
||||
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||
endIdx = startIdx + paginationParams.pageSize
|
||||
return {
|
||||
"items": filtered[startIdx:endIdx],
|
||||
"pagination": PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters,
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
return {"items": logs, "pagination": None}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting automation logs: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Error getting automation logs: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_automation_log_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
currentUser: User = Depends(requireSysAdminRole),
|
||||
):
|
||||
"""Return distinct filter values for a column in automation logs."""
|
||||
try:
|
||||
crossFilterParams: Optional[PaginationParams] = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
logs = _buildFlattenedExecutionLogs(currentUser)
|
||||
crossFiltered = _applyFiltersAndSort(logs, crossFilterParams)
|
||||
return _extractDistinctValues(crossFiltered, column)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -14,7 +14,11 @@ from fastapi import APIRouter, HTTPException, Depends, Request, Query
|
|||
from typing import List, Dict, Any, Optional
|
||||
from fastapi import status
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
from pydantic import BaseModel, Field
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginationMetadata, normalize_pagination_dict
|
||||
from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
|
||||
|
||||
from modules.auth import limiter, getRequestContext, RequestContext, requireSysAdminRole
|
||||
from modules.datamodels.datamodelUam import User, UserInDB
|
||||
|
|
@ -433,6 +437,35 @@ def list_feature_instances(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/instances/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_feature_instance_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
featureCode: Optional[str] = Query(None, description="Filter by feature code"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in feature instances."""
|
||||
if not context.mandateId:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="X-Mandate-Id header is required")
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
rootInterface = getRootInterface()
|
||||
featureInterface = getFeatureInterface(rootInterface.db)
|
||||
instances = featureInterface.getFeatureInstancesForMandate(
|
||||
mandateId=str(context.mandateId),
|
||||
featureCode=featureCode
|
||||
)
|
||||
items = [inst.model_dump() for inst in instances]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for feature instances: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/instances/{instanceId}", response_model=Dict[str, Any])
|
||||
@limiter.limit("60/minute")
|
||||
def get_feature_instance(
|
||||
|
|
@ -518,15 +551,40 @@ def create_feature_instance(
|
|||
detail=f"Feature '{data.featureCode}' not found"
|
||||
)
|
||||
|
||||
# Subscription capacity check
|
||||
mandateIdStr = str(context.mandateId)
|
||||
try:
|
||||
from modules.interfaces.interfaceDbSubscription import getInterface as _getSubIf
|
||||
from modules.security.rootAccess import getRootUser
|
||||
_subIf = _getSubIf(getRootUser(), mandateIdStr)
|
||||
_subIf.assertCapacity(mandateIdStr, "featureInstances", delta=1)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as capErr:
|
||||
if "SubscriptionCapacityException" in type(capErr).__name__:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(capErr),
|
||||
)
|
||||
|
||||
instance = featureInterface.createFeatureInstance(
|
||||
featureCode=data.featureCode,
|
||||
mandateId=str(context.mandateId),
|
||||
mandateId=mandateIdStr,
|
||||
label=data.label,
|
||||
enabled=data.enabled,
|
||||
copyTemplateRoles=data.copyTemplateRoles,
|
||||
config=data.config
|
||||
)
|
||||
|
||||
|
||||
# Sync Stripe quantity after successful creation
|
||||
try:
|
||||
from modules.interfaces.interfaceDbSubscription import getInterface as _getSubIf2
|
||||
from modules.security.rootAccess import getRootUser as _getRU
|
||||
_subIf2 = _getSubIf2(_getRU(), mandateIdStr)
|
||||
_subIf2.syncQuantityToStripe(mandateIdStr)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"User {context.user.id} created feature instance '{data.label}' "
|
||||
f"for feature '{data.featureCode}' in mandate {context.mandateId}"
|
||||
|
|
@ -758,27 +816,55 @@ def sync_instance_roles(
|
|||
# Template Role Endpoints (SysAdmin only)
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/templates/roles", response_model=List[Dict[str, Any]])
|
||||
def _buildTemplateRolesList(featureCode: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Build the full template roles list."""
|
||||
rootInterface = getRootInterface()
|
||||
featureInterface = getFeatureInterface(rootInterface.db)
|
||||
roles = featureInterface.getTemplateRoles(featureCode)
|
||||
return [r.model_dump() for r in roles]
|
||||
|
||||
|
||||
@router.get("/templates/roles")
|
||||
@limiter.limit("60/minute")
|
||||
def list_template_roles(
|
||||
request: Request,
|
||||
featureCode: Optional[str] = Query(None, description="Filter by feature code"),
|
||||
sysAdmin: User = Depends(requireSysAdminRole)
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List global template roles.
|
||||
|
||||
SysAdmin only - returns template roles that are copied to new feature instances.
|
||||
|
||||
Args:
|
||||
featureCode: Optional filter by feature code
|
||||
"""
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
|
||||
sysAdmin: User = Depends(requireSysAdminRole),
|
||||
):
|
||||
"""List global template roles with pagination support."""
|
||||
try:
|
||||
rootInterface = getRootInterface()
|
||||
featureInterface = getFeatureInterface(rootInterface.db)
|
||||
|
||||
roles = featureInterface.getTemplateRoles(featureCode)
|
||||
return [r.model_dump() for r in roles]
|
||||
paginationParams: Optional[PaginationParams] = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
||||
|
||||
enriched = _buildTemplateRolesList(featureCode)
|
||||
filtered = _applyFiltersAndSort(enriched, paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
totalItems = len(filtered)
|
||||
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
||||
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||
endIdx = startIdx + paginationParams.pageSize
|
||||
return {
|
||||
"items": filtered[startIdx:endIdx],
|
||||
"pagination": PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters,
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
return {"items": enriched, "pagination": None}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing template roles: {e}")
|
||||
|
|
@ -788,6 +874,39 @@ def list_template_roles(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/templates/roles/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_template_role_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
featureCode: Optional[str] = Query(None, description="Filter by feature code"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
sysAdmin: User = Depends(requireSysAdminRole),
|
||||
):
|
||||
"""Return distinct filter values for a column in template roles."""
|
||||
try:
|
||||
crossFilterParams: Optional[PaginationParams] = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
enriched = _buildTemplateRolesList(featureCode)
|
||||
crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
|
||||
return _extractDistinctValues(crossFiltered, column)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/templates/roles", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
def create_template_role(
|
||||
|
|
@ -951,6 +1070,56 @@ def list_feature_instance_users(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/instances/{instanceId}/users/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_feature_instance_users_filter_values(
|
||||
request: Request,
|
||||
instanceId: str,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in feature instance users."""
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
rootInterface = getRootInterface()
|
||||
featureInterface = getFeatureInterface(rootInterface.db)
|
||||
instance = featureInterface.getFeatureInstance(instanceId)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Feature instance '{instanceId}' not found")
|
||||
if context.mandateId and str(instance.mandateId) != str(context.mandateId):
|
||||
if not context.hasSysAdminRole:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied to this feature instance")
|
||||
featureAccesses = rootInterface.getFeatureAccessesByInstance(instanceId)
|
||||
result = []
|
||||
for fa in featureAccesses:
|
||||
user = rootInterface.getUser(str(fa.userId))
|
||||
if not user:
|
||||
continue
|
||||
roleIds = rootInterface.getRoleIdsForFeatureAccess(str(fa.id))
|
||||
roleLabels = []
|
||||
for roleId in roleIds:
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
roleLabels.append(role.roleLabel)
|
||||
result.append({
|
||||
"id": str(fa.id),
|
||||
"userId": str(fa.userId),
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"fullName": user.fullName,
|
||||
"roleIds": roleIds,
|
||||
"roleLabels": roleLabels,
|
||||
"enabled": fa.enabled
|
||||
})
|
||||
return _handleFilterValuesRequest(result, column, pagination)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for feature instance users: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/instances/{instanceId}/users", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
def add_user_to_feature_instance(
|
||||
|
|
@ -1239,7 +1408,7 @@ def update_feature_instance_user_roles(
|
|||
"userId": userId,
|
||||
"featureInstanceId": instanceId,
|
||||
"roleIds": data.roleIds,
|
||||
"enabled": data.enabled if data.enabled is not None else existingAccess[0].get("enabled", True)
|
||||
"enabled": data.enabled if data.enabled is not None else bool(existingAccess.enabled),
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
|
|
|
|||
|
|
@ -394,7 +394,10 @@ def get_access_rules(
|
|||
)
|
||||
|
||||
# Get rules with optional pagination
|
||||
# MandateAdmin: fetch all then filter by admin's mandates
|
||||
# MandateAdmin: fetch all then filter by admin's mandates.
|
||||
# NOTE: Cannot use DB-level pagination for MandateAdmin because
|
||||
# _isRoleInAdminMandates requires joining Role → mandateId which
|
||||
# isn't expressible via getRecordsetPaginated's recordFilter.
|
||||
if not isSysAdmin:
|
||||
allRules = interface.getAccessRules(
|
||||
roleLabel=roleLabel,
|
||||
|
|
@ -814,6 +817,11 @@ def list_roles(
|
|||
By default, only returns true global roles (mandateId=None, featureInstanceId=None, featureCode=None).
|
||||
Feature template roles are managed via /api/features/templates/roles.
|
||||
|
||||
NOTE: Base query (getAllRoles) already uses db.getRecordsetPaginated() internally.
|
||||
However pagination=None is passed here because post-processing adds computed fields
|
||||
(userCount, scopeType) and applies scope/mandate/template filtering that cannot run
|
||||
at the DB level. In-memory pagination is applied after all transformations.
|
||||
|
||||
Args:
|
||||
pagination: Optional pagination parameters (includes search, filters, sort)
|
||||
includeTemplates: If True, also include feature template roles (featureCode != None)
|
||||
|
|
@ -983,6 +991,77 @@ def list_roles(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/roles/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_roles_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
includeTemplates: bool = Query(False, description="Include feature template roles"),
|
||||
mandateId: Optional[str] = Query(None, description="Include mandate-specific roles for this mandate"),
|
||||
scopeFilter: Optional[str] = Query(None, description="Filter by scope: 'all', 'mandate', 'global', 'system'"),
|
||||
reqContext: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in roles."""
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
isSysAdmin = reqContext.hasSysAdminRole
|
||||
adminMandateIds = [] if isSysAdmin else _getAdminMandateIds(reqContext)
|
||||
if not isSysAdmin and not adminMandateIds:
|
||||
raise HTTPException(status_code=403, detail="Admin role required")
|
||||
|
||||
interface = getRootInterface()
|
||||
dbRoles = interface.getAllRoles(pagination=None)
|
||||
roleCounts = interface.countRoleAssignments()
|
||||
|
||||
def _computeScopeType(role) -> str:
|
||||
if role.mandateId:
|
||||
return "mandate"
|
||||
if role.isSystemRole:
|
||||
return "system"
|
||||
return "global"
|
||||
|
||||
result = []
|
||||
for role in dbRoles:
|
||||
if role.featureInstanceId is not None:
|
||||
continue
|
||||
if mandateId:
|
||||
if role.mandateId != mandateId:
|
||||
continue
|
||||
else:
|
||||
if role.mandateId is not None:
|
||||
continue
|
||||
if not includeTemplates and role.featureCode is not None:
|
||||
continue
|
||||
scopeType = _computeScopeType(role)
|
||||
if scopeFilter and scopeFilter != 'all':
|
||||
if scopeFilter == 'mandate' and scopeType != 'mandate':
|
||||
continue
|
||||
if scopeFilter == 'global' and scopeType not in ('global', 'system'):
|
||||
continue
|
||||
if scopeFilter == 'system' and scopeType != 'system':
|
||||
continue
|
||||
result.append({
|
||||
"id": role.id,
|
||||
"roleLabel": role.roleLabel,
|
||||
"description": role.description,
|
||||
"mandateId": role.mandateId,
|
||||
"featureInstanceId": role.featureInstanceId,
|
||||
"featureCode": role.featureCode,
|
||||
"userCount": roleCounts.get(str(role.id), 0),
|
||||
"isSystemRole": role.isSystemRole,
|
||||
"scopeType": scopeType
|
||||
})
|
||||
if not isSysAdmin:
|
||||
result = [r for r in result if r.get("mandateId") and str(r["mandateId"]) in adminMandateIds]
|
||||
return _handleFilterValuesRequest(result, column, pagination)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for roles: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/roles", response_model=Dict[str, Any])
|
||||
@limiter.limit("30/minute")
|
||||
def create_role(
|
||||
|
|
|
|||
|
|
@ -278,7 +278,8 @@ def getUserAccessOverview(
|
|||
# Get mandate name
|
||||
mandate = interface.getMandate(umMandateId)
|
||||
mandateName = mandate.name if mandate else umMandateId
|
||||
|
||||
mandateLabel = (mandate.label or None) if mandate else None
|
||||
|
||||
# Get roles for this UserMandate using interface method
|
||||
umRoles = interface.getUserMandateRoles(umId)
|
||||
|
||||
|
|
@ -368,6 +369,7 @@ def getUserAccessOverview(
|
|||
mandatesInfo.append({
|
||||
"id": umMandateId,
|
||||
"name": mandateName,
|
||||
"label": mandateLabel,
|
||||
"roleIds": mandateRoleIds,
|
||||
"featureInstances": featureInstancesInfo,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -22,8 +22,10 @@ from modules.auth import limiter, requireSysAdminRole, getRequestContext, Reques
|
|||
# Import billing components
|
||||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface, _getRootInterface
|
||||
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import getService as getBillingService
|
||||
import json
|
||||
import math
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
from modules.routes.routeDataUsers import _applyFiltersAndSort
|
||||
from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues, _handleFilterValuesRequest
|
||||
from modules.datamodels.datamodelBilling import (
|
||||
BillingAccount,
|
||||
BillingTransaction,
|
||||
|
|
@ -226,9 +228,9 @@ def _filterTransactionsByScope(transactions: list, scope: BillingDataScope) -> l
|
|||
# =============================================================================
|
||||
|
||||
class CreditAddRequest(BaseModel):
|
||||
"""Request model for adding credit to an account."""
|
||||
"""Request model for adding or deducting credit from an account."""
|
||||
userId: Optional[str] = Field(None, description="Target user ID (for PREPAY_USER model)")
|
||||
amount: float = Field(..., gt=0, description="Amount to credit in CHF")
|
||||
amount: float = Field(..., description="Amount in CHF. Positive = credit, negative = deduction. Must not be zero.")
|
||||
description: str = Field(default="Manual credit", description="Transaction description")
|
||||
|
||||
|
||||
|
|
@ -358,19 +360,8 @@ class UserTransactionResponse(BaseModel):
|
|||
|
||||
def _getStripeClient():
|
||||
"""Initialize and return configured Stripe SDK module."""
|
||||
import stripe
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
api_version = APP_CONFIG.get("STRIPE_API_VERSION")
|
||||
if api_version:
|
||||
stripe.api_version = api_version
|
||||
|
||||
secret_key = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY")
|
||||
if not secret_key:
|
||||
raise ValueError("STRIPE_SECRET_KEY_SECRET not configured")
|
||||
|
||||
stripe.api_key = secret_key
|
||||
return stripe
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
return getStripeClient()
|
||||
|
||||
|
||||
def _creditStripeSessionIfNeeded(
|
||||
|
|
@ -835,20 +826,27 @@ def addCredit(
|
|||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Cannot add credit to {billingModel.value} billing model")
|
||||
|
||||
# Create credit transaction
|
||||
if creditRequest.amount == 0:
|
||||
raise HTTPException(status_code=400, detail="Amount must not be zero")
|
||||
|
||||
from modules.datamodels.datamodelBilling import BillingTransaction
|
||||
|
||||
|
||||
isDeduction = creditRequest.amount < 0
|
||||
txType = TransactionTypeEnum.DEBIT if isDeduction else TransactionTypeEnum.CREDIT
|
||||
absAmount = abs(creditRequest.amount)
|
||||
|
||||
transaction = BillingTransaction(
|
||||
accountId=account["id"],
|
||||
transactionType=TransactionTypeEnum.CREDIT,
|
||||
amount=creditRequest.amount,
|
||||
transactionType=txType,
|
||||
amount=absAmount,
|
||||
description=creditRequest.description,
|
||||
referenceType=ReferenceTypeEnum.ADMIN
|
||||
)
|
||||
|
||||
result = billingInterface.createTransaction(transaction)
|
||||
|
||||
logger.info(f"Added {creditRequest.amount} CHF credit to account {account['id']} in mandate {targetMandateId}")
|
||||
action = "Deducted" if isDeduction else "Added"
|
||||
logger.info(f"{action} {absAmount} CHF to account {account['id']} in mandate {targetMandateId}")
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -1006,13 +1004,33 @@ async def stripeWebhook(
|
|||
|
||||
logger.info(f"Stripe webhook received: event={event.id}, type={event.type}")
|
||||
|
||||
accepted_event_types = {"checkout.session.completed", "checkout.session.async_payment_succeeded"}
|
||||
if event.type not in accepted_event_types:
|
||||
# Subscription-related events
|
||||
subscriptionEventTypes = {
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
"invoice.paid",
|
||||
"invoice.payment_failed",
|
||||
"customer.subscription.trial_will_end",
|
||||
}
|
||||
|
||||
# Checkout events (existing)
|
||||
checkoutEventTypes = {"checkout.session.completed", "checkout.session.async_payment_succeeded"}
|
||||
|
||||
if event.type in subscriptionEventTypes:
|
||||
_handleSubscriptionWebhook(event)
|
||||
return {"received": True}
|
||||
|
||||
|
||||
if event.type not in checkoutEventTypes:
|
||||
return {"received": True}
|
||||
|
||||
session = event.data.object
|
||||
event_id = event.id
|
||||
|
||||
sessionMode = session.get("mode") if hasattr(session, "get") else getattr(session, "mode", None)
|
||||
if sessionMode == "subscription":
|
||||
_handleSubscriptionCheckoutCompleted(session, event_id)
|
||||
return {"received": True}
|
||||
|
||||
billingInterface = _getRootInterface()
|
||||
if billingInterface.getStripeWebhookEventByEventId(event_id):
|
||||
logger.info(f"Stripe event {event_id} already processed, skipping")
|
||||
|
|
@ -1027,6 +1045,257 @@ async def stripeWebhook(
|
|||
return {"received": True}
|
||||
|
||||
|
||||
def _handleSubscriptionCheckoutCompleted(session, eventId: str) -> None:
|
||||
"""Handle checkout.session.completed for mode=subscription.
|
||||
Resolves the local PENDING record by ID from webhook metadata and transitions it."""
|
||||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
|
||||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, _getPlan
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
getService as getSubscriptionService,
|
||||
_notifySubscriptionChange,
|
||||
)
|
||||
from modules.security.rootAccess import getRootUser
|
||||
from datetime import datetime, timezone
|
||||
|
||||
metadata = {}
|
||||
if hasattr(session, "get"):
|
||||
metadata = session.get("metadata") or {}
|
||||
subscriptionRecordId = metadata.get("subscriptionRecordId")
|
||||
mandateId = metadata.get("mandateId")
|
||||
planKey = metadata.get("planKey", "")
|
||||
|
||||
platformUrl = metadata.get("platformUrl", "")
|
||||
|
||||
if not subscriptionRecordId:
|
||||
stripeSub = session.get("subscription")
|
||||
if stripeSub:
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
subObj = stripe.Subscription.retrieve(stripeSub)
|
||||
metadata = subObj.get("metadata") or {}
|
||||
subscriptionRecordId = metadata.get("subscriptionRecordId")
|
||||
mandateId = metadata.get("mandateId")
|
||||
planKey = metadata.get("planKey", "")
|
||||
platformUrl = platformUrl or metadata.get("platformUrl", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
stripeSubId = session.get("subscription")
|
||||
|
||||
if not mandateId or not subscriptionRecordId:
|
||||
logger.warning("Subscription checkout missing metadata: %s", metadata)
|
||||
return
|
||||
|
||||
subInterface = getSubRootInterface()
|
||||
rootUser = getRootUser()
|
||||
|
||||
sub = subInterface.getById(subscriptionRecordId)
|
||||
if not sub:
|
||||
logger.error("Subscription record %s not found for checkout webhook", subscriptionRecordId)
|
||||
return
|
||||
if sub.get("status") != SubscriptionStatusEnum.PENDING.value:
|
||||
logger.warning("Subscription %s is %s, expected PENDING — skipping", subscriptionRecordId, sub.get("status"))
|
||||
return
|
||||
|
||||
stripeData: Dict[str, Any] = {}
|
||||
if stripeSubId:
|
||||
stripeData["stripeSubscriptionId"] = stripeSubId
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
stripeSub = stripe.Subscription.retrieve(stripeSubId, expand=["items"])
|
||||
|
||||
if stripeSub.get("current_period_start"):
|
||||
stripeData["currentPeriodStart"] = datetime.fromtimestamp(
|
||||
stripeSub["current_period_start"], tz=timezone.utc
|
||||
).isoformat()
|
||||
if stripeSub.get("current_period_end"):
|
||||
stripeData["currentPeriodEnd"] = datetime.fromtimestamp(
|
||||
stripeSub["current_period_end"], tz=timezone.utc
|
||||
).isoformat()
|
||||
|
||||
from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import getStripePricesForPlan
|
||||
priceMapping = getStripePricesForPlan(planKey)
|
||||
for item in stripeSub.get("items", {}).get("data", []):
|
||||
priceId = item.get("price", {}).get("id", "")
|
||||
if priceMapping and priceId == priceMapping.stripePriceIdUsers:
|
||||
stripeData["stripeItemIdUsers"] = item["id"]
|
||||
elif priceMapping and priceId == priceMapping.stripePriceIdInstances:
|
||||
stripeData["stripeItemIdInstances"] = item["id"]
|
||||
except Exception as e:
|
||||
logger.error("Error retrieving Stripe subscription %s: %s", stripeSubId, e)
|
||||
|
||||
if stripeData:
|
||||
subInterface.updateFields(subscriptionRecordId, stripeData)
|
||||
|
||||
operative = subInterface.getOperativeForMandate(mandateId)
|
||||
hasActivePredecessor = operative is not None and operative["id"] != subscriptionRecordId
|
||||
|
||||
if hasActivePredecessor:
|
||||
toStatus = SubscriptionStatusEnum.SCHEDULED
|
||||
if operative.get("recurring", True):
|
||||
operativeStripeId = operative.get("stripeSubscriptionId")
|
||||
if operativeStripeId:
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
stripe.Subscription.modify(operativeStripeId, cancel_at_period_end=True)
|
||||
except Exception as e:
|
||||
logger.error("Failed to set cancel_at_period_end on predecessor %s: %s", operativeStripeId, e)
|
||||
subInterface.updateFields(operative["id"], {"recurring": False})
|
||||
effectiveFrom = operative.get("currentPeriodEnd")
|
||||
if effectiveFrom:
|
||||
subInterface.updateFields(subscriptionRecordId, {"effectiveFrom": effectiveFrom})
|
||||
else:
|
||||
toStatus = SubscriptionStatusEnum.ACTIVE
|
||||
|
||||
try:
|
||||
subInterface.transitionStatus(
|
||||
subscriptionRecordId, SubscriptionStatusEnum.PENDING, toStatus,
|
||||
{"recurring": True},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to transition subscription %s: %s", subscriptionRecordId, e)
|
||||
return
|
||||
|
||||
subService = getSubscriptionService(rootUser, mandateId)
|
||||
subService.invalidateCache(mandateId)
|
||||
|
||||
if toStatus == SubscriptionStatusEnum.ACTIVE:
|
||||
plan = _getPlan(planKey)
|
||||
updatedSub = subInterface.getById(subscriptionRecordId)
|
||||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=updatedSub, platformUrl=platformUrl)
|
||||
|
||||
logger.info(
|
||||
"Checkout completed: sub=%s -> %s, mandate=%s, plan=%s",
|
||||
subscriptionRecordId, toStatus.value, mandateId, planKey,
|
||||
)
|
||||
|
||||
|
||||
def _handleSubscriptionWebhook(event) -> None:
|
||||
"""Process Stripe subscription webhook events.
|
||||
All record resolution is by stripeSubscriptionId — no mandate-based guessing."""
|
||||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
|
||||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum, _getPlan
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
getService as getSubscriptionService,
|
||||
_notifySubscriptionChange,
|
||||
)
|
||||
from modules.security.rootAccess import getRootUser
|
||||
from datetime import datetime, timezone
|
||||
|
||||
obj = event.data.object
|
||||
stripeSubId = obj.get("id") if event.type.startswith("customer.subscription") else obj.get("subscription")
|
||||
if not stripeSubId:
|
||||
logger.warning("Subscription webhook %s has no subscription ID", event.type)
|
||||
return
|
||||
|
||||
subInterface = getSubRootInterface()
|
||||
sub = subInterface.getByStripeSubscriptionId(stripeSubId)
|
||||
if not sub:
|
||||
logger.warning("No local record for Stripe subscription %s (event: %s)", stripeSubId, event.type)
|
||||
return
|
||||
|
||||
subId = sub["id"]
|
||||
mandateId = sub["mandateId"]
|
||||
currentStatus = SubscriptionStatusEnum(sub["status"])
|
||||
rootUser = getRootUser()
|
||||
subService = getSubscriptionService(rootUser, mandateId)
|
||||
|
||||
subMetadata = obj.get("metadata") or {}
|
||||
webhookPlatformUrl = subMetadata.get("platformUrl", "")
|
||||
|
||||
if event.type == "customer.subscription.updated":
|
||||
stripeStatus = obj.get("status", "")
|
||||
|
||||
periodData: Dict[str, Any] = {}
|
||||
if obj.get("current_period_start"):
|
||||
periodData["currentPeriodStart"] = datetime.fromtimestamp(
|
||||
obj["current_period_start"], tz=timezone.utc
|
||||
).isoformat()
|
||||
if obj.get("current_period_end"):
|
||||
periodData["currentPeriodEnd"] = datetime.fromtimestamp(
|
||||
obj["current_period_end"], tz=timezone.utc
|
||||
).isoformat()
|
||||
if periodData:
|
||||
subInterface.updateFields(subId, periodData)
|
||||
|
||||
if stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.SCHEDULED:
|
||||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE)
|
||||
subService.invalidateCache(mandateId)
|
||||
plan = _getPlan(sub.get("planKey", ""))
|
||||
refreshedSub = subInterface.getById(subId)
|
||||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedSub, platformUrl=webhookPlatformUrl)
|
||||
logger.info("SCHEDULED -> ACTIVE for sub %s (mandate %s)", subId, mandateId)
|
||||
|
||||
elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.PAST_DUE:
|
||||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.ACTIVE)
|
||||
subService.invalidateCache(mandateId)
|
||||
logger.info("PAST_DUE -> ACTIVE for sub %s (mandate %s)", subId, mandateId)
|
||||
|
||||
elif stripeStatus == "past_due" and currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE)
|
||||
subService.invalidateCache(mandateId)
|
||||
logger.info("ACTIVE -> PAST_DUE for sub %s (mandate %s)", subId, mandateId)
|
||||
|
||||
elif stripeStatus == "active" and currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||||
subService.invalidateCache(mandateId)
|
||||
logger.info("Period renewed for sub %s (mandate %s)", subId, mandateId)
|
||||
|
||||
elif event.type == "customer.subscription.deleted":
|
||||
if currentStatus not in (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE,
|
||||
SubscriptionStatusEnum.SCHEDULED):
|
||||
logger.info("Ignoring deletion for sub %s in status %s", subId, currentStatus.value)
|
||||
return
|
||||
|
||||
subInterface.transitionStatus(subId, currentStatus, SubscriptionStatusEnum.EXPIRED)
|
||||
subService.invalidateCache(mandateId)
|
||||
logger.info("Sub %s -> EXPIRED (Stripe deleted, mandate %s)", subId, mandateId)
|
||||
|
||||
scheduled = subInterface.getScheduledForMandate(mandateId)
|
||||
if scheduled:
|
||||
try:
|
||||
subInterface.transitionStatus(
|
||||
scheduled["id"], SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.ACTIVE,
|
||||
)
|
||||
subService.invalidateCache(mandateId)
|
||||
plan = _getPlan(scheduled.get("planKey", ""))
|
||||
refreshedScheduled = subInterface.getById(scheduled["id"])
|
||||
_notifySubscriptionChange(mandateId, "activated", plan, subscriptionRecord=refreshedScheduled, platformUrl=webhookPlatformUrl)
|
||||
logger.info("Promoted SCHEDULED sub %s -> ACTIVE (mandate %s)", scheduled["id"], mandateId)
|
||||
except Exception as e:
|
||||
logger.error("Failed to promote SCHEDULED sub %s: %s", scheduled["id"], e)
|
||||
|
||||
elif event.type == "invoice.payment_failed":
|
||||
if currentStatus == SubscriptionStatusEnum.ACTIVE:
|
||||
subInterface.transitionStatus(subId, SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.PAST_DUE)
|
||||
subService.invalidateCache(mandateId)
|
||||
plan = _getPlan(sub.get("planKey", ""))
|
||||
_notifySubscriptionChange(mandateId, "payment_failed", plan, subscriptionRecord=sub, platformUrl=webhookPlatformUrl)
|
||||
logger.info("Payment failed for sub %s (mandate %s)", subId, mandateId)
|
||||
|
||||
elif event.type == "customer.subscription.trial_will_end":
|
||||
logger.info("Trial ending soon for sub %s (mandate %s)", subId, mandateId)
|
||||
try:
|
||||
from modules.shared.notifyMandateAdmins import notifyMandateAdmins
|
||||
notifyMandateAdmins(
|
||||
mandateId,
|
||||
"[PowerOn] Testphase endet bald",
|
||||
"Testphase endet bald",
|
||||
[
|
||||
"Die kostenlose Testphase für Ihren Mandanten endet in Kürze.",
|
||||
"Bitte wählen Sie einen Plan unter Billing-Verwaltung › Abonnement.",
|
||||
],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to notify about trial ending: %s", e)
|
||||
|
||||
elif event.type == "invoice.paid":
|
||||
logger.info("Invoice paid for sub %s (mandate %s)", subId, mandateId)
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/admin/accounts/{targetMandateId}", response_model=List[AccountSummary])
|
||||
@limiter.limit("30/minute")
|
||||
def getAccounts(
|
||||
|
|
@ -1135,49 +1404,187 @@ def getUsersForMandate(
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/admin/transactions/{targetMandateId}", response_model=List[TransactionResponse])
|
||||
def _enrichTransactionRows(transactions) -> List[Dict[str, Any]]:
|
||||
"""Convert raw transaction dicts to enriched TransactionResponse rows with resolved usernames."""
|
||||
result = []
|
||||
for t in transactions:
|
||||
row = TransactionResponse(
|
||||
id=t.get("id"),
|
||||
accountId=t.get("accountId"),
|
||||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||||
amount=t.get("amount", 0.0),
|
||||
description=t.get("description", ""),
|
||||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||||
workflowId=t.get("workflowId"),
|
||||
featureCode=t.get("featureCode"),
|
||||
featureInstanceId=t.get("featureInstanceId"),
|
||||
aicoreProvider=t.get("aicoreProvider"),
|
||||
aicoreModel=t.get("aicoreModel"),
|
||||
createdByUserId=t.get("createdByUserId"),
|
||||
createdAt=t.get("_createdAt")
|
||||
)
|
||||
result.append(row.model_dump())
|
||||
|
||||
try:
|
||||
from modules.interfaces.interfaceDbUam import _getRootInterface as getUamRoot
|
||||
uamInterface = getUamRoot()
|
||||
userNames: Dict[str, str] = {}
|
||||
for row in result:
|
||||
uid = row.get("createdByUserId")
|
||||
if uid and uid not in userNames:
|
||||
try:
|
||||
user = uamInterface.getUser(uid)
|
||||
userNames[uid] = user.get("username", uid[:8]) if user else uid[:8]
|
||||
except Exception:
|
||||
userNames[uid] = uid[:8]
|
||||
row["userName"] = userNames.get(uid, "") if uid else ""
|
||||
except Exception:
|
||||
for row in result:
|
||||
row["userName"] = row.get("createdByUserId", "")[:8] if row.get("createdByUserId") else ""
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _buildTransactionsList(ctx: RequestContext, targetMandateId: str) -> List[Dict[str, Any]]:
|
||||
"""Build the full enriched transactions list for a mandate."""
|
||||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||||
transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=5000)
|
||||
|
||||
result = []
|
||||
for t in transactions:
|
||||
row = TransactionResponse(
|
||||
id=t.get("id"),
|
||||
accountId=t.get("accountId"),
|
||||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||||
amount=t.get("amount", 0.0),
|
||||
description=t.get("description", ""),
|
||||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||||
workflowId=t.get("workflowId"),
|
||||
featureCode=t.get("featureCode"),
|
||||
featureInstanceId=t.get("featureInstanceId"),
|
||||
aicoreProvider=t.get("aicoreProvider"),
|
||||
aicoreModel=t.get("aicoreModel"),
|
||||
createdByUserId=t.get("createdByUserId"),
|
||||
createdAt=t.get("_createdAt")
|
||||
)
|
||||
result.append(row.model_dump())
|
||||
|
||||
# Resolve user names
|
||||
try:
|
||||
from modules.interfaces.interfaceDbUam import _getRootInterface as getUamRoot
|
||||
uamInterface = getUamRoot()
|
||||
userNames: Dict[str, str] = {}
|
||||
for row in result:
|
||||
uid = row.get("createdByUserId")
|
||||
if uid and uid not in userNames:
|
||||
try:
|
||||
user = uamInterface.getUser(uid)
|
||||
userNames[uid] = user.get("username", uid[:8]) if user else uid[:8]
|
||||
except Exception:
|
||||
userNames[uid] = uid[:8]
|
||||
row["userName"] = userNames.get(uid, "") if uid else ""
|
||||
except Exception:
|
||||
for row in result:
|
||||
row["userName"] = row.get("createdByUserId", "")[:8] if row.get("createdByUserId") else ""
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/admin/transactions/{targetMandateId}")
|
||||
@limiter.limit("30/minute")
|
||||
def getTransactionsAdmin(
|
||||
request: Request,
|
||||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams"),
|
||||
limit: int = Query(default=100, ge=1, le=1000),
|
||||
ctx: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""
|
||||
Get all transactions for a mandate.
|
||||
Access: SysAdmin (any mandate) or MandateAdmin (own mandate).
|
||||
"""
|
||||
"""Get all transactions for a mandate with pagination support."""
|
||||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate")
|
||||
try:
|
||||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||||
transactions = billingInterface.getTransactionsByMandate(targetMandateId, limit=limit)
|
||||
|
||||
result = []
|
||||
for t in transactions:
|
||||
result.append(TransactionResponse(
|
||||
id=t.get("id"),
|
||||
accountId=t.get("accountId"),
|
||||
transactionType=TransactionTypeEnum(t.get("transactionType", "DEBIT")),
|
||||
amount=t.get("amount", 0.0),
|
||||
description=t.get("description", ""),
|
||||
referenceType=ReferenceTypeEnum(t["referenceType"]) if t.get("referenceType") else None,
|
||||
workflowId=t.get("workflowId"),
|
||||
featureCode=t.get("featureCode"),
|
||||
featureInstanceId=t.get("featureInstanceId"),
|
||||
aicoreProvider=t.get("aicoreProvider"),
|
||||
aicoreModel=t.get("aicoreModel"),
|
||||
createdByUserId=t.get("createdByUserId"),
|
||||
createdAt=t.get("_createdAt")
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
paginationParams: Optional[PaginationParams] = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
||||
|
||||
if paginationParams:
|
||||
# DB-level pagination — enrich only the returned page
|
||||
billingInterface = getBillingInterface(ctx.user, targetMandateId)
|
||||
result = billingInterface.getTransactionsByMandate(targetMandateId, pagination=paginationParams)
|
||||
transactions = result.items if hasattr(result, 'items') else result
|
||||
enrichedItems = _enrichTransactionRows(transactions)
|
||||
return {
|
||||
"items": enrichedItems,
|
||||
"pagination": PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems if hasattr(result, 'totalItems') else len(enrichedItems),
|
||||
totalPages=result.totalPages if hasattr(result, 'totalPages') else 0,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters,
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
enriched = _buildTransactionsList(ctx, targetMandateId)
|
||||
return {"items": enriched, "pagination": None}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting billing transactions for mandate {targetMandateId}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/admin/transactions/{targetMandateId}/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def getTransactionFilterValues(
|
||||
request: Request,
|
||||
targetMandateId: str = Path(..., description="Mandate ID"),
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
ctx: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Return distinct filter values for a column in mandate transactions."""
|
||||
if not _isAdminOfMandate(ctx, targetMandateId):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required for this mandate")
|
||||
try:
|
||||
crossFilterParams: Optional[PaginationParams] = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Try SQL DISTINCT for native DB columns; fallback to in-memory for enriched columns (e.g. userName)
|
||||
try:
|
||||
rootBillingInterface = _getRootInterface()
|
||||
recordFilter = {"mandateId": targetMandateId}
|
||||
values = rootBillingInterface.db.getDistinctColumnValues(
|
||||
BillingTransaction, column, crossFilterParams, recordFilter
|
||||
)
|
||||
return sorted(values, key=lambda v: str(v).lower())
|
||||
except Exception:
|
||||
enriched = _buildTransactionsList(ctx, targetMandateId)
|
||||
crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
|
||||
return _extractDistinctValues(crossFiltered, column)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for transactions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Mandate View Endpoints (for Admins)
|
||||
# =============================================================================
|
||||
|
|
@ -1625,3 +2032,50 @@ def getUserViewTransactions(
|
|||
except Exception as e:
|
||||
logger.error(f"Error getting user view transactions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/view/users/transactions/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def getUserViewTransactionsFilterValues(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
ctx: RequestContext = Depends(getRequestContext)
|
||||
):
|
||||
"""Return distinct filter values for a column in user transactions."""
|
||||
try:
|
||||
billingInterface = getBillingInterface(ctx.user, ctx.mandateId)
|
||||
scope = _getBillingDataScope(ctx.user)
|
||||
if scope.isGlobalAdmin:
|
||||
mandateIds = None
|
||||
else:
|
||||
mandateIds = scope.adminMandateIds + scope.memberMandateIds
|
||||
if not mandateIds:
|
||||
return []
|
||||
allTransactions = billingInterface.getUserTransactionsForMandates(mandateIds, limit=10000)
|
||||
allTransactions = _filterTransactionsByScope(allTransactions, scope)
|
||||
transactionDicts = []
|
||||
for t in allTransactions:
|
||||
transactionDicts.append({
|
||||
"id": t.get("id"),
|
||||
"accountId": t.get("accountId"),
|
||||
"transactionType": t.get("transactionType", "DEBIT"),
|
||||
"amount": t.get("amount", 0.0),
|
||||
"description": t.get("description", ""),
|
||||
"referenceType": t.get("referenceType"),
|
||||
"workflowId": t.get("workflowId"),
|
||||
"featureCode": t.get("featureCode"),
|
||||
"featureInstanceId": t.get("featureInstanceId"),
|
||||
"aicoreProvider": t.get("aicoreProvider"),
|
||||
"aicoreModel": t.get("aicoreModel"),
|
||||
"createdByUserId": t.get("createdByUserId"),
|
||||
"createdAt": t.get("_createdAt"),
|
||||
"mandateId": t.get("mandateId"),
|
||||
"mandateName": t.get("mandateName"),
|
||||
"userId": t.get("userId"),
|
||||
"userName": t.get("userName"),
|
||||
})
|
||||
return _handleFilterValuesRequest(transactionDicts, column, pagination)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for user transactions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
|
|||
|
|
@ -147,6 +147,11 @@ async def get_connections(
|
|||
try:
|
||||
interface = getInterface(currentUser)
|
||||
|
||||
# NOTE: Cannot use db.getRecordsetPaginated() here because each connection
|
||||
# is enriched with computed tokenStatus/tokenExpiresAt (requires per-row DB lookup).
|
||||
# Token refresh also may trigger re-fetch. Connections per user are typically < 10,
|
||||
# so in-memory pagination is acceptable.
|
||||
|
||||
# Parse pagination parameter
|
||||
paginationParams = None
|
||||
if pagination:
|
||||
|
|
@ -287,6 +292,42 @@ async def get_connections(
|
|||
detail=f"Failed to get connections: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_connection_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> List[str]:
|
||||
"""Return distinct filter values for a column in connections."""
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
interface = getInterface(currentUser)
|
||||
connections = interface.getUserConnections(currentUser.id)
|
||||
items = []
|
||||
for connection in connections:
|
||||
tokenStatus, tokenExpiresAt = getTokenStatusForConnection(interface, connection.id)
|
||||
items.append({
|
||||
"id": connection.id,
|
||||
"userId": connection.userId,
|
||||
"authority": connection.authority.value if hasattr(connection.authority, 'value') else str(connection.authority),
|
||||
"externalId": connection.externalId,
|
||||
"externalUsername": connection.externalUsername or "",
|
||||
"externalEmail": connection.externalEmail,
|
||||
"status": connection.status.value if hasattr(connection.status, 'value') else str(connection.status),
|
||||
"connectedAt": connection.connectedAt,
|
||||
"lastChecked": connection.lastChecked,
|
||||
"expiresAt": connection.expiresAt,
|
||||
"tokenStatus": tokenStatus,
|
||||
"tokenExpiresAt": tokenExpiresAt
|
||||
})
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for connections: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/", response_model=UserConnection)
|
||||
@limiter.limit("10/minute")
|
||||
def create_connection(
|
||||
|
|
|
|||
|
|
@ -37,6 +37,17 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
|
|||
mgmtInterface.updateFile(fileId, {"status": "active"})
|
||||
return
|
||||
|
||||
file_meta = mgmtInterface.getFile(fileId)
|
||||
feature_instance_id = ""
|
||||
mandate_id = ""
|
||||
if file_meta:
|
||||
if isinstance(file_meta, dict):
|
||||
feature_instance_id = file_meta.get("featureInstanceId") or ""
|
||||
mandate_id = file_meta.get("mandateId") or ""
|
||||
else:
|
||||
feature_instance_id = getattr(file_meta, "featureInstanceId", None) or ""
|
||||
mandate_id = getattr(file_meta, "mandateId", None) or ""
|
||||
|
||||
logger.info(f"Auto-index starting for {fileName} ({len(rawBytes)} bytes, {mimeType})")
|
||||
|
||||
# Step 1: Structure Pre-Scan (AI-free)
|
||||
|
|
@ -47,6 +58,8 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
|
|||
fileId=fileId,
|
||||
fileName=fileName,
|
||||
userId=userId,
|
||||
featureInstanceId=str(feature_instance_id) if feature_instance_id else "",
|
||||
mandateId=str(mandate_id) if mandate_id else "",
|
||||
)
|
||||
logger.info(
|
||||
f"Pre-scan complete for {fileName}: "
|
||||
|
|
@ -105,7 +118,11 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
|
|||
from modules.serviceCenter import getService
|
||||
from modules.serviceCenter.context import ServiceCenterContext
|
||||
|
||||
ctx = ServiceCenterContext(user=user, mandate_id="", feature_instance_id="")
|
||||
ctx = ServiceCenterContext(
|
||||
user=user,
|
||||
mandate_id=str(mandate_id) if mandate_id else "",
|
||||
feature_instance_id=str(feature_instance_id) if feature_instance_id else "",
|
||||
)
|
||||
knowledgeService = getService("knowledge", ctx)
|
||||
|
||||
await knowledgeService.indexFile(
|
||||
|
|
@ -113,6 +130,8 @@ async def _autoIndexFile(fileId: str, fileName: str, mimeType: str, user):
|
|||
fileName=fileName,
|
||||
mimeType=mimeType,
|
||||
userId=userId,
|
||||
featureInstanceId=str(feature_instance_id) if feature_instance_id else "",
|
||||
mandateId=str(mandate_id) if mandate_id else "",
|
||||
contentObjects=contentObjects,
|
||||
structure=contentIndex.structure,
|
||||
)
|
||||
|
|
@ -214,6 +233,56 @@ def get_files(
|
|||
detail=f"Failed to get files: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/list/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_file_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
currentUser: User = Depends(getCurrentUser),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in files."""
|
||||
try:
|
||||
managementInterface = interfaceDbManagement.getInterface(
|
||||
currentUser,
|
||||
mandateId=str(context.mandateId) if context.mandateId else None,
|
||||
featureInstanceId=str(context.featureInstanceId) if context.featureInstanceId else None
|
||||
)
|
||||
|
||||
crossFilterPagination = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterPagination = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
try:
|
||||
recordFilter = {"_createdBy": managementInterface.userId}
|
||||
values = managementInterface.db.getDistinctColumnValues(
|
||||
FileItem, column, crossFilterPagination, recordFilter
|
||||
)
|
||||
return sorted(values, key=lambda v: str(v).lower())
|
||||
except Exception:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
result = managementInterface.getAllFiles(pagination=None)
|
||||
items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in result]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for files: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e)
|
||||
)
|
||||
|
||||
|
||||
@router.post("/upload", status_code=status.HTTP_201_CREATED)
|
||||
@limiter.limit("10/minute")
|
||||
async def upload_file(
|
||||
|
|
|
|||
|
|
@ -164,6 +164,66 @@ def get_mandates(
|
|||
detail=f"Failed to get mandates: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_mandate_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in mandates."""
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
isSysAdmin = context.hasSysAdminRole
|
||||
if not isSysAdmin:
|
||||
adminMandateIds = _getAdminMandateIds(context)
|
||||
if not adminMandateIds:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin role required")
|
||||
|
||||
appInterface = interfaceDbApp.getRootInterface()
|
||||
|
||||
if isSysAdmin:
|
||||
# SysAdmin: try SQL DISTINCT for DB columns
|
||||
crossFilterPagination = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterPagination = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
try:
|
||||
values = appInterface.db.getDistinctColumnValues(
|
||||
Mandate, column, crossFilterPagination
|
||||
)
|
||||
return sorted(values, key=lambda v: str(v).lower())
|
||||
except Exception:
|
||||
result = appInterface.getAllMandates(pagination=None)
|
||||
items = result if isinstance(result, list) else (result.items if hasattr(result, 'items') else result)
|
||||
items = [i.model_dump() if hasattr(i, 'model_dump') else i for i in items]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
else:
|
||||
# MandateAdmin: in-memory (small set of individual mandate lookups)
|
||||
result = []
|
||||
for mid in adminMandateIds:
|
||||
mandate = appInterface.getMandate(mid)
|
||||
if mandate:
|
||||
result.append(mandate if isinstance(mandate, dict) else mandate.model_dump() if hasattr(mandate, 'model_dump') else vars(mandate))
|
||||
items = [i.model_dump() if hasattr(i, 'model_dump') else i for i in result]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for mandates: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{targetMandateId}", response_model=Mandate)
|
||||
@limiter.limit("30/minute")
|
||||
def get_mandate(
|
||||
|
|
@ -540,6 +600,63 @@ def list_mandate_users(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/{targetMandateId}/users/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_mandate_users_filter_values(
|
||||
request: Request,
|
||||
targetMandateId: str = Path(..., description="ID of the mandate"),
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in mandate users."""
|
||||
if not _hasMandateAdminRole(context, targetMandateId) and not context.hasSysAdminRole:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate-Admin role required")
|
||||
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
rootInterface = interfaceDbApp.getRootInterface()
|
||||
mandate = rootInterface.getMandate(targetMandateId)
|
||||
if not mandate:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Mandate {targetMandateId} not found")
|
||||
|
||||
userMandates = rootInterface.getUserMandatesByMandate(targetMandateId)
|
||||
result = []
|
||||
for um in userMandates:
|
||||
user = rootInterface.getUser(str(um.userId))
|
||||
if not user:
|
||||
continue
|
||||
roleIds = rootInterface.getRoleIdsForUserMandate(str(um.id))
|
||||
roleLabels = []
|
||||
filteredRoleIds = []
|
||||
seenLabels = set()
|
||||
for roleId in roleIds:
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role:
|
||||
if role.featureInstanceId:
|
||||
continue
|
||||
filteredRoleIds.append(roleId)
|
||||
if role.roleLabel not in seenLabels:
|
||||
roleLabels.append(role.roleLabel)
|
||||
seenLabels.add(role.roleLabel)
|
||||
result.append({
|
||||
"id": str(um.id),
|
||||
"userId": str(user.id),
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"fullName": user.fullName,
|
||||
"roleIds": filteredRoleIds,
|
||||
"roleLabels": roleLabels,
|
||||
"enabled": um.enabled
|
||||
})
|
||||
return _handleFilterValuesRequest(result, column, pagination)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for mandate users: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{targetMandateId}/users", response_model=UserMandateResponse)
|
||||
@limiter.limit("30/minute")
|
||||
def add_user_to_mandate(
|
||||
|
|
|
|||
|
|
@ -81,6 +81,29 @@ def get_prompts(
|
|||
pagination=None
|
||||
)
|
||||
|
||||
@router.get("/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_prompt_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
currentUser: User = Depends(getCurrentUser)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in prompts.
|
||||
|
||||
NOTE: Cannot use db.getDistinctColumnValues() because visibility rules
|
||||
(own + system for regular users) require pre-filtering the recordset.
|
||||
"""
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
managementInterface = interfaceDbManagement.getInterface(currentUser)
|
||||
result = managementInterface.getAllPrompts(pagination=None)
|
||||
items = [r.model_dump() if hasattr(r, 'model_dump') else r for r in result]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("", response_model=Prompt)
|
||||
@limiter.limit("10/minute")
|
||||
def create_prompt(
|
||||
|
|
|
|||
|
|
@ -69,6 +69,55 @@ def _isAdminForUser(context: RequestContext, targetUserId: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _extractDistinctValues(items: List[Dict[str, Any]], columnKey: str) -> List[str]:
|
||||
"""Extract sorted distinct display values for a column from enriched items."""
|
||||
values = set()
|
||||
for item in items:
|
||||
val = item.get(columnKey)
|
||||
if val is None or val == "":
|
||||
continue
|
||||
if isinstance(val, bool):
|
||||
values.add("true" if val else "false")
|
||||
elif isinstance(val, (int, float)):
|
||||
values.add(str(val))
|
||||
elif isinstance(val, dict):
|
||||
text = val.get("en") or next((v for v in val.values() if isinstance(v, str) and v), None)
|
||||
if text:
|
||||
values.add(str(text))
|
||||
else:
|
||||
values.add(str(val))
|
||||
return sorted(values, key=lambda v: v.lower())
|
||||
|
||||
|
||||
def _handleFilterValuesRequest(
|
||||
items: List[Dict[str, Any]],
|
||||
column: str,
|
||||
paginationJson: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generic handler for /filter-values endpoints.
|
||||
Applies all active filters EXCEPT the one for the requested column (cross-filtering),
|
||||
then extracts distinct values for that column.
|
||||
"""
|
||||
crossFilterParams: Optional[PaginationParams] = None
|
||||
if paginationJson:
|
||||
try:
|
||||
import json
|
||||
paginationDict = json.loads(paginationJson)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
crossFiltered = _applyFiltersAndSort(items, crossFilterParams)
|
||||
return _extractDistinctValues(crossFiltered, column)
|
||||
|
||||
|
||||
def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional[PaginationParams]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply filters and sorting to a list of items.
|
||||
|
|
@ -121,7 +170,6 @@ def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional
|
|||
if itemValue is None:
|
||||
return False
|
||||
|
||||
# Convert to string for comparison if needed
|
||||
itemStr = str(itemValue).lower()
|
||||
valueStr = str(v).lower()
|
||||
|
||||
|
|
@ -147,6 +195,42 @@ def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional
|
|||
return itemNum <= valueNum
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
elif op == 'between':
|
||||
if isinstance(v, dict):
|
||||
fromVal = v.get('from', '')
|
||||
toVal = v.get('to', '')
|
||||
if not fromVal and not toVal:
|
||||
return True
|
||||
# Date range: from/to are YYYY-MM-DD strings, itemValue may be Unix timestamp
|
||||
try:
|
||||
from datetime import datetime, timezone
|
||||
fromTs = None
|
||||
toTs = None
|
||||
if fromVal:
|
||||
fromTs = datetime.strptime(str(fromVal), '%Y-%m-%d').replace(tzinfo=timezone.utc).timestamp()
|
||||
if toVal:
|
||||
toTs = datetime.strptime(str(toVal), '%Y-%m-%d').replace(hour=23, minute=59, second=59, tzinfo=timezone.utc).timestamp()
|
||||
itemNum = float(itemValue) if not isinstance(itemValue, (int, float)) else itemValue
|
||||
# Normalize: if item looks like a millisecond timestamp, convert to seconds
|
||||
if itemNum > 10000000000:
|
||||
itemNum = itemNum / 1000
|
||||
if fromTs is not None and toTs is not None:
|
||||
return fromTs <= itemNum <= toTs
|
||||
elif fromTs is not None:
|
||||
return itemNum >= fromTs
|
||||
elif toTs is not None:
|
||||
return itemNum <= toTs
|
||||
except (ValueError, TypeError):
|
||||
# Fallback: string comparison (for non-numeric date fields)
|
||||
fromStr = str(fromVal).lower() if fromVal else ''
|
||||
toStr = str(toVal).lower() if toVal else ''
|
||||
if fromStr and toStr:
|
||||
return fromStr <= itemStr <= toStr
|
||||
elif fromStr:
|
||||
return itemStr >= fromStr
|
||||
elif toStr:
|
||||
return itemStr <= toStr
|
||||
return True
|
||||
elif op == 'in':
|
||||
if isinstance(v, list):
|
||||
return itemStr in [str(x).lower() for x in v]
|
||||
|
|
@ -159,23 +243,25 @@ def _applyFiltersAndSort(items: List[Dict[str, Any]], paginationParams: Optional
|
|||
|
||||
result = [item for item in result if matchesFilter(item, field, operator, value)]
|
||||
|
||||
# Apply sorting
|
||||
# Apply sorting — None values always last
|
||||
if paginationParams.sort:
|
||||
for sortField in reversed(paginationParams.sort):
|
||||
fieldName = sortField.field
|
||||
ascending = sortField.direction == 'asc'
|
||||
|
||||
def getSortKey(item: Dict[str, Any]):
|
||||
value = item.get(fieldName)
|
||||
if value is None:
|
||||
return (1, '') # Nulls last
|
||||
if isinstance(value, bool):
|
||||
return (0, not value if ascending else value)
|
||||
if isinstance(value, (int, float)):
|
||||
return (0, value)
|
||||
return (0, str(value).lower())
|
||||
noneItems = [item for item in result if item.get(fieldName) is None]
|
||||
nonNoneItems = [item for item in result if item.get(fieldName) is not None]
|
||||
|
||||
result = sorted(result, key=getSortKey, reverse=not ascending)
|
||||
def getSortKey(item: Dict[str, Any], _fn=fieldName):
|
||||
value = item.get(_fn)
|
||||
if isinstance(value, bool):
|
||||
return (0, int(value), '')
|
||||
if isinstance(value, (int, float)):
|
||||
return (0, value, '')
|
||||
return (1, 0, str(value).lower())
|
||||
|
||||
nonNoneItems = sorted(nonNoneItems, key=getSortKey, reverse=not ascending)
|
||||
result = nonNoneItems + noneItems
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -291,38 +377,23 @@ def get_users(
|
|||
pagination=None
|
||||
)
|
||||
elif context.hasSysAdminRole:
|
||||
# SysAdmin without mandateId sees all users
|
||||
# Get all users via interface method (returns Pydantic User models)
|
||||
allUserModels = appInterface.getAllUsers()
|
||||
# Convert to dictionaries for filtering/sorting
|
||||
cleanedUsers = [u.model_dump() for u in allUserModels]
|
||||
# SysAdmin without mandateId — DB-level pagination via interface
|
||||
result = appInterface.getAllUsers(paginationParams)
|
||||
|
||||
# Apply server-side filtering and sorting
|
||||
filteredUsers = _applyFiltersAndSort(cleanedUsers, paginationParams)
|
||||
|
||||
# Convert to User objects
|
||||
users = [User(**u) for u in filteredUsers]
|
||||
|
||||
if paginationParams:
|
||||
import math
|
||||
totalItems = len(users)
|
||||
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
||||
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||
endIdx = startIdx + paginationParams.pageSize
|
||||
paginatedUsers = users[startIdx:endIdx]
|
||||
|
||||
if paginationParams and hasattr(result, 'items'):
|
||||
return PaginatedResponse(
|
||||
items=paginatedUsers,
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
else:
|
||||
users = result if isinstance(result, list) else (result.items if hasattr(result, 'items') else [])
|
||||
return PaginatedResponse(
|
||||
items=users,
|
||||
pagination=None
|
||||
|
|
@ -407,6 +478,88 @@ def get_users(
|
|||
detail=f"Failed to get users: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_user_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in users."""
|
||||
try:
|
||||
appInterface = interfaceDbApp.getInterface(context.user, mandateId=context.mandateId)
|
||||
|
||||
# Build cross-filter pagination (all filters except the requested column)
|
||||
crossFilterPagination = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterPagination = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
if context.mandateId:
|
||||
# Mandate-scoped: in-memory (users require UserMandate join)
|
||||
result = appInterface.getUsersByMandate(str(context.mandateId), None)
|
||||
users = result if isinstance(result, list) else (result.items if hasattr(result, 'items') else [])
|
||||
items = [u.model_dump() if hasattr(u, 'model_dump') else u for u in users]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
elif context.hasSysAdminRole:
|
||||
# SysAdmin: use SQL DISTINCT for DB columns
|
||||
try:
|
||||
rootInterface = getRootInterface()
|
||||
values = rootInterface.db.getDistinctColumnValues(
|
||||
UserInDB, column, crossFilterPagination
|
||||
)
|
||||
return sorted(values, key=lambda v: v.lower())
|
||||
except Exception:
|
||||
users = appInterface.getAllUsers()
|
||||
items = [u.model_dump() if hasattr(u, 'model_dump') else u for u in users]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
else:
|
||||
# Non-admin multi-mandate: aggregate across admin mandates (in-memory)
|
||||
rootInterface = getRootInterface()
|
||||
userMandates = rootInterface.getUserMandates(str(context.user.id))
|
||||
adminMandateIds = []
|
||||
for um in userMandates:
|
||||
umId = getattr(um, 'id', None)
|
||||
mandateId = getattr(um, 'mandateId', None)
|
||||
if not umId or not mandateId:
|
||||
continue
|
||||
roleIds = rootInterface.getRoleIdsForUserMandate(str(umId))
|
||||
for roleId in roleIds:
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role and role.roleLabel == "admin" and not role.featureInstanceId:
|
||||
adminMandateIds.append(str(mandateId))
|
||||
break
|
||||
if not adminMandateIds:
|
||||
return []
|
||||
seenUserIds = set()
|
||||
users = []
|
||||
for mid in adminMandateIds:
|
||||
mandateUsers = rootInterface.getUsersByMandate(mid)
|
||||
uList = mandateUsers if isinstance(mandateUsers, list) else (mandateUsers.items if hasattr(mandateUsers, 'items') else [])
|
||||
for u in uList:
|
||||
uid = u.get("id") if isinstance(u, dict) else getattr(u, "id", None)
|
||||
if uid and uid not in seenUserIds:
|
||||
seenUserIds.add(uid)
|
||||
users.append(u)
|
||||
items = [u.model_dump() if hasattr(u, 'model_dump') else u for u in users]
|
||||
return _handleFilterValuesRequest(items, column, pagination)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for users: {str(e)}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{userId}", response_model=User)
|
||||
@limiter.limit("30/minute")
|
||||
def get_user(
|
||||
|
|
|
|||
|
|
@ -420,6 +420,13 @@ def list_invitations(
|
|||
|
||||
Requires Mandate-Admin role. Returns all invitations created for this mandate.
|
||||
|
||||
NOTE: Cannot use db.getRecordsetPaginated() because:
|
||||
- Computed status fields (isExpired, isUsedUp) are derived in-memory
|
||||
- Filtering by revoked/used/expired requires post-fetch logic
|
||||
- Invitation volume per mandate is typically low (< 100)
|
||||
When this endpoint needs FormGeneratorTable pagination, add PaginatedResponse
|
||||
support with in-memory slicing (similar to routeDataConnections).
|
||||
|
||||
Args:
|
||||
includeUsed: Include invitations that have reached maxUses
|
||||
includeExpired: Include expired invitations
|
||||
|
|
@ -485,6 +492,54 @@ def list_invitations(
|
|||
)
|
||||
|
||||
|
||||
@router.get("/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def get_invitation_filter_values(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters"),
|
||||
frontendUrl: str = Query("", description="Frontend URL for building invite links"),
|
||||
includeUsed: bool = Query(False, description="Include already used invitations"),
|
||||
includeExpired: bool = Query(False, description="Include expired invitations"),
|
||||
context: RequestContext = Depends(getRequestContext)
|
||||
) -> list:
|
||||
"""Return distinct filter values for a column in invitations."""
|
||||
if not context.mandateId:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="X-Mandate-Id header is required")
|
||||
if not _hasMandateAdminRole(context):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate-Admin role required")
|
||||
try:
|
||||
from modules.routes.routeDataUsers import _handleFilterValuesRequest
|
||||
rootInterface = getRootInterface()
|
||||
allInvitations = rootInterface.getInvitationsByMandate(str(context.mandateId))
|
||||
currentTime = getUtcTimestamp()
|
||||
result = []
|
||||
for inv in allInvitations:
|
||||
if inv.revokedAt:
|
||||
continue
|
||||
currentUses = inv.currentUses or 0
|
||||
maxUses = inv.maxUses or 1
|
||||
if not includeUsed and currentUses >= maxUses:
|
||||
continue
|
||||
expiresAt = inv.expiresAt or 0
|
||||
if not includeExpired and expiresAt < currentTime:
|
||||
continue
|
||||
baseUrl = frontendUrl.rstrip("/") if frontendUrl else ""
|
||||
inviteUrl = f"{baseUrl}/invite/{inv.token}" if baseUrl else ""
|
||||
result.append({
|
||||
**inv.model_dump(),
|
||||
"inviteUrl": inviteUrl,
|
||||
"isExpired": expiresAt < currentTime,
|
||||
"isUsedUp": currentUses >= maxUses
|
||||
})
|
||||
return _handleFilterValuesRequest(result, column, pagination)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting filter values for invitations: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{invitationId}", response_model=Dict[str, str])
|
||||
@limiter.limit("30/minute")
|
||||
def revoke_invitation(
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from modules.datamodels.datamodelMessaging import (
|
|||
MessagingSubscriptionExecutionResult
|
||||
)
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
|
||||
# Configure logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -48,6 +48,8 @@ def get_subscriptions(
|
|||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -185,6 +187,8 @@ def get_subscription_registrations(
|
|||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -277,6 +281,8 @@ def get_my_registrations(
|
|||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
|
|
@ -450,6 +456,8 @@ def get_deliveries(
|
|||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict) if paginationDict else None
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(
|
||||
|
|
|
|||
|
|
@ -227,33 +227,22 @@ async def get_projects(
|
|||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
recordFilter = {"featureInstanceId": instanceId}
|
||||
items = interface.getProjekte(recordFilter=recordFilter)
|
||||
paginationParams = _parsePagination(pagination)
|
||||
if paginationParams:
|
||||
if paginationParams.sort:
|
||||
for sort_field in reversed(paginationParams.sort):
|
||||
field_name = sort_field.field
|
||||
direction = sort_field.direction.lower()
|
||||
items.sort(
|
||||
key=lambda x: getattr(x, field_name, None),
|
||||
reverse=(direction == "desc")
|
||||
result = interface.getProjekte(pagination=paginationParams, recordFilter=recordFilter)
|
||||
if hasattr(result, 'items'):
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort or [],
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
total_items = len(items)
|
||||
total_pages = (total_items + paginationParams.pageSize - 1) // paginationParams.pageSize
|
||||
start_idx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||
end_idx = start_idx + paginationParams.pageSize
|
||||
paginated_items = items[start_idx:end_idx]
|
||||
return PaginatedResponse(
|
||||
items=paginated_items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=total_items,
|
||||
totalPages=total_pages,
|
||||
sort=paginationParams.sort or [],
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
items = interface.getProjekte(recordFilter=recordFilter)
|
||||
return PaginatedResponse(items=items, pagination=None)
|
||||
|
||||
|
||||
|
|
@ -359,33 +348,22 @@ async def get_parcels(
|
|||
context.user, mandateId=mandateId, featureInstanceId=instanceId
|
||||
)
|
||||
recordFilter = {"featureInstanceId": instanceId}
|
||||
items = interface.getParzellen(recordFilter=recordFilter)
|
||||
paginationParams = _parsePagination(pagination)
|
||||
if paginationParams:
|
||||
if paginationParams.sort:
|
||||
for sort_field in reversed(paginationParams.sort):
|
||||
field_name = sort_field.field
|
||||
direction = sort_field.direction.lower()
|
||||
items.sort(
|
||||
key=lambda x: getattr(x, field_name, None),
|
||||
reverse=(direction == "desc")
|
||||
result = interface.getParzellen(pagination=paginationParams, recordFilter=recordFilter)
|
||||
if hasattr(result, 'items'):
|
||||
return PaginatedResponse(
|
||||
items=result.items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=result.totalItems,
|
||||
totalPages=result.totalPages,
|
||||
sort=paginationParams.sort or [],
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
total_items = len(items)
|
||||
total_pages = (total_items + paginationParams.pageSize - 1) // paginationParams.pageSize
|
||||
start_idx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||
end_idx = start_idx + paginationParams.pageSize
|
||||
paginated_items = items[start_idx:end_idx]
|
||||
return PaginatedResponse(
|
||||
items=paginated_items,
|
||||
pagination=PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=total_items,
|
||||
totalPages=total_pages,
|
||||
sort=paginationParams.sort or [],
|
||||
filters=paginationParams.filters
|
||||
)
|
||||
)
|
||||
items = interface.getParzellen(recordFilter=recordFilter)
|
||||
return PaginatedResponse(items=items, pagination=None)
|
||||
|
||||
|
||||
|
|
|
|||
437
modules/routes/routeSubscription.py
Normal file
437
modules/routes/routeSubscription.py
Normal file
|
|
@ -0,0 +1,437 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Subscription routes — ID-based, state-machine-driven.
|
||||
|
||||
Endpoints:
|
||||
- GET /api/subscription/plans — list selectable plans
|
||||
- GET /api/subscription/status — operative + scheduled subscription for current mandate
|
||||
- POST /api/subscription/activate — start checkout for a plan
|
||||
- POST /api/subscription/cancel — cancel a specific subscription (by ID)
|
||||
- POST /api/subscription/reactivate — reactivate a cancelled subscription (by ID)
|
||||
- POST /api/subscription/force-cancel — sysadmin immediate cancel (by ID)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Request, Query
|
||||
from fastapi import status
|
||||
from typing import Dict, Any, List, Optional
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from modules.auth import limiter, getRequestContext, RequestContext
|
||||
from modules.datamodels.datamodelPagination import PaginationParams, PaginatedResponse, PaginationMetadata, normalize_pagination_dict
|
||||
from modules.routes.routeDataUsers import _applyFiltersAndSort, _extractDistinctValues
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolveMandateId(context: RequestContext) -> str:
|
||||
if context.mandateId:
|
||||
return str(context.mandateId)
|
||||
return ""
|
||||
|
||||
|
||||
def _assertMandateAdmin(context: RequestContext, mandateId: str) -> None:
|
||||
if context.hasSysAdminRole:
|
||||
return
|
||||
try:
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
rootInterface = getRootInterface()
|
||||
userMandates = rootInterface.getUserMandates(str(context.user.id))
|
||||
for um in userMandates:
|
||||
if str(getattr(um, "mandateId", None)) != str(mandateId):
|
||||
continue
|
||||
if not getattr(um, "enabled", True):
|
||||
continue
|
||||
umId = str(getattr(um, "id", ""))
|
||||
roleIds = rootInterface.getRoleIdsForUserMandate(umId)
|
||||
for roleId in roleIds:
|
||||
role = rootInterface.getRole(roleId)
|
||||
if role and role.roleLabel == "admin" and not role.featureInstanceId:
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Mandate admin role required")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request / Response models
|
||||
# =============================================================================
|
||||
|
||||
class ActivatePlanRequest(BaseModel):
|
||||
planKey: str = Field(..., description="Key of the plan to activate")
|
||||
returnUrl: str = Field(..., description="Frontend URL to redirect back to after Stripe Checkout")
|
||||
|
||||
class CancelRequest(BaseModel):
|
||||
subscriptionId: str = Field(..., description="ID of the subscription to cancel")
|
||||
|
||||
class ReactivateRequest(BaseModel):
|
||||
subscriptionId: str = Field(..., description="ID of the subscription to reactivate")
|
||||
|
||||
class ForceCancelRequest(BaseModel):
|
||||
subscriptionId: str = Field(..., description="ID of the subscription to force-cancel")
|
||||
|
||||
class VerifyCheckoutRequest(BaseModel):
|
||||
sessionId: str = Field(..., description="Stripe Checkout Session ID to verify")
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
active: bool
|
||||
subscription: Optional[Dict[str, Any]] = None
|
||||
plan: Optional[Dict[str, Any]] = None
|
||||
scheduled: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Router
|
||||
# =============================================================================
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/subscription",
|
||||
tags=["Subscription"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Endpoints
|
||||
# =============================================================================
|
||||
|
||||
@router.get("/plans", response_model=List[Dict[str, Any]])
|
||||
@limiter.limit("30/minute")
|
||||
def getPlans(request: Request, context: RequestContext = Depends(getRequestContext)):
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
getService as getSubscriptionService,
|
||||
)
|
||||
try:
|
||||
mandateId = _resolveMandateId(context)
|
||||
subService = getSubscriptionService(context.user, mandateId)
|
||||
plans = subService.getSelectablePlans()
|
||||
return [p.model_dump() for p in plans]
|
||||
except Exception as e:
|
||||
logger.error("Error fetching plans: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/status", response_model=SubscriptionStatusResponse)
|
||||
@limiter.limit("60/minute")
|
||||
def getStatus(request: Request, context: RequestContext = Depends(getRequestContext)):
|
||||
"""Return the operative subscription and any scheduled successor for the current mandate."""
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
getService as getSubscriptionService,
|
||||
)
|
||||
mandateId = _resolveMandateId(context)
|
||||
if not mandateId:
|
||||
return SubscriptionStatusResponse(active=False)
|
||||
_assertMandateAdmin(context, mandateId)
|
||||
|
||||
try:
|
||||
subService = getSubscriptionService(context.user, mandateId)
|
||||
operative = subService.getOperativeSubscription(mandateId)
|
||||
scheduled = subService.getScheduledSubscription(mandateId)
|
||||
|
||||
if not operative:
|
||||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum
|
||||
pending = subService.listSubscriptions(mandateId, [SubscriptionStatusEnum.PENDING])
|
||||
if pending:
|
||||
sub = pending[0]
|
||||
plan = subService.getPlan(sub.get("planKey", ""))
|
||||
return SubscriptionStatusResponse(
|
||||
active=False,
|
||||
subscription=sub,
|
||||
plan=plan.model_dump() if plan else None,
|
||||
scheduled=scheduled,
|
||||
)
|
||||
return SubscriptionStatusResponse(active=False, scheduled=scheduled)
|
||||
|
||||
plan = subService.getPlan(operative.get("planKey", ""))
|
||||
return SubscriptionStatusResponse(
|
||||
active=True,
|
||||
subscription=operative,
|
||||
plan=plan.model_dump() if plan else None,
|
||||
scheduled=scheduled,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error fetching status: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/activate", response_model=Dict[str, Any])
|
||||
@limiter.limit("10/minute")
|
||||
def activatePlan(
|
||||
request: Request,
|
||||
data: ActivatePlanRequest,
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
getService as getSubscriptionService,
|
||||
)
|
||||
mandateId = _resolveMandateId(context)
|
||||
if not mandateId:
|
||||
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
|
||||
_assertMandateAdmin(context, mandateId)
|
||||
|
||||
try:
|
||||
subService = getSubscriptionService(context.user, mandateId)
|
||||
return subService.activatePlan(mandateId, data.planKey, returnUrl=data.returnUrl)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error("Error activating plan %s: %s", data.planKey, e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/cancel", response_model=Dict[str, Any])
|
||||
@limiter.limit("5/minute")
|
||||
def cancelSubscription(
|
||||
request: Request,
|
||||
data: CancelRequest,
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Cancel a specific subscription by its ID."""
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
getService as getSubscriptionService,
|
||||
)
|
||||
mandateId = _resolveMandateId(context)
|
||||
if not mandateId:
|
||||
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
|
||||
_assertMandateAdmin(context, mandateId)
|
||||
|
||||
try:
|
||||
subService = getSubscriptionService(context.user, mandateId)
|
||||
return subService.cancelSubscription(data.subscriptionId)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error("Error cancelling subscription %s: %s", data.subscriptionId, e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/reactivate", response_model=Dict[str, Any])
|
||||
@limiter.limit("5/minute")
|
||||
def reactivateSubscription(
|
||||
request: Request,
|
||||
data: ReactivateRequest,
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Reactivate a cancelled (non-recurring) subscription before its period ends."""
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
getService as getSubscriptionService,
|
||||
)
|
||||
mandateId = _resolveMandateId(context)
|
||||
if not mandateId:
|
||||
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
|
||||
_assertMandateAdmin(context, mandateId)
|
||||
|
||||
try:
|
||||
subService = getSubscriptionService(context.user, mandateId)
|
||||
return subService.reactivateSubscription(data.subscriptionId)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error("Error reactivating subscription %s: %s", data.subscriptionId, e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/force-cancel", response_model=Dict[str, Any])
|
||||
@limiter.limit("5/minute")
|
||||
def forceCancel(
|
||||
request: Request,
|
||||
data: ForceCancelRequest,
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Sysadmin: immediately expire any non-terminal subscription."""
|
||||
if not context.hasSysAdminRole:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
|
||||
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
getService as getSubscriptionService,
|
||||
)
|
||||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
|
||||
sub = getSubRootInterface().getById(data.subscriptionId)
|
||||
if not sub:
|
||||
raise HTTPException(status_code=404, detail="Subscription not found")
|
||||
mandateId = sub["mandateId"]
|
||||
|
||||
try:
|
||||
subService = getSubscriptionService(context.user, mandateId)
|
||||
return subService.forceCancel(data.subscriptionId)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error("Error force-cancelling subscription %s: %s", data.subscriptionId, e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/checkout/verify", response_model=Dict[str, Any])
|
||||
@limiter.limit("20/minute")
|
||||
def verifyCheckout(
|
||||
request: Request,
|
||||
data: VerifyCheckoutRequest,
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Verify a Stripe Checkout Session and activate the subscription if paid.
|
||||
|
||||
This is the synchronous counterpart to the checkout.session.completed webhook.
|
||||
It's called by the frontend immediately after returning from Stripe to handle
|
||||
environments where webhooks may be delayed or unavailable (e.g. localhost dev).
|
||||
The logic is idempotent — if the webhook already processed the session, this is a no-op.
|
||||
"""
|
||||
mandateId = _resolveMandateId(context)
|
||||
if not mandateId:
|
||||
raise HTTPException(status_code=400, detail="X-Mandate-Id header required")
|
||||
_assertMandateAdmin(context, mandateId)
|
||||
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
session = stripe.checkout.Session.retrieve(data.sessionId)
|
||||
except Exception as e:
|
||||
logger.error("Failed to retrieve checkout session %s: %s", data.sessionId, e)
|
||||
raise HTTPException(status_code=400, detail="Invalid session ID")
|
||||
|
||||
if session.get("status") != "complete" or session.get("payment_status") != "paid":
|
||||
return {"status": "pending", "message": "Checkout not yet completed"}
|
||||
|
||||
if session.get("mode") != "subscription":
|
||||
raise HTTPException(status_code=400, detail="Not a subscription checkout session")
|
||||
|
||||
from modules.routes.routeBilling import _handleSubscriptionCheckoutCompleted
|
||||
_handleSubscriptionCheckoutCompleted(session, f"verify-{data.sessionId}")
|
||||
|
||||
return {"status": "activated", "message": "Subscription activated"}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SysAdmin: global subscription overview
|
||||
# =============================================================================
|
||||
|
||||
def _buildEnrichedSubscriptions() -> List[Dict[str, Any]]:
|
||||
"""Build the full enriched subscription list (shared by list + filter-values endpoints)."""
|
||||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
|
||||
from modules.datamodels.datamodelSubscription import BUILTIN_PLANS, OPERATIVE_STATUSES
|
||||
|
||||
subInterface = getSubRootInterface()
|
||||
allSubs = subInterface.listAll()
|
||||
|
||||
mandateNames: Dict[str, str] = {}
|
||||
try:
|
||||
from modules.datamodels.datamodelUam import Mandate
|
||||
from modules.security.rootAccess import getRootDbAppConnector
|
||||
appDb = getRootDbAppConnector()
|
||||
for row in appDb.getRecordset(Mandate):
|
||||
r = dict(row)
|
||||
mid = r.get("id", "")
|
||||
mandateNames[mid] = r.get("label") or r.get("name") or mid[:8]
|
||||
except Exception as e:
|
||||
logger.warning("Could not bulk-resolve mandate names: %s", e)
|
||||
|
||||
operativeValues = {s.value for s in OPERATIVE_STATUSES}
|
||||
|
||||
enriched = []
|
||||
for sub in allSubs:
|
||||
mid = sub.get("mandateId", "")
|
||||
planKey = sub.get("planKey", "")
|
||||
plan = BUILTIN_PLANS.get(planKey)
|
||||
|
||||
sub["mandateName"] = mandateNames.get(mid, mid[:8])
|
||||
sub["planTitle"] = (plan.title.get("de") or plan.title.get("en") or planKey) if plan else planKey
|
||||
|
||||
if sub.get("status") in operativeValues:
|
||||
userPrice = sub.get("snapshotPricePerUserCHF", 0) or 0
|
||||
instPrice = sub.get("snapshotPricePerInstanceCHF", 0) or 0
|
||||
try:
|
||||
userCount = subInterface.countActiveUsers(mid)
|
||||
instanceCount = subInterface.countActiveFeatureInstances(mid)
|
||||
except Exception:
|
||||
userCount = 0
|
||||
instanceCount = 0
|
||||
sub["monthlyRevenueCHF"] = round(userPrice * userCount + instPrice * instanceCount, 2)
|
||||
sub["activeUsers"] = userCount
|
||||
sub["activeInstances"] = instanceCount
|
||||
else:
|
||||
sub["monthlyRevenueCHF"] = 0
|
||||
sub["activeUsers"] = 0
|
||||
sub["activeInstances"] = 0
|
||||
|
||||
enriched.append(sub)
|
||||
|
||||
return enriched
|
||||
|
||||
|
||||
@router.get("/admin/all")
|
||||
@limiter.limit("30/minute")
|
||||
def getAllSubscriptions(
|
||||
request: Request,
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded PaginationParams object"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""SysAdmin: list ALL subscriptions across all mandates with enriched metadata."""
|
||||
if not context.hasSysAdminRole:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
|
||||
|
||||
paginationParams: Optional[PaginationParams] = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
paginationParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
||||
|
||||
enriched = _buildEnrichedSubscriptions()
|
||||
filtered = _applyFiltersAndSort(enriched, paginationParams)
|
||||
|
||||
if paginationParams:
|
||||
totalItems = len(filtered)
|
||||
totalPages = math.ceil(totalItems / paginationParams.pageSize) if totalItems > 0 else 0
|
||||
startIdx = (paginationParams.page - 1) * paginationParams.pageSize
|
||||
endIdx = startIdx + paginationParams.pageSize
|
||||
pageItems = filtered[startIdx:endIdx]
|
||||
return {
|
||||
"items": pageItems,
|
||||
"pagination": PaginationMetadata(
|
||||
currentPage=paginationParams.page,
|
||||
pageSize=paginationParams.pageSize,
|
||||
totalItems=totalItems,
|
||||
totalPages=totalPages,
|
||||
sort=paginationParams.sort,
|
||||
filters=paginationParams.filters,
|
||||
).model_dump(),
|
||||
}
|
||||
|
||||
return {"items": enriched, "pagination": None}
|
||||
|
||||
|
||||
@router.get("/admin/all/filter-values")
|
||||
@limiter.limit("60/minute")
|
||||
def getFilterValues(
|
||||
request: Request,
|
||||
column: str = Query(..., description="Column key to extract distinct values for"),
|
||||
pagination: Optional[str] = Query(None, description="JSON-encoded current filters (applied except for the requested column)"),
|
||||
context: RequestContext = Depends(getRequestContext),
|
||||
):
|
||||
"""Return distinct values for a column, respecting all active filters except the requested one."""
|
||||
if not context.hasSysAdminRole:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Sysadmin role required")
|
||||
|
||||
crossFilterParams: Optional[PaginationParams] = None
|
||||
if pagination:
|
||||
try:
|
||||
paginationDict = json.loads(pagination)
|
||||
if paginationDict:
|
||||
paginationDict = normalize_pagination_dict(paginationDict)
|
||||
filters = paginationDict.get("filters", {})
|
||||
filters.pop(column, None)
|
||||
paginationDict["filters"] = filters
|
||||
paginationDict.pop("sort", None)
|
||||
crossFilterParams = PaginationParams(**paginationDict)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid pagination parameter: {str(e)}")
|
||||
|
||||
enriched = _buildEnrichedSubscriptions()
|
||||
crossFiltered = _applyFiltersAndSort(enriched, crossFilterParams)
|
||||
|
||||
return _extractDistinctValues(crossFiltered, column)
|
||||
|
|
@ -105,6 +105,9 @@ def _getFeatureUiObjects(featureCode: str) -> List[Dict[str, Any]]:
|
|||
elif featureCode == "automation":
|
||||
from modules.features.automation.mainAutomation import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
elif featureCode == "automation2":
|
||||
from modules.features.automation2.mainAutomation2 import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
elif featureCode == "teamsbot":
|
||||
from modules.features.teamsbot.mainTeamsbot import UI_OBJECTS
|
||||
return UI_OBJECTS
|
||||
|
|
|
|||
|
|
@ -45,10 +45,17 @@ IMPORTABLE_SERVICES: Dict[str, Dict[str, Any]] = {
|
|||
"billing": {
|
||||
"module": "modules.serviceCenter.services.serviceBilling.mainServiceBilling",
|
||||
"class": "BillingService",
|
||||
"dependencies": [],
|
||||
"dependencies": ["subscription"],
|
||||
"objectKey": "service.billing",
|
||||
"label": {"en": "Billing", "de": "Abrechnung", "fr": "Facturation"},
|
||||
},
|
||||
"subscription": {
|
||||
"module": "modules.serviceCenter.services.serviceSubscription.mainServiceSubscription",
|
||||
"class": "SubscriptionService",
|
||||
"dependencies": [],
|
||||
"objectKey": "service.subscription",
|
||||
"label": {"en": "Subscription", "de": "Abonnement", "fr": "Abonnement"},
|
||||
},
|
||||
"sharepoint": {
|
||||
"module": "modules.serviceCenter.services.serviceSharepoint.mainServiceSharepoint",
|
||||
"class": "SharepointService",
|
||||
|
|
@ -92,7 +99,7 @@ IMPORTABLE_SERVICES: Dict[str, Dict[str, Any]] = {
|
|||
"label": {"en": "Web Research", "de": "Web-Recherche", "fr": "Recherche Web"},
|
||||
},
|
||||
"neutralization": {
|
||||
"module": "modules.serviceCenter.services.serviceNeutralization.mainServiceNeutralization",
|
||||
"module": "modules.features.neutralization.serviceNeutralization.mainServiceNeutralization",
|
||||
"class": "NeutralizationService",
|
||||
"dependencies": ["extraction", "generation"],
|
||||
"objectKey": "service.neutralization",
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ from modules.shared.jsonUtils import closeJsonStructures
|
|||
from modules.serviceCenter.services.serviceBilling.mainServiceBilling import (
|
||||
InsufficientBalanceException,
|
||||
)
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
SubscriptionInactiveException,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -191,6 +194,18 @@ async def runAgentLoop(
|
|||
else:
|
||||
aiResponse = await aiCallFn(aiRequest)
|
||||
|
||||
except SubscriptionInactiveException as e:
|
||||
logger.warning(
|
||||
f"Subscription inactive in round {state.currentRound} (mandate={mandateId}): {e.message}"
|
||||
)
|
||||
state.status = AgentStatusEnum.ERROR
|
||||
state.abortReason = e.message
|
||||
yield AgentEvent(
|
||||
type=AgentEventTypeEnum.ERROR,
|
||||
content=e.message,
|
||||
data=e.toClientDict(),
|
||||
)
|
||||
break
|
||||
except InsufficientBalanceException as e:
|
||||
logger.warning(
|
||||
f"Insufficient balance in round {state.currentRound} (mandate={mandateId}): {e.message}"
|
||||
|
|
|
|||
|
|
@ -1432,21 +1432,55 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
return ToolResult(toolCallId="", toolName="uploadToExternal", success=False, error=str(e))
|
||||
|
||||
async def _sendMail(args: Dict[str, Any], context: Dict[str, Any]):
|
||||
import base64 as _b64
|
||||
|
||||
connectionId = args.get("connectionId", "")
|
||||
to = args.get("to", [])
|
||||
subject = args.get("subject", "")
|
||||
body = args.get("body", "")
|
||||
bodyType = "HTML" if args.get("bodyType", "text").lower() == "html" else "Text"
|
||||
draft = args.get("draft", False)
|
||||
attachmentFileIds = args.get("attachmentFileIds") or []
|
||||
|
||||
if not connectionId or not to or not subject:
|
||||
return ToolResult(toolCallId="", toolName="sendMail", success=False, error="connectionId, to, and subject are required")
|
||||
try:
|
||||
graphAttachments: List[Dict[str, Any]] = []
|
||||
if attachmentFileIds:
|
||||
chatService = services.chat
|
||||
dbMgmt = chatService.interfaceDbComponent
|
||||
for fid in attachmentFileIds:
|
||||
fileRow = dbMgmt.getFile(fid)
|
||||
if not fileRow:
|
||||
return ToolResult(toolCallId="", toolName="sendMail", success=False, error=f"Attachment file not found: {fid}")
|
||||
rawBytes = dbMgmt.getFileData(fid)
|
||||
if not rawBytes:
|
||||
return ToolResult(toolCallId="", toolName="sendMail", success=False, error=f"Attachment file has no data: {fid}")
|
||||
graphAttachments.append({
|
||||
"name": fileRow.fileName,
|
||||
"contentBytes": _b64.b64encode(rawBytes).decode("ascii"),
|
||||
"contentType": getattr(fileRow, "mimeType", "application/octet-stream"),
|
||||
})
|
||||
|
||||
from modules.connectors.connectorResolver import ConnectorResolver
|
||||
resolver = ConnectorResolver(
|
||||
services.getService("security"),
|
||||
_buildResolverDb(),
|
||||
)
|
||||
adapter = await resolver.resolveService(connectionId, "outlook")
|
||||
|
||||
if draft and hasattr(adapter, "createDraft"):
|
||||
result = await adapter.createDraft(
|
||||
to=to, subject=subject, body=body, bodyType=bodyType,
|
||||
cc=args.get("cc"), attachments=graphAttachments or None,
|
||||
)
|
||||
return ToolResult(toolCallId="", toolName="sendMail", success=True, data=str(result))
|
||||
|
||||
if hasattr(adapter, "sendMail"):
|
||||
result = await adapter.sendMail(to=to, subject=subject, body=body, cc=args.get("cc"))
|
||||
result = await adapter.sendMail(
|
||||
to=to, subject=subject, body=body, bodyType=bodyType,
|
||||
cc=args.get("cc"), attachments=graphAttachments or None,
|
||||
)
|
||||
return ToolResult(toolCallId="", toolName="sendMail", success=True, data=str(result))
|
||||
return ToolResult(toolCallId="", toolName="sendMail", success=False, error="Mail not supported by this adapter")
|
||||
except Exception as e:
|
||||
|
|
@ -1484,15 +1518,26 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
|
||||
registry.register(
|
||||
"sendMail", _sendMail,
|
||||
description="Send an email via a connected mail service (Outlook, Gmail). Use listConnections to find the connectionId.",
|
||||
description=(
|
||||
"Send or draft an email via a connected mail service (Outlook). "
|
||||
"Supports HTML body and file attachments from the workspace. "
|
||||
"Set draft=true to save as draft without sending. "
|
||||
"Use listConnections to find the connectionId."
|
||||
),
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"connectionId": {"type": "string", "description": "UserConnection ID"},
|
||||
"to": {"type": "array", "items": {"type": "string"}, "description": "Recipient email addresses"},
|
||||
"subject": {"type": "string", "description": "Email subject"},
|
||||
"body": {"type": "string", "description": "Email body text"},
|
||||
"body": {"type": "string", "description": "Email body — plain text or HTML markup"},
|
||||
"bodyType": {"type": "string", "enum": ["text", "html"], "description": "Body format: 'text' (default) or 'html'"},
|
||||
"cc": {"type": "array", "items": {"type": "string"}, "description": "CC addresses"},
|
||||
"attachmentFileIds": {
|
||||
"type": "array", "items": {"type": "string"},
|
||||
"description": "File IDs from the workspace to attach (use listFiles to find IDs)",
|
||||
},
|
||||
"draft": {"type": "boolean", "description": "If true, save as draft in Drafts folder instead of sending"},
|
||||
},
|
||||
"required": ["connectionId", "to", "subject", "body"],
|
||||
},
|
||||
|
|
@ -2471,16 +2516,16 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
|
||||
if not voiceName:
|
||||
try:
|
||||
from modules.interfaces import interfaceDbManagement
|
||||
from modules.features.workspace import interfaceFeatureWorkspace
|
||||
featureInstanceId = context.get("featureInstanceId", "")
|
||||
userId = context.get("userId", "")
|
||||
if userId:
|
||||
dbMgmt = interfaceDbManagement.getInterface(
|
||||
wsIf = interfaceFeatureWorkspace.getInterface(
|
||||
services.user,
|
||||
mandateId=mandateId or None,
|
||||
featureInstanceId=featureInstanceId or None,
|
||||
)
|
||||
vs = dbMgmt.getVoiceSettings(userId) if dbMgmt and hasattr(dbMgmt, "getVoiceSettings") else None
|
||||
vs = wsIf.getVoiceSettings(userId) if wsIf else None
|
||||
if vs:
|
||||
voiceMap = {}
|
||||
if hasattr(vs, "ttsVoiceMap") and vs.ttsVoiceMap:
|
||||
|
|
@ -2914,6 +2959,8 @@ def _registerCoreTools(registry: ToolRegistry, services):
|
|||
neutralizationService = services.getService("neutralization")
|
||||
if not neutralizationService:
|
||||
return ToolResult(toolCallId="", toolName="neutralizeData", success=False, error="Neutralization service not available")
|
||||
if not neutralizationService.interfaceDbComponent:
|
||||
neutralizationService.interfaceDbComponent = services.chat.interfaceDbComponent
|
||||
if text:
|
||||
result = neutralizationService.processText(text)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -27,6 +27,10 @@ from modules.serviceCenter.services.serviceBilling.mainServiceBilling import (
|
|||
ProviderNotAllowedException,
|
||||
BillingContextError
|
||||
)
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
SubscriptionInactiveException,
|
||||
SUBSCRIPTION_REASONS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -590,11 +594,25 @@ detectedIntent-Werte:
|
|||
balanceCheck = billingService.checkBalance(estimatedCost)
|
||||
|
||||
if not balanceCheck.allowed:
|
||||
reason = balanceCheck.reason or ""
|
||||
|
||||
if reason in SUBSCRIPTION_REASONS:
|
||||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum
|
||||
statusMap = {
|
||||
"SUBSCRIPTION_PAYMENT_REQUIRED": SubscriptionStatusEnum.PAST_DUE,
|
||||
"SUBSCRIPTION_EXPIRED": SubscriptionStatusEnum.EXPIRED,
|
||||
"SUBSCRIPTION_INACTIVE": SubscriptionStatusEnum.EXPIRED,
|
||||
}
|
||||
raise SubscriptionInactiveException(
|
||||
status=statusMap.get(reason, SubscriptionStatusEnum.EXPIRED),
|
||||
mandateId=str(mandateId),
|
||||
)
|
||||
|
||||
balance_str = f"{(balanceCheck.currentBalance or 0):.2f}"
|
||||
logger.warning(
|
||||
f"Billing check failed for user {user.id}: "
|
||||
f"Balance {balance_str} CHF, "
|
||||
f"Reason: {balanceCheck.reason}"
|
||||
f"Reason: {reason}"
|
||||
)
|
||||
if balanceCheck.billingModel == BillingModelEnum.PREPAY_MANDATE:
|
||||
ulabel = (getattr(user, "email", None) or getattr(user, "username", None) or str(user.id))
|
||||
|
|
@ -651,6 +669,8 @@ detectedIntent-Werte:
|
|||
|
||||
logger.debug(f"Provider check passed: {len(rbacAllowedProviders)} providers allowed")
|
||||
|
||||
except SubscriptionInactiveException:
|
||||
raise
|
||||
except InsufficientBalanceException:
|
||||
raise
|
||||
except ProviderNotAllowedException:
|
||||
|
|
@ -658,7 +678,6 @@ detectedIntent-Werte:
|
|||
except BillingContextError:
|
||||
raise
|
||||
except Exception as e:
|
||||
# FAIL-SAFE: Don't silently swallow errors - log at ERROR level
|
||||
logger.error(f"BILLING FAIL-SAFE: Billing check failed with unexpected error: {e}")
|
||||
raise BillingContextError(f"Billing check failed: {e}")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,23 +1,17 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
When the shared mandate pool (PREPAY_MANDATE) is exhausted, notify billing contacts.
|
||||
When the shared mandate pool (PREPAY_MANDATE) is exhausted, notify mandate admins.
|
||||
|
||||
Recipients: BillingSettings.notifyEmails for the mandate (configure as mandate owner / finance).
|
||||
Uses the central notifyMandateAdmins() function for recipient resolution and delivery.
|
||||
Emails are throttled per mandate to avoid spam (one notification per cooldown window).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from modules.datamodels.datamodelMessaging import MessagingChannel
|
||||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
|
||||
from modules.interfaces.interfaceMessaging import getInterface as getMessagingInterface
|
||||
from modules.security.rootAccess import getRootUser
|
||||
from typing import Dict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -26,29 +20,6 @@ _poolExhaustedEmailLastSent: Dict[str, float] = {}
|
|||
_DEFAULT_COOLDOWN_SEC = 3600
|
||||
|
||||
|
||||
def _normalizeNotifyEmails(raw: Any) -> List[str]:
|
||||
if raw is None:
|
||||
return []
|
||||
if isinstance(raw, list):
|
||||
return [str(e).strip() for e in raw if str(e).strip()]
|
||||
if isinstance(raw, str):
|
||||
s = raw.strip()
|
||||
if not s:
|
||||
return []
|
||||
# JSON array string
|
||||
if s.startswith("["):
|
||||
try:
|
||||
import json
|
||||
|
||||
parsed = json.loads(s)
|
||||
if isinstance(parsed, list):
|
||||
return [str(e).strip() for e in parsed if str(e).strip()]
|
||||
except Exception:
|
||||
pass
|
||||
return [s]
|
||||
return []
|
||||
|
||||
|
||||
def maybeEmailMandatePoolExhausted(
|
||||
mandateId: str,
|
||||
triggeringUserId: str,
|
||||
|
|
@ -58,7 +29,7 @@ def maybeEmailMandatePoolExhausted(
|
|||
cooldownSec: float = _DEFAULT_COOLDOWN_SEC,
|
||||
) -> None:
|
||||
"""
|
||||
Send one email per mandate per cooldown to BillingSettings.notifyEmails.
|
||||
Send one notification per mandate per cooldown window when the pool is exhausted.
|
||||
|
||||
Args:
|
||||
mandateId: Mandate whose pool is empty.
|
||||
|
|
@ -82,59 +53,22 @@ def maybeEmailMandatePoolExhausted(
|
|||
return
|
||||
|
||||
try:
|
||||
billing = getBillingInterface(getRootUser(), mandateId)
|
||||
settings = billing.getSettings(mandateId) or {}
|
||||
recipients = _normalizeNotifyEmails(settings.get("notifyEmails"))
|
||||
if not recipients:
|
||||
logger.warning(
|
||||
"PREPAY_MANDATE pool exhausted for mandate %s but notifyEmails is empty — "
|
||||
"configure BillingSettings.notifyEmails for owner alerts",
|
||||
mandateId,
|
||||
)
|
||||
return
|
||||
from modules.shared.notifyMandateAdmins import notifyMandateAdmins
|
||||
|
||||
subject = f"[PowerOn] Mandanten-Budget aufgebraucht (Mandant {mandateId[:8]}…)"
|
||||
body = (
|
||||
f"Das gemeinsame Guthaben (PREPAY_MANDATE) für diesen Mandanten ist nicht mehr ausreichend.\n\n"
|
||||
f"Mandanten-ID: {mandateId}\n"
|
||||
f"Aktuelles Guthaben (Pool): CHF {currentBalance:.2f}\n"
|
||||
f"Benötigt (mind.): CHF {requiredAmount:.2f}\n\n"
|
||||
f"Auslösende/r Benutzer/in: {triggeringUserLabel} (ID: {triggeringUserId})\n\n"
|
||||
f"Bitte laden Sie das Mandats-Guthaben in der Billing-Verwaltung auf, "
|
||||
f"damit Benutzer wieder AI-Funktionen nutzen können.\n"
|
||||
sent = notifyMandateAdmins(
|
||||
mandateId,
|
||||
"[PowerOn] Mandanten-Budget aufgebraucht",
|
||||
"Budget aufgebraucht",
|
||||
[
|
||||
"Das gemeinsame Guthaben (Prepaid-Pool) für diesen Mandanten ist nicht mehr ausreichend.",
|
||||
f"Aktuelles Guthaben: CHF {currentBalance:.2f}\n"
|
||||
f"Benötigt (mindestens): CHF {requiredAmount:.2f}",
|
||||
f"Ausgelöst durch: {triggeringUserLabel}",
|
||||
"Bitte laden Sie das Mandats-Guthaben in der Billing-Verwaltung auf, "
|
||||
"damit Benutzer wieder AI-Funktionen nutzen können.",
|
||||
],
|
||||
)
|
||||
escaped = html.escape(body)
|
||||
# Cannot use '\\n' inside f-string {…} expression (SyntaxError); build replacement outside.
|
||||
brWithNl = "<br>" + "\n"
|
||||
htmlMessage = f"""<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"></head>
|
||||
<body style="font-family: Arial, sans-serif; line-height: 1.6;">
|
||||
{escaped.replace(chr(10), brWithNl)}
|
||||
</body></html>"""
|
||||
|
||||
messaging = getMessagingInterface()
|
||||
any_ok = False
|
||||
for to in recipients:
|
||||
try:
|
||||
ok = messaging.send(
|
||||
channel=MessagingChannel.EMAIL,
|
||||
recipient=to,
|
||||
subject=subject,
|
||||
message=htmlMessage,
|
||||
)
|
||||
if ok:
|
||||
any_ok = True
|
||||
else:
|
||||
logger.warning("Pool exhausted email failed for %s", to)
|
||||
except Exception as send_err:
|
||||
logger.error("Error sending pool exhausted email to %s: %s", to, send_err)
|
||||
|
||||
if any_ok:
|
||||
if sent > 0:
|
||||
_poolExhaustedEmailLastSent[mandateId] = now
|
||||
logger.info(
|
||||
"Sent mandate pool exhausted notification for mandate %s to %s recipient(s)",
|
||||
mandateId,
|
||||
len(recipients),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("maybeEmailMandatePoolExhausted failed: %s", e, exc_info=True)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ from modules.interfaces.interfaceDbBilling import getInterface as getBillingInte
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Markup percentage for internal pricing (+50% für Infrastruktur und Platform Service + 50% für Währungsrisiko ==> Faktor 2.0)
|
||||
BILLING_MARKUP_PERCENT = 100
|
||||
BILLING_MARKUP_PERCENT = 400
|
||||
|
||||
# Singleton cache
|
||||
_billingServices: Dict[str, "BillingService"] = {}
|
||||
|
|
@ -160,6 +160,10 @@ class BillingService:
|
|||
def checkBalance(self, estimatedCost: float = 0.0) -> BillingCheckResult:
|
||||
"""
|
||||
Check if the current user/mandate has sufficient balance.
|
||||
|
||||
Gate order:
|
||||
1. Subscription active? (fast, cached) — blocks AI if not
|
||||
2. Budget sufficient? (existing prepaid logic)
|
||||
|
||||
Args:
|
||||
estimatedCost: Estimated cost of the operation (with markup applied)
|
||||
|
|
@ -167,11 +171,42 @@ class BillingService:
|
|||
Returns:
|
||||
BillingCheckResult indicating if operation is allowed
|
||||
"""
|
||||
subResult = self._checkSubscription()
|
||||
if subResult is not None:
|
||||
return subResult
|
||||
|
||||
return self._billingInterface.checkBalance(
|
||||
self.mandateId,
|
||||
self.currentUser.id,
|
||||
estimatedCost
|
||||
)
|
||||
|
||||
def _checkSubscription(self) -> Optional[BillingCheckResult]:
|
||||
"""Return a failing BillingCheckResult if subscription is not active, else None."""
|
||||
try:
|
||||
from modules.datamodels.datamodelSubscription import SubscriptionStatusEnum
|
||||
from modules.serviceCenter.services.serviceSubscription.mainServiceSubscription import (
|
||||
getService as getSubscriptionService,
|
||||
_subscriptionReasonForStatus,
|
||||
_subscriptionUserActionForStatus,
|
||||
)
|
||||
|
||||
subService = getSubscriptionService(self.currentUser, self.mandateId)
|
||||
status = subService.assertActive(self.mandateId)
|
||||
|
||||
if status in (SubscriptionStatusEnum.ACTIVE, SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.PAST_DUE):
|
||||
return None
|
||||
|
||||
return BillingCheckResult(
|
||||
allowed=False,
|
||||
reason=_subscriptionReasonForStatus(status),
|
||||
upgradeRequired=True,
|
||||
subscriptionUiPath="/admin/billing?tab=subscription",
|
||||
userAction=_subscriptionUserActionForStatus(status),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Subscription check failed (allowing): {e}")
|
||||
return None
|
||||
|
||||
def hasBalance(self, estimatedCost: float = 0.0) -> bool:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -82,17 +82,8 @@ def create_checkout_session(
|
|||
f"Invalid amount {amount_chf} CHF. Allowed: {ALLOWED_AMOUNTS_CHF}"
|
||||
)
|
||||
|
||||
# Pin API version from config (match Stripe Dashboard)
|
||||
api_version = APP_CONFIG.get("STRIPE_API_VERSION")
|
||||
if api_version:
|
||||
stripe.api_version = api_version
|
||||
|
||||
# Get secrets
|
||||
secret_key = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY")
|
||||
if not secret_key:
|
||||
raise ValueError("STRIPE_SECRET_KEY_SECRET not configured")
|
||||
|
||||
stripe.api_key = secret_key
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
|
||||
base_return_url = _normalizeReturnUrl(return_url)
|
||||
query_separator = "&" if "?" in base_return_url else "?"
|
||||
|
|
|
|||
|
|
@ -253,7 +253,8 @@ class ChatService:
|
|||
logger.error(f"No messages found with documentsLabel: {docRef}")
|
||||
raise ValueError(f"Document reference not found: {docRef}")
|
||||
|
||||
logger.debug(f"Resolved {len(allDocuments)} documents from document list: {documentList}")
|
||||
ref_count = len(getattr(documentList, 'references', [])) if documentList else 0
|
||||
logger.debug(f"Resolved {len(allDocuments)} documents from document list ({ref_count} refs)")
|
||||
return allDocuments
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting documents from document list: {str(e)}")
|
||||
|
|
|
|||
|
|
@ -50,6 +50,38 @@ class ExtractionService:
|
|||
if model is None or model.calculatepriceCHF is None:
|
||||
raise RuntimeError(f"FATAL: Required internal model '{modelDisplayName}' is not available. Check connector registration.")
|
||||
|
||||
def extractContentFromBytes(
|
||||
self,
|
||||
documentBytes: bytes,
|
||||
fileName: str,
|
||||
mimeType: str,
|
||||
documentId: Optional[str] = None,
|
||||
options: Optional[ExtractionOptions] = None,
|
||||
) -> ContentExtracted:
|
||||
"""Extract content from raw bytes (no persistence).
|
||||
Used for inline ActionDocuments from SharePoint/email in automation2."""
|
||||
opts = options or ExtractionOptions(prompt="", mergeStrategy=MergeStrategy())
|
||||
doc_id = documentId or str(uuid.uuid4())
|
||||
ec = runExtraction(
|
||||
self._extractorRegistry,
|
||||
self._chunkerRegistry,
|
||||
documentBytes,
|
||||
fileName,
|
||||
mimeType,
|
||||
opts,
|
||||
)
|
||||
for p in ec.parts:
|
||||
if not p.metadata:
|
||||
p.metadata = {}
|
||||
p.metadata.setdefault("documentId", doc_id)
|
||||
p.metadata.setdefault("documentMimeType", mimeType)
|
||||
p.metadata.setdefault("originalFileName", fileName)
|
||||
p.metadata.setdefault("contentFormat", "extracted")
|
||||
p.metadata.setdefault("intent", "extract")
|
||||
p.metadata.setdefault("usageHint", f"Use extracted content from {fileName}")
|
||||
p.metadata.setdefault("sourceAction", "extraction.extractContentFromBytes")
|
||||
return ec
|
||||
|
||||
def extractContent(
|
||||
self,
|
||||
documents: List[ChatDocument],
|
||||
|
|
|
|||
|
|
@ -56,15 +56,15 @@ class DocumentGenerationPath:
|
|||
|
||||
try:
|
||||
# Schritt 5A: Kläre Dokument-Intents
|
||||
documents = []
|
||||
doc_list = []
|
||||
if documentList:
|
||||
documents = self.services.chat.getChatDocumentsFromDocumentList(documentList)
|
||||
doc_list = self.services.chat.getChatDocumentsFromDocumentList(documentList)
|
||||
|
||||
# Filter: Entferne Original-Dokumente, wenn bereits Pre-Extracted JSONs existieren
|
||||
# (um Duplikate zu vermeiden - Pre-Extracted JSONs enthalten bereits die ContentParts)
|
||||
# Schritt 1: Identifiziere alle Original-Dokument-IDs, die durch Pre-Extracted JSONs abgedeckt werden
|
||||
originalDocIdsCoveredByPreExtracted = set()
|
||||
for doc in documents:
|
||||
for doc in doc_list:
|
||||
preExtracted = self.services.ai.intentAnalyzer.resolvePreExtractedDocument(doc)
|
||||
if preExtracted:
|
||||
originalDocId = preExtracted["originalDocument"]["id"]
|
||||
|
|
@ -73,7 +73,7 @@ class DocumentGenerationPath:
|
|||
|
||||
# Schritt 2: Filtere Dokumente - entferne Original-Dokumente, die bereits durch Pre-Extracted JSONs abgedeckt werden
|
||||
filteredDocuments = []
|
||||
for doc in documents:
|
||||
for doc in doc_list:
|
||||
preExtracted = self.services.ai.intentAnalyzer.resolvePreExtractedDocument(doc)
|
||||
if preExtracted:
|
||||
# Pre-Extracted JSON behalten
|
||||
|
|
@ -85,13 +85,13 @@ class DocumentGenerationPath:
|
|||
# Normales Dokument ohne Pre-Extracted JSON - behalten
|
||||
filteredDocuments.append(doc)
|
||||
|
||||
documents = filteredDocuments
|
||||
doc_list = filteredDocuments
|
||||
|
||||
checkWorkflowStopped(self.services)
|
||||
|
||||
if not documentIntents and documents:
|
||||
if not documentIntents and doc_list:
|
||||
documentIntents = await self.services.ai.clarifyDocumentIntents(
|
||||
documents,
|
||||
doc_list,
|
||||
userPrompt,
|
||||
{"outputFormat": outputFormat},
|
||||
docOperationId
|
||||
|
|
@ -100,9 +100,9 @@ class DocumentGenerationPath:
|
|||
checkWorkflowStopped(self.services)
|
||||
|
||||
# Schritt 5B: Extrahiere und bereite Content vor
|
||||
if documents:
|
||||
if doc_list:
|
||||
preparedContentParts = await self.services.ai.extractAndPrepareContent(
|
||||
documents,
|
||||
doc_list,
|
||||
documentIntents or [],
|
||||
docOperationId
|
||||
)
|
||||
|
|
|
|||
|
|
@ -72,11 +72,18 @@ class RendererRegistry:
|
|||
"""Register a renderer class keyed by (format, outputStyle)."""
|
||||
try:
|
||||
supportedFormats = rendererClass.getSupportedFormats()
|
||||
outputStyle = rendererClass.getOutputStyle() if hasattr(rendererClass, 'getOutputStyle') else 'document'
|
||||
priority = rendererClass.getPriority() if hasattr(rendererClass, 'getPriority') else 0
|
||||
|
||||
for formatName in supportedFormats:
|
||||
formatKey = formatName.lower()
|
||||
# Per-format output style when renderer supports it (e.g. RendererText: txt→document, js→code)
|
||||
if hasattr(rendererClass, 'getOutputStyle'):
|
||||
try:
|
||||
outputStyle = rendererClass.getOutputStyle(formatKey)
|
||||
except TypeError:
|
||||
outputStyle = rendererClass.getOutputStyle() if callable(getattr(rendererClass, 'getOutputStyle')) else 'document'
|
||||
else:
|
||||
outputStyle = 'document'
|
||||
registryKey = (formatKey, outputStyle)
|
||||
|
||||
if registryKey in self._renderers:
|
||||
|
|
|
|||
|
|
@ -692,12 +692,35 @@ class SharepointService:
|
|||
logger.error(f"Error extracting site from standard path '{pathQuery}': {str(e)}")
|
||||
return None
|
||||
|
||||
def _isGraphSiteId(self, sitePath: str) -> bool:
|
||||
"""Check if sitePath is a Graph API site ID (hostname,siteId,webId format with 2 commas)."""
|
||||
if not sitePath or sitePath.count(',') != 2:
|
||||
return False
|
||||
parts = sitePath.split(',')
|
||||
return len(parts) == 3 and all(p.strip() for p in parts)
|
||||
|
||||
async def getSiteByStandardPath(self, sitePath: str, allSites: Optional[List[Dict[str, Any]]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get SharePoint site directly by Microsoft-standard path (/sites/SiteName)."""
|
||||
"""Get SharePoint site directly by Microsoft-standard path (/sites/SiteName) or by site ID."""
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
hostname = None
|
||||
|
||||
# When sitePath is a Graph API site ID (host,siteId,webId), use sites/{id} directly
|
||||
if self._isGraphSiteId(sitePath):
|
||||
endpoint = f"sites/{sitePath}"
|
||||
result = await self._makeGraphApiCall(endpoint)
|
||||
if result and "error" not in result:
|
||||
return {
|
||||
"id": result.get("id"),
|
||||
"displayName": result.get("displayName"),
|
||||
"name": result.get("name"),
|
||||
"webUrl": result.get("webUrl"),
|
||||
"description": result.get("description"),
|
||||
"createdDateTime": result.get("createdDateTime"),
|
||||
"lastModifiedDateTime": result.get("lastModifiedDateTime")
|
||||
}
|
||||
return None
|
||||
|
||||
hostname = None
|
||||
if allSites and len(allSites) > 0:
|
||||
webUrl = allSites[0].get("webUrl", "")
|
||||
hostname = urlparse(webUrl).hostname if webUrl else None
|
||||
|
|
@ -777,6 +800,14 @@ class SharepointService:
|
|||
parsedPath = self.extractSiteFromStandardPath(pathQuery)
|
||||
if parsedPath:
|
||||
siteName = parsedPath.get("siteName")
|
||||
# When siteName is Graph API composite ID (host,siteId,webId), match by exact id
|
||||
if siteName and ',' in siteName:
|
||||
exact = [s for s in allSites if s.get("id") == siteName]
|
||||
if exact:
|
||||
logger.info(f"Resolved site by exact ID: {siteName}")
|
||||
return exact
|
||||
logger.warning(f"No site found with exact ID '{siteName}'")
|
||||
return []
|
||||
sites = self.filterSitesByHint(allSites, siteName)
|
||||
if not sites:
|
||||
logger.warning(f"No SharePoint site found matching '{siteName}'")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,758 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Subscription Service — state-machine-based lifecycle management.
|
||||
|
||||
Every mutation takes an explicit subscriptionId. No status-scan guessing.
|
||||
See wiki/concepts/Subscription-State-Machine.md for the full state machine.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from modules.datamodels.datamodelUam import User
|
||||
from modules.datamodels.datamodelSubscription import (
|
||||
SubscriptionPlan,
|
||||
MandateSubscription,
|
||||
SubscriptionStatusEnum,
|
||||
BillingPeriodEnum,
|
||||
OPERATIVE_STATUSES,
|
||||
_getPlan,
|
||||
_getSelectablePlans,
|
||||
)
|
||||
from modules.interfaces.interfaceDbSubscription import (
|
||||
getInterface as getSubscriptionInterface,
|
||||
InvalidTransitionError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUBSCRIPTION_CACHE_TTL_SECONDS = 60
|
||||
_STALE_PENDING_SECONDS = 30 * 60
|
||||
|
||||
_subscriptionServices: Dict[str, "SubscriptionService"] = {}
|
||||
_statusCache: Dict[str, tuple] = {}
|
||||
|
||||
|
||||
def getService(currentUser: User, mandateId: str) -> "SubscriptionService":
|
||||
cacheKey = f"{currentUser.id}_{mandateId}"
|
||||
if cacheKey not in _subscriptionServices:
|
||||
_subscriptionServices[cacheKey] = SubscriptionService(currentUser, mandateId)
|
||||
else:
|
||||
_subscriptionServices[cacheKey].setContext(currentUser, mandateId)
|
||||
return _subscriptionServices[cacheKey]
|
||||
|
||||
|
||||
class SubscriptionService:
|
||||
"""State-machine-based subscription service.
|
||||
All mutations use explicit subscriptionId. No scan-based writes."""
|
||||
|
||||
def __init__(self, contextOrUser, mandateId=None, get_service=None):
|
||||
if mandateId is not None and callable(mandateId):
|
||||
ctx = contextOrUser
|
||||
self.currentUser = ctx.user
|
||||
self.mandateId = ctx.mandate_id or ""
|
||||
elif get_service is not None and hasattr(contextOrUser, "user"):
|
||||
ctx = contextOrUser
|
||||
self.currentUser = ctx.user
|
||||
self.mandateId = ctx.mandate_id or ""
|
||||
else:
|
||||
self.currentUser = contextOrUser
|
||||
self.mandateId = mandateId or ""
|
||||
self._interface = getSubscriptionInterface(self.currentUser, self.mandateId)
|
||||
|
||||
def setContext(self, currentUser: User, mandateId: str):
|
||||
self.currentUser = currentUser
|
||||
self.mandateId = mandateId
|
||||
self._interface = getSubscriptionInterface(currentUser, mandateId)
|
||||
|
||||
# =========================================================================
|
||||
# Billing gate (cached, read-only)
|
||||
# =========================================================================
|
||||
|
||||
def assertActive(self, mandateId: str = None) -> SubscriptionStatusEnum:
|
||||
"""Return subscription status for billing decisions. Uses TTL cache.
|
||||
This is the ONLY method that works by mandateId (read-only)."""
|
||||
mid = mandateId or self.mandateId
|
||||
now = time.monotonic()
|
||||
|
||||
cached = _statusCache.get(mid)
|
||||
if cached and cached[1] > now:
|
||||
return cached[0]
|
||||
|
||||
status = self._interface.assertActive(mid)
|
||||
_statusCache[mid] = (status, now + SUBSCRIPTION_CACHE_TTL_SECONDS)
|
||||
return status
|
||||
|
||||
def invalidateCache(self, mandateId: str = None):
|
||||
mid = mandateId or self.mandateId
|
||||
_statusCache.pop(mid, None)
|
||||
|
||||
# =========================================================================
|
||||
# Capacity (delegation)
|
||||
# =========================================================================
|
||||
|
||||
def assertCapacity(self, mandateId: str, resourceType: str, delta: int = 1) -> bool:
|
||||
return self._interface.assertCapacity(mandateId or self.mandateId, resourceType, delta)
|
||||
|
||||
# =========================================================================
|
||||
# Read operations
|
||||
# =========================================================================
|
||||
|
||||
def getById(self, subscriptionId: str) -> Optional[Dict[str, Any]]:
|
||||
return self._interface.getById(subscriptionId)
|
||||
|
||||
def getOperativeSubscription(self, mandateId: str = None) -> Optional[Dict[str, Any]]:
|
||||
return self._interface.getOperativeForMandate(mandateId or self.mandateId)
|
||||
|
||||
def getScheduledSubscription(self, mandateId: str = None) -> Optional[Dict[str, Any]]:
|
||||
return self._interface.getScheduledForMandate(mandateId or self.mandateId)
|
||||
|
||||
def listSubscriptions(self, mandateId: str = None, statusFilter=None) -> List[Dict[str, Any]]:
|
||||
return self._interface.listForMandate(mandateId or self.mandateId, statusFilter)
|
||||
|
||||
def getSelectablePlans(self) -> List[SubscriptionPlan]:
|
||||
return _getSelectablePlans()
|
||||
|
||||
def getPlan(self, planKey: str) -> Optional[SubscriptionPlan]:
|
||||
return _getPlan(planKey)
|
||||
|
||||
# =========================================================================
|
||||
# T1/T2: Plan activation (creates PENDING, returns checkout URL)
|
||||
# =========================================================================
|
||||
|
||||
def activatePlan(self, mandateId: str, planKey: str, returnUrl: str) -> Dict[str, Any]:
|
||||
"""Create a new subscription as PENDING and start the checkout flow.
|
||||
|
||||
- Free/trial plans: immediately ACTIVE/TRIALING (no checkout).
|
||||
- Paid plans with active predecessor: PENDING -> checkout -> SCHEDULED on confirmation.
|
||||
- Paid plans without predecessor: PENDING -> checkout -> ACTIVE on confirmation.
|
||||
|
||||
Cleans up any existing PENDING/SCHEDULED for this mandate first (by ID)."""
|
||||
mid = mandateId or self.mandateId
|
||||
plan = _getPlan(planKey)
|
||||
if not plan:
|
||||
raise ValueError(f"Unknown plan: {planKey}")
|
||||
|
||||
isPaid = plan.billingPeriod != BillingPeriodEnum.NONE and not plan.trialDays
|
||||
currentOperative = self._interface.getOperativeForMandate(mid)
|
||||
|
||||
self._cleanupPreparatorySubscriptions(mid)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if plan.trialDays:
|
||||
initialStatus = SubscriptionStatusEnum.TRIALING
|
||||
elif isPaid:
|
||||
initialStatus = SubscriptionStatusEnum.PENDING
|
||||
else:
|
||||
initialStatus = SubscriptionStatusEnum.ACTIVE
|
||||
|
||||
sub = MandateSubscription(
|
||||
mandateId=mid,
|
||||
planKey=planKey,
|
||||
status=initialStatus,
|
||||
recurring=plan.autoRenew and not plan.trialDays,
|
||||
startedAt=now,
|
||||
currentPeriodStart=now,
|
||||
snapshotPricePerUserCHF=plan.pricePerUserCHF,
|
||||
snapshotPricePerInstanceCHF=plan.pricePerFeatureInstanceCHF,
|
||||
)
|
||||
|
||||
if plan.trialDays:
|
||||
sub.trialEndsAt = now + timedelta(days=plan.trialDays)
|
||||
|
||||
if plan.billingPeriod == BillingPeriodEnum.MONTHLY:
|
||||
sub.currentPeriodEnd = now + timedelta(days=30)
|
||||
elif plan.billingPeriod == BillingPeriodEnum.YEARLY:
|
||||
sub.currentPeriodEnd = now + timedelta(days=365)
|
||||
|
||||
created = self._interface.createSubscription(sub)
|
||||
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(returnUrl) if returnUrl else None
|
||||
pUrl = f"{parsed.scheme}://{parsed.netloc}" if parsed and parsed.scheme else ""
|
||||
|
||||
if isPaid:
|
||||
try:
|
||||
checkoutUrl = self._createCheckoutSession(mid, plan, created, currentOperative, returnUrl)
|
||||
created["redirectUrl"] = checkoutUrl
|
||||
except Exception as e:
|
||||
self._interface.forceExpire(created["id"])
|
||||
self.invalidateCache(mid)
|
||||
raise ValueError(f"Subscription konnte nicht erstellt werden: {e}") from e
|
||||
else:
|
||||
if currentOperative:
|
||||
self._expireOperative(currentOperative["id"], mid)
|
||||
_notifySubscriptionChange(mid, "activated", plan, subscriptionRecord=created, platformUrl=pUrl)
|
||||
|
||||
self.invalidateCache(mid)
|
||||
return created
|
||||
|
||||
def _cleanupPreparatorySubscriptions(self, mandateId: str) -> None:
|
||||
"""Expire any existing PENDING or SCHEDULED subscriptions for this mandate (by ID)."""
|
||||
preparatory = self._interface.listForMandate(
|
||||
mandateId, [SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.SCHEDULED],
|
||||
)
|
||||
for sub in preparatory:
|
||||
subId = sub["id"]
|
||||
currentStatus = SubscriptionStatusEnum(sub["status"])
|
||||
stripeSubId = sub.get("stripeSubscriptionId")
|
||||
|
||||
if stripeSubId and currentStatus == SubscriptionStatusEnum.SCHEDULED:
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
stripe.Subscription.cancel(stripeSubId)
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel Stripe sub %s during cleanup: %s", stripeSubId, e)
|
||||
|
||||
self._interface.transitionStatus(subId, currentStatus, SubscriptionStatusEnum.EXPIRED)
|
||||
logger.info("Cleaned up %s subscription %s for mandate %s", currentStatus.value, subId, mandateId)
|
||||
|
||||
def _expireOperative(self, subscriptionId: str, mandateId: str) -> None:
|
||||
"""Expire the current operative subscription (used when a free/trial plan replaces it)."""
|
||||
sub = self._interface.getById(subscriptionId)
|
||||
if not sub:
|
||||
return
|
||||
currentStatus = SubscriptionStatusEnum(sub["status"])
|
||||
if currentStatus in OPERATIVE_STATUSES:
|
||||
stripeSubId = sub.get("stripeSubscriptionId")
|
||||
if stripeSubId:
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
stripe.Subscription.cancel(stripeSubId)
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel Stripe sub %s: %s", stripeSubId, e)
|
||||
self._interface.transitionStatus(subscriptionId, currentStatus, SubscriptionStatusEnum.EXPIRED)
|
||||
|
||||
def _createCheckoutSession(
|
||||
self, mandateId: str, plan: SubscriptionPlan, subRecord: Dict[str, Any],
|
||||
currentOperative: Optional[Dict[str, Any]], returnUrl: str,
|
||||
) -> str:
|
||||
"""Create a Stripe Checkout Session. If a predecessor exists, delays billing
|
||||
via trial_end to start after the predecessor's period end."""
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
from modules.serviceCenter.services.serviceSubscription.stripeBootstrap import getStripePricesForPlan
|
||||
|
||||
stripe = getStripeClient()
|
||||
priceMapping = getStripePricesForPlan(plan.planKey)
|
||||
if not priceMapping or (not priceMapping.stripePriceIdUsers and not priceMapping.stripePriceIdInstances):
|
||||
raise ValueError(f"Stripe Price IDs not provisioned for plan {plan.planKey}")
|
||||
|
||||
stripeCustomerId = self._resolveStripeCustomer(mandateId)
|
||||
if not stripeCustomerId:
|
||||
raise ValueError(f"Could not resolve Stripe customer for mandate {mandateId}")
|
||||
|
||||
activeUsers = self._interface.countActiveUsers(mandateId)
|
||||
activeInstances = self._interface.countActiveFeatureInstances(mandateId)
|
||||
|
||||
lineItems = []
|
||||
if priceMapping.stripePriceIdUsers:
|
||||
lineItems.append({"price": priceMapping.stripePriceIdUsers, "quantity": max(activeUsers, 1)})
|
||||
if priceMapping.stripePriceIdInstances and activeInstances > 0:
|
||||
lineItems.append({"price": priceMapping.stripePriceIdInstances, "quantity": activeInstances})
|
||||
|
||||
if not returnUrl:
|
||||
raise ValueError("returnUrl is required for paid subscription checkout")
|
||||
|
||||
from urllib.parse import urlparse
|
||||
parsedReturn = urlparse(returnUrl)
|
||||
platformUrl = f"{parsedReturn.scheme}://{parsedReturn.netloc}" if parsedReturn.scheme else ""
|
||||
|
||||
separator = "&" if "?" in returnUrl else "?"
|
||||
successUrl = f"{returnUrl}{separator}success=true&session_id={{CHECKOUT_SESSION_ID}}"
|
||||
cancelUrl = f"{returnUrl}{separator}canceled=true"
|
||||
|
||||
subscriptionData: Dict[str, Any] = {
|
||||
"metadata": {
|
||||
"mandateId": mandateId,
|
||||
"subscriptionRecordId": subRecord["id"],
|
||||
"planKey": plan.planKey,
|
||||
"platformUrl": platformUrl,
|
||||
},
|
||||
}
|
||||
|
||||
if currentOperative and currentOperative.get("currentPeriodEnd"):
|
||||
periodEnd = currentOperative["currentPeriodEnd"]
|
||||
if isinstance(periodEnd, str):
|
||||
periodEnd = datetime.fromisoformat(periodEnd)
|
||||
trialEndTs = int(periodEnd.timestamp())
|
||||
subscriptionData["trial_end"] = trialEndTs
|
||||
self._interface.updateFields(subRecord["id"], {"effectiveFrom": periodEnd.isoformat()})
|
||||
|
||||
session = stripe.checkout.Session.create(
|
||||
mode="subscription",
|
||||
customer=stripeCustomerId,
|
||||
line_items=lineItems,
|
||||
success_url=successUrl,
|
||||
cancel_url=cancelUrl,
|
||||
subscription_data=subscriptionData,
|
||||
)
|
||||
|
||||
if not session or not session.url:
|
||||
raise ValueError("Stripe Checkout Session creation failed")
|
||||
|
||||
logger.info("Checkout session %s created for mandate %s, plan %s", session.id, mandateId, plan.planKey)
|
||||
return session.url
|
||||
|
||||
def _resolveStripeCustomer(self, mandateId: str) -> Optional[str]:
|
||||
try:
|
||||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
|
||||
billingIf = getBillingInterface(self.currentUser, mandateId)
|
||||
settings = billingIf.getSettings(mandateId)
|
||||
if not settings:
|
||||
return None
|
||||
customerId = settings.get("stripeCustomerId")
|
||||
if customerId:
|
||||
return customerId
|
||||
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
|
||||
mandateLabel = mandateId
|
||||
try:
|
||||
from modules.datamodels.datamodelUam import Mandate
|
||||
from modules.security.rootAccess import getRootDbAppConnector
|
||||
appDb = getRootDbAppConnector()
|
||||
rows = appDb.getRecordset(Mandate, recordFilter={"id": mandateId})
|
||||
if rows:
|
||||
mandateLabel = rows[0].get("label") or rows[0].get("name") or mandateId
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
customer = stripe.Customer.create(name=mandateLabel, metadata={"mandateId": mandateId})
|
||||
billingIf.updateSettings(settings["id"], {"stripeCustomerId": customer.id})
|
||||
logger.info("Stripe customer %s created for mandate %s", customer.id, mandateId)
|
||||
return customer.id
|
||||
except Exception as e:
|
||||
logger.error("_resolveStripeCustomer(%s) failed: %s", mandateId, e)
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# T7: Cancel (set recurring=false)
|
||||
# =========================================================================
|
||||
|
||||
def cancelSubscription(self, subscriptionId: str) -> Dict[str, Any]:
|
||||
"""Cancel a subscription (T7: set recurring=false, Stripe cancel_at_period_end).
|
||||
The subscription stays ACTIVE until its period ends."""
|
||||
sub = self._interface.getById(subscriptionId)
|
||||
if not sub:
|
||||
raise ValueError(f"Subscription {subscriptionId} not found")
|
||||
|
||||
status = sub.get("status", "")
|
||||
mandateId = sub["mandateId"]
|
||||
|
||||
if status == SubscriptionStatusEnum.PENDING.value:
|
||||
result = self._interface.transitionStatus(
|
||||
subscriptionId, SubscriptionStatusEnum.PENDING, SubscriptionStatusEnum.EXPIRED,
|
||||
)
|
||||
self.invalidateCache(mandateId)
|
||||
return result
|
||||
|
||||
if status == SubscriptionStatusEnum.SCHEDULED.value:
|
||||
stripeSubId = sub.get("stripeSubscriptionId")
|
||||
if stripeSubId:
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
stripe.Subscription.cancel(stripeSubId)
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel Stripe sub %s: %s", stripeSubId, e)
|
||||
result = self._interface.transitionStatus(
|
||||
subscriptionId, SubscriptionStatusEnum.SCHEDULED, SubscriptionStatusEnum.EXPIRED,
|
||||
)
|
||||
self.invalidateCache(mandateId)
|
||||
return result
|
||||
|
||||
if status != SubscriptionStatusEnum.ACTIVE.value:
|
||||
raise ValueError(f"Cannot cancel subscription in status {status}")
|
||||
|
||||
if not sub.get("recurring", True):
|
||||
raise ValueError("Subscription is already cancelled (non-recurring)")
|
||||
|
||||
stripeSubId = sub.get("stripeSubscriptionId")
|
||||
pUrl = ""
|
||||
if stripeSubId:
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
stripeSub = stripe.Subscription.modify(stripeSubId, cancel_at_period_end=True)
|
||||
pUrl = (stripeSub.get("metadata") or {}).get("platformUrl", "")
|
||||
except Exception as e:
|
||||
logger.error("Failed to set cancel_at_period_end for %s: %s", stripeSubId, e)
|
||||
|
||||
result = self._interface.updateFields(subscriptionId, {"recurring": False})
|
||||
self.invalidateCache(mandateId)
|
||||
|
||||
plan = _getPlan(sub.get("planKey", ""))
|
||||
_notifySubscriptionChange(mandateId, "cancelled", plan, subscriptionRecord=sub, platformUrl=pUrl)
|
||||
return result
|
||||
|
||||
# =========================================================================
|
||||
# T8: Reactivate (set recurring=true)
|
||||
# =========================================================================
|
||||
|
||||
def reactivateSubscription(self, subscriptionId: str) -> Dict[str, Any]:
|
||||
"""Reactivate a cancelled subscription before its period ends (T8: recurring=true)."""
|
||||
sub = self._interface.getById(subscriptionId)
|
||||
if not sub:
|
||||
raise ValueError(f"Subscription {subscriptionId} not found")
|
||||
|
||||
if sub.get("status") != SubscriptionStatusEnum.ACTIVE.value:
|
||||
raise ValueError(f"Can only reactivate ACTIVE subscriptions, got {sub.get('status')}")
|
||||
if sub.get("recurring", True):
|
||||
raise ValueError("Subscription is already recurring")
|
||||
|
||||
periodEnd = sub.get("currentPeriodEnd")
|
||||
if periodEnd:
|
||||
if isinstance(periodEnd, str):
|
||||
periodEnd = datetime.fromisoformat(periodEnd)
|
||||
if periodEnd <= datetime.now(timezone.utc):
|
||||
raise ValueError("Cannot reactivate — period has already ended")
|
||||
|
||||
stripeSubId = sub.get("stripeSubscriptionId")
|
||||
if stripeSubId:
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
stripe.Subscription.modify(stripeSubId, cancel_at_period_end=False)
|
||||
except Exception as e:
|
||||
logger.error("Failed to reactivate Stripe sub %s: %s", stripeSubId, e)
|
||||
|
||||
result = self._interface.updateFields(subscriptionId, {"recurring": True})
|
||||
self.invalidateCache(sub["mandateId"])
|
||||
return result
|
||||
|
||||
# =========================================================================
|
||||
# T13: Sysadmin force-cancel
|
||||
# =========================================================================
|
||||
|
||||
def forceCancel(self, subscriptionId: str) -> Dict[str, Any]:
|
||||
"""Sysadmin force-cancel: immediately expire any non-terminal subscription."""
|
||||
sub = self._interface.getById(subscriptionId)
|
||||
if not sub:
|
||||
raise ValueError(f"Subscription {subscriptionId} not found")
|
||||
|
||||
stripeSubId = sub.get("stripeSubscriptionId")
|
||||
pUrl = ""
|
||||
if stripeSubId:
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
stripeSub = stripe.Subscription.retrieve(stripeSubId)
|
||||
pUrl = (stripeSub.get("metadata") or {}).get("platformUrl", "")
|
||||
stripe.Subscription.cancel(stripeSubId)
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel Stripe sub %s: %s", stripeSubId, e)
|
||||
|
||||
result = self._interface.forceExpire(subscriptionId)
|
||||
mandateId = sub["mandateId"]
|
||||
self.invalidateCache(mandateId)
|
||||
|
||||
plan = _getPlan(sub.get("planKey", ""))
|
||||
_notifySubscriptionChange(mandateId, "force_cancelled", plan, subscriptionRecord=sub, platformUrl=pUrl)
|
||||
return result
|
||||
|
||||
# =========================================================================
|
||||
# T6: Trial expiry
|
||||
# =========================================================================
|
||||
|
||||
def handleTrialExpiry(self, subscriptionId: str) -> None:
|
||||
"""Expire a trial subscription (T6: TRIALING -> EXPIRED)."""
|
||||
sub = self._interface.getById(subscriptionId)
|
||||
if not sub or sub.get("status") != SubscriptionStatusEnum.TRIALING.value:
|
||||
return
|
||||
|
||||
self._interface.transitionStatus(
|
||||
subscriptionId, SubscriptionStatusEnum.TRIALING, SubscriptionStatusEnum.EXPIRED,
|
||||
)
|
||||
self.invalidateCache(sub["mandateId"])
|
||||
|
||||
plan = _getPlan(sub.get("planKey", ""))
|
||||
successorPlan = _getPlan(plan.successorPlanKey) if plan and plan.successorPlanKey else None
|
||||
_notifySubscriptionChange(sub["mandateId"], "trial_expired", successorPlan)
|
||||
logger.info("Trial expired for subscription %s", subscriptionId)
|
||||
|
||||
# =========================================================================
|
||||
# Stripe quantity sync
|
||||
# =========================================================================
|
||||
|
||||
def syncStripeQuantity(self, subscriptionId: str):
|
||||
self._interface.syncQuantityToStripe(subscriptionId)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Notifications
|
||||
# ============================================================================
|
||||
|
||||
def _notifySubscriptionChange(
|
||||
mandateId: str,
|
||||
event: str,
|
||||
plan: Optional[SubscriptionPlan] = None,
|
||||
subscriptionRecord: Optional[Dict[str, Any]] = None,
|
||||
platformUrl: str = "",
|
||||
) -> None:
|
||||
try:
|
||||
from modules.shared.notifyMandateAdmins import notifyMandateAdmins
|
||||
|
||||
planLabel = (plan.title.get("de") or plan.title.get("en") or plan.planKey) if plan else "—"
|
||||
platformHint = f"Plattform: {platformUrl}" if platformUrl else ""
|
||||
|
||||
rawHtmlBlock: Optional[str] = None
|
||||
|
||||
if event == "activated" and plan and subscriptionRecord:
|
||||
rawHtmlBlock = _buildInvoiceSummaryHtml(plan, subscriptionRecord, mandateId, platformUrl)
|
||||
elif event in ("cancelled", "force_cancelled") and subscriptionRecord:
|
||||
rawHtmlBlock = _buildCancelSummaryHtml(subscriptionRecord, platformUrl)
|
||||
|
||||
templates: Dict[str, Dict[str, Any]] = {
|
||||
"activated": {
|
||||
"subject": f"[PowerOn] Abonnement aktiviert — {planLabel}",
|
||||
"headline": "Abonnement aktiviert",
|
||||
"paragraphs": [
|
||||
p for p in [
|
||||
f"Das Abonnement wurde auf den Plan «{planLabel}» aktiviert.",
|
||||
platformHint,
|
||||
"Sie können Ihr Abonnement jederzeit unter Billing-Verwaltung › Abonnement einsehen und verwalten.",
|
||||
] if p
|
||||
],
|
||||
},
|
||||
"cancelled": {
|
||||
"subject": f"[PowerOn] Abonnement gekündigt — {planLabel}",
|
||||
"headline": "Abonnement gekündigt",
|
||||
"paragraphs": [
|
||||
p for p in [
|
||||
f"Das Abonnement «{planLabel}» wurde gekündigt.",
|
||||
platformHint,
|
||||
"Die Kündigung wird zum Ende der aktuellen bezahlten Periode wirksam. Bis dahin bleibt der volle Zugang bestehen.",
|
||||
] if p
|
||||
],
|
||||
},
|
||||
"force_cancelled": {
|
||||
"subject": f"[PowerOn] Abonnement sofort beendet — {planLabel}",
|
||||
"headline": "Abonnement sofort beendet",
|
||||
"paragraphs": [
|
||||
p for p in [
|
||||
f"Das Abonnement «{planLabel}» wurde durch den Plattform-Administrator sofort beendet.",
|
||||
platformHint,
|
||||
"Der Zugang wurde per sofort deaktiviert. Bei Fragen wenden Sie sich an den Plattform-Support.",
|
||||
] if p
|
||||
],
|
||||
},
|
||||
"trial_expired": {
|
||||
"subject": "[PowerOn] Testphase abgelaufen",
|
||||
"headline": "Testphase abgelaufen",
|
||||
"paragraphs": [
|
||||
p for p in [
|
||||
"Die kostenlose Testphase ist abgelaufen.",
|
||||
platformHint,
|
||||
"Bitte wählen Sie einen Plan unter Billing-Verwaltung › Abonnement, damit der Zugang nicht unterbrochen wird.",
|
||||
] if p
|
||||
],
|
||||
},
|
||||
"payment_failed": {
|
||||
"subject": f"[PowerOn] Zahlung fehlgeschlagen — {planLabel}",
|
||||
"headline": "Zahlung fehlgeschlagen",
|
||||
"paragraphs": [
|
||||
p for p in [
|
||||
f"Die Zahlung für das Abonnement «{planLabel}» ist fehlgeschlagen.",
|
||||
platformHint,
|
||||
"Bitte aktualisieren Sie Ihr Zahlungsmittel unter Billing-Verwaltung.",
|
||||
] if p
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
tpl = templates.get(event, {
|
||||
"subject": f"[PowerOn] Abonnement-Änderung — {planLabel}",
|
||||
"headline": "Abonnement-Änderung",
|
||||
"paragraphs": [f"Änderung am Abonnement «{planLabel}»."],
|
||||
})
|
||||
|
||||
notifyMandateAdmins(
|
||||
mandateId, tpl["subject"], tpl["headline"], tpl["paragraphs"],
|
||||
rawHtmlBlock=rawHtmlBlock,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("_notifySubscriptionChange failed for mandate %s event %s: %s", mandateId, event, e)
|
||||
|
||||
|
||||
def _buildInvoiceSummaryHtml(
|
||||
plan: SubscriptionPlan,
|
||||
subRecord: Dict[str, Any],
|
||||
mandateId: str,
|
||||
platformUrl: str = "",
|
||||
) -> str:
|
||||
"""Build an HTML invoice summary block for inclusion in the activation email."""
|
||||
import html as htmlmod
|
||||
from modules.interfaces.interfaceDbSubscription import _getRootInterface as getSubRootInterface
|
||||
|
||||
subInterface = getSubRootInterface()
|
||||
userCount = subInterface.countActiveUsers(mandateId)
|
||||
instanceCount = subInterface.countActiveFeatureInstances(mandateId)
|
||||
|
||||
userPrice = plan.pricePerUserCHF
|
||||
instancePrice = plan.pricePerFeatureInstanceCHF
|
||||
userTotal = userCount * userPrice
|
||||
instanceTotal = instanceCount * instancePrice
|
||||
netTotal = userTotal + instanceTotal
|
||||
|
||||
periodLabel = {"MONTHLY": "Monatlich", "YEARLY": "Jährlich"}.get(plan.billingPeriod, plan.billingPeriod)
|
||||
|
||||
def _chf(amount: float) -> str:
|
||||
return f"CHF {amount:,.2f}".replace(",", "'")
|
||||
|
||||
rows = ""
|
||||
if userPrice > 0:
|
||||
rows += (
|
||||
f'<tr><td style="padding:6px 0;color:#333;">Benutzer-Lizenzen</td>'
|
||||
f'<td style="padding:6px 8px;color:#555;text-align:right;">{userCount} × {_chf(userPrice)}</td>'
|
||||
f'<td style="padding:6px 0;color:#333;text-align:right;font-weight:600;">{_chf(userTotal)}</td></tr>\n'
|
||||
)
|
||||
if instancePrice > 0:
|
||||
rows += (
|
||||
f'<tr><td style="padding:6px 0;color:#333;">Feature-Instanzen</td>'
|
||||
f'<td style="padding:6px 8px;color:#555;text-align:right;">{instanceCount} × {_chf(instancePrice)}</td>'
|
||||
f'<td style="padding:6px 0;color:#333;text-align:right;font-weight:600;">{_chf(instanceTotal)}</td></tr>\n'
|
||||
)
|
||||
|
||||
invoiceLink = ""
|
||||
stripeSubId = subRecord.get("stripeSubscriptionId")
|
||||
if stripeSubId:
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
invoices = stripe.Invoice.list(subscription=stripeSubId, limit=1)
|
||||
if invoices.data:
|
||||
hostedUrl = invoices.data[0].get("hosted_invoice_url", "")
|
||||
if hostedUrl:
|
||||
invoiceLink = (
|
||||
f'<p style="margin:12px 0 0 0;font-size:14px;">'
|
||||
f'<a href="{htmlmod.escape(hostedUrl)}" style="color:#3b82f6;text-decoration:underline;">'
|
||||
f'Vollständige Rechnung mit MwSt-Ausweis anzeigen</a></p>\n'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Could not fetch Stripe invoice URL for sub %s: %s", stripeSubId, e)
|
||||
|
||||
return (
|
||||
f'<table style="width:100%;border-collapse:collapse;font-size:14px;margin:8px 0;">'
|
||||
f'<thead><tr style="border-bottom:2px solid #e5e7eb;">'
|
||||
f'<th style="text-align:left;padding:8px 0;color:#6b7280;font-weight:500;">Position</th>'
|
||||
f'<th style="text-align:right;padding:8px;color:#6b7280;font-weight:500;">Menge × Preis</th>'
|
||||
f'<th style="text-align:right;padding:8px 0;color:#6b7280;font-weight:500;">Total</th>'
|
||||
f'</tr></thead>'
|
||||
f'<tbody>{rows}</tbody>'
|
||||
f'<tfoot><tr style="border-top:2px solid #1a1a2e;">'
|
||||
f'<td style="padding:10px 0;font-weight:700;color:#1a1a2e;">Netto-Total ({periodLabel})</td>'
|
||||
f'<td></td>'
|
||||
f'<td style="padding:10px 0;text-align:right;font-weight:700;color:#1a1a2e;font-size:16px;">{_chf(netTotal)}</td>'
|
||||
f'</tr></tfoot>'
|
||||
f'</table>'
|
||||
f'{invoiceLink}'
|
||||
)
|
||||
|
||||
|
||||
def _buildCancelSummaryHtml(subRecord: Dict[str, Any], platformUrl: str = "") -> str:
|
||||
"""Build an HTML block with billing link and Stripe invoice link for cancel emails."""
|
||||
import html as htmlmod
|
||||
|
||||
parts: list[str] = []
|
||||
|
||||
stripeSubId = subRecord.get("stripeSubscriptionId")
|
||||
if stripeSubId:
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
invoices = stripe.Invoice.list(subscription=stripeSubId, limit=1)
|
||||
if invoices.data:
|
||||
hostedUrl = invoices.data[0].get("hosted_invoice_url", "")
|
||||
if hostedUrl:
|
||||
parts.append(
|
||||
f'<p style="margin:4px 0;font-size:14px;">'
|
||||
f'<a href="{htmlmod.escape(hostedUrl)}" style="color:#3b82f6;text-decoration:underline;">'
|
||||
f'Letzte Stripe-Rechnung anzeigen</a></p>'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Could not fetch Stripe invoice URL for sub %s: %s", stripeSubId, e)
|
||||
|
||||
return "\n".join(parts) if parts else ""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Exception Classes
|
||||
# ============================================================================
|
||||
|
||||
SUBSCRIPTION_USER_ACTION_UPGRADE = "UPGRADE_SUBSCRIPTION"
|
||||
SUBSCRIPTION_USER_ACTION_REACTIVATE = "REACTIVATE_SUBSCRIPTION"
|
||||
SUBSCRIPTION_USER_ACTION_ADD_PAYMENT = "ADD_PAYMENT_METHOD"
|
||||
|
||||
SUBSCRIPTION_REASONS = {
|
||||
"SUBSCRIPTION_INACTIVE",
|
||||
"SUBSCRIPTION_PAYMENT_REQUIRED",
|
||||
"SUBSCRIPTION_PAYMENT_PENDING",
|
||||
"SUBSCRIPTION_EXPIRED",
|
||||
}
|
||||
|
||||
|
||||
def _subscriptionReasonForStatus(status: SubscriptionStatusEnum) -> str:
|
||||
if status == SubscriptionStatusEnum.PENDING:
|
||||
return "SUBSCRIPTION_PAYMENT_PENDING"
|
||||
if status == SubscriptionStatusEnum.PAST_DUE:
|
||||
return "SUBSCRIPTION_PAYMENT_REQUIRED"
|
||||
if status == SubscriptionStatusEnum.EXPIRED:
|
||||
return "SUBSCRIPTION_EXPIRED"
|
||||
return "SUBSCRIPTION_INACTIVE"
|
||||
|
||||
|
||||
def _subscriptionUserActionForStatus(status: SubscriptionStatusEnum) -> str:
|
||||
if status in (SubscriptionStatusEnum.PAST_DUE, SubscriptionStatusEnum.PENDING):
|
||||
return SUBSCRIPTION_USER_ACTION_ADD_PAYMENT
|
||||
return SUBSCRIPTION_USER_ACTION_UPGRADE
|
||||
|
||||
|
||||
class SubscriptionInactiveException(Exception):
|
||||
def __init__(self, status: SubscriptionStatusEnum, mandateId: str = "", message: Optional[str] = None):
|
||||
self.status = status
|
||||
self.mandateId = mandateId
|
||||
self.reason = _subscriptionReasonForStatus(status)
|
||||
self.userAction = _subscriptionUserActionForStatus(status)
|
||||
self.message = message or (
|
||||
"Kein aktives Abonnement für diesen Mandanten. Bitte wählen Sie einen Plan unter Billing."
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
def toClientDict(self) -> Dict[str, Any]:
|
||||
out: Dict[str, Any] = {
|
||||
"error": self.reason, "message": self.message,
|
||||
"userAction": self.userAction, "subscriptionUiPath": "/admin/billing?tab=subscription",
|
||||
}
|
||||
if self.mandateId:
|
||||
out["mandateId"] = self.mandateId
|
||||
return out
|
||||
|
||||
|
||||
class SubscriptionCapacityException(Exception):
|
||||
def __init__(self, resourceType: str, currentCount: int, maxAllowed: int, message: Optional[str] = None):
|
||||
self.resourceType = resourceType
|
||||
self.currentCount = currentCount
|
||||
self.maxAllowed = maxAllowed
|
||||
self.message = message or (
|
||||
f"Ihr Plan erlaubt maximal {maxAllowed} {'Benutzer' if resourceType == 'users' else 'Feature-Instanzen'} "
|
||||
f"(aktuell {currentCount}). Bitte wechseln Sie zu einem grösseren Plan."
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
def toClientDict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"error": f"SUBSCRIPTION_{self.resourceType.upper()}_LIMIT",
|
||||
"currentCount": self.currentCount, "maxAllowed": self.maxAllowed,
|
||||
"message": self.message, "userAction": SUBSCRIPTION_USER_ACTION_UPGRADE,
|
||||
"subscriptionUiPath": "/admin/billing?tab=subscription",
|
||||
}
|
||||
|
||||
|
||||
SubscriptionService.SubscriptionInactiveException = SubscriptionInactiveException
|
||||
SubscriptionService.SubscriptionCapacityException = SubscriptionCapacityException
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Auto-provision Stripe Products and Prices from the built-in plan catalog.
|
||||
|
||||
Creates separate Stripe Products for user licenses and feature instances
|
||||
so that invoice line items show clear, descriptive names:
|
||||
- "Benutzer-Lizenzen"
|
||||
- "Feature-Instanzen"
|
||||
|
||||
Idempotent — safe to call on every startup.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
from modules.connectors.connectorDbPostgre import DatabaseConnector
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
from modules.datamodels.datamodelSubscription import (
|
||||
BUILTIN_PLANS,
|
||||
SubscriptionPlan,
|
||||
BillingPeriodEnum,
|
||||
StripePlanPrice,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_BILLING_DATABASE = "poweron_billing"
|
||||
_METADATA_KEY = "poweron_plan_key"
|
||||
_METADATA_LINE_TYPE = "poweron_line_type"
|
||||
|
||||
_PERIOD_TO_STRIPE = {
|
||||
BillingPeriodEnum.MONTHLY: {"interval": "month", "interval_count": 1},
|
||||
BillingPeriodEnum.YEARLY: {"interval": "year", "interval_count": 1},
|
||||
}
|
||||
|
||||
|
||||
def _getBillingDb() -> DatabaseConnector:
|
||||
return DatabaseConnector(
|
||||
dbDatabase=_BILLING_DATABASE,
|
||||
dbHost=APP_CONFIG.get("DB_HOST", "localhost"),
|
||||
dbPort=int(APP_CONFIG.get("DB_PORT", "5432")),
|
||||
dbUser=APP_CONFIG.get("DB_USER"),
|
||||
dbPassword=APP_CONFIG.get("DB_PASSWORD_SECRET"),
|
||||
)
|
||||
|
||||
|
||||
def _loadExistingMappings(db: DatabaseConnector) -> Dict[str, StripePlanPrice]:
|
||||
try:
|
||||
rows = db.getRecordset(StripePlanPrice)
|
||||
result = {}
|
||||
for row in rows:
|
||||
pk = row.get("planKey")
|
||||
if pk:
|
||||
result[pk] = StripePlanPrice(**{k: v for k, v in row.items() if not k.startswith("_")})
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning("Could not load StripePlanPrice records: %s", e)
|
||||
return {}
|
||||
|
||||
|
||||
def _findStripeProduct(stripe, planKey: str, lineType: str) -> Optional[str]:
|
||||
"""Search Stripe for a product tagged with plan key + line type."""
|
||||
try:
|
||||
products = stripe.Product.search(
|
||||
query=f'metadata["{_METADATA_KEY}"]:"{planKey}" AND metadata["{_METADATA_LINE_TYPE}"]:"{lineType}"',
|
||||
limit=1,
|
||||
)
|
||||
if products.data:
|
||||
return products.data[0].id
|
||||
except Exception:
|
||||
try:
|
||||
products = stripe.Product.search(
|
||||
query=f'metadata["{_METADATA_KEY}"]:"{planKey}"',
|
||||
limit=10,
|
||||
)
|
||||
for p in products.data:
|
||||
meta = p.get("metadata") or {}
|
||||
if meta.get(_METADATA_LINE_TYPE) == lineType:
|
||||
return p.id
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _createStripeProduct(stripe, name: str, description: str, planKey: str, lineType: str) -> str:
|
||||
product = stripe.Product.create(
|
||||
name=name,
|
||||
description=description,
|
||||
metadata={_METADATA_KEY: planKey, _METADATA_LINE_TYPE: lineType},
|
||||
)
|
||||
logger.info("Created Stripe Product %s: %s (%s/%s)", product.id, name, planKey, lineType)
|
||||
return product.id
|
||||
|
||||
|
||||
def _findExistingStripePrice(stripe, productId: str, unitAmount: int, interval: str) -> Optional[str]:
|
||||
try:
|
||||
prices = stripe.Price.list(product=productId, active=True, limit=50)
|
||||
for p in prices.data:
|
||||
recurring = p.get("recurring") or {}
|
||||
if p.get("unit_amount") == unitAmount and recurring.get("interval") == interval:
|
||||
return p.id
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _getStripePriceAmount(stripe, priceId: str) -> Optional[int]:
|
||||
"""Retrieve the unit_amount (in Rappen) of an existing Stripe Price."""
|
||||
try:
|
||||
price = stripe.Price.retrieve(priceId)
|
||||
return price.get("unit_amount") if price else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _reconcilePrice(stripe, productId: str, oldPriceId: str, expectedCHF: float, interval: str, nickname: str) -> str:
|
||||
"""If the stored Stripe Price has a different amount, create a new one and deactivate the old."""
|
||||
expectedCents = int(expectedCHF * 100)
|
||||
actualCents = _getStripePriceAmount(stripe, oldPriceId)
|
||||
|
||||
if actualCents == expectedCents:
|
||||
return oldPriceId
|
||||
|
||||
logger.warning(
|
||||
"Price drift detected for %s: Stripe has %s Rappen, catalog expects %s Rappen. Rotating price.",
|
||||
oldPriceId, actualCents, expectedCents,
|
||||
)
|
||||
|
||||
existingMatch = _findExistingStripePrice(stripe, productId, expectedCents, interval)
|
||||
if existingMatch:
|
||||
newPriceId = existingMatch
|
||||
else:
|
||||
newPriceId = _createStripePrice(stripe, productId, expectedCHF, interval, nickname)
|
||||
|
||||
try:
|
||||
stripe.Price.modify(oldPriceId, active=False)
|
||||
logger.info("Deactivated old Stripe Price %s", oldPriceId)
|
||||
except Exception as e:
|
||||
logger.warning("Could not deactivate old price %s: %s", oldPriceId, e)
|
||||
|
||||
return newPriceId
|
||||
|
||||
|
||||
def _createStripePrice(stripe, productId: str, unitAmountCHF: float, interval: str, nickname: str) -> str:
|
||||
price = stripe.Price.create(
|
||||
product=productId,
|
||||
unit_amount=int(unitAmountCHF * 100),
|
||||
currency="chf",
|
||||
recurring={"interval": interval},
|
||||
nickname=nickname,
|
||||
)
|
||||
logger.info("Created Stripe Price %s (%s, %s CHF/%s)", price.id, nickname, unitAmountCHF, interval)
|
||||
return price.id
|
||||
|
||||
|
||||
def bootstrapStripePrices() -> None:
|
||||
"""Ensure all paid plans have separate Stripe Products for users and instances."""
|
||||
try:
|
||||
from modules.shared.stripeClient import getStripeClient
|
||||
stripe = getStripeClient()
|
||||
except ValueError as e:
|
||||
logger.error("Stripe not configured — cannot bootstrap subscription prices: %s", e)
|
||||
return
|
||||
|
||||
db = _getBillingDb()
|
||||
existing = _loadExistingMappings(db)
|
||||
|
||||
for planKey, plan in BUILTIN_PLANS.items():
|
||||
if plan.billingPeriod == BillingPeriodEnum.NONE:
|
||||
continue
|
||||
if plan.pricePerUserCHF == 0 and plan.pricePerFeatureInstanceCHF == 0:
|
||||
continue
|
||||
|
||||
stripePeriod = _PERIOD_TO_STRIPE.get(plan.billingPeriod)
|
||||
if not stripePeriod:
|
||||
continue
|
||||
|
||||
interval = stripePeriod["interval"]
|
||||
|
||||
if planKey in existing:
|
||||
mapping = existing[planKey]
|
||||
hasAllPrices = mapping.stripePriceIdUsers and mapping.stripePriceIdInstances
|
||||
hasAllProducts = mapping.stripeProductIdUsers and mapping.stripeProductIdInstances
|
||||
if hasAllPrices and hasAllProducts:
|
||||
changed = False
|
||||
reconciledUsers = _reconcilePrice(
|
||||
stripe, mapping.stripeProductIdUsers, mapping.stripePriceIdUsers,
|
||||
plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
|
||||
)
|
||||
if reconciledUsers != mapping.stripePriceIdUsers:
|
||||
changed = True
|
||||
|
||||
reconciledInstances = _reconcilePrice(
|
||||
stripe, mapping.stripeProductIdInstances, mapping.stripePriceIdInstances,
|
||||
plan.pricePerFeatureInstanceCHF, interval, f"{planKey} — Feature-Instanz",
|
||||
)
|
||||
if reconciledInstances != mapping.stripePriceIdInstances:
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
db.recordModify(StripePlanPrice, mapping.id, {
|
||||
"stripePriceIdUsers": reconciledUsers,
|
||||
"stripePriceIdInstances": reconciledInstances,
|
||||
})
|
||||
logger.info("Reconciled Stripe prices for plan %s: users=%s, instances=%s", planKey, reconciledUsers, reconciledInstances)
|
||||
else:
|
||||
logger.debug("Stripe prices up-to-date for plan %s", planKey)
|
||||
continue
|
||||
|
||||
productIdUsers = None
|
||||
productIdInstances = None
|
||||
priceIdUsers = None
|
||||
priceIdInstances = None
|
||||
|
||||
if plan.pricePerUserCHF > 0:
|
||||
productIdUsers = _findStripeProduct(stripe, planKey, "users")
|
||||
if not productIdUsers:
|
||||
productIdUsers = _createStripeProduct(
|
||||
stripe, "Benutzer-Lizenzen", f"Benutzer-Lizenzen für {plan.title.get('de', planKey)}",
|
||||
planKey, "users",
|
||||
)
|
||||
priceIdUsers = _findExistingStripePrice(stripe, productIdUsers, int(plan.pricePerUserCHF * 100), interval)
|
||||
if not priceIdUsers:
|
||||
priceIdUsers = _createStripePrice(
|
||||
stripe, productIdUsers, plan.pricePerUserCHF, interval, f"{planKey} — Benutzer-Lizenz",
|
||||
)
|
||||
|
||||
if plan.pricePerFeatureInstanceCHF > 0:
|
||||
productIdInstances = _findStripeProduct(stripe, planKey, "instances")
|
||||
if not productIdInstances:
|
||||
productIdInstances = _createStripeProduct(
|
||||
stripe, "Feature-Instanzen", f"Feature-Instanzen für {plan.title.get('de', planKey)}",
|
||||
planKey, "instances",
|
||||
)
|
||||
priceIdInstances = _findExistingStripePrice(
|
||||
stripe, productIdInstances, int(plan.pricePerFeatureInstanceCHF * 100), interval,
|
||||
)
|
||||
if not priceIdInstances:
|
||||
priceIdInstances = _createStripePrice(
|
||||
stripe, productIdInstances, plan.pricePerFeatureInstanceCHF, interval,
|
||||
f"{planKey} — Feature-Instanz",
|
||||
)
|
||||
|
||||
persistData = {
|
||||
"stripeProductId": "",
|
||||
"stripeProductIdUsers": productIdUsers,
|
||||
"stripeProductIdInstances": productIdInstances,
|
||||
"stripePriceIdUsers": priceIdUsers,
|
||||
"stripePriceIdInstances": priceIdInstances,
|
||||
}
|
||||
|
||||
if planKey in existing:
|
||||
db.recordModify(StripePlanPrice, existing[planKey].id, persistData)
|
||||
else:
|
||||
db.recordCreate(StripePlanPrice, StripePlanPrice(planKey=planKey, **persistData).model_dump())
|
||||
|
||||
logger.info(
|
||||
"Stripe bootstrapped for %s: users=%s/%s, instances=%s/%s",
|
||||
planKey, productIdUsers, priceIdUsers, productIdInstances, priceIdInstances,
|
||||
)
|
||||
|
||||
|
||||
def getStripePricesForPlan(planKey: str) -> Optional[StripePlanPrice]:
|
||||
"""Load the persisted Stripe IDs for a plan."""
|
||||
try:
|
||||
db = _getBillingDb()
|
||||
rows = db.getRecordset(StripePlanPrice, recordFilter={"planKey": planKey})
|
||||
if rows:
|
||||
return StripePlanPrice(**{k: v for k, v in rows[0].items() if not k.startswith("_")})
|
||||
except Exception as e:
|
||||
logger.error("Error loading Stripe prices for plan %s: %s", planKey, e)
|
||||
return None
|
||||
|
|
@ -258,6 +258,27 @@ def getModelAttributeDefinitions(modelClass: Type[BaseModel] = None, userLanguag
|
|||
|
||||
attributes.append(attr_def)
|
||||
|
||||
# Append system timestamp fields (set automatically by DatabaseConnector)
|
||||
systemTimestampFields = [
|
||||
("_createdAt", {"en": "Created at", "de": "Erstellt am", "fr": "Créé le"}),
|
||||
("_modifiedAt", {"en": "Modified at", "de": "Geändert am", "fr": "Modifié le"}),
|
||||
]
|
||||
for sysName, sysLabels in systemTimestampFields:
|
||||
attributes.append({
|
||||
"name": sysName,
|
||||
"type": "timestamp",
|
||||
"required": False,
|
||||
"description": "",
|
||||
"label": sysLabels.get(userLanguage, sysLabels["en"]),
|
||||
"placeholder": "",
|
||||
"editable": False,
|
||||
"visible": True,
|
||||
"order": len(attributes),
|
||||
"readonly": True,
|
||||
"options": None,
|
||||
"default": None,
|
||||
})
|
||||
|
||||
return {"model": model_label, "attributes": attributes}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class Configuration:
|
|||
self._configMtime = currentMtime
|
||||
|
||||
try:
|
||||
with open(configPath, 'r') as f:
|
||||
with open(configPath, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
i = 0
|
||||
|
|
|
|||
285
modules/shared/notifyMandateAdmins.py
Normal file
285
modules/shared/notifyMandateAdmins.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Central notification utility for mandate administrators.
|
||||
|
||||
All mandate-level notifications (subscription changes, billing warnings, etc.)
|
||||
MUST go through notifyMandateAdmins() to ensure consistent recipient resolution
|
||||
and delivery.
|
||||
|
||||
Recipients are the union of:
|
||||
1. BillingSettings.notifyEmails for the mandate (configured contact addresses)
|
||||
2. All users with the mandate-level "admin" RBAC role
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import html
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from modules.datamodels.datamodelMessaging import MessagingChannel
|
||||
from modules.interfaces.interfaceMessaging import getInterface as getMessagingInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Recipient resolution
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _normalizeEmailList(raw: Any) -> List[str]:
|
||||
"""Parse notifyEmails which can be a list, JSON string, or single address."""
|
||||
if raw is None:
|
||||
return []
|
||||
if isinstance(raw, list):
|
||||
return [str(e).strip().lower() for e in raw if str(e).strip()]
|
||||
if isinstance(raw, str):
|
||||
s = raw.strip()
|
||||
if not s:
|
||||
return []
|
||||
if s.startswith("["):
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
if isinstance(parsed, list):
|
||||
return [str(e).strip().lower() for e in parsed if str(e).strip()]
|
||||
except Exception:
|
||||
pass
|
||||
return [s.lower()]
|
||||
return []
|
||||
|
||||
|
||||
def _resolveMandateContactEmails(mandateId: str) -> List[str]:
|
||||
"""Get the configured notifyEmails from BillingSettings."""
|
||||
try:
|
||||
from modules.interfaces.interfaceDbBilling import getInterface as getBillingInterface
|
||||
from modules.security.rootAccess import getRootUser
|
||||
billingIf = getBillingInterface(getRootUser(), mandateId)
|
||||
settings = billingIf.getSettings(mandateId) or {}
|
||||
return _normalizeEmailList(settings.get("notifyEmails"))
|
||||
except Exception as e:
|
||||
logger.warning("Could not resolve BillingSettings.notifyEmails for mandate %s: %s", mandateId, e)
|
||||
return []
|
||||
|
||||
|
||||
def _resolveMandateAdminEmails(mandateId: str) -> List[str]:
|
||||
"""Resolve all admin users of a mandate via RBAC and return their emails."""
|
||||
emails: List[str] = []
|
||||
try:
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
rootIf = getRootInterface()
|
||||
userMandates = rootIf.getUserMandatesByMandate(mandateId)
|
||||
for um in userMandates:
|
||||
if not getattr(um, "enabled", True):
|
||||
continue
|
||||
umId = str(getattr(um, "id", ""))
|
||||
userId = getattr(um, "userId", None)
|
||||
if not userId:
|
||||
continue
|
||||
roleIds = rootIf.getRoleIdsForUserMandate(umId)
|
||||
isAdmin = False
|
||||
for roleId in roleIds:
|
||||
role = rootIf.getRole(roleId)
|
||||
if role and role.roleLabel == "admin" and not role.featureInstanceId:
|
||||
isAdmin = True
|
||||
break
|
||||
if not isAdmin:
|
||||
continue
|
||||
user = rootIf.getUser(str(userId))
|
||||
if user and user.email:
|
||||
emails.append(user.email.strip().lower())
|
||||
except Exception as e:
|
||||
logger.warning("Could not resolve admin emails for mandate %s: %s", mandateId, e)
|
||||
return emails
|
||||
|
||||
|
||||
def _resolveAllRecipients(mandateId: str) -> List[str]:
|
||||
"""Union of BillingSettings.notifyEmails + all mandate admin user emails, deduplicated."""
|
||||
seen: Set[str] = set()
|
||||
result: List[str] = []
|
||||
for email in _resolveMandateContactEmails(mandateId) + _resolveMandateAdminEmails(mandateId):
|
||||
if email and email not in seen:
|
||||
seen.add(email)
|
||||
result.append(email)
|
||||
return result
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mandate name resolution
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _resolveMandateName(mandateId: str) -> str:
|
||||
"""Return the human-readable mandate name (label or name), falling back to a short ID."""
|
||||
try:
|
||||
from modules.datamodels.datamodelUam import Mandate
|
||||
from modules.security.rootAccess import getRootDbAppConnector
|
||||
appDb = getRootDbAppConnector()
|
||||
rows = appDb.getRecordset(Mandate, recordFilter={"id": mandateId})
|
||||
if rows:
|
||||
return rows[0].get("label") or rows[0].get("name") or mandateId[:8]
|
||||
except Exception as e:
|
||||
logger.warning("Could not resolve mandate name for %s: %s", mandateId, e)
|
||||
return mandateId[:8]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HTML email rendering
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _getOperatorInfo() -> Dict[str, str]:
|
||||
"""Load operator company data from config.ini."""
|
||||
try:
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
return {
|
||||
"companyName": APP_CONFIG.get("Operator_CompanyName", ""),
|
||||
"address": APP_CONFIG.get("Operator_Address", ""),
|
||||
"vatNumber": APP_CONFIG.get("Operator_VatNumber", ""),
|
||||
}
|
||||
except Exception:
|
||||
return {"companyName": "", "address": "", "vatNumber": ""}
|
||||
|
||||
|
||||
def _renderHtmlEmail(
|
||||
headline: str,
|
||||
bodyParagraphs: List[str],
|
||||
mandateName: str,
|
||||
footerNote: Optional[str] = None,
|
||||
rawHtmlBlock: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Render a clean, professional HTML notification email.
|
||||
|
||||
Args:
|
||||
rawHtmlBlock: Optional pre-formatted HTML inserted after bodyParagraphs (e.g. invoice table).
|
||||
"""
|
||||
hl = html.escape(headline)
|
||||
mn = html.escape(mandateName)
|
||||
|
||||
paragraphsHtml = ""
|
||||
for p in bodyParagraphs:
|
||||
escaped = html.escape(p).replace("\n", "<br>")
|
||||
paragraphsHtml += f'<p style="margin: 0 0 14px 0; color: #333333;">{escaped}</p>\n'
|
||||
|
||||
rawBlock = ""
|
||||
if rawHtmlBlock:
|
||||
rawBlock = f'<div style="margin: 16px 0;">{rawHtmlBlock}</div>\n'
|
||||
|
||||
footer = ""
|
||||
if footerNote:
|
||||
footer = (
|
||||
f'<p style="margin: 16px 0 0 0; font-size: 13px; color: #888888;">'
|
||||
f'{html.escape(footerNote)}</p>\n'
|
||||
)
|
||||
|
||||
operator = _getOperatorInfo()
|
||||
operatorLine = ""
|
||||
parts = [p for p in [operator["companyName"], operator["address"], operator["vatNumber"]] if p]
|
||||
if parts:
|
||||
operatorLine = (
|
||||
f'<p style="margin: 4px 0 0 0; font-size: 11px; color: #b0b0b0; text-align: center;">'
|
||||
f'{html.escape(" | ".join(parts))}</p>\n'
|
||||
)
|
||||
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="de">
|
||||
<head><meta charset="utf-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"></head>
|
||||
<body style="margin: 0; padding: 0; background-color: #f4f4f7; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;">
|
||||
<table role="presentation" width="100%" cellpadding="0" cellspacing="0" style="background-color: #f4f4f7; padding: 32px 16px;">
|
||||
<tr><td align="center">
|
||||
<table role="presentation" width="560" cellpadding="0" cellspacing="0" style="background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 1px 3px rgba(0,0,0,0.08);">
|
||||
<!-- Header -->
|
||||
<tr><td style="background-color: #1a1a2e; padding: 24px 32px;">
|
||||
<h1 style="margin: 0; font-size: 18px; font-weight: 600; color: #ffffff;">PowerOn</h1>
|
||||
</td></tr>
|
||||
<!-- Body -->
|
||||
<tr><td style="padding: 32px;">
|
||||
<h2 style="margin: 0 0 8px 0; font-size: 20px; font-weight: 600; color: #1a1a2e;">{hl}</h2>
|
||||
<p style="margin: 0 0 24px 0; font-size: 14px; color: #6b7280;">Mandant: <strong>{mn}</strong></p>
|
||||
<div style="font-size: 15px; line-height: 1.6;">
|
||||
{paragraphsHtml}
|
||||
{rawBlock}
|
||||
</div>
|
||||
{footer}
|
||||
</td></tr>
|
||||
<!-- Footer -->
|
||||
<tr><td style="padding: 16px 32px; background-color: #f9fafb; border-top: 1px solid #e5e7eb;">
|
||||
<p style="margin: 0; font-size: 12px; color: #9ca3af; text-align: center;">
|
||||
Diese E-Mail wurde automatisch von PowerOn versendet.
|
||||
</p>
|
||||
{operatorLine}
|
||||
</td></tr>
|
||||
</table>
|
||||
</td></tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Public API
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def notifyMandateAdmins(
|
||||
mandateId: str,
|
||||
subject: str,
|
||||
headline: str,
|
||||
bodyParagraphs: List[str],
|
||||
*,
|
||||
footerNote: Optional[str] = None,
|
||||
rawHtmlBlock: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Send a styled HTML notification to all mandate admins and configured contacts.
|
||||
|
||||
Args:
|
||||
mandateId: The mandate to notify admins for.
|
||||
subject: Email subject line.
|
||||
headline: Bold headline inside the email body.
|
||||
bodyParagraphs: List of paragraph strings (plain text, auto-escaped).
|
||||
footerNote: Optional small-print note below the main content.
|
||||
rawHtmlBlock: Optional pre-formatted HTML block (e.g. invoice summary table).
|
||||
|
||||
Returns:
|
||||
Number of recipients that were successfully notified.
|
||||
"""
|
||||
if not mandateId:
|
||||
return 0
|
||||
|
||||
recipients = _resolveAllRecipients(mandateId)
|
||||
if not recipients:
|
||||
logger.warning(
|
||||
"notifyMandateAdmins: no recipients found for mandate %s "
|
||||
"(no notifyEmails configured and no admin users with email)",
|
||||
mandateId,
|
||||
)
|
||||
return 0
|
||||
|
||||
mandateName = _resolveMandateName(mandateId)
|
||||
htmlMessage = _renderHtmlEmail(headline, bodyParagraphs, mandateName, footerNote, rawHtmlBlock)
|
||||
messaging = getMessagingInterface()
|
||||
successCount = 0
|
||||
|
||||
for to in recipients:
|
||||
try:
|
||||
ok = messaging.send(
|
||||
channel=MessagingChannel.EMAIL,
|
||||
recipient=to,
|
||||
subject=subject,
|
||||
message=htmlMessage,
|
||||
)
|
||||
if ok:
|
||||
successCount += 1
|
||||
else:
|
||||
logger.warning("notifyMandateAdmins: send failed for %s", to)
|
||||
except Exception as e:
|
||||
logger.error("notifyMandateAdmins: error sending to %s: %s", to, e)
|
||||
|
||||
logger.info(
|
||||
"notifyMandateAdmins: sent '%s' to %d/%d recipients for mandate %s (%s)",
|
||||
subject, successCount, len(recipients), mandateId, mandateName,
|
||||
)
|
||||
return successCount
|
||||
|
|
@ -135,8 +135,8 @@ class ProgressLogger:
|
|||
message = f"{op['service']}"
|
||||
|
||||
workflow = self.services.workflow
|
||||
if not workflow:
|
||||
logger.warning(f"Cannot log progress: no workflow available")
|
||||
if not workflow or not getattr(workflow, "id", None):
|
||||
# No workflow or no workflow.id (e.g. automation2 placeholder) - skip progress logging
|
||||
return None
|
||||
|
||||
# Validate parentOperationId exists in activeOperations or finishedOperations
|
||||
|
|
|
|||
38
modules/shared/stripeClient.py
Normal file
38
modules/shared/stripeClient.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
"""
|
||||
Central Stripe SDK initialization.
|
||||
|
||||
All Stripe interactions MUST use getStripeClient() to ensure consistent
|
||||
API key, API version, and fallback handling across billing and subscription flows.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_stripeInitialized = False
|
||||
|
||||
|
||||
def getStripeClient():
|
||||
"""
|
||||
Initialize and return the configured Stripe SDK module.
|
||||
|
||||
Raises ValueError if no Stripe secret key is configured.
|
||||
"""
|
||||
import stripe
|
||||
from modules.shared.configuration import APP_CONFIG
|
||||
|
||||
apiVersion = APP_CONFIG.get("STRIPE_API_VERSION")
|
||||
if apiVersion:
|
||||
stripe.api_version = apiVersion
|
||||
|
||||
secretKey = APP_CONFIG.get("STRIPE_SECRET_KEY_SECRET") or APP_CONFIG.get("STRIPE_SECRET_KEY")
|
||||
if not secretKey:
|
||||
raise ValueError("STRIPE_SECRET_KEY_SECRET not configured")
|
||||
|
||||
stripe.api_key = secretKey
|
||||
return stripe
|
||||
|
||||
|
||||
|
|
@ -190,7 +190,15 @@ NAVIGATION_SECTIONS = [
|
|||
"path": "/admin/billing",
|
||||
"order": 40,
|
||||
"adminOnly": True,
|
||||
"sysAdminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-subscriptions",
|
||||
"objectKey": "ui.admin.subscriptions",
|
||||
"label": {"en": "Subscriptions", "de": "Abonnements", "fr": "Abonnements"},
|
||||
"icon": "FaFileContract",
|
||||
"path": "/admin/subscriptions",
|
||||
"order": 50,
|
||||
"adminOnly": True,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
|
@ -274,6 +282,16 @@ NAVIGATION_SECTIONS = [
|
|||
"adminOnly": True,
|
||||
"sysAdminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-automation-logs",
|
||||
"objectKey": "ui.admin.automationLogs",
|
||||
"label": {"en": "Execution Logs", "de": "Ausführungsprotokolle", "fr": "Journaux d'exécution"},
|
||||
"icon": "FaClipboardList",
|
||||
"path": "/admin/automation-logs",
|
||||
"order": 85,
|
||||
"adminOnly": True,
|
||||
"sysAdminOnly": True,
|
||||
},
|
||||
{
|
||||
"id": "admin-logs",
|
||||
"objectKey": "ui.admin.logs",
|
||||
|
|
@ -437,6 +455,11 @@ RESOURCE_OBJECTS = [
|
|||
"label": {"en": "Store: Automation", "de": "Store: Automation", "fr": "Store: Automatisation"},
|
||||
"meta": {"category": "store", "featureCode": "automation"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.store.automation2",
|
||||
"label": {"en": "Store: Automation 2", "de": "Store: Automation 2", "fr": "Store: Automatisation 2"},
|
||||
"meta": {"category": "store", "featureCode": "automation2"}
|
||||
},
|
||||
{
|
||||
"objectKey": "resource.store.teamsbot",
|
||||
"label": {"en": "Store: Teams Bot", "de": "Store: Teams Bot", "fr": "Store: Teams Bot"},
|
||||
|
|
|
|||
2
modules/workflows/automation2/__init__.py
Normal file
2
modules/workflows/automation2/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# automation2 - n8n-style graph execution engine.
|
||||
244
modules/workflows/automation2/executionEngine.py
Normal file
244
modules/workflows/automation2/executionEngine.py
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Main execution engine for automation2 graphs.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, List, Set, Optional
|
||||
|
||||
from modules.workflows.automation2.graphUtils import (
|
||||
parseGraph,
|
||||
buildConnectionMap,
|
||||
validateGraph,
|
||||
topoSort,
|
||||
getInputSources,
|
||||
)
|
||||
|
||||
from modules.workflows.automation2.executors import (
|
||||
TriggerExecutor,
|
||||
FlowExecutor,
|
||||
DataExecutor,
|
||||
ActionNodeExecutor,
|
||||
InputExecutor,
|
||||
PauseForHumanTaskError,
|
||||
PauseForEmailWaitError,
|
||||
)
|
||||
from modules.features.automation2.nodeDefinitions import STATIC_NODE_TYPES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _getNodeTypeIds(services: Any = None) -> Set[str]:
|
||||
"""Collect all known node type IDs from static definitions."""
|
||||
return {n["id"] for n in STATIC_NODE_TYPES}
|
||||
|
||||
|
||||
def _getExecutor(
|
||||
nodeType: str,
|
||||
services: Any,
|
||||
automation2_interface: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""Dispatch to correct executor based on node type."""
|
||||
if nodeType.startswith("trigger."):
|
||||
return TriggerExecutor()
|
||||
if nodeType.startswith("flow."):
|
||||
return FlowExecutor()
|
||||
if nodeType.startswith("data."):
|
||||
return DataExecutor()
|
||||
if nodeType.startswith("ai.") or nodeType.startswith("email.") or nodeType.startswith("sharepoint."):
|
||||
return ActionNodeExecutor(services)
|
||||
if nodeType.startswith("input.") and automation2_interface:
|
||||
return InputExecutor(automation2_interface)
|
||||
return None
|
||||
|
||||
|
||||
async def executeGraph(
|
||||
graph: Dict[str, Any],
|
||||
services: Any,
|
||||
workflowId: str = None,
|
||||
instanceId: str = None,
|
||||
userId: str = None,
|
||||
mandateId: str = None,
|
||||
automation2_interface: Optional[Any] = None,
|
||||
initialNodeOutputs: Optional[Dict[str, Any]] = None,
|
||||
startAfterNodeId: Optional[str] = None,
|
||||
runId: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute automation2 graph. Returns { success, nodeOutputs, error?, stopped? }.
|
||||
When an input node is reached and automation2_interface is provided, creates a task,
|
||||
pauses the run, and returns { success: False, paused: True, taskId, runId }.
|
||||
For resume: pass initialNodeOutputs (with result for the human node) and startAfterNodeId.
|
||||
"""
|
||||
logger.info(
|
||||
"executeGraph start: instanceId=%s workflowId=%s userId=%s mandateId=%s resume=%s",
|
||||
instanceId,
|
||||
workflowId,
|
||||
userId,
|
||||
mandateId,
|
||||
startAfterNodeId is not None,
|
||||
)
|
||||
from modules.workflows.processing.shared.methodDiscovery import discoverMethods
|
||||
discoverMethods(services)
|
||||
nodeTypeIds = _getNodeTypeIds(services)
|
||||
logger.debug("executeGraph nodeTypeIds (%d): %s", len(nodeTypeIds), sorted(nodeTypeIds))
|
||||
errors = validateGraph(graph, nodeTypeIds)
|
||||
if errors:
|
||||
logger.warning("executeGraph validation failed: %s", errors)
|
||||
return {"success": False, "error": "; ".join(errors), "nodeOutputs": {}}
|
||||
|
||||
nodes, connections = parseGraph(graph)[:2]
|
||||
connectionMap = buildConnectionMap(connections)
|
||||
inputSources = {n["id"]: getInputSources(n["id"], connectionMap) for n in nodes if n.get("id")}
|
||||
logger.info(
|
||||
"executeGraph parsed: nodes=%d connections=%d connectionMap_targets=%s",
|
||||
len(nodes),
|
||||
len(connections),
|
||||
list(connectionMap.keys()),
|
||||
)
|
||||
|
||||
ordered = topoSort(nodes, connectionMap)
|
||||
ordered_ids = [n.get("id") for n in ordered if n.get("id")]
|
||||
logger.info("executeGraph topoSort order: %s", ordered_ids)
|
||||
|
||||
nodeOutputs: Dict[str, Any] = dict(initialNodeOutputs or {})
|
||||
is_resume = startAfterNodeId is not None
|
||||
if not runId and automation2_interface and workflowId and not is_resume:
|
||||
run_context = {
|
||||
"connectionMap": connectionMap,
|
||||
"inputSources": inputSources,
|
||||
"orderedNodeIds": ordered_ids,
|
||||
}
|
||||
if userId:
|
||||
run_context["ownerId"] = userId
|
||||
if mandateId:
|
||||
run_context["mandateId"] = mandateId
|
||||
if instanceId:
|
||||
run_context["instanceId"] = instanceId
|
||||
run = automation2_interface.createRun(
|
||||
workflowId=workflowId,
|
||||
nodeOutputs=nodeOutputs,
|
||||
context=run_context,
|
||||
)
|
||||
runId = run.get("id") if run else None
|
||||
logger.info("executeGraph created run %s", runId)
|
||||
|
||||
context = {
|
||||
"workflowId": workflowId,
|
||||
"instanceId": instanceId,
|
||||
"userId": userId,
|
||||
"mandateId": mandateId,
|
||||
"nodeOutputs": nodeOutputs,
|
||||
"connectionMap": connectionMap,
|
||||
"inputSources": inputSources,
|
||||
"services": services,
|
||||
"_runId": runId,
|
||||
"_orderedNodes": ordered,
|
||||
}
|
||||
|
||||
skip_until_passed = bool(startAfterNodeId)
|
||||
for i, node in enumerate(ordered):
|
||||
if skip_until_passed:
|
||||
if node.get("id") == startAfterNodeId:
|
||||
skip_until_passed = False
|
||||
continue
|
||||
if context.get("_stopped"):
|
||||
logger.info("executeGraph stopped early (flow.stop) at step %d", i)
|
||||
break
|
||||
nodeId = node.get("id")
|
||||
nodeType = node.get("type", "")
|
||||
executor = _getExecutor(nodeType, services, automation2_interface)
|
||||
logger.info(
|
||||
"executeGraph step %d/%d: nodeId=%s nodeType=%s executor=%s",
|
||||
i + 1,
|
||||
len(ordered),
|
||||
nodeId,
|
||||
nodeType,
|
||||
type(executor).__name__ if executor else "None",
|
||||
)
|
||||
if not executor:
|
||||
nodeOutputs[nodeId] = None
|
||||
logger.debug("executeGraph node %s: no executor, output=None", nodeId)
|
||||
continue
|
||||
try:
|
||||
result = await executor.execute(node, context)
|
||||
nodeOutputs[nodeId] = result
|
||||
logger.info(
|
||||
"executeGraph node %s done: result_type=%s result_keys=%s",
|
||||
nodeId,
|
||||
type(result).__name__,
|
||||
list(result.keys()) if isinstance(result, dict) else "n/a",
|
||||
)
|
||||
except PauseForHumanTaskError as e:
|
||||
logger.info("executeGraph paused for human task %s", e.taskId)
|
||||
return {
|
||||
"success": False,
|
||||
"paused": True,
|
||||
"taskId": e.taskId,
|
||||
"runId": e.runId,
|
||||
"nodeId": e.nodeId,
|
||||
"nodeOutputs": dict(nodeOutputs),
|
||||
}
|
||||
except PauseForEmailWaitError as e:
|
||||
logger.info("executeGraph paused for email wait (run %s, node %s)", e.runId, e.nodeId)
|
||||
# Start email poller on-demand (only runs while workflows wait for email)
|
||||
try:
|
||||
from modules.interfaces.interfaceDbApp import getRootInterface
|
||||
from modules.features.automation2.emailPoller import ensureRunning
|
||||
root = getRootInterface()
|
||||
event_user = root.getUserByUsername("event") if root else None
|
||||
if event_user:
|
||||
ensureRunning(event_user)
|
||||
except Exception as poll_err:
|
||||
logger.warning("Could not start email poller: %s", poll_err)
|
||||
paused_at = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
run_ctx = {
|
||||
"connectionMap": context.get("connectionMap"),
|
||||
"inputSources": context.get("inputSources"),
|
||||
"orderedNodeIds": [n.get("id") for n in context.get("_orderedNodes", []) if n.get("id")],
|
||||
"waitReason": "email",
|
||||
"waitConfig": e.waitConfig,
|
||||
"pausedAt": paused_at,
|
||||
"lastCheckedAt": None,
|
||||
"ownerId": context.get("userId"),
|
||||
"mandateId": context.get("mandateId"),
|
||||
"instanceId": context.get("instanceId"),
|
||||
}
|
||||
automation2_interface.updateRun(
|
||||
e.runId,
|
||||
status="paused",
|
||||
nodeOutputs=dict(nodeOutputs),
|
||||
currentNodeId=e.nodeId,
|
||||
context=run_ctx,
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"paused": True,
|
||||
"waitReason": "email",
|
||||
"runId": e.runId,
|
||||
"nodeId": e.nodeId,
|
||||
"nodeOutputs": dict(nodeOutputs),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception("executeGraph node %s (%s) FAILED: %s", nodeId, nodeType, e)
|
||||
nodeOutputs[nodeId] = {"error": str(e), "success": False}
|
||||
if runId and automation2_interface:
|
||||
automation2_interface.updateRun(runId, status="failed", nodeOutputs=nodeOutputs)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"nodeOutputs": nodeOutputs,
|
||||
"failedNode": nodeId,
|
||||
}
|
||||
|
||||
if runId and automation2_interface:
|
||||
automation2_interface.updateRun(runId, status="completed", nodeOutputs=nodeOutputs)
|
||||
logger.info(
|
||||
"executeGraph complete: success=True nodeOutputs_keys=%s stopped=%s",
|
||||
list(nodeOutputs.keys()),
|
||||
context.get("_stopped", False),
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"nodeOutputs": nodeOutputs,
|
||||
"stopped": context.get("_stopped", False),
|
||||
}
|
||||
18
modules/workflows/automation2/executors/__init__.py
Normal file
18
modules/workflows/automation2/executors/__init__.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Executors for automation2 node types.
|
||||
|
||||
from .triggerExecutor import TriggerExecutor
|
||||
from .flowExecutor import FlowExecutor
|
||||
from .dataExecutor import DataExecutor
|
||||
from .actionNodeExecutor import ActionNodeExecutor
|
||||
from .inputExecutor import InputExecutor, PauseForHumanTaskError, PauseForEmailWaitError
|
||||
|
||||
__all__ = [
|
||||
"TriggerExecutor",
|
||||
"FlowExecutor",
|
||||
"DataExecutor",
|
||||
"ActionNodeExecutor",
|
||||
"InputExecutor",
|
||||
"PauseForHumanTaskError",
|
||||
"PauseForEmailWaitError",
|
||||
]
|
||||
599
modules/workflows/automation2/executors/actionNodeExecutor.py
Normal file
599
modules/workflows/automation2/executors/actionNodeExecutor.py
Normal file
|
|
@ -0,0 +1,599 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Action node executor - maps ai.*, email.*, sharepoint.* to method actions via ActionExecutor.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _getNodeDefinition(nodeType: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get node definition by type id for _method, _action, _paramMap."""
|
||||
from modules.features.automation2.nodeDefinitions import STATIC_NODE_TYPES
|
||||
for node in STATIC_NODE_TYPES:
|
||||
if node.get("id") == nodeType:
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
def _resolveConnectionIdToReference(chatService, connectionId: str, services=None) -> Optional[str]:
|
||||
"""
|
||||
Resolve connectionId (UserConnection.id) to connectionReference format.
|
||||
connectionReference format: connection:{authority}:{externalUsername}
|
||||
Falls back to interfaceDbApp.getUserConnectionById when chatService resolution fails.
|
||||
"""
|
||||
if not connectionId:
|
||||
return None
|
||||
# Already in reference format
|
||||
if isinstance(connectionId, str) and connectionId.startswith("connection:"):
|
||||
return connectionId
|
||||
# Try chatService first
|
||||
if chatService:
|
||||
try:
|
||||
connections = chatService.getUserConnections()
|
||||
for c in connections or []:
|
||||
conn = c if isinstance(c, dict) else (c.model_dump() if hasattr(c, "model_dump") else {})
|
||||
if str(conn.get("id")) == str(connectionId):
|
||||
authority = conn.get("authority")
|
||||
if hasattr(authority, "value"):
|
||||
authority = authority.value
|
||||
username = conn.get("externalUsername", "")
|
||||
return f"connection:{authority}:{username}"
|
||||
except Exception as e:
|
||||
logger.debug("_resolveConnectionIdToReference chatService: %s", e)
|
||||
# Fallback: interfaceDbApp.getUserConnectionById (automation2 may not have chat.getUserConnections)
|
||||
app = getattr(services, "interfaceDbApp", None) if services else None
|
||||
if app and hasattr(app, "getUserConnectionById"):
|
||||
try:
|
||||
conn = app.getUserConnectionById(str(connectionId))
|
||||
if conn:
|
||||
authority = getattr(conn, "authority", None)
|
||||
if hasattr(authority, "value"):
|
||||
authority = authority.value
|
||||
else:
|
||||
authority = str(authority) if authority else "outlook"
|
||||
username = getattr(conn, "externalUsername", "") or ""
|
||||
return f"connection:{authority}:{username}"
|
||||
except Exception as e:
|
||||
logger.debug("_resolveConnectionIdToReference getUserConnectionById: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _extractEmailContentFromUpstream(inp: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract {subject, body, to} from upstream node output (e.g. AI node returning JSON).
|
||||
Expects JSON like {"subject": "...", "body": "...", "to": "..."} in documentData.
|
||||
"""
|
||||
if not inp:
|
||||
return None
|
||||
import json
|
||||
docs = inp.get("documents", inp.get("documentList", [])) if isinstance(inp, dict) else []
|
||||
if not docs:
|
||||
return None
|
||||
doc = docs[0] if isinstance(docs, list) else docs
|
||||
raw = getattr(doc, "documentData", None) if hasattr(doc, "documentData") else (doc.get("documentData") if isinstance(doc, dict) else None)
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
data = json.loads(raw) if isinstance(raw, str) else raw
|
||||
if isinstance(data, dict) and data.get("subject") and data.get("body"):
|
||||
return {
|
||||
"subject": str(data.get("subject", "")),
|
||||
"body": str(data.get("body", "")),
|
||||
"to": data.get("to"),
|
||||
}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _extractContextFromUpstream(inp: Any) -> Optional[str]:
|
||||
"""
|
||||
Extract plain text context from upstream node output (e.g. AI node returning txt).
|
||||
Use when _extractEmailContentFromUpstream returns None – the generated document content
|
||||
(email body, summary, etc.) should be passed as context to email.draftEmail.
|
||||
"""
|
||||
if not inp:
|
||||
return None
|
||||
docs = None
|
||||
if isinstance(inp, dict):
|
||||
docs = inp.get("documents") or inp.get("documentList")
|
||||
if not docs and isinstance(inp.get("data"), dict):
|
||||
docs = inp.get("data", {}).get("documents")
|
||||
if not docs or not isinstance(docs, (list, tuple)):
|
||||
return None
|
||||
doc = docs[0] if docs else None
|
||||
if not doc:
|
||||
return None
|
||||
raw = getattr(doc, "documentData", None) if hasattr(doc, "documentData") else (doc.get("documentData") or doc.get("content") if isinstance(doc, dict) else None)
|
||||
if not raw:
|
||||
return None
|
||||
if isinstance(raw, bytes):
|
||||
return raw.decode("utf-8", errors="replace").strip()
|
||||
s = str(raw).strip()
|
||||
return s if s else None
|
||||
|
||||
|
||||
def _gatherAttachmentDocumentsFromUpstream(
|
||||
nodeId: str,
|
||||
inputSources: Dict[str, Dict[int, tuple]],
|
||||
nodeOutputs: Dict[str, Any],
|
||||
orderedNodes: List[Dict],
|
||||
visited: Optional[set] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Walk upstream from nodeId through AI nodes to collect file documents (e.g. from sharepoint.downloadFile).
|
||||
Used when email.draftEmail has AI upstream – attachments come from file nodes, not AI output.
|
||||
"""
|
||||
visited = visited or set()
|
||||
if nodeId in visited:
|
||||
return []
|
||||
visited.add(nodeId)
|
||||
docs = []
|
||||
src = inputSources.get(nodeId, {}).get(0)
|
||||
if not src:
|
||||
return []
|
||||
srcId, _ = src
|
||||
srcNode = next((n for n in (orderedNodes or []) if n.get("id") == srcId), None)
|
||||
srcType = (srcNode or {}).get("type", "")
|
||||
out = nodeOutputs.get(srcId)
|
||||
|
||||
if srcType in ("sharepoint.downloadFile", "sharepoint.readFile"):
|
||||
if isinstance(out, dict):
|
||||
for d in out.get("documents") or out.get("documentList") or []:
|
||||
if isinstance(d, dict) and (d.get("documentData") or (d.get("validationMetadata") or {}).get("fileId")):
|
||||
docs.append(d)
|
||||
elif hasattr(d, "documentData") or (getattr(d, "validationMetadata", None) or {}).get("fileId"):
|
||||
docs.append(d.model_dump() if hasattr(d, "model_dump") else d)
|
||||
elif srcType.startswith("ai."):
|
||||
docs.extend(
|
||||
_gatherAttachmentDocumentsFromUpstream(srcId, inputSources, nodeOutputs, orderedNodes, visited)
|
||||
)
|
||||
return docs
|
||||
|
||||
|
||||
def _getIncomingEmailFromUpstream(
|
||||
nodeId: str,
|
||||
inputSources: Dict[str, Dict[int, tuple]],
|
||||
nodeOutputs: Dict[str, Any],
|
||||
orderedNodes: List[Dict],
|
||||
) -> Optional[tuple]:
|
||||
"""
|
||||
Walk upstream from draftEmail to find email.checkEmail/searchEmail and return (context, documentList).
|
||||
context = formatted incoming email(s) for composeAndDraftEmail.
|
||||
documentList = documents from the email node for attachment/context.
|
||||
"""
|
||||
src = inputSources.get(nodeId, {}).get(0)
|
||||
if not src:
|
||||
return None
|
||||
srcId, _ = src
|
||||
srcNode = next((n for n in (orderedNodes or []) if n.get("id") == srcId), None)
|
||||
srcType = (srcNode or {}).get("type", "")
|
||||
|
||||
# Direct connection to email node
|
||||
if srcType in ("email.checkEmail", "email.searchEmail"):
|
||||
out = nodeOutputs.get(srcId)
|
||||
return _formatEmailOutputAsContext(out)
|
||||
|
||||
# Connected via AI node: walk one more step to email source
|
||||
if srcType.startswith("ai."):
|
||||
src2 = inputSources.get(srcId, {}).get(0)
|
||||
if not src2:
|
||||
return None
|
||||
emailNodeId, _ = src2
|
||||
emailNode = next((n for n in (orderedNodes or []) if n.get("id") == emailNodeId), None)
|
||||
if (emailNode or {}).get("type") in ("email.checkEmail", "email.searchEmail"):
|
||||
out = nodeOutputs.get(emailNodeId)
|
||||
return _formatEmailOutputAsContext(out)
|
||||
return None
|
||||
|
||||
|
||||
def _formatEmailOutputAsContext(out: Any) -> Optional[tuple]:
|
||||
"""Format email node output as (context, documentList, reply_to) for composeAndDraftEmail.
|
||||
reply_to = sender address of first email (recipient for the reply).
|
||||
"""
|
||||
if not out:
|
||||
return None
|
||||
docs = out.get("documents", out.get("documentList", [])) if isinstance(out, dict) else []
|
||||
if not docs:
|
||||
return None
|
||||
doc = docs[0] if isinstance(docs, list) else docs
|
||||
raw = getattr(doc, "documentData", None) if hasattr(doc, "documentData") else (doc.get("documentData") if isinstance(doc, dict) else None)
|
||||
if not raw:
|
||||
return None
|
||||
import json
|
||||
try:
|
||||
data = json.loads(raw) if isinstance(raw, str) else raw
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
# readEmails: data.emails.emails | searchEmails: data.searchResults.results
|
||||
emails_data = data.get("emails") or {}
|
||||
emails_list = emails_data.get("emails", []) if isinstance(emails_data, dict) else []
|
||||
if not emails_list:
|
||||
search_results = data.get("searchResults") or {}
|
||||
emails_list = search_results.get("results", []) if isinstance(search_results, dict) else []
|
||||
if not emails_list:
|
||||
return None
|
||||
reply_to = None
|
||||
parts = ["Reply to the following email(s):", ""]
|
||||
for i, em in enumerate(emails_list[:5]): # max 5
|
||||
if not isinstance(em, dict):
|
||||
continue
|
||||
fr = em.get("from", em.get("sender", {}))
|
||||
addr = fr.get("emailAddress", {}) if isinstance(fr, dict) else {}
|
||||
from_str = addr.get("address", "") or addr.get("name", "")
|
||||
if from_str and not reply_to:
|
||||
reply_to = addr.get("address", "") or from_str
|
||||
subj = em.get("subject", "")
|
||||
body = em.get("bodyPreview", "") or (em.get("body") or {}).get("content", "") if isinstance(em.get("body"), dict) else ""
|
||||
if body and len(str(body)) > 1500:
|
||||
body = str(body)[:1500] + "..."
|
||||
parts.append(f"From: {from_str}")
|
||||
parts.append(f"Subject: {subj}")
|
||||
parts.append(f"Content:\n{body}")
|
||||
parts.append("")
|
||||
if reply_to:
|
||||
parts.insert(2, f"Recipient (reply to this address): {reply_to}")
|
||||
parts.insert(3, "")
|
||||
context = "\n".join(parts).strip()
|
||||
return (context, docs, reply_to)
|
||||
|
||||
|
||||
def _buildSearchQuery(
|
||||
query: str = None,
|
||||
fromAddress: str = None,
|
||||
toAddress: str = None,
|
||||
subjectContains: str = None,
|
||||
bodyContains: str = None,
|
||||
hasAttachment: bool = None,
|
||||
filter: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build Microsoft Graph $search query from discrete params.
|
||||
Uses KQL: from:, to:, subject:, body:, hasattachments: (supported by Graph API).
|
||||
"""
|
||||
if filter and str(filter).strip():
|
||||
return str(filter).strip()
|
||||
parts = []
|
||||
if query and str(query).strip():
|
||||
parts.append(str(query).strip())
|
||||
if fromAddress and str(fromAddress).strip():
|
||||
safe = str(fromAddress).strip().replace('"', '')
|
||||
parts.append(f'from:{safe}')
|
||||
if toAddress and str(toAddress).strip():
|
||||
safe = str(toAddress).strip().replace('"', '')
|
||||
parts.append(f'to:{safe}')
|
||||
if subjectContains and str(subjectContains).strip():
|
||||
safe = str(subjectContains).strip().replace('"', '')
|
||||
parts.append(f'subject:{safe}')
|
||||
if bodyContains and str(bodyContains).strip():
|
||||
safe = str(bodyContains).strip().replace('"', '')
|
||||
parts.append(f'body:{safe}')
|
||||
if hasAttachment is True:
|
||||
parts.append("hasattachments:true")
|
||||
return " ".join(parts) if parts else "*"
|
||||
|
||||
|
||||
def _buildEmailFilter(fromAddress: str = None, subjectContains: str = None, hasAttachment: bool = None) -> str:
|
||||
"""
|
||||
Build Microsoft Graph API $filter string from discrete email filter params.
|
||||
Used for email.checkEmail (and trigger.newEmail).
|
||||
"""
|
||||
parts = []
|
||||
if fromAddress and str(fromAddress).strip():
|
||||
safe = str(fromAddress).strip().replace("'", "''")
|
||||
parts.append(f"from/emailAddress/address eq '{safe}'")
|
||||
if subjectContains and str(subjectContains).strip():
|
||||
safe = str(subjectContains).strip().replace("'", "''")
|
||||
parts.append(f"contains(subject,'{safe}')")
|
||||
if hasAttachment is True:
|
||||
parts.append("hasAttachments eq true")
|
||||
return " and ".join(parts) if parts else ""
|
||||
|
||||
|
||||
def _buildActionParams(
|
||||
node: Dict[str, Any],
|
||||
nodeDef: Dict[str, Any],
|
||||
resolvedParams: Dict[str, Any],
|
||||
chatService,
|
||||
services=None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build params for ActionExecutor from node parameters using _paramMap.
|
||||
Resolves connectionId -> connectionReference.
|
||||
Handles _contextFrom for composite params (e.g. email.draftEmail subject+body -> context).
|
||||
"""
|
||||
params = dict(resolvedParams)
|
||||
paramMap = nodeDef.get("_paramMap") or {}
|
||||
contextFrom = nodeDef.get("_contextFrom") or []
|
||||
|
||||
# email.checkEmail: build filter from discrete params (fromAddress, subjectContains, hasAttachment)
|
||||
nodeType = node.get("type", "")
|
||||
if nodeType == "email.checkEmail":
|
||||
built = _buildEmailFilter(
|
||||
fromAddress=params.get("fromAddress"),
|
||||
subjectContains=params.get("subjectContains"),
|
||||
hasAttachment=params.get("hasAttachment"),
|
||||
)
|
||||
raw_filter = (params.get("filter") or "").strip()
|
||||
params["filter"] = built if built else (raw_filter if raw_filter else None)
|
||||
params.pop("fromAddress", None)
|
||||
params.pop("subjectContains", None)
|
||||
params.pop("hasAttachment", None)
|
||||
|
||||
# email.searchEmail: build query from discrete params (fromAddress, toAddress, subjectContains, bodyContains, hasAttachment)
|
||||
if nodeType == "email.searchEmail":
|
||||
built = _buildSearchQuery(
|
||||
query=params.get("query"),
|
||||
fromAddress=params.get("fromAddress"),
|
||||
toAddress=params.get("toAddress"),
|
||||
subjectContains=params.get("subjectContains"),
|
||||
bodyContains=params.get("bodyContains"),
|
||||
hasAttachment=params.get("hasAttachment"),
|
||||
filter=params.get("filter"),
|
||||
)
|
||||
params["query"] = built
|
||||
params.pop("fromAddress", None)
|
||||
params.pop("toAddress", None)
|
||||
params.pop("subjectContains", None)
|
||||
params.pop("bodyContains", None)
|
||||
params.pop("hasAttachment", None)
|
||||
params.pop("filter", None)
|
||||
|
||||
# Resolve connectionId to connectionReference
|
||||
if "connectionId" in params:
|
||||
connId = params.get("connectionId")
|
||||
if connId:
|
||||
ref = _resolveConnectionIdToReference(chatService, connId, services)
|
||||
if ref:
|
||||
params["connectionReference"] = ref
|
||||
else:
|
||||
logger.warning(f"Could not resolve connectionId {connId} to connectionReference")
|
||||
params.pop("connectionId", None)
|
||||
|
||||
# Build context from multiple params (e.g. subject + body for draft email)
|
||||
if contextFrom:
|
||||
parts = []
|
||||
for key in contextFrom:
|
||||
val = params.get(key)
|
||||
if val:
|
||||
if key == "subject":
|
||||
parts.append(f"Subject: {val}")
|
||||
elif key == "body":
|
||||
parts.append(f"Body:\n{val}")
|
||||
else:
|
||||
parts.append(str(val))
|
||||
if parts:
|
||||
params["context"] = "\n\n".join(parts)
|
||||
for k in contextFrom:
|
||||
params.pop(k, None)
|
||||
|
||||
# Apply paramMap: node param name -> action param name
|
||||
result = {}
|
||||
mappedNodeKeys = {nodeKey for nodeKey, actionKey in paramMap.items() if actionKey and nodeKey in params}
|
||||
for nodeKey, actionKey in paramMap.items():
|
||||
if nodeKey in params and actionKey:
|
||||
result[actionKey] = params[nodeKey]
|
||||
# Pass through params not used as source for mapping
|
||||
for k, v in params.items():
|
||||
if k not in mappedNodeKeys and k not in result:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
class ActionNodeExecutor:
|
||||
"""Execute ai.*, email.*, sharepoint.* nodes by mapping to method actions."""
|
||||
|
||||
def __init__(self, services: Any):
|
||||
self.services = services
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
context: Dict[str, Any],
|
||||
) -> Any:
|
||||
from modules.features.automation2.nodeRegistry import getNodeTypeToMethodAction
|
||||
from modules.workflows.automation2.graphUtils import resolveParameterReferences
|
||||
from modules.workflows.processing.core.actionExecutor import ActionExecutor
|
||||
|
||||
nodeType = node.get("type", "")
|
||||
nodeId = node.get("id", "")
|
||||
logger.info("ActionNodeExecutor node %s type=%s", nodeId, nodeType)
|
||||
|
||||
mapping = getNodeTypeToMethodAction()
|
||||
methodAction = mapping.get(nodeType)
|
||||
if not methodAction:
|
||||
logger.debug("ActionNodeExecutor node %s not in mapping -> None", nodeId)
|
||||
return None
|
||||
|
||||
methodName, actionName = methodAction
|
||||
logger.info("ActionNodeExecutor node %s method=%s action=%s", nodeId, methodName, actionName)
|
||||
|
||||
nodeDef = _getNodeDefinition(nodeType)
|
||||
params = dict(node.get("parameters") or {})
|
||||
resolvedParams = resolveParameterReferences(params, context.get("nodeOutputs", {}))
|
||||
|
||||
# Merge input from connected nodes (documentList, etc.)
|
||||
inputSources = context.get("inputSources", {}).get(nodeId, {})
|
||||
if 0 in inputSources:
|
||||
srcId, _ = inputSources[0]
|
||||
inp = context.get("nodeOutputs", {}).get(srcId)
|
||||
if isinstance(inp, dict):
|
||||
resolvedParams.setdefault("documentList", inp.get("documents", inp.get("documentList", [])))
|
||||
elif inp is not None:
|
||||
resolvedParams.setdefault("input", inp)
|
||||
|
||||
# ai.prompt with email upstream: inject actual email content into prompt so AI has context
|
||||
# (getChatDocumentsFromDocumentList fails in automation2 – workflow has no messages)
|
||||
if nodeType.startswith("ai."):
|
||||
orderedNodes = context.get("_orderedNodes") or []
|
||||
if 0 in inputSources:
|
||||
srcId, _ = inputSources[0]
|
||||
srcNode = next((n for n in orderedNodes if n.get("id") == srcId), None)
|
||||
srcType = (srcNode or {}).get("type", "")
|
||||
if srcType in ("email.checkEmail", "email.searchEmail"):
|
||||
incoming = _getIncomingEmailFromUpstream(
|
||||
nodeId,
|
||||
context.get("inputSources", {}),
|
||||
context.get("nodeOutputs", {}),
|
||||
orderedNodes,
|
||||
)
|
||||
if incoming:
|
||||
ctx, _doc_list, _reply_to = incoming
|
||||
if ctx and ctx.strip():
|
||||
base_prompt = (resolvedParams.get("aiPrompt") or "").strip()
|
||||
resolvedParams["aiPrompt"] = (
|
||||
f"Eingehende E-Mail:\n{ctx}\n\nAufgabe: {base_prompt}"
|
||||
if base_prompt
|
||||
else f"Eingehende E-Mail:\n{ctx}"
|
||||
)
|
||||
logger.debug("ai.prompt: injected email context from upstream %s", srcType)
|
||||
|
||||
chatService = getattr(self.services, "chat", None)
|
||||
actionParams = _buildActionParams(node, nodeDef or {}, resolvedParams, chatService, self.services)
|
||||
|
||||
# email.checkEmail: pause and wait for new email (background poller will resume)
|
||||
if nodeType == "email.checkEmail":
|
||||
runId = context.get("_runId")
|
||||
workflowId = context.get("workflowId")
|
||||
connRef = actionParams.get("connectionReference")
|
||||
if runId and workflowId and connRef:
|
||||
from modules.workflows.automation2.executors import PauseForEmailWaitError
|
||||
waitConfig = {
|
||||
"connectionReference": connRef,
|
||||
"folder": actionParams.get("folder", "Inbox"),
|
||||
"limit": min(int(actionParams.get("limit") or 10), 50),
|
||||
"filter": actionParams.get("filter"),
|
||||
}
|
||||
raise PauseForEmailWaitError(runId=runId, nodeId=nodeId, waitConfig=waitConfig)
|
||||
# Fallback: no pause (calls readEmails directly) – needs runId, workflowId, connectionReference
|
||||
if not runId or not workflowId:
|
||||
logger.warning(
|
||||
"email.checkEmail not pausing (runId=%s workflowId=%s) – run must be saved/executed as workflow",
|
||||
runId,
|
||||
workflowId,
|
||||
)
|
||||
elif not connRef:
|
||||
logger.warning(
|
||||
"email.checkEmail not pausing – connectionReference missing (check connectionId/config)",
|
||||
)
|
||||
|
||||
# email.draftEmail: use AI output as emailContent if available; else pass incoming email as context
|
||||
if nodeType == "email.draftEmail":
|
||||
inputSources = context.get("inputSources", {})
|
||||
nodeOutputs = context.get("nodeOutputs", {})
|
||||
orderedNodes = context.get("_orderedNodes") or []
|
||||
if 0 in inputSources.get(nodeId, {}):
|
||||
srcId, _ = inputSources[nodeId][0]
|
||||
srcNode = next((n for n in orderedNodes if n.get("id") == srcId), None)
|
||||
srcType = (srcNode or {}).get("type", "")
|
||||
if srcType.startswith("ai."):
|
||||
inp = nodeOutputs.get(srcId)
|
||||
email_content = _extractEmailContentFromUpstream(inp)
|
||||
if email_content:
|
||||
actionParams["emailContent"] = email_content
|
||||
actionParams["context"] = email_content.get("body", "") or "(from connected AI node)"
|
||||
# Attachments: gather from file nodes upstream of AI (e.g. downloadFile -> AI -> email)
|
||||
attachment_docs = _gatherAttachmentDocumentsFromUpstream(
|
||||
nodeId, inputSources, nodeOutputs, orderedNodes
|
||||
)
|
||||
if attachment_docs:
|
||||
existing = actionParams.get("documentList") or []
|
||||
# Prefer file docs from upstream; append any existing that look like binary attachments
|
||||
def _is_binary_attachment(d):
|
||||
if isinstance(d, dict) and d.get("documentData"):
|
||||
try:
|
||||
import json
|
||||
json.loads(d["documentData"])
|
||||
return False # JSON = email content, not attachment
|
||||
except (TypeError, ValueError):
|
||||
return True
|
||||
return bool(isinstance(d, dict) and (d.get("validationMetadata") or {}).get("fileId"))
|
||||
extra = [x for x in (existing if isinstance(existing, list) else []) if _is_binary_attachment(x)]
|
||||
actionParams["documentList"] = attachment_docs + extra
|
||||
if not email_content:
|
||||
# AI returns plain text (e.g. email.txt): use as email body directly (no extra AI call)
|
||||
ctx = _extractContextFromUpstream(inp)
|
||||
if ctx:
|
||||
actionParams["emailContent"] = {
|
||||
"subject": actionParams.get("subject", "Draft"),
|
||||
"body": ctx,
|
||||
"to": actionParams.get("to"),
|
||||
}
|
||||
actionParams["context"] = ctx
|
||||
else:
|
||||
# Fallback: incoming email from upstream (if flow is email->AI->draft)
|
||||
incoming = _getIncomingEmailFromUpstream(nodeId, inputSources, nodeOutputs, orderedNodes)
|
||||
if incoming:
|
||||
ctx, doc_list, reply_to = incoming
|
||||
actionParams["context"] = ctx
|
||||
if doc_list and not actionParams.get("documentList"):
|
||||
actionParams["documentList"] = doc_list
|
||||
if reply_to and not actionParams.get("to"):
|
||||
actionParams["to"] = [reply_to]
|
||||
else:
|
||||
doc_count = len(inp.get("documents", [])) if isinstance(inp, dict) else 0
|
||||
logger.warning(
|
||||
"email.draftEmail: AI upstream returned %d doc(s) but context extraction failed (no subject/body, no plain text). "
|
||||
"Ensure AI node outputs document with documentData.",
|
||||
doc_count,
|
||||
)
|
||||
actionParams["context"] = "(no context extracted from upstream – check AI node output)"
|
||||
elif srcType in ("sharepoint.downloadFile", "sharepoint.readFile"):
|
||||
# File itself is the context: pass as attachment, use filename as minimal context (no content extraction)
|
||||
if not actionParams.get("context"):
|
||||
inp = nodeOutputs.get(srcId)
|
||||
docs = (inp.get("documents") or inp.get("documentList", [])) if isinstance(inp, dict) else []
|
||||
doc = docs[0] if docs else None
|
||||
name = None
|
||||
if isinstance(doc, dict):
|
||||
name = doc.get("documentName") or doc.get("fileName")
|
||||
elif doc and hasattr(doc, "documentName"):
|
||||
name = getattr(doc, "documentName", None) or getattr(doc, "fileName", None)
|
||||
ctx = name if name else "Attachment"
|
||||
actionParams["context"] = ctx
|
||||
actionParams["emailContent"] = {
|
||||
"subject": actionParams.get("subject", "Draft"),
|
||||
"body": ctx,
|
||||
"to": actionParams.get("to"),
|
||||
}
|
||||
# documentList already merged from upstream (file as attachment)
|
||||
else:
|
||||
# Direct connection to email.checkEmail/searchEmail: use incoming email as context
|
||||
if not actionParams.get("context"):
|
||||
incoming = _getIncomingEmailFromUpstream(nodeId, inputSources, nodeOutputs, orderedNodes)
|
||||
if incoming:
|
||||
ctx, doc_list, reply_to = incoming
|
||||
actionParams["context"] = ctx
|
||||
if doc_list and not actionParams.get("documentList"):
|
||||
actionParams["documentList"] = doc_list
|
||||
if reply_to and not actionParams.get("to"):
|
||||
actionParams["to"] = [reply_to]
|
||||
|
||||
# Generic context handover: when upstream provides documents, pass first doc as content for actions that expect it
|
||||
docList = actionParams.get("documentList") or resolvedParams.get("documentList")
|
||||
if docList and "content" not in actionParams:
|
||||
first = docList[0] if isinstance(docList, list) and docList else docList
|
||||
# Actions like sharepoint.uploadFile consume content from context
|
||||
actionParams["content"] = first
|
||||
|
||||
executor = ActionExecutor(self.services)
|
||||
logger.info("ActionNodeExecutor node %s calling executeAction(%s, %s)", nodeId, methodName, actionName)
|
||||
result = await executor.executeAction(methodName, actionName, actionParams)
|
||||
|
||||
out = {
|
||||
"success": result.success,
|
||||
"error": result.error,
|
||||
"documents": [d.model_dump() if hasattr(d, "model_dump") else d for d in (result.documents or [])],
|
||||
"data": result.model_dump() if hasattr(result, "model_dump") else {"success": result.success, "error": result.error},
|
||||
}
|
||||
logger.info(
|
||||
"ActionNodeExecutor node %s result: success=%s error=%s doc_count=%d",
|
||||
nodeId,
|
||||
result.success,
|
||||
result.error,
|
||||
len(out.get("documents", [])),
|
||||
)
|
||||
return out
|
||||
120
modules/workflows/automation2/executors/dataExecutor.py
Normal file
120
modules/workflows/automation2/executors/dataExecutor.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Data transformation node executor (setFields, filter, parseJson, template).
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, Any, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_nested(obj: Any, path: str) -> Any:
|
||||
"""Get nested key from obj, e.g. 'data.items'."""
|
||||
for k in path.split("."):
|
||||
if not k:
|
||||
continue
|
||||
if isinstance(obj, dict) and k in obj:
|
||||
obj = obj[k]
|
||||
elif isinstance(obj, (list, tuple)) and k.isdigit():
|
||||
obj = obj[int(k)]
|
||||
else:
|
||||
return None
|
||||
return obj
|
||||
|
||||
|
||||
class DataExecutor:
|
||||
"""Execute data transformation nodes."""
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
context: Dict[str, Any],
|
||||
) -> Any:
|
||||
nodeType = node.get("type", "")
|
||||
nodeOutputs = context.get("nodeOutputs", {})
|
||||
nodeId = node.get("id", "")
|
||||
inputSources = context.get("inputSources", {}).get(nodeId, {})
|
||||
params = node.get("parameters") or {}
|
||||
logger.info(
|
||||
"DataExecutor node %s type=%s inputSources=%s params=%s",
|
||||
nodeId,
|
||||
nodeType,
|
||||
inputSources,
|
||||
params,
|
||||
)
|
||||
|
||||
inp = None
|
||||
if 0 in inputSources:
|
||||
srcId, _ = inputSources[0]
|
||||
inp = nodeOutputs.get(srcId)
|
||||
|
||||
from modules.workflows.automation2.graphUtils import resolveParameterReferences
|
||||
resolvedParams = {k: resolveParameterReferences(v, nodeOutputs) for k, v in params.items()}
|
||||
|
||||
if nodeType == "data.setFields":
|
||||
out = self._setFields(inp, resolvedParams)
|
||||
logger.info("DataExecutor node %s setFields inp=%s -> %s", nodeId, type(inp).__name__, out)
|
||||
return out
|
||||
if nodeType == "data.filter":
|
||||
out = self._filter(inp, resolvedParams)
|
||||
logger.info("DataExecutor node %s filter inp=%s -> len=%d", nodeId, type(inp).__name__, len(out) if isinstance(out, list) else -1)
|
||||
return out
|
||||
if nodeType == "data.parseJson":
|
||||
out = self._parseJson(inp, resolvedParams)
|
||||
logger.info("DataExecutor node %s parseJson -> %s", nodeId, type(out).__name__)
|
||||
return out
|
||||
if nodeType == "data.template":
|
||||
out = self._template(inp, resolvedParams, nodeOutputs)
|
||||
logger.info("DataExecutor node %s template -> %s", nodeId, out)
|
||||
return out
|
||||
|
||||
logger.debug("DataExecutor node %s unhandled type %s -> passThrough", nodeId, nodeType)
|
||||
return inp
|
||||
|
||||
def _setFields(self, inp: Any, params: Dict) -> Any:
|
||||
fields = params.get("fields", {})
|
||||
if not isinstance(fields, dict):
|
||||
return inp
|
||||
base = dict(inp) if isinstance(inp, dict) else {}
|
||||
base.update(fields)
|
||||
return base
|
||||
|
||||
def _filter(self, inp: Any, params: Dict) -> Any:
|
||||
itemsPath = (params.get("itemsPath") or "").strip()
|
||||
condition = params.get("condition", "True")
|
||||
items = inp
|
||||
if itemsPath:
|
||||
items = _get_nested(inp, itemsPath)
|
||||
if not isinstance(items, list):
|
||||
items = [inp] if inp is not None else []
|
||||
out = []
|
||||
for i, item in enumerate(items):
|
||||
try:
|
||||
local = {"item": item, "index": i, "input": inp}
|
||||
ok = bool(eval(condition, {"__builtins__": {}}, local))
|
||||
if ok:
|
||||
out.append(item)
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
def _parseJson(self, inp: Any, params: Dict) -> Any:
|
||||
jsonPath = (params.get("jsonPath") or "").strip()
|
||||
raw = inp
|
||||
if jsonPath:
|
||||
raw = _get_nested(inp, jsonPath) if isinstance(inp, dict) else inp
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
return json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": "Invalid JSON", "raw": raw[:200]}
|
||||
return inp
|
||||
|
||||
def _template(self, inp: Any, params: Dict, nodeOutputs: Dict) -> Any:
|
||||
tpl = params.get("template", "")
|
||||
from modules.workflows.automation2.graphUtils import resolveParameterReferences
|
||||
result = resolveParameterReferences(tpl, nodeOutputs)
|
||||
return {"text": result, "template": tpl}
|
||||
146
modules/workflows/automation2/executors/flowExecutor.py
Normal file
146
modules/workflows/automation2/executors/flowExecutor.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Flow control node executor (ifElse, merge, wait, stop).
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FlowExecutor:
|
||||
"""Execute flow control nodes."""
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
context: Dict[str, Any],
|
||||
) -> Any:
|
||||
nodeType = node.get("type", "")
|
||||
nodeOutputs = context.get("nodeOutputs", {})
|
||||
connectionMap = context.get("connectionMap", {})
|
||||
nodeId = node.get("id", "")
|
||||
inputSources = context.get("inputSources", {}).get(nodeId, {})
|
||||
logger.info(
|
||||
"FlowExecutor node %s type=%s inputSources=%s params=%s",
|
||||
nodeId,
|
||||
nodeType,
|
||||
inputSources,
|
||||
node.get("parameters"),
|
||||
)
|
||||
|
||||
if nodeType == "flow.ifElse":
|
||||
out = await self._ifElse(node, nodeOutputs, nodeId, inputSources)
|
||||
logger.info("FlowExecutor node %s ifElse -> %s", nodeId, out)
|
||||
return out
|
||||
if nodeType == "flow.merge":
|
||||
out = await self._merge(node, nodeOutputs, nodeId, inputSources)
|
||||
logger.info("FlowExecutor node %s merge -> %s", nodeId, out)
|
||||
return out
|
||||
if nodeType == "flow.wait":
|
||||
out = await self._wait(node, nodeOutputs, nodeId, inputSources)
|
||||
logger.info("FlowExecutor node %s wait -> %s", nodeId, out)
|
||||
return out
|
||||
if nodeType == "flow.stop":
|
||||
context["_stopped"] = True
|
||||
logger.info("FlowExecutor node %s -> STOP", nodeId)
|
||||
return {"stopped": True}
|
||||
if nodeType == "flow.switch":
|
||||
out = await self._switch(node, nodeOutputs, nodeId, inputSources)
|
||||
logger.info("FlowExecutor node %s switch -> %s", nodeId, out)
|
||||
return out
|
||||
if nodeType == "flow.loop":
|
||||
out = await self._loop(node, nodeOutputs, nodeId, inputSources)
|
||||
logger.info("FlowExecutor node %s loop -> %s", nodeId, out)
|
||||
return out
|
||||
|
||||
logger.debug("FlowExecutor node %s unhandled type %s -> None", nodeId, nodeType)
|
||||
return None
|
||||
|
||||
def _getInputData(self, nodeId: str, inputSources: Dict, nodeOutputs: Dict, outputIndex: int = 0) -> Any:
|
||||
"""Get data from the connected source node."""
|
||||
sources = inputSources.get(nodeId, {})
|
||||
if 0 not in sources:
|
||||
return None
|
||||
srcId, srcOut = sources[0]
|
||||
return nodeOutputs.get(srcId)
|
||||
|
||||
async def _ifElse(
|
||||
self,
|
||||
node: Dict,
|
||||
nodeOutputs: Dict,
|
||||
nodeId: str,
|
||||
inputSources: Dict,
|
||||
) -> Any:
|
||||
condExpr = (node.get("parameters") or {}).get("condition", "")
|
||||
inp = self._getInputData(nodeId, {nodeId: inputSources}, nodeOutputs)
|
||||
# Simple eval - in production use safe evaluation
|
||||
try:
|
||||
# Replace {{nodeId}} refs with actual values
|
||||
from modules.workflows.automation2.graphUtils import resolveParameterReferences
|
||||
resolved = resolveParameterReferences(condExpr, nodeOutputs)
|
||||
# Minimal eval for simple comparisons (e.g. "True", "1 > 0")
|
||||
ok = bool(eval(resolved)) if resolved else False
|
||||
except Exception:
|
||||
ok = False
|
||||
return {"branch": 0 if ok else 1, "conditionResult": ok, "input": inp}
|
||||
|
||||
async def _merge(self, node: Dict, nodeOutputs: Dict, nodeId: str, inputSources: Dict) -> Any:
|
||||
mode = (node.get("parameters") or {}).get("mode", "append")
|
||||
sources = inputSources
|
||||
items = []
|
||||
for inpIdx in sorted(sources.keys()):
|
||||
srcId, _ = sources[inpIdx]
|
||||
data = nodeOutputs.get(srcId)
|
||||
if data is not None:
|
||||
if isinstance(data, list):
|
||||
items.extend(data)
|
||||
else:
|
||||
items.append(data)
|
||||
if mode == "combine" and len(items) == 2:
|
||||
if isinstance(items[0], dict) and isinstance(items[1], dict):
|
||||
return {**items[0], **items[1]}
|
||||
return items
|
||||
|
||||
async def _wait(self, node: Dict, nodeOutputs: Dict) -> Any:
|
||||
secs = (node.get("parameters") or {}).get("seconds", 0)
|
||||
if secs > 0:
|
||||
await asyncio.sleep(min(float(secs), 300))
|
||||
nodeId = node.get("id")
|
||||
from modules.workflows.automation2.graphUtils import getInputSources
|
||||
# Input comes from context
|
||||
inp = context.get("_inputData") if "context" in dir() else None
|
||||
return nodeOutputs.get(nodeId, {})
|
||||
|
||||
async def _wait(
|
||||
self,
|
||||
node: Dict,
|
||||
nodeOutputs: Dict,
|
||||
nodeId: str,
|
||||
inputSources: Dict,
|
||||
) -> Any:
|
||||
secs = (node.get("parameters") or {}).get("seconds", 0)
|
||||
if secs > 0:
|
||||
await asyncio.sleep(min(float(secs), 300))
|
||||
if 0 in inputSources:
|
||||
srcId, _ = inputSources[0]
|
||||
return nodeOutputs.get(srcId)
|
||||
return None
|
||||
|
||||
async def _switch(self, node: Dict, nodeOutputs: Dict, nodeId: str, inputSources: Dict) -> Any:
|
||||
valueExpr = (node.get("parameters") or {}).get("value", "")
|
||||
from modules.workflows.automation2.graphUtils import resolveParameterReferences
|
||||
value = resolveParameterReferences(valueExpr, nodeOutputs)
|
||||
cases = (node.get("parameters") or {}).get("cases", [])
|
||||
for i, c in enumerate(cases):
|
||||
if c == value:
|
||||
return {"match": i, "value": value}
|
||||
return {"match": -1, "value": value}
|
||||
|
||||
async def _loop(self, node: Dict, nodeOutputs: Dict, nodeId: str, inputSources: Dict) -> Any:
|
||||
itemsPath = (node.get("parameters") or {}).get("items", "[]")
|
||||
from modules.workflows.automation2.graphUtils import resolveParameterReferences
|
||||
items = resolveParameterReferences(itemsPath, nodeOutputs)
|
||||
if not isinstance(items, list):
|
||||
items = [items] if items is not None else []
|
||||
return {"items": items, "count": len(items)}
|
||||
80
modules/workflows/automation2/executors/inputExecutor.py
Normal file
80
modules/workflows/automation2/executors/inputExecutor.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Input/Human node executor - creates tasks and pauses execution.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PauseForHumanTaskError(Exception):
|
||||
"""Raised when execution must pause for a human task. Contains runId, taskId."""
|
||||
|
||||
def __init__(self, runId: str, taskId: str, nodeId: str):
|
||||
self.runId = runId
|
||||
self.taskId = taskId
|
||||
self.nodeId = nodeId
|
||||
super().__init__(f"Pause for human task {taskId} (run {runId}, node {nodeId})")
|
||||
|
||||
|
||||
class PauseForEmailWaitError(Exception):
|
||||
"""Raised when execution must pause waiting for a new email. Background poller will resume."""
|
||||
|
||||
def __init__(self, runId: str, nodeId: str, waitConfig: Dict[str, Any]):
|
||||
self.runId = runId
|
||||
self.nodeId = nodeId
|
||||
self.waitConfig = waitConfig
|
||||
super().__init__(f"Pause for email wait (run {runId}, node {nodeId})")
|
||||
|
||||
|
||||
class InputExecutor:
|
||||
"""
|
||||
Execute input/human nodes. Creates a HumanTask, pauses the run, and raises
|
||||
PauseForHumanTaskError so the engine returns { paused: true, taskId, runId }.
|
||||
"""
|
||||
|
||||
def __init__(self, automation2_interface: Any):
|
||||
self.automation2 = automation2_interface
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
context: Dict[str, Any],
|
||||
) -> Any:
|
||||
nodeType = node.get("type", "")
|
||||
nodeId = node.get("id", "")
|
||||
runId = context.get("_runId")
|
||||
workflowId = context.get("workflowId")
|
||||
instanceId = context.get("instanceId")
|
||||
userId = context.get("userId")
|
||||
|
||||
if not runId or not workflowId:
|
||||
logger.error("InputExecutor: runId/workflowId missing in context - cannot create task")
|
||||
return {"error": "Missing run context", "success": False}
|
||||
|
||||
config = dict(node.get("parameters") or {})
|
||||
logger.info("InputExecutor node %s type=%s creating task", nodeId, nodeType)
|
||||
|
||||
task = self.automation2.createTask(
|
||||
runId=runId,
|
||||
workflowId=workflowId,
|
||||
nodeId=nodeId,
|
||||
nodeType=nodeType,
|
||||
config=config,
|
||||
assigneeId=userId,
|
||||
)
|
||||
taskId = task.get("id")
|
||||
|
||||
self.automation2.updateRun(
|
||||
runId,
|
||||
status="paused",
|
||||
nodeOutputs=context.get("nodeOutputs"),
|
||||
currentNodeId=nodeId,
|
||||
context={
|
||||
"connectionMap": context.get("connectionMap"),
|
||||
"inputSources": context.get("inputSources"),
|
||||
"orderedNodeIds": [n.get("id") for n in context.get("_orderedNodes", []) if n.get("id")],
|
||||
},
|
||||
)
|
||||
logger.info("InputExecutor node %s: created task %s, run %s paused", nodeId, taskId, runId)
|
||||
raise PauseForHumanTaskError(runId=runId, taskId=taskId, nodeId=nodeId)
|
||||
69
modules/workflows/automation2/executors/ioExecutor.py
Normal file
69
modules/workflows/automation2/executors/ioExecutor.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# I/O node executor - delegates to ActionExecutor.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IOExecutor:
|
||||
"""Execute I/O nodes by calling ActionExecutor.executeAction(method, action, params)."""
|
||||
|
||||
def __init__(self, services: Any):
|
||||
self.services = services
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
context: Dict[str, Any],
|
||||
) -> Any:
|
||||
from modules.workflows.processing.core.actionExecutor import ActionExecutor
|
||||
|
||||
nodeType = node.get("type", "")
|
||||
nodeId = node.get("id", "")
|
||||
logger.info("IOExecutor node %s type=%s", nodeId, nodeType)
|
||||
if not nodeType.startswith("io."):
|
||||
logger.debug("IOExecutor node %s not io.* -> None", nodeId)
|
||||
return None
|
||||
|
||||
parts = nodeType.split(".", 2)
|
||||
if len(parts) < 3:
|
||||
logger.debug("IOExecutor node %s invalid type parts -> None", nodeId)
|
||||
return None
|
||||
_, methodName, actionName = parts
|
||||
logger.info("IOExecutor node %s method=%s action=%s", nodeId, methodName, actionName)
|
||||
|
||||
nodeOutputs = context.get("nodeOutputs", {})
|
||||
params = dict(node.get("parameters") or {})
|
||||
|
||||
from modules.workflows.automation2.graphUtils import resolveParameterReferences
|
||||
resolvedParams = resolveParameterReferences(params, nodeOutputs)
|
||||
logger.info("IOExecutor node %s resolvedParams keys=%s", nodeId, list(resolvedParams.keys()))
|
||||
|
||||
inputSources = context.get("inputSources", {}).get(nodeId, {})
|
||||
if 0 in inputSources:
|
||||
srcId, _ = inputSources[0]
|
||||
inp = nodeOutputs.get(srcId)
|
||||
if isinstance(inp, dict):
|
||||
resolvedParams.setdefault("documentList", inp.get("documents", inp.get("documentList", [])))
|
||||
elif inp is not None:
|
||||
resolvedParams.setdefault("input", inp)
|
||||
|
||||
executor = ActionExecutor(self.services)
|
||||
logger.info("IOExecutor node %s calling executeAction(%s, %s)", nodeId, methodName, actionName)
|
||||
result = await executor.executeAction(methodName, actionName, resolvedParams)
|
||||
out = {
|
||||
"success": result.success,
|
||||
"error": result.error,
|
||||
"documents": [d.model_dump() if hasattr(d, "model_dump") else d for d in (result.documents or [])],
|
||||
"data": result.model_dump() if hasattr(result, "model_dump") else {"success": result.success, "error": result.error},
|
||||
}
|
||||
logger.info(
|
||||
"IOExecutor node %s result: success=%s error=%s doc_count=%d",
|
||||
nodeId,
|
||||
result.success,
|
||||
result.error,
|
||||
len(out.get("documents", [])),
|
||||
)
|
||||
return out
|
||||
37
modules/workflows/automation2/executors/triggerExecutor.py
Normal file
37
modules/workflows/automation2/executors/triggerExecutor.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Trigger node executor.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TriggerExecutor:
|
||||
"""Execute trigger nodes (manual, schedule, formSubmit)."""
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
node: Dict[str, Any],
|
||||
context: Dict[str, Any],
|
||||
) -> Any:
|
||||
nodeType = node.get("type", "")
|
||||
nodeId = node.get("id", "")
|
||||
logger.info("TriggerExecutor node %s type=%s parameters=%s", nodeId, nodeType, node.get("parameters"))
|
||||
if nodeType == "trigger.manual":
|
||||
out = {"triggered": True, "source": "manual"}
|
||||
logger.info("TriggerExecutor node %s -> manual trigger: %s", nodeId, out)
|
||||
return out
|
||||
if nodeType == "trigger.schedule":
|
||||
out = {"triggered": True, "source": "schedule"}
|
||||
logger.info("TriggerExecutor node %s -> schedule trigger: %s", nodeId, out)
|
||||
return out
|
||||
if nodeType == "trigger.formSubmit":
|
||||
params = node.get("parameters") or {}
|
||||
formId = params.get("formId", "")
|
||||
out = {"triggered": True, "source": "formSubmit", "formId": formId}
|
||||
logger.info("TriggerExecutor node %s -> formSubmit: %s", nodeId, out)
|
||||
return out
|
||||
out = {"triggered": True, "source": "unknown"}
|
||||
logger.info("TriggerExecutor node %s -> unknown: %s", nodeId, out)
|
||||
return out
|
||||
177
modules/workflows/automation2/graphUtils.py
Normal file
177
modules/workflows/automation2/graphUtils.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# Graph parsing, validation, and topological sort for automation2.
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any, Tuple, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parseGraph(graph: Dict[str, Any]) -> Tuple[List[Dict], List[Dict], Set[str]]:
|
||||
"""
|
||||
Parse graph into nodes, connections, and node IDs.
|
||||
graph: { nodes: [...], connections: [...] }
|
||||
Returns (nodes, connections, node_ids).
|
||||
"""
|
||||
nodes = graph.get("nodes") or []
|
||||
connections = graph.get("connections") or []
|
||||
nodeIds = {n.get("id") for n in nodes if n.get("id")}
|
||||
logger.debug(
|
||||
"parseGraph: nodes=%d connections=%d nodeIds=%s",
|
||||
len(nodes),
|
||||
len(connections),
|
||||
sorted(nodeIds),
|
||||
)
|
||||
return nodes, connections, nodeIds
|
||||
|
||||
|
||||
def buildConnectionMap(connections: List[Dict]) -> Dict[str, List[Tuple[str, int, int]]]:
|
||||
"""
|
||||
Build map: targetNodeId -> [(sourceNodeId, sourceOutput, targetInput), ...]
|
||||
connection: { source, sourceOutput?, target, targetInput? }
|
||||
"""
|
||||
out: Dict[str, List[Tuple[str, int, int]]] = {}
|
||||
for i, c in enumerate(connections):
|
||||
src = c.get("source") or c.get("sourceNode")
|
||||
tgt = c.get("target") or c.get("targetNode")
|
||||
if not src or not tgt:
|
||||
logger.debug("buildConnectionMap skip conn[%d]: missing source/target %r", i, c)
|
||||
continue
|
||||
so = c.get("sourceOutput", 0)
|
||||
ti = c.get("targetInput", 0)
|
||||
if tgt not in out:
|
||||
out[tgt] = []
|
||||
out[tgt].append((src, so, ti))
|
||||
logger.debug("buildConnectionMap conn[%d]: %s -> %s (so=%d ti=%d)", i, src, tgt, so, ti)
|
||||
logger.debug("buildConnectionMap result: %s", {k: v for k, v in out.items()})
|
||||
return out
|
||||
|
||||
|
||||
def getInputSources(nodeId: str, connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> Dict[int, Tuple[str, int]]:
|
||||
"""
|
||||
For a node, return targetInput -> (sourceNodeId, sourceOutput).
|
||||
"""
|
||||
result: Dict[int, Tuple[str, int]] = {}
|
||||
for src, so, ti in connectionMap.get(nodeId, []):
|
||||
result[ti] = (src, so)
|
||||
return result
|
||||
|
||||
|
||||
def getTriggerNodes(nodes: List[Dict]) -> List[Dict]:
|
||||
"""Return nodes with category=trigger or type starting with trigger."""
|
||||
return [n for n in nodes if (n.get("type", "").startswith("trigger.") or n.get("category") == "trigger")]
|
||||
|
||||
|
||||
def validateGraph(graph: Dict[str, Any], nodeTypeIds: Set[str]) -> List[str]:
|
||||
"""
|
||||
Validate graph: all node IDs referenced in connections exist, all node types in registry.
|
||||
Returns list of error messages (empty if valid).
|
||||
"""
|
||||
errors = []
|
||||
nodes, connections, nodeIds = parseGraph(graph)
|
||||
|
||||
for n in nodes:
|
||||
nid = n.get("id")
|
||||
ntype = n.get("type")
|
||||
if not nid:
|
||||
errors.append("Node missing id")
|
||||
continue
|
||||
if not ntype:
|
||||
errors.append(f"Node {nid} missing type")
|
||||
continue
|
||||
if ntype not in nodeTypeIds:
|
||||
errors.append(f"Unknown node type '{ntype}' for node {nid}")
|
||||
|
||||
connMap = buildConnectionMap(connections)
|
||||
allReferred = set()
|
||||
for tgt, pairs in connMap.items():
|
||||
allReferred.add(tgt)
|
||||
for src, _, _ in pairs:
|
||||
allReferred.add(src)
|
||||
for nid in allReferred:
|
||||
if nid not in nodeIds:
|
||||
errors.append(f"Connection references non-existent node {nid}")
|
||||
|
||||
if errors:
|
||||
logger.debug("validateGraph errors: %s", errors)
|
||||
else:
|
||||
logger.debug("validateGraph: OK")
|
||||
return errors
|
||||
|
||||
|
||||
def topoSort(nodes: List[Dict], connectionMap: Dict[str, List[Tuple[str, int, int]]]) -> List[Dict]:
|
||||
"""
|
||||
Topological sort: start from trigger nodes, then BFS by connections.
|
||||
Returns ordered list of nodes (trigger first, then downstream).
|
||||
"""
|
||||
nodeById = {n["id"]: n for n in nodes if n.get("id")}
|
||||
triggers = getTriggerNodes(nodes)
|
||||
if not triggers:
|
||||
return list(nodes)
|
||||
|
||||
visited: Set[str] = set()
|
||||
order: List[Dict] = []
|
||||
|
||||
def bfs(startIds: List[str]) -> None:
|
||||
from collections import deque
|
||||
q = deque(startIds)
|
||||
for nid in startIds:
|
||||
visited.add(nid)
|
||||
if nid in nodeById:
|
||||
order.append(nodeById[nid])
|
||||
while q:
|
||||
nid = q.popleft()
|
||||
# Find all nodes that receive from nid
|
||||
for tgt, pairs in connectionMap.items():
|
||||
for src, _, _ in pairs:
|
||||
if src == nid and tgt not in visited:
|
||||
visited.add(tgt)
|
||||
q.append(tgt)
|
||||
if tgt in nodeById:
|
||||
order.append(nodeById[tgt])
|
||||
|
||||
triggerIds = [t["id"] for t in triggers]
|
||||
logger.debug("topoSort triggers: %s", triggerIds)
|
||||
bfs(triggerIds)
|
||||
|
||||
# Append any orphan nodes (e.g. disconnected)
|
||||
for n in nodes:
|
||||
if n.get("id") and n["id"] not in visited:
|
||||
order.append(n)
|
||||
logger.debug("topoSort order (%d nodes): %s", len(order), [n.get("id") for n in order])
|
||||
return order
|
||||
|
||||
|
||||
def resolveParameterReferences(value: Any, nodeOutputs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Resolve {{nodeId.output}} or {{nodeId.output.path}} in strings/structures.
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
if isinstance(value, str):
|
||||
def repl(m):
|
||||
ref = m.group(1).strip()
|
||||
parts = ref.split(".")
|
||||
nodeId = parts[0]
|
||||
data = nodeOutputs.get(nodeId)
|
||||
if data is None:
|
||||
return m.group(0)
|
||||
if len(parts) < 2:
|
||||
return json.dumps(data) if isinstance(data, (dict, list)) else str(data)
|
||||
rest = ".".join(parts[1:])
|
||||
if data is None:
|
||||
return m.group(0)
|
||||
for k in rest.split("."):
|
||||
if isinstance(data, dict) and k in data:
|
||||
data = data[k]
|
||||
elif isinstance(data, (list, tuple)) and k.isdigit():
|
||||
data = data[int(k)]
|
||||
else:
|
||||
return m.group(0)
|
||||
return str(data) if data is not None else m.group(0)
|
||||
return re.sub(r"\{\{\s*([^}]+)\s*\}\}", repl, value)
|
||||
if isinstance(value, dict):
|
||||
return {k: resolveParameterReferences(v, nodeOutputs) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [resolveParameterReferences(v, nodeOutputs) for v in value]
|
||||
return value
|
||||
|
|
@ -1,9 +1,11 @@
|
|||
# Copyright (c) 2025 Patrick Motsch
|
||||
# All rights reserved.
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional
|
||||
from modules.datamodels.datamodelChat import ActionResult, ActionDocument
|
||||
from modules.datamodels.datamodelAi import AiCallRequest, AiCallOptions, OperationTypeEnum, ProcessingModeEnum
|
||||
|
|
@ -11,6 +13,64 @@ from modules.datamodels.datamodelExtraction import ContentPart
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_action_document_like(obj: Any) -> bool:
|
||||
"""Check if object is ActionDocument-like (has documentData for inline workflow documents)."""
|
||||
if obj is None:
|
||||
return False
|
||||
data = None
|
||||
if isinstance(obj, dict):
|
||||
data = obj.get("documentData") or obj.get("document_data")
|
||||
else:
|
||||
data = getattr(obj, "documentData", None) or getattr(obj, "document_data", None)
|
||||
if data is None:
|
||||
return False
|
||||
if isinstance(data, bytes):
|
||||
return len(data) > 0
|
||||
if isinstance(data, str):
|
||||
return len(data.strip()) > 0
|
||||
return True
|
||||
|
||||
|
||||
def _action_docs_to_content_parts(services, docs: List[Any]) -> List[ContentPart]:
|
||||
"""Extract content from ActionDocument-like objects in memory (no persistence).
|
||||
Decodes base64, runs extraction pipeline, returns ContentParts for AI.
|
||||
"""
|
||||
from modules.datamodels.datamodelExtraction import ExtractionOptions, MergeStrategy
|
||||
|
||||
all_parts = []
|
||||
extraction = getattr(services, "extraction", None)
|
||||
if not extraction:
|
||||
logger.warning("ai.process: No extraction service - cannot extract from inline documents")
|
||||
return []
|
||||
opts = ExtractionOptions(prompt="", mergeStrategy=MergeStrategy())
|
||||
for i, doc in enumerate(docs):
|
||||
raw = (doc.get("documentData") or doc.get("document_data")) if isinstance(doc, dict) else (getattr(doc, "documentData", None) or getattr(doc, "document_data", None))
|
||||
if not raw:
|
||||
continue
|
||||
name = doc.get("documentName", doc.get("fileName", f"document_{i}"))
|
||||
mime = doc.get("mimeType", "application/octet-stream")
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
content = base64.b64decode(raw, validate=True)
|
||||
except Exception:
|
||||
content = raw.encode("utf-8")
|
||||
else:
|
||||
content = raw if isinstance(raw, bytes) else bytes(raw)
|
||||
ec = extraction.extractContentFromBytes(
|
||||
documentBytes=content,
|
||||
fileName=name,
|
||||
mimeType=mime,
|
||||
documentId=str(uuid.uuid4()),
|
||||
options=opts,
|
||||
)
|
||||
for p in ec.parts:
|
||||
if p.data or getattr(p, "typeGroup", "") == "image":
|
||||
p.metadata.setdefault("originalFileName", name)
|
||||
all_parts.append(p)
|
||||
logger.info(f"ai.process: Extracted {len(ec.parts)} parts from {name} (no persistence)")
|
||||
return all_parts
|
||||
|
||||
async def process(self, parameters: Dict[str, Any]) -> ActionResult:
|
||||
operationId = None
|
||||
try:
|
||||
|
|
@ -41,8 +101,21 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult:
|
|||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
|
||||
documentListParam = parameters.get("documentList")
|
||||
# Convert to DocumentReferenceList if needed
|
||||
if documentListParam is None:
|
||||
inline_content_parts: Optional[List[ContentPart]] = None
|
||||
|
||||
# Handle inline ActionDocuments (e.g. from SharePoint/email in automation2 – no persistence)
|
||||
is_inline = (
|
||||
isinstance(documentListParam, list)
|
||||
and len(documentListParam) > 0
|
||||
and _is_action_document_like(documentListParam[0])
|
||||
)
|
||||
if is_inline:
|
||||
inline_content_parts = _action_docs_to_content_parts(self.services, documentListParam)
|
||||
documentList = DocumentReferenceList(references=[])
|
||||
logger.info(
|
||||
f"ai.process: Extracted {len(inline_content_parts)} ContentParts from {len(documentListParam)} inline ActionDocuments (no persistence)"
|
||||
)
|
||||
elif documentListParam is None:
|
||||
documentList = DocumentReferenceList(references=[])
|
||||
logger.debug(f"ai.process: documentList is None, using empty DocumentReferenceList")
|
||||
elif isinstance(documentListParam, DocumentReferenceList):
|
||||
|
|
@ -54,6 +127,11 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult:
|
|||
documentList = DocumentReferenceList.from_string_list([documentListParam])
|
||||
logger.info(f"ai.process: Converted string to DocumentReferenceList with {len(documentList.references)} references")
|
||||
elif isinstance(documentListParam, list):
|
||||
first = documentListParam[0] if documentListParam else None
|
||||
logger.info(
|
||||
f"ai.process: documentList is list of {len(documentListParam)} items, "
|
||||
f"first type={type(first).__name__}, has_documentData={_is_action_document_like(first) if first else False}"
|
||||
)
|
||||
documentList = DocumentReferenceList.from_string_list(documentListParam)
|
||||
logger.info(f"ai.process: Converted list to DocumentReferenceList with {len(documentList.references)} references")
|
||||
else:
|
||||
|
|
@ -65,7 +143,9 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult:
|
|||
simpleMode = parameters.get("simpleMode", False)
|
||||
|
||||
if not aiPrompt:
|
||||
logger.error(f"aiPrompt is missing or empty. Parameters: {parameters}")
|
||||
param_keys = list(parameters.keys()) if isinstance(parameters, dict) else []
|
||||
doc_count = len(parameters.get("documentList") or []) if isinstance(parameters.get("documentList"), (list, tuple)) else 0
|
||||
logger.error(f"aiPrompt is missing or empty. Parameter keys: {param_keys}, documentList: {doc_count} item(s)")
|
||||
return ActionResult.isFailure(
|
||||
error="AI prompt is required"
|
||||
)
|
||||
|
|
@ -86,10 +166,9 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult:
|
|||
mimeMap = {"txt": "text/plain", "json": "application/json", "html": "text/html", "md": "text/markdown", "csv": "text/csv", "xml": "application/xml"}
|
||||
output_mime_type = mimeMap.get(normalized_result_type, "text/plain") if normalized_result_type else "text/plain"
|
||||
|
||||
# Phase 7.3: Pass both documentList and contentParts to AI service
|
||||
# (Extraction logic removed - handled by AI service)
|
||||
contentParts: Optional[List[ContentPart]] = None
|
||||
if "contentParts" in parameters:
|
||||
# Phase 7.3: Pass documentList and/or contentParts to AI service
|
||||
contentParts: Optional[List[ContentPart]] = inline_content_parts
|
||||
if "contentParts" in parameters and not inline_content_parts:
|
||||
contentPartsParam = parameters.get("contentParts")
|
||||
if contentPartsParam:
|
||||
if isinstance(contentPartsParam, list):
|
||||
|
|
@ -203,8 +282,8 @@ async def process(self, parameters: Dict[str, Any]) -> ActionResult:
|
|||
aiResponse = await self.services.ai.callAiContent(
|
||||
prompt=aiPrompt,
|
||||
options=options,
|
||||
documentList=documentList, # Pass documentList - AI service handles extraction
|
||||
contentParts=contentParts, # Pass contentParts if provided (or None)
|
||||
documentList=documentList if not inline_content_parts else None, # Skip if using inline extracted parts
|
||||
contentParts=contentParts, # From inline ActionDocuments (extracted in memory) or parameters
|
||||
outputFormat=output_format, # Can be None - AI determines from prompt
|
||||
parentOperationId=operationId,
|
||||
generationIntent=generationIntent # REQUIRED for DATA_GENERATE
|
||||
|
|
|
|||
|
|
@ -20,11 +20,44 @@ async def composeAndDraftEmailWithContext(self, parameters: Dict[str, Any]) -> A
|
|||
bcc = parameters.get("bcc") or []
|
||||
emailStyle = parameters.get("emailStyle") or "business"
|
||||
maxLength = parameters.get("maxLength") or 1000
|
||||
|
||||
# Only connectionReference and context are required - to is optional for drafts
|
||||
if not connectionReference or not context:
|
||||
return ActionResult.isFailure(error="connectionReference and context are required")
|
||||
|
||||
|
||||
# Direct content from upstream (e.g. AI node): skip internal AI, use subject/body/to directly
|
||||
email_content = parameters.get("emailContent")
|
||||
if isinstance(email_content, dict):
|
||||
direct_subject = email_content.get("subject")
|
||||
direct_body = email_content.get("body")
|
||||
direct_to = email_content.get("to")
|
||||
if direct_subject and direct_body:
|
||||
subject = str(direct_subject).strip()
|
||||
body = str(direct_body).strip()
|
||||
to = [direct_to] if isinstance(direct_to, str) else (direct_to or [])
|
||||
if isinstance(to, str):
|
||||
to = [to]
|
||||
ai_attachments = []
|
||||
# Jump to create-email section (see below)
|
||||
else:
|
||||
direct_subject = parameters.get("subject")
|
||||
direct_body = parameters.get("body")
|
||||
if direct_subject and direct_body:
|
||||
subject = str(direct_subject).strip()
|
||||
body = str(direct_body).strip()
|
||||
if isinstance(to, str):
|
||||
to = [to]
|
||||
ai_attachments = []
|
||||
else:
|
||||
subject = None
|
||||
body = None
|
||||
ai_attachments = None
|
||||
|
||||
use_direct_content = bool(subject and body)
|
||||
|
||||
if not use_direct_content:
|
||||
# Original path: require connectionReference and context
|
||||
if not connectionReference or not context:
|
||||
return ActionResult.isFailure(error="connectionReference and context are required")
|
||||
elif not connectionReference:
|
||||
return ActionResult.isFailure(error="connectionReference is required")
|
||||
|
||||
# Convert single values to lists for all recipient parameters
|
||||
if isinstance(to, str):
|
||||
to = [to]
|
||||
|
|
@ -45,10 +78,10 @@ async def composeAndDraftEmailWithContext(self, parameters: Dict[str, Any]) -> A
|
|||
if not permissions_ok:
|
||||
return ActionResult.isFailure(error="Connection lacks necessary permissions for Outlook operations")
|
||||
|
||||
# Prepare documents for AI processing
|
||||
# Prepare documents for AI processing (only when using AI path)
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
chatDocuments = []
|
||||
if documentList:
|
||||
if not use_direct_content and documentList:
|
||||
# Convert to DocumentReferenceList if needed
|
||||
if isinstance(documentList, DocumentReferenceList):
|
||||
docRefList = documentList
|
||||
|
|
@ -60,33 +93,34 @@ async def composeAndDraftEmailWithContext(self, parameters: Dict[str, Any]) -> A
|
|||
docRefList = DocumentReferenceList(references=[])
|
||||
chatDocuments = self.services.chat.getChatDocumentsFromDocumentList(docRefList)
|
||||
|
||||
# Create AI prompt for email composition
|
||||
# Build document reference list for AI with expanded list contents when possible
|
||||
doc_references = documentList
|
||||
doc_list_text = ""
|
||||
if doc_references:
|
||||
lines = ["Available_Document_References:"]
|
||||
for ref in doc_references:
|
||||
# Each item is a label: resolve to its document list and render contained items
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
list_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([ref])) or []
|
||||
if list_docs:
|
||||
for d in list_docs:
|
||||
doc_ref_label = self.services.chat.getDocumentReferenceFromChatDocument(d)
|
||||
lines.append(f"- {doc_ref_label}")
|
||||
else:
|
||||
lines.append(" - (no documents)")
|
||||
doc_list_text = "\n" + "\n".join(lines)
|
||||
else:
|
||||
doc_list_text = "Available_Document_References: (No documents available for attachment)"
|
||||
|
||||
# Escape only the user-controlled context to prevent prompt injection
|
||||
escaped_context = context.replace('"', '\\"').replace('\n', '\\n').replace('\r', '\\r')
|
||||
|
||||
# Build recipients text for prompt
|
||||
recipients_text = f"Recipients: {to}" if to else "Recipients: (not specified - this is a draft)"
|
||||
|
||||
ai_prompt = f"""Compose an email based on this context:
|
||||
if not use_direct_content:
|
||||
# Create AI prompt for email composition
|
||||
# Build document reference list for AI with expanded list contents when possible
|
||||
doc_references = documentList
|
||||
doc_list_text = ""
|
||||
if doc_references:
|
||||
lines = ["Available_Document_References:"]
|
||||
for ref in doc_references:
|
||||
# Each item is a label: resolve to its document list and render contained items
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
list_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([ref])) or []
|
||||
if list_docs:
|
||||
for d in list_docs:
|
||||
doc_ref_label = self.services.chat.getDocumentReferenceFromChatDocument(d)
|
||||
lines.append(f"- {doc_ref_label}")
|
||||
else:
|
||||
lines.append(" - (no documents)")
|
||||
doc_list_text = "\n" + "\n".join(lines)
|
||||
else:
|
||||
doc_list_text = "Available_Document_References: (No documents available for attachment)"
|
||||
|
||||
# Escape only the user-controlled context to prevent prompt injection
|
||||
escaped_context = context.replace('"', '\\"').replace('\n', '\\n').replace('\r', '\\r')
|
||||
|
||||
# Build recipients text for prompt
|
||||
recipients_text = f"Recipients: {to}" if to else "Recipients: (not specified - this is a draft)"
|
||||
|
||||
ai_prompt = f"""Compose an email based on this context:
|
||||
-------
|
||||
{escaped_context}
|
||||
-------
|
||||
|
|
@ -107,93 +141,93 @@ Return JSON:
|
|||
"attachments": ["docItem:<documentId>:<filename>"]
|
||||
}}
|
||||
"""
|
||||
|
||||
# Call AI service to generate email content
|
||||
try:
|
||||
ai_response = await self.services.ai.callAiPlanning(
|
||||
prompt=ai_prompt,
|
||||
placeholders=None,
|
||||
debugType="email_composition"
|
||||
)
|
||||
|
||||
# Parse AI response
|
||||
|
||||
# Call AI service to generate email content
|
||||
try:
|
||||
ai_content = ai_response
|
||||
# Extract JSON from AI response
|
||||
if "```json" in ai_content:
|
||||
json_start = ai_content.find("```json") + 7
|
||||
json_end = ai_content.find("```", json_start)
|
||||
json_content = ai_content[json_start:json_end].strip()
|
||||
elif "{" in ai_content and "}" in ai_content:
|
||||
json_start = ai_content.find("{")
|
||||
json_end = ai_content.rfind("}") + 1
|
||||
json_content = ai_content[json_start:json_end]
|
||||
else:
|
||||
json_content = ai_content
|
||||
|
||||
email_data = json.loads(json_content)
|
||||
subject = email_data.get("subject", "")
|
||||
body = email_data.get("body", "")
|
||||
ai_attachments = email_data.get("attachments", [])
|
||||
|
||||
if not subject or not body:
|
||||
return ActionResult.isFailure(error="AI did not generate valid subject and body")
|
||||
|
||||
# Use AI-selected attachments if provided, otherwise use all documents
|
||||
normalized_ai_attachments = []
|
||||
if documentList:
|
||||
try:
|
||||
available_refs = [documentList] if isinstance(documentList, str) else documentList
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
available_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(available_refs)) or []
|
||||
except Exception:
|
||||
available_docs = []
|
||||
ai_response = await self.services.ai.callAiPlanning(
|
||||
prompt=ai_prompt,
|
||||
placeholders=None,
|
||||
debugType="email_composition"
|
||||
)
|
||||
|
||||
# Normalize AI attachments to a list of strings
|
||||
if isinstance(ai_attachments, str):
|
||||
ai_attachments = [ai_attachments]
|
||||
elif isinstance(ai_attachments, list):
|
||||
ai_attachments = [a for a in ai_attachments if isinstance(a, str)]
|
||||
# Parse AI response
|
||||
try:
|
||||
ai_content = ai_response
|
||||
# Extract JSON from AI response
|
||||
if "```json" in ai_content:
|
||||
json_start = ai_content.find("```json") + 7
|
||||
json_end = ai_content.find("```", json_start)
|
||||
json_content = ai_content[json_start:json_end].strip()
|
||||
elif "{" in ai_content and "}" in ai_content:
|
||||
json_start = ai_content.find("{")
|
||||
json_end = ai_content.rfind("}") + 1
|
||||
json_content = ai_content[json_start:json_end]
|
||||
else:
|
||||
json_content = ai_content
|
||||
|
||||
if ai_attachments:
|
||||
email_data = json.loads(json_content)
|
||||
subject = email_data.get("subject", "")
|
||||
body = email_data.get("body", "")
|
||||
ai_attachments = email_data.get("attachments", [])
|
||||
|
||||
if not subject or not body:
|
||||
return ActionResult.isFailure(error="AI did not generate valid subject and body")
|
||||
|
||||
# Use AI-selected attachments if provided, otherwise use all documents
|
||||
normalized_ai_attachments = []
|
||||
if documentList:
|
||||
try:
|
||||
ai_refs = [ai_attachments] if isinstance(ai_attachments, str) else ai_attachments
|
||||
available_refs = [documentList] if isinstance(documentList, str) else documentList
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
ai_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(ai_refs)) or []
|
||||
available_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(available_refs)) or []
|
||||
except Exception:
|
||||
ai_docs = []
|
||||
available_docs = []
|
||||
|
||||
# Intersect by document id
|
||||
available_ids = {getattr(d, 'id', None) for d in available_docs}
|
||||
selected_docs = [d for d in ai_docs if getattr(d, 'id', None) in available_ids]
|
||||
# Normalize AI attachments to a list of strings
|
||||
if isinstance(ai_attachments, str):
|
||||
ai_attachments = [ai_attachments]
|
||||
elif isinstance(ai_attachments, list):
|
||||
ai_attachments = [a for a in ai_attachments if isinstance(a, str)]
|
||||
|
||||
if selected_docs:
|
||||
# Map selected ChatDocuments back to docItem references (with full filename)
|
||||
documentList = [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in selected_docs]
|
||||
# Normalize ai_attachments to full format for storage
|
||||
normalized_ai_attachments = documentList.copy()
|
||||
logger.info(f"AI selected {len(documentList)} documents for attachment (resolved via ChatDocuments)")
|
||||
if ai_attachments:
|
||||
try:
|
||||
ai_refs = [ai_attachments] if isinstance(ai_attachments, str) else ai_attachments
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
ai_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(ai_refs)) or []
|
||||
except Exception:
|
||||
ai_docs = []
|
||||
|
||||
# Intersect by document id
|
||||
available_ids = {getattr(d, 'id', None) for d in available_docs}
|
||||
selected_docs = [d for d in ai_docs if getattr(d, 'id', None) in available_ids]
|
||||
|
||||
if selected_docs:
|
||||
# Map selected ChatDocuments back to docItem references (with full filename)
|
||||
documentList = [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in selected_docs]
|
||||
# Normalize ai_attachments to full format for storage
|
||||
normalized_ai_attachments = documentList.copy()
|
||||
logger.info(f"AI selected {len(documentList)} documents for attachment (resolved via ChatDocuments)")
|
||||
else:
|
||||
# No intersection; use all available documents
|
||||
documentList = [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in available_docs]
|
||||
normalized_ai_attachments = documentList.copy()
|
||||
logger.warning("AI selected attachments not found in available documents, using all documents")
|
||||
else:
|
||||
# No intersection; use all available documents
|
||||
# No AI selection; use all available documents
|
||||
documentList = [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in available_docs]
|
||||
normalized_ai_attachments = documentList.copy()
|
||||
logger.warning("AI selected attachments not found in available documents, using all documents")
|
||||
logger.warning("AI did not specify attachments, using all available documents")
|
||||
else:
|
||||
# No AI selection; use all available documents
|
||||
documentList = [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in available_docs]
|
||||
normalized_ai_attachments = documentList.copy()
|
||||
logger.warning("AI did not specify attachments, using all available documents")
|
||||
else:
|
||||
logger.info("No documents provided in documentList; skipping attachment processing")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse AI response as JSON: {str(e)}")
|
||||
logger.error(f"AI response content: {ai_response}")
|
||||
return ActionResult.isFailure(error="AI response was not valid JSON format")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling AI service: {str(e)}")
|
||||
return ActionResult.isFailure(error=f"Failed to generate email content: {str(e)}")
|
||||
logger.info("No documents provided in documentList; skipping attachment processing")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse AI response as JSON: {str(e)}")
|
||||
logger.error(f"AI response content: {ai_response}")
|
||||
return ActionResult.isFailure(error="AI response was not valid JSON format")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling AI service: {str(e)}")
|
||||
return ActionResult.isFailure(error=f"Failed to generate email content: {str(e)}")
|
||||
|
||||
# Now create the email with AI-generated content
|
||||
try:
|
||||
|
|
@ -227,35 +261,78 @@ Return JSON:
|
|||
}
|
||||
|
||||
# Add documents as attachments if provided
|
||||
# Supports: 1) inline ActionDocuments (dict with documentData from e.g. sharepoint.downloadFile)
|
||||
# 2) docItem:... references (chat workflow documents)
|
||||
if documentList:
|
||||
message["attachments"] = []
|
||||
for attachment_ref in documentList:
|
||||
# Get attachment document from service center
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
attachment_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([attachment_ref]))
|
||||
base64_content = None
|
||||
attach_name = "attachment"
|
||||
attach_mime = "application/octet-stream"
|
||||
|
||||
# Inline document: dict/object with documentData (from automation2 upstream, e.g. sharepoint.downloadFile)
|
||||
is_inline = isinstance(attachment_ref, dict) and attachment_ref.get("documentData")
|
||||
if not is_inline and hasattr(attachment_ref, "documentData"):
|
||||
is_inline = bool(getattr(attachment_ref, "documentData", None))
|
||||
if is_inline:
|
||||
doc = attachment_ref
|
||||
base64_content = doc.get("documentData") if isinstance(doc, dict) else getattr(doc, "documentData", None)
|
||||
attach_name = (doc.get("documentName") or doc.get("fileName")) if isinstance(doc, dict) else (getattr(doc, "documentName", None) or getattr(doc, "fileName", "attachment"))
|
||||
attach_mime = (doc.get("mimeType") or attach_mime) if isinstance(doc, dict) else (getattr(doc, "mimeType", None) or attach_mime)
|
||||
if base64_content and attach_name:
|
||||
message["attachments"].append({
|
||||
"@odata.type": "#microsoft.graph.fileAttachment",
|
||||
"name": attach_name,
|
||||
"contentType": attach_mime,
|
||||
"contentBytes": base64_content
|
||||
})
|
||||
continue
|
||||
# fileId in validationMetadata: resolve via getFileData (avoids large base64 in pipeline)
|
||||
file_id = None
|
||||
if isinstance(attachment_ref, dict):
|
||||
vm = attachment_ref.get("validationMetadata") or {}
|
||||
file_id = vm.get("fileId")
|
||||
elif hasattr(attachment_ref, "validationMetadata"):
|
||||
vm = getattr(attachment_ref, "validationMetadata") or {}
|
||||
file_id = vm.get("fileId") if isinstance(vm, dict) else None
|
||||
if file_id:
|
||||
try:
|
||||
file_content = self.services.chat.getFileData(file_id)
|
||||
if file_content:
|
||||
base64_content = base64.b64encode(file_content if isinstance(file_content, bytes) else str(file_content).encode("utf-8")).decode("utf-8")
|
||||
name = (attachment_ref.get("documentName") or attachment_ref.get("fileName", "attachment")) if isinstance(attachment_ref, dict) else (getattr(attachment_ref, "documentName", None) or getattr(attachment_ref, "fileName", "attachment"))
|
||||
mime = (attachment_ref.get("mimeType") or attach_mime) if isinstance(attachment_ref, dict) else (getattr(attachment_ref, "mimeType", None) or attach_mime)
|
||||
message["attachments"].append({
|
||||
"@odata.type": "#microsoft.graph.fileAttachment",
|
||||
"name": name,
|
||||
"contentType": mime,
|
||||
"contentBytes": base64_content
|
||||
})
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("Could not load file %s for attachment: %s", file_id, e)
|
||||
|
||||
# docItem:... reference (chat workflow) – only when it's a string ref
|
||||
attachment_docs = []
|
||||
if isinstance(attachment_ref, str) and attachment_ref.strip():
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
attachment_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list([attachment_ref]))
|
||||
if attachment_docs:
|
||||
for doc in attachment_docs:
|
||||
file_id = getattr(doc, 'fileId', None)
|
||||
if file_id:
|
||||
fid = getattr(doc, 'fileId', None)
|
||||
if fid:
|
||||
try:
|
||||
file_content = self.services.chat.getFileData(file_id)
|
||||
file_content = self.services.chat.getFileData(fid)
|
||||
if file_content:
|
||||
if isinstance(file_content, bytes):
|
||||
content_bytes = file_content
|
||||
else:
|
||||
content_bytes = str(file_content).encode('utf-8')
|
||||
|
||||
base64_content = base64.b64encode(content_bytes).decode('utf-8')
|
||||
|
||||
attachment = {
|
||||
cb = file_content if isinstance(file_content, bytes) else str(file_content).encode('utf-8')
|
||||
message["attachments"].append({
|
||||
"@odata.type": "#microsoft.graph.fileAttachment",
|
||||
"name": doc.fileName,
|
||||
"contentType": doc.mimeType or "application/octet-stream",
|
||||
"contentBytes": base64_content
|
||||
}
|
||||
message["attachments"].append(attachment)
|
||||
"contentBytes": base64.b64encode(cb).decode('utf-8')
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading attachment file {doc.fileName}: {str(e)}")
|
||||
logger.error("Error reading attachment file %s: %s", doc.fileName, e)
|
||||
|
||||
# Create the draft message
|
||||
drafts_folder_id = self.folderManagement.getFolderId("Drafts", connection)
|
||||
|
|
@ -297,14 +374,21 @@ Return JSON:
|
|||
attachmentFilenames = []
|
||||
attachmentReferences = []
|
||||
if documentList:
|
||||
try:
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
attached_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(documentList)) or []
|
||||
attachmentFilenames = [getattr(doc, 'fileName', '') for doc in attached_docs if getattr(doc, 'fileName', None)]
|
||||
# Store normalized document references (with filenames) - use normalized_ai_attachments if available
|
||||
attachmentReferences = normalized_ai_attachments if normalized_ai_attachments else [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in attached_docs]
|
||||
except Exception:
|
||||
pass
|
||||
# Inline docs (dict with documentName): use directly
|
||||
string_refs = [r for r in documentList if isinstance(r, str)]
|
||||
inline_docs = [r for r in documentList if isinstance(r, dict)]
|
||||
for d in inline_docs:
|
||||
name = d.get("documentName") or d.get("fileName")
|
||||
if name:
|
||||
attachmentFilenames.append(name)
|
||||
if string_refs:
|
||||
try:
|
||||
from modules.datamodels.datamodelDocref import DocumentReferenceList
|
||||
attached_docs = self.services.chat.getChatDocumentsFromDocumentList(DocumentReferenceList.from_string_list(string_refs)) or []
|
||||
attachmentFilenames.extend(getattr(doc, 'fileName', '') for doc in attached_docs if getattr(doc, 'fileName', None))
|
||||
attachmentReferences = normalized_ai_attachments if normalized_ai_attachments else [self.services.chat.getDocumentReferenceFromChatDocument(d) for d in attached_docs]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Create validation metadata for content validator
|
||||
validationMetadata = {
|
||||
|
|
|
|||
|
|
@ -49,9 +49,9 @@ async def readEmails(self, parameters: Dict[str, Any]) -> ActionResult:
|
|||
if filter:
|
||||
# Remove any potentially dangerous characters that could break the filter
|
||||
filter = filter.strip()
|
||||
if len(filter) > 100:
|
||||
logger.warning(f"Filter too long ({len(filter)} chars), truncating to 100 characters")
|
||||
filter = filter[:100]
|
||||
if len(filter) > 500:
|
||||
logger.warning(f"Filter too long ({len(filter)} chars), truncating to 500 characters")
|
||||
filter = filter[:500]
|
||||
|
||||
|
||||
# Get Microsoft connection
|
||||
|
|
|
|||
|
|
@ -73,8 +73,12 @@ async def searchEmails(self, parameters: Dict[str, Any]) -> ActionResult:
|
|||
logger.warning(f"Could not find folder ID for '{folder}', using folder name directly")
|
||||
|
||||
# Build the search API request
|
||||
api_url = f"{graph_url}/me/messages"
|
||||
params = self.emailProcessing.buildSearchParameters(query, folder_id or folder, limit)
|
||||
# Use folder-specific URL when we have folder_id and $search - avoids InefficientFilter
|
||||
if folder_id and params.get("$search"):
|
||||
api_url = f"{graph_url}/me/mailFolders/{folder_id}/messages"
|
||||
else:
|
||||
api_url = f"{graph_url}/me/messages"
|
||||
|
||||
# Log search parameters for debugging
|
||||
logger.debug(f"Search query: '{query}'")
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class EmailProcessingHelper:
|
|||
# Handle common search operators
|
||||
# Recognize Graph operators including both singular and plural forms for hasAttachments
|
||||
lowered = clean_query.lower()
|
||||
if any(op in lowered for op in ['from:', 'to:', 'subject:', 'received:', 'hasattachment:', 'hasattachments:']):
|
||||
if any(op in lowered for op in ['from:', 'to:', 'subject:', 'body:', 'received:', 'hasattachment:', 'hasattachments:']):
|
||||
# This is an advanced search query, return as-is
|
||||
return clean_query
|
||||
|
||||
|
|
@ -104,7 +104,7 @@ class EmailProcessingHelper:
|
|||
# Check if this is a complex search query with multiple operators
|
||||
# Recognize Graph operators including both singular and plural forms for hasAttachments
|
||||
lowered = clean_query.lower()
|
||||
if any(op in lowered for op in ['from:', 'to:', 'subject:', 'received:', 'hasattachment:', 'hasattachments:']):
|
||||
if any(op in lowered for op in ['from:', 'to:', 'subject:', 'body:', 'received:', 'hasattachment:', 'hasattachments:']):
|
||||
# This is an advanced search query, use $search
|
||||
# Microsoft Graph API supports complex search syntax
|
||||
params["$search"] = f'"{clean_query}"'
|
||||
|
|
@ -113,34 +113,20 @@ class EmailProcessingHelper:
|
|||
# We'll need to filter results after the API call
|
||||
# Folder filtering will be done after the API call
|
||||
else:
|
||||
# Use $filter for basic text search, but keep it simple to avoid "InefficientFilter" error
|
||||
# Microsoft Graph API has limitations on complex filters
|
||||
# Use $search (KQL) instead of $filter to avoid "InefficientFilter" - Graph rejects
|
||||
# contains(subject,x) + parentFolderId + orderby. $search handles subject:query.
|
||||
if len(clean_query) > 50:
|
||||
# If query is too long, truncate it to avoid complex filter issues
|
||||
clean_query = clean_query[:50]
|
||||
|
||||
|
||||
# Use only subject search to keep filter simple
|
||||
# Handle wildcard queries specially
|
||||
if clean_query == "*" or clean_query == "":
|
||||
# For wildcard or empty query, don't use contains filter
|
||||
# Just use folder filter if specified
|
||||
if folder and folder.lower() != "all":
|
||||
params["$filter"] = f"parentFolderId eq '{folder}'"
|
||||
else:
|
||||
# No filter needed for wildcard search across all folders
|
||||
pass
|
||||
params["$orderby"] = "receivedDateTime desc"
|
||||
else:
|
||||
params["$filter"] = f"contains(subject,'{clean_query}')"
|
||||
|
||||
# Add folder filter if specified
|
||||
if folder and folder.lower() != "all":
|
||||
params["$filter"] = f"{params['$filter']} and parentFolderId eq '{folder}'"
|
||||
|
||||
# Add orderby for basic queries
|
||||
params["$orderby"] = "receivedDateTime desc"
|
||||
# Use $search with subject: to avoid InefficientFilter
|
||||
safe = clean_query.replace('"', '')
|
||||
params["$search"] = f'"subject:{safe}"'
|
||||
# Folder filtering done post-API in searchEmails when $search is used
|
||||
|
||||
|
||||
return params
|
||||
|
||||
def buildGraphFilter(self, filter_text: str) -> Dict[str, str]:
|
||||
|
|
@ -168,7 +154,7 @@ class EmailProcessingHelper:
|
|||
# Handle search queries (from:, to:, subject:, etc.) - check this FIRST
|
||||
# Support both singular and plural forms for hasAttachments
|
||||
lt = filter_text.lower()
|
||||
if any(lt.startswith(prefix) for prefix in ['from:', 'to:', 'subject:', 'received:', 'hasattachment:', 'hasattachments:']):
|
||||
if any(lt.startswith(prefix) for prefix in ['from:', 'to:', 'subject:', 'body:', 'received:', 'hasattachment:', 'hasattachments:']):
|
||||
return {"$search": f'"{filter_text}"'}
|
||||
|
||||
# Handle email address filters (only if it's NOT a search query)
|
||||
|
|
|
|||
|
|
@ -27,10 +27,15 @@ class FolderManagementHelper:
|
|||
|
||||
def getFolderId(self, folder_name: str, connection: Dict[str, Any]) -> Optional[str]:
|
||||
"""
|
||||
Get the folder ID for a given folder name
|
||||
|
||||
This is needed for proper filtering when using advanced search queries
|
||||
Get the folder ID for a given folder name or ID.
|
||||
Returns the input as-is if it already looks like a Microsoft Graph folder ID.
|
||||
"""
|
||||
if not folder_name or not str(folder_name).strip():
|
||||
return None
|
||||
# Graph folder IDs are base64-like strings (e.g. AQMk...); return as-is
|
||||
s = str(folder_name).strip()
|
||||
if s.startswith("AQMk") and len(s) > 20 and " " not in s:
|
||||
return s
|
||||
try:
|
||||
graph_url = "https://graph.microsoft.com/v1.0"
|
||||
headers = {
|
||||
|
|
|
|||
|
|
@ -157,15 +157,22 @@ class MethodOutlook(MethodBase):
|
|||
name="context",
|
||||
type="str",
|
||||
frontendType=FrontendType.TEXTAREA,
|
||||
required=True,
|
||||
description="Detailed context for composing the email"
|
||||
required=False,
|
||||
description="Detailed context for AI composition (omit when emailContent provided)"
|
||||
),
|
||||
"emailContent": WorkflowActionParameter(
|
||||
name="emailContent",
|
||||
type="dict",
|
||||
frontendType=FrontendType.HIDDEN,
|
||||
required=False,
|
||||
description="Direct subject/body/to from upstream (skips AI composition)"
|
||||
),
|
||||
"documentList": WorkflowActionParameter(
|
||||
name="documentList",
|
||||
type="List[str]",
|
||||
type="List[Any]",
|
||||
frontendType=FrontendType.DOCUMENT_REFERENCE,
|
||||
required=False,
|
||||
description="Document references for context/attachments"
|
||||
description="Document references or inline ActionDocuments for attachments"
|
||||
),
|
||||
"cc": WorkflowActionParameter(
|
||||
name="cc",
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue